自定义损失函数

keras中, 如果不是特别的任务或网络结构, 往往不需要自己显式地定义损失(这点与tensorflow不同). 例如对于多分类任务, 往往在输出层使用softmax, 再使用交叉熵作为损失函数. 对于keras, 在搭建模型时, 只需要得到最后的概率输出logit, 然后在model.compile方法中指定loss参数为categorical_crossentropy, 然后fit训练即可.

keras.losses中封装了keras中所有的损失函数, 涵盖了各种任务常用的损失函数, 基本可以满足常见的任务需求.

自定义损失函数

当keras中的损失函数不能满足训练的需要时, 就需要自定义损失函数, 然后融入到keras框架中使用即可.

keras中的损失函数, 例如交叉熵函数定义为:

def categorical_crossentropy(y_true, y_pred):
    return K.categorical_crossentropy(y_true, y_pred)

即拥有两个参数, 前者y_true为真是标签, 后者y_pred为预测的概率值. 因此, 新定义的函数必须满足有且只有y_true和y_pred这两个参数.

例如计算focal loss, 对应的公式为:

Lfl=(1y^t)γlogy^tL_{fl}=-(1-\hat{y}_t)^{\gamma}\log \hat{y}_t

计算loss值时只使用到了真实值和预测值, 因此可以使用如下的损失函数:

def binary_focal_loss(gamma=2.0, alpha=0.25):
    """
        Implementation of Focal Loss from the paper in multiclass classification
        Formula:
            loss = -alpha_t*((1-p_t)^gamma)*log(p_t)

            p_t = y_pred, if y_true = 1
            p_t = 1-y_pred, otherwise

            alpha_t = alpha, if y_true=1
            alpha_t = 1-alpha, otherwise

            cross_entropy = -log(p_t)
        Parameters:
            alpha -- the same as wighting factor in balanced cross entropy
            gamma -- focusing parameter for modulating factor (1-p)
        Default value:
            gamma -- 2.0 as mentioned in the paper
            alpha -- 0.25 as mentioned in the paper
        """

    def focal_loss(y_true, y_pred):
        # Define espislon so that the backpropagation will not result int NaN
        # for 0 divisor case
        epsilon = K.epsilon()
        # Add the epsilon to prediction value
        # y_pred = y_pred + epsilon
        # Clip the prediction value
        y_pred = K.clip(y_pred, epsilon, 1.0 - epsilon)
        # Calculate p_t
        p_t = tf.where(K.equal(y_true, 1), alpha_factor, 1-alpha_factor)
        # Calculate alpha_t
        alpha_factor = K.once_like(y_true)*alpha
        alpha_t = tf.where(K.equal(y_true, 1), alpha_factor, 1-alpha_factor)
        # Calculate cross entropy
        cross_entropy = -K.log(p_t)
        weight = alpha_t * K.pow((1-p_t), gamma)
        # Calculate focal loss
        loss = weight * cross_entropy
        # Sum the losses in mini_batch
        loss = K.sum(loss, axis=1)

        return loss
    return focal_loss

分别是二分类和多分类版本. 多分类版本还有个BUG, alpha应当改为每个类别对应不同值, 否则不同类别缩放比例相同, α\alpha参数就失去了作用. 不过这也大大增加了调参与训练的难度.

另外再举一个例子, 还可以对经常使用的softmax-交叉熵损失进行优化, 这种优化可以防止过拟合, 详情参见参考资料中的第一部分.

参考资料

最后更新于

这有帮助吗?