Cuando entrenamos nuestros modelos requerimos de monitorear su desempeño durante el entrenamiento, esto nos dará tiempo de reacción para entender por qué rinden como rinden. En esta ocasión manejaremos callbacks personalizados para monitorear el desempeño de nuestra red en términos de precisión y pérdida.

Aplicando callbacks a nuestro modelo

Para esta ocasión crearemos un modelo llamado model_callbacks, virtualmente será idéntico al último modelo trabajado, pero cambiaremos levemente el comportamiento del entrenamiento cuando sea llamado.

model_callback = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(75, (3,3), activation = "relu", input_shape = (28, 28, 1)),
    tf.keras.layers.MaxPool2D((2,2)),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(256, kernel_regularizer = regularizers.l2(1e-5), activation = "relu"),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(128, kernel_regularizer = regularizers.l2(1e-5), activation = "relu"),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(len(classes), activation = "softmax")
])

model_callback.summary()

model_callback.compile(optimizer = "adam", loss = "categorical_crossentropy", metrics = ["accuracy"])

Crearemos nuestro propio callback desde la clase Callback que nos ofrece Keras.

Podemos activar un callback en cualquier momento del ciclo de vida del modelo, para esta ocasión podemos elegir si activarlo al inicio de cada época, durante el entrenamiento o al final, para esta ocasión elegiremos el último caso.

Crearemos nuestra clase TrainingCallback que heredará de Callback, definiremos la función on_epoch_end que se activará cada que termine una época y recibirá como parámetros el objeto mismo, la época y los logs (que contendrán las métricas de la red).

Obtenemos la precisión de los logs y la comparamos, para esta ocasión determinaremos que el modelo se detenga si es mayor a 95% o 0.95, si es así, entonces daremos un pequeño mensaje pantalla y setearemos la variable self.model.stop_training en verdadero para detenerlo prematuramente.

from tensorflow.keras.callbacks import Callback

class TrainingCallback(Callback):
  def on_epoch_end(self, epoch, logs = {}):
    if logs.get("accuracy") > 0.95:
      print("Lo logramos, nuestro modelo llego a 95%, detenemos nuestro modelo")
      self.model.stop_training = True

Para hacer efectivo este callback, creamos una instancia y lo inyectamos al momento de entrenar el modelo en el parámetro de callbacks, puedes notar que recibe una lista, por lo que puedes agregar cuantos quieras.