Triplet Loss,即三元組損失,用于訓(xùn)練差異性較小的數(shù)據(jù)集,數(shù)據(jù)集中標(biāo)簽較多,標(biāo)簽的樣本較少。輸入數(shù)據(jù)包括錨(Anchor)示例??、正(Positive)示例和負(fù)(Negative)示例,通過(guò)優(yōu)化模型,使得錨示例與正示例的距離小于錨示例與負(fù)示例的距離,實(shí)現(xiàn)樣本的相似性計(jì)算。其中錨示例是樣本集中隨機(jī)選取的一個(gè)樣本,正示例與錨示例屬于同一類(lèi)的樣本,而負(fù)示例與錨示例屬于不同類(lèi)的樣本。
歡迎Follow我的GitHub:https://github.com/SpikeKing

在訓(xùn)練Triplet Loss模型時(shí),只需要輸入樣本,不需要輸入標(biāo)簽,這樣避免標(biāo)簽過(guò)多、同標(biāo)簽樣本過(guò)少的問(wèn)題,模型只關(guān)心樣本編碼,不關(guān)心樣本類(lèi)別。Triplet Loss在相似性計(jì)算和檢索中的效果較好,可以學(xué)習(xí)到樣本與變換樣本之間的關(guān)聯(lián),檢索出與當(dāng)前樣本最相似的其他樣本。
Triplet Loss通常應(yīng)用于個(gè)體級(jí)別的細(xì)粒度識(shí)別,比如分類(lèi)貓與狗等是大類(lèi)別的識(shí)別,但是有些需求要精確至個(gè)體級(jí)別,比如識(shí)別不同種類(lèi)不同配色的貓??等,所以Triplet Loss最主要的應(yīng)用也是在細(xì)粒度檢索領(lǐng)域中。
Triplet Loss的對(duì)比:
- 如果把不同個(gè)體作為類(lèi)別進(jìn)行分類(lèi)訓(xùn)練,Softmax維度可能遠(yuǎn)大于Feature維度,精度無(wú)法保證。
- Triplet Loss一般比分類(lèi)能學(xué)習(xí)到更好的特征,在度量樣本距離時(shí),效果較好;
- Triplet Loss支持調(diào)整閾值Margin,控制正負(fù)樣本的距離,當(dāng)特征歸一化之后,通過(guò)調(diào)節(jié)閾值提升置信度。
Triplet Loss的公式:

