Source code for deephyper.stopper.integration.tensorflow
import warnings
from tensorflow.keras.callbacks import Callback
[docs]class TFKerasStopperCallback(Callback):
def __init__(self, job, monitor="val_loss", mode="min") -> None:
"""Callback to use in conjonction with a DeepHyper ``RunningJob`` to stop the training when the ``Stopper`` is triggered.
.. code-block:: python
def run(job):
callback = TFKerasStopperCallback(job, ...)
...
model.fit(..., callbacks=[callback])
...
Args:
job (RunningJob): The running job created by DeepHyper.
monitor (str, optional): The metric to monitor. It can be any metric collected in the ``History``. Defaults to "val_loss".
mode (str, optional): If the metric is maximized or minimized. Value in ``["max", "min"]``. Defaults to "max".
"""
super().__init__()
self.job = job
self.monitor = monitor
assert mode in ["max", "min"]
self.mode = mode
self.budget = 0
def on_epoch_end(self, epoch, logs=None):
self.budget += 1
self.observe_and_stop(self.budget, logs)
def observe_and_stop(self, budget, logs):
if logs is None:
return
objective = logs.get(self.monitor)
if objective is None:
warnings.warn(
f"Monitor {self.monitor} is not found in the history logs. Stopper will not be able to stop the training. Available logs are: {list(logs.keys())}"
)
return
if self.mode == "min":
objective = -objective
self.job.record(budget, objective)
if self.job.stopped():
self.model.stop_training = True