Source code for deephyper.keras.callbacks.utils
from typing import Type
import deephyper
import deephyper.core.exceptions
import tensorflow as tf
[docs]def import_callback(cb_name: str) -> Type[tf.keras.callbacks.Callback]:
"""Import a callback class from its name.
Args:
cb_name (str): class name of the callback to import fron ``tensorflow.keras.callbacks`` or ``deephyper.keras.callbacks``.
Raises:
DeephyperRuntimeError: raised if the class name of the callback is not registered in corresponding packages.
Returns:
tensorflow.keras.callbacks.Callback: the class corresponding to the given class name.
"""
if cb_name in dir(tf.keras.callbacks):
return getattr(tf.keras.callbacks, cb_name)
elif cb_name in dir(deephyper.keras.callbacks):
return getattr(deephyper.keras.callbacks, cb_name)
else:
raise deephyper.core.exceptions.DeephyperRuntimeError(
f"Callback '{cb_name}' is not registered in tensorflow.keras and deephyper.keras.callbacks."
)