[Keras] Three ways to use custom validation metrics in Keras

Keras offers some basic metrics to validate the test data set like accuracy, binary accuracy or categorical accuracy. However, sometimes other metrics are more feasable to evaluate your model. In this post I will show three different approaches to apply your cusom metrics in Keras.



Simple callbacks

The simplest one is described in the official Keras documentation. It is basically just a measure, which accepts the true values and the predictions:

def matthews_correlation(y_true, y_pred):
    y_pred_pos = K.round(K.clip(y_pred, 0, 1))
    y_pred_neg = 1 - y_pred_pos

    y_pos = K.round(K.clip(y_true, 0, 1))
    y_neg = 1 - y_pos

    tp = K.sum(y_pos * y_pred_pos)
    tn = K.sum(y_neg * y_pred_neg)

    fp = K.sum(y_neg * y_pred_pos)
    fn = K.sum(y_pos * y_pred_neg)

    numerator = (tp * tn - fp * fn)
    denominator = K.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn))

    return numerator / (denominator + K.epsilon())

To use this metric, we just pass it to the model compilation:

model.compile(optimizer='rmsprop',
              loss='binary_crossentropy',
              metrics=['accuracy', matthews_correlation])

It prints scores for validation and training data:

245063/245063 [==============================] - 63s 256us/step - matthews_correlation: 0.0032 - val_matthews_correlation: 0.0039

Interval metrics on custom validation data

To calculate metrics after a custom number of epochs it is possible to use custom callbacks in Keras like this:

from sklearn.metrics import roc_auc_score
from keras.callbacks import Callback

class IntervalEvaluation(Callback):
    def __init__(self, validation_data=(), interval=10):
        super(Callback, self).__init__()

        self.interval = interval
        self.X_val, self.y_val = validation_data

    def on_epoch_end(self, epoch, logs={}):
        if epoch % self.interval == 0:
            y_pred = self.model.predict_proba(self.X_val, verbose=0)
            score = roc_auc_score(self.y_val, y_pred)
            print("interval evaluation - epoch: {:d} - score: {:.6f}".format(epoch, score))

This metric is passed as a callback:

ival = IntervalEvaluation(validation_data=(x_test2, y_test2), interval=1)
model.fit(x_train, y_train,
          batch_size=8196,
          epochs=256,
          validation_data=[x_test, y_test],   
          class_weight=class_weight,
          callbacks=[ival],
          verbose=1 )

It prints scores after each interval

interval evaluation - epoch: 0 - score: 0.545038

Persisted metrics

To persist all the calculated metrics, it is also possible to use a callback and save the results into the callback object.

from sklearn.metrics import roc_auc_score
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.metrics import f1_score
from sklearn.metrics import cohen_kappa_score
from sklearn.metrics import confusion_matrix
from keras.callbacks import Callback

class SkMetrics(Callback):
    def on_train_begin(self, logs={}):
        self.confusion = []
        self.precision = []
        self.recall = []
        self.f1s = []
        self.kappa = []        

    def on_epoch_end(self, epoch, logs={}):
        score = np.asarray(self.model.predict(self.validation_data[0]))
        predict = np.round(np.asarray(self.model.predict(self.validation_data[0])))
        targ = self.validation_data[1]
        
        self.confusion.append(confusion_matrix(targ, predict))
        self.precision.append(precision_score(targ, predict))
        self.recall.append(recall_score(targ, predict))
        self.f1s.append(f1_score(targ, predict))
        self.kappa.append(cohen_kappa_score(targ, predict))

Pass the metric object as a callback:

skmetrics = SkMetrics()
model.fit(x_train, y_train,
          batch_size=8196,
          epochs=256,
          validation_data=[x_test, y_test],   
          class_weight=class_weight,
          callbacks=[skmetrics],
          verbose=1 )

To access the metrics:

for k in skmetrics.confusion:
    print(k)
[[   156 103782]
 [     0   1090]]
....

Related Posts

1 comment

[…] we define the custom metric, as shown here. In this case we use the AUC […]

Leave a reply