Source code for deephyper.ensemble._ensemble
from typing import Dict, Sequence
import numpy as np
from deephyper.ensemble.aggregator import Aggregator
from deephyper.evaluator import Evaluator, RunningJob
from deephyper.evaluator.callback import TqdmCallback
from deephyper.evaluator.storage import NullStorage
from deephyper.predictor import Predictor, PredictorLoader
def predict_with_predictor(predictor: Predictor | PredictorLoader, X: np.ndarray):
if isinstance(predictor, PredictorLoader):
predictor = predictor.load()
return predictor.predict(X)
def _wrapper_predict_with_predictor(job: RunningJob):
try:
return predict_with_predictor(**job.parameters)
except Exception as exception:
return exception
[docs]
class EnsemblePredictor(Predictor):
"""A predictor that is itself an ensemble of multiple predictors.
Args:
predictors (Sequence[Predictor | PredictorLoader]): the list of predictors to put in the
ensemble. The sequence can be composed of ``Predictor`` (i.e., the model is already
loaded in memory) or ``PredictorLoader`` to perform the loading remotely and in
parallel. In the later case, the ``.load()`` function is called for each inference.
aggregator (Aggregator): the aggregation function to fuse the predictions of the predictors
into one prediction.
weights (Sequence[float], optional): the weights of the predictors in the aggregation.
Defaults to ``None``.
evaluator (str | Dict, optional): The parallel strategy to compute predictions from the
list of predictions. If it is a ``str`` it must be a possible ``method`` of
``Evaluator.create(..., method=...)``. If it is a ``dict`` it must have two keys
``method`` and ``method_kwargs`` such as ``Evaluator.create(...)``. Defaults to
``None`` which is equivalent to ``evaluator="serial"`` for serial evaluations.
Raises:
ValueError: when the type of the ``evaluator`` argument is not ``str`` or ``dict``.
"""
def __init__(
self,
predictors: Sequence[Predictor | PredictorLoader],
aggregator: Aggregator,
weights: Sequence[float] = None,
evaluator: str | Dict = None,
):
self.predictors = predictors
self.aggregator = aggregator
self.weights = weights
self._evaluator = None
if evaluator is None:
self.evaluator_method = "thread"
self.evaluator_method_kwargs = {}
elif isinstance(evaluator, str):
self.evaluator_method = evaluator
self.evaluator_method_kwargs = {}
elif isinstance(evaluator, dict):
self.evaluator_method = evaluator.get("method", "serial")
self.evaluator_method_kwargs = evaluator.get("method_kwargs", {})
else:
raise ValueError(f"evaluator must be either None or str or dict, got {type(evaluator)}")
self.init_evaluator()
[docs]
def init_evaluator(self):
"""Initialize an evaluator for the ensemble.
Returns:
Evaluator: An evaluator instance.
"""
method_kwargs = {
"storage": NullStorage(),
"run_function_kwargs": {},
}
method_kwargs.update(self.evaluator_method_kwargs)
self._evaluator = Evaluator.create(
run_function=_wrapper_predict_with_predictor,
method=self.evaluator_method,
method_kwargs=method_kwargs,
)
[docs]
def predict(self, X: np.ndarray):
"""Compute the prediction of the ensemble.
Args:
X (np.ndarray): the input query for the prediction.
Returns:
np.ndarray: the target prediction.
"""
y_predictors = self.predictions_from_predictors(X, self.predictors)
y = self.aggregator.aggregate(y_predictors, weights=self.weights)
return y
[docs]
def predictions_from_predictors(
self, X: np.ndarray, predictors: Sequence[Predictor | PredictorLoader]
):
"""Compute the predictions of a list of predictors.
Args:
X (np.ndarray): the input query for the predictions.
predictors (Sequence[Predictor]): the list of predictors to compute the predictions.
Returns:
List[np.ndarray]: the sequence of predictions in the same order that the list of
predictors.
"""
n_jobs_submitted = len(predictors)
for cb in self._evaluator._callbacks:
if isinstance(cb, TqdmCallback):
cb.set_max_evals(n_jobs_submitted)
self._evaluator.submit(
[
{
"predictor": predictor,
"X": X,
}
for predictor in predictors
]
)
jobs_done = []
while len(jobs_done) != n_jobs_submitted:
new_jobs_done = self._evaluator.gather("BATCH", size=1)
jobs_done.extend(new_jobs_done)
jobs_done = list(sorted(jobs_done, key=lambda j: int(j.id.split(".")[-1])))
self._evaluator.close()
y_pred = []
for i, job in enumerate(jobs_done):
if isinstance(job.output, Exception):
try:
raise job.output
except Exception:
raise RuntimeError(
f"Failed to call .predict(X) with predictors[{i}]: {predictors[i]}"
)
else:
y_pred.append(job.output)
return y_pred