Source code for deephyper.predictor.tf_keras2._predictor_tf_keras2
from typing import List
import tensorflow as tf
# TODO: Check if this import could be removed. Currently, the following import
# is necessary to avoid a bug if checkpointed models that are loaded have tfp
# layers.
# import tensorflow_probability as tfp
import tf_keras as tfk
from deephyper.predictor import Predictor, PredictorFileLoader
[docs]
class TFKeras2Predictor(Predictor):
"""Represents a frozen TensorFlow/Keras2 model that can only predict."""
def __init__(self, model: tfk.Model):
self.model = model
def pre_process_inputs(self, X):
return X
def post_process_predictions(self, y):
if isinstance(y, tf.Tensor):
y = y.numpy()
elif isinstance(y, dict):
y = {k: v.numpy() for k, v in y.items()}
elif isinstance(y, list):
y = [yi.numpy() for yi in y]
return y
[docs]
def predict(self, X):
X = self.pre_process_inputs(X)
y = self.model(X, training=False)
y = self.post_process_predictions(y)
return y
[docs]
class TFKeras2PredictorFileLoader(PredictorFileLoader):
"""Loads a predictor from a file for the TensorFlow/Keras2 backend.
Args:
path_predictor_file (str): the path to the predictor file.
"""
def __init__(self, path_predictor_file: str):
super().__init__(path_predictor_file)
[docs]
def load(self) -> TFKeras2Predictor:
model = tfk.models.load_model(
self.path_predictor_file, compile=False, safe_mode=False
)
return TFKeras2Predictor(model)
[docs]
@staticmethod
def find_predictor_files(
path_directory: str, file_extension: str = "keras"
) -> List[str]:
"""Finds the predictor files in a directory given a specific extension.
Args:
path_directory (str): the directory path.
file_extension (str, optional): the file extension. Defaults to ``"keras"``.
returns:
List[str]: the list of predictor files found in the directory.
"""
return PredictorFileLoader.find_predictor_files(path_directory, file_extension)