3. Multi-Fidelity Hyperparameter Optimization with Keras#
In this tutorial we present how to use hyperparameter optimization on a basic example from the Keras documentation. We follow the previous tutorial based on the same example and add multi-fidelity to it. The purpose of multi-fidelity is to dynamically manage the budget allocated (also called fidelity) to evaluate an hyperparameter configuration. For example, when training a deep neural network the number of epochs can be continued or stopped based on currently observed performance and some policy.
In DeepHyper, the multi-fidelity agent is designed separately from the hyperparameter search agent. Of course, both can communicate but from an API perspective they are different objects. The multi-fidelity agents are called Stopper
in DeepHyper and their documentation can be found at deephyper.stopper.
In this notebook, we will demonstrate how to use multi-fidelity inside sequential Bayesian optimization. When moving to a distributed setting, it is important to use a shared database accessible by all workers otherwise the multi-fidelity scheme may not work properly. An example, of database instanciation for parallel computing is explained in: Introduction to Distributed Bayesian Optimization (DBO) with MPI (Communication) and Redis (Storage).
Reference: This tutorial is based on materials from the Keras Documentation: Structured data classification from scratch
Let us start with installing DeepHyper!
Warning
This tutorial should be run with tensorflow>=2.6
.
[1]:
!pip install "deephyper[jax-cpu]"
import deephyper
print(deephyper.__version__)
Requirement already satisfied: deephyper[jax-cpu] in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (0.4.3)
Requirement already satisfied: numpyro[cpu] in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from deephyper[jax-cpu]) (0.10.1)
Requirement already satisfied: tensorflow>=2.0.0 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from deephyper[jax-cpu]) (2.10.0)
Requirement already satisfied: parse in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from deephyper[jax-cpu]) (1.19.0)
Requirement already satisfied: scikit-learn>=0.23.1 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from deephyper[jax-cpu]) (1.1.2)
Requirement already satisfied: packaging in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from deephyper[jax-cpu]) (21.3)
Requirement already satisfied: pyyaml in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from deephyper[jax-cpu]) (5.4.1)
Requirement already satisfied: jax[cpu]>=0.3.25 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from deephyper[jax-cpu]) (0.3.25)
Requirement already satisfied: pandas>=0.24.2 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from deephyper[jax-cpu]) (1.5.1)
Requirement already satisfied: ConfigSpace>=0.4.20 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from deephyper[jax-cpu]) (0.5.0)
Requirement already satisfied: tqdm>=4.64.0 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from deephyper[jax-cpu]) (4.64.1)
Requirement already satisfied: networkx in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from deephyper[jax-cpu]) (2.6.3)
Requirement already satisfied: tensorflow-probability in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from deephyper[jax-cpu]) (0.14.1)
Requirement already satisfied: sdv>=0.17.1 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from deephyper[jax-cpu]) (0.17.1)
Requirement already satisfied: pydot in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from deephyper[jax-cpu]) (1.4.2)
Requirement already satisfied: Jinja2<3.1 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from deephyper[jax-cpu]) (3.0.3)
Requirement already satisfied: numpy in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from deephyper[jax-cpu]) (1.22.4)
Requirement already satisfied: dm-tree in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from deephyper[jax-cpu]) (0.1.7)
Requirement already satisfied: ray[default]>=1.3.0 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from deephyper[jax-cpu]) (2.3.0)
Requirement already satisfied: scipy>=1.7 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from deephyper[jax-cpu]) (1.7.3)
Requirement already satisfied: cython in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from ConfigSpace>=0.4.20->deephyper[jax-cpu]) (0.29.32)
Requirement already satisfied: pyparsing in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from ConfigSpace>=0.4.20->deephyper[jax-cpu]) (3.0.9)
Requirement already satisfied: opt-einsum in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from jax[cpu]>=0.3.25->deephyper[jax-cpu]) (3.3.0)
Requirement already satisfied: typing-extensions in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from jax[cpu]>=0.3.25->deephyper[jax-cpu]) (4.4.0)
Requirement already satisfied: jaxlib==0.3.25 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from jax[cpu]>=0.3.25->deephyper[jax-cpu]) (0.3.25)
Requirement already satisfied: MarkupSafe>=2.0 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from Jinja2<3.1->deephyper[jax-cpu]) (2.0.1)
Requirement already satisfied: pytz>=2020.1 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from pandas>=0.24.2->deephyper[jax-cpu]) (2022.6)
Requirement already satisfied: python-dateutil>=2.8.1 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from pandas>=0.24.2->deephyper[jax-cpu]) (2.8.2)
Requirement already satisfied: aiosignal in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from ray[default]>=1.3.0->deephyper[jax-cpu]) (1.2.0)
Requirement already satisfied: msgpack<2.0.0,>=1.0.0 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from ray[default]>=1.3.0->deephyper[jax-cpu]) (1.0.4)
Requirement already satisfied: click>=7.0 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from ray[default]>=1.3.0->deephyper[jax-cpu]) (8.1.3)
Requirement already satisfied: jsonschema in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from ray[default]>=1.3.0->deephyper[jax-cpu]) (4.17.3)
Requirement already satisfied: filelock in /Users/romainegele/.local/lib/python3.9/site-packages (from ray[default]>=1.3.0->deephyper[jax-cpu]) (3.9.0)
Requirement already satisfied: attrs in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from ray[default]>=1.3.0->deephyper[jax-cpu]) (22.1.0)
Requirement already satisfied: frozenlist in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from ray[default]>=1.3.0->deephyper[jax-cpu]) (1.3.1)
Requirement already satisfied: virtualenv>=20.0.24 in /Users/romainegele/.local/lib/python3.9/site-packages (from ray[default]>=1.3.0->deephyper[jax-cpu]) (20.17.1)
Requirement already satisfied: requests in /Users/romainegele/.local/lib/python3.9/site-packages (from ray[default]>=1.3.0->deephyper[jax-cpu]) (2.28.2)
Collecting grpcio<=1.49.1,>=1.32.0
Downloading grpcio-1.49.1.tar.gz (22.1 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 22.1/22.1 MB 926.5 kB/s eta 0:00:00m eta 0:00:01[36m0:00:01
Preparing metadata (setup.py) ... done
Requirement already satisfied: protobuf!=3.19.5,>=3.15.3 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from ray[default]>=1.3.0->deephyper[jax-cpu]) (3.20.3)
Requirement already satisfied: py-spy>=0.2.0 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from ray[default]>=1.3.0->deephyper[jax-cpu]) (0.3.12)
Requirement already satisfied: gpustat>=1.0.0 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from ray[default]>=1.3.0->deephyper[jax-cpu]) (1.0.0)
Requirement already satisfied: colorful in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from ray[default]>=1.3.0->deephyper[jax-cpu]) (0.5.4)
Requirement already satisfied: opencensus in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from ray[default]>=1.3.0->deephyper[jax-cpu]) (0.11.0)
Requirement already satisfied: aiohttp-cors in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from ray[default]>=1.3.0->deephyper[jax-cpu]) (0.7.0)
Requirement already satisfied: smart-open in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from ray[default]>=1.3.0->deephyper[jax-cpu]) (6.1.0)
Requirement already satisfied: prometheus-client>=0.7.1 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from ray[default]>=1.3.0->deephyper[jax-cpu]) (0.13.1)
Requirement already satisfied: aiohttp>=3.7 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from ray[default]>=1.3.0->deephyper[jax-cpu]) (3.8.3)
Requirement already satisfied: pydantic in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from ray[default]>=1.3.0->deephyper[jax-cpu]) (1.10.2)
Requirement already satisfied: joblib>=1.0.0 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from scikit-learn>=0.23.1->deephyper[jax-cpu]) (1.2.0)
Requirement already satisfied: threadpoolctl>=2.0.0 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from scikit-learn>=0.23.1->deephyper[jax-cpu]) (3.1.0)
Requirement already satisfied: copulas<0.8,>=0.7.0 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from sdv>=0.17.1->deephyper[jax-cpu]) (0.7.0)
Requirement already satisfied: deepecho<0.4,>=0.3.0.post1 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from sdv>=0.17.1->deephyper[jax-cpu]) (0.3.0.post1)
Requirement already satisfied: ctgan<0.6,>=0.5.2 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from sdv>=0.17.1->deephyper[jax-cpu]) (0.5.2)
Requirement already satisfied: graphviz<1,>=0.13.2 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from sdv>=0.17.1->deephyper[jax-cpu]) (0.20.1)
Requirement already satisfied: rdt<1.3.0,>=1.2.0 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from sdv>=0.17.1->deephyper[jax-cpu]) (1.2.1)
Requirement already satisfied: Faker<15,>=10 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from sdv>=0.17.1->deephyper[jax-cpu]) (14.2.1)
Requirement already satisfied: cloudpickle<3.0,>=2.1.0 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from sdv>=0.17.1->deephyper[jax-cpu]) (2.2.0)
Requirement already satisfied: sdmetrics<0.8,>=0.7.0.dev0 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from sdv>=0.17.1->deephyper[jax-cpu]) (0.7.0)
Requirement already satisfied: gast<=0.4.0,>=0.2.1 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from tensorflow>=2.0.0->deephyper[jax-cpu]) (0.4.0)
Requirement already satisfied: flatbuffers>=2.0 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from tensorflow>=2.0.0->deephyper[jax-cpu]) (2.0)
Requirement already satisfied: six>=1.12.0 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from tensorflow>=2.0.0->deephyper[jax-cpu]) (1.16.0)
Requirement already satisfied: h5py>=2.9.0 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from tensorflow>=2.0.0->deephyper[jax-cpu]) (3.7.0)
Requirement already satisfied: google-pasta>=0.1.1 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from tensorflow>=2.0.0->deephyper[jax-cpu]) (0.2.0)
Requirement already satisfied: keras-preprocessing>=1.1.1 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from tensorflow>=2.0.0->deephyper[jax-cpu]) (1.1.2)
Requirement already satisfied: astunparse>=1.6.0 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from tensorflow>=2.0.0->deephyper[jax-cpu]) (1.6.3)
Requirement already satisfied: keras<2.11,>=2.10.0 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from tensorflow>=2.0.0->deephyper[jax-cpu]) (2.10.0)
Requirement already satisfied: termcolor>=1.1.0 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from tensorflow>=2.0.0->deephyper[jax-cpu]) (2.0.1)
Requirement already satisfied: tensorboard<2.11,>=2.10 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from tensorflow>=2.0.0->deephyper[jax-cpu]) (2.10.1)
Requirement already satisfied: tensorflow-estimator<2.11,>=2.10.0 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from tensorflow>=2.0.0->deephyper[jax-cpu]) (2.10.0)
Requirement already satisfied: wrapt>=1.11.0 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from tensorflow>=2.0.0->deephyper[jax-cpu]) (1.14.1)
Requirement already satisfied: absl-py>=1.0.0 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from tensorflow>=2.0.0->deephyper[jax-cpu]) (1.2.0)
Requirement already satisfied: setuptools in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from tensorflow>=2.0.0->deephyper[jax-cpu]) (65.4.1)
Requirement already satisfied: multipledispatch in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from numpyro[cpu]->deephyper[jax-cpu]) (0.6.0)
Requirement already satisfied: decorator in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from tensorflow-probability->deephyper[jax-cpu]) (5.1.1)
Requirement already satisfied: yarl<2.0,>=1.0 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from aiohttp>=3.7->ray[default]>=1.3.0->deephyper[jax-cpu]) (1.7.2)
Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from aiohttp>=3.7->ray[default]>=1.3.0->deephyper[jax-cpu]) (2.1.1)
Requirement already satisfied: multidict<7.0,>=4.5 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from aiohttp>=3.7->ray[default]>=1.3.0->deephyper[jax-cpu]) (6.0.2)
Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from aiohttp>=3.7->ray[default]>=1.3.0->deephyper[jax-cpu]) (4.0.2)
Requirement already satisfied: wheel<1.0,>=0.23.0 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from astunparse>=1.6.0->tensorflow>=2.0.0->deephyper[jax-cpu]) (0.37.1)
Requirement already satisfied: matplotlib<4,>=3.4.0 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from copulas<0.8,>=0.7.0->sdv>=0.17.1->deephyper[jax-cpu]) (3.6.2)
Requirement already satisfied: torchvision<1,>=0.9.0 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from ctgan<0.6,>=0.5.2->sdv>=0.17.1->deephyper[jax-cpu]) (0.14.0)
Requirement already satisfied: torch<2,>=1.8.0 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from ctgan<0.6,>=0.5.2->sdv>=0.17.1->deephyper[jax-cpu]) (1.13.0)
Requirement already satisfied: psutil>=5.6.0 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from gpustat>=1.0.0->ray[default]>=1.3.0->deephyper[jax-cpu]) (5.9.4)
Requirement already satisfied: blessed>=1.17.1 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from gpustat>=1.0.0->ray[default]>=1.3.0->deephyper[jax-cpu]) (1.19.1)
Requirement already satisfied: nvidia-ml-py<=11.495.46,>=11.450.129 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from gpustat>=1.0.0->ray[default]>=1.3.0->deephyper[jax-cpu]) (11.495.46)
Requirement already satisfied: plotly<6,>=5.10.0 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from sdmetrics<0.8,>=0.7.0.dev0->sdv>=0.17.1->deephyper[jax-cpu]) (5.11.0)
Requirement already satisfied: markdown>=2.6.8 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from tensorboard<2.11,>=2.10->tensorflow>=2.0.0->deephyper[jax-cpu]) (3.4.1)
Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from tensorboard<2.11,>=2.10->tensorflow>=2.0.0->deephyper[jax-cpu]) (0.4.6)
Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from tensorboard<2.11,>=2.10->tensorflow>=2.0.0->deephyper[jax-cpu]) (1.8.1)
Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from tensorboard<2.11,>=2.10->tensorflow>=2.0.0->deephyper[jax-cpu]) (0.6.0)
Requirement already satisfied: werkzeug>=1.0.1 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from tensorboard<2.11,>=2.10->tensorflow>=2.0.0->deephyper[jax-cpu]) (2.1.2)
Requirement already satisfied: google-auth<3,>=1.6.3 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from tensorboard<2.11,>=2.10->tensorflow>=2.0.0->deephyper[jax-cpu]) (2.12.0)
Requirement already satisfied: idna<4,>=2.5 in /Users/romainegele/.local/lib/python3.9/site-packages (from requests->ray[default]>=1.3.0->deephyper[jax-cpu]) (3.4)
Requirement already satisfied: certifi>=2017.4.17 in /Users/romainegele/.local/lib/python3.9/site-packages (from requests->ray[default]>=1.3.0->deephyper[jax-cpu]) (2022.12.7)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /Users/romainegele/.local/lib/python3.9/site-packages (from requests->ray[default]>=1.3.0->deephyper[jax-cpu]) (1.26.14)
Requirement already satisfied: platformdirs<3,>=2.4 in /Users/romainegele/.local/lib/python3.9/site-packages (from virtualenv>=20.0.24->ray[default]>=1.3.0->deephyper[jax-cpu]) (2.6.2)
Requirement already satisfied: distlib<1,>=0.3.6 in /Users/romainegele/.local/lib/python3.9/site-packages (from virtualenv>=20.0.24->ray[default]>=1.3.0->deephyper[jax-cpu]) (0.3.6)
Requirement already satisfied: pyrsistent!=0.17.0,!=0.17.1,!=0.17.2,>=0.14.0 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from jsonschema->ray[default]>=1.3.0->deephyper[jax-cpu]) (0.19.2)
Requirement already satisfied: opencensus-context>=0.1.3 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from opencensus->ray[default]>=1.3.0->deephyper[jax-cpu]) (0.1.3)
Requirement already satisfied: google-api-core<3.0.0,>=1.0.0 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from opencensus->ray[default]>=1.3.0->deephyper[jax-cpu]) (2.8.2)
Requirement already satisfied: wcwidth>=0.1.4 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from blessed>=1.17.1->gpustat>=1.0.0->ray[default]>=1.3.0->deephyper[jax-cpu]) (0.2.5)
Requirement already satisfied: googleapis-common-protos<2.0dev,>=1.56.2 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from google-api-core<3.0.0,>=1.0.0->opencensus->ray[default]>=1.3.0->deephyper[jax-cpu]) (1.56.4)
Requirement already satisfied: pyasn1-modules>=0.2.1 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.11,>=2.10->tensorflow>=2.0.0->deephyper[jax-cpu]) (0.2.8)
Requirement already satisfied: cachetools<6.0,>=2.0.0 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.11,>=2.10->tensorflow>=2.0.0->deephyper[jax-cpu]) (5.2.0)
Requirement already satisfied: rsa<5,>=3.1.4 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.11,>=2.10->tensorflow>=2.0.0->deephyper[jax-cpu]) (4.9)
Requirement already satisfied: requests-oauthlib>=0.7.0 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard<2.11,>=2.10->tensorflow>=2.0.0->deephyper[jax-cpu]) (1.3.1)
Requirement already satisfied: importlib-metadata>=4.4 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from markdown>=2.6.8->tensorboard<2.11,>=2.10->tensorflow>=2.0.0->deephyper[jax-cpu]) (5.1.0)
Requirement already satisfied: fonttools>=4.22.0 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from matplotlib<4,>=3.4.0->copulas<0.8,>=0.7.0->sdv>=0.17.1->deephyper[jax-cpu]) (4.38.0)
Requirement already satisfied: cycler>=0.10 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from matplotlib<4,>=3.4.0->copulas<0.8,>=0.7.0->sdv>=0.17.1->deephyper[jax-cpu]) (0.11.0)
Requirement already satisfied: contourpy>=1.0.1 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from matplotlib<4,>=3.4.0->copulas<0.8,>=0.7.0->sdv>=0.17.1->deephyper[jax-cpu]) (1.0.6)
Requirement already satisfied: pillow>=6.2.0 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from matplotlib<4,>=3.4.0->copulas<0.8,>=0.7.0->sdv>=0.17.1->deephyper[jax-cpu]) (9.3.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from matplotlib<4,>=3.4.0->copulas<0.8,>=0.7.0->sdv>=0.17.1->deephyper[jax-cpu]) (1.4.4)
Requirement already satisfied: tenacity>=6.2.0 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from plotly<6,>=5.10.0->sdmetrics<0.8,>=0.7.0.dev0->sdv>=0.17.1->deephyper[jax-cpu]) (8.1.0)
Requirement already satisfied: zipp>=0.5 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tensorboard<2.11,>=2.10->tensorflow>=2.0.0->deephyper[jax-cpu]) (3.10.0)
Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard<2.11,>=2.10->tensorflow>=2.0.0->deephyper[jax-cpu]) (0.4.8)
Requirement already satisfied: oauthlib>=3.0.0 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard<2.11,>=2.10->tensorflow>=2.0.0->deephyper[jax-cpu]) (3.2.1)
Building wheels for collected packages: grpcio
Building wheel for grpcio (setup.py) ... done
Created wheel for grpcio: filename=grpcio-1.49.1-cp39-cp39-macosx_11_0_arm64.whl size=3343063 sha256=7ea43983c514bb15eeea5cc12e74b206ca43fed14095198df6ce0812349151ed
Stored in directory: /Users/romainegele/Library/Caches/pip/wheels/35/a0/8c/de46f52c6cde99252a495c2f83232f7ce94f847c22eced1837
Successfully built grpcio
Installing collected packages: grpcio
Attempting uninstall: grpcio
Found existing installation: grpcio 1.51.3
Uninstalling grpcio-1.51.3:
Successfully uninstalled grpcio-1.51.3
Successfully installed grpcio-1.49.1
0.5.0
Note
The following environment variables can be used to avoid the logging of some Tensorflow DEBUG, INFO and WARNING statements.
[2]:
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = str(4)
os.environ["AUTOGRAPH_VERBOSITY"] = str(0)
3.1. Imports#
[3]:
import pandas as pd
import tensorflow as tf
tf.get_logger().setLevel("ERROR")
3.2. The dataset (from Keras.io)#
The dataset is provided by the Cleveland Clinic Foundation for Heart Disease. It’s a CSV file with 303 rows. Each row contains information about a patient (a sample), and each column describes an attribute of the patient (a feature). We use the features to predict whether a patient has a heart disease (binary classification).
Here’s the description of each feature:
Column |
Description |
Feature Type |
---|---|---|
Age |
Age in years |
Numerical |
Sex |
(1 = male; 0 = female) |
Categorical |
CP |
Chest pain type (0, 1, 2, 3, 4) |
Categorical |
Trestbpd |
Resting blood pressure (in mm Hg on admission) |
Numerical |
Chol |
Serum cholesterol in mg/dl |
Numerical |
FBS |
fasting blood sugar in 120 mg/dl (1 = true; 0 = false) |
Categorical |
RestECG |
Resting electrocardiogram results (0, 1, 2) |
Categorical |
Thalach |
Maximum heart rate achieved |
Numerical |
Exang |
Exercise induced angina (1 = yes; 0 = no) |
Categorical |
Oldpeak |
ST depression induced by exercise relative to rest |
Numerical |
Slope |
Slope of the peak exercise ST segment |
Numerical |
CA |
Number of major vessels (0-3) colored by fluoroscopy |
Both numerical & categorical |
Thal |
3 = normal; 6 = fixed defect; 7 = reversible defect |
Categorical |
Target |
Diagnosis of heart disease (1 = true; 0 = false) |
Target |
[4]:
def load_data():
file_url = "http://storage.googleapis.com/download.tensorflow.org/data/heart.csv"
dataframe = pd.read_csv(file_url)
val_dataframe = dataframe.sample(frac=0.2, random_state=1337)
train_dataframe = dataframe.drop(val_dataframe.index)
return train_dataframe, val_dataframe
def dataframe_to_dataset(dataframe):
dataframe = dataframe.copy()
labels = dataframe.pop("target")
ds = tf.data.Dataset.from_tensor_slices((dict(dataframe), labels))
ds = ds.shuffle(buffer_size=len(dataframe))
return ds
3.3. Preprocessing & encoding of features#
The next cells use tf.keras.layers.Normalization()
to apply standard scaling on the features.
Then, the tf.keras.layers.StringLookup
and tf.keras.layers.IntegerLookup
are used to encode categorical variables.
[5]:
def encode_numerical_feature(feature, name, dataset):
# Create a Normalization layer for our feature
normalizer = tf.keras.layers.Normalization()
# Prepare a Dataset that only yields our feature
feature_ds = dataset.map(lambda x, y: x[name])
feature_ds = feature_ds.map(lambda x: tf.expand_dims(x, -1))
# Learn the statistics of the data
normalizer.adapt(feature_ds)
# Normalize the input feature
encoded_feature = normalizer(feature)
return encoded_feature
def encode_categorical_feature(feature, name, dataset, is_string):
lookup_class = (
tf.keras.layers.StringLookup if is_string else tf.keras.layers.IntegerLookup
)
# Create a lookup layer which will turn strings into integer indices
lookup = lookup_class(output_mode="binary")
# Prepare a Dataset that only yields our feature
feature_ds = dataset.map(lambda x, y: x[name])
feature_ds = feature_ds.map(lambda x: tf.expand_dims(x, -1))
# Learn the set of possible string values and assign them a fixed integer index
lookup.adapt(feature_ds)
# Turn the string input into integer indices
encoded_feature = lookup(feature)
return encoded_feature
3.4. Define the run-function with multi-fidelity#
The run-function defines how the objective that we want to maximize is computed. It takes a job
(see deephyper.evaluator.RunningJob) as input and outputs a scaler value or dictionnary (see deephyper.evaluator). The objective is always maximized in DeepHyper. The job.parameters
contains a suggested
configuration of hyperparameters that we want to evaluate. In this example we will search for:
units
(default value:32
)activation
(default value:"relu"
)dropout_rate
(default value:0.5
)batch_size
(default value:32
)learning_rate
(default value:1e-3
)
A hyperparameter value can be acessed easily in the dictionary through the corresponding key, for example job["units"]
or job.parameters["units"]
are both valid. Unlike the previous tutorial in this example we want to use multi-fidelity to dynamically choose the allocated budget of each evaluation. Therefore we use the tensorflow keras integration of stoppers deephyper.stopper.integration.TFKerasStopperCallback
. The multi-fidelity agent will monitor the validation accuracy
(val_accuracy
) in the context of maximization. This stopper_callback
is then added to the callbacks used by the model during the training. In order to collect more information about the execution of our job we use the @profile
decorator on the run-function which will collect execution timings (timestamp_start
and timestamp_end
). We will also add "metadata"
to the output of our function to know how many epochs were used to evaluate each model. To learn more about how the
@profile
decorator can be used check our tutorial on Understanding the pros and cons of Evaluator parallel backends.
stopper_callback = TFKerasStopperCallback(
job,
monitor="val_accuracy",
mode="max"
)
history = model.fit(
train_ds,
epochs=100,
validation_data=val_ds,
verbose=0,
callbacks=[stopper_callback]
)
objective = history.history["val_accuracy"][-1]
metadata = {"budget": stopper_callback.budget}
return {"objective": objective, "metadata": metadata}
[6]:
from deephyper.evaluator import profile, RunningJob
from deephyper.stopper.integration import TFKerasStopperCallback
@profile
def run(job):
config = job.parameters
tf.autograph.set_verbosity(0)
# Load data and split into validation set
train_dataframe, val_dataframe = load_data()
train_ds = dataframe_to_dataset(train_dataframe)
val_ds = dataframe_to_dataset(val_dataframe)
train_ds = train_ds.batch(config["batch_size"])
val_ds = val_ds.batch(config["batch_size"])
# Categorical features encoded as integers
sex = tf.keras.Input(shape=(1,), name="sex", dtype="int64")
cp = tf.keras.Input(shape=(1,), name="cp", dtype="int64")
fbs = tf.keras.Input(shape=(1,), name="fbs", dtype="int64")
restecg = tf.keras.Input(shape=(1,), name="restecg", dtype="int64")
exang = tf.keras.Input(shape=(1,), name="exang", dtype="int64")
ca = tf.keras.Input(shape=(1,), name="ca", dtype="int64")
# Categorical feature encoded as string
thal = tf.keras.Input(shape=(1,), name="thal", dtype="string")
# Numerical features
age = tf.keras.Input(shape=(1,), name="age")
trestbps = tf.keras.Input(shape=(1,), name="trestbps")
chol = tf.keras.Input(shape=(1,), name="chol")
thalach = tf.keras.Input(shape=(1,), name="thalach")
oldpeak = tf.keras.Input(shape=(1,), name="oldpeak")
slope = tf.keras.Input(shape=(1,), name="slope")
all_inputs = [
sex,
cp,
fbs,
restecg,
exang,
ca,
thal,
age,
trestbps,
chol,
thalach,
oldpeak,
slope,
]
# Integer categorical features
sex_encoded = encode_categorical_feature(sex, "sex", train_ds, False)
cp_encoded = encode_categorical_feature(cp, "cp", train_ds, False)
fbs_encoded = encode_categorical_feature(fbs, "fbs", train_ds, False)
restecg_encoded = encode_categorical_feature(restecg, "restecg", train_ds, False)
exang_encoded = encode_categorical_feature(exang, "exang", train_ds, False)
ca_encoded = encode_categorical_feature(ca, "ca", train_ds, False)
# String categorical features
thal_encoded = encode_categorical_feature(thal, "thal", train_ds, True)
# Numerical features
age_encoded = encode_numerical_feature(age, "age", train_ds)
trestbps_encoded = encode_numerical_feature(trestbps, "trestbps", train_ds)
chol_encoded = encode_numerical_feature(chol, "chol", train_ds)
thalach_encoded = encode_numerical_feature(thalach, "thalach", train_ds)
oldpeak_encoded = encode_numerical_feature(oldpeak, "oldpeak", train_ds)
slope_encoded = encode_numerical_feature(slope, "slope", train_ds)
all_features = tf.keras.layers.concatenate(
[
sex_encoded,
cp_encoded,
fbs_encoded,
restecg_encoded,
exang_encoded,
slope_encoded,
ca_encoded,
thal_encoded,
age_encoded,
trestbps_encoded,
chol_encoded,
thalach_encoded,
oldpeak_encoded,
]
)
x = tf.keras.layers.Dense(config["units"], activation=config["activation"])(
all_features
)
x = tf.keras.layers.Dropout(config["dropout_rate"])(x)
output = tf.keras.layers.Dense(1, activation="sigmoid")(x)
model = tf.keras.Model(all_inputs, output)
optimizer = tf.keras.optimizers.Adam(learning_rate=config["learning_rate"])
model.compile(optimizer, "binary_crossentropy", metrics=["accuracy"])
stopper_callback = TFKerasStopperCallback(
job,
monitor="val_accuracy",
mode="max"
)
history = model.fit(
train_ds,
epochs=100,
validation_data=val_ds,
verbose=0,
callbacks=[stopper_callback]
)
objective = history.history["val_accuracy"][-1]
metadata = {"budget": stopper_callback.budget}
return {"objective": objective, "metadata": metadata}
The objective maximized by DeepHyper is the "objective"
value returned by the run
-function.
In this tutorial it corresponds to the validation accuracy of the last epoch of training which we retrieve in the History
object returned by the model.fit(...)
call.
...
objective = history.history["val_accuracy"][-1]
...
Using an objective like max(history.history['val_accuracy'])
can have undesired side effects.
For example, it is possible that the training curves will overshoot a local maximum, resulting in a model without the capacity to flexibly adapt to new data in the future.
3.5. Define the Hyperparameter optimization problem#
Hyperparameter ranges are defined using the following syntax:
Discrete integer ranges are generated from a tuple
(lower: int, upper: int)
Continuous prarameters are generated from a tuple
(lower: float, upper: float)
Categorical or nonordinal hyperparameter ranges can be given as a list of possible values
[val1, val2, ...]
[7]:
from deephyper.problem import HpProblem
# Creation of an hyperparameter problem
problem = HpProblem()
# Discrete hyperparameter (sampled with uniform prior)
problem.add_hyperparameter((8, 128), "units", default_value=32)
# Categorical hyperparameter (sampled with uniform prior)
ACTIVATIONS = [
"elu", "gelu", "hard_sigmoid", "linear", "relu", "selu",
"sigmoid", "softplus", "softsign", "swish", "tanh",
]
problem.add_hyperparameter(ACTIVATIONS, "activation", default_value="relu")
# Real hyperparameter (sampled with uniform prior)
problem.add_hyperparameter((0.0, 0.6), "dropout_rate", default_value=0.5)
# Discrete and Real hyperparameters (sampled with log-uniform)
problem.add_hyperparameter((8, 256, "log-uniform"), "batch_size", default_value=32)
problem.add_hyperparameter((1e-5, 1e-2, "log-uniform"), "learning_rate", default_value=1e-3)
problem
[7]:
Configuration space object:
Hyperparameters:
activation, Type: Categorical, Choices: {elu, gelu, hard_sigmoid, linear, relu, selu, sigmoid, softplus, softsign, swish, tanh}, Default: relu
batch_size, Type: UniformInteger, Range: [8, 256], Default: 32, on log-scale
dropout_rate, Type: UniformFloat, Range: [0.0, 0.6], Default: 0.5
learning_rate, Type: UniformFloat, Range: [1e-05, 0.01], Default: 0.001, on log-scale
units, Type: UniformInteger, Range: [8, 128], Default: 32
3.6. Evaluate a default configuration#
We evaluate the performance of the default set of hyperparameters provided in the Keras tutorial.
[8]:
objective_default = run(RunningJob(parameters=problem.default_configuration))
print(f"Accuracy of the default configuration is {objective_default['objective']:.3f}\n with a budget of {objective_default['metadata']['budget']}")
Accuracy of the default configuration is 0.803
with a budget of 100
3.7. Execute Multi-Fidelity Bayesian Optimization#
We create the CBO using the problem
and run
-function defined above. When directly passing the run
-function to the search it is wrapped inside a deephyper.evaluator.SerialEvaluator. Then, we also import the deephyper.stopper.LCModelStopper.
[9]:
from deephyper.search.hps import CBO
from deephyper.stopper import LCModelStopper
[10]:
# Instanciate the search with the problem and the evaluator that we created before
stopper = LCModelStopper(min_steps=1, max_steps=100)
search = CBO(problem, run, initial_points=[problem.default_configuration], stopper=stopper)
/Users/romainegele/Documents/Argonne/deephyper/deephyper/evaluator/_evaluator.py:126: UserWarning: Applying nest-asyncio patch for IPython Shell!
warnings.warn(
Note
All DeepHyper’s search algorithm have two stopping criteria:
<li> <code>`max_evals (int)`</code>: Defines the maximum number of evaluations that we want to perform. Default to <code>-1</code> for an infinite number.</li>
<li> <code>`timeout (int)`</code>: Defines a time budget (in seconds) before stopping the search. Default to <code>None</code> for an infinite time budget.</li>
</ul>
[11]:
results = search.search(max_evals=30)
The returned results
is a Pandas Dataframe where columns starting by "p:"
are hyperparameters, columns starting by "m:"
are additional metadata (from the user or from the Evaluator
) as well as the objective
value and the job_id
:
job_id
is a unique identifier corresponding to the order of creation of tasks.objective
is the value returned by the run-function.m:timestamp_submit
is the time (in seconds) when the task was created by the evaluator since the creation of the evaluator.m:timestamp_gather
is the time (in seconds) when the task was received after finishing by the evaluator since the creation of the evaluator.m:timestamp_start
is the time (in seconds) when the task started to run.m:timestamp_end
is the time (in seconds) when task finished to run.m:budget
is the consumed number of epoch for each evaluation.
[12]:
results
[12]:
p:activation | p:batch_size | p:dropout_rate | p:learning_rate | p:units | objective | job_id | m:timestamp_submit | m:timestamp_gather | m:timestamp_start | m:timestamp_end | m:budget | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | relu | 32 | 0.500000 | 0.001000 | 32 | 0.803279 | 0 | 1.670782 | 5.526985 | 1.677845e+09 | 1.677845e+09 | 38 |
1 | swish | 14 | 0.489233 | 0.001336 | 41 | 0.803279 | 1 | 5.572275 | 8.168081 | 1.677845e+09 | 1.677845e+09 | 4 |
2 | relu | 8 | 0.587494 | 0.005130 | 20 | 0.803279 | 2 | 8.192211 | 11.887135 | 1.677845e+09 | 1.677845e+09 | 28 |
3 | tanh | 149 | 0.467829 | 0.000165 | 120 | 0.803279 | 3 | 11.910818 | 14.864182 | 1.677845e+09 | 1.677845e+09 | 58 |
4 | selu | 106 | 0.527451 | 0.000020 | 107 | 0.590164 | 4 | 14.888071 | 17.392803 | 1.677845e+09 | 1.677845e+09 | 4 |
5 | softsign | 80 | 0.270408 | 0.000032 | 66 | 0.360656 | 5 | 17.416549 | 19.849482 | 1.677845e+09 | 1.677845e+09 | 4 |
6 | hard_sigmoid | 58 | 0.563340 | 0.009599 | 77 | 0.819672 | 6 | 20.020613 | 22.708393 | 1.677845e+09 | 1.677845e+09 | 29 |
7 | hard_sigmoid | 19 | 0.202024 | 0.000013 | 119 | 0.229508 | 7 | 22.732340 | 25.483728 | 1.677845e+09 | 1.677845e+09 | 4 |
8 | gelu | 87 | 0.393025 | 0.000047 | 13 | 0.557377 | 8 | 25.507511 | 28.061162 | 1.677845e+09 | 1.677845e+09 | 4 |
9 | linear | 150 | 0.405327 | 0.000357 | 82 | 0.819672 | 9 | 28.085102 | 31.010434 | 1.677845e+09 | 1.677845e+09 | 49 |
10 | elu | 172 | 0.113248 | 0.008485 | 82 | 0.803279 | 10 | 31.135696 | 33.961459 | 1.677845e+09 | 1.677845e+09 | 27 |
11 | hard_sigmoid | 223 | 0.550292 | 0.000403 | 79 | 0.770492 | 11 | 34.087655 | 36.807531 | 1.677845e+09 | 1.677845e+09 | 16 |
12 | linear | 57 | 0.418369 | 0.009584 | 65 | 0.819672 | 12 | 36.936843 | 39.796188 | 1.677845e+09 | 1.677845e+09 | 27 |
13 | linear | 96 | 0.353992 | 0.009298 | 90 | 0.803279 | 13 | 39.927335 | 42.940046 | 1.677845e+09 | 1.677845e+09 | 27 |
14 | sigmoid | 54 | 0.422907 | 0.009956 | 78 | 0.803279 | 14 | 43.074251 | 45.870560 | 1.677845e+09 | 1.677845e+09 | 28 |
15 | hard_sigmoid | 172 | 0.588743 | 0.009655 | 71 | 0.819672 | 15 | 46.006324 | 48.938381 | 1.677845e+09 | 1.677845e+09 | 37 |
16 | linear | 95 | 0.451820 | 0.009566 | 73 | 0.803279 | 16 | 49.071774 | 52.043772 | 1.677845e+09 | 1.677845e+09 | 27 |
17 | gelu | 160 | 0.564634 | 0.009373 | 95 | 0.803279 | 17 | 52.176734 | 54.882586 | 1.677845e+09 | 1.677845e+09 | 29 |
18 | hard_sigmoid | 17 | 0.443228 | 0.009888 | 63 | 0.803279 | 18 | 55.017707 | 58.130752 | 1.677845e+09 | 1.677845e+09 | 28 |
19 | hard_sigmoid | 170 | 0.564063 | 0.007612 | 72 | 0.770492 | 19 | 58.266633 | 60.979968 | 1.677845e+09 | 1.677845e+09 | 4 |
20 | hard_sigmoid | 209 | 0.578272 | 0.009821 | 69 | 0.852459 | 20 | 61.116661 | 63.789852 | 1.677845e+09 | 1.677845e+09 | 32 |
21 | elu | 192 | 0.582215 | 0.007338 | 34 | 0.786885 | 21 | 63.926539 | 66.472252 | 1.677845e+09 | 1.677845e+09 | 4 |
22 | linear | 238 | 0.586330 | 0.009814 | 115 | 0.819672 | 22 | 66.609757 | 69.159936 | 1.677845e+09 | 1.677845e+09 | 4 |
23 | tanh | 167 | 0.576660 | 0.009978 | 11 | 0.836066 | 23 | 69.301096 | 71.973102 | 1.677845e+09 | 1.677845e+09 | 36 |
24 | gelu | 230 | 0.572529 | 0.002841 | 67 | 0.836066 | 24 | 72.111029 | 75.021637 | 1.677845e+09 | 1.677845e+09 | 39 |
25 | relu | 210 | 0.596245 | 0.008074 | 58 | 0.786885 | 25 | 75.160838 | 77.813228 | 1.677845e+09 | 1.677845e+09 | 4 |
26 | elu | 214 | 0.097347 | 0.009841 | 40 | 0.819672 | 26 | 77.955130 | 80.345646 | 1.677845e+09 | 1.677845e+09 | 4 |
27 | hard_sigmoid | 206 | 0.572626 | 0.002464 | 25 | 0.786885 | 27 | 80.486534 | 83.256348 | 1.677845e+09 | 1.677845e+09 | 16 |
28 | hard_sigmoid | 209 | 0.536363 | 0.009566 | 125 | 0.770492 | 28 | 83.397966 | 85.788719 | 1.677845e+09 | 1.677845e+09 | 4 |
29 | tanh | 241 | 0.576931 | 0.009229 | 69 | 0.819672 | 29 | 85.929811 | 88.510505 | 1.677845e+09 | 1.677845e+09 | 4 |
Now that the search is over, let us print the best configuration found during this run.
[13]:
i_max = results.objective.argmax()
best_config = results.iloc[i_max].to_dict()
print(f"The default configuration has an accuracy of {objective_default['objective']:.3f}. \n"
f"The best configuration found by DeepHyper has an accuracy {results['objective'].iloc[i_max]:.3f}, \n"
f"discovered after {results['m:timestamp_gather'].iloc[i_max]:.2f} secondes of search.\n")
best_config
The default configuration has an accuracy of 0.803.
The best configuration found by DeepHyper has an accuracy 0.852,
discovered after 63.79 secondes of search.
[13]:
{'p:activation': 'hard_sigmoid',
'p:batch_size': 209,
'p:dropout_rate': 0.5782715361012362,
'p:learning_rate': 0.0098209943552909,
'p:units': 69,
'objective': 0.8524590134620667,
'job_id': 20,
'm:timestamp_submit': 61.116661071777344,
'm:timestamp_gather': 63.789851903915405,
'm:timestamp_start': 1677845351.991086,
'm:timestamp_end': 1677845354.663899,
'm:budget': 32}
We can observe an improvement of more than 3% in accuracy. We can retrieve the corresponding hyperparameter configuration with the number of epochs used for this evaluation (32).