Source code for deephyper.keras.callbacks.stop_if_unfeasible

import time

import tensorflow as tf


[docs]class StopIfUnfeasible(tf.keras.callbacks.Callback): def __init__(self, time_limit=600, patience=20): super().__init__() self.time_limit = time_limit self.timing = list() self.stopped = False # boolean set to True if the model training has been stopped due to time_limit condition self.patience = patience def set_params(self, params): self.params = params if self.params["steps"] is None: self.steps = self.params["samples"] // self.params["batch_size"] self.steps = self.params["samples"] // self.params["batch_size"] if self.steps * self.params["batch_size"] < self.params["samples"]: self.steps += 1 else: self.steps = self.params["steps"]
[docs] def on_batch_begin(self, batch, logs=None): """Called at the beginning of a training batch in `fit` methods. Subclasses should override for any actions to run. Args: batch (int): index of batch within the current epoch. logs (dict): has keys `batch` and `size` representing the current batch number and the size of the batch. """ self.timing.append(time.time())
[docs] def on_batch_end(self, batch, logs=None): """Called at the end of a training batch in `fit` methods. Subclasses should override for any actions to run. Args: batch (int): index of batch within the current epoch. logs (dict): metric results for this batch. """ self.timing[-1] = time.time() - self.timing[-1] self.avr_batch_time = sum(self.timing) / len(self.timing) self.estimate_training_time = sum(self.timing) + self.avr_batch_time * ( self.steps - len(self.timing) ) if ( len(self.timing) >= self.patience and self.estimate_training_time > self.time_limit ): self.stopped = True self.model.stop_training = True