def binary_focal_loss(gamma=2.0, alpha=0.25):
"""
Implementation of Focal Loss from the paper in multiclass classification
Formula:
loss = -alpha_t*((1-p_t)^gamma)*log(p_t)
p_t = y_pred, if y_true = 1
p_t = 1-y_pred, otherwise
alpha_t = alpha, if y_true=1
alpha_t = 1-alpha, otherwise
cross_entropy = -log(p_t)
Parameters:
alpha -- the same as wighting factor in balanced cross entropy
gamma -- focusing parameter for modulating factor (1-p)
Default value:
gamma -- 2.0 as mentioned in the paper
alpha -- 0.25 as mentioned in the paper
"""
def focal_loss(y_true, y_pred):
# Define espislon so that the backpropagation will not result int NaN
# for 0 divisor case
epsilon = K.epsilon()
# Add the epsilon to prediction value
# y_pred = y_pred + epsilon
# Clip the prediction value
y_pred = K.clip(y_pred, epsilon, 1.0 - epsilon)
# Calculate p_t
p_t = tf.where(K.equal(y_true, 1), alpha_factor, 1-alpha_factor)
# Calculate alpha_t
alpha_factor = K.once_like(y_true)*alpha
alpha_t = tf.where(K.equal(y_true, 1), alpha_factor, 1-alpha_factor)
# Calculate cross entropy
cross_entropy = -K.log(p_t)
weight = alpha_t * K.pow((1-p_t), gamma)
# Calculate focal loss
loss = weight * cross_entropy
# Sum the losses in mini_batch
loss = K.sum(loss, axis=1)
return loss
return focal_loss
def catergorical_focal_loss(gamma = 2.0, alpha = 0.25):
"""
Implementation of Focal Loss from the paper in multiclass classification
Formula:
loss = -alpha*((1-p_t)^gamma)*log(p_t)
Parameters:
alpha -- the same as wighting factor in balanced cross entropy
gamma -- focusing parameter for modulating factor (1-p)
Default value:
gamma -- 2.0 as mentioned in the paper
alpha -- 0.25 as mentioned in the paper
"""
def focal_loss(y_true, y_pred):
# Define epsilon so that the backpropagation will no result in NaN
# for o divisor case
epsilon = K.epsilon()
# Add the epsilon to prediction value
# y_pred = y_pred + epsilon
# Clip the prediction value
y_pred = K.clip(y_pred, epsilon, 1.0-epsilon)
# Calculate cross entropy
cross_entropy = -y_true * k.log(y_pred)
# Calculate weight that consists of modulating factor and weighting factor
weight = alpha * y_true * K.pow((1-y_pred), gamma)
# Calculate focal loss
loss = weight * cross_entropy
# Sum the losses in mini_batch
loss = K.sum(loss, axis=1)
return loss
return focal_loss