2. Hyperparameter search for classification with Tabular data (Keras)#
In this tutorial we present how to use hyperparameter optimization on a basic example from the Keras documentation.
Reference: This tutorial is based on materials from the Keras Documentation: Structured data classification from scratch
Let us start with installing DeepHyper!
Warning
This tutorial should be run with tensorflow>=2.6
.
[1]:
try:
import deephyper
print(deephyper.__version__)
except (ImportError, ModuleNotFoundError):
!pip install deephyper
try:
import ray
except (ImportError, ModuleNotFoundError):
!pip install ray
0.6.0
Note
The following environment variables can be used to avoid the logging of some Tensorflow DEBUG, INFO and WARNING statements.
[2]:
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = str(3)
os.environ["AUTOGRAPH_VERBOSITY"] = str(0)
2.1. Imports#
Warning
It is important to follow the import strategy import tensorflow as tf
to prevent serialization errors that will crash the search.
The import strategy from the original Keras tutorial (shown below):
from tensorflow import keras
from tensorflow.keras import layers
...
from tensorflow.keras.layers import IntegerLookup
from tensorflow.keras.layers import Normalization
from tensorflow.keras.layers import StringLookup
resulted in non-serializable data, preventing the search from executing in parallel.
[1]:
import json
import numpy as np
import pandas as pd
import tensorflow as tf
import tf_keras as tfk
import tf_keras.backend as K
Note
The following can be used to detect if GPU devices are available on the current host. Therefore, this notebook will automatically adapt the parallel execution based on the resources available locally. However, this simple code will not detect the ressources from multiple nodes.
[2]:
from tensorflow.python.client import device_lib
def get_available_gpus():
local_device_protos = device_lib.list_local_devices()
return [x.name for x in local_device_protos if x.device_type == "GPU"]
n_gpus = len(get_available_gpus())
if n_gpus > 1:
n_gpus -= 1
is_gpu_available = n_gpus > 0
if is_gpu_available:
print(f"{n_gpus} GPU{'s are' if n_gpus > 1 else ' is'} available.")
else:
print("No GPU available")
No GPU available
2.2. The dataset (from Keras.io)#
The dataset is provided by the Cleveland Clinic Foundation for Heart Disease. It’s a CSV file with 303 rows. Each row contains information about a patient (a sample), and each column describes an attribute of the patient (a feature). We use the features to predict whether a patient has a heart disease (binary classification).
Here’s the description of each feature:
Column |
Description |
Feature Type |
---|---|---|
Age |
Age in years |
Numerical |
Sex |
(1 = male; 0 = female) |
Categorical |
CP |
Chest pain type (0, 1, 2, 3, 4) |
Categorical |
Trestbpd |
Resting blood pressure (in mm Hg on admission) |
Numerical |
Chol |
Serum cholesterol in mg/dl |
Numerical |
FBS |
fasting blood sugar in 120 mg/dl (1 = true; 0 = false) |
Categorical |
RestECG |
Resting electrocardiogram results (0, 1, 2) |
Categorical |
Thalach |
Maximum heart rate achieved |
Numerical |
Exang |
Exercise induced angina (1 = yes; 0 = no) |
Categorical |
Oldpeak |
ST depression induced by exercise relative to rest |
Numerical |
Slope |
Slope of the peak exercise ST segment |
Numerical |
CA |
Number of major vessels (0-3) colored by fluoroscopy |
Both numerical & categorical |
Thal |
3 = normal; 6 = fixed defect; 7 = reversible defect |
Categorical |
Target |
Diagnosis of heart disease (1 = true; 0 = false) |
Target |
[3]:
def load_data():
file_url = "http://storage.googleapis.com/download.tensorflow.org/data/heart.csv"
dataframe = pd.read_csv(file_url)
val_dataframe = dataframe.sample(frac=0.2, random_state=1337)
train_dataframe = dataframe.drop(val_dataframe.index)
return train_dataframe, val_dataframe
def dataframe_to_dataset(dataframe):
dataframe = dataframe.copy()
labels = dataframe.pop("target")
ds = tf.data.Dataset.from_tensor_slices((dict(dataframe), labels))
ds = ds.shuffle(buffer_size=len(dataframe))
return ds
2.3. Preprocessing & encoding of features#
The next cells use tfk.layers.Normalization()
to apply standard scaling on the features.
Then, the tfk.layers.StringLookup
and tfk.layers.IntegerLookup
are used to encode categorical variables.
[7]:
def encode_numerical_feature(feature, name, dataset):
# Create a Normalization layer for our feature
normalizer = tfk.layers.Normalization()
# Prepare a Dataset that only yields our feature
feature_ds = dataset.map(lambda x, y: x[name])
feature_ds = feature_ds.map(lambda x: tf.expand_dims(x, -1))
# Learn the statistics of the data
normalizer.adapt(feature_ds)
# Normalize the input feature
encoded_feature = normalizer(feature)
return encoded_feature
def encode_categorical_feature(feature, name, dataset, is_string):
lookup_class = (
tfk.layers.StringLookup if is_string else tfk.layers.IntegerLookup
)
# Create a lookup layer which will turn strings into integer indices
lookup = lookup_class(output_mode="binary")
# Prepare a Dataset that only yields our feature
feature_ds = dataset.map(lambda x, y: x[name])
feature_ds = feature_ds.map(lambda x: tf.expand_dims(x, -1))
# Learn the set of possible string values and assign them a fixed integer index
lookup.adapt(feature_ds)
# Turn the string input into integer indices
encoded_feature = lookup(feature)
return encoded_feature
2.4. Define the run-function#
The run-function defines how the objective that we want to maximize is computed. It takes a config
dictionary as input and often returns a scalar value that we want to maximize. The config
contains a sample value of hyperparameters that we want to tune. In this example we will search for:
units
(default value:32
)activation
(default value:"relu"
)dropout_rate
(default value:0.5
)num_epochs
(default value:50
)batch_size
(default value:32
)learning_rate
(default value:1e-3
)
A hyperparameter value can be acessed easily in the dictionary through the corresponding key, for example config["units"]
.
[18]:
def count_params(model: tfk.Model) -> dict:
"""Evaluate the number of parameters of a Keras model.
Args:
model (tfk.Model): a Keras model.
Returns:
dict: a dictionary with the number of trainable ``"num_parameters_train"`` and
non-trainable parameters ``"num_parameters"``.
"""
def count_or_null(p):
try:
return K.count_params(p)
except:
return 0
num_parameters_train = int(
np.sum([count_or_null(p) for p in model.trainable_weights])
)
num_parameters = int(
np.sum([count_or_null(p) for p in model.non_trainable_weights])
)
return {
"num_parameters": num_parameters,
"num_parameters_train": num_parameters_train,
}
[19]:
def run(config: dict):
tf.autograph.set_verbosity(0)
# Load data and split into validation set
train_dataframe, val_dataframe = load_data()
train_ds = dataframe_to_dataset(train_dataframe)
val_ds = dataframe_to_dataset(val_dataframe)
train_ds = train_ds.batch(config["batch_size"])
val_ds = val_ds.batch(config["batch_size"])
# Categorical features encoded as integers
sex = tfk.Input(shape=(1,), name="sex", dtype="int64")
cp = tfk.Input(shape=(1,), name="cp", dtype="int64")
fbs = tfk.Input(shape=(1,), name="fbs", dtype="int64")
restecg = tfk.Input(shape=(1,), name="restecg", dtype="int64")
exang = tfk.Input(shape=(1,), name="exang", dtype="int64")
ca = tfk.Input(shape=(1,), name="ca", dtype="int64")
# Categorical feature encoded as string
thal = tfk.Input(shape=(1,), name="thal", dtype="string")
# Numerical features
age = tfk.Input(shape=(1,), name="age")
trestbps = tfk.Input(shape=(1,), name="trestbps")
chol = tfk.Input(shape=(1,), name="chol")
thalach = tfk.Input(shape=(1,), name="thalach")
oldpeak = tfk.Input(shape=(1,), name="oldpeak")
slope = tfk.Input(shape=(1,), name="slope")
all_inputs = [
sex,
cp,
fbs,
restecg,
exang,
ca,
thal,
age,
trestbps,
chol,
thalach,
oldpeak,
slope,
]
# Integer categorical features
sex_encoded = encode_categorical_feature(sex, "sex", train_ds, False)
cp_encoded = encode_categorical_feature(cp, "cp", train_ds, False)
fbs_encoded = encode_categorical_feature(fbs, "fbs", train_ds, False)
restecg_encoded = encode_categorical_feature(restecg, "restecg", train_ds, False)
exang_encoded = encode_categorical_feature(exang, "exang", train_ds, False)
ca_encoded = encode_categorical_feature(ca, "ca", train_ds, False)
# String categorical features
thal_encoded = encode_categorical_feature(thal, "thal", train_ds, True)
# Numerical features
age_encoded = encode_numerical_feature(age, "age", train_ds)
trestbps_encoded = encode_numerical_feature(trestbps, "trestbps", train_ds)
chol_encoded = encode_numerical_feature(chol, "chol", train_ds)
thalach_encoded = encode_numerical_feature(thalach, "thalach", train_ds)
oldpeak_encoded = encode_numerical_feature(oldpeak, "oldpeak", train_ds)
slope_encoded = encode_numerical_feature(slope, "slope", train_ds)
all_features = tfk.layers.concatenate(
[
sex_encoded,
cp_encoded,
fbs_encoded,
restecg_encoded,
exang_encoded,
slope_encoded,
ca_encoded,
thal_encoded,
age_encoded,
trestbps_encoded,
chol_encoded,
thalach_encoded,
oldpeak_encoded,
]
)
x = tfk.layers.Dense(config["units"], activation=config["activation"])(
all_features
)
x = tfk.layers.Dropout(config["dropout_rate"])(x)
output = tfk.layers.Dense(1, activation="sigmoid")(x)
model = tfk.Model(all_inputs, output)
optimizer = tfk.optimizers.Adam(learning_rate=config["learning_rate"])
model.compile(optimizer, "binary_crossentropy", metrics=["accuracy"])
history = model.fit(
train_ds, epochs=config["num_epochs"], validation_data=val_ds, verbose=0
)
objective = history.history["val_accuracy"][-1]
metadata = {
"loss": history.history["loss"],
"val_loss": history.history["val_loss"],
"accuracy": history.history["accuracy"],
"val_accuracy": history.history["val_accuracy"],
}
metadata = {k:json.dumps(v) for k,v in metadata.items()}
metadata.update(count_params(model))
return {"objective": objective, "metadata": metadata}
The objective maximised by DeepHyper is the scalar value returned by the run
-function.
In this tutorial it corresponds to the validation accuracy of the last epoch of training which we retrieve in the History
object returned by the model.fit(...)
call.
...
history = model.fit(
train_ds, epochs=config["num_epochs"], validation_data=val_ds, verbose=0
)
return history.history["val_accuracy"][-1]
...
Using an objective like max(history.history['val_accuracy'])
can have undesired side effects.
For example, it is possible that the training curves will overshoot a local maximum, resulting in a model without the capacity to flexibly adapt to new data in the future.
2.5. Define the Hyperparameter optimization problem#
Hyperparameter ranges are defined using the following syntax:
Discrete integer ranges are generated from a tuple
(lower: int, upper: int)
Continuous prarameters are generated from a tuple
(lower: float, upper: float)
Categorical or nonordinal hyperparameter ranges can be given as a list of possible values
[val1, val2, ...]
[20]:
from deephyper.hpo import HpProblem
# Creation of an hyperparameter problem
problem = HpProblem()
# Discrete hyperparameter (sampled with uniform prior)
problem.add_hyperparameter((8, 128), "units", default_value=32)
problem.add_hyperparameter((10, 100), "num_epochs", default_value=50)
# Categorical hyperparameter (sampled with uniform prior)
ACTIVATIONS = [
"elu", "gelu", "hard_sigmoid", "linear", "relu", "selu",
"sigmoid", "softplus", "softsign", "swish", "tanh",
]
problem.add_hyperparameter(ACTIVATIONS, "activation", default_value="relu")
# Real hyperparameter (sampled with uniform prior)
problem.add_hyperparameter((0.0, 0.6), "dropout_rate", default_value=0.5)
# Discrete and Real hyperparameters (sampled with log-uniform)
problem.add_hyperparameter((8, 256, "log-uniform"), "batch_size", default_value=32)
problem.add_hyperparameter((1e-5, 1e-2, "log-uniform"), "learning_rate", default_value=1e-3)
problem
[20]:
Configuration space object:
Hyperparameters:
activation, Type: Categorical, Choices: {elu, gelu, hard_sigmoid, linear, relu, selu, sigmoid, softplus, softsign, swish, tanh}, Default: relu
batch_size, Type: UniformInteger, Range: [8, 256], Default: 32, on log-scale
dropout_rate, Type: UniformFloat, Range: [0.0, 0.6], Default: 0.5
learning_rate, Type: UniformFloat, Range: [1e-05, 0.01], Default: 0.001, on log-scale
num_epochs, Type: UniformInteger, Range: [10, 100], Default: 50
units, Type: UniformInteger, Range: [8, 128], Default: 32
2.6. Evaluate a default configuration#
We evaluate the performance of the default set of hyperparameters provided in the Keras tutorial.
[21]:
import ray
# We launch the Ray run-time depending of the detected local ressources
# and execute the `run` function with the default configuration
# WARNING: in the case of GPUs it is important to follow this scheme
# to avoid multiple processes (Ray workers vs current process) to lock
# the same GPU.
if is_gpu_available:
if not(ray.is_initialized()):
ray.init(num_cpus=n_gpus, num_gpus=n_gpus, log_to_driver=False)
run_default = ray.remote(num_cpus=1, num_gpus=1)(run)
out = ray.get(run_default.remote(problem.default_configuration))
else:
if not(ray.is_initialized()):
ray.init(num_cpus=1, log_to_driver=False)
run_default = run
out = run_default(problem.default_configuration)
objective_default = out["objective"]
metadata_default = out["metadata"]
print(f"Accuracy Default Configuration: {objective_default:.3f}")
print("Metadata Default Configuration")
for k,v in out["metadata"].items():
print(f"\t- {k}: {v}")
WARNING:absl:At this time, the v2.11+ optimizer `tf.keras.optimizers.Adam` runs slowly on M1/M2 Macs, please use the legacy TF-Keras optimizer instead, located at `tf.keras.optimizers.legacy.Adam`.
Accuracy Default Configuration: 0.803
Metadata Default Configuration
- loss: [0.7089458107948303, 0.7017484307289124, 0.6555561423301697, 0.5975483059883118, 0.5408881306648254, 0.5184577107429504, 0.4933375418186188, 0.5026713013648987, 0.46867308020591736, 0.43367132544517517, 0.4040399491786957, 0.4410407245159149, 0.41532158851623535, 0.42353206872940063, 0.40005776286125183, 0.39258643984794617, 0.385629266500473, 0.3648240566253662, 0.37516599893569946, 0.3589401841163635, 0.36665764451026917, 0.34737372398376465, 0.3585016429424286, 0.3524617850780487, 0.32876622676849365, 0.34414738416671753, 0.3265135884284973, 0.3396293520927429, 0.33635473251342773, 0.32959404587745667, 0.32769283652305603, 0.3066641390323639, 0.3406123220920563, 0.3300285041332245, 0.3249838352203369, 0.31367143988609314, 0.2998355031013489, 0.28332749009132385, 0.28868308663368225, 0.3160744905471802, 0.290345698595047, 0.30034708976745605, 0.25978073477745056, 0.2852059602737427, 0.25552764534950256, 0.2629491984844208, 0.2993203401565552, 0.2639424502849579, 0.2701556980609894, 0.2878676652908325]
- val_loss: [0.6440699696540833, 0.5855709314346313, 0.5408179759979248, 0.5036494135856628, 0.4742478132247925, 0.45106351375579834, 0.4320058226585388, 0.4158112704753876, 0.4028586745262146, 0.39374834299087524, 0.3855816125869751, 0.37890157103538513, 0.3737941086292267, 0.3691202998161316, 0.36556652188301086, 0.36313873529434204, 0.36203548312187195, 0.3613433837890625, 0.3603682816028595, 0.3586443066596985, 0.3586690127849579, 0.35930439829826355, 0.3583105206489563, 0.35828936100006104, 0.3591354489326477, 0.3601025342941284, 0.36102163791656494, 0.36082351207733154, 0.36077219247817993, 0.3612353503704071, 0.3610706627368927, 0.36180663108825684, 0.36208730936050415, 0.36208462715148926, 0.36210131645202637, 0.36168861389160156, 0.36216410994529724, 0.3641442358493805, 0.36540892720222473, 0.366742879152298, 0.36642923951148987, 0.3666687607765198, 0.3686812222003937, 0.37019941210746765, 0.3713856339454651, 0.37242332100868225, 0.37362170219421387, 0.3748522102832794, 0.37537622451782227, 0.37630414962768555]
- accuracy: [0.5330578684806824, 0.586776852607727, 0.6198347210884094, 0.6570248007774353, 0.7314049601554871, 0.7520661354064941, 0.7768595218658447, 0.7479338645935059, 0.7561983466148376, 0.8016529083251953, 0.8181818127632141, 0.7768595218658447, 0.7851239442825317, 0.8264462947845459, 0.797520637512207, 0.8140496015548706, 0.8347107172012329, 0.8264462947845459, 0.797520637512207, 0.8264462947845459, 0.8305785059928894, 0.8429751992225647, 0.8388429880142212, 0.8305785059928894, 0.8636363744735718, 0.8471074104309082, 0.8719007968902588, 0.8512396812438965, 0.8305785059928894, 0.8429751992225647, 0.8471074104309082, 0.8677685856819153, 0.8347107172012329, 0.8719007968902588, 0.8471074104309082, 0.85537189245224, 0.8760330677032471, 0.8925619721412659, 0.8801652789115906, 0.8595041036605835, 0.8966942429542542, 0.8595041036605835, 0.8884297609329224, 0.8842975497245789, 0.9049586653709412, 0.9008264541625977, 0.8719007968902588, 0.9090909361839294, 0.8719007968902588, 0.8842975497245789]
- val_accuracy: [0.6393442749977112, 0.7540983557701111, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8360655903816223, 0.8360655903816223, 0.8196721076965332, 0.8196721076965332, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889]
- num_parameters: 18
- num_parameters_train: 1217
2.7. Define the evaluator object#
Evaluator
object allows to change the parallelization backend used by DeepHyper.run_function
to be instantiated.method
defines the backend (e.g., "ray"
) and the method_kwargs
corresponds to keyword arguments of this chosen method
.evaluator = Evaluator.create(run_function, method, method_kwargs)
Once created the evaluator.num_workers
gives access to the number of available parallel workers.
Finally, to submit and collect tasks to the evaluator one just needs to use the following interface:
configs = [{"units": 8, ...}, ...]
evaluator.submit(configs)
...
# To collect the first finished task (asynchronous)
tasks_done = evaluator.get("BATCH", size=1)
# To collect all of the pending tasks (synchronous)
tasks_done = evaluator.get("ALL")
Warning
Each Evaluator
saves its own state, therefore it is crucial to create a new evaluator when launching a fresh search.
[22]:
from deephyper.evaluator import Evaluator
from deephyper.evaluator.callback import TqdmCallback
def get_evaluator(run_function):
# Default arguments for Ray: 1 worker and 1 worker per evaluation
method_kwargs = {
"num_cpus": 1,
"num_cpus_per_task": 1,
"callbacks": [TqdmCallback()]
}
# If GPU devices are detected then it will create 'n_gpus' workers
# and use 1 worker for each evaluation
if is_gpu_available:
method_kwargs["num_cpus"] = n_gpus
method_kwargs["num_gpus"] = n_gpus
method_kwargs["num_cpus_per_task"] = 1
method_kwargs["num_gpus_per_task"] = 1
evaluator = Evaluator.create(
run_function,
method="ray",
method_kwargs=method_kwargs
)
print(f"Created new evaluator with {evaluator.num_workers} worker{'s' if evaluator.num_workers > 1 else ''} and config: {method_kwargs}", )
return evaluator
evaluator_1 = get_evaluator(run)
Created new evaluator with 1 worker and config: {'num_cpus': 1, 'num_cpus_per_task': 1, 'callbacks': [<deephyper.evaluator.callback.TqdmCallback object at 0x3df0ce2a0>]}
2.8. Define and run the centralized Bayesian optimization search (CBO)#
A primary pillar of hyperparameter search in DeepHyper is given by a centralized Bayesian optimization search (henceforth CBO). CBO may be described in the following algorithm:
Following the parallelized evaluation of these configurations, a low-fidelity and high efficiency model (henceforth “the surrogate”) is devised to reproduce the relationship between the input variables involved in the model (i.e., the choice of hyperparameters) and the outputs (which are generally a measure of validation data accuracy).
After obtaining this surrogate of the validation accuracy, we may utilize ideas from classical methods in Bayesian optimization literature for adaptively sample the search space of hyperparameters.
First, the surrogate is used to obtain an estimate for the mean value of the validation accuracy at a certain sampling location \(x\) in addition to an estimated variance. The latter requirement restricts us to the use of high efficiency data-driven modeling strategies that have inbuilt variance estimates (such as a Gaussian process or Random Forest regressor).
Regions where the mean is high represent opportunities for exploitation and regions where the variance is high represent opportunities for exploration. An optimistic acquisition function called UCB can be constructed using these two quantities:
The unevaluated hyperparameter configurations that maximize the acquisition function are chosen for the next batch of evaluations.
Note that the choice of the variance weighting parameter \(\kappa\) controls the degree of exploration in the hyperparameter search with zero indicating purely exploitation (unseen configurations where the predicted accuracy is highest will be sampled).
The top s
configurations are selected for the new batch. The following schematic demonstrates this process:
The process of obtaining s
configurations relies on the “constant-liar” strategy where a sampled configuration is mapped to a dummy output given by a bulk metric of all the evaluated configurations thus far (such as the maximum, mean or median validation accuracy).
Prior to sampling the next configuration by acquisition function maximization, the surrogate is retrained with the dummy output as a data point. As the true validation accuracy becomes available for one of the sampled configurations, the dummy output is replaced and the surrogate is updated.
This allows for scalable asynchronous (or batch synchronous) sampling of new hyperparameter configurations.
2.8.1. Choice of surrogate model#
Users should note that our choice of the surrogate is given by the Random Forest regressor due to its ability to handle non-ordinal data (hyperparameter configurations may not be purely continuous or even numerical). Evidence for how they outperform other methods (such as Gaussian processes) is also available in [1]
2.8.1.1. Setup CBO#
We create the CBO using the problem
and evaluator
defined above.
[26]:
from deephyper.hpo import CBO
# Uncomment the following line to show the arguments of CBO.
# help(CBO)
[28]:
# Instanciate the search with the problem and the evaluator that we created before
search = CBO(problem, evaluator_1, acq_optimizer="mixedga", acq_optimizer_freq=1, initial_points=[problem.default_configuration])
Note
All DeepHyper’s search algorithm have two stopping criteria:
max_evals (int)
: Defines the maximum number of evaluations that we want to perform. Default to -1 for an infinite number.timeout (int)
: Defines a time budget (in seconds) before stopping the search. Default to None for an infinite time budget.
[29]:
results = search.search(max_evals=50)
Warning
The search
call does not output any information about the current status of the search. However, a results.csv
file is created in the local directly and can be visualized to see finished tasks.
The returned results
is a Pandas Dataframe where columns starting by "p:"
are hyperparameters, columns starting by "m:"
are additional metadata (from the user or from the Evaluator
) as well as the objective
value and the job_id
:
job_id
is a unique identifier corresponding to the order of creation of tasks.objective
is the value returned by the run-function.m:timestamp_submit
is the time (in seconds) when the task was created by the evaluator since the creation of the evaluator.m:timestamp_gather
is the time (in seconds) when the task was received after finishing by the evaluator since the creation of the evaluator.
[30]:
results
[30]:
p:activation | p:batch_size | p:dropout_rate | p:learning_rate | p:num_epochs | p:units | objective | job_id | m:timestamp_submit | m:loss | m:val_loss | m:accuracy | m:val_accuracy | m:num_parameters | m:num_parameters_train | m:timestamp_gather | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | relu | 32 | 0.500000 | 0.001000 | 50 | 32 | 0.803279 | 0 | 50.325637 | [0.706670343875885, 0.6361508369445801, 0.5850... | [0.6415352821350098, 0.5993135571479797, 0.564... | [0.5123966932296753, 0.6280992031097412, 0.710... | [0.6721311211585999, 0.7540983557701111, 0.770... | 18 | 1217 | 56.067836 |
1 | linear | 114 | 0.562393 | 0.002370 | 47 | 128 | 0.803279 | 1 | 56.123101 | [0.6474491953849792, 0.49577072262763977, 0.41... | [0.45269274711608887, 0.39127492904663086, 0.3... | [0.6652892827987671, 0.7685950398445129, 0.805... | [0.7868852615356445, 0.7704917788505554, 0.819... | 18 | 4865 | 57.814659 |
2 | linear | 9 | 0.533665 | 0.000075 | 69 | 84 | 0.836066 | 2 | 58.013101 | [0.6138803958892822, 0.5668570399284363, 0.563... | [0.5532547235488892, 0.533083975315094, 0.5151... | [0.6776859760284424, 0.7438016533851624, 0.706... | [0.7868852615356445, 0.8032786846160889, 0.803... | 18 | 3193 | 61.268282 |
3 | linear | 230 | 0.319679 | 0.001157 | 45 | 59 | 0.836066 | 3 | 61.298386 | [0.7759286761283875, 0.6881203055381775, 0.633... | [0.6713256239891052, 0.6128526329994202, 0.564... | [0.5123966932296753, 0.6198347210884094, 0.669... | [0.6065573692321777, 0.688524603843689, 0.7049... | 18 | 2243 | 63.091872 |
4 | linear | 50 | 0.011336 | 0.000046 | 96 | 79 | 0.852459 | 4 | 63.144296 | [0.7343717813491821, 0.7315843105316162, 0.720... | [0.753821849822998, 0.7444661855697632, 0.7352... | [0.4793388545513153, 0.4958677589893341, 0.537... | [0.39344263076782227, 0.44262295961380005, 0.4... | 18 | 3003 | 65.508172 |
5 | softsign | 68 | 0.194614 | 0.000110 | 71 | 98 | 0.770492 | 5 | 65.540140 | [0.7060982584953308, 0.7222269773483276, 0.696... | [0.706885576248169, 0.6966867446899414, 0.6866... | [0.5041322112083435, 0.4628099203109741, 0.528... | [0.5245901346206665, 0.5573770403862, 0.573770... | 18 | 3725 | 67.730778 |
6 | linear | 12 | 0.032802 | 0.003319 | 94 | 124 | 0.803279 | 6 | 67.762288 | [0.3936194181442261, 0.298471599817276, 0.2840... | [0.44452720880508423, 0.4467885494232178, 0.46... | [0.797520637512207, 0.8595041036605835, 0.8636... | [0.8524590134620667, 0.8360655903816223, 0.819... | 18 | 4713 | 70.985557 |
7 | gelu | 85 | 0.212161 | 0.002263 | 29 | 53 | 0.836066 | 7 | 71.017057 | [0.7266819477081299, 0.650373101234436, 0.6081... | [0.6597949862480164, 0.5888913869857788, 0.532... | [0.4917355477809906, 0.6652892827987671, 0.714... | [0.6721311211585999, 0.7540983557701111, 0.786... | 18 | 2015 | 72.507068 |
8 | selu | 164 | 0.540524 | 0.001172 | 14 | 48 | 0.836066 | 8 | 72.537765 | [0.9216533899307251, 0.8793801665306091, 0.743... | [0.7923454642295837, 0.7235696315765381, 0.663... | [0.44214877486228943, 0.4876033067703247, 0.53... | [0.5245901346206665, 0.5409836173057556, 0.639... | 18 | 1825 | 74.019521 |
9 | relu | 139 | 0.356863 | 0.000101 | 59 | 94 | 0.737705 | 9 | 74.053959 | [0.6047406196594238, 0.6085638999938965, 0.576... | [0.6019086241722107, 0.598895788192749, 0.5958... | [0.6652892827987671, 0.6776859760284424, 0.739... | [0.6721311211585999, 0.6721311211585999, 0.672... | 18 | 3573 | 75.723831 |
10 | linear | 148 | 0.005233 | 0.000012 | 100 | 91 | 0.557377 | 10 | 76.229987 | [0.8135287761688232, 0.8092227578163147, 0.812... | [0.7794104218482971, 0.7784043550491333, 0.777... | [0.41735535860061646, 0.42561984062194824, 0.4... | [0.39344263076782227, 0.39344263076782227, 0.3... | 18 | 3459 | 78.200037 |
11 | linear | 50 | 0.014352 | 0.000025 | 100 | 79 | 0.786885 | 11 | 78.871375 | [0.9058765769004822, 0.8957070708274841, 0.892... | [0.9036123752593994, 0.8989002704620361, 0.894... | [0.2933884263038635, 0.3264462947845459, 0.314... | [0.32786884903907776, 0.32786884903907776, 0.3... | 18 | 3003 | 81.243957 |
12 | linear | 109 | 0.000487 | 0.000045 | 96 | 79 | 0.803279 | 12 | 81.892266 | [0.8505944013595581, 0.8450313210487366, 0.839... | [0.8944583535194397, 0.888396143913269, 0.8825... | [0.2975206673145294, 0.3016528785228729, 0.305... | [0.26229506731033325, 0.26229506731033325, 0.2... | 18 | 3003 | 83.936618 |
13 | relu | 50 | 0.011727 | 0.000049 | 94 | 79 | 0.786885 | 13 | 84.554558 | [0.772240400314331, 0.7702962160110474, 0.7616... | [0.7718428373336792, 0.7672584652900696, 0.762... | [0.35537189245224, 0.35537189245224, 0.3801652... | [0.3442623019218445, 0.3442623019218445, 0.360... | 18 | 3003 | 86.888750 |
14 | linear | 48 | 0.010876 | 0.000047 | 96 | 86 | 0.786885 | 14 | 87.831352 | [0.7737085223197937, 0.7620970010757446, 0.758... | [0.7752880454063416, 0.7666113972663879, 0.757... | [0.41322314739227295, 0.43388429284095764, 0.4... | [0.44262295961380005, 0.4754098355770111, 0.52... | 18 | 3269 | 90.291341 |
15 | linear | 50 | 0.012980 | 0.000046 | 96 | 81 | 0.852459 | 15 | 91.166686 | [0.6715313196182251, 0.6666457056999207, 0.660... | [0.5990051031112671, 0.5933732986450195, 0.587... | [0.6404958963394165, 0.6611570119857788, 0.669... | [0.6721311211585999, 0.6721311211585999, 0.704... | 18 | 3079 | 93.352006 |
16 | linear | 49 | 0.010183 | 0.000047 | 98 | 81 | 0.803279 | 16 | 94.176494 | [0.9817445278167725, 0.9637849926948547, 0.955... | [1.0305403470993042, 1.0171641111373901, 1.003... | [0.32231405377388, 0.3388429880142212, 0.35950... | [0.3606557250022888, 0.3606557250022888, 0.360... | 18 | 3079 | 96.356579 |
17 | linear | 51 | 0.011189 | 0.000035 | 95 | 81 | 0.754098 | 17 | 97.268516 | [0.8748018741607666, 0.8668413162231445, 0.861... | [0.8491439819335938, 0.8422006368637085, 0.835... | [0.3264462947845459, 0.3181818127632141, 0.347... | [0.2786885201931, 0.2950819730758667, 0.295081... | 18 | 3079 | 99.473332 |
18 | linear | 50 | 0.019094 | 0.000046 | 95 | 79 | 0.836066 | 18 | 100.300250 | [0.6110990047454834, 0.6006994843482971, 0.598... | [0.5857220888137817, 0.580162525177002, 0.5746... | [0.7438016533851624, 0.7479338645935059, 0.760... | [0.8360655903816223, 0.8360655903816223, 0.836... | 18 | 3003 | 102.724646 |
19 | linear | 51 | 0.012842 | 0.000041 | 96 | 81 | 0.803279 | 19 | 103.713039 | [0.6344509720802307, 0.6267118453979492, 0.626... | [0.6352565884590149, 0.6293896436691284, 0.623... | [0.6900826692581177, 0.6983470916748047, 0.710... | [0.6557376980781555, 0.6557376980781555, 0.655... | 18 | 3079 | 105.967294 |
20 | relu | 171 | 0.508927 | 0.001910 | 10 | 83 | 0.770492 | 20 | 106.621811 | [0.7741369009017944, 0.6997448801994324, 0.641... | [0.6997786164283752, 0.6350088119506836, 0.582... | [0.46694216132164, 0.557851254940033, 0.648760... | [0.5737704634666443, 0.6557376980781555, 0.737... | 18 | 3155 | 107.922894 |
21 | gelu | 247 | 0.291055 | 0.000012 | 14 | 57 | 0.393443 | 21 | 108.269415 | [0.7637456655502319, 0.7536334991455078, 0.767... | [0.770416259765625, 0.7701885104179382, 0.7699... | [0.4586776793003082, 0.42975205183029175, 0.44... | [0.39344263076782227, 0.39344263076782227, 0.3... | 18 | 2167 | 109.827998 |
22 | elu | 247 | 0.330785 | 0.000270 | 32 | 57 | 0.754098 | 22 | 110.438860 | [0.7152367234230042, 0.7155708074569702, 0.698... | [0.7250485420227051, 0.7178604006767273, 0.710... | [0.5123966932296753, 0.4958677589893341, 0.570... | [0.44262295961380005, 0.44262295961380005, 0.4... | 18 | 2167 | 111.909659 |
23 | gelu | 256 | 0.174206 | 0.001193 | 22 | 54 | 0.836066 | 23 | 112.662811 | [0.7355976700782776, 0.7147685885429382, 0.687... | [0.7190053462982178, 0.6970692276954651, 0.676... | [0.44214877486228943, 0.4834710657596588, 0.55... | [0.3606557250022888, 0.4754098355770111, 0.573... | 18 | 2053 | 114.152527 |
24 | elu | 50 | 0.012164 | 0.000047 | 96 | 73 | 0.786885 | 24 | 115.373385 | [0.7974379062652588, 0.7888069748878479, 0.779... | [0.8222996592521667, 0.8145594596862793, 0.806... | [0.39669421315193176, 0.40082645416259766, 0.4... | [0.37704917788505554, 0.39344263076782227, 0.3... | 18 | 2775 | 117.593125 |
25 | hard_sigmoid | 244 | 0.215005 | 0.000582 | 13 | 56 | 0.229508 | 25 | 118.166421 | [0.8906719088554382, 0.9096777439117432, 0.914... | [0.9311334490776062, 0.9191083312034607, 0.907... | [0.3140496015548706, 0.28925618529319763, 0.30... | [0.2295081913471222, 0.2295081913471222, 0.229... | 18 | 2129 | 119.727548 |
26 | hard_sigmoid | 228 | 0.195278 | 0.001142 | 16 | 48 | 0.770492 | 26 | 120.442391 | [0.7672627568244934, 0.7242721319198608, 0.727... | [0.7250207662582397, 0.701765239238739, 0.6803... | [0.3636363744735718, 0.43801653385162354, 0.45... | [0.26229506731033325, 0.4590163826942444, 0.62... | 18 | 1825 | 121.826843 |
27 | selu | 160 | 0.546215 | 0.000789 | 13 | 11 | 0.688525 | 27 | 122.283213 | [0.7485548257827759, 0.7522017955780029, 0.763... | [0.7076205015182495, 0.6924997568130493, 0.678... | [0.5950413346290588, 0.586776852607727, 0.6157... | [0.6065573692321777, 0.6065573692321777, 0.622... | 18 | 419 | 123.619181 |
28 | linear | 8 | 0.526041 | 0.000060 | 13 | 82 | 0.786885 | 28 | 124.196094 | [0.8352355360984802, 0.7898195385932922, 0.765... | [0.7591118812561035, 0.7175912857055664, 0.682... | [0.43388429284095764, 0.5206611752510071, 0.49... | [0.4590163826942444, 0.5245901346206665, 0.606... | 18 | 3117 | 125.941018 |
29 | gelu | 256 | 0.188568 | 0.001515 | 21 | 8 | 0.737705 | 29 | 126.460014 | [0.7105452418327332, 0.7035156488418579, 0.698... | [0.7025739550590515, 0.696033239364624, 0.6896... | [0.5537189841270447, 0.5371900796890259, 0.574... | [0.6229507923126221, 0.6393442749977112, 0.639... | 18 | 305 | 128.180563 |
30 | gelu | 256 | 0.244248 | 0.001660 | 23 | 56 | 0.786885 | 30 | 128.727047 | [0.6546282768249512, 0.6317970752716064, 0.637... | [0.572919487953186, 0.556894838809967, 0.54188... | [0.6818181872367859, 0.7066115736961365, 0.706... | [0.7540983557701111, 0.7540983557701111, 0.754... | 18 | 2129 | 130.153736 |
31 | linear | 224 | 0.305334 | 0.000660 | 44 | 61 | 0.819672 | 31 | 130.919080 | [0.9484014511108398, 0.9179948568344116, 0.877... | [0.924122154712677, 0.8768399953842163, 0.8322... | [0.32231405377388, 0.28925618529319763, 0.3512... | [0.19672131538391113, 0.19672131538391113, 0.2... | 18 | 2319 | 132.589723 |
32 | linear | 50 | 0.479327 | 0.000068 | 96 | 17 | 0.737705 | 32 | 133.278981 | [0.9485551118850708, 0.9460268616676331, 0.900... | [0.8557782769203186, 0.8498371243476868, 0.843... | [0.5123966932296753, 0.5247933864593506, 0.483... | [0.5245901346206665, 0.5245901346206665, 0.524... | 18 | 647 | 135.438342 |
33 | selu | 164 | 0.561031 | 0.000726 | 14 | 55 | 0.819672 | 33 | 136.143825 | [0.8745442628860474, 0.8545739054679871, 0.796... | [0.7864746451377869, 0.7405412793159485, 0.699... | [0.4752066135406494, 0.5, 0.5247933864593506, ... | [0.4262295067310333, 0.5081967115402222, 0.606... | 18 | 2091 | 137.718102 |
34 | selu | 254 | 0.322091 | 0.001231 | 12 | 47 | 0.819672 | 34 | 138.606087 | [0.7325422763824463, 0.7275435328483582, 0.714... | [0.6849062442779541, 0.6536070108413696, 0.624... | [0.6115702390670776, 0.56611567735672, 0.59090... | [0.5573770403862, 0.5901639461517334, 0.622950... | 18 | 1787 | 140.006373 |
35 | gelu | 96 | 0.235332 | 0.000678 | 35 | 49 | 0.819672 | 35 | 140.362036 | [0.6482117176055908, 0.640116274356842, 0.6336... | [0.6331732869148254, 0.6124953031539917, 0.593... | [0.64462810754776, 0.6776859760284424, 0.70247... | [0.7377049326896667, 0.7049180269241333, 0.721... | 18 | 1863 | 141.949003 |
36 | linear | 50 | 0.327453 | 0.000119 | 96 | 81 | 0.819672 | 36 | 142.978975 | [0.7736309170722961, 0.6837102770805359, 0.716... | [0.6871227622032166, 0.6735551953315735, 0.660... | [0.6033057570457458, 0.6776859760284424, 0.623... | [0.7213114500045776, 0.7213114500045776, 0.770... | 18 | 3079 | 145.525915 |
37 | gelu | 134 | 0.098413 | 0.000580 | 22 | 54 | 0.786885 | 37 | 146.242206 | [0.6934930682182312, 0.6824401617050171, 0.672... | [0.7072174549102783, 0.6886430382728577, 0.670... | [0.4793388545513153, 0.5206611752510071, 0.582... | [0.4590163826942444, 0.5245901346206665, 0.655... | 18 | 2053 | 147.982779 |
38 | hard_sigmoid | 244 | 0.317880 | 0.001168 | 45 | 53 | 0.786885 | 38 | 148.651484 | [0.9446930289268494, 0.9471455216407776, 0.897... | [0.9400205612182617, 0.9183498620986938, 0.897... | [0.3140496015548706, 0.3305785059928894, 0.322... | [0.2295081913471222, 0.2295081913471222, 0.229... | 18 | 2015 | 150.211455 |
39 | linear | 11 | 0.539959 | 0.000073 | 71 | 42 | 0.836066 | 39 | 151.001553 | [0.7638189792633057, 0.768632173538208, 0.8030... | [0.7098137140274048, 0.6863061785697937, 0.662... | [0.5289255976676941, 0.5, 0.5289255976676941, ... | [0.5573770403862, 0.5573770403862, 0.557377040... | 18 | 1597 | 153.791880 |
40 | linear | 9 | 0.594717 | 0.000072 | 75 | 22 | 0.836066 | 40 | 154.191006 | [0.9886115193367004, 0.8805288672447205, 0.913... | [0.8028327226638794, 0.7799118161201477, 0.759... | [0.45041322708129883, 0.4958677589893341, 0.45... | [0.49180328845977783, 0.5245901346206665, 0.54... | 18 | 837 | 157.699540 |
41 | linear | 9 | 0.583705 | 0.000074 | 70 | 11 | 0.704918 | 41 | 158.364740 | [1.2854955196380615, 1.2296828031539917, 1.140... | [1.1728463172912598, 1.1491509675979614, 1.125... | [0.2975206673145294, 0.3429751992225647, 0.371... | [0.2295081913471222, 0.26229506731033325, 0.27... | 18 | 419 | 161.862414 |
42 | linear | 10 | 0.579446 | 0.000080 | 70 | 21 | 0.819672 | 42 | 162.550486 | [0.7621033191680908, 0.7450456023216248, 0.790... | [0.6448240280151367, 0.630098819732666, 0.6165... | [0.5619834661483765, 0.5495867729187012, 0.570... | [0.5901639461517334, 0.6065573692321777, 0.639... | 18 | 799 | 165.567682 |
43 | linear | 11 | 0.542695 | 0.000010 | 72 | 41 | 0.704918 | 43 | 166.191566 | [0.7994367480278015, 0.7612443566322327, 0.755... | [0.701274037361145, 0.6975817084312439, 0.6937... | [0.5082644820213318, 0.5454545617103577, 0.557... | [0.5245901346206665, 0.5245901346206665, 0.540... | 18 | 1559 | 169.025697 |
44 | hard_sigmoid | 9 | 0.531649 | 0.000072 | 72 | 56 | 0.803279 | 44 | 169.882809 | [0.6588014364242554, 0.6326671242713928, 0.658... | [0.5621706247329712, 0.5559288263320923, 0.550... | [0.6115702390670776, 0.6652892827987671, 0.657... | [0.7704917788505554, 0.7704917788505554, 0.770... | 18 | 2129 | 172.994597 |
45 | selu | 149 | 0.537188 | 0.001174 | 14 | 30 | 0.836066 | 45 | 173.636141 | [0.9731682538986206, 0.9632638692855835, 0.834... | [0.7772341370582581, 0.7184216976165771, 0.666... | [0.41322314739227295, 0.4586776793003082, 0.50... | [0.4262295067310333, 0.49180328845977783, 0.59... | 18 | 1141 | 175.037412 |
46 | linear | 49 | 0.011794 | 0.000046 | 96 | 69 | 0.754098 | 46 | 175.385063 | [0.9058528542518616, 0.8961911797523499, 0.890... | [0.9788206815719604, 0.9681548476219177, 0.957... | [0.3677685856819153, 0.3595041334629059, 0.371... | [0.26229506731033325, 0.26229506731033325, 0.2... | 18 | 2623 | 177.826115 |
47 | selu | 157 | 0.534877 | 0.001165 | 16 | 23 | 0.803279 | 47 | 178.533461 | [1.0175780057907104, 0.9338928461074829, 0.922... | [0.9525404572486877, 0.9009444713592529, 0.852... | [0.42148759961128235, 0.46694216132164, 0.4710... | [0.32786884903907776, 0.4098360538482666, 0.45... | 18 | 875 | 179.892400 |
48 | gelu | 86 | 0.196295 | 0.002204 | 11 | 54 | 0.819672 | 48 | 180.589500 | [0.5843045711517334, 0.5370311141014099, 0.489... | [0.5310717821121216, 0.48231208324432373, 0.44... | [0.7355371713638306, 0.7685950398445129, 0.789... | [0.7868852615356445, 0.7868852615356445, 0.754... | 18 | 2053 | 181.977024 |
49 | linear | 9 | 0.591808 | 0.000048 | 80 | 20 | 0.819672 | 49 | 182.668562 | [1.0278195142745972, 1.0783154964447021, 1.051... | [0.9265504479408264, 0.9104623198509216, 0.894... | [0.42975205183029175, 0.37603306770324707, 0.3... | [0.3606557250022888, 0.3606557250022888, 0.377... | 18 | 761 | 185.918045 |
The search can be continued without any issue.
[31]:
results = search.search(max_evals=5)
results
[31]:
p:activation | p:batch_size | p:dropout_rate | p:learning_rate | p:num_epochs | p:units | objective | job_id | m:timestamp_submit | m:loss | m:val_loss | m:accuracy | m:val_accuracy | m:num_parameters | m:num_parameters_train | m:timestamp_gather | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | relu | 32 | 0.500000 | 0.001000 | 50 | 32 | 0.803279 | 0 | 50.325637 | [0.706670343875885, 0.6361508369445801, 0.5850... | [0.6415352821350098, 0.5993135571479797, 0.564... | [0.5123966932296753, 0.6280992031097412, 0.710... | [0.6721311211585999, 0.7540983557701111, 0.770... | 18 | 1217 | 56.067836 |
1 | linear | 114 | 0.562393 | 0.002370 | 47 | 128 | 0.803279 | 1 | 56.123101 | [0.6474491953849792, 0.49577072262763977, 0.41... | [0.45269274711608887, 0.39127492904663086, 0.3... | [0.6652892827987671, 0.7685950398445129, 0.805... | [0.7868852615356445, 0.7704917788505554, 0.819... | 18 | 4865 | 57.814659 |
2 | linear | 9 | 0.533665 | 0.000075 | 69 | 84 | 0.836066 | 2 | 58.013101 | [0.6138803958892822, 0.5668570399284363, 0.563... | [0.5532547235488892, 0.533083975315094, 0.5151... | [0.6776859760284424, 0.7438016533851624, 0.706... | [0.7868852615356445, 0.8032786846160889, 0.803... | 18 | 3193 | 61.268282 |
3 | linear | 230 | 0.319679 | 0.001157 | 45 | 59 | 0.836066 | 3 | 61.298386 | [0.7759286761283875, 0.6881203055381775, 0.633... | [0.6713256239891052, 0.6128526329994202, 0.564... | [0.5123966932296753, 0.6198347210884094, 0.669... | [0.6065573692321777, 0.688524603843689, 0.7049... | 18 | 2243 | 63.091872 |
4 | linear | 50 | 0.011336 | 0.000046 | 96 | 79 | 0.852459 | 4 | 63.144296 | [0.7343717813491821, 0.7315843105316162, 0.720... | [0.753821849822998, 0.7444661855697632, 0.7352... | [0.4793388545513153, 0.4958677589893341, 0.537... | [0.39344263076782227, 0.44262295961380005, 0.4... | 18 | 3003 | 65.508172 |
5 | softsign | 68 | 0.194614 | 0.000110 | 71 | 98 | 0.770492 | 5 | 65.540140 | [0.7060982584953308, 0.7222269773483276, 0.696... | [0.706885576248169, 0.6966867446899414, 0.6866... | [0.5041322112083435, 0.4628099203109741, 0.528... | [0.5245901346206665, 0.5573770403862, 0.573770... | 18 | 3725 | 67.730778 |
6 | linear | 12 | 0.032802 | 0.003319 | 94 | 124 | 0.803279 | 6 | 67.762288 | [0.3936194181442261, 0.298471599817276, 0.2840... | [0.44452720880508423, 0.4467885494232178, 0.46... | [0.797520637512207, 0.8595041036605835, 0.8636... | [0.8524590134620667, 0.8360655903816223, 0.819... | 18 | 4713 | 70.985557 |
7 | gelu | 85 | 0.212161 | 0.002263 | 29 | 53 | 0.836066 | 7 | 71.017057 | [0.7266819477081299, 0.650373101234436, 0.6081... | [0.6597949862480164, 0.5888913869857788, 0.532... | [0.4917355477809906, 0.6652892827987671, 0.714... | [0.6721311211585999, 0.7540983557701111, 0.786... | 18 | 2015 | 72.507068 |
8 | selu | 164 | 0.540524 | 0.001172 | 14 | 48 | 0.836066 | 8 | 72.537765 | [0.9216533899307251, 0.8793801665306091, 0.743... | [0.7923454642295837, 0.7235696315765381, 0.663... | [0.44214877486228943, 0.4876033067703247, 0.53... | [0.5245901346206665, 0.5409836173057556, 0.639... | 18 | 1825 | 74.019521 |
9 | relu | 139 | 0.356863 | 0.000101 | 59 | 94 | 0.737705 | 9 | 74.053959 | [0.6047406196594238, 0.6085638999938965, 0.576... | [0.6019086241722107, 0.598895788192749, 0.5958... | [0.6652892827987671, 0.6776859760284424, 0.739... | [0.6721311211585999, 0.6721311211585999, 0.672... | 18 | 3573 | 75.723831 |
10 | linear | 148 | 0.005233 | 0.000012 | 100 | 91 | 0.557377 | 10 | 76.229987 | [0.8135287761688232, 0.8092227578163147, 0.812... | [0.7794104218482971, 0.7784043550491333, 0.777... | [0.41735535860061646, 0.42561984062194824, 0.4... | [0.39344263076782227, 0.39344263076782227, 0.3... | 18 | 3459 | 78.200037 |
11 | linear | 50 | 0.014352 | 0.000025 | 100 | 79 | 0.786885 | 11 | 78.871375 | [0.9058765769004822, 0.8957070708274841, 0.892... | [0.9036123752593994, 0.8989002704620361, 0.894... | [0.2933884263038635, 0.3264462947845459, 0.314... | [0.32786884903907776, 0.32786884903907776, 0.3... | 18 | 3003 | 81.243957 |
12 | linear | 109 | 0.000487 | 0.000045 | 96 | 79 | 0.803279 | 12 | 81.892266 | [0.8505944013595581, 0.8450313210487366, 0.839... | [0.8944583535194397, 0.888396143913269, 0.8825... | [0.2975206673145294, 0.3016528785228729, 0.305... | [0.26229506731033325, 0.26229506731033325, 0.2... | 18 | 3003 | 83.936618 |
13 | relu | 50 | 0.011727 | 0.000049 | 94 | 79 | 0.786885 | 13 | 84.554558 | [0.772240400314331, 0.7702962160110474, 0.7616... | [0.7718428373336792, 0.7672584652900696, 0.762... | [0.35537189245224, 0.35537189245224, 0.3801652... | [0.3442623019218445, 0.3442623019218445, 0.360... | 18 | 3003 | 86.888750 |
14 | linear | 48 | 0.010876 | 0.000047 | 96 | 86 | 0.786885 | 14 | 87.831352 | [0.7737085223197937, 0.7620970010757446, 0.758... | [0.7752880454063416, 0.7666113972663879, 0.757... | [0.41322314739227295, 0.43388429284095764, 0.4... | [0.44262295961380005, 0.4754098355770111, 0.52... | 18 | 3269 | 90.291341 |
15 | linear | 50 | 0.012980 | 0.000046 | 96 | 81 | 0.852459 | 15 | 91.166686 | [0.6715313196182251, 0.6666457056999207, 0.660... | [0.5990051031112671, 0.5933732986450195, 0.587... | [0.6404958963394165, 0.6611570119857788, 0.669... | [0.6721311211585999, 0.6721311211585999, 0.704... | 18 | 3079 | 93.352006 |
16 | linear | 49 | 0.010183 | 0.000047 | 98 | 81 | 0.803279 | 16 | 94.176494 | [0.9817445278167725, 0.9637849926948547, 0.955... | [1.0305403470993042, 1.0171641111373901, 1.003... | [0.32231405377388, 0.3388429880142212, 0.35950... | [0.3606557250022888, 0.3606557250022888, 0.360... | 18 | 3079 | 96.356579 |
17 | linear | 51 | 0.011189 | 0.000035 | 95 | 81 | 0.754098 | 17 | 97.268516 | [0.8748018741607666, 0.8668413162231445, 0.861... | [0.8491439819335938, 0.8422006368637085, 0.835... | [0.3264462947845459, 0.3181818127632141, 0.347... | [0.2786885201931, 0.2950819730758667, 0.295081... | 18 | 3079 | 99.473332 |
18 | linear | 50 | 0.019094 | 0.000046 | 95 | 79 | 0.836066 | 18 | 100.300250 | [0.6110990047454834, 0.6006994843482971, 0.598... | [0.5857220888137817, 0.580162525177002, 0.5746... | [0.7438016533851624, 0.7479338645935059, 0.760... | [0.8360655903816223, 0.8360655903816223, 0.836... | 18 | 3003 | 102.724646 |
19 | linear | 51 | 0.012842 | 0.000041 | 96 | 81 | 0.803279 | 19 | 103.713039 | [0.6344509720802307, 0.6267118453979492, 0.626... | [0.6352565884590149, 0.6293896436691284, 0.623... | [0.6900826692581177, 0.6983470916748047, 0.710... | [0.6557376980781555, 0.6557376980781555, 0.655... | 18 | 3079 | 105.967294 |
20 | relu | 171 | 0.508927 | 0.001910 | 10 | 83 | 0.770492 | 20 | 106.621811 | [0.7741369009017944, 0.6997448801994324, 0.641... | [0.6997786164283752, 0.6350088119506836, 0.582... | [0.46694216132164, 0.557851254940033, 0.648760... | [0.5737704634666443, 0.6557376980781555, 0.737... | 18 | 3155 | 107.922894 |
21 | gelu | 247 | 0.291055 | 0.000012 | 14 | 57 | 0.393443 | 21 | 108.269415 | [0.7637456655502319, 0.7536334991455078, 0.767... | [0.770416259765625, 0.7701885104179382, 0.7699... | [0.4586776793003082, 0.42975205183029175, 0.44... | [0.39344263076782227, 0.39344263076782227, 0.3... | 18 | 2167 | 109.827998 |
22 | elu | 247 | 0.330785 | 0.000270 | 32 | 57 | 0.754098 | 22 | 110.438860 | [0.7152367234230042, 0.7155708074569702, 0.698... | [0.7250485420227051, 0.7178604006767273, 0.710... | [0.5123966932296753, 0.4958677589893341, 0.570... | [0.44262295961380005, 0.44262295961380005, 0.4... | 18 | 2167 | 111.909659 |
23 | gelu | 256 | 0.174206 | 0.001193 | 22 | 54 | 0.836066 | 23 | 112.662811 | [0.7355976700782776, 0.7147685885429382, 0.687... | [0.7190053462982178, 0.6970692276954651, 0.676... | [0.44214877486228943, 0.4834710657596588, 0.55... | [0.3606557250022888, 0.4754098355770111, 0.573... | 18 | 2053 | 114.152527 |
24 | elu | 50 | 0.012164 | 0.000047 | 96 | 73 | 0.786885 | 24 | 115.373385 | [0.7974379062652588, 0.7888069748878479, 0.779... | [0.8222996592521667, 0.8145594596862793, 0.806... | [0.39669421315193176, 0.40082645416259766, 0.4... | [0.37704917788505554, 0.39344263076782227, 0.3... | 18 | 2775 | 117.593125 |
25 | hard_sigmoid | 244 | 0.215005 | 0.000582 | 13 | 56 | 0.229508 | 25 | 118.166421 | [0.8906719088554382, 0.9096777439117432, 0.914... | [0.9311334490776062, 0.9191083312034607, 0.907... | [0.3140496015548706, 0.28925618529319763, 0.30... | [0.2295081913471222, 0.2295081913471222, 0.229... | 18 | 2129 | 119.727548 |
26 | hard_sigmoid | 228 | 0.195278 | 0.001142 | 16 | 48 | 0.770492 | 26 | 120.442391 | [0.7672627568244934, 0.7242721319198608, 0.727... | [0.7250207662582397, 0.701765239238739, 0.6803... | [0.3636363744735718, 0.43801653385162354, 0.45... | [0.26229506731033325, 0.4590163826942444, 0.62... | 18 | 1825 | 121.826843 |
27 | selu | 160 | 0.546215 | 0.000789 | 13 | 11 | 0.688525 | 27 | 122.283213 | [0.7485548257827759, 0.7522017955780029, 0.763... | [0.7076205015182495, 0.6924997568130493, 0.678... | [0.5950413346290588, 0.586776852607727, 0.6157... | [0.6065573692321777, 0.6065573692321777, 0.622... | 18 | 419 | 123.619181 |
28 | linear | 8 | 0.526041 | 0.000060 | 13 | 82 | 0.786885 | 28 | 124.196094 | [0.8352355360984802, 0.7898195385932922, 0.765... | [0.7591118812561035, 0.7175912857055664, 0.682... | [0.43388429284095764, 0.5206611752510071, 0.49... | [0.4590163826942444, 0.5245901346206665, 0.606... | 18 | 3117 | 125.941018 |
29 | gelu | 256 | 0.188568 | 0.001515 | 21 | 8 | 0.737705 | 29 | 126.460014 | [0.7105452418327332, 0.7035156488418579, 0.698... | [0.7025739550590515, 0.696033239364624, 0.6896... | [0.5537189841270447, 0.5371900796890259, 0.574... | [0.6229507923126221, 0.6393442749977112, 0.639... | 18 | 305 | 128.180563 |
30 | gelu | 256 | 0.244248 | 0.001660 | 23 | 56 | 0.786885 | 30 | 128.727047 | [0.6546282768249512, 0.6317970752716064, 0.637... | [0.572919487953186, 0.556894838809967, 0.54188... | [0.6818181872367859, 0.7066115736961365, 0.706... | [0.7540983557701111, 0.7540983557701111, 0.754... | 18 | 2129 | 130.153736 |
31 | linear | 224 | 0.305334 | 0.000660 | 44 | 61 | 0.819672 | 31 | 130.919080 | [0.9484014511108398, 0.9179948568344116, 0.877... | [0.924122154712677, 0.8768399953842163, 0.8322... | [0.32231405377388, 0.28925618529319763, 0.3512... | [0.19672131538391113, 0.19672131538391113, 0.2... | 18 | 2319 | 132.589723 |
32 | linear | 50 | 0.479327 | 0.000068 | 96 | 17 | 0.737705 | 32 | 133.278981 | [0.9485551118850708, 0.9460268616676331, 0.900... | [0.8557782769203186, 0.8498371243476868, 0.843... | [0.5123966932296753, 0.5247933864593506, 0.483... | [0.5245901346206665, 0.5245901346206665, 0.524... | 18 | 647 | 135.438342 |
33 | selu | 164 | 0.561031 | 0.000726 | 14 | 55 | 0.819672 | 33 | 136.143825 | [0.8745442628860474, 0.8545739054679871, 0.796... | [0.7864746451377869, 0.7405412793159485, 0.699... | [0.4752066135406494, 0.5, 0.5247933864593506, ... | [0.4262295067310333, 0.5081967115402222, 0.606... | 18 | 2091 | 137.718102 |
34 | selu | 254 | 0.322091 | 0.001231 | 12 | 47 | 0.819672 | 34 | 138.606087 | [0.7325422763824463, 0.7275435328483582, 0.714... | [0.6849062442779541, 0.6536070108413696, 0.624... | [0.6115702390670776, 0.56611567735672, 0.59090... | [0.5573770403862, 0.5901639461517334, 0.622950... | 18 | 1787 | 140.006373 |
35 | gelu | 96 | 0.235332 | 0.000678 | 35 | 49 | 0.819672 | 35 | 140.362036 | [0.6482117176055908, 0.640116274356842, 0.6336... | [0.6331732869148254, 0.6124953031539917, 0.593... | [0.64462810754776, 0.6776859760284424, 0.70247... | [0.7377049326896667, 0.7049180269241333, 0.721... | 18 | 1863 | 141.949003 |
36 | linear | 50 | 0.327453 | 0.000119 | 96 | 81 | 0.819672 | 36 | 142.978975 | [0.7736309170722961, 0.6837102770805359, 0.716... | [0.6871227622032166, 0.6735551953315735, 0.660... | [0.6033057570457458, 0.6776859760284424, 0.623... | [0.7213114500045776, 0.7213114500045776, 0.770... | 18 | 3079 | 145.525915 |
37 | gelu | 134 | 0.098413 | 0.000580 | 22 | 54 | 0.786885 | 37 | 146.242206 | [0.6934930682182312, 0.6824401617050171, 0.672... | [0.7072174549102783, 0.6886430382728577, 0.670... | [0.4793388545513153, 0.5206611752510071, 0.582... | [0.4590163826942444, 0.5245901346206665, 0.655... | 18 | 2053 | 147.982779 |
38 | hard_sigmoid | 244 | 0.317880 | 0.001168 | 45 | 53 | 0.786885 | 38 | 148.651484 | [0.9446930289268494, 0.9471455216407776, 0.897... | [0.9400205612182617, 0.9183498620986938, 0.897... | [0.3140496015548706, 0.3305785059928894, 0.322... | [0.2295081913471222, 0.2295081913471222, 0.229... | 18 | 2015 | 150.211455 |
39 | linear | 11 | 0.539959 | 0.000073 | 71 | 42 | 0.836066 | 39 | 151.001553 | [0.7638189792633057, 0.768632173538208, 0.8030... | [0.7098137140274048, 0.6863061785697937, 0.662... | [0.5289255976676941, 0.5, 0.5289255976676941, ... | [0.5573770403862, 0.5573770403862, 0.557377040... | 18 | 1597 | 153.791880 |
40 | linear | 9 | 0.594717 | 0.000072 | 75 | 22 | 0.836066 | 40 | 154.191006 | [0.9886115193367004, 0.8805288672447205, 0.913... | [0.8028327226638794, 0.7799118161201477, 0.759... | [0.45041322708129883, 0.4958677589893341, 0.45... | [0.49180328845977783, 0.5245901346206665, 0.54... | 18 | 837 | 157.699540 |
41 | linear | 9 | 0.583705 | 0.000074 | 70 | 11 | 0.704918 | 41 | 158.364740 | [1.2854955196380615, 1.2296828031539917, 1.140... | [1.1728463172912598, 1.1491509675979614, 1.125... | [0.2975206673145294, 0.3429751992225647, 0.371... | [0.2295081913471222, 0.26229506731033325, 0.27... | 18 | 419 | 161.862414 |
42 | linear | 10 | 0.579446 | 0.000080 | 70 | 21 | 0.819672 | 42 | 162.550486 | [0.7621033191680908, 0.7450456023216248, 0.790... | [0.6448240280151367, 0.630098819732666, 0.6165... | [0.5619834661483765, 0.5495867729187012, 0.570... | [0.5901639461517334, 0.6065573692321777, 0.639... | 18 | 799 | 165.567682 |
43 | linear | 11 | 0.542695 | 0.000010 | 72 | 41 | 0.704918 | 43 | 166.191566 | [0.7994367480278015, 0.7612443566322327, 0.755... | [0.701274037361145, 0.6975817084312439, 0.6937... | [0.5082644820213318, 0.5454545617103577, 0.557... | [0.5245901346206665, 0.5245901346206665, 0.540... | 18 | 1559 | 169.025697 |
44 | hard_sigmoid | 9 | 0.531649 | 0.000072 | 72 | 56 | 0.803279 | 44 | 169.882809 | [0.6588014364242554, 0.6326671242713928, 0.658... | [0.5621706247329712, 0.5559288263320923, 0.550... | [0.6115702390670776, 0.6652892827987671, 0.657... | [0.7704917788505554, 0.7704917788505554, 0.770... | 18 | 2129 | 172.994597 |
45 | selu | 149 | 0.537188 | 0.001174 | 14 | 30 | 0.836066 | 45 | 173.636141 | [0.9731682538986206, 0.9632638692855835, 0.834... | [0.7772341370582581, 0.7184216976165771, 0.666... | [0.41322314739227295, 0.4586776793003082, 0.50... | [0.4262295067310333, 0.49180328845977783, 0.59... | 18 | 1141 | 175.037412 |
46 | linear | 49 | 0.011794 | 0.000046 | 96 | 69 | 0.754098 | 46 | 175.385063 | [0.9058528542518616, 0.8961911797523499, 0.890... | [0.9788206815719604, 0.9681548476219177, 0.957... | [0.3677685856819153, 0.3595041334629059, 0.371... | [0.26229506731033325, 0.26229506731033325, 0.2... | 18 | 2623 | 177.826115 |
47 | selu | 157 | 0.534877 | 0.001165 | 16 | 23 | 0.803279 | 47 | 178.533461 | [1.0175780057907104, 0.9338928461074829, 0.922... | [0.9525404572486877, 0.9009444713592529, 0.852... | [0.42148759961128235, 0.46694216132164, 0.4710... | [0.32786884903907776, 0.4098360538482666, 0.45... | 18 | 875 | 179.892400 |
48 | gelu | 86 | 0.196295 | 0.002204 | 11 | 54 | 0.819672 | 48 | 180.589500 | [0.5843045711517334, 0.5370311141014099, 0.489... | [0.5310717821121216, 0.48231208324432373, 0.44... | [0.7355371713638306, 0.7685950398445129, 0.789... | [0.7868852615356445, 0.7868852615356445, 0.754... | 18 | 2053 | 181.977024 |
49 | linear | 9 | 0.591808 | 0.000048 | 80 | 20 | 0.819672 | 49 | 182.668562 | [1.0278195142745972, 1.0783154964447021, 1.051... | [0.9265504479408264, 0.9104623198509216, 0.894... | [0.42975205183029175, 0.37603306770324707, 0.3... | [0.3606557250022888, 0.3606557250022888, 0.377... | 18 | 761 | 185.918045 |
50 | linear | 9 | 0.591808 | 0.000048 | 80 | 20 | 0.770492 | 50 | 306.613144 | [0.7385859489440918, 0.6464691758155823, 0.692... | [0.61632239818573, 0.6080429553985596, 0.60057... | [0.5909090638160706, 0.6570248007774353, 0.590... | [0.6721311211585999, 0.6721311211585999, 0.704... | 18 | 761 | 309.925543 |
51 | elu | 79 | 0.210961 | 0.002270 | 29 | 53 | 0.803279 | 51 | 310.659269 | [0.7370269298553467, 0.6033796668052673, 0.516... | [0.6075996160507202, 0.5017809271812439, 0.442... | [0.42975205183029175, 0.6570248007774353, 0.76... | [0.7213114500045776, 0.8360655903816223, 0.852... | 18 | 2015 | 312.720539 |
52 | softplus | 148 | 0.547227 | 0.001533 | 13 | 30 | 0.770492 | 52 | 313.843650 | [1.0147786140441895, 0.9364703297615051, 0.922... | [0.8291641473770142, 0.7705664038658142, 0.717... | [0.44628098607063293, 0.4586776793003082, 0.49... | [0.3606557250022888, 0.4754098355770111, 0.590... | 18 | 1141 | 315.235044 |
53 | linear | 18 | 0.574718 | 0.000072 | 74 | 42 | 0.836066 | 53 | 315.653283 | [0.7922735810279846, 0.7768206596374512, 0.744... | [0.6496352553367615, 0.6359923481941223, 0.622... | [0.5206611752510071, 0.5454545617103577, 0.541... | [0.6229507923126221, 0.6065573692321777, 0.655... | 18 | 1597 | 318.200261 |
54 | linear | 11 | 0.581862 | 0.000073 | 96 | 42 | 0.836066 | 54 | 319.290598 | [0.8693579435348511, 0.8212579488754272, 0.779... | [0.8195937275886536, 0.7894597053527832, 0.761... | [0.4545454680919647, 0.5247933864593506, 0.491... | [0.44262295961380005, 0.44262295961380005, 0.4... | 18 | 1597 | 322.698709 |
Now that the search is over we can look at the evolution of performance of explored models.
[34]:
import matplotlib.pyplot as plt
from deephyper.analysis.hpo import plot_search_trajectory_single_objective_hpo
fig, ax = plot_search_trajectory_single_objective_hpo(results)
plt.title("Search Trajectory")
plt.show()
Then, we can look at the best configuration.
[49]:
i_max = results.objective.argmax()
best_job = results.iloc[i_max].to_dict()
print(f"The default configuration has an accuracy of {objective_default:.3f}. \n"
f"The best configuration found by DeepHyper has an accuracy {results['objective'].iloc[i_max]:.3f}, \n"
f"discovered after {results['m:timestamp_gather'].iloc[i_max]:.2f} secondes of search.\n")
best_job
The default configuration has an accuracy of 0.803.
The best configuration found by DeepHyper has an accuracy 0.852,
discovered after 65.51 secondes of search.
[49]:
{'p:activation': 'linear',
'p:batch_size': 50,
'p:dropout_rate': 0.0113359066468359,
'p:learning_rate': 4.570626021696808e-05,
'p:num_epochs': 96,
'p:units': 79,
'objective': 0.8524590134620667,
'job_id': 4,
'm:timestamp_submit': 63.14429593086243,
'm:loss': '[0.7343717813491821, 0.7315843105316162, 0.7203551530838013, 0.7077977657318115, 0.7045767903327942, 0.6942116022109985, 0.6890315413475037, 0.684148907661438, 0.6750453114509583, 0.6661133170127869, 0.6599052548408508, 0.6488535404205322, 0.6500421166419983, 0.646197497844696, 0.6356723308563232, 0.6293549537658691, 0.6263395547866821, 0.6230587363243103, 0.6143192648887634, 0.6077391505241394, 0.6039666533470154, 0.599918782711029, 0.5972365736961365, 0.5854122638702393, 0.5852221846580505, 0.5819635391235352, 0.5733869671821594, 0.5693132281303406, 0.5666898488998413, 0.5630242228507996, 0.5568211078643799, 0.5544881820678711, 0.5535055994987488, 0.5494033694267273, 0.544603705406189, 0.5400131940841675, 0.5392252802848816, 0.5353421568870544, 0.531427800655365, 0.5265730023384094, 0.5229849815368652, 0.522513747215271, 0.5185511708259583, 0.510654091835022, 0.510829508304596, 0.5096814036369324, 0.5069150328636169, 0.5036590695381165, 0.4942742884159088, 0.4953521490097046, 0.4973622262477875, 0.4932021498680115, 0.48927953839302063, 0.48885059356689453, 0.4867109954357147, 0.4817766845226288, 0.4790145754814148, 0.48116615414619446, 0.47567224502563477, 0.47586676478385925, 0.474057674407959, 0.4685772955417633, 0.46800220012664795, 0.46771240234375, 0.46357816457748413, 0.4627784490585327, 0.46076714992523193, 0.46169304847717285, 0.45370250940322876, 0.4523642659187317, 0.452671080827713, 0.45055699348449707, 0.4527966380119324, 0.4467359185218811, 0.4463013708591461, 0.44568803906440735, 0.44608384370803833, 0.44155585765838623, 0.4365282952785492, 0.43844056129455566, 0.4361962378025055, 0.43178996443748474, 0.43474602699279785, 0.4314608871936798, 0.4278643727302551, 0.4289937913417816, 0.42474204301834106, 0.4234481155872345, 0.42462724447250366, 0.4230608344078064, 0.42143714427948, 0.4172115623950958, 0.4172593951225281, 0.41717463731765747, 0.41767194867134094, 0.4153216481208801]',
'm:val_loss': '[0.753821849822998, 0.7444661855697632, 0.7352833151817322, 0.7264078855514526, 0.7177568674087524, 0.709143877029419, 0.7008553147315979, 0.6926632523536682, 0.6846651434898376, 0.6767299175262451, 0.6690924763679504, 0.6618218421936035, 0.6545751690864563, 0.6477261185646057, 0.6409757137298584, 0.6344925761222839, 0.6279751658439636, 0.6217548251152039, 0.6155738234519958, 0.6095430850982666, 0.6038309335708618, 0.5982047319412231, 0.5927146673202515, 0.5874735116958618, 0.582288920879364, 0.5773451924324036, 0.5724054574966431, 0.5675630569458008, 0.562959611415863, 0.5584943294525146, 0.5539917349815369, 0.5496424436569214, 0.545595645904541, 0.5415534377098083, 0.537651002407074, 0.5336986780166626, 0.5300918221473694, 0.526334822177887, 0.5227644443511963, 0.5193122029304504, 0.5160749554634094, 0.5127428770065308, 0.5095983743667603, 0.5065170526504517, 0.5034927129745483, 0.5005301833152771, 0.49770525097846985, 0.4948558509349823, 0.4920128583908081, 0.4893036186695099, 0.4867943823337555, 0.48432719707489014, 0.4818807542324066, 0.4794256389141083, 0.47698256373405457, 0.4746597409248352, 0.47240322828292847, 0.47020772099494934, 0.46811091899871826, 0.4660354256629944, 0.46405380964279175, 0.46211710572242737, 0.4601345956325531, 0.4583086669445038, 0.4563702642917633, 0.4545257091522217, 0.4527909457683563, 0.4510558843612671, 0.4493982195854187, 0.44778916239738464, 0.4462117552757263, 0.4445821940898895, 0.44306841492652893, 0.4415777623653412, 0.4401895999908447, 0.4388270080089569, 0.43738314509391785, 0.4360154867172241, 0.43456223607063293, 0.4332718551158905, 0.4320412278175354, 0.4307521879673004, 0.4295240044593811, 0.42836472392082214, 0.4272339642047882, 0.42612510919570923, 0.4250316321849823, 0.42397260665893555, 0.42292481660842896, 0.4219026267528534, 0.42088067531585693, 0.41994309425354004, 0.41888922452926636, 0.41796427965164185, 0.4170234501361847, 0.4160982072353363]',
'm:accuracy': '[0.4793388545513153, 0.4958677589893341, 0.5371900796890259, 0.5413222908973694, 0.5743801593780518, 0.5743801593780518, 0.586776852607727, 0.5743801593780518, 0.5950413346290588, 0.6115702390670776, 0.6363636255264282, 0.64462810754776, 0.6404958963394165, 0.6363636255264282, 0.6611570119857788, 0.6735537052154541, 0.6611570119857788, 0.6776859760284424, 0.6776859760284424, 0.6818181872367859, 0.6818181872367859, 0.6818181872367859, 0.6942148804664612, 0.702479362487793, 0.71074378490448, 0.71074378490448, 0.71074378490448, 0.7272727489471436, 0.7355371713638306, 0.7396694421768188, 0.7479338645935059, 0.7479338645935059, 0.7396694421768188, 0.7396694421768188, 0.7355371713638306, 0.7685950398445129, 0.7644628286361694, 0.7561983466148376, 0.7644628286361694, 0.7851239442825317, 0.78925621509552, 0.78925621509552, 0.797520637512207, 0.797520637512207, 0.7933884263038635, 0.7809917330741882, 0.797520637512207, 0.797520637512207, 0.7933884263038635, 0.797520637512207, 0.797520637512207, 0.7933884263038635, 0.797520637512207, 0.8016529083251953, 0.78925621509552, 0.8057851195335388, 0.8140496015548706, 0.8099173307418823, 0.8016529083251953, 0.8016529083251953, 0.8140496015548706, 0.8099173307418823, 0.8181818127632141, 0.78925621509552, 0.8223140239715576, 0.8016529083251953, 0.8181818127632141, 0.8140496015548706, 0.8181818127632141, 0.8140496015548706, 0.8099173307418823, 0.8140496015548706, 0.8099173307418823, 0.8140496015548706, 0.8016529083251953, 0.8099173307418823, 0.8140496015548706, 0.8181818127632141, 0.8181818127632141, 0.8140496015548706, 0.8140496015548706, 0.8223140239715576, 0.8140496015548706, 0.8223140239715576, 0.8181818127632141, 0.8181818127632141, 0.8223140239715576, 0.8223140239715576, 0.8140496015548706, 0.8223140239715576, 0.8305785059928894, 0.8388429880142212, 0.8099173307418823, 0.8181818127632141, 0.8223140239715576, 0.8181818127632141]',
'm:val_accuracy': '[0.39344263076782227, 0.44262295961380005, 0.4590163826942444, 0.4590163826942444, 0.49180328845977783, 0.49180328845977783, 0.5245901346206665, 0.5409836173057556, 0.5409836173057556, 0.5245901346206665, 0.5409836173057556, 0.5737704634666443, 0.5901639461517334, 0.5901639461517334, 0.6229507923126221, 0.6229507923126221, 0.6229507923126221, 0.6229507923126221, 0.6393442749977112, 0.6393442749977112, 0.6393442749977112, 0.6557376980781555, 0.6721311211585999, 0.6721311211585999, 0.688524603843689, 0.688524603843689, 0.7049180269241333, 0.7213114500045776, 0.7213114500045776, 0.7213114500045776, 0.7377049326896667, 0.7377049326896667, 0.7377049326896667, 0.7377049326896667, 0.7213114500045776, 0.7213114500045776, 0.7213114500045776, 0.7213114500045776, 0.7540983557701111, 0.7540983557701111, 0.7540983557701111, 0.7540983557701111, 0.7540983557701111, 0.7540983557701111, 0.7540983557701111, 0.7540983557701111, 0.7704917788505554, 0.7704917788505554, 0.7704917788505554, 0.7704917788505554, 0.7704917788505554, 0.7704917788505554, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667]',
'm:num_parameters': 18,
'm:num_parameters_train': 3003,
'm:timestamp_gather': 65.50817203521729}
[51]:
import matplotlib.pyplot as plt
plt.figure()
plt.plot(json.loads(metadata_default["val_accuracy"]), color="skyblue", label="Default (val)")
plt.plot(json.loads(metadata_default["accuracy"]), color="skyblue", linestyle="--", label="Default (train)")
plt.plot(json.loads(best_job["m:val_accuracy"]), color="coral", linewidth=2, label="Best Job(val)")
plt.plot(json.loads(best_job["m:accuracy"]), color="coral", linestyle="--", linewidth=2, label="Best Job (train)")
plt.legend()
plt.ylim(0.3, 0.9)
plt.grid()
plt.show()
2.9. Restart from a checkpoint#
It can often be useful to continue the search from previous results. For example, if the allocation requested was not enough or if an unexpected crash happened. The CBO
search provides the fit_surrogate(dataframe_of_results)
method for this use case.
To simulate this we create a second evaluator evaluator_2
and start a fresh CBO search with strong explotation kappa=0.001
.
[42]:
# Create a new evaluator
evaluator_2 = get_evaluator(run)
# Create a new CBO search with strong explotation (i.e., small kappa)
search_from_checkpoint = CBO(problem, evaluator_2, kappa=0.001, acq_optimizer="mixedga", acq_optimizer_freq=1)
# Initialize surrogate model of Bayesian optization
# With results of previous search
search_from_checkpoint.fit_surrogate(results)
WARNING:root:Results file already exists, it will be renamed to /Users/romainegele/Documents/Argonne/deephyper-tutorials/tutorials/colab/HPS_basic_classification_with_tabular_data/results_20240805-185838.csv
Created new evaluator with 1 worker and config: {'num_cpus': 1, 'num_cpus_per_task': 1, 'callbacks': [<deephyper.evaluator.callback.TqdmCallback object at 0x3e6806c60>]}
[43]:
results_from_checkpoint = search_from_checkpoint.search(max_evals=25)
[44]:
results_from_checkpoint
[44]:
p:activation | p:batch_size | p:dropout_rate | p:learning_rate | p:num_epochs | p:units | objective | job_id | m:timestamp_submit | m:loss | m:val_loss | m:accuracy | m:val_accuracy | m:num_parameters | m:num_parameters_train | m:timestamp_gather | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | linear | 11 | 0.580120 | 0.000073 | 95 | 42 | 0.852459 | 0 | 4.672708 | [0.8041970729827881, 0.7837277054786682, 0.789... | [0.760534405708313, 0.736428439617157, 0.71280... | [0.5247933864593506, 0.5, 0.5289255976676941, ... | [0.5409836173057556, 0.5573770403862, 0.540983... | 18 | 1597 | 7.694214 |
1 | swish | 11 | 0.580091 | 0.000073 | 95 | 42 | 0.803279 | 1 | 8.683520 | [0.7580281496047974, 0.8090499639511108, 0.768... | [0.7929872870445251, 0.7761178612709045, 0.758... | [0.5206611752510071, 0.44628098607063293, 0.46... | [0.3606557250022888, 0.37704917788505554, 0.40... | 18 | 1597 | 12.092570 |
2 | linear | 12 | 0.580188 | 0.000036 | 95 | 50 | 0.786885 | 2 | 13.035506 | [1.0248744487762451, 1.0645800828933716, 0.993... | [0.9728235006332397, 0.9541761875152588, 0.937... | [0.3347107470035553, 0.3099173605442047, 0.326... | [0.24590164422988892, 0.26229506731033325, 0.2... | 18 | 1901 | 16.033196 |
3 | linear | 11 | 0.579794 | 0.000073 | 95 | 72 | 0.836066 | 3 | 16.548202 | [0.7747272849082947, 0.7421781420707703, 0.706... | [0.6965786218643188, 0.6639910936355591, 0.633... | [0.5, 0.5702479481697083, 0.5785123705863953, ... | [0.5573770403862, 0.6557376980781555, 0.737704... | 18 | 2737 | 19.609617 |
4 | linear | 19 | 0.580786 | 0.000068 | 95 | 43 | 0.803279 | 4 | 20.466129 | [0.8868724703788757, 0.8863998651504517, 0.783... | [0.7910361289978027, 0.7758963108062744, 0.761... | [0.43801653385162354, 0.4586776793003082, 0.53... | [0.4262295067310333, 0.4262295067310333, 0.442... | 18 | 1635 | 23.013769 |
5 | linear | 11 | 0.421636 | 0.000072 | 95 | 41 | 0.836066 | 5 | 23.887383 | [0.7451992034912109, 0.7273797988891602, 0.748... | [0.5972851514816284, 0.581959068775177, 0.5678... | [0.5702479481697083, 0.56611567735672, 0.57851... | [0.688524603843689, 0.7213114500045776, 0.7213... | 18 | 1559 | 27.051523 |
6 | linear | 11 | 0.001346 | 0.000073 | 93 | 27 | 0.852459 | 6 | 27.735532 | [0.720139741897583, 0.6974911689758301, 0.6783... | [0.664795994758606, 0.6413559913635254, 0.6196... | [0.5289255976676941, 0.557851254940033, 0.6033... | [0.5409836173057556, 0.5573770403862, 0.590163... | 18 | 1027 | 31.108865 |
7 | linear | 11 | 0.015196 | 0.000073 | 93 | 8 | 0.836066 | 7 | 31.640157 | [0.8719937205314636, 0.8656114339828491, 0.857... | [0.8745443820953369, 0.8604922890663147, 0.846... | [0.5041322112083435, 0.5123966932296753, 0.516... | [0.44262295961380005, 0.4590163826942444, 0.45... | 18 | 305 | 34.892566 |
8 | linear | 11 | 0.000146 | 0.000073 | 85 | 12 | 0.836066 | 8 | 35.989811 | [0.9494009613990784, 0.9227721095085144, 0.896... | [0.9361185431480408, 0.9073224663734436, 0.880... | [0.3057851195335388, 0.3057851195335388, 0.334... | [0.3606557250022888, 0.37704917788505554, 0.37... | 18 | 457 | 38.953501 |
9 | linear | 8 | 0.000002 | 0.000073 | 93 | 27 | 0.836066 | 9 | 39.846366 | [0.8590447902679443, 0.815822958946228, 0.7771... | [0.8581119179725647, 0.813581109046936, 0.7709... | [0.42975205183029175, 0.43801653385162354, 0.4... | [0.4754098355770111, 0.49180328845977783, 0.52... | 18 | 1027 | 43.352862 |
10 | swish | 9 | 0.001305 | 0.000074 | 30 | 28 | 0.803279 | 10 | 43.974943 | [0.8331950306892395, 0.8155596256256104, 0.798... | [0.8349722027778625, 0.8154217004776001, 0.796... | [0.3057851195335388, 0.3181818127632141, 0.330... | [0.26229506731033325, 0.2786885201931, 0.29508... | 18 | 1065 | 45.914327 |
11 | linear | 9 | 0.002338 | 0.000073 | 96 | 15 | 0.819672 | 11 | 46.717739 | [0.5619564056396484, 0.5478906035423279, 0.538... | [0.6027979254722595, 0.5921545624732971, 0.582... | [0.7272727489471436, 0.7396694421768188, 0.752... | [0.7213114500045776, 0.7049180269241333, 0.721... | 18 | 571 | 50.088631 |
12 | linear | 11 | 0.001406 | 0.000047 | 92 | 27 | 0.819672 | 12 | 51.044316 | [0.7591452598571777, 0.748119592666626, 0.7355... | [0.7010340094566345, 0.6882603168487549, 0.676... | [0.43801653385162354, 0.46694216132164, 0.4876... | [0.6065573692321777, 0.6065573692321777, 0.639... | 18 | 1027 | 54.554788 |
13 | linear | 11 | 0.001531 | 0.000073 | 93 | 65 | 0.819672 | 13 | 55.431078 | [0.7318987846374512, 0.6924680471420288, 0.658... | [0.7065127491950989, 0.6669818162918091, 0.632... | [0.5041322112083435, 0.5702479481697083, 0.619... | [0.5409836173057556, 0.5901639461517334, 0.688... | 18 | 2471 | 58.463386 |
14 | selu | 11 | 0.051216 | 0.000073 | 93 | 27 | 0.852459 | 14 | 59.775333 | [0.7977746725082397, 0.7680606245994568, 0.733... | [0.8183743357658386, 0.7818477153778076, 0.746... | [0.4793388545513153, 0.5206611752510071, 0.553... | [0.4754098355770111, 0.4754098355770111, 0.540... | 18 | 1027 | 62.897967 |
15 | selu | 11 | 0.599294 | 0.000072 | 93 | 9 | 0.819672 | 15 | 63.684615 | [0.7609134912490845, 0.8461434245109558, 0.769... | [0.6055606007575989, 0.5996733903884888, 0.594... | [0.56611567735672, 0.5041322112083435, 0.53305... | [0.6229507923126221, 0.6393442749977112, 0.639... | 18 | 343 | 66.736602 |
16 | linear | 12 | 0.002261 | 0.000073 | 93 | 27 | 0.819672 | 16 | 67.926674 | [0.724230170249939, 0.6960060596466064, 0.6756... | [0.7928970456123352, 0.7624304294586182, 0.734... | [0.5, 0.5454545617103577, 0.5950413346290588, ... | [0.4098360538482666, 0.44262295961380005, 0.47... | 18 | 1027 | 71.041350 |
17 | selu | 8 | 0.051312 | 0.000074 | 75 | 11 | 0.836066 | 17 | 71.875614 | [0.5699914693832397, 0.555569589138031, 0.5570... | [0.5412390828132629, 0.5306512117385864, 0.520... | [0.7685950398445129, 0.7603305578231812, 0.727... | [0.7213114500045776, 0.7540983557701111, 0.754... | 18 | 419 | 75.466510 |
18 | softplus | 11 | 0.074423 | 0.000073 | 94 | 29 | 0.770492 | 18 | 75.929389 | [0.6149121522903442, 0.6076364517211914, 0.600... | [0.6102054715156555, 0.5992701053619385, 0.589... | [0.71074378490448, 0.71074378490448, 0.6983470... | [0.7377049326896667, 0.7213114500045776, 0.721... | 18 | 1103 | 79.442036 |
19 | sigmoid | 11 | 0.047733 | 0.000083 | 93 | 27 | 0.803279 | 19 | 80.129976 | [0.7153695821762085, 0.7007089257240295, 0.698... | [0.5909143090248108, 0.5860286355018616, 0.581... | [0.7148760557174683, 0.7148760557174683, 0.714... | [0.7704917788505554, 0.7704917788505554, 0.770... | 18 | 1027 | 83.201642 |
20 | selu | 11 | 0.069503 | 0.000073 | 94 | 29 | 0.852459 | 20 | 84.463790 | [1.1732215881347656, 1.1224826574325562, 1.058... | [1.1363171339035034, 1.0872124433517456, 1.042... | [0.3264462947845459, 0.3677685856819153, 0.338... | [0.31147539615631104, 0.3606557250022888, 0.36... | 18 | 1103 | 87.813287 |
21 | selu | 11 | 0.105616 | 0.000073 | 100 | 29 | 0.836066 | 21 | 88.577887 | [0.5842286348342896, 0.5772862434387207, 0.562... | [0.4800376296043396, 0.47080090641975403, 0.46... | [0.7148760557174683, 0.71074378490448, 0.70661... | [0.8524590134620667, 0.8524590134620667, 0.868... | 18 | 1103 | 92.231860 |
22 | selu | 18 | 0.066401 | 0.000073 | 93 | 26 | 0.836066 | 22 | 93.099906 | [1.31602144241333, 1.320502758026123, 1.265641... | [1.3727202415466309, 1.3392741680145264, 1.307... | [0.2975206673145294, 0.2851239740848541, 0.293... | [0.26229506731033325, 0.2950819730758667, 0.29... | 18 | 989 | 95.827970 |
23 | selu | 17 | 0.057178 | 0.000073 | 70 | 12 | 0.786885 | 23 | 97.016439 | [0.7276808023452759, 0.7200209498405457, 0.720... | [0.6873724460601807, 0.6807807087898254, 0.673... | [0.5123966932296753, 0.5206611752510071, 0.537... | [0.6229507923126221, 0.6393442749977112, 0.639... | 18 | 457 | 99.855316 |
24 | selu | 11 | 0.058280 | 0.000081 | 94 | 27 | 0.803279 | 24 | 100.785771 | [0.79051274061203, 0.7398812770843506, 0.72031... | [0.8464049696922302, 0.811725378036499, 0.7795... | [0.42561984062194824, 0.5165289044380188, 0.52... | [0.37704917788505554, 0.4098360538482666, 0.45... | 18 | 1027 | 103.862200 |
[52]:
i_max = results_from_checkpoint.objective.argmax()
best_job = results_from_checkpoint.iloc[i_max].to_dict()
print(f"The default configuration has an accuracy of {objective_default:.3f}. \n"
f"The best configuration found by DeepHyper has an accuracy {results_from_checkpoint['objective'].iloc[i_max]:.3f}, \n"
f"discovered after {results_from_checkpoint['m:timestamp_gather'].iloc[i_max]:.2f} secondes of search.\n")
best_job
The default configuration has an accuracy of 0.803.
The best configuration found by DeepHyper has an accuracy 0.852,
discovered after 7.69 secondes of search.
[52]:
{'p:activation': 'linear',
'p:batch_size': 11,
'p:dropout_rate': 0.5801200292603258,
'p:learning_rate': 7.270169218267874e-05,
'p:num_epochs': 95,
'p:units': 42,
'objective': 0.8524590134620667,
'job_id': 0,
'm:timestamp_submit': 4.672708034515381,
'm:loss': '[0.8041970729827881, 0.7837277054786682, 0.7897129058837891, 0.7576928734779358, 0.7714273929595947, 0.7486253380775452, 0.7247172594070435, 0.6684877872467041, 0.6340616941452026, 0.652440071105957, 0.6518965363502502, 0.6469079852104187, 0.6724880337715149, 0.5748777389526367, 0.5973073244094849, 0.5889704823493958, 0.6270110011100769, 0.541706919670105, 0.5298312902450562, 0.582248330116272, 0.49765029549598694, 0.5214836597442627, 0.47826045751571655, 0.5435822010040283, 0.5364992618560791, 0.5255090594291687, 0.4721006453037262, 0.4848361611366272, 0.49145030975341797, 0.5109071731567383, 0.5109598636627197, 0.4852293133735657, 0.47832751274108887, 0.48602572083473206, 0.48633140325546265, 0.44579175114631653, 0.4708606004714966, 0.4509105086326599, 0.4845390319824219, 0.4494284987449646, 0.43613725900650024, 0.434268057346344, 0.4538938105106354, 0.45544442534446716, 0.4280913472175598, 0.4577348232269287, 0.40012580156326294, 0.4429107904434204, 0.44394591450691223, 0.457697331905365, 0.4169984757900238, 0.42198798060417175, 0.4538344144821167, 0.4329397976398468, 0.408939391374588, 0.4074942469596863, 0.4367082118988037, 0.446547269821167, 0.4011521637439728, 0.40441566705703735, 0.36089783906936646, 0.4228549003601074, 0.38539719581604004, 0.41360726952552795, 0.3939768373966217, 0.39181941747665405, 0.38108402490615845, 0.34451887011528015, 0.4110186994075775, 0.3830220103263855, 0.3807251751422882, 0.3751121461391449, 0.3672732710838318, 0.39625462889671326, 0.39395615458488464, 0.37351155281066895, 0.381844699382782, 0.37353140115737915, 0.3986795246601105, 0.36958974599838257, 0.35957154631614685, 0.34944215416908264, 0.37648308277130127, 0.3538556694984436, 0.3683083951473236, 0.37798619270324707, 0.35513031482696533, 0.3352597653865814, 0.36496520042419434, 0.346407949924469, 0.3528498113155365, 0.34690144658088684, 0.3503521978855133, 0.3336299955844879, 0.39004188776016235]',
'm:val_loss': '[0.760534405708313, 0.736428439617157, 0.7128077149391174, 0.6911187171936035, 0.6703223586082458, 0.6510680317878723, 0.6329866051673889, 0.6171879768371582, 0.6040096879005432, 0.5901919007301331, 0.576133668422699, 0.5650199055671692, 0.5530554056167603, 0.5425328612327576, 0.5322731733322144, 0.5232126116752625, 0.5143499970436096, 0.5066238641738892, 0.4990280270576477, 0.4916485548019409, 0.4852895140647888, 0.47908690571784973, 0.473391592502594, 0.467697411775589, 0.4625626802444458, 0.4571584463119507, 0.4523298442363739, 0.44821926951408386, 0.44439101219177246, 0.44111281633377075, 0.4379344880580902, 0.43438616394996643, 0.4314509332180023, 0.4287097752094269, 0.4262154996395111, 0.42408379912376404, 0.42122986912727356, 0.41910144686698914, 0.41701099276542664, 0.41472113132476807, 0.4127712547779083, 0.410778284072876, 0.40907588601112366, 0.4074136018753052, 0.40630432963371277, 0.40454062819480896, 0.4030597507953644, 0.40178635716438293, 0.40064138174057007, 0.3992939889431, 0.39809659123420715, 0.3967132270336151, 0.3962016999721527, 0.3954685628414154, 0.39448660612106323, 0.3935595750808716, 0.39282742142677307, 0.39161983132362366, 0.3909461796283722, 0.38992923498153687, 0.3893953561782837, 0.3887447416782379, 0.3886417746543884, 0.3885132968425751, 0.3881435692310333, 0.3880539536476135, 0.3878744840621948, 0.3877726197242737, 0.3876495659351349, 0.3874644339084625, 0.3871561884880066, 0.38711845874786377, 0.38701578974723816, 0.38724857568740845, 0.3872767686843872, 0.38656359910964966, 0.38652437925338745, 0.38627997040748596, 0.3856770396232605, 0.3852933943271637, 0.3848276734352112, 0.38475126028060913, 0.38461777567863464, 0.3846655488014221, 0.3846093416213989, 0.384924054145813, 0.3849468231201172, 0.3846433460712433, 0.3851482570171356, 0.3851175606250763, 0.38545405864715576, 0.3855822682380676, 0.385833740234375, 0.38589683175086975, 0.38629570603370667]',
'm:accuracy': '[0.5247933864593506, 0.5, 0.5289255976676941, 0.5619834661483765, 0.5371900796890259, 0.5371900796890259, 0.5950413346290588, 0.6487603187561035, 0.6611570119857788, 0.6322314143180847, 0.6363636255264282, 0.6818181872367859, 0.6033057570457458, 0.7355371713638306, 0.6900826692581177, 0.6900826692581177, 0.6776859760284424, 0.7644628286361694, 0.7438016533851624, 0.702479362487793, 0.7520661354064941, 0.7066115736961365, 0.7768595218658447, 0.7231404781341553, 0.702479362487793, 0.7396694421768188, 0.7851239442825317, 0.7685950398445129, 0.7479338645935059, 0.7314049601554871, 0.7520661354064941, 0.7561983466148376, 0.7727272510528564, 0.7603305578231812, 0.7685950398445129, 0.8057851195335388, 0.7603305578231812, 0.797520637512207, 0.7479338645935059, 0.7768595218658447, 0.8057851195335388, 0.7809917330741882, 0.7768595218658447, 0.7603305578231812, 0.8181818127632141, 0.7603305578231812, 0.8057851195335388, 0.797520637512207, 0.7851239442825317, 0.7727272510528564, 0.78925621509552, 0.8140496015548706, 0.7520661354064941, 0.78925621509552, 0.797520637512207, 0.8223140239715576, 0.7851239442825317, 0.7685950398445129, 0.8140496015548706, 0.7933884263038635, 0.8264462947845459, 0.78925621509552, 0.8347107172012329, 0.8057851195335388, 0.8181818127632141, 0.8057851195335388, 0.8305785059928894, 0.8595041036605835, 0.8057851195335388, 0.8264462947845459, 0.8016529083251953, 0.8264462947845459, 0.8305785059928894, 0.8099173307418823, 0.8223140239715576, 0.8140496015548706, 0.8140496015548706, 0.8429751992225647, 0.78925621509552, 0.8264462947845459, 0.8429751992225647, 0.8471074104309082, 0.8181818127632141, 0.8471074104309082, 0.8140496015548706, 0.8016529083251953, 0.85537189245224, 0.8471074104309082, 0.8140496015548706, 0.8347107172012329, 0.8264462947845459, 0.8636363744735718, 0.8305785059928894, 0.85537189245224, 0.8223140239715576]',
'm:val_accuracy': '[0.5409836173057556, 0.5573770403862, 0.5409836173057556, 0.5573770403862, 0.5901639461517334, 0.5901639461517334, 0.5901639461517334, 0.6065573692321777, 0.6721311211585999, 0.688524603843689, 0.7049180269241333, 0.7049180269241333, 0.7049180269241333, 0.7213114500045776, 0.7704917788505554, 0.7540983557701111, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667]',
'm:num_parameters': 18,
'm:num_parameters_train': 1597,
'm:timestamp_gather': 7.694214105606079}
[53]:
import matplotlib.pyplot as plt
plt.figure()
plt.plot(json.loads(metadata_default["val_accuracy"]), color="skyblue", label="Default (val)")
plt.plot(json.loads(metadata_default["accuracy"]), color="skyblue", linestyle="--", label="Default (train)")
plt.plot(json.loads(best_job["m:val_accuracy"]), color="coral", linewidth=2, label="Best Job(val)")
plt.plot(json.loads(best_job["m:accuracy"]), color="coral", linestyle="--", linewidth=2, label="Best Job (train)")
plt.legend()
plt.ylim(0.5, 0.9)
plt.grid()
plt.show()
2.10. Add conditional hyperparameters#
Now we want to add the possibility to search for a second fully-connected layer. We simply add two new lines:
if config.get("dense_2", False):
x = tfk.layers.Dense(config["dense_2:units"], activation=config["dense_2:activation"])(x)
[54]:
def run_with_condition(config: dict):
tf.autograph.set_verbosity(0)
train_dataframe, val_dataframe = load_data()
train_ds = dataframe_to_dataset(train_dataframe)
val_ds = dataframe_to_dataset(val_dataframe)
train_ds = train_ds.batch(config["batch_size"])
val_ds = val_ds.batch(config["batch_size"])
# Categorical features encoded as integers
sex = tfk.Input(shape=(1,), name="sex", dtype="int64")
cp = tfk.Input(shape=(1,), name="cp", dtype="int64")
fbs = tfk.Input(shape=(1,), name="fbs", dtype="int64")
restecg = tfk.Input(shape=(1,), name="restecg", dtype="int64")
exang = tfk.Input(shape=(1,), name="exang", dtype="int64")
ca = tfk.Input(shape=(1,), name="ca", dtype="int64")
# Categorical feature encoded as string
thal = tfk.Input(shape=(1,), name="thal", dtype="string")
# Numerical features
age = tfk.Input(shape=(1,), name="age")
trestbps = tfk.Input(shape=(1,), name="trestbps")
chol = tfk.Input(shape=(1,), name="chol")
thalach = tfk.Input(shape=(1,), name="thalach")
oldpeak = tfk.Input(shape=(1,), name="oldpeak")
slope = tfk.Input(shape=(1,), name="slope")
all_inputs = [
sex,
cp,
fbs,
restecg,
exang,
ca,
thal,
age,
trestbps,
chol,
thalach,
oldpeak,
slope,
]
# Integer categorical features
sex_encoded = encode_categorical_feature(sex, "sex", train_ds, False)
cp_encoded = encode_categorical_feature(cp, "cp", train_ds, False)
fbs_encoded = encode_categorical_feature(fbs, "fbs", train_ds, False)
restecg_encoded = encode_categorical_feature(restecg, "restecg", train_ds, False)
exang_encoded = encode_categorical_feature(exang, "exang", train_ds, False)
ca_encoded = encode_categorical_feature(ca, "ca", train_ds, False)
# String categorical features
thal_encoded = encode_categorical_feature(thal, "thal", train_ds, True)
# Numerical features
age_encoded = encode_numerical_feature(age, "age", train_ds)
trestbps_encoded = encode_numerical_feature(trestbps, "trestbps", train_ds)
chol_encoded = encode_numerical_feature(chol, "chol", train_ds)
thalach_encoded = encode_numerical_feature(thalach, "thalach", train_ds)
oldpeak_encoded = encode_numerical_feature(oldpeak, "oldpeak", train_ds)
slope_encoded = encode_numerical_feature(slope, "slope", train_ds)
all_features = tfk.layers.concatenate(
[
sex_encoded,
cp_encoded,
fbs_encoded,
restecg_encoded,
exang_encoded,
slope_encoded,
ca_encoded,
thal_encoded,
age_encoded,
trestbps_encoded,
chol_encoded,
thalach_encoded,
oldpeak_encoded,
]
)
x = tfk.layers.Dense(config["units"], activation=config["activation"])(
all_features
)
### START - NEW LINES
if config.get("dense_2", False):
x = tfk.layers.Dense(config["dense_2:units"], activation=config["dense_2:activation"])(x)
### END - NEW LINES
x = tfk.layers.Dropout(config["dropout_rate"])(x)
output = tfk.layers.Dense(1, activation="sigmoid")(x)
model = tfk.Model(all_inputs, output)
optimizer = tfk.optimizers.Adam(learning_rate=config["learning_rate"])
model.compile(optimizer, "binary_crossentropy", metrics=["accuracy"])
history = model.fit(
train_ds, epochs=config["num_epochs"], validation_data=val_ds, verbose=0
)
objective = history.history["val_accuracy"][-1]
metadata = {
"loss": history.history["loss"],
"val_loss": history.history["val_loss"],
"accuracy": history.history["accuracy"],
"val_accuracy": history.history["val_accuracy"],
}
metadata = {k:json.dumps(v) for k,v in metadata.items()}
metadata.update(count_params(model))
return {"objective": objective, "metadata": metadata}
To define conditionnal hyperparameters we use ConfigSpace. We define dense_2:units
and dense_2:activation
as active hyperparameters only when dense_2 == True
. The cs.EqualsCondition
help us do that. Then we call
problem_with_condition.add_condition(condition)
to register each new condition to the HpProblem
.
[56]:
from ConfigSpace import EqualsCondition
# Define the hyperparameter problem
problem_with_condition = HpProblem()
# Define the same hyperparameters as before
problem_with_condition.add_hyperparameter((8, 128), "units")
problem_with_condition.add_hyperparameter(ACTIVATIONS, "activation")
problem_with_condition.add_hyperparameter((0.0, 0.6), "dropout_rate")
problem_with_condition.add_hyperparameter((10, 100), "num_epochs")
problem_with_condition.add_hyperparameter((8, 256, "log-uniform"), "batch_size")
problem_with_condition.add_hyperparameter((1e-5, 1e-2, "log-uniform"), "learning_rate")
# Add a new hyperparameter "dense_2 (bool)" to decide if a second fully-connected layer should be created
hp_dense_2 = problem_with_condition.add_hyperparameter([True, False], "dense_2")
hp_dense_2_units = problem_with_condition.add_hyperparameter((8, 128), "dense_2:units")
hp_dense_2_activation = problem_with_condition.add_hyperparameter(ACTIVATIONS, "dense_2:activation")
problem_with_condition.add_condition(EqualsCondition(hp_dense_2_units, hp_dense_2, True))
problem_with_condition.add_condition(EqualsCondition(hp_dense_2_activation, hp_dense_2, True))
problem_with_condition
[56]:
Configuration space object:
Hyperparameters:
activation, Type: Categorical, Choices: {elu, gelu, hard_sigmoid, linear, relu, selu, sigmoid, softplus, softsign, swish, tanh}, Default: elu
batch_size, Type: UniformInteger, Range: [8, 256], Default: 45, on log-scale
dense_2, Type: Categorical, Choices: {True, False}, Default: True
dense_2:activation, Type: Categorical, Choices: {elu, gelu, hard_sigmoid, linear, relu, selu, sigmoid, softplus, softsign, swish, tanh}, Default: elu
dense_2:units, Type: UniformInteger, Range: [8, 128], Default: 68
dropout_rate, Type: UniformFloat, Range: [0.0, 0.6], Default: 0.3
learning_rate, Type: UniformFloat, Range: [1e-05, 0.01], Default: 0.000316227766, on log-scale
num_epochs, Type: UniformInteger, Range: [10, 100], Default: 55
units, Type: UniformInteger, Range: [8, 128], Default: 68
Conditions:
dense_2:activation | dense_2 == True
dense_2:units | dense_2 == True
We create a new evaluator evaluator_3
and start a fresh CBO search with this new problem problem_with_condition
.
[57]:
evaluator_3 = get_evaluator(run_with_condition)
search_with_condition = CBO(problem_with_condition, evaluator_3)
WARNING:root:Results file already exists, it will be renamed to /Users/romainegele/Documents/Argonne/deephyper-tutorials/tutorials/colab/HPS_basic_classification_with_tabular_data/results_20240805-190449.csv
Created new evaluator with 1 worker and config: {'num_cpus': 1, 'num_cpus_per_task': 1, 'callbacks': [<deephyper.evaluator.callback.TqdmCallback object at 0x3e628d1c0>]}
[58]:
results_with_condition = search_with_condition.search(max_evals=25)
[59]:
results_with_condition
[59]:
p:activation | p:batch_size | p:dense_2 | p:dropout_rate | p:learning_rate | p:num_epochs | p:units | p:dense_2:activation | p:dense_2:units | objective | job_id | m:timestamp_submit | m:loss | m:val_loss | m:accuracy | m:val_accuracy | m:num_parameters | m:num_parameters_train | m:timestamp_gather | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | tanh | 15 | True | 0.366818 | 0.005029 | 26 | 60 | softplus | 105 | 0.786885 | 0 | 5.803378 | [0.4718848168849945, 0.3748987913131714, 0.300... | [0.483046293258667, 0.4596226215362549, 0.4899... | [0.7438016533851624, 0.8305785059928894, 0.842... | [0.8524590134620667, 0.7868852615356445, 0.836... | 18 | 8731 | 7.808099 |
1 | tanh | 8 | True | 0.390801 | 0.000175 | 81 | 30 | hard_sigmoid | 68 | 0.836066 | 1 | 8.464106 | [0.7436654567718506, 0.6908826231956482, 0.664... | [0.6634212136268616, 0.6115420460700989, 0.579... | [0.4710743725299835, 0.5247933864593506, 0.586... | [0.688524603843689, 0.8196721076965332, 0.7704... | 18 | 3287 | 11.945705 |
2 | tanh | 114 | False | 0.196952 | 0.000011 | 17 | 100 | elu | 8 | 0.491803 | 2 | 12.560954 | [0.7704780101776123, 0.7623841762542725, 0.757... | [0.6971723437309265, 0.6961050033569336, 0.695... | [0.44628098607063293, 0.43388429284095764, 0.4... | [0.4754098355770111, 0.4754098355770111, 0.475... | 18 | 3801 | 14.004241 |
3 | selu | 115 | False | 0.077023 | 0.000019 | 25 | 50 | elu | 8 | 0.721311 | 3 | 14.622597 | [0.6315001845359802, 0.630281925201416, 0.6373... | [0.6136785745620728, 0.612109899520874, 0.6105... | [0.6363636255264282, 0.6280992031097412, 0.603... | [0.6721311211585999, 0.6721311211585999, 0.688... | 18 | 1901 | 16.010743 |
4 | relu | 73 | False | 0.476480 | 0.003393 | 19 | 34 | elu | 8 | 0.803279 | 4 | 16.631053 | [0.5818436741828918, 0.5115591287612915, 0.514... | [0.48164546489715576, 0.4386799931526184, 0.41... | [0.702479362487793, 0.7438016533851624, 0.7438... | [0.8196721076965332, 0.7704917788505554, 0.770... | 18 | 1293 | 18.078223 |
5 | selu | 125 | False | 0.534733 | 0.000035 | 84 | 74 | elu | 8 | 0.245902 | 5 | 18.695810 | [1.2875844240188599, 1.262080430984497, 1.3596... | [1.3276287317276, 1.3229693174362183, 1.318302... | [0.24793387949466705, 0.2975206673145294, 0.25... | [0.2295081913471222, 0.2295081913471222, 0.245... | 18 | 2813 | 20.870275 |
6 | softplus | 65 | False | 0.568267 | 0.004382 | 34 | 30 | elu | 8 | 0.836066 | 6 | 21.482700 | [1.0909706354141235, 0.7855749130249023, 0.700... | [0.7736652493476868, 0.5466883778572083, 0.439... | [0.42561984062194824, 0.5826446413993835, 0.62... | [0.44262295961380005, 0.8032786846160889, 0.80... | 18 | 1141 | 23.006530 |
7 | linear | 55 | False | 0.162118 | 0.000082 | 50 | 44 | elu | 8 | 0.819672 | 7 | 23.618691 | [0.8554414510726929, 0.8225628137588501, 0.824... | [0.8711621761322021, 0.8583137392997742, 0.845... | [0.3677685856819153, 0.3719008266925812, 0.413... | [0.3442623019218445, 0.3442623019218445, 0.360... | 18 | 1673 | 25.332632 |
8 | hard_sigmoid | 11 | True | 0.352241 | 0.000210 | 99 | 61 | elu | 125 | 0.803279 | 8 | 25.941547 | [0.7294607162475586, 0.6213605999946594, 0.591... | [0.5591641664505005, 0.5274262428283691, 0.524... | [0.5, 0.71074378490448, 0.7148760557174683, 0.... | [0.7704917788505554, 0.7704917788505554, 0.770... | 18 | 10133 | 29.293199 |
9 | gelu | 185 | False | 0.341791 | 0.003917 | 65 | 17 | elu | 8 | 0.786885 | 9 | 29.903265 | [0.9228700399398804, 0.8316530585289001, 0.809... | [0.7922216057777405, 0.724449634552002, 0.6675... | [0.28925618529319763, 0.42561984062194824, 0.4... | [0.4262295067310333, 0.49180328845977783, 0.63... | 18 | 647 | 31.662543 |
10 | swish | 28 | False | 0.561951 | 0.000036 | 34 | 30 | elu | 8 | 0.442623 | 10 | 32.490997 | [0.8498828411102295, 0.8368768692016602, 0.829... | [0.8152909874916077, 0.8121859431266785, 0.809... | [0.40495866537094116, 0.40909090638160706, 0.4... | [0.3442623019218445, 0.3442623019218445, 0.344... | 18 | 1141 | 34.107248 |
11 | swish | 24 | False | 0.577138 | 0.005004 | 56 | 18 | elu | 8 | 0.770492 | 11 | 34.816861 | [0.6796150803565979, 0.5124174356460571, 0.482... | [0.46827590465545654, 0.4033910036087036, 0.37... | [0.5950413346290588, 0.7479338645935059, 0.776... | [0.7868852615356445, 0.8032786846160889, 0.819... | 18 | 685 | 37.176652 |
12 | relu | 9 | True | 0.523482 | 0.000046 | 39 | 15 | linear | 62 | 0.770492 | 12 | 37.912493 | [0.7733532190322876, 0.7673518657684326, 0.725... | [0.7620072364807129, 0.7392509579658508, 0.716... | [0.5041322112083435, 0.5330578684806824, 0.570... | [0.44262295961380005, 0.49180328845977783, 0.5... | 18 | 1610 | 40.207185 |
13 | sigmoid | 60 | False | 0.576131 | 0.004627 | 63 | 119 | elu | 8 | 0.819672 | 13 | 40.919064 | [0.6822071075439453, 0.5901292562484741, 0.567... | [0.4921690821647644, 0.4631254971027374, 0.408... | [0.5743801593780518, 0.7148760557174683, 0.731... | [0.7704917788505554, 0.7704917788505554, 0.786... | 18 | 4523 | 42.717181 |
14 | swish | 195 | True | 0.550026 | 0.000172 | 56 | 30 | softplus | 72 | 0.754098 | 14 | 43.429739 | [0.7369769215583801, 0.6703598499298096, 0.709... | [0.5385007858276367, 0.5357068181037903, 0.533... | [0.5950413346290588, 0.64462810754776, 0.60330... | [0.7704917788505554, 0.7704917788505554, 0.770... | 18 | 3415 | 45.337643 |
15 | softplus | 65 | False | 0.529968 | 0.001002 | 31 | 36 | elu | 8 | 0.819672 | 15 | 46.051016 | [0.8508321046829224, 0.8252332210540771, 0.771... | [0.7166399359703064, 0.6583455801010132, 0.608... | [0.43388429284095764, 0.5206611752510071, 0.54... | [0.44262295961380005, 0.6229507923126221, 0.72... | 18 | 1369 | 47.613793 |
16 | softsign | 34 | True | 0.557522 | 0.000161 | 78 | 109 | hard_sigmoid | 47 | 0.852459 | 16 | 48.333625 | [0.8418372273445129, 0.8010062575340271, 0.781... | [0.768680214881897, 0.746289074420929, 0.72464... | [0.45041322708129883, 0.5082644820213318, 0.46... | [0.2295081913471222, 0.2295081913471222, 0.229... | 18 | 9251 | 51.402913 |
17 | softplus | 85 | True | 0.560434 | 0.000156 | 78 | 118 | linear | 26 | 0.852459 | 17 | 52.115983 | [0.7104208469390869, 0.7846921682357788, 0.714... | [0.5288053154945374, 0.5168441534042358, 0.508... | [0.6074380278587341, 0.5702479481697083, 0.628... | [0.7704917788505554, 0.7704917788505554, 0.770... | 18 | 7487 | 54.040631 |
18 | softsign | 82 | True | 0.023935 | 0.000069 | 78 | 117 | hard_sigmoid | 22 | 0.704918 | 18 | 54.757471 | [0.8375796675682068, 0.8401011824607849, 0.837... | [0.8699012398719788, 0.866222620010376, 0.8625... | [0.2851239740848541, 0.2851239740848541, 0.289... | [0.2295081913471222, 0.2295081913471222, 0.229... | 18 | 6948 | 57.191873 |
19 | selu | 126 | True | 0.498169 | 0.000588 | 80 | 112 | relu | 114 | 0.803279 | 19 | 58.051332 | [0.7589589357376099, 0.6722019910812378, 0.575... | [0.6904788613319397, 0.6121534109115601, 0.551... | [0.4586776793003082, 0.5702479481697083, 0.731... | [0.5901639461517334, 0.7540983557701111, 0.770... | 18 | 17141 | 60.207874 |
20 | softsign | 125 | True | 0.520356 | 0.000171 | 99 | 94 | linear | 54 | 0.819672 | 20 | 60.940526 | [0.8036916851997375, 0.7935233116149902, 0.770... | [0.7480814456939697, 0.7282476425170898, 0.709... | [0.42561984062194824, 0.40495866537094116, 0.4... | [0.4098360538482666, 0.44262295961380005, 0.50... | 18 | 8663 | 63.042618 |
21 | softplus | 30 | True | 0.517795 | 0.000126 | 20 | 119 | hard_sigmoid | 44 | 0.770492 | 21 | 63.761625 | [0.7613235712051392, 0.7507070899009705, 0.705... | [0.649841845035553, 0.6311606764793396, 0.6109... | [0.5206611752510071, 0.5247933864593506, 0.590... | [0.7704917788505554, 0.7704917788505554, 0.770... | 18 | 9728 | 65.623381 |
22 | softplus | 177 | True | 0.596195 | 0.000163 | 66 | 119 | gelu | 37 | 0.770492 | 22 | 66.337983 | [0.6425224542617798, 0.6305481195449829, 0.657... | [0.5786556005477905, 0.5653592944145203, 0.554... | [0.6239669322967529, 0.6611570119857788, 0.628... | [0.7704917788505554, 0.7704917788505554, 0.770... | 18 | 8881 | 68.226429 |
23 | softsign | 8 | True | 0.578768 | 0.000083 | 78 | 105 | relu | 11 | 0.819672 | 23 | 68.941229 | [0.7219522595405579, 0.7058809995651245, 0.676... | [0.6714370250701904, 0.6515097618103027, 0.632... | [0.5454545617103577, 0.5991735458374023, 0.644... | [0.6557376980781555, 0.6557376980781555, 0.737... | 18 | 5063 | 72.357375 |
24 | tanh | 33 | False | 0.572871 | 0.000163 | 86 | 49 | elu | 8 | 0.819672 | 24 | 73.081465 | [0.7666149139404297, 0.7785250544548035, 0.725... | [0.7132001519203186, 0.6926335692405701, 0.673... | [0.4876033067703247, 0.4917355477809906, 0.508... | [0.5081967115402222, 0.5409836173057556, 0.557... | 18 | 1863 | 75.528965 |
Finally, let us print out the best configuration found from this conditionned search space.
[60]:
i_max = results_with_condition.objective.argmax()
best_job = results_with_condition.iloc[i_max].to_dict()
print(f"The default configuration has an accuracy of {objective_default:.3f}. \n"
f"The best configuration found by DeepHyper has an accuracy {results_with_condition['objective'].iloc[i_max]:.3f}, \n"
f"discovered after {results_with_condition['m:timestamp_gather'].iloc[i_max]:.2f} secondes of search.\n")
best_job
The default configuration has an accuracy of 0.803.
The best configuration found by DeepHyper has an accuracy 0.852,
discovered after 51.40 secondes of search.
[60]:
{'p:activation': 'softsign',
'p:batch_size': 34,
'p:dense_2': True,
'p:dropout_rate': 0.5575222563657,
'p:learning_rate': 0.0001609081602,
'p:num_epochs': 78,
'p:units': 109,
'p:dense_2:activation': 'hard_sigmoid',
'p:dense_2:units': 47,
'objective': 0.8524590134620667,
'job_id': 16,
'm:timestamp_submit': 48.333625078201294,
'm:loss': '[0.8418372273445129, 0.8010062575340271, 0.7812130451202393, 0.7930625677108765, 0.7684123516082764, 0.7609521746635437, 0.7478557229042053, 0.7406894564628601, 0.6992117762565613, 0.6690812706947327, 0.6832402944564819, 0.6622161269187927, 0.7154543399810791, 0.6304647922515869, 0.632775604724884, 0.635278046131134, 0.6395263671875, 0.601736843585968, 0.5866347551345825, 0.5930865406990051, 0.6114271879196167, 0.6148985028266907, 0.5917722582817078, 0.5602957010269165, 0.6068222522735596, 0.5561351776123047, 0.52374666929245, 0.5643391609191895, 0.5197662711143494, 0.5247690081596375, 0.5277420282363892, 0.5073702335357666, 0.46645841002464294, 0.5366184711456299, 0.5077937841415405, 0.5054715871810913, 0.5132231712341309, 0.5002268552780151, 0.4574722945690155, 0.4624318480491638, 0.45751628279685974, 0.45831236243247986, 0.46884045004844666, 0.46934202313423157, 0.43772009015083313, 0.45625945925712585, 0.42081278562545776, 0.4312887191772461, 0.4441389739513397, 0.4487796425819397, 0.44675105810165405, 0.42890456318855286, 0.42480218410491943, 0.4400651156902313, 0.4144305884838104, 0.39338719844818115, 0.359445720911026, 0.4274265468120575, 0.4111899435520172, 0.38475120067596436, 0.3514232337474823, 0.3587934970855713, 0.39239099621772766, 0.37596988677978516, 0.38595589995384216, 0.38566166162490845, 0.3809291422367096, 0.3818380534648895, 0.42398494482040405, 0.35981088876724243, 0.3791501820087433, 0.41108107566833496, 0.36640042066574097, 0.3670456111431122, 0.36325547099113464, 0.3857632577419281, 0.3685024678707123, 0.39088982343673706]',
'm:val_loss': '[0.768680214881897, 0.746289074420929, 0.7246480584144592, 0.7041943669319153, 0.6852578520774841, 0.6664589643478394, 0.6496288776397705, 0.6337476372718811, 0.6189693808555603, 0.6056522130966187, 0.593632698059082, 0.5819397568702698, 0.569520890712738, 0.5585426688194275, 0.5482839941978455, 0.5388845205307007, 0.5297328233718872, 0.5214434862136841, 0.5139332413673401, 0.5062064528465271, 0.49924206733703613, 0.4916667640209198, 0.48417896032333374, 0.4775199890136719, 0.4708450138568878, 0.46558263897895813, 0.46020370721817017, 0.45523616671562195, 0.4506804049015045, 0.4466724097728729, 0.4432690143585205, 0.43901562690734863, 0.43512803316116333, 0.43151235580444336, 0.42741674184799194, 0.4234548509120941, 0.41981783509254456, 0.41622480750083923, 0.4126134216785431, 0.4092336595058441, 0.4061407446861267, 0.40323206782341003, 0.40096515417099, 0.3985515236854553, 0.3963643014431, 0.3940359652042389, 0.3918095529079437, 0.38982391357421875, 0.387992799282074, 0.3862203359603882, 0.384972482919693, 0.38385680317878723, 0.38305702805519104, 0.38194236159324646, 0.38052645325660706, 0.3791072964668274, 0.377858966588974, 0.3774709105491638, 0.3772493600845337, 0.37700068950653076, 0.3765186667442322, 0.3758242130279541, 0.37525153160095215, 0.37498900294303894, 0.3749580979347229, 0.3747103810310364, 0.374695748090744, 0.37496882677078247, 0.3751734793186188, 0.375230610370636, 0.37512537837028503, 0.3750172555446625, 0.3752257525920868, 0.37517407536506653, 0.37543660402297974, 0.37605175375938416, 0.3765294849872589, 0.37751638889312744]',
'm:accuracy': '[0.45041322708129883, 0.5082644820213318, 0.4628099203109741, 0.46694216132164, 0.4876033067703247, 0.5082644820213318, 0.5165289044380188, 0.5247933864593506, 0.5371900796890259, 0.5950413346290588, 0.5826446413993835, 0.6115702390670776, 0.586776852607727, 0.5991735458374023, 0.6115702390670776, 0.6115702390670776, 0.6322314143180847, 0.6735537052154541, 0.6694214940071106, 0.6694214940071106, 0.6528925895690918, 0.64462810754776, 0.6611570119857788, 0.6818181872367859, 0.6776859760284424, 0.7190082669258118, 0.7438016533851624, 0.7066115736961365, 0.7644628286361694, 0.7272727489471436, 0.7520661354064941, 0.7479338645935059, 0.7685950398445129, 0.7396694421768188, 0.7479338645935059, 0.7603305578231812, 0.7314049601554871, 0.7272727489471436, 0.78925621509552, 0.7851239442825317, 0.7768595218658447, 0.7644628286361694, 0.7851239442825317, 0.7727272510528564, 0.7561983466148376, 0.7603305578231812, 0.8140496015548706, 0.8140496015548706, 0.7685950398445129, 0.7809917330741882, 0.7933884263038635, 0.8099173307418823, 0.8264462947845459, 0.7851239442825317, 0.7809917330741882, 0.8057851195335388, 0.8388429880142212, 0.8223140239715576, 0.8016529083251953, 0.8140496015548706, 0.8223140239715576, 0.8264462947845459, 0.8223140239715576, 0.8719007968902588, 0.8223140239715576, 0.8264462947845459, 0.8223140239715576, 0.8347107172012329, 0.78925621509552, 0.85537189245224, 0.8099173307418823, 0.8099173307418823, 0.8305785059928894, 0.8429751992225647, 0.8388429880142212, 0.8140496015548706, 0.8471074104309082, 0.8347107172012329]',
'm:val_accuracy': '[0.2295081913471222, 0.2295081913471222, 0.2295081913471222, 0.31147539615631104, 0.6393442749977112, 0.7540983557701111, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.8032786846160889, 0.7868852615356445, 0.8032786846160889, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8196721076965332, 0.7868852615356445, 0.8196721076965332, 0.8032786846160889, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.7704917788505554, 0.7704917788505554, 0.7704917788505554, 0.7704917788505554, 0.7704917788505554, 0.7704917788505554, 0.7704917788505554, 0.7704917788505554, 0.7704917788505554, 0.7704917788505554, 0.7704917788505554, 0.7704917788505554, 0.7704917788505554, 0.7704917788505554, 0.7704917788505554, 0.7704917788505554, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.8032786846160889, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.868852436542511, 0.8524590134620667, 0.8524590134620667, 0.868852436542511, 0.868852436542511, 0.868852436542511, 0.868852436542511, 0.868852436542511, 0.8524590134620667, 0.8524590134620667]',
'm:num_parameters': 18,
'm:num_parameters_train': 9251,
'm:timestamp_gather': 51.402913093566895}
[61]:
import matplotlib.pyplot as plt
plt.figure()
plt.plot(json.loads(metadata_default["val_accuracy"]), color="skyblue", label="Default (val)")
plt.plot(json.loads(metadata_default["accuracy"]), color="skyblue", linestyle="--", label="Default (train)")
plt.plot(json.loads(best_job["m:val_accuracy"]), color="coral", linewidth=2, label="Best Job(val)")
plt.plot(json.loads(best_job["m:accuracy"]), color="coral", linestyle="--", linewidth=2, label="Best Job (train)")
plt.legend()
plt.ylim(0.5, 0.9)
plt.grid()
plt.show()
[ ]: