引入
triplet loss在图像和NLP领域都有使用, 主要是用来做一些匹配问题. 由于匹配的目标不是一个非黑即白的问题, 更贴近一个Ranking模型, 因此不适合二分类对应的损失函数, 通常是答案之间进行比较出最好的选择.
triplet loss常用于文本匹配(相似, FAQ等), 图像精细识别(人脸识别)等领域.
triplet loss是一个三元损失. 之所以称为三元, 是因为一次训练中包含了三个样本:
Anchor, 锚样本, 即另外两个样本都与这个样本做交互, 判断谁与这个样本关系更近
Positive, 正样本, 事先知道的与Anchor关系更近的样本, 如FAQ中的正确答案
Negative, 负样本, 如FAQ中的错误样本
只需要要求正样本与锚样本的分值大于负样本的即可, 这种对差异性的度量, 也有助于细节的区分.
原理
以FAQ任务为例, 先分别将答案和问题都encode成为一个同样长度的向量,然后比较它们的cos值, 值越大越匹配. 而且, 我们只要求正确答案的cos值比错误的大就好, 至于大多少不重要, 因此triplet loss可以记为:
loss=max(0,m+cos(q,Awrong)−cos(q,Aright))
其中的m是一个大于0的正数.
当然这里衡量答案和问题(锚样本和正样本或负样本)关系的函数是cos, 在其他情况下也可以换成其他的函数, 例如欧式距离等等.
最小化loss, 根据上式, triplet loss的思想就是: 只希望正确比错误答案的差距大一点(并不是越大越好), 超过m就不再需要管(此时没有损失, 不会再有反向梯度更新参数), 集中精力关心那些还没有拉开的样本吧. 这也体现了这个损失函数专注样本之间差异性的度量.
在构造训练样本时, 肯定有问题和正确答案, 再从错误答案中随机抽取即可. 在预测阶段, 只需要传入问题和我们要确定的答案即可, 可以使用阈值的方法, 判断是否是正确答案.
特点
优点
使用分类损失训练模型, 对于类别较多的任务(人脸识别), 在输出层引入的参数量是非常大, 甚至难以承受的
如果目的是为了得到feature embedding, 使用triplet loss训练得到的结果往往比分类任务训练得到的效果更好
缺点
容易过拟合, 因为可能还有很多情况未考虑到, 在样本量少的情况下常见
代码
from keras.layers import Input,Embedding,LSTM,Dense,Lambda
from keras.layers.merge import dot
from keras.models import Model
from keras import backend as K
word_size = 128
nb_features = 10000
nb_classes = 10
encode_size = 64
margin = 0.1
embedding = Embedding(nb_features,word_size)
lstm_encoder = LSTM(encode_size)
def encode(input):
return lstm_encoder(embedding(input))
q_input = Input(shape=(None,))
a_right = Input(shape=(None,))
a_wrong = Input(shape=(None,))
q_encoded = encode(q_input)
a_right_encoded = encode(a_right)
a_wrong_encoded = encode(a_wrong)
q_encoded = Dense(encode_size)(q_encoded) #一般的做法是,直接讲问题和答案用同样的方法encode成向量后直接匹配,但我认为这是不合理的,我认为至少经过某个变换。
right_cos = dot([q_encoded,a_right_encoded], -1, normalize=True)
wrong_cos = dot([q_encoded,a_wrong_encoded], -1, normalize=True)
loss = Lambda(lambda x: K.relu(margin+x[0]-x[1]))([wrong_cos,right_cos])
model_train = Model(inputs=[q_input,a_right,a_wrong], outputs=loss)
model_q_encoder = Model(inputs=q_input, outputs=q_encoded)
model_a_encoder = Model(inputs=a_right, outputs=a_right_encoded)
model_train.compile(optimizer='adam', loss=lambda y_true,y_pred: y_pred)
model_q_encoder.compile(optimizer='adam', loss='mse')
model_a_encoder.compile(optimizer='adam', loss='mse')
model_train.fit([q,a1,a2], y, epochs=10)
#其中q,a1,a2分别是问题、正确答案、错误答案的batch,y是任意形状为(len(q),1)的矩阵
参考资料