deephyper.predictor.torch.TorchPredictor

deephyper.predictor.torch.TorchPredictor#

class deephyper.predictor.torch.TorchPredictor(module: torch.nn.Module)[source]#

Bases: Predictor

Represents a frozen torch model that can only predict.

Methods

post_process_predictions

pre_process_inputs

predict

Predicts the target for the inputs.

predict(X)[source]#

Predicts the target for the inputs.

Parameters:

X (np.ndarray) – the inputs.

Returns:

the predicted target.

Return type:

np.ndarray