Source code for deephyper.predictor.torch._predictor_torch
from typing import List
import torch
from deephyper.predictor import Predictor, PredictorFileLoader
[docs]
class TorchPredictor(Predictor):
"""Represents a frozen torch model that can only predict."""
def __init__(self, module: torch.nn.Module):
if not isinstance(module, torch.nn.Module):
raise ValueError(
f"The given module is of type {type(module)} when it should be of type "
f"torch.nn.Module!"
)
self.module = module
def pre_process_inputs(self, X):
X = torch.from_numpy(X)
return X
def post_process_predictions(self, y):
if isinstance(y, torch.Tensor):
y = y.detach().numpy()
elif isinstance(y, dict):
y = {k: v.detach().numpy() for k, v in y.items()}
elif isinstance(y, list):
y = [yi.detach().numpy() for yi in y]
return y
[docs]
def predict(self, X):
X = self.pre_process_inputs(X)
training = self.module.training
if training:
self.module.eval()
if hasattr(self.module, "predict_proba"):
y = self.module.predict_proba(X)
else:
y = self.module(X)
self.module.train(training)
y = self.post_process_predictions(y)
return y
[docs]
class TorchPredictorFileLoader(PredictorFileLoader):
"""Loads a predictor from a file for the Pytorch backend.
Args:
path_predictor_file (str): the path to the predictor file.
"""
def __init__(self, path_predictor_file: str):
super().__init__(path_predictor_file)
[docs]
def load(self) -> TorchPredictor:
model = torch.load(self.path_predictor_file, weights_only=False)
return TorchPredictor(model)
[docs]
@staticmethod
def find_predictor_files(path_directory: str, file_extension: str = "pt") -> List[str]:
"""Finds the predictor files in a directory given a specific extension.
Args:
path_directory (str): the directory path.
file_extension (str, optional): the file extension. Defaults to ``"pt"``.
Returns:
List[str]: the list of predictor files found in the directory.
"""
return PredictorFileLoader.find_predictor_files(path_directory, file_extension)