Source code for deephyper.hpo._random
from typing import Literal, Optional
import numpy as np
from deephyper.hpo._search import Search
from deephyper.hpo._solution import SolutionSelection
from deephyper.hpo.utils import get_inactive_value_of_hyperparameter
__all__ = ["RandomSearch"]
[docs]
class RandomSearch(Search):
"""Random search algorithm used as an example for the API to implement new search algorithms.
.. list-table::
:widths: 25 25 25
:header-rows: 1
* - Single-Objective
- Multi-Objectives
- Failures
* - ✅
- ✅
- ✅
Args:
problem:
object describing the search/optimization problem.
random_state (np.random.RandomState, optional):
Initial random state of the search. Defaults to ``None``.
log_dir (str, optional):
Path to the directoy where results of the search are stored. Defaults to ``"."``.
verbose (int, optional):
Use verbose mode. Defaults to ``0``.
stopper (Stopper, optional):
a stopper to leverage multi-fidelity when evaluating the
function. Defaults to ``None`` which does not use any stopper.
checkpoint_history_to_csv (bool, optional):
wether the results from progressively collected evaluations should be checkpointed
regularly to disc as a csv. Defaults to ``True``.
solution_selection (Literal["argmax_obs", "argmax_est"] | SolutionSelection, optional):
the solution selection strategy. It can be a string where ``"argmax_obs"`` would
select the argmax of observed objective values, and ``"argmax_est"`` would select the
argmax of estimated objective values (through a predictive model).
"""
def __init__(
self,
problem,
random_state=None,
log_dir=".",
verbose=0,
stopper=None,
checkpoint_history_to_csv: bool = True,
solution_selection: Optional[
Literal["argmax_obs", "argmax_est"] | SolutionSelection
] = None,
):
super().__init__(
problem,
random_state,
log_dir,
verbose,
stopper,
checkpoint_history_to_csv,
solution_selection,
)
self._problem.space.seed(self._random_state.randint(0, np.iinfo(np.int32).max))
def _ask(self, n: int = 1) -> list[dict[str, Optional[str | int | float]]]:
"""Ask the search for new configurations to evaluate.
Args:
n (int, optional): The number of configurations to ask. Defaults to 1.
Returns:
List[Dict]: a list of hyperparameter configurations to evaluate.
"""
import warnings
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning)
new_samples = self._problem.space.sample_configuration(size=n)
if not (isinstance(new_samples, list)):
new_samples = [new_samples]
for i, sample in enumerate(new_samples):
sample = dict(sample)
for hp_name in self._problem.hyperparameter_names:
# If the parameter is inactive due to some conditions then we attribute the
# lower bound value to break symmetries and enforce the same representation.
if hp_name not in sample:
sample[hp_name] = get_inactive_value_of_hyperparameter(
self._problem.space[hp_name]
)
# Make sure to have JSON serializable values
if type(sample[hp_name]).__module__ == np.__name__:
sample[hp_name] = sample[hp_name].tolist()
new_samples[i] = sample
return new_samples
def _tell(
self, results: list[tuple[dict[str, Optional[str | int | float]], str | int | float]]
):
"""Tell the search the results of the evaluations.
Args:
results (list[tuple[dict[str, Optional[str | int | float]], str | int | float]]):
a dictionary containing the results of the evaluations.
"""
for config, obj in results:
pass