3. Multi-Fidelity Hyperparameter Optimization with Keras#
In this tutorial we present how to use hyperparameter optimization on a basic example from the Keras documentation. We follow the previous tutorial based on the same example and add multi-fidelity to it. The purpose of multi-fidelity is to dynamically manage the budget allocated (also called fidelity) to evaluate an hyperparameter configuration. For example, when training a deep neural network the number of epochs can be continued or stopped based on currently observed performance and some policy.
In DeepHyper, the multi-fidelity agent is designed separately from the hyperparameter search agent. Of course, both can communicate but from an API perspective they are different objects. The multi-fidelity agents are called Stopper
in DeepHyper and their documentation can be found at deephyper.stopper.
In this notebook, we will demonstrate how to use multi-fidelity inside sequential Bayesian optimization. When moving to a distributed setting, it is important to use a shared database accessible by all workers otherwise the multi-fidelity scheme may not work properly. An example, of database instanciation for parallel computing is explained in: Introduction to Distributed Bayesian Optimization (DBO) with MPI (Communication) and Redis (Storage).
Reference: This tutorial is based on materials from the Keras Documentation: Structured data classification from scratch
Let us start with installing DeepHyper!
Warning
This tutorial should be run with tensorflow>=2.6
.
[1]:
# !pip install "deephyper[jax-cpu]"
import deephyper
print(deephyper.__version__)
0.6.0
Note
The following environment variables can be used to avoid the logging of some Tensorflow DEBUG, INFO and WARNING statements.
[2]:
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = str(4)
os.environ["AUTOGRAPH_VERBOSITY"] = str(0)
3.1. Imports#
[3]:
import pandas as pd
import tensorflow as tf
tf.get_logger().setLevel("ERROR")
3.2. The dataset (from Keras.io)#
The dataset is provided by the Cleveland Clinic Foundation for Heart Disease. It’s a CSV file with 303 rows. Each row contains information about a patient (a sample), and each column describes an attribute of the patient (a feature). We use the features to predict whether a patient has a heart disease (binary classification).
Here’s the description of each feature:
Column |
Description |
Feature Type |
---|---|---|
Age |
Age in years |
Numerical |
Sex |
(1 = male; 0 = female) |
Categorical |
CP |
Chest pain type (0, 1, 2, 3, 4) |
Categorical |
Trestbpd |
Resting blood pressure (in mm Hg on admission) |
Numerical |
Chol |
Serum cholesterol in mg/dl |
Numerical |
FBS |
fasting blood sugar in 120 mg/dl (1 = true; 0 = false) |
Categorical |
RestECG |
Resting electrocardiogram results (0, 1, 2) |
Categorical |
Thalach |
Maximum heart rate achieved |
Numerical |
Exang |
Exercise induced angina (1 = yes; 0 = no) |
Categorical |
Oldpeak |
ST depression induced by exercise relative to rest |
Numerical |
Slope |
Slope of the peak exercise ST segment |
Numerical |
CA |
Number of major vessels (0-3) colored by fluoroscopy |
Both numerical & categorical |
Thal |
3 = normal; 6 = fixed defect; 7 = reversible defect |
Categorical |
Target |
Diagnosis of heart disease (1 = true; 0 = false) |
Target |
[4]:
def load_data():
file_url = "http://storage.googleapis.com/download.tensorflow.org/data/heart.csv"
dataframe = pd.read_csv(file_url)
val_dataframe = dataframe.sample(frac=0.2, random_state=1337)
train_dataframe = dataframe.drop(val_dataframe.index)
return train_dataframe, val_dataframe
def dataframe_to_dataset(dataframe):
dataframe = dataframe.copy()
labels = dataframe.pop("target")
ds = tf.data.Dataset.from_tensor_slices((dict(dataframe), labels))
ds = ds.shuffle(buffer_size=len(dataframe))
return ds
3.3. Preprocessing & encoding of features#
The next cells use tf.keras.layers.Normalization()
to apply standard scaling on the features.
Then, the tf.keras.layers.StringLookup
and tf.keras.layers.IntegerLookup
are used to encode categorical variables.
[5]:
def encode_numerical_feature(feature, name, dataset):
# Create a Normalization layer for our feature
normalizer = tf.keras.layers.Normalization()
# Prepare a Dataset that only yields our feature
feature_ds = dataset.map(lambda x, y: x[name])
feature_ds = feature_ds.map(lambda x: tf.expand_dims(x, -1))
# Learn the statistics of the data
normalizer.adapt(feature_ds)
# Normalize the input feature
encoded_feature = normalizer(feature)
return encoded_feature
def encode_categorical_feature(feature, name, dataset, is_string):
lookup_class = (
tf.keras.layers.StringLookup if is_string else tf.keras.layers.IntegerLookup
)
# Create a lookup layer which will turn strings into integer indices
lookup = lookup_class(output_mode="binary")
# Prepare a Dataset that only yields our feature
feature_ds = dataset.map(lambda x, y: x[name])
feature_ds = feature_ds.map(lambda x: tf.expand_dims(x, -1))
# Learn the set of possible string values and assign them a fixed integer index
lookup.adapt(feature_ds)
# Turn the string input into integer indices
encoded_feature = lookup(feature)
return encoded_feature
3.4. Define the run-function with multi-fidelity#
The run-function defines how the objective that we want to maximize is computed. It takes a job
(see deephyper.evaluator.RunningJob) as input and outputs a scaler value or dictionnary (see deephyper.evaluator). The objective is always maximized in DeepHyper. The job.parameters
contains a suggested
configuration of hyperparameters that we want to evaluate. In this example we will search for:
units
(default value:32
)activation
(default value:"relu"
)dropout_rate
(default value:0.5
)batch_size
(default value:32
)learning_rate
(default value:1e-3
)
A hyperparameter value can be acessed easily in the dictionary through the corresponding key, for example job["units"]
or job.parameters["units"]
are both valid. Unlike the previous tutorial in this example we want to use multi-fidelity to dynamically choose the allocated budget of each evaluation. Therefore we use the tensorflow keras integration of stoppers deephyper.stopper.integration.TFKerasStopperCallback
. The multi-fidelity agent will monitor the validation accuracy
(val_accuracy
) in the context of maximization. This stopper_callback
is then added to the callbacks used by the model during the training. In order to collect more information about the execution of our job we use the @profile
decorator on the run-function which will collect execution timings (timestamp_start
and timestamp_end
). We will also add "metadata"
to the output of our function to know how many epochs were used to evaluate each model. To learn more about how the
@profile
decorator can be used check our tutorial on Understanding the pros and cons of Evaluator parallel backends.
stopper_callback = TFKerasStopperCallback(
job,
monitor="val_accuracy",
mode="max"
)
history = model.fit(
train_ds,
epochs=100,
validation_data=val_ds,
verbose=0,
callbacks=[stopper_callback]
)
objective = history.history["val_accuracy"][-1]
metadata = {"budget": stopper_callback.budget}
return {"objective": objective, "metadata": metadata}
[15]:
import json
from deephyper.evaluator import profile, RunningJob
from deephyper.stopper.integration.tensorflow import TFKerasStopperCallback
@profile
def run(job):
config = job.parameters
tf.autograph.set_verbosity(0)
import absl.logging
absl.logging.set_verbosity(absl.logging.ERROR)
# Load data and split into validation set
train_dataframe, val_dataframe = load_data()
train_ds = dataframe_to_dataset(train_dataframe)
val_ds = dataframe_to_dataset(val_dataframe)
train_ds = train_ds.batch(config["batch_size"])
val_ds = val_ds.batch(config["batch_size"])
# Categorical features encoded as integers
sex = tf.keras.Input(shape=(1,), name="sex", dtype="int64")
cp = tf.keras.Input(shape=(1,), name="cp", dtype="int64")
fbs = tf.keras.Input(shape=(1,), name="fbs", dtype="int64")
restecg = tf.keras.Input(shape=(1,), name="restecg", dtype="int64")
exang = tf.keras.Input(shape=(1,), name="exang", dtype="int64")
ca = tf.keras.Input(shape=(1,), name="ca", dtype="int64")
# Categorical feature encoded as string
thal = tf.keras.Input(shape=(1,), name="thal", dtype="string")
# Numerical features
age = tf.keras.Input(shape=(1,), name="age")
trestbps = tf.keras.Input(shape=(1,), name="trestbps")
chol = tf.keras.Input(shape=(1,), name="chol")
thalach = tf.keras.Input(shape=(1,), name="thalach")
oldpeak = tf.keras.Input(shape=(1,), name="oldpeak")
slope = tf.keras.Input(shape=(1,), name="slope")
all_inputs = [
sex,
cp,
fbs,
restecg,
exang,
ca,
thal,
age,
trestbps,
chol,
thalach,
oldpeak,
slope,
]
# Integer categorical features
sex_encoded = encode_categorical_feature(sex, "sex", train_ds, False)
cp_encoded = encode_categorical_feature(cp, "cp", train_ds, False)
fbs_encoded = encode_categorical_feature(fbs, "fbs", train_ds, False)
restecg_encoded = encode_categorical_feature(restecg, "restecg", train_ds, False)
exang_encoded = encode_categorical_feature(exang, "exang", train_ds, False)
ca_encoded = encode_categorical_feature(ca, "ca", train_ds, False)
# String categorical features
thal_encoded = encode_categorical_feature(thal, "thal", train_ds, True)
# Numerical features
age_encoded = encode_numerical_feature(age, "age", train_ds)
trestbps_encoded = encode_numerical_feature(trestbps, "trestbps", train_ds)
chol_encoded = encode_numerical_feature(chol, "chol", train_ds)
thalach_encoded = encode_numerical_feature(thalach, "thalach", train_ds)
oldpeak_encoded = encode_numerical_feature(oldpeak, "oldpeak", train_ds)
slope_encoded = encode_numerical_feature(slope, "slope", train_ds)
all_features = tf.keras.layers.concatenate(
[
sex_encoded,
cp_encoded,
fbs_encoded,
restecg_encoded,
exang_encoded,
slope_encoded,
ca_encoded,
thal_encoded,
age_encoded,
trestbps_encoded,
chol_encoded,
thalach_encoded,
oldpeak_encoded,
]
)
x = tf.keras.layers.Dense(config["units"], activation=config["activation"])(
all_features
)
x = tf.keras.layers.Dropout(config["dropout_rate"])(x)
output = tf.keras.layers.Dense(1, activation="sigmoid")(x)
model = tf.keras.Model(all_inputs, output)
optimizer = tf.keras.optimizers.Adam(learning_rate=config["learning_rate"])
model.compile(optimizer, "binary_crossentropy", metrics=["accuracy"])
stopper_callback = TFKerasStopperCallback(
job,
monitor="val_accuracy",
mode="max"
)
history = model.fit(
train_ds,
epochs=100,
validation_data=val_ds,
verbose=0,
callbacks=[stopper_callback]
)
objective = history.history["val_accuracy"][-1]
metadata = {
"loss": history.history["loss"],
"val_loss": history.history["val_loss"],
"accuracy": history.history["accuracy"],
"val_accuracy": history.history["val_accuracy"],
}
metadata = {k:json.dumps(v) for k,v in metadata.items()}
metadata["budget"] = stopper_callback.budget
return {"objective": objective, "metadata": metadata}
The objective maximized by DeepHyper is the "objective"
value returned by the run
-function.
In this tutorial it corresponds to the validation accuracy of the last epoch of training which we retrieve in the History
object returned by the model.fit(...)
call.
...
objective = history.history["val_accuracy"][-1]
...
Using an objective like max(history.history['val_accuracy'])
can have undesired side effects.
For example, it is possible that the training curves will overshoot a local maximum, resulting in a model without the capacity to flexibly adapt to new data in the future.
3.5. Define the Hyperparameter optimization problem#
Hyperparameter ranges are defined using the following syntax:
Discrete integer ranges are generated from a tuple
(lower: int, upper: int)
Continuous prarameters are generated from a tuple
(lower: float, upper: float)
Categorical or nonordinal hyperparameter ranges can be given as a list of possible values
[val1, val2, ...]
[16]:
from deephyper.problem import HpProblem
# Creation of an hyperparameter problem
problem = HpProblem()
# Discrete hyperparameter (sampled with uniform prior)
problem.add_hyperparameter((8, 128), "units", default_value=32)
# Categorical hyperparameter (sampled with uniform prior)
ACTIVATIONS = [
"elu", "gelu", "hard_sigmoid", "linear", "relu", "selu",
"sigmoid", "softplus", "softsign", "swish", "tanh",
]
problem.add_hyperparameter(ACTIVATIONS, "activation", default_value="relu")
# Real hyperparameter (sampled with uniform prior)
problem.add_hyperparameter((0.0, 0.6), "dropout_rate", default_value=0.5)
# Discrete and Real hyperparameters (sampled with log-uniform)
problem.add_hyperparameter((8, 256, "log-uniform"), "batch_size", default_value=32)
problem.add_hyperparameter((1e-5, 1e-2, "log-uniform"), "learning_rate", default_value=1e-3)
problem
[16]:
Configuration space object:
Hyperparameters:
activation, Type: Categorical, Choices: {elu, gelu, hard_sigmoid, linear, relu, selu, sigmoid, softplus, softsign, swish, tanh}, Default: relu
batch_size, Type: UniformInteger, Range: [8, 256], Default: 32, on log-scale
dropout_rate, Type: UniformFloat, Range: [0.0, 0.6], Default: 0.5
learning_rate, Type: UniformFloat, Range: [1e-05, 0.01], Default: 0.001, on log-scale
units, Type: UniformInteger, Range: [8, 128], Default: 32
3.6. Evaluate a default configuration#
We evaluate the performance of the default set of hyperparameters provided in the Keras tutorial.
[18]:
out = run(RunningJob(parameters=problem.default_configuration))
objective_default = out["objective"]
metadata_default = out["metadata"]
print(f"Accuracy of the default configuration is {objective_default:.3f}\n with a budget of {metadata_default['budget']}")
out
Accuracy of the default configuration is 0.787
with a budget of 100
[18]:
{'objective': 0.7868852615356445,
'metadata': {'timestamp_start': 1692627006.076231,
'timestamp_end': 1692627008.611656,
'loss': '[0.7388495802879333, 0.6523264050483704, 0.6074849367141724, 0.5875506401062012, 0.5609573125839233, 0.5364176630973816, 0.5154970288276672, 0.5150707364082336, 0.49714744091033936, 0.47235098481178284, 0.44334346055984497, 0.40525874495506287, 0.4218010902404785, 0.404602587223053, 0.40189820528030396, 0.38025644421577454, 0.3819867670536041, 0.3854610025882721, 0.41191190481185913, 0.3633726239204407, 0.38953328132629395, 0.38333550095558167, 0.37049752473831177, 0.3487849533557892, 0.32601648569107056, 0.37537238001823425, 0.34992167353630066, 0.3610842227935791, 0.3291627764701843, 0.3573709726333618, 0.3185168504714966, 0.3198896646499634, 0.342100590467453, 0.31612464785575867, 0.33958184719085693, 0.2965970039367676, 0.3082221746444702, 0.3110814690589905, 0.3165033161640167, 0.3095366954803467, 0.3058503568172455, 0.26944229006767273, 0.28905966877937317, 0.29246464371681213, 0.298763245344162, 0.29010289907455444, 0.3088240921497345, 0.2651433050632477, 0.28673067688941956, 0.28051823377609253, 0.30437788367271423, 0.27532726526260376, 0.27114972472190857, 0.2815909683704376, 0.26282840967178345, 0.2893137037754059, 0.2620639204978943, 0.2733112573623657, 0.29566553235054016, 0.24813897907733917, 0.2784141004085541, 0.2645460367202759, 0.2866308391094208, 0.2698310613632202, 0.2550477385520935, 0.2571200430393219, 0.2692863345146179, 0.28535759449005127, 0.25309813022613525, 0.23560240864753723, 0.2546529769897461, 0.2909449338912964, 0.27611395716667175, 0.2572629451751709, 0.25273746252059937, 0.26408591866493225, 0.2512419521808624, 0.25369662046432495, 0.2681562900543213, 0.26099127531051636, 0.26887956261634827, 0.27868539094924927, 0.25104275345802307, 0.24246439337730408, 0.2736146152019501, 0.229806587100029, 0.23607467114925385, 0.25247693061828613, 0.26732784509658813, 0.23494768142700195, 0.24175147712230682, 0.24835185706615448, 0.24178683757781982, 0.23670029640197754, 0.25046640634536743, 0.2326313853263855, 0.252770334482193, 0.20456381142139435, 0.22638487815856934, 0.23603281378746033]',
'val_loss': '[0.6194877624511719, 0.5722625255584717, 0.5385775566101074, 0.508735716342926, 0.4839463233947754, 0.4639243483543396, 0.4494287371635437, 0.43661004304885864, 0.4258253872394562, 0.4169422388076782, 0.4105747938156128, 0.4043523371219635, 0.3982822299003601, 0.393393874168396, 0.3889952003955841, 0.38584354519844055, 0.38407716155052185, 0.3818555176258087, 0.3788350224494934, 0.3772242069244385, 0.3753066062927246, 0.37455886602401733, 0.37446901202201843, 0.3741632103919983, 0.3731308579444885, 0.3719276189804077, 0.371856689453125, 0.3722780644893646, 0.37168246507644653, 0.37138059735298157, 0.37127938866615295, 0.371003121137619, 0.3717569410800934, 0.3721306622028351, 0.37286439538002014, 0.3737015724182129, 0.3748771846294403, 0.37533822655677795, 0.3756699562072754, 0.37536323070526123, 0.37565532326698303, 0.3775019943714142, 0.37845736742019653, 0.3793482482433319, 0.37970390915870667, 0.38008853793144226, 0.38097918033599854, 0.3820502758026123, 0.38251516222953796, 0.38319170475006104, 0.3819979727268219, 0.3817276656627655, 0.38192903995513916, 0.38186657428741455, 0.38266661763191223, 0.3833865225315094, 0.38367798924446106, 0.38322150707244873, 0.3829233646392822, 0.3829481899738312, 0.3825083374977112, 0.383163720369339, 0.38400524854660034, 0.3840462863445282, 0.3839638829231262, 0.38363203406333923, 0.3847125172615051, 0.3835242986679077, 0.3827455937862396, 0.38354116678237915, 0.38440680503845215, 0.3843563199043274, 0.3848268389701843, 0.38419896364212036, 0.38506948947906494, 0.385724812746048, 0.38643878698349, 0.38729530572891235, 0.3885875642299652, 0.3890146315097809, 0.39021605253219604, 0.39056646823883057, 0.39105290174484253, 0.391281396150589, 0.3917790353298187, 0.39285603165626526, 0.39422619342803955, 0.3948151469230652, 0.39457252621650696, 0.3959348797798157, 0.39613062143325806, 0.39701947569847107, 0.3969540297985077, 0.39677679538726807, 0.3973367512226105, 0.39705222845077515, 0.39696934819221497, 0.39737534523010254, 0.39815035462379456, 0.3987182378768921]',
'accuracy': '[0.5454545617103577, 0.6570248007774353, 0.6983470916748047, 0.6776859760284424, 0.7272727489471436, 0.7479338645935059, 0.7561983466148376, 0.7396694421768188, 0.7355371713638306, 0.7561983466148376, 0.7644628286361694, 0.8429751992225647, 0.8057851195335388, 0.8223140239715576, 0.8099173307418823, 0.8471074104309082, 0.8181818127632141, 0.8388429880142212, 0.7851239442825317, 0.8471074104309082, 0.8223140239715576, 0.8264462947845459, 0.8429751992225647, 0.8512396812438965, 0.8760330677032471, 0.8264462947845459, 0.8636363744735718, 0.8223140239715576, 0.8760330677032471, 0.8429751992225647, 0.85537189245224, 0.85537189245224, 0.8429751992225647, 0.8677685856819153, 0.8388429880142212, 0.8925619721412659, 0.8512396812438965, 0.8636363744735718, 0.8636363744735718, 0.8677685856819153, 0.8719007968902588, 0.8925619721412659, 0.8925619721412659, 0.8719007968902588, 0.8760330677032471, 0.85537189245224, 0.8760330677032471, 0.8966942429542542, 0.8801652789115906, 0.8842975497245789, 0.8595041036605835, 0.8884297609329224, 0.8966942429542542, 0.8842975497245789, 0.8677685856819153, 0.8677685856819153, 0.8842975497245789, 0.8636363744735718, 0.8801652789115906, 0.9090909361839294, 0.8719007968902588, 0.8760330677032471, 0.8760330677032471, 0.8719007968902588, 0.9008264541625977, 0.9173553586006165, 0.8966942429542542, 0.8842975497245789, 0.8884297609329224, 0.8925619721412659, 0.913223147392273, 0.8636363744735718, 0.8801652789115906, 0.8925619721412659, 0.9049586653709412, 0.8842975497245789, 0.8842975497245789, 0.9008264541625977, 0.8801652789115906, 0.8842975497245789, 0.8966942429542542, 0.8760330677032471, 0.9090909361839294, 0.8842975497245789, 0.8677685856819153, 0.9090909361839294, 0.9090909361839294, 0.8966942429542542, 0.8760330677032471, 0.913223147392273, 0.9008264541625977, 0.913223147392273, 0.8966942429542542, 0.9008264541625977, 0.8966942429542542, 0.8966942429542542, 0.8925619721412659, 0.913223147392273, 0.9008264541625977, 0.9049586653709412]',
'val_accuracy': '[0.6393442749977112, 0.7049180269241333, 0.7213114500045776, 0.7213114500045776, 0.7213114500045776, 0.7377049326896667, 0.7704917788505554, 0.7704917788505554, 0.7868852615356445, 0.8196721076965332, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8360655903816223, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445]',
'budget': 100}}
3.7. Execute Multi-Fidelity Bayesian Optimization#
We create the CBO using the problem
and run
-function defined above. When directly passing the run
-function to the search it is wrapped inside a deephyper.evaluator.SerialEvaluator. Then, we also import the deephyper.stopper.LCModelStopper.
[19]:
from deephyper.search.hps import CBO
from deephyper.stopper import LCModelStopper
[20]:
# Instanciate the search with the problem and the evaluator that we created before
stopper = LCModelStopper(min_steps=1, max_steps=100)
search = CBO(
problem,
run,
initial_points=[problem.default_configuration],
stopper=stopper,
verbose=1,
)
Note
All DeepHyper’s search algorithm have two stopping criteria:
max_evals (int)
: Defines the maximum number of evaluations that we want to perform. Default to -1 for an infinite number.timeout (int)
: Defines a time budget (in seconds) before stopping the search. Default to None for an infinite time budget.
[21]:
results = search.search(max_evals=30)
The returned results
is a Pandas Dataframe where columns starting by "p:"
are hyperparameters, columns starting by "m:"
are additional metadata (from the user or from the Evaluator
) as well as the objective
value and the job_id
:
job_id
is a unique identifier corresponding to the order of creation of tasks.objective
is the value returned by the run-function.m:timestamp_submit
is the time (in seconds) when the task was created by the evaluator since the creation of the evaluator.m:timestamp_gather
is the time (in seconds) when the task was received after finishing by the evaluator since the creation of the evaluator.m:timestamp_start
is the time (in seconds) when the task started to run.m:timestamp_end
is the time (in seconds) when task finished to run.m:budget
is the consumed number of epoch for each evaluation.
[22]:
results
[22]:
p:activation | p:batch_size | p:dropout_rate | p:learning_rate | p:units | objective | job_id | m:timestamp_submit | m:timestamp_gather | m:timestamp_start | m:timestamp_end | m:loss | m:val_loss | m:accuracy | m:val_accuracy | m:budget | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | relu | 32 | 0.500000 | 0.001000 | 32 | 0.803279 | 0 | 3.554752 | 7.302983 | 1.692627e+09 | 1.692627e+09 | [0.6227220296859741, 0.5570479035377502, 0.538... | [0.5288254618644714, 0.4989699125289917, 0.475... | [0.6487603187561035, 0.7438016533851624, 0.756... | [0.7868852615356445, 0.8032786846160889, 0.819... | 35 |
1 | linear | 9 | 0.147090 | 0.001889 | 49 | 0.803279 | 1 | 7.347784 | 10.729823 | 1.692627e+09 | 1.692627e+09 | [0.49330684542655945, 0.3671776056289673, 0.30... | [0.3687109649181366, 0.3877825140953064, 0.386... | [0.7479338645935059, 0.8264462947845459, 0.851... | [0.8032786846160889, 0.868852436542511, 0.8524... | 27 |
2 | softsign | 12 | 0.499104 | 0.000029 | 100 | 0.786885 | 2 | 10.749194 | 14.626382 | 1.692627e+09 | 1.692627e+09 | [0.710996150970459, 0.7085673213005066, 0.7007... | [0.6798619031906128, 0.6712697148323059, 0.663... | [0.5289255976676941, 0.5495867729187012, 0.561... | [0.5409836173057556, 0.5409836173057556, 0.557... | 59 |
3 | softsign | 27 | 0.582597 | 0.000215 | 97 | 0.868852 | 3 | 14.645453 | 18.470026 | 1.692627e+09 | 1.692627e+09 | [0.6582441329956055, 0.6144227385520935, 0.614... | [0.6063965559005737, 0.5792577862739563, 0.555... | [0.6157024502754211, 0.6528925895690918, 0.648... | [0.7213114500045776, 0.7540983557701111, 0.737... | 64 |
4 | selu | 28 | 0.469018 | 0.000618 | 110 | 0.786885 | 4 | 18.490789 | 21.121490 | 1.692627e+09 | 1.692627e+09 | [0.6250562071800232, 0.5189632177352905, 0.483... | [0.49002861976623535, 0.423252135515213, 0.386... | [0.6818181872367859, 0.7231404781341553, 0.731... | [0.7704917788505554, 0.8196721076965332, 0.786... | 4 |
5 | linear | 15 | 0.289261 | 0.000777 | 68 | 0.819672 | 5 | 21.140275 | 23.720370 | 1.692627e+09 | 1.692627e+09 | [0.6295762062072754, 0.4781874716281891, 0.421... | [0.5124895572662354, 0.433324933052063, 0.4010... | [0.6735537052154541, 0.7603305578231812, 0.801... | [0.8032786846160889, 0.7868852615356445, 0.803... | 4 |
6 | elu | 133 | 0.007165 | 0.000024 | 88 | 0.688525 | 6 | 23.738940 | 26.728599 | 1.692627e+09 | 1.692627e+09 | [0.6447474956512451, 0.6408497095108032, 0.641... | [0.597851037979126, 0.5967557430267334, 0.5956... | [0.6280992031097412, 0.6322314143180847, 0.640... | [0.688524603843689, 0.688524603843689, 0.68852... | 4 |
7 | elu | 184 | 0.365660 | 0.000276 | 126 | 0.704918 | 7 | 26.748085 | 29.293730 | 1.692627e+09 | 1.692627e+09 | [0.6677700281143188, 0.6242449879646301, 0.632... | [0.6279993057250977, 0.612653374671936, 0.5987... | [0.6239669322967529, 0.6652892827987671, 0.673... | [0.7049180269241333, 0.7213114500045776, 0.721... | 4 |
8 | linear | 27 | 0.410357 | 0.003974 | 10 | 0.819672 | 8 | 29.312796 | 32.318316 | 1.692627e+09 | 1.692627e+09 | [0.6348783373832703, 0.4910481572151184, 0.433... | [0.5030055046081543, 0.4379017949104309, 0.406... | [0.7066115736961365, 0.7438016533851624, 0.801... | [0.7540983557701111, 0.8032786846160889, 0.803... | 31 |
9 | gelu | 37 | 0.586930 | 0.001828 | 94 | 0.786885 | 9 | 32.583841 | 35.622722 | 1.692627e+09 | 1.692627e+09 | [0.646159291267395, 0.5308106541633606, 0.4458... | [0.5199227333068848, 0.44885024428367615, 0.40... | [0.6363636255264282, 0.7479338645935059, 0.809... | [0.8196721076965332, 0.7868852615356445, 0.786... | 35 |
10 | softsign | 15 | 0.580355 | 0.000195 | 88 | 0.426230 | 10 | 35.721770 | 37.368073 | 1.692627e+09 | 1.692627e+09 | [0.7734954953193665] | [0.7477120161056519] | [0.41735535860061646] | [0.4262295067310333] | 1 |
11 | swish | 27 | 0.599008 | 0.000148 | 80 | 0.245902 | 11 | 37.469103 | 38.974784 | 1.692627e+09 | 1.692627e+09 | [0.828150749206543] | [0.810135006904602] | [0.3636363744735718] | [0.24590164422988892] | 1 |
12 | softplus | 11 | 0.590109 | 0.000299 | 97 | 0.770492 | 12 | 39.081071 | 41.960308 | 1.692627e+09 | 1.692627e+09 | [0.8678354024887085, 0.7065883874893188, 0.700... | [0.6125381588935852, 0.5351727604866028, 0.495... | [0.5206611752510071, 0.5909090638160706, 0.669... | [0.7049180269241333, 0.7704917788505554, 0.770... | 4 |
13 | softsign | 24 | 0.568922 | 0.000112 | 122 | 0.803279 | 13 | 42.066191 | 45.404559 | 1.692627e+09 | 1.692627e+09 | [0.724600076675415, 0.7176037430763245, 0.6875... | [0.6987729072570801, 0.674211859703064, 0.6525... | [0.4958677589893341, 0.5454545617103577, 0.607... | [0.5573770403862, 0.6065573692321777, 0.655737... | 40 |
14 | softsign | 29 | 0.578839 | 0.000159 | 105 | 0.803279 | 14 | 45.507679 | 48.744906 | 1.692627e+09 | 1.692627e+09 | [0.7335056662559509, 0.7066584229469299, 0.688... | [0.7332508563995361, 0.7050181031227112, 0.679... | [0.46694216132164, 0.5289255976676941, 0.52066... | [0.44262295961380005, 0.49180328845977783, 0.5... | 51 |
15 | linear | 16 | 0.287004 | 0.000726 | 69 | 0.786885 | 15 | 48.853081 | 52.410750 | 1.692627e+09 | 1.692627e+09 | [0.6595832109451294, 0.5204204320907593, 0.485... | [0.48184749484062195, 0.4187839925289154, 0.38... | [0.64462810754776, 0.7148760557174683, 0.74793... | [0.8032786846160889, 0.8032786846160889, 0.819... | 37 |
16 | tanh | 27 | 0.433537 | 0.005760 | 24 | 0.786885 | 16 | 52.515180 | 55.492876 | 1.692627e+09 | 1.692627e+09 | [0.4885702431201935, 0.4026663899421692, 0.338... | [0.3990103304386139, 0.3842548429965973, 0.395... | [0.7644628286361694, 0.8016529083251953, 0.847... | [0.8196721076965332, 0.8032786846160889, 0.852... | 29 |
17 | softsign | 32 | 0.597483 | 0.000243 | 8 | 0.803279 | 17 | 55.600845 | 59.004269 | 1.692627e+09 | 1.692627e+09 | [0.8090429306030273, 0.7903764247894287, 0.808... | [0.7881447076797485, 0.7788082361221313, 0.769... | [0.4586776793003082, 0.44628098607063293, 0.45... | [0.31147539615631104, 0.31147539615631104, 0.3... | 63 |
18 | swish | 24 | 0.584958 | 0.000213 | 111 | 0.852459 | 18 | 59.111170 | 62.916621 | 1.692627e+09 | 1.692627e+09 | [0.6998273134231567, 0.6901867985725403, 0.633... | [0.6759735941886902, 0.6463225483894348, 0.621... | [0.5289255976676941, 0.56611567735672, 0.65289... | [0.6065573692321777, 0.6721311211585999, 0.672... | 36 |
19 | swish | 25 | 0.598112 | 0.000129 | 124 | 0.409836 | 19 | 63.025823 | 65.628007 | 1.692627e+09 | 1.692627e+09 | [0.7799810171127319, 0.7728349566459656, 0.743... | [0.7869507670402527, 0.7617835998535156, 0.738... | [0.39256197214126587, 0.42148759961128235, 0.4... | [0.26229506731033325, 0.3442623019218445, 0.37... | 4 |
20 | swish | 24 | 0.503962 | 0.000188 | 108 | 0.852459 | 20 | 65.737362 | 69.012517 | 1.692627e+09 | 1.692627e+09 | [0.7826627492904663, 0.7401067018508911, 0.716... | [0.7268974184989929, 0.6923832893371582, 0.661... | [0.40495866537094116, 0.4710743725299835, 0.51... | [0.4098360538482666, 0.5409836173057556, 0.672... | 52 |
21 | swish | 23 | 0.508027 | 0.000195 | 121 | 0.786885 | 21 | 69.122440 | 72.220513 | 1.692627e+09 | 1.692627e+09 | [0.6466206312179565, 0.6104539036750793, 0.596... | [0.6054739952087402, 0.5772570967674255, 0.553... | [0.6363636255264282, 0.7355371713638306, 0.739... | [0.688524603843689, 0.7704917788505554, 0.7704... | 16 |
22 | swish | 21 | 0.599638 | 0.000220 | 115 | 0.803279 | 22 | 72.331612 | 75.464572 | 1.692627e+09 | 1.692627e+09 | [0.7254757285118103, 0.6856257915496826, 0.656... | [0.7080947756767273, 0.6708484292030334, 0.636... | [0.4834710657596588, 0.5289255976676941, 0.628... | [0.5081967115402222, 0.6393442749977112, 0.737... | 36 |
23 | swish | 24 | 0.559085 | 0.000214 | 88 | 0.721311 | 23 | 75.574948 | 78.101354 | 1.692627e+09 | 1.692627e+09 | [0.7135123610496521, 0.6844885349273682, 0.662... | [0.6889106631278992, 0.6591018438339233, 0.631... | [0.4834710657596588, 0.5619834661483765, 0.595... | [0.5081967115402222, 0.6229507923126221, 0.704... | 4 |
24 | softplus | 26 | 0.576747 | 0.000199 | 102 | 0.803279 | 24 | 78.210997 | 81.422945 | 1.692627e+09 | 1.692627e+09 | [0.8818848729133606, 0.8496137857437134, 0.756... | [0.7038170695304871, 0.6505393385887146, 0.610... | [0.4793388545513153, 0.4752066135406494, 0.537... | [0.4098360538482666, 0.8360655903816223, 0.819... | 27 |
25 | swish | 24 | 0.567181 | 0.000209 | 107 | 0.754098 | 25 | 81.533879 | 84.056534 | 1.692627e+09 | 1.692627e+09 | [0.7189084887504578, 0.7079837918281555, 0.677... | [0.6937562227249146, 0.6694052219390869, 0.646... | [0.5371900796890259, 0.557851254940033, 0.5785... | [0.5901639461517334, 0.6557376980781555, 0.721... | 4 |
26 | tanh | 24 | 0.503059 | 0.000118 | 100 | 0.754098 | 26 | 84.166216 | 86.679132 | 1.692627e+09 | 1.692627e+09 | [0.6868227124214172, 0.6927969455718994, 0.650... | [0.6453925967216492, 0.6199279427528381, 0.596... | [0.5826446413993835, 0.5785123705863953, 0.648... | [0.6721311211585999, 0.688524603843689, 0.7377... | 4 |
27 | softsign | 27 | 0.552740 | 0.000288 | 101 | 0.836066 | 27 | 86.789922 | 90.000142 | 1.692627e+09 | 1.692627e+09 | [0.6240310072898865, 0.6154924035072327, 0.591... | [0.5968618392944336, 0.5632079243659973, 0.535... | [0.6652892827987671, 0.6363636255264282, 0.739... | [0.8032786846160889, 0.8360655903816223, 0.819... | 27 |
28 | softsign | 26 | 0.585004 | 0.000215 | 104 | 0.836066 | 28 | 90.113099 | 93.519698 | 1.692627e+09 | 1.692627e+09 | [0.6899656057357788, 0.6779984831809998, 0.646... | [0.677489161491394, 0.6455464363098145, 0.6153... | [0.5619834661483765, 0.5743801593780518, 0.644... | [0.5573770403862, 0.6065573692321777, 0.639344... | 64 |
29 | softsign | 19 | 0.582967 | 0.000434 | 96 | 0.803279 | 29 | 93.630995 | 96.207705 | 1.692627e+09 | 1.692627e+09 | [0.7093287110328674, 0.5999704003334045, 0.584... | [0.6647931933403015, 0.5841431021690369, 0.525... | [0.5289255976676941, 0.6983470916748047, 0.706... | [0.5573770403862, 0.7540983557701111, 0.803278... | 4 |
Now that the search is over, let us print the best configuration found during this run.
[24]:
i_max = results.objective.argmax()
best_job = results.iloc[i_max].to_dict()
print(f"The default configuration has an accuracy of {objective_default:.3f}. \n"
f"The best configuration found by DeepHyper has an accuracy {results['objective'].iloc[i_max]:.3f}, \n"
f"discovered after {results['m:timestamp_gather'].iloc[i_max]:.2f} secondes of search.\n")
best_job
The default configuration has an accuracy of 0.787.
The best configuration found by DeepHyper has an accuracy 0.869,
discovered after 18.47 secondes of search.
[24]:
{'p:activation': 'softsign',
'p:batch_size': 27,
'p:dropout_rate': 0.5825972321970286,
'p:learning_rate': 0.0002146492013936,
'p:units': 97,
'objective': 0.868852436542511,
'job_id': 3,
'm:timestamp_submit': 14.645452976226808,
'm:timestamp_gather': 18.47002601623535,
'm:timestamp_start': 1692627056.0797062,
'm:timestamp_end': 1692627059.903903,
'm:loss': '[0.6582441329956055, 0.6144227385520935, 0.6145278215408325, 0.6029292941093445, 0.5701202154159546, 0.5590323209762573, 0.5266819000244141, 0.5313710570335388, 0.5190641283988953, 0.48551449179649353, 0.4518640339374542, 0.4618057608604431, 0.47630706429481506, 0.4841085374355316, 0.46425509452819824, 0.46001386642456055, 0.4311169981956482, 0.42034876346588135, 0.4447265863418579, 0.4271446466445923, 0.41852933168411255, 0.42826342582702637, 0.4135468304157257, 0.40737172961235046, 0.3894239068031311, 0.38594695925712585, 0.4143550395965576, 0.3991055190563202, 0.39989709854125977, 0.3991136848926544, 0.3789837062358856, 0.3853186070919037, 0.37715306878089905, 0.37205013632774353, 0.37237340211868286, 0.3673132061958313, 0.3666028082370758, 0.35598233342170715, 0.3553876280784607, 0.3808722496032715, 0.35543930530548096, 0.34617236256599426, 0.35386309027671814, 0.37097516655921936, 0.3688899576663971, 0.35417208075523376, 0.35026198625564575, 0.3427666425704956, 0.346967488527298, 0.3532065153121948, 0.33742740750312805, 0.3545960783958435, 0.3325014114379883, 0.33512216806411743, 0.32153061032295227, 0.3290161192417145, 0.3484468460083008, 0.34243884682655334, 0.33464089035987854, 0.32899194955825806, 0.32888177037239075, 0.3334735631942749, 0.3368818759918213, 0.3451095521450043]',
'm:val_loss': '[0.6063965559005737, 0.5792577862739563, 0.5556067824363708, 0.5347539782524109, 0.5161513090133667, 0.5006381869316101, 0.4881114661693573, 0.47605815529823303, 0.4642772376537323, 0.45479875802993774, 0.4466439485549927, 0.43913522362709045, 0.43216830492019653, 0.4264807105064392, 0.42094600200653076, 0.41611552238464355, 0.4121261537075043, 0.4083685278892517, 0.4054298996925354, 0.40238967537879944, 0.3995138108730316, 0.39638280868530273, 0.39376404881477356, 0.3915865421295166, 0.3895185589790344, 0.38762181997299194, 0.38626304268836975, 0.3853980302810669, 0.3846575915813446, 0.38337844610214233, 0.38249704241752625, 0.38157397508621216, 0.38086169958114624, 0.3803839385509491, 0.3804793655872345, 0.3803507685661316, 0.3803969919681549, 0.379818856716156, 0.3797812759876251, 0.37954601645469666, 0.3791370987892151, 0.3784518837928772, 0.3781093955039978, 0.377916544675827, 0.37754306197166443, 0.3776501715183258, 0.37767040729522705, 0.3779224455356598, 0.3777530789375305, 0.37799403071403503, 0.37827208638191223, 0.3791006803512573, 0.3788507282733917, 0.37848344445228577, 0.3786301016807556, 0.3790893256664276, 0.37959209084510803, 0.38035571575164795, 0.3806765675544739, 0.38088637590408325, 0.38094907999038696, 0.3814001977443695, 0.3820967972278595, 0.3824552893638611]',
'm:accuracy': '[0.6157024502754211, 0.6528925895690918, 0.6487603187561035, 0.6735537052154541, 0.71074378490448, 0.7685950398445129, 0.7561983466148376, 0.7520661354064941, 0.7438016533851624, 0.7727272510528564, 0.8099173307418823, 0.7851239442825317, 0.797520637512207, 0.7933884263038635, 0.7851239442825317, 0.78925621509552, 0.8140496015548706, 0.8181818127632141, 0.8140496015548706, 0.8057851195335388, 0.8347107172012329, 0.8264462947845459, 0.8181818127632141, 0.8264462947845459, 0.8347107172012329, 0.8223140239715576, 0.8305785059928894, 0.8181818127632141, 0.8429751992225647, 0.8264462947845459, 0.8305785059928894, 0.8223140239715576, 0.8471074104309082, 0.8512396812438965, 0.8429751992225647, 0.8305785059928894, 0.8429751992225647, 0.8512396812438965, 0.8595041036605835, 0.8181818127632141, 0.8388429880142212, 0.85537189245224, 0.8264462947845459, 0.8181818127632141, 0.8388429880142212, 0.8471074104309082, 0.8347107172012329, 0.8512396812438965, 0.8471074104309082, 0.8471074104309082, 0.85537189245224, 0.8512396812438965, 0.8760330677032471, 0.8388429880142212, 0.8760330677032471, 0.8471074104309082, 0.8347107172012329, 0.8471074104309082, 0.85537189245224, 0.8677685856819153, 0.85537189245224, 0.85537189245224, 0.8471074104309082, 0.8429751992225647]',
'm:val_accuracy': '[0.7213114500045776, 0.7540983557701111, 0.7377049326896667, 0.7540983557701111, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.868852436542511, 0.868852436542511, 0.868852436542511, 0.868852436542511, 0.868852436542511, 0.868852436542511, 0.868852436542511, 0.868852436542511, 0.868852436542511, 0.868852436542511, 0.868852436542511, 0.868852436542511, 0.868852436542511, 0.868852436542511, 0.868852436542511, 0.868852436542511, 0.868852436542511, 0.868852436542511]',
'm:budget': 64}
[25]:
import matplotlib.pyplot as plt
plt.figure()
plt.plot(json.loads(metadata_default["val_accuracy"]), color="skyblue", label="Default (val)")
plt.plot(json.loads(metadata_default["accuracy"]), color="skyblue", linestyle="--", label="Default (train)")
plt.plot(json.loads(best_job["m:val_accuracy"]), color="coral", linewidth=2, label="Best Job(val)")
plt.plot(json.loads(best_job["m:accuracy"]), color="coral", linestyle="--", linewidth=2, label="Best Job (train)")
plt.legend()
plt.xlim(0,50)
plt.ylim(0.5, 0.9)
plt.grid()
plt.show()

We can observe an improvement of more than 3% in accuracy. We can retrieve the corresponding hyperparameter configuration with the number of epochs used for this evaluation (32).