Source code for deephyper.stopper.integration.deepxde
from deepxde.callbacks import Callback
[docs]
class DeepXDEStopperCallback(Callback):
def __init__(self, job, mode="min", monitor="loss_test"):
super().__init__()
"""Callback to use in conjonction with a DeepHyper ``RunningJob`` to stop the training when the ``Stopper`` is triggered.
.. code-block:: python
def run(job):
callback = DeepXDEStopperCallback(job, ...)
...
model.train(..., callbacks=[callback])
...
Args:
job (RunningJob): The running job created by DeepHyper.
monitor (str, optional): The metric to monitor. It can be any metric collected in the ``History``. Defaults to "loss_test" (equivalent to validation loss).
mode (str, optional): If the metric is maximized or minimized. Value in ``["max", "min"]``. Defaults to "min".
"""
self.job = job
self.monitor = monitor
assert mode in ["max", "min"]
self.mode = mode
self.budget = 0
self.stopped = False
def on_epoch_end(self):
self.budget += 1
self.observe_and_stop(self.budget)
def observe_and_stop(self, budget):
objective = self.get_monitor_value()
if self.mode == "min":
objective = -objective
self.job.record(budget, objective)
if self.job.stopped():
self.model.stop_training = True
self.stopped = True
def get_monitor_value(self):
if self.monitor == "loss_train":
result = sum(self.model.train_state.loss_train)
elif self.monitor == "loss_test":
result = sum(self.model.train_state.loss_test)
else:
raise ValueError("The specified monitor function is incorrect.")
return result