其他請(qǐng)參考Triplet Loss算法的論文。
本文使用MXNet/Gluon深度學(xué)習(xí)框架,數(shù)據(jù)集選用MNIST,實(shí)現(xiàn)Triplet Loss算法。
本文的源碼:https://github.com/SpikeKing/triplet-loss-gluon
數(shù)據(jù)集
安裝MXNet庫(kù):
pip install mxnet
推薦豆瓣源下載,速度較快,-i https://pypi.douban.com/simple
MNIST就是著名的手寫(xiě)數(shù)字識(shí)別庫(kù),其中包含0至9等10個(gè)數(shù)字的手寫(xiě)體,圖片大小為28*28的灰度圖,目標(biāo)是根據(jù)圖片識(shí)別正確的數(shù)字。
使用MNIST類(lèi)加載數(shù)據(jù)集,獲取訓(xùn)練集mnist_train和測(cè)試集mnist_test的數(shù)據(jù)和標(biāo)簽。
mnist_train = MNIST(train=True) # 加載訓(xùn)練
tr_data = mnist_train._data.reshape((-1, 28 * 28)) # 數(shù)據(jù)
tr_label = mnist_train._label # 標(biāo)簽
mnist_test = MNIST(train=False) # 加載測(cè)試
te_data = mnist_test._data.reshape((-1, 28 * 28)) # 數(shù)據(jù)
te_label = mnist_test._label # 標(biāo)簽
Triplet Loss訓(xùn)練的一個(gè)關(guān)鍵步驟就是準(zhǔn)備訓(xùn)練數(shù)據(jù)。本例繼承Dataset類(lèi)創(chuàng)建Triplet的數(shù)據(jù)集類(lèi)TripletDataset:
- 在構(gòu)造器中:
- 傳入原始數(shù)據(jù)rd、原始標(biāo)簽rl;
-
_data和_label是標(biāo)準(zhǔn)的數(shù)據(jù)和標(biāo)簽變量; -
_transform是標(biāo)準(zhǔn)的轉(zhuǎn)換變量; - 調(diào)用
_get_data(),完成_data和_label的賦值;
-
__getitem__是數(shù)據(jù)處理接口,根據(jù)索引idx返回?cái)?shù)據(jù),支持調(diào)用_transform執(zhí)行數(shù)據(jù)轉(zhuǎn)換; -
__len__是數(shù)據(jù)的總數(shù); -
_get_data()是數(shù)據(jù)賦值的核心方法:- 分離索引,獲取標(biāo)簽相同數(shù)據(jù)的索引值Index列表
digit_indices; - 創(chuàng)建三元組,即錨示例、正示例和負(fù)示例的索引組合矩陣;
- 數(shù)據(jù)是三元組,標(biāo)簽是ones矩陣,因?yàn)闃?biāo)簽在Triplet Loss中沒(méi)有實(shí)際意義;
- 分離索引,獲取標(biāo)簽相同數(shù)據(jù)的索引值Index列表
具體實(shí)現(xiàn):
class TripletDataset(dataset.Dataset):
def __init__(self, rd, rl, transform=None):
self.__rd = rd # 原始數(shù)據(jù)
self.__rl = rl # 原始標(biāo)簽
self._data = None
self._label = None
self._transform = transform
self._get_data()
def __getitem__(self, idx):
if self._transform is not None:
return self._transform(self._data[idx], self._label[idx])
return self._data[idx], self._label[idx]
def __len__(self):
return len(self._label)
def _get_data(self):
label_list = np.unique(self.__rl)
digit_indices = [np.where(self.__rl == i)[0] for i in label_list]
tl_pairs = create_pairs(self.__rd, digit_indices, len(label_list))
self._data = tl_pairs
self._label = mx.nd.ones(tl_pairs.shape[0])
create_pairs()是創(chuàng)建三元組的核心邏輯:
- 確定不同標(biāo)簽的選擇樣本數(shù),選擇最少的標(biāo)簽樣本數(shù);
- 將標(biāo)簽d的索引值隨機(jī)洗牌(Shuffle),選擇樣本i和i+1作為錨和正示例;
- 隨機(jī)選擇(Randrange)其他標(biāo)簽dn中的樣本i作為負(fù)示例;
- 循環(huán)全部標(biāo)簽和全部樣本,生成含有錨、正、負(fù)示例的隨機(jī)組合。
這樣所創(chuàng)建的組合矩陣,保證樣本的分布均勻,既避免組合過(guò)大(對(duì)比于全排列),又引入足夠的隨機(jī)性(雙重隨機(jī))。注意:由于滑動(dòng)窗口為2,即i和i+1,則19個(gè)樣本生成18個(gè)樣本組。
具體實(shí)現(xiàn),如下:
@staticmethod
def create_pairs(x, digit_indices, num_classes):
x = x.asnumpy() # 轉(zhuǎn)換數(shù)據(jù)格式
pairs = []
n = min([len(digit_indices[d]) for d in range(num_classes)]) - 1 # 最小類(lèi)別數(shù)
for d in range(num_classes):
for i in range(n):
np.random.shuffle(digit_indices[d])
z1, z2 = digit_indices[d][i], digit_indices[d][i + 1]
inc = random.randrange(1, num_classes)
dn = (d + inc) % num_classes
z3 = digit_indices[dn][i]
pairs += [[x[z1], x[z2], x[z3]]]
return np.asarray(pairs))
使用DataLoader將TripletDataset封裝為迭代器train_data和test_data,支持按批次batch輸出樣本。train_data用于訓(xùn)練網(wǎng)絡(luò),test_data用于驗(yàn)證網(wǎng)絡(luò)。
def transform(data_, label_):
return data_.astype(np.float32) / 255., label_.astype(np.float32)
train_data = DataLoader(
TripletDataset(rd=tr_data, rl=tr_label, transform=transform),
batch_size, shuffle=True)
test_data = DataLoader(
TripletDataset(rd=te_data, rl=te_label, transform=transform),
batch_size, shuffle=True)
網(wǎng)絡(luò)和訓(xùn)練
Triplet Loss的基礎(chǔ)網(wǎng)絡(luò),選用非常簡(jiǎn)單的多層感知機(jī),主要為了驗(yàn)證Triplet Loss的效果。
base_net = Sequential()
with base_net.name_scope():
base_net.add(Dense(256, activation='relu'))
base_net.add(Dense(128, activation='relu'))
base_net.collect_params().initialize(mx.init.Uniform(scale=0.1), ctx=ctx)
初始化參數(shù),使用uniform均勻分布,范圍是[-0.1, 0.1],效果類(lèi)似如下:

Gluon中自帶TripletLoss損失函數(shù),非常贊??,產(chǎn)學(xué)結(jié)合的非常好!初始化損失函數(shù)triplet_loss和訓(xùn)練器trainer_triplet。
triplet_loss = gluon.loss.TripletLoss() # TripletLoss損失函數(shù)
trainer_triplet = gluon.Trainer(base_net.collect_params(), 'sgd', {'learning_rate': 0.05})
Triplet Loss的訓(xùn)練過(guò)程:
- 循環(huán)執(zhí)行epoch,共10輪;
-
train_data迭代輸出每個(gè)批次的訓(xùn)練數(shù)據(jù)data; - 指定訓(xùn)練的執(zhí)行環(huán)境
as_in_context(),MXNet的數(shù)據(jù)環(huán)境就是訓(xùn)練環(huán)境; - 數(shù)據(jù)來(lái)源于TripletDataset,可以直接分為三個(gè)示例;
- 三個(gè)示例共享模型
base_net,計(jì)算triplet_loss的損失函數(shù); - 調(diào)用loss.backward(),反向傳播求導(dǎo);
- 設(shè)置訓(xùn)練器
trainer_triplet的step是batch_size; - 計(jì)算損失函數(shù)的均值
curr_loss; - 使用測(cè)試數(shù)據(jù)
test_data評(píng)估網(wǎng)絡(luò)base_net;
具體實(shí)現(xiàn):
for epoch in range(10):
curr_loss = 0.0
for i, (data, _) in enumerate(train_data):
data = data.as_in_context(ctx)
anc_ins, pos_ins, neg_ins = data[:, 0], data[:, 1], data[:, 2]
with autograd.record():
inter1 = base_net(anc_ins)
inter2 = base_net(pos_ins)
inter3 = base_net(neg_ins)
loss = triplet_loss(inter1, inter2, inter3) # Triplet Loss
loss.backward()
trainer_triplet.step(batch_size)
curr_loss = mx.nd.mean(loss).asscalar()
# print('Epoch: %s, Batch: %s, Triplet Loss: %s' % (epoch, i, curr_loss))
print('Epoch: %s, Triplet Loss: %s' % (epoch, curr_loss))
evaluate_net(base_net, test_data, ctx=ctx) # 評(píng)估網(wǎng)絡(luò)
評(píng)估網(wǎng)絡(luò)也是一個(gè)重要的過(guò)程,驗(yàn)證網(wǎng)絡(luò)的泛化能力:
- 設(shè)置
triplet_loss損失函數(shù),margin設(shè)置為0; -
test_data迭代輸出每個(gè)批次的驗(yàn)證數(shù)據(jù)data; - 指定驗(yàn)證數(shù)據(jù)的環(huán)境,需要與訓(xùn)練一致,因?yàn)槭窃谟?xùn)練的過(guò)程中驗(yàn)證;
- 通過(guò)模型,預(yù)測(cè)三元數(shù)據(jù),計(jì)算損失函數(shù);
- 由于TripletLoss的margin是0,因此只有0才是預(yù)測(cè)正確,其余全部預(yù)測(cè)錯(cuò)誤;
- 統(tǒng)計(jì)整體的樣本總數(shù)和正確樣本數(shù),計(jì)算全部測(cè)試數(shù)據(jù)的正確率;
具體實(shí)現(xiàn):
def evaluate_net(model, test_data, ctx):
triplet_loss = gluon.loss.TripletLoss(margin=0)
sum_correct = 0
sum_all = 0
rate = 0.0
for i, (data, _) in enumerate(test_data):
data = data.as_in_context(ctx)
anc_ins, pos_ins, neg_ins = data[:, 0], data[:, 1], data[:, 2]
inter1 = model(anc_ins) # 訓(xùn)練的時(shí)候組合
inter2 = model(pos_ins)
inter3 = model(neg_ins)
loss = triplet_loss(inter1, inter2, inter3)
loss = loss.asnumpy()
n_all = loss.shape[0]
n_correct = np.sum(np.where(loss == 0, 1, 0))
sum_correct += n_correct
sum_all += n_all
rate = safe_div(sum_correct, sum_all)
print('準(zhǔn)確率: %.4f (%s / %s)' % (rate, sum_correct, sum_all))
return rate
在實(shí)驗(yàn)輸出的效果中,Loss值逐漸減少,驗(yàn)證準(zhǔn)確率逐步上升,模型收斂效果較好。具體如下:
Epoch: 0, Triplet Loss: 0.26367417
準(zhǔn)確率: 0.9052 (8065 / 8910)
Epoch: 1, Triplet Loss: 0.18126598
準(zhǔn)確率: 0.9297 (8284 / 8910)
Epoch: 2, Triplet Loss: 0.15365836
準(zhǔn)確率: 0.9391 (8367 / 8910)
Epoch: 3, Triplet Loss: 0.13773362
準(zhǔn)確率: 0.9448 (8418 / 8910)
Epoch: 4, Triplet Loss: 0.12188278
準(zhǔn)確率: 0.9495 (8460 / 8910)
Epoch: 5, Triplet Loss: 0.115614936
準(zhǔn)確率: 0.9520 (8482 / 8910)
Epoch: 6, Triplet Loss: 0.10390957
準(zhǔn)確率: 0.9544 (8504 / 8910)
Epoch: 7, Triplet Loss: 0.087059245
準(zhǔn)確率: 0.9569 (8526 / 8910)
Epoch: 8, Triplet Loss: 0.10168926
準(zhǔn)確率: 0.9588 (8543 / 8910)
Epoch: 9, Triplet Loss: 0.06260935
準(zhǔn)確率: 0.9606 (8559 / 8910)
可視化
Triplet Loss的核心功能就是將數(shù)據(jù)編碼為具有可區(qū)分性的特征。使用PCA降維,將樣本特征轉(zhuǎn)換為可視化的二維分布,通過(guò)觀察可知,樣本特征具有一定的區(qū)分性。效果如下:

而原始的數(shù)據(jù)分布,效果較差:

在訓(xùn)練結(jié)束時(shí),執(zhí)行可視化數(shù)據(jù):
- 原始的數(shù)據(jù)和標(biāo)簽
- Triplet Loss網(wǎng)絡(luò)輸出的數(shù)據(jù)和標(biāo)簽
具體實(shí)現(xiàn):
te_data, te_label = transform(te_data, te_label)
tb_projector(te_data, te_label, os.path.join(ROOT_DIR, 'logs', 'origin'))
te_res = base_net(te_data)
tb_projector(te_res.asnumpy(), te_label, os.path.join(ROOT_DIR, 'logs', 'triplet'))
可視化工具以tensorboard為基礎(chǔ),通過(guò)嵌入向量的可視化接口實(shí)現(xiàn)數(shù)據(jù)分布的可視化。在tb_projector()方法中,輸入數(shù)據(jù)、標(biāo)簽和路徑,即可生成可視化的數(shù)據(jù)格式。
具體實(shí)現(xiàn):
def tb_projector(X_test, y_test, log_dir):
metadata = os.path.join(log_dir, 'metadata.tsv')
images = tf.Variable(X_test)
with open(metadata, 'w') as metadata_file: # 把標(biāo)簽寫(xiě)入metadata
for row in y_test:
metadata_file.write('%d\n' % row)
with tf.Session() as sess:
saver = tf.train.Saver([images]) # 把數(shù)據(jù)存儲(chǔ)為矩陣
sess.run(images.initializer) # 圖像初始化
saver.save(sess, os.path.join(log_dir, 'images.ckpt')) # 圖像存儲(chǔ)
config = projector.ProjectorConfig() # 配置
embedding = config.embeddings.add() # 嵌入向量添加
embedding.tensor_name = images.name # Tensor名稱(chēng)
embedding.metadata_path = metadata # Metadata的路徑
projector.visualize_embeddings(tf.summary.FileWriter(log_dir), config) # 可視化嵌入向量
TensorBoard在可視化方面的功能較多,一些其他框架也是使用TensorBoard進(jìn)行數(shù)據(jù)可視化,如tensorboard-pytorch等,可視化為深度學(xué)習(xí)理論提供驗(yàn)證。
TensorBoard需要額外安裝TensorFlow:
pip install tensorflow
Triplet Loss在數(shù)據(jù)編碼領(lǐng)域中,有著重要的作用,算法也非常巧妙,適合相似性推薦等需求,是重要的工業(yè)界需求之一,如推薦菜譜、推薦音樂(lè)、推薦視頻等。Triplet Loss模型可以學(xué)習(xí)到數(shù)據(jù)集中不同樣本的相似性。除了傳統(tǒng)的Triplet Loss損失計(jì)算方法,還有一些有趣的優(yōu)化,如Lossless Triplet Loss等。
OK, that's all! Enjoy it!