2. Hyperparameter search for classification with Tabular data (Keras)#

Open In Colab

In this tutorial we present how to use hyperparameter optimization on a basic example from the Keras documentation.

Reference: This tutorial is based on materials from the Keras Documentation: Structured data classification from scratch

Let us start with installing DeepHyper!

Warning

This tutorial should be run with tensorflow>=2.6.

[1]:
try:
    import deephyper
    print(deephyper.__version__)
except (ImportError, ModuleNotFoundError):
    !pip install deephyper

try:
    import ray
except (ImportError, ModuleNotFoundError):
    !pip install ray
0.6.0

Note

The following environment variables can be used to avoid the logging of some Tensorflow DEBUG, INFO and WARNING statements.

[2]:
import os


os.environ["TF_CPP_MIN_LOG_LEVEL"] = str(3)
os.environ["AUTOGRAPH_VERBOSITY"] = str(0)

2.1. Imports#

Warning

It is important to follow the import strategy import tensorflow as tf to prevent serialization errors that will crash the search.

The import strategy from the original Keras tutorial (shown below):

from tensorflow import keras
from tensorflow.keras import layers
...
from tensorflow.keras.layers import IntegerLookup
from tensorflow.keras.layers import Normalization
from tensorflow.keras.layers import StringLookup

resulted in non-serializable data, preventing the search from executing in parallel.

[1]:
import json

import numpy as np
import pandas as pd
import tensorflow as tf
import tf_keras as tfk
import tf_keras.backend as K

Note

The following can be used to detect if GPU devices are available on the current host. Therefore, this notebook will automatically adapt the parallel execution based on the resources available locally. However, this simple code will not detect the ressources from multiple nodes.

[2]:
from tensorflow.python.client import device_lib


def get_available_gpus():
    local_device_protos = device_lib.list_local_devices()
    return [x.name for x in local_device_protos if x.device_type == "GPU"]


n_gpus = len(get_available_gpus())
if n_gpus > 1:
    n_gpus -= 1

is_gpu_available = n_gpus > 0

if is_gpu_available:
    print(f"{n_gpus} GPU{'s are' if n_gpus > 1 else ' is'} available.")
else:
    print("No GPU available")
No GPU available

2.2. The dataset (from Keras.io)#

The dataset is provided by the Cleveland Clinic Foundation for Heart Disease. It’s a CSV file with 303 rows. Each row contains information about a patient (a sample), and each column describes an attribute of the patient (a feature). We use the features to predict whether a patient has a heart disease (binary classification).

Here’s the description of each feature:

Column

Description

Feature Type

Age

Age in years

Numerical

Sex

(1 = male; 0 = female)

Categorical

CP

Chest pain type (0, 1, 2, 3, 4)

Categorical

Trestbpd

Resting blood pressure (in mm Hg on admission)

Numerical

Chol

Serum cholesterol in mg/dl

Numerical

FBS

fasting blood sugar in 120 mg/dl (1 = true; 0 = false)

Categorical

RestECG

Resting electrocardiogram results (0, 1, 2)

Categorical

Thalach

Maximum heart rate achieved

Numerical

Exang

Exercise induced angina (1 = yes; 0 = no)

Categorical

Oldpeak

ST depression induced by exercise relative to rest

Numerical

Slope

Slope of the peak exercise ST segment

Numerical

CA

Number of major vessels (0-3) colored by fluoroscopy

Both numerical & categorical

Thal

3 = normal; 6 = fixed defect; 7 = reversible defect

Categorical

Target

Diagnosis of heart disease (1 = true; 0 = false)

Target

[3]:
def load_data():
    file_url = "http://storage.googleapis.com/download.tensorflow.org/data/heart.csv"
    dataframe = pd.read_csv(file_url)

    val_dataframe = dataframe.sample(frac=0.2, random_state=1337)
    train_dataframe = dataframe.drop(val_dataframe.index)

    return train_dataframe, val_dataframe


def dataframe_to_dataset(dataframe):
    dataframe = dataframe.copy()
    labels = dataframe.pop("target")
    ds = tf.data.Dataset.from_tensor_slices((dict(dataframe), labels))
    ds = ds.shuffle(buffer_size=len(dataframe))
    return ds

2.3. Preprocessing & encoding of features#

The next cells use tfk.layers.Normalization() to apply standard scaling on the features.

Then, the tfk.layers.StringLookup and tfk.layers.IntegerLookup are used to encode categorical variables.

[7]:
def encode_numerical_feature(feature, name, dataset):
    # Create a Normalization layer for our feature
    normalizer = tfk.layers.Normalization()

    # Prepare a Dataset that only yields our feature
    feature_ds = dataset.map(lambda x, y: x[name])
    feature_ds = feature_ds.map(lambda x: tf.expand_dims(x, -1))

    # Learn the statistics of the data
    normalizer.adapt(feature_ds)

    # Normalize the input feature
    encoded_feature = normalizer(feature)
    return encoded_feature


def encode_categorical_feature(feature, name, dataset, is_string):
    lookup_class = (
        tfk.layers.StringLookup if is_string else tfk.layers.IntegerLookup
    )
    # Create a lookup layer which will turn strings into integer indices
    lookup = lookup_class(output_mode="binary")

    # Prepare a Dataset that only yields our feature
    feature_ds = dataset.map(lambda x, y: x[name])
    feature_ds = feature_ds.map(lambda x: tf.expand_dims(x, -1))

    # Learn the set of possible string values and assign them a fixed integer index
    lookup.adapt(feature_ds)

    # Turn the string input into integer indices
    encoded_feature = lookup(feature)
    return encoded_feature

2.4. Define the run-function#

The run-function defines how the objective that we want to maximize is computed. It takes a config dictionary as input and often returns a scalar value that we want to maximize. The config contains a sample value of hyperparameters that we want to tune. In this example we will search for:

  • units (default value: 32)

  • activation (default value: "relu")

  • dropout_rate (default value: 0.5)

  • num_epochs (default value: 50)

  • batch_size (default value: 32)

  • learning_rate (default value: 1e-3)

A hyperparameter value can be acessed easily in the dictionary through the corresponding key, for example config["units"].

[18]:
def count_params(model: tfk.Model) -> dict:
    """Evaluate the number of parameters of a Keras model.

    Args:
        model (tfk.Model): a Keras model.

    Returns:
        dict: a dictionary with the number of trainable ``"num_parameters_train"`` and
        non-trainable parameters ``"num_parameters"``.
    """

    def count_or_null(p):
        try:
            return K.count_params(p)
        except:
            return 0

    num_parameters_train = int(
        np.sum([count_or_null(p) for p in model.trainable_weights])
    )
    num_parameters = int(
        np.sum([count_or_null(p) for p in model.non_trainable_weights])
    )
    return {
        "num_parameters": num_parameters,
        "num_parameters_train": num_parameters_train,
    }
[19]:
def run(config: dict):
    tf.autograph.set_verbosity(0)
    # Load data and split into validation set
    train_dataframe, val_dataframe = load_data()
    train_ds = dataframe_to_dataset(train_dataframe)
    val_ds = dataframe_to_dataset(val_dataframe)
    train_ds = train_ds.batch(config["batch_size"])
    val_ds = val_ds.batch(config["batch_size"])

    # Categorical features encoded as integers
    sex = tfk.Input(shape=(1,), name="sex", dtype="int64")
    cp = tfk.Input(shape=(1,), name="cp", dtype="int64")
    fbs = tfk.Input(shape=(1,), name="fbs", dtype="int64")
    restecg = tfk.Input(shape=(1,), name="restecg", dtype="int64")
    exang = tfk.Input(shape=(1,), name="exang", dtype="int64")
    ca = tfk.Input(shape=(1,), name="ca", dtype="int64")

    # Categorical feature encoded as string
    thal = tfk.Input(shape=(1,), name="thal", dtype="string")

    # Numerical features
    age = tfk.Input(shape=(1,), name="age")
    trestbps = tfk.Input(shape=(1,), name="trestbps")
    chol = tfk.Input(shape=(1,), name="chol")
    thalach = tfk.Input(shape=(1,), name="thalach")
    oldpeak = tfk.Input(shape=(1,), name="oldpeak")
    slope = tfk.Input(shape=(1,), name="slope")

    all_inputs = [
        sex,
        cp,
        fbs,
        restecg,
        exang,
        ca,
        thal,
        age,
        trestbps,
        chol,
        thalach,
        oldpeak,
        slope,
    ]

    # Integer categorical features
    sex_encoded = encode_categorical_feature(sex, "sex", train_ds, False)
    cp_encoded = encode_categorical_feature(cp, "cp", train_ds, False)
    fbs_encoded = encode_categorical_feature(fbs, "fbs", train_ds, False)
    restecg_encoded = encode_categorical_feature(restecg, "restecg", train_ds, False)
    exang_encoded = encode_categorical_feature(exang, "exang", train_ds, False)
    ca_encoded = encode_categorical_feature(ca, "ca", train_ds, False)

    # String categorical features
    thal_encoded = encode_categorical_feature(thal, "thal", train_ds, True)

    # Numerical features
    age_encoded = encode_numerical_feature(age, "age", train_ds)
    trestbps_encoded = encode_numerical_feature(trestbps, "trestbps", train_ds)
    chol_encoded = encode_numerical_feature(chol, "chol", train_ds)
    thalach_encoded = encode_numerical_feature(thalach, "thalach", train_ds)
    oldpeak_encoded = encode_numerical_feature(oldpeak, "oldpeak", train_ds)
    slope_encoded = encode_numerical_feature(slope, "slope", train_ds)

    all_features = tfk.layers.concatenate(
        [
            sex_encoded,
            cp_encoded,
            fbs_encoded,
            restecg_encoded,
            exang_encoded,
            slope_encoded,
            ca_encoded,
            thal_encoded,
            age_encoded,
            trestbps_encoded,
            chol_encoded,
            thalach_encoded,
            oldpeak_encoded,
        ]
    )
    x = tfk.layers.Dense(config["units"], activation=config["activation"])(
        all_features
    )
    x = tfk.layers.Dropout(config["dropout_rate"])(x)
    output = tfk.layers.Dense(1, activation="sigmoid")(x)
    model = tfk.Model(all_inputs, output)

    optimizer = tfk.optimizers.Adam(learning_rate=config["learning_rate"])
    model.compile(optimizer, "binary_crossentropy", metrics=["accuracy"])

    history = model.fit(
        train_ds, epochs=config["num_epochs"], validation_data=val_ds, verbose=0
    )

    objective = history.history["val_accuracy"][-1]
    metadata = {
        "loss": history.history["loss"],
        "val_loss": history.history["val_loss"],
        "accuracy": history.history["accuracy"],
        "val_accuracy": history.history["val_accuracy"],
    }
    metadata = {k:json.dumps(v) for k,v in metadata.items()}
    metadata.update(count_params(model))

    return {"objective": objective, "metadata": metadata}
Note

The objective maximised by DeepHyper is the scalar value returned by the run-function.

In this tutorial it corresponds to the validation accuracy of the last epoch of training which we retrieve in the History object returned by the model.fit(...) call.

...
history = model.fit(
    train_ds, epochs=config["num_epochs"], validation_data=val_ds, verbose=0
)
return history.history["val_accuracy"][-1]
...

Using an objective like max(history.history['val_accuracy']) can have undesired side effects.

For example, it is possible that the training curves will overshoot a local maximum, resulting in a model without the capacity to flexibly adapt to new data in the future.

2.5. Define the Hyperparameter optimization problem#

Hyperparameter ranges are defined using the following syntax:

  • Discrete integer ranges are generated from a tuple (lower: int, upper: int)

  • Continuous prarameters are generated from a tuple (lower: float, upper: float)

  • Categorical or nonordinal hyperparameter ranges can be given as a list of possible values [val1, val2, ...]

[20]:
from deephyper.hpo import HpProblem


# Creation of an hyperparameter problem
problem = HpProblem()

# Discrete hyperparameter (sampled with uniform prior)
problem.add_hyperparameter((8, 128), "units", default_value=32)
problem.add_hyperparameter((10, 100), "num_epochs", default_value=50)


# Categorical hyperparameter (sampled with uniform prior)
ACTIVATIONS = [
    "elu", "gelu", "hard_sigmoid", "linear", "relu", "selu",
    "sigmoid", "softplus", "softsign", "swish", "tanh",
]
problem.add_hyperparameter(ACTIVATIONS, "activation", default_value="relu")


# Real hyperparameter (sampled with uniform prior)
problem.add_hyperparameter((0.0, 0.6), "dropout_rate", default_value=0.5)


# Discrete and Real hyperparameters (sampled with log-uniform)
problem.add_hyperparameter((8, 256, "log-uniform"), "batch_size", default_value=32)
problem.add_hyperparameter((1e-5, 1e-2, "log-uniform"), "learning_rate", default_value=1e-3)

problem
[20]:
Configuration space object:
  Hyperparameters:
    activation, Type: Categorical, Choices: {elu, gelu, hard_sigmoid, linear, relu, selu, sigmoid, softplus, softsign, swish, tanh}, Default: relu
    batch_size, Type: UniformInteger, Range: [8, 256], Default: 32, on log-scale
    dropout_rate, Type: UniformFloat, Range: [0.0, 0.6], Default: 0.5
    learning_rate, Type: UniformFloat, Range: [1e-05, 0.01], Default: 0.001, on log-scale
    num_epochs, Type: UniformInteger, Range: [10, 100], Default: 50
    units, Type: UniformInteger, Range: [8, 128], Default: 32

2.6. Evaluate a default configuration#

We evaluate the performance of the default set of hyperparameters provided in the Keras tutorial.

[21]:
import ray


# We launch the Ray run-time depending of the detected local ressources
# and execute the `run` function with the default configuration
# WARNING: in the case of GPUs it is important to follow this scheme
# to avoid multiple processes (Ray workers vs current process) to lock
# the same GPU.
if is_gpu_available:
    if not(ray.is_initialized()):
        ray.init(num_cpus=n_gpus, num_gpus=n_gpus, log_to_driver=False)

    run_default = ray.remote(num_cpus=1, num_gpus=1)(run)
    out = ray.get(run_default.remote(problem.default_configuration))
else:
    if not(ray.is_initialized()):
        ray.init(num_cpus=1, log_to_driver=False)
    run_default = run
    out = run_default(problem.default_configuration)

objective_default = out["objective"]
metadata_default = out["metadata"]

print(f"Accuracy Default Configuration:  {objective_default:.3f}")

print("Metadata Default Configuration")
for k,v in out["metadata"].items():
    print(f"\t- {k}: {v}")
WARNING:absl:At this time, the v2.11+ optimizer `tf.keras.optimizers.Adam` runs slowly on M1/M2 Macs, please use the legacy TF-Keras optimizer instead, located at `tf.keras.optimizers.legacy.Adam`.
Accuracy Default Configuration:  0.803
Metadata Default Configuration
        - loss: [0.7089458107948303, 0.7017484307289124, 0.6555561423301697, 0.5975483059883118, 0.5408881306648254, 0.5184577107429504, 0.4933375418186188, 0.5026713013648987, 0.46867308020591736, 0.43367132544517517, 0.4040399491786957, 0.4410407245159149, 0.41532158851623535, 0.42353206872940063, 0.40005776286125183, 0.39258643984794617, 0.385629266500473, 0.3648240566253662, 0.37516599893569946, 0.3589401841163635, 0.36665764451026917, 0.34737372398376465, 0.3585016429424286, 0.3524617850780487, 0.32876622676849365, 0.34414738416671753, 0.3265135884284973, 0.3396293520927429, 0.33635473251342773, 0.32959404587745667, 0.32769283652305603, 0.3066641390323639, 0.3406123220920563, 0.3300285041332245, 0.3249838352203369, 0.31367143988609314, 0.2998355031013489, 0.28332749009132385, 0.28868308663368225, 0.3160744905471802, 0.290345698595047, 0.30034708976745605, 0.25978073477745056, 0.2852059602737427, 0.25552764534950256, 0.2629491984844208, 0.2993203401565552, 0.2639424502849579, 0.2701556980609894, 0.2878676652908325]
        - val_loss: [0.6440699696540833, 0.5855709314346313, 0.5408179759979248, 0.5036494135856628, 0.4742478132247925, 0.45106351375579834, 0.4320058226585388, 0.4158112704753876, 0.4028586745262146, 0.39374834299087524, 0.3855816125869751, 0.37890157103538513, 0.3737941086292267, 0.3691202998161316, 0.36556652188301086, 0.36313873529434204, 0.36203548312187195, 0.3613433837890625, 0.3603682816028595, 0.3586443066596985, 0.3586690127849579, 0.35930439829826355, 0.3583105206489563, 0.35828936100006104, 0.3591354489326477, 0.3601025342941284, 0.36102163791656494, 0.36082351207733154, 0.36077219247817993, 0.3612353503704071, 0.3610706627368927, 0.36180663108825684, 0.36208730936050415, 0.36208462715148926, 0.36210131645202637, 0.36168861389160156, 0.36216410994529724, 0.3641442358493805, 0.36540892720222473, 0.366742879152298, 0.36642923951148987, 0.3666687607765198, 0.3686812222003937, 0.37019941210746765, 0.3713856339454651, 0.37242332100868225, 0.37362170219421387, 0.3748522102832794, 0.37537622451782227, 0.37630414962768555]
        - accuracy: [0.5330578684806824, 0.586776852607727, 0.6198347210884094, 0.6570248007774353, 0.7314049601554871, 0.7520661354064941, 0.7768595218658447, 0.7479338645935059, 0.7561983466148376, 0.8016529083251953, 0.8181818127632141, 0.7768595218658447, 0.7851239442825317, 0.8264462947845459, 0.797520637512207, 0.8140496015548706, 0.8347107172012329, 0.8264462947845459, 0.797520637512207, 0.8264462947845459, 0.8305785059928894, 0.8429751992225647, 0.8388429880142212, 0.8305785059928894, 0.8636363744735718, 0.8471074104309082, 0.8719007968902588, 0.8512396812438965, 0.8305785059928894, 0.8429751992225647, 0.8471074104309082, 0.8677685856819153, 0.8347107172012329, 0.8719007968902588, 0.8471074104309082, 0.85537189245224, 0.8760330677032471, 0.8925619721412659, 0.8801652789115906, 0.8595041036605835, 0.8966942429542542, 0.8595041036605835, 0.8884297609329224, 0.8842975497245789, 0.9049586653709412, 0.9008264541625977, 0.8719007968902588, 0.9090909361839294, 0.8719007968902588, 0.8842975497245789]
        - val_accuracy: [0.6393442749977112, 0.7540983557701111, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8360655903816223, 0.8360655903816223, 0.8196721076965332, 0.8196721076965332, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889]
        - num_parameters: 18
        - num_parameters_train: 1217

2.7. Define the evaluator object#

The Evaluator object allows to change the parallelization backend used by DeepHyper.
It is a standalone object which schedules the execution of remote tasks. All evaluators needs a run_function to be instantiated.
Then a keyword method defines the backend (e.g., "ray") and the method_kwargs corresponds to keyword arguments of this chosen method.
evaluator = Evaluator.create(run_function, method, method_kwargs)

Once created the evaluator.num_workers gives access to the number of available parallel workers.

Finally, to submit and collect tasks to the evaluator one just needs to use the following interface:

configs = [{"units": 8, ...}, ...]
evaluator.submit(configs)
...
# To collect the first finished task (asynchronous)
tasks_done = evaluator.get("BATCH", size=1)

# To collect all of the pending tasks (synchronous)
tasks_done = evaluator.get("ALL")

Warning

Each Evaluator saves its own state, therefore it is crucial to create a new evaluator when launching a fresh search.

[22]:
from deephyper.evaluator import Evaluator
from deephyper.evaluator.callback import TqdmCallback


def get_evaluator(run_function):
    # Default arguments for Ray: 1 worker and 1 worker per evaluation
    method_kwargs = {
        "num_cpus": 1,
        "num_cpus_per_task": 1,
        "callbacks": [TqdmCallback()]
    }

    # If GPU devices are detected then it will create 'n_gpus' workers
    # and use 1 worker for each evaluation
    if is_gpu_available:
        method_kwargs["num_cpus"] = n_gpus
        method_kwargs["num_gpus"] = n_gpus
        method_kwargs["num_cpus_per_task"] = 1
        method_kwargs["num_gpus_per_task"] = 1

    evaluator = Evaluator.create(
        run_function,
        method="ray",
        method_kwargs=method_kwargs
    )
    print(f"Created new evaluator with {evaluator.num_workers} worker{'s' if evaluator.num_workers > 1 else ''} and config: {method_kwargs}", )

    return evaluator

evaluator_1 = get_evaluator(run)
Created new evaluator with 1 worker and config: {'num_cpus': 1, 'num_cpus_per_task': 1, 'callbacks': [<deephyper.evaluator.callback.TqdmCallback object at 0x3df0ce2a0>]}

2.8. Define and run the centralized Bayesian optimization search (CBO)#

A primary pillar of hyperparameter search in DeepHyper is given by a centralized Bayesian optimization search (henceforth CBO). CBO may be described in the following algorithm:

93f882c3eb6c44bea803799717becab1


Following the parallelized evaluation of these configurations, a low-fidelity and high efficiency model (henceforth “the surrogate”) is devised to reproduce the relationship between the input variables involved in the model (i.e., the choice of hyperparameters) and the outputs (which are generally a measure of validation data accuracy).

After obtaining this surrogate of the validation accuracy, we may utilize ideas from classical methods in Bayesian optimization literature for adaptively sample the search space of hyperparameters.

First, the surrogate is used to obtain an estimate for the mean value of the validation accuracy at a certain sampling location \(x\) in addition to an estimated variance. The latter requirement restricts us to the use of high efficiency data-driven modeling strategies that have inbuilt variance estimates (such as a Gaussian process or Random Forest regressor).

Regions where the mean is high represent opportunities for exploitation and regions where the variance is high represent opportunities for exploration. An optimistic acquisition function called UCB can be constructed using these two quantities:

\[L_{\text{UCB}}(x) = \mu(x) + \kappa \cdot \sigma(x)\]

The unevaluated hyperparameter configurations that maximize the acquisition function are chosen for the next batch of evaluations.

Note that the choice of the variance weighting parameter \(\kappa\) controls the degree of exploration in the hyperparameter search with zero indicating purely exploitation (unseen configurations where the predicted accuracy is highest will be sampled).

The top s configurations are selected for the new batch. The following schematic demonstrates this process:

4e2da2716cb84f0782c8ab1e0a0993c5

The process of obtaining s configurations relies on the “constant-liar” strategy where a sampled configuration is mapped to a dummy output given by a bulk metric of all the evaluated configurations thus far (such as the maximum, mean or median validation accuracy).

Prior to sampling the next configuration by acquisition function maximization, the surrogate is retrained with the dummy output as a data point. As the true validation accuracy becomes available for one of the sampled configurations, the dummy output is replaced and the surrogate is updated.

This allows for scalable asynchronous (or batch synchronous) sampling of new hyperparameter configurations.

2.8.1. Choice of surrogate model#

Users should note that our choice of the surrogate is given by the Random Forest regressor due to its ability to handle non-ordinal data (hyperparameter configurations may not be purely continuous or even numerical). Evidence for how they outperform other methods (such as Gaussian processes) is also available in [1]

d638b9dd553d4a7bbaa1545cb834131e

27c0bb0b3b254e70a0cb2cbe80c7952f

2.8.1.1. Setup CBO#

We create the CBO using the problem and evaluator defined above.

[26]:
from deephyper.hpo import CBO
# Uncomment the following line to show the arguments of CBO.
# help(CBO)
[28]:
# Instanciate the search with the problem and the evaluator that we created before
search = CBO(problem, evaluator_1, acq_optimizer="mixedga", acq_optimizer_freq=1, initial_points=[problem.default_configuration])

Note

All DeepHyper’s search algorithm have two stopping criteria:

  • max_evals (int): Defines the maximum number of evaluations that we want to perform. Default to -1 for an infinite number.

  • timeout (int): Defines a time budget (in seconds) before stopping the search. Default to None for an infinite time budget.

[29]:
results = search.search(max_evals=50)

Warning

The search call does not output any information about the current status of the search. However, a results.csv file is created in the local directly and can be visualized to see finished tasks.

The returned results is a Pandas Dataframe where columns starting by "p:" are hyperparameters, columns starting by "m:" are additional metadata (from the user or from the Evaluator) as well as the objective value and the job_id:

  • job_id is a unique identifier corresponding to the order of creation of tasks.

  • objective is the value returned by the run-function.

  • m:timestamp_submit is the time (in seconds) when the task was created by the evaluator since the creation of the evaluator.

  • m:timestamp_gather is the time (in seconds) when the task was received after finishing by the evaluator since the creation of the evaluator.

[30]:
results
[30]:
p:activation p:batch_size p:dropout_rate p:learning_rate p:num_epochs p:units objective job_id m:timestamp_submit m:loss m:val_loss m:accuracy m:val_accuracy m:num_parameters m:num_parameters_train m:timestamp_gather
0 relu 32 0.500000 0.001000 50 32 0.803279 0 50.325637 [0.706670343875885, 0.6361508369445801, 0.5850... [0.6415352821350098, 0.5993135571479797, 0.564... [0.5123966932296753, 0.6280992031097412, 0.710... [0.6721311211585999, 0.7540983557701111, 0.770... 18 1217 56.067836
1 linear 114 0.562393 0.002370 47 128 0.803279 1 56.123101 [0.6474491953849792, 0.49577072262763977, 0.41... [0.45269274711608887, 0.39127492904663086, 0.3... [0.6652892827987671, 0.7685950398445129, 0.805... [0.7868852615356445, 0.7704917788505554, 0.819... 18 4865 57.814659
2 linear 9 0.533665 0.000075 69 84 0.836066 2 58.013101 [0.6138803958892822, 0.5668570399284363, 0.563... [0.5532547235488892, 0.533083975315094, 0.5151... [0.6776859760284424, 0.7438016533851624, 0.706... [0.7868852615356445, 0.8032786846160889, 0.803... 18 3193 61.268282
3 linear 230 0.319679 0.001157 45 59 0.836066 3 61.298386 [0.7759286761283875, 0.6881203055381775, 0.633... [0.6713256239891052, 0.6128526329994202, 0.564... [0.5123966932296753, 0.6198347210884094, 0.669... [0.6065573692321777, 0.688524603843689, 0.7049... 18 2243 63.091872
4 linear 50 0.011336 0.000046 96 79 0.852459 4 63.144296 [0.7343717813491821, 0.7315843105316162, 0.720... [0.753821849822998, 0.7444661855697632, 0.7352... [0.4793388545513153, 0.4958677589893341, 0.537... [0.39344263076782227, 0.44262295961380005, 0.4... 18 3003 65.508172
5 softsign 68 0.194614 0.000110 71 98 0.770492 5 65.540140 [0.7060982584953308, 0.7222269773483276, 0.696... [0.706885576248169, 0.6966867446899414, 0.6866... [0.5041322112083435, 0.4628099203109741, 0.528... [0.5245901346206665, 0.5573770403862, 0.573770... 18 3725 67.730778
6 linear 12 0.032802 0.003319 94 124 0.803279 6 67.762288 [0.3936194181442261, 0.298471599817276, 0.2840... [0.44452720880508423, 0.4467885494232178, 0.46... [0.797520637512207, 0.8595041036605835, 0.8636... [0.8524590134620667, 0.8360655903816223, 0.819... 18 4713 70.985557
7 gelu 85 0.212161 0.002263 29 53 0.836066 7 71.017057 [0.7266819477081299, 0.650373101234436, 0.6081... [0.6597949862480164, 0.5888913869857788, 0.532... [0.4917355477809906, 0.6652892827987671, 0.714... [0.6721311211585999, 0.7540983557701111, 0.786... 18 2015 72.507068
8 selu 164 0.540524 0.001172 14 48 0.836066 8 72.537765 [0.9216533899307251, 0.8793801665306091, 0.743... [0.7923454642295837, 0.7235696315765381, 0.663... [0.44214877486228943, 0.4876033067703247, 0.53... [0.5245901346206665, 0.5409836173057556, 0.639... 18 1825 74.019521
9 relu 139 0.356863 0.000101 59 94 0.737705 9 74.053959 [0.6047406196594238, 0.6085638999938965, 0.576... [0.6019086241722107, 0.598895788192749, 0.5958... [0.6652892827987671, 0.6776859760284424, 0.739... [0.6721311211585999, 0.6721311211585999, 0.672... 18 3573 75.723831
10 linear 148 0.005233 0.000012 100 91 0.557377 10 76.229987 [0.8135287761688232, 0.8092227578163147, 0.812... [0.7794104218482971, 0.7784043550491333, 0.777... [0.41735535860061646, 0.42561984062194824, 0.4... [0.39344263076782227, 0.39344263076782227, 0.3... 18 3459 78.200037
11 linear 50 0.014352 0.000025 100 79 0.786885 11 78.871375 [0.9058765769004822, 0.8957070708274841, 0.892... [0.9036123752593994, 0.8989002704620361, 0.894... [0.2933884263038635, 0.3264462947845459, 0.314... [0.32786884903907776, 0.32786884903907776, 0.3... 18 3003 81.243957
12 linear 109 0.000487 0.000045 96 79 0.803279 12 81.892266 [0.8505944013595581, 0.8450313210487366, 0.839... [0.8944583535194397, 0.888396143913269, 0.8825... [0.2975206673145294, 0.3016528785228729, 0.305... [0.26229506731033325, 0.26229506731033325, 0.2... 18 3003 83.936618
13 relu 50 0.011727 0.000049 94 79 0.786885 13 84.554558 [0.772240400314331, 0.7702962160110474, 0.7616... [0.7718428373336792, 0.7672584652900696, 0.762... [0.35537189245224, 0.35537189245224, 0.3801652... [0.3442623019218445, 0.3442623019218445, 0.360... 18 3003 86.888750
14 linear 48 0.010876 0.000047 96 86 0.786885 14 87.831352 [0.7737085223197937, 0.7620970010757446, 0.758... [0.7752880454063416, 0.7666113972663879, 0.757... [0.41322314739227295, 0.43388429284095764, 0.4... [0.44262295961380005, 0.4754098355770111, 0.52... 18 3269 90.291341
15 linear 50 0.012980 0.000046 96 81 0.852459 15 91.166686 [0.6715313196182251, 0.6666457056999207, 0.660... [0.5990051031112671, 0.5933732986450195, 0.587... [0.6404958963394165, 0.6611570119857788, 0.669... [0.6721311211585999, 0.6721311211585999, 0.704... 18 3079 93.352006
16 linear 49 0.010183 0.000047 98 81 0.803279 16 94.176494 [0.9817445278167725, 0.9637849926948547, 0.955... [1.0305403470993042, 1.0171641111373901, 1.003... [0.32231405377388, 0.3388429880142212, 0.35950... [0.3606557250022888, 0.3606557250022888, 0.360... 18 3079 96.356579
17 linear 51 0.011189 0.000035 95 81 0.754098 17 97.268516 [0.8748018741607666, 0.8668413162231445, 0.861... [0.8491439819335938, 0.8422006368637085, 0.835... [0.3264462947845459, 0.3181818127632141, 0.347... [0.2786885201931, 0.2950819730758667, 0.295081... 18 3079 99.473332
18 linear 50 0.019094 0.000046 95 79 0.836066 18 100.300250 [0.6110990047454834, 0.6006994843482971, 0.598... [0.5857220888137817, 0.580162525177002, 0.5746... [0.7438016533851624, 0.7479338645935059, 0.760... [0.8360655903816223, 0.8360655903816223, 0.836... 18 3003 102.724646
19 linear 51 0.012842 0.000041 96 81 0.803279 19 103.713039 [0.6344509720802307, 0.6267118453979492, 0.626... [0.6352565884590149, 0.6293896436691284, 0.623... [0.6900826692581177, 0.6983470916748047, 0.710... [0.6557376980781555, 0.6557376980781555, 0.655... 18 3079 105.967294
20 relu 171 0.508927 0.001910 10 83 0.770492 20 106.621811 [0.7741369009017944, 0.6997448801994324, 0.641... [0.6997786164283752, 0.6350088119506836, 0.582... [0.46694216132164, 0.557851254940033, 0.648760... [0.5737704634666443, 0.6557376980781555, 0.737... 18 3155 107.922894
21 gelu 247 0.291055 0.000012 14 57 0.393443 21 108.269415 [0.7637456655502319, 0.7536334991455078, 0.767... [0.770416259765625, 0.7701885104179382, 0.7699... [0.4586776793003082, 0.42975205183029175, 0.44... [0.39344263076782227, 0.39344263076782227, 0.3... 18 2167 109.827998
22 elu 247 0.330785 0.000270 32 57 0.754098 22 110.438860 [0.7152367234230042, 0.7155708074569702, 0.698... [0.7250485420227051, 0.7178604006767273, 0.710... [0.5123966932296753, 0.4958677589893341, 0.570... [0.44262295961380005, 0.44262295961380005, 0.4... 18 2167 111.909659
23 gelu 256 0.174206 0.001193 22 54 0.836066 23 112.662811 [0.7355976700782776, 0.7147685885429382, 0.687... [0.7190053462982178, 0.6970692276954651, 0.676... [0.44214877486228943, 0.4834710657596588, 0.55... [0.3606557250022888, 0.4754098355770111, 0.573... 18 2053 114.152527
24 elu 50 0.012164 0.000047 96 73 0.786885 24 115.373385 [0.7974379062652588, 0.7888069748878479, 0.779... [0.8222996592521667, 0.8145594596862793, 0.806... [0.39669421315193176, 0.40082645416259766, 0.4... [0.37704917788505554, 0.39344263076782227, 0.3... 18 2775 117.593125
25 hard_sigmoid 244 0.215005 0.000582 13 56 0.229508 25 118.166421 [0.8906719088554382, 0.9096777439117432, 0.914... [0.9311334490776062, 0.9191083312034607, 0.907... [0.3140496015548706, 0.28925618529319763, 0.30... [0.2295081913471222, 0.2295081913471222, 0.229... 18 2129 119.727548
26 hard_sigmoid 228 0.195278 0.001142 16 48 0.770492 26 120.442391 [0.7672627568244934, 0.7242721319198608, 0.727... [0.7250207662582397, 0.701765239238739, 0.6803... [0.3636363744735718, 0.43801653385162354, 0.45... [0.26229506731033325, 0.4590163826942444, 0.62... 18 1825 121.826843
27 selu 160 0.546215 0.000789 13 11 0.688525 27 122.283213 [0.7485548257827759, 0.7522017955780029, 0.763... [0.7076205015182495, 0.6924997568130493, 0.678... [0.5950413346290588, 0.586776852607727, 0.6157... [0.6065573692321777, 0.6065573692321777, 0.622... 18 419 123.619181
28 linear 8 0.526041 0.000060 13 82 0.786885 28 124.196094 [0.8352355360984802, 0.7898195385932922, 0.765... [0.7591118812561035, 0.7175912857055664, 0.682... [0.43388429284095764, 0.5206611752510071, 0.49... [0.4590163826942444, 0.5245901346206665, 0.606... 18 3117 125.941018
29 gelu 256 0.188568 0.001515 21 8 0.737705 29 126.460014 [0.7105452418327332, 0.7035156488418579, 0.698... [0.7025739550590515, 0.696033239364624, 0.6896... [0.5537189841270447, 0.5371900796890259, 0.574... [0.6229507923126221, 0.6393442749977112, 0.639... 18 305 128.180563
30 gelu 256 0.244248 0.001660 23 56 0.786885 30 128.727047 [0.6546282768249512, 0.6317970752716064, 0.637... [0.572919487953186, 0.556894838809967, 0.54188... [0.6818181872367859, 0.7066115736961365, 0.706... [0.7540983557701111, 0.7540983557701111, 0.754... 18 2129 130.153736
31 linear 224 0.305334 0.000660 44 61 0.819672 31 130.919080 [0.9484014511108398, 0.9179948568344116, 0.877... [0.924122154712677, 0.8768399953842163, 0.8322... [0.32231405377388, 0.28925618529319763, 0.3512... [0.19672131538391113, 0.19672131538391113, 0.2... 18 2319 132.589723
32 linear 50 0.479327 0.000068 96 17 0.737705 32 133.278981 [0.9485551118850708, 0.9460268616676331, 0.900... [0.8557782769203186, 0.8498371243476868, 0.843... [0.5123966932296753, 0.5247933864593506, 0.483... [0.5245901346206665, 0.5245901346206665, 0.524... 18 647 135.438342
33 selu 164 0.561031 0.000726 14 55 0.819672 33 136.143825 [0.8745442628860474, 0.8545739054679871, 0.796... [0.7864746451377869, 0.7405412793159485, 0.699... [0.4752066135406494, 0.5, 0.5247933864593506, ... [0.4262295067310333, 0.5081967115402222, 0.606... 18 2091 137.718102
34 selu 254 0.322091 0.001231 12 47 0.819672 34 138.606087 [0.7325422763824463, 0.7275435328483582, 0.714... [0.6849062442779541, 0.6536070108413696, 0.624... [0.6115702390670776, 0.56611567735672, 0.59090... [0.5573770403862, 0.5901639461517334, 0.622950... 18 1787 140.006373
35 gelu 96 0.235332 0.000678 35 49 0.819672 35 140.362036 [0.6482117176055908, 0.640116274356842, 0.6336... [0.6331732869148254, 0.6124953031539917, 0.593... [0.64462810754776, 0.6776859760284424, 0.70247... [0.7377049326896667, 0.7049180269241333, 0.721... 18 1863 141.949003
36 linear 50 0.327453 0.000119 96 81 0.819672 36 142.978975 [0.7736309170722961, 0.6837102770805359, 0.716... [0.6871227622032166, 0.6735551953315735, 0.660... [0.6033057570457458, 0.6776859760284424, 0.623... [0.7213114500045776, 0.7213114500045776, 0.770... 18 3079 145.525915
37 gelu 134 0.098413 0.000580 22 54 0.786885 37 146.242206 [0.6934930682182312, 0.6824401617050171, 0.672... [0.7072174549102783, 0.6886430382728577, 0.670... [0.4793388545513153, 0.5206611752510071, 0.582... [0.4590163826942444, 0.5245901346206665, 0.655... 18 2053 147.982779
38 hard_sigmoid 244 0.317880 0.001168 45 53 0.786885 38 148.651484 [0.9446930289268494, 0.9471455216407776, 0.897... [0.9400205612182617, 0.9183498620986938, 0.897... [0.3140496015548706, 0.3305785059928894, 0.322... [0.2295081913471222, 0.2295081913471222, 0.229... 18 2015 150.211455
39 linear 11 0.539959 0.000073 71 42 0.836066 39 151.001553 [0.7638189792633057, 0.768632173538208, 0.8030... [0.7098137140274048, 0.6863061785697937, 0.662... [0.5289255976676941, 0.5, 0.5289255976676941, ... [0.5573770403862, 0.5573770403862, 0.557377040... 18 1597 153.791880
40 linear 9 0.594717 0.000072 75 22 0.836066 40 154.191006 [0.9886115193367004, 0.8805288672447205, 0.913... [0.8028327226638794, 0.7799118161201477, 0.759... [0.45041322708129883, 0.4958677589893341, 0.45... [0.49180328845977783, 0.5245901346206665, 0.54... 18 837 157.699540
41 linear 9 0.583705 0.000074 70 11 0.704918 41 158.364740 [1.2854955196380615, 1.2296828031539917, 1.140... [1.1728463172912598, 1.1491509675979614, 1.125... [0.2975206673145294, 0.3429751992225647, 0.371... [0.2295081913471222, 0.26229506731033325, 0.27... 18 419 161.862414
42 linear 10 0.579446 0.000080 70 21 0.819672 42 162.550486 [0.7621033191680908, 0.7450456023216248, 0.790... [0.6448240280151367, 0.630098819732666, 0.6165... [0.5619834661483765, 0.5495867729187012, 0.570... [0.5901639461517334, 0.6065573692321777, 0.639... 18 799 165.567682
43 linear 11 0.542695 0.000010 72 41 0.704918 43 166.191566 [0.7994367480278015, 0.7612443566322327, 0.755... [0.701274037361145, 0.6975817084312439, 0.6937... [0.5082644820213318, 0.5454545617103577, 0.557... [0.5245901346206665, 0.5245901346206665, 0.540... 18 1559 169.025697
44 hard_sigmoid 9 0.531649 0.000072 72 56 0.803279 44 169.882809 [0.6588014364242554, 0.6326671242713928, 0.658... [0.5621706247329712, 0.5559288263320923, 0.550... [0.6115702390670776, 0.6652892827987671, 0.657... [0.7704917788505554, 0.7704917788505554, 0.770... 18 2129 172.994597
45 selu 149 0.537188 0.001174 14 30 0.836066 45 173.636141 [0.9731682538986206, 0.9632638692855835, 0.834... [0.7772341370582581, 0.7184216976165771, 0.666... [0.41322314739227295, 0.4586776793003082, 0.50... [0.4262295067310333, 0.49180328845977783, 0.59... 18 1141 175.037412
46 linear 49 0.011794 0.000046 96 69 0.754098 46 175.385063 [0.9058528542518616, 0.8961911797523499, 0.890... [0.9788206815719604, 0.9681548476219177, 0.957... [0.3677685856819153, 0.3595041334629059, 0.371... [0.26229506731033325, 0.26229506731033325, 0.2... 18 2623 177.826115
47 selu 157 0.534877 0.001165 16 23 0.803279 47 178.533461 [1.0175780057907104, 0.9338928461074829, 0.922... [0.9525404572486877, 0.9009444713592529, 0.852... [0.42148759961128235, 0.46694216132164, 0.4710... [0.32786884903907776, 0.4098360538482666, 0.45... 18 875 179.892400
48 gelu 86 0.196295 0.002204 11 54 0.819672 48 180.589500 [0.5843045711517334, 0.5370311141014099, 0.489... [0.5310717821121216, 0.48231208324432373, 0.44... [0.7355371713638306, 0.7685950398445129, 0.789... [0.7868852615356445, 0.7868852615356445, 0.754... 18 2053 181.977024
49 linear 9 0.591808 0.000048 80 20 0.819672 49 182.668562 [1.0278195142745972, 1.0783154964447021, 1.051... [0.9265504479408264, 0.9104623198509216, 0.894... [0.42975205183029175, 0.37603306770324707, 0.3... [0.3606557250022888, 0.3606557250022888, 0.377... 18 761 185.918045

The search can be continued without any issue.

[31]:
results = search.search(max_evals=5)

results
[31]:
p:activation p:batch_size p:dropout_rate p:learning_rate p:num_epochs p:units objective job_id m:timestamp_submit m:loss m:val_loss m:accuracy m:val_accuracy m:num_parameters m:num_parameters_train m:timestamp_gather
0 relu 32 0.500000 0.001000 50 32 0.803279 0 50.325637 [0.706670343875885, 0.6361508369445801, 0.5850... [0.6415352821350098, 0.5993135571479797, 0.564... [0.5123966932296753, 0.6280992031097412, 0.710... [0.6721311211585999, 0.7540983557701111, 0.770... 18 1217 56.067836
1 linear 114 0.562393 0.002370 47 128 0.803279 1 56.123101 [0.6474491953849792, 0.49577072262763977, 0.41... [0.45269274711608887, 0.39127492904663086, 0.3... [0.6652892827987671, 0.7685950398445129, 0.805... [0.7868852615356445, 0.7704917788505554, 0.819... 18 4865 57.814659
2 linear 9 0.533665 0.000075 69 84 0.836066 2 58.013101 [0.6138803958892822, 0.5668570399284363, 0.563... [0.5532547235488892, 0.533083975315094, 0.5151... [0.6776859760284424, 0.7438016533851624, 0.706... [0.7868852615356445, 0.8032786846160889, 0.803... 18 3193 61.268282
3 linear 230 0.319679 0.001157 45 59 0.836066 3 61.298386 [0.7759286761283875, 0.6881203055381775, 0.633... [0.6713256239891052, 0.6128526329994202, 0.564... [0.5123966932296753, 0.6198347210884094, 0.669... [0.6065573692321777, 0.688524603843689, 0.7049... 18 2243 63.091872
4 linear 50 0.011336 0.000046 96 79 0.852459 4 63.144296 [0.7343717813491821, 0.7315843105316162, 0.720... [0.753821849822998, 0.7444661855697632, 0.7352... [0.4793388545513153, 0.4958677589893341, 0.537... [0.39344263076782227, 0.44262295961380005, 0.4... 18 3003 65.508172
5 softsign 68 0.194614 0.000110 71 98 0.770492 5 65.540140 [0.7060982584953308, 0.7222269773483276, 0.696... [0.706885576248169, 0.6966867446899414, 0.6866... [0.5041322112083435, 0.4628099203109741, 0.528... [0.5245901346206665, 0.5573770403862, 0.573770... 18 3725 67.730778
6 linear 12 0.032802 0.003319 94 124 0.803279 6 67.762288 [0.3936194181442261, 0.298471599817276, 0.2840... [0.44452720880508423, 0.4467885494232178, 0.46... [0.797520637512207, 0.8595041036605835, 0.8636... [0.8524590134620667, 0.8360655903816223, 0.819... 18 4713 70.985557
7 gelu 85 0.212161 0.002263 29 53 0.836066 7 71.017057 [0.7266819477081299, 0.650373101234436, 0.6081... [0.6597949862480164, 0.5888913869857788, 0.532... [0.4917355477809906, 0.6652892827987671, 0.714... [0.6721311211585999, 0.7540983557701111, 0.786... 18 2015 72.507068
8 selu 164 0.540524 0.001172 14 48 0.836066 8 72.537765 [0.9216533899307251, 0.8793801665306091, 0.743... [0.7923454642295837, 0.7235696315765381, 0.663... [0.44214877486228943, 0.4876033067703247, 0.53... [0.5245901346206665, 0.5409836173057556, 0.639... 18 1825 74.019521
9 relu 139 0.356863 0.000101 59 94 0.737705 9 74.053959 [0.6047406196594238, 0.6085638999938965, 0.576... [0.6019086241722107, 0.598895788192749, 0.5958... [0.6652892827987671, 0.6776859760284424, 0.739... [0.6721311211585999, 0.6721311211585999, 0.672... 18 3573 75.723831
10 linear 148 0.005233 0.000012 100 91 0.557377 10 76.229987 [0.8135287761688232, 0.8092227578163147, 0.812... [0.7794104218482971, 0.7784043550491333, 0.777... [0.41735535860061646, 0.42561984062194824, 0.4... [0.39344263076782227, 0.39344263076782227, 0.3... 18 3459 78.200037
11 linear 50 0.014352 0.000025 100 79 0.786885 11 78.871375 [0.9058765769004822, 0.8957070708274841, 0.892... [0.9036123752593994, 0.8989002704620361, 0.894... [0.2933884263038635, 0.3264462947845459, 0.314... [0.32786884903907776, 0.32786884903907776, 0.3... 18 3003 81.243957
12 linear 109 0.000487 0.000045 96 79 0.803279 12 81.892266 [0.8505944013595581, 0.8450313210487366, 0.839... [0.8944583535194397, 0.888396143913269, 0.8825... [0.2975206673145294, 0.3016528785228729, 0.305... [0.26229506731033325, 0.26229506731033325, 0.2... 18 3003 83.936618
13 relu 50 0.011727 0.000049 94 79 0.786885 13 84.554558 [0.772240400314331, 0.7702962160110474, 0.7616... [0.7718428373336792, 0.7672584652900696, 0.762... [0.35537189245224, 0.35537189245224, 0.3801652... [0.3442623019218445, 0.3442623019218445, 0.360... 18 3003 86.888750
14 linear 48 0.010876 0.000047 96 86 0.786885 14 87.831352 [0.7737085223197937, 0.7620970010757446, 0.758... [0.7752880454063416, 0.7666113972663879, 0.757... [0.41322314739227295, 0.43388429284095764, 0.4... [0.44262295961380005, 0.4754098355770111, 0.52... 18 3269 90.291341
15 linear 50 0.012980 0.000046 96 81 0.852459 15 91.166686 [0.6715313196182251, 0.6666457056999207, 0.660... [0.5990051031112671, 0.5933732986450195, 0.587... [0.6404958963394165, 0.6611570119857788, 0.669... [0.6721311211585999, 0.6721311211585999, 0.704... 18 3079 93.352006
16 linear 49 0.010183 0.000047 98 81 0.803279 16 94.176494 [0.9817445278167725, 0.9637849926948547, 0.955... [1.0305403470993042, 1.0171641111373901, 1.003... [0.32231405377388, 0.3388429880142212, 0.35950... [0.3606557250022888, 0.3606557250022888, 0.360... 18 3079 96.356579
17 linear 51 0.011189 0.000035 95 81 0.754098 17 97.268516 [0.8748018741607666, 0.8668413162231445, 0.861... [0.8491439819335938, 0.8422006368637085, 0.835... [0.3264462947845459, 0.3181818127632141, 0.347... [0.2786885201931, 0.2950819730758667, 0.295081... 18 3079 99.473332
18 linear 50 0.019094 0.000046 95 79 0.836066 18 100.300250 [0.6110990047454834, 0.6006994843482971, 0.598... [0.5857220888137817, 0.580162525177002, 0.5746... [0.7438016533851624, 0.7479338645935059, 0.760... [0.8360655903816223, 0.8360655903816223, 0.836... 18 3003 102.724646
19 linear 51 0.012842 0.000041 96 81 0.803279 19 103.713039 [0.6344509720802307, 0.6267118453979492, 0.626... [0.6352565884590149, 0.6293896436691284, 0.623... [0.6900826692581177, 0.6983470916748047, 0.710... [0.6557376980781555, 0.6557376980781555, 0.655... 18 3079 105.967294
20 relu 171 0.508927 0.001910 10 83 0.770492 20 106.621811 [0.7741369009017944, 0.6997448801994324, 0.641... [0.6997786164283752, 0.6350088119506836, 0.582... [0.46694216132164, 0.557851254940033, 0.648760... [0.5737704634666443, 0.6557376980781555, 0.737... 18 3155 107.922894
21 gelu 247 0.291055 0.000012 14 57 0.393443 21 108.269415 [0.7637456655502319, 0.7536334991455078, 0.767... [0.770416259765625, 0.7701885104179382, 0.7699... [0.4586776793003082, 0.42975205183029175, 0.44... [0.39344263076782227, 0.39344263076782227, 0.3... 18 2167 109.827998
22 elu 247 0.330785 0.000270 32 57 0.754098 22 110.438860 [0.7152367234230042, 0.7155708074569702, 0.698... [0.7250485420227051, 0.7178604006767273, 0.710... [0.5123966932296753, 0.4958677589893341, 0.570... [0.44262295961380005, 0.44262295961380005, 0.4... 18 2167 111.909659
23 gelu 256 0.174206 0.001193 22 54 0.836066 23 112.662811 [0.7355976700782776, 0.7147685885429382, 0.687... [0.7190053462982178, 0.6970692276954651, 0.676... [0.44214877486228943, 0.4834710657596588, 0.55... [0.3606557250022888, 0.4754098355770111, 0.573... 18 2053 114.152527
24 elu 50 0.012164 0.000047 96 73 0.786885 24 115.373385 [0.7974379062652588, 0.7888069748878479, 0.779... [0.8222996592521667, 0.8145594596862793, 0.806... [0.39669421315193176, 0.40082645416259766, 0.4... [0.37704917788505554, 0.39344263076782227, 0.3... 18 2775 117.593125
25 hard_sigmoid 244 0.215005 0.000582 13 56 0.229508 25 118.166421 [0.8906719088554382, 0.9096777439117432, 0.914... [0.9311334490776062, 0.9191083312034607, 0.907... [0.3140496015548706, 0.28925618529319763, 0.30... [0.2295081913471222, 0.2295081913471222, 0.229... 18 2129 119.727548
26 hard_sigmoid 228 0.195278 0.001142 16 48 0.770492 26 120.442391 [0.7672627568244934, 0.7242721319198608, 0.727... [0.7250207662582397, 0.701765239238739, 0.6803... [0.3636363744735718, 0.43801653385162354, 0.45... [0.26229506731033325, 0.4590163826942444, 0.62... 18 1825 121.826843
27 selu 160 0.546215 0.000789 13 11 0.688525 27 122.283213 [0.7485548257827759, 0.7522017955780029, 0.763... [0.7076205015182495, 0.6924997568130493, 0.678... [0.5950413346290588, 0.586776852607727, 0.6157... [0.6065573692321777, 0.6065573692321777, 0.622... 18 419 123.619181
28 linear 8 0.526041 0.000060 13 82 0.786885 28 124.196094 [0.8352355360984802, 0.7898195385932922, 0.765... [0.7591118812561035, 0.7175912857055664, 0.682... [0.43388429284095764, 0.5206611752510071, 0.49... [0.4590163826942444, 0.5245901346206665, 0.606... 18 3117 125.941018
29 gelu 256 0.188568 0.001515 21 8 0.737705 29 126.460014 [0.7105452418327332, 0.7035156488418579, 0.698... [0.7025739550590515, 0.696033239364624, 0.6896... [0.5537189841270447, 0.5371900796890259, 0.574... [0.6229507923126221, 0.6393442749977112, 0.639... 18 305 128.180563
30 gelu 256 0.244248 0.001660 23 56 0.786885 30 128.727047 [0.6546282768249512, 0.6317970752716064, 0.637... [0.572919487953186, 0.556894838809967, 0.54188... [0.6818181872367859, 0.7066115736961365, 0.706... [0.7540983557701111, 0.7540983557701111, 0.754... 18 2129 130.153736
31 linear 224 0.305334 0.000660 44 61 0.819672 31 130.919080 [0.9484014511108398, 0.9179948568344116, 0.877... [0.924122154712677, 0.8768399953842163, 0.8322... [0.32231405377388, 0.28925618529319763, 0.3512... [0.19672131538391113, 0.19672131538391113, 0.2... 18 2319 132.589723
32 linear 50 0.479327 0.000068 96 17 0.737705 32 133.278981 [0.9485551118850708, 0.9460268616676331, 0.900... [0.8557782769203186, 0.8498371243476868, 0.843... [0.5123966932296753, 0.5247933864593506, 0.483... [0.5245901346206665, 0.5245901346206665, 0.524... 18 647 135.438342
33 selu 164 0.561031 0.000726 14 55 0.819672 33 136.143825 [0.8745442628860474, 0.8545739054679871, 0.796... [0.7864746451377869, 0.7405412793159485, 0.699... [0.4752066135406494, 0.5, 0.5247933864593506, ... [0.4262295067310333, 0.5081967115402222, 0.606... 18 2091 137.718102
34 selu 254 0.322091 0.001231 12 47 0.819672 34 138.606087 [0.7325422763824463, 0.7275435328483582, 0.714... [0.6849062442779541, 0.6536070108413696, 0.624... [0.6115702390670776, 0.56611567735672, 0.59090... [0.5573770403862, 0.5901639461517334, 0.622950... 18 1787 140.006373
35 gelu 96 0.235332 0.000678 35 49 0.819672 35 140.362036 [0.6482117176055908, 0.640116274356842, 0.6336... [0.6331732869148254, 0.6124953031539917, 0.593... [0.64462810754776, 0.6776859760284424, 0.70247... [0.7377049326896667, 0.7049180269241333, 0.721... 18 1863 141.949003
36 linear 50 0.327453 0.000119 96 81 0.819672 36 142.978975 [0.7736309170722961, 0.6837102770805359, 0.716... [0.6871227622032166, 0.6735551953315735, 0.660... [0.6033057570457458, 0.6776859760284424, 0.623... [0.7213114500045776, 0.7213114500045776, 0.770... 18 3079 145.525915
37 gelu 134 0.098413 0.000580 22 54 0.786885 37 146.242206 [0.6934930682182312, 0.6824401617050171, 0.672... [0.7072174549102783, 0.6886430382728577, 0.670... [0.4793388545513153, 0.5206611752510071, 0.582... [0.4590163826942444, 0.5245901346206665, 0.655... 18 2053 147.982779
38 hard_sigmoid 244 0.317880 0.001168 45 53 0.786885 38 148.651484 [0.9446930289268494, 0.9471455216407776, 0.897... [0.9400205612182617, 0.9183498620986938, 0.897... [0.3140496015548706, 0.3305785059928894, 0.322... [0.2295081913471222, 0.2295081913471222, 0.229... 18 2015 150.211455
39 linear 11 0.539959 0.000073 71 42 0.836066 39 151.001553 [0.7638189792633057, 0.768632173538208, 0.8030... [0.7098137140274048, 0.6863061785697937, 0.662... [0.5289255976676941, 0.5, 0.5289255976676941, ... [0.5573770403862, 0.5573770403862, 0.557377040... 18 1597 153.791880
40 linear 9 0.594717 0.000072 75 22 0.836066 40 154.191006 [0.9886115193367004, 0.8805288672447205, 0.913... [0.8028327226638794, 0.7799118161201477, 0.759... [0.45041322708129883, 0.4958677589893341, 0.45... [0.49180328845977783, 0.5245901346206665, 0.54... 18 837 157.699540
41 linear 9 0.583705 0.000074 70 11 0.704918 41 158.364740 [1.2854955196380615, 1.2296828031539917, 1.140... [1.1728463172912598, 1.1491509675979614, 1.125... [0.2975206673145294, 0.3429751992225647, 0.371... [0.2295081913471222, 0.26229506731033325, 0.27... 18 419 161.862414
42 linear 10 0.579446 0.000080 70 21 0.819672 42 162.550486 [0.7621033191680908, 0.7450456023216248, 0.790... [0.6448240280151367, 0.630098819732666, 0.6165... [0.5619834661483765, 0.5495867729187012, 0.570... [0.5901639461517334, 0.6065573692321777, 0.639... 18 799 165.567682
43 linear 11 0.542695 0.000010 72 41 0.704918 43 166.191566 [0.7994367480278015, 0.7612443566322327, 0.755... [0.701274037361145, 0.6975817084312439, 0.6937... [0.5082644820213318, 0.5454545617103577, 0.557... [0.5245901346206665, 0.5245901346206665, 0.540... 18 1559 169.025697
44 hard_sigmoid 9 0.531649 0.000072 72 56 0.803279 44 169.882809 [0.6588014364242554, 0.6326671242713928, 0.658... [0.5621706247329712, 0.5559288263320923, 0.550... [0.6115702390670776, 0.6652892827987671, 0.657... [0.7704917788505554, 0.7704917788505554, 0.770... 18 2129 172.994597
45 selu 149 0.537188 0.001174 14 30 0.836066 45 173.636141 [0.9731682538986206, 0.9632638692855835, 0.834... [0.7772341370582581, 0.7184216976165771, 0.666... [0.41322314739227295, 0.4586776793003082, 0.50... [0.4262295067310333, 0.49180328845977783, 0.59... 18 1141 175.037412
46 linear 49 0.011794 0.000046 96 69 0.754098 46 175.385063 [0.9058528542518616, 0.8961911797523499, 0.890... [0.9788206815719604, 0.9681548476219177, 0.957... [0.3677685856819153, 0.3595041334629059, 0.371... [0.26229506731033325, 0.26229506731033325, 0.2... 18 2623 177.826115
47 selu 157 0.534877 0.001165 16 23 0.803279 47 178.533461 [1.0175780057907104, 0.9338928461074829, 0.922... [0.9525404572486877, 0.9009444713592529, 0.852... [0.42148759961128235, 0.46694216132164, 0.4710... [0.32786884903907776, 0.4098360538482666, 0.45... 18 875 179.892400
48 gelu 86 0.196295 0.002204 11 54 0.819672 48 180.589500 [0.5843045711517334, 0.5370311141014099, 0.489... [0.5310717821121216, 0.48231208324432373, 0.44... [0.7355371713638306, 0.7685950398445129, 0.789... [0.7868852615356445, 0.7868852615356445, 0.754... 18 2053 181.977024
49 linear 9 0.591808 0.000048 80 20 0.819672 49 182.668562 [1.0278195142745972, 1.0783154964447021, 1.051... [0.9265504479408264, 0.9104623198509216, 0.894... [0.42975205183029175, 0.37603306770324707, 0.3... [0.3606557250022888, 0.3606557250022888, 0.377... 18 761 185.918045
50 linear 9 0.591808 0.000048 80 20 0.770492 50 306.613144 [0.7385859489440918, 0.6464691758155823, 0.692... [0.61632239818573, 0.6080429553985596, 0.60057... [0.5909090638160706, 0.6570248007774353, 0.590... [0.6721311211585999, 0.6721311211585999, 0.704... 18 761 309.925543
51 elu 79 0.210961 0.002270 29 53 0.803279 51 310.659269 [0.7370269298553467, 0.6033796668052673, 0.516... [0.6075996160507202, 0.5017809271812439, 0.442... [0.42975205183029175, 0.6570248007774353, 0.76... [0.7213114500045776, 0.8360655903816223, 0.852... 18 2015 312.720539
52 softplus 148 0.547227 0.001533 13 30 0.770492 52 313.843650 [1.0147786140441895, 0.9364703297615051, 0.922... [0.8291641473770142, 0.7705664038658142, 0.717... [0.44628098607063293, 0.4586776793003082, 0.49... [0.3606557250022888, 0.4754098355770111, 0.590... 18 1141 315.235044
53 linear 18 0.574718 0.000072 74 42 0.836066 53 315.653283 [0.7922735810279846, 0.7768206596374512, 0.744... [0.6496352553367615, 0.6359923481941223, 0.622... [0.5206611752510071, 0.5454545617103577, 0.541... [0.6229507923126221, 0.6065573692321777, 0.655... 18 1597 318.200261
54 linear 11 0.581862 0.000073 96 42 0.836066 54 319.290598 [0.8693579435348511, 0.8212579488754272, 0.779... [0.8195937275886536, 0.7894597053527832, 0.761... [0.4545454680919647, 0.5247933864593506, 0.491... [0.44262295961380005, 0.44262295961380005, 0.4... 18 1597 322.698709

Now that the search is over we can look at the evolution of performance of explored models.

[34]:
import matplotlib.pyplot as plt
from deephyper.analysis.hpo import plot_search_trajectory_single_objective_hpo

fig, ax = plot_search_trajectory_single_objective_hpo(results)
plt.title("Search Trajectory")
plt.show()
../../../../_images/tutorials_tutorials_colab_HPS_basic_classification_with_tabular_data_notebook_34_0.png

Then, we can look at the best configuration.

[49]:
i_max = results.objective.argmax()
best_job = results.iloc[i_max].to_dict()


print(f"The default configuration has an accuracy of {objective_default:.3f}. \n"
      f"The best configuration found by DeepHyper has an accuracy {results['objective'].iloc[i_max]:.3f}, \n"
      f"discovered after {results['m:timestamp_gather'].iloc[i_max]:.2f} secondes of search.\n")

best_job
The default configuration has an accuracy of 0.803.
The best configuration found by DeepHyper has an accuracy 0.852,
discovered after 65.51 secondes of search.

[49]:
{'p:activation': 'linear',
 'p:batch_size': 50,
 'p:dropout_rate': 0.0113359066468359,
 'p:learning_rate': 4.570626021696808e-05,
 'p:num_epochs': 96,
 'p:units': 79,
 'objective': 0.8524590134620667,
 'job_id': 4,
 'm:timestamp_submit': 63.14429593086243,
 'm:loss': '[0.7343717813491821, 0.7315843105316162, 0.7203551530838013, 0.7077977657318115, 0.7045767903327942, 0.6942116022109985, 0.6890315413475037, 0.684148907661438, 0.6750453114509583, 0.6661133170127869, 0.6599052548408508, 0.6488535404205322, 0.6500421166419983, 0.646197497844696, 0.6356723308563232, 0.6293549537658691, 0.6263395547866821, 0.6230587363243103, 0.6143192648887634, 0.6077391505241394, 0.6039666533470154, 0.599918782711029, 0.5972365736961365, 0.5854122638702393, 0.5852221846580505, 0.5819635391235352, 0.5733869671821594, 0.5693132281303406, 0.5666898488998413, 0.5630242228507996, 0.5568211078643799, 0.5544881820678711, 0.5535055994987488, 0.5494033694267273, 0.544603705406189, 0.5400131940841675, 0.5392252802848816, 0.5353421568870544, 0.531427800655365, 0.5265730023384094, 0.5229849815368652, 0.522513747215271, 0.5185511708259583, 0.510654091835022, 0.510829508304596, 0.5096814036369324, 0.5069150328636169, 0.5036590695381165, 0.4942742884159088, 0.4953521490097046, 0.4973622262477875, 0.4932021498680115, 0.48927953839302063, 0.48885059356689453, 0.4867109954357147, 0.4817766845226288, 0.4790145754814148, 0.48116615414619446, 0.47567224502563477, 0.47586676478385925, 0.474057674407959, 0.4685772955417633, 0.46800220012664795, 0.46771240234375, 0.46357816457748413, 0.4627784490585327, 0.46076714992523193, 0.46169304847717285, 0.45370250940322876, 0.4523642659187317, 0.452671080827713, 0.45055699348449707, 0.4527966380119324, 0.4467359185218811, 0.4463013708591461, 0.44568803906440735, 0.44608384370803833, 0.44155585765838623, 0.4365282952785492, 0.43844056129455566, 0.4361962378025055, 0.43178996443748474, 0.43474602699279785, 0.4314608871936798, 0.4278643727302551, 0.4289937913417816, 0.42474204301834106, 0.4234481155872345, 0.42462724447250366, 0.4230608344078064, 0.42143714427948, 0.4172115623950958, 0.4172593951225281, 0.41717463731765747, 0.41767194867134094, 0.4153216481208801]',
 'm:val_loss': '[0.753821849822998, 0.7444661855697632, 0.7352833151817322, 0.7264078855514526, 0.7177568674087524, 0.709143877029419, 0.7008553147315979, 0.6926632523536682, 0.6846651434898376, 0.6767299175262451, 0.6690924763679504, 0.6618218421936035, 0.6545751690864563, 0.6477261185646057, 0.6409757137298584, 0.6344925761222839, 0.6279751658439636, 0.6217548251152039, 0.6155738234519958, 0.6095430850982666, 0.6038309335708618, 0.5982047319412231, 0.5927146673202515, 0.5874735116958618, 0.582288920879364, 0.5773451924324036, 0.5724054574966431, 0.5675630569458008, 0.562959611415863, 0.5584943294525146, 0.5539917349815369, 0.5496424436569214, 0.545595645904541, 0.5415534377098083, 0.537651002407074, 0.5336986780166626, 0.5300918221473694, 0.526334822177887, 0.5227644443511963, 0.5193122029304504, 0.5160749554634094, 0.5127428770065308, 0.5095983743667603, 0.5065170526504517, 0.5034927129745483, 0.5005301833152771, 0.49770525097846985, 0.4948558509349823, 0.4920128583908081, 0.4893036186695099, 0.4867943823337555, 0.48432719707489014, 0.4818807542324066, 0.4794256389141083, 0.47698256373405457, 0.4746597409248352, 0.47240322828292847, 0.47020772099494934, 0.46811091899871826, 0.4660354256629944, 0.46405380964279175, 0.46211710572242737, 0.4601345956325531, 0.4583086669445038, 0.4563702642917633, 0.4545257091522217, 0.4527909457683563, 0.4510558843612671, 0.4493982195854187, 0.44778916239738464, 0.4462117552757263, 0.4445821940898895, 0.44306841492652893, 0.4415777623653412, 0.4401895999908447, 0.4388270080089569, 0.43738314509391785, 0.4360154867172241, 0.43456223607063293, 0.4332718551158905, 0.4320412278175354, 0.4307521879673004, 0.4295240044593811, 0.42836472392082214, 0.4272339642047882, 0.42612510919570923, 0.4250316321849823, 0.42397260665893555, 0.42292481660842896, 0.4219026267528534, 0.42088067531585693, 0.41994309425354004, 0.41888922452926636, 0.41796427965164185, 0.4170234501361847, 0.4160982072353363]',
 'm:accuracy': '[0.4793388545513153, 0.4958677589893341, 0.5371900796890259, 0.5413222908973694, 0.5743801593780518, 0.5743801593780518, 0.586776852607727, 0.5743801593780518, 0.5950413346290588, 0.6115702390670776, 0.6363636255264282, 0.64462810754776, 0.6404958963394165, 0.6363636255264282, 0.6611570119857788, 0.6735537052154541, 0.6611570119857788, 0.6776859760284424, 0.6776859760284424, 0.6818181872367859, 0.6818181872367859, 0.6818181872367859, 0.6942148804664612, 0.702479362487793, 0.71074378490448, 0.71074378490448, 0.71074378490448, 0.7272727489471436, 0.7355371713638306, 0.7396694421768188, 0.7479338645935059, 0.7479338645935059, 0.7396694421768188, 0.7396694421768188, 0.7355371713638306, 0.7685950398445129, 0.7644628286361694, 0.7561983466148376, 0.7644628286361694, 0.7851239442825317, 0.78925621509552, 0.78925621509552, 0.797520637512207, 0.797520637512207, 0.7933884263038635, 0.7809917330741882, 0.797520637512207, 0.797520637512207, 0.7933884263038635, 0.797520637512207, 0.797520637512207, 0.7933884263038635, 0.797520637512207, 0.8016529083251953, 0.78925621509552, 0.8057851195335388, 0.8140496015548706, 0.8099173307418823, 0.8016529083251953, 0.8016529083251953, 0.8140496015548706, 0.8099173307418823, 0.8181818127632141, 0.78925621509552, 0.8223140239715576, 0.8016529083251953, 0.8181818127632141, 0.8140496015548706, 0.8181818127632141, 0.8140496015548706, 0.8099173307418823, 0.8140496015548706, 0.8099173307418823, 0.8140496015548706, 0.8016529083251953, 0.8099173307418823, 0.8140496015548706, 0.8181818127632141, 0.8181818127632141, 0.8140496015548706, 0.8140496015548706, 0.8223140239715576, 0.8140496015548706, 0.8223140239715576, 0.8181818127632141, 0.8181818127632141, 0.8223140239715576, 0.8223140239715576, 0.8140496015548706, 0.8223140239715576, 0.8305785059928894, 0.8388429880142212, 0.8099173307418823, 0.8181818127632141, 0.8223140239715576, 0.8181818127632141]',
 'm:val_accuracy': '[0.39344263076782227, 0.44262295961380005, 0.4590163826942444, 0.4590163826942444, 0.49180328845977783, 0.49180328845977783, 0.5245901346206665, 0.5409836173057556, 0.5409836173057556, 0.5245901346206665, 0.5409836173057556, 0.5737704634666443, 0.5901639461517334, 0.5901639461517334, 0.6229507923126221, 0.6229507923126221, 0.6229507923126221, 0.6229507923126221, 0.6393442749977112, 0.6393442749977112, 0.6393442749977112, 0.6557376980781555, 0.6721311211585999, 0.6721311211585999, 0.688524603843689, 0.688524603843689, 0.7049180269241333, 0.7213114500045776, 0.7213114500045776, 0.7213114500045776, 0.7377049326896667, 0.7377049326896667, 0.7377049326896667, 0.7377049326896667, 0.7213114500045776, 0.7213114500045776, 0.7213114500045776, 0.7213114500045776, 0.7540983557701111, 0.7540983557701111, 0.7540983557701111, 0.7540983557701111, 0.7540983557701111, 0.7540983557701111, 0.7540983557701111, 0.7540983557701111, 0.7704917788505554, 0.7704917788505554, 0.7704917788505554, 0.7704917788505554, 0.7704917788505554, 0.7704917788505554, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667]',
 'm:num_parameters': 18,
 'm:num_parameters_train': 3003,
 'm:timestamp_gather': 65.50817203521729}
[51]:
import matplotlib.pyplot as plt

plt.figure()
plt.plot(json.loads(metadata_default["val_accuracy"]), color="skyblue", label="Default (val)")
plt.plot(json.loads(metadata_default["accuracy"]), color="skyblue", linestyle="--", label="Default (train)")
plt.plot(json.loads(best_job["m:val_accuracy"]), color="coral", linewidth=2, label="Best Job(val)")
plt.plot(json.loads(best_job["m:accuracy"]), color="coral", linestyle="--", linewidth=2, label="Best Job (train)")
plt.legend()
plt.ylim(0.3, 0.9)
plt.grid()
plt.show()
../../../../_images/tutorials_tutorials_colab_HPS_basic_classification_with_tabular_data_notebook_37_0.png

2.9. Restart from a checkpoint#

It can often be useful to continue the search from previous results. For example, if the allocation requested was not enough or if an unexpected crash happened. The CBO search provides the fit_surrogate(dataframe_of_results) method for this use case.

To simulate this we create a second evaluator evaluator_2 and start a fresh CBO search with strong explotation kappa=0.001.

[42]:
# Create a new evaluator
evaluator_2 = get_evaluator(run)

# Create a new CBO search with strong explotation (i.e., small kappa)
search_from_checkpoint = CBO(problem, evaluator_2, kappa=0.001, acq_optimizer="mixedga", acq_optimizer_freq=1)

# Initialize surrogate model of Bayesian optization
# With results of previous search
search_from_checkpoint.fit_surrogate(results)
WARNING:root:Results file already exists, it will be renamed to /Users/romainegele/Documents/Argonne/deephyper-tutorials/tutorials/colab/HPS_basic_classification_with_tabular_data/results_20240805-185838.csv
Created new evaluator with 1 worker and config: {'num_cpus': 1, 'num_cpus_per_task': 1, 'callbacks': [<deephyper.evaluator.callback.TqdmCallback object at 0x3e6806c60>]}
[43]:
results_from_checkpoint = search_from_checkpoint.search(max_evals=25)
[44]:
results_from_checkpoint
[44]:
p:activation p:batch_size p:dropout_rate p:learning_rate p:num_epochs p:units objective job_id m:timestamp_submit m:loss m:val_loss m:accuracy m:val_accuracy m:num_parameters m:num_parameters_train m:timestamp_gather
0 linear 11 0.580120 0.000073 95 42 0.852459 0 4.672708 [0.8041970729827881, 0.7837277054786682, 0.789... [0.760534405708313, 0.736428439617157, 0.71280... [0.5247933864593506, 0.5, 0.5289255976676941, ... [0.5409836173057556, 0.5573770403862, 0.540983... 18 1597 7.694214
1 swish 11 0.580091 0.000073 95 42 0.803279 1 8.683520 [0.7580281496047974, 0.8090499639511108, 0.768... [0.7929872870445251, 0.7761178612709045, 0.758... [0.5206611752510071, 0.44628098607063293, 0.46... [0.3606557250022888, 0.37704917788505554, 0.40... 18 1597 12.092570
2 linear 12 0.580188 0.000036 95 50 0.786885 2 13.035506 [1.0248744487762451, 1.0645800828933716, 0.993... [0.9728235006332397, 0.9541761875152588, 0.937... [0.3347107470035553, 0.3099173605442047, 0.326... [0.24590164422988892, 0.26229506731033325, 0.2... 18 1901 16.033196
3 linear 11 0.579794 0.000073 95 72 0.836066 3 16.548202 [0.7747272849082947, 0.7421781420707703, 0.706... [0.6965786218643188, 0.6639910936355591, 0.633... [0.5, 0.5702479481697083, 0.5785123705863953, ... [0.5573770403862, 0.6557376980781555, 0.737704... 18 2737 19.609617
4 linear 19 0.580786 0.000068 95 43 0.803279 4 20.466129 [0.8868724703788757, 0.8863998651504517, 0.783... [0.7910361289978027, 0.7758963108062744, 0.761... [0.43801653385162354, 0.4586776793003082, 0.53... [0.4262295067310333, 0.4262295067310333, 0.442... 18 1635 23.013769
5 linear 11 0.421636 0.000072 95 41 0.836066 5 23.887383 [0.7451992034912109, 0.7273797988891602, 0.748... [0.5972851514816284, 0.581959068775177, 0.5678... [0.5702479481697083, 0.56611567735672, 0.57851... [0.688524603843689, 0.7213114500045776, 0.7213... 18 1559 27.051523
6 linear 11 0.001346 0.000073 93 27 0.852459 6 27.735532 [0.720139741897583, 0.6974911689758301, 0.6783... [0.664795994758606, 0.6413559913635254, 0.6196... [0.5289255976676941, 0.557851254940033, 0.6033... [0.5409836173057556, 0.5573770403862, 0.590163... 18 1027 31.108865
7 linear 11 0.015196 0.000073 93 8 0.836066 7 31.640157 [0.8719937205314636, 0.8656114339828491, 0.857... [0.8745443820953369, 0.8604922890663147, 0.846... [0.5041322112083435, 0.5123966932296753, 0.516... [0.44262295961380005, 0.4590163826942444, 0.45... 18 305 34.892566
8 linear 11 0.000146 0.000073 85 12 0.836066 8 35.989811 [0.9494009613990784, 0.9227721095085144, 0.896... [0.9361185431480408, 0.9073224663734436, 0.880... [0.3057851195335388, 0.3057851195335388, 0.334... [0.3606557250022888, 0.37704917788505554, 0.37... 18 457 38.953501
9 linear 8 0.000002 0.000073 93 27 0.836066 9 39.846366 [0.8590447902679443, 0.815822958946228, 0.7771... [0.8581119179725647, 0.813581109046936, 0.7709... [0.42975205183029175, 0.43801653385162354, 0.4... [0.4754098355770111, 0.49180328845977783, 0.52... 18 1027 43.352862
10 swish 9 0.001305 0.000074 30 28 0.803279 10 43.974943 [0.8331950306892395, 0.8155596256256104, 0.798... [0.8349722027778625, 0.8154217004776001, 0.796... [0.3057851195335388, 0.3181818127632141, 0.330... [0.26229506731033325, 0.2786885201931, 0.29508... 18 1065 45.914327
11 linear 9 0.002338 0.000073 96 15 0.819672 11 46.717739 [0.5619564056396484, 0.5478906035423279, 0.538... [0.6027979254722595, 0.5921545624732971, 0.582... [0.7272727489471436, 0.7396694421768188, 0.752... [0.7213114500045776, 0.7049180269241333, 0.721... 18 571 50.088631
12 linear 11 0.001406 0.000047 92 27 0.819672 12 51.044316 [0.7591452598571777, 0.748119592666626, 0.7355... [0.7010340094566345, 0.6882603168487549, 0.676... [0.43801653385162354, 0.46694216132164, 0.4876... [0.6065573692321777, 0.6065573692321777, 0.639... 18 1027 54.554788
13 linear 11 0.001531 0.000073 93 65 0.819672 13 55.431078 [0.7318987846374512, 0.6924680471420288, 0.658... [0.7065127491950989, 0.6669818162918091, 0.632... [0.5041322112083435, 0.5702479481697083, 0.619... [0.5409836173057556, 0.5901639461517334, 0.688... 18 2471 58.463386
14 selu 11 0.051216 0.000073 93 27 0.852459 14 59.775333 [0.7977746725082397, 0.7680606245994568, 0.733... [0.8183743357658386, 0.7818477153778076, 0.746... [0.4793388545513153, 0.5206611752510071, 0.553... [0.4754098355770111, 0.4754098355770111, 0.540... 18 1027 62.897967
15 selu 11 0.599294 0.000072 93 9 0.819672 15 63.684615 [0.7609134912490845, 0.8461434245109558, 0.769... [0.6055606007575989, 0.5996733903884888, 0.594... [0.56611567735672, 0.5041322112083435, 0.53305... [0.6229507923126221, 0.6393442749977112, 0.639... 18 343 66.736602
16 linear 12 0.002261 0.000073 93 27 0.819672 16 67.926674 [0.724230170249939, 0.6960060596466064, 0.6756... [0.7928970456123352, 0.7624304294586182, 0.734... [0.5, 0.5454545617103577, 0.5950413346290588, ... [0.4098360538482666, 0.44262295961380005, 0.47... 18 1027 71.041350
17 selu 8 0.051312 0.000074 75 11 0.836066 17 71.875614 [0.5699914693832397, 0.555569589138031, 0.5570... [0.5412390828132629, 0.5306512117385864, 0.520... [0.7685950398445129, 0.7603305578231812, 0.727... [0.7213114500045776, 0.7540983557701111, 0.754... 18 419 75.466510
18 softplus 11 0.074423 0.000073 94 29 0.770492 18 75.929389 [0.6149121522903442, 0.6076364517211914, 0.600... [0.6102054715156555, 0.5992701053619385, 0.589... [0.71074378490448, 0.71074378490448, 0.6983470... [0.7377049326896667, 0.7213114500045776, 0.721... 18 1103 79.442036
19 sigmoid 11 0.047733 0.000083 93 27 0.803279 19 80.129976 [0.7153695821762085, 0.7007089257240295, 0.698... [0.5909143090248108, 0.5860286355018616, 0.581... [0.7148760557174683, 0.7148760557174683, 0.714... [0.7704917788505554, 0.7704917788505554, 0.770... 18 1027 83.201642
20 selu 11 0.069503 0.000073 94 29 0.852459 20 84.463790 [1.1732215881347656, 1.1224826574325562, 1.058... [1.1363171339035034, 1.0872124433517456, 1.042... [0.3264462947845459, 0.3677685856819153, 0.338... [0.31147539615631104, 0.3606557250022888, 0.36... 18 1103 87.813287
21 selu 11 0.105616 0.000073 100 29 0.836066 21 88.577887 [0.5842286348342896, 0.5772862434387207, 0.562... [0.4800376296043396, 0.47080090641975403, 0.46... [0.7148760557174683, 0.71074378490448, 0.70661... [0.8524590134620667, 0.8524590134620667, 0.868... 18 1103 92.231860
22 selu 18 0.066401 0.000073 93 26 0.836066 22 93.099906 [1.31602144241333, 1.320502758026123, 1.265641... [1.3727202415466309, 1.3392741680145264, 1.307... [0.2975206673145294, 0.2851239740848541, 0.293... [0.26229506731033325, 0.2950819730758667, 0.29... 18 989 95.827970
23 selu 17 0.057178 0.000073 70 12 0.786885 23 97.016439 [0.7276808023452759, 0.7200209498405457, 0.720... [0.6873724460601807, 0.6807807087898254, 0.673... [0.5123966932296753, 0.5206611752510071, 0.537... [0.6229507923126221, 0.6393442749977112, 0.639... 18 457 99.855316
24 selu 11 0.058280 0.000081 94 27 0.803279 24 100.785771 [0.79051274061203, 0.7398812770843506, 0.72031... [0.8464049696922302, 0.811725378036499, 0.7795... [0.42561984062194824, 0.5165289044380188, 0.52... [0.37704917788505554, 0.4098360538482666, 0.45... 18 1027 103.862200
[52]:
i_max = results_from_checkpoint.objective.argmax()
best_job = results_from_checkpoint.iloc[i_max].to_dict()


print(f"The default configuration has an accuracy of {objective_default:.3f}. \n"
      f"The best configuration found by DeepHyper has an accuracy {results_from_checkpoint['objective'].iloc[i_max]:.3f}, \n"
      f"discovered after {results_from_checkpoint['m:timestamp_gather'].iloc[i_max]:.2f} secondes of search.\n")

best_job
The default configuration has an accuracy of 0.803.
The best configuration found by DeepHyper has an accuracy 0.852,
discovered after 7.69 secondes of search.

[52]:
{'p:activation': 'linear',
 'p:batch_size': 11,
 'p:dropout_rate': 0.5801200292603258,
 'p:learning_rate': 7.270169218267874e-05,
 'p:num_epochs': 95,
 'p:units': 42,
 'objective': 0.8524590134620667,
 'job_id': 0,
 'm:timestamp_submit': 4.672708034515381,
 'm:loss': '[0.8041970729827881, 0.7837277054786682, 0.7897129058837891, 0.7576928734779358, 0.7714273929595947, 0.7486253380775452, 0.7247172594070435, 0.6684877872467041, 0.6340616941452026, 0.652440071105957, 0.6518965363502502, 0.6469079852104187, 0.6724880337715149, 0.5748777389526367, 0.5973073244094849, 0.5889704823493958, 0.6270110011100769, 0.541706919670105, 0.5298312902450562, 0.582248330116272, 0.49765029549598694, 0.5214836597442627, 0.47826045751571655, 0.5435822010040283, 0.5364992618560791, 0.5255090594291687, 0.4721006453037262, 0.4848361611366272, 0.49145030975341797, 0.5109071731567383, 0.5109598636627197, 0.4852293133735657, 0.47832751274108887, 0.48602572083473206, 0.48633140325546265, 0.44579175114631653, 0.4708606004714966, 0.4509105086326599, 0.4845390319824219, 0.4494284987449646, 0.43613725900650024, 0.434268057346344, 0.4538938105106354, 0.45544442534446716, 0.4280913472175598, 0.4577348232269287, 0.40012580156326294, 0.4429107904434204, 0.44394591450691223, 0.457697331905365, 0.4169984757900238, 0.42198798060417175, 0.4538344144821167, 0.4329397976398468, 0.408939391374588, 0.4074942469596863, 0.4367082118988037, 0.446547269821167, 0.4011521637439728, 0.40441566705703735, 0.36089783906936646, 0.4228549003601074, 0.38539719581604004, 0.41360726952552795, 0.3939768373966217, 0.39181941747665405, 0.38108402490615845, 0.34451887011528015, 0.4110186994075775, 0.3830220103263855, 0.3807251751422882, 0.3751121461391449, 0.3672732710838318, 0.39625462889671326, 0.39395615458488464, 0.37351155281066895, 0.381844699382782, 0.37353140115737915, 0.3986795246601105, 0.36958974599838257, 0.35957154631614685, 0.34944215416908264, 0.37648308277130127, 0.3538556694984436, 0.3683083951473236, 0.37798619270324707, 0.35513031482696533, 0.3352597653865814, 0.36496520042419434, 0.346407949924469, 0.3528498113155365, 0.34690144658088684, 0.3503521978855133, 0.3336299955844879, 0.39004188776016235]',
 'm:val_loss': '[0.760534405708313, 0.736428439617157, 0.7128077149391174, 0.6911187171936035, 0.6703223586082458, 0.6510680317878723, 0.6329866051673889, 0.6171879768371582, 0.6040096879005432, 0.5901919007301331, 0.576133668422699, 0.5650199055671692, 0.5530554056167603, 0.5425328612327576, 0.5322731733322144, 0.5232126116752625, 0.5143499970436096, 0.5066238641738892, 0.4990280270576477, 0.4916485548019409, 0.4852895140647888, 0.47908690571784973, 0.473391592502594, 0.467697411775589, 0.4625626802444458, 0.4571584463119507, 0.4523298442363739, 0.44821926951408386, 0.44439101219177246, 0.44111281633377075, 0.4379344880580902, 0.43438616394996643, 0.4314509332180023, 0.4287097752094269, 0.4262154996395111, 0.42408379912376404, 0.42122986912727356, 0.41910144686698914, 0.41701099276542664, 0.41472113132476807, 0.4127712547779083, 0.410778284072876, 0.40907588601112366, 0.4074136018753052, 0.40630432963371277, 0.40454062819480896, 0.4030597507953644, 0.40178635716438293, 0.40064138174057007, 0.3992939889431, 0.39809659123420715, 0.3967132270336151, 0.3962016999721527, 0.3954685628414154, 0.39448660612106323, 0.3935595750808716, 0.39282742142677307, 0.39161983132362366, 0.3909461796283722, 0.38992923498153687, 0.3893953561782837, 0.3887447416782379, 0.3886417746543884, 0.3885132968425751, 0.3881435692310333, 0.3880539536476135, 0.3878744840621948, 0.3877726197242737, 0.3876495659351349, 0.3874644339084625, 0.3871561884880066, 0.38711845874786377, 0.38701578974723816, 0.38724857568740845, 0.3872767686843872, 0.38656359910964966, 0.38652437925338745, 0.38627997040748596, 0.3856770396232605, 0.3852933943271637, 0.3848276734352112, 0.38475126028060913, 0.38461777567863464, 0.3846655488014221, 0.3846093416213989, 0.384924054145813, 0.3849468231201172, 0.3846433460712433, 0.3851482570171356, 0.3851175606250763, 0.38545405864715576, 0.3855822682380676, 0.385833740234375, 0.38589683175086975, 0.38629570603370667]',
 'm:accuracy': '[0.5247933864593506, 0.5, 0.5289255976676941, 0.5619834661483765, 0.5371900796890259, 0.5371900796890259, 0.5950413346290588, 0.6487603187561035, 0.6611570119857788, 0.6322314143180847, 0.6363636255264282, 0.6818181872367859, 0.6033057570457458, 0.7355371713638306, 0.6900826692581177, 0.6900826692581177, 0.6776859760284424, 0.7644628286361694, 0.7438016533851624, 0.702479362487793, 0.7520661354064941, 0.7066115736961365, 0.7768595218658447, 0.7231404781341553, 0.702479362487793, 0.7396694421768188, 0.7851239442825317, 0.7685950398445129, 0.7479338645935059, 0.7314049601554871, 0.7520661354064941, 0.7561983466148376, 0.7727272510528564, 0.7603305578231812, 0.7685950398445129, 0.8057851195335388, 0.7603305578231812, 0.797520637512207, 0.7479338645935059, 0.7768595218658447, 0.8057851195335388, 0.7809917330741882, 0.7768595218658447, 0.7603305578231812, 0.8181818127632141, 0.7603305578231812, 0.8057851195335388, 0.797520637512207, 0.7851239442825317, 0.7727272510528564, 0.78925621509552, 0.8140496015548706, 0.7520661354064941, 0.78925621509552, 0.797520637512207, 0.8223140239715576, 0.7851239442825317, 0.7685950398445129, 0.8140496015548706, 0.7933884263038635, 0.8264462947845459, 0.78925621509552, 0.8347107172012329, 0.8057851195335388, 0.8181818127632141, 0.8057851195335388, 0.8305785059928894, 0.8595041036605835, 0.8057851195335388, 0.8264462947845459, 0.8016529083251953, 0.8264462947845459, 0.8305785059928894, 0.8099173307418823, 0.8223140239715576, 0.8140496015548706, 0.8140496015548706, 0.8429751992225647, 0.78925621509552, 0.8264462947845459, 0.8429751992225647, 0.8471074104309082, 0.8181818127632141, 0.8471074104309082, 0.8140496015548706, 0.8016529083251953, 0.85537189245224, 0.8471074104309082, 0.8140496015548706, 0.8347107172012329, 0.8264462947845459, 0.8636363744735718, 0.8305785059928894, 0.85537189245224, 0.8223140239715576]',
 'm:val_accuracy': '[0.5409836173057556, 0.5573770403862, 0.5409836173057556, 0.5573770403862, 0.5901639461517334, 0.5901639461517334, 0.5901639461517334, 0.6065573692321777, 0.6721311211585999, 0.688524603843689, 0.7049180269241333, 0.7049180269241333, 0.7049180269241333, 0.7213114500045776, 0.7704917788505554, 0.7540983557701111, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667]',
 'm:num_parameters': 18,
 'm:num_parameters_train': 1597,
 'm:timestamp_gather': 7.694214105606079}
[53]:
import matplotlib.pyplot as plt

plt.figure()
plt.plot(json.loads(metadata_default["val_accuracy"]), color="skyblue", label="Default (val)")
plt.plot(json.loads(metadata_default["accuracy"]), color="skyblue", linestyle="--", label="Default (train)")
plt.plot(json.loads(best_job["m:val_accuracy"]), color="coral", linewidth=2, label="Best Job(val)")
plt.plot(json.loads(best_job["m:accuracy"]), color="coral", linestyle="--", linewidth=2, label="Best Job (train)")
plt.legend()
plt.ylim(0.5, 0.9)
plt.grid()
plt.show()
../../../../_images/tutorials_tutorials_colab_HPS_basic_classification_with_tabular_data_notebook_43_0.png

2.10. Add conditional hyperparameters#

Now we want to add the possibility to search for a second fully-connected layer. We simply add two new lines:

if config.get("dense_2", False):
    x = tfk.layers.Dense(config["dense_2:units"], activation=config["dense_2:activation"])(x)
[54]:
def run_with_condition(config: dict):
    tf.autograph.set_verbosity(0)

    train_dataframe, val_dataframe = load_data()

    train_ds = dataframe_to_dataset(train_dataframe)
    val_ds = dataframe_to_dataset(val_dataframe)

    train_ds = train_ds.batch(config["batch_size"])
    val_ds = val_ds.batch(config["batch_size"])

    # Categorical features encoded as integers
    sex = tfk.Input(shape=(1,), name="sex", dtype="int64")
    cp = tfk.Input(shape=(1,), name="cp", dtype="int64")
    fbs = tfk.Input(shape=(1,), name="fbs", dtype="int64")
    restecg = tfk.Input(shape=(1,), name="restecg", dtype="int64")
    exang = tfk.Input(shape=(1,), name="exang", dtype="int64")
    ca = tfk.Input(shape=(1,), name="ca", dtype="int64")

    # Categorical feature encoded as string
    thal = tfk.Input(shape=(1,), name="thal", dtype="string")

    # Numerical features
    age = tfk.Input(shape=(1,), name="age")
    trestbps = tfk.Input(shape=(1,), name="trestbps")
    chol = tfk.Input(shape=(1,), name="chol")
    thalach = tfk.Input(shape=(1,), name="thalach")
    oldpeak = tfk.Input(shape=(1,), name="oldpeak")
    slope = tfk.Input(shape=(1,), name="slope")

    all_inputs = [
        sex,
        cp,
        fbs,
        restecg,
        exang,
        ca,
        thal,
        age,
        trestbps,
        chol,
        thalach,
        oldpeak,
        slope,
    ]

    # Integer categorical features
    sex_encoded = encode_categorical_feature(sex, "sex", train_ds, False)
    cp_encoded = encode_categorical_feature(cp, "cp", train_ds, False)
    fbs_encoded = encode_categorical_feature(fbs, "fbs", train_ds, False)
    restecg_encoded = encode_categorical_feature(restecg, "restecg", train_ds, False)
    exang_encoded = encode_categorical_feature(exang, "exang", train_ds, False)
    ca_encoded = encode_categorical_feature(ca, "ca", train_ds, False)

    # String categorical features
    thal_encoded = encode_categorical_feature(thal, "thal", train_ds, True)

    # Numerical features
    age_encoded = encode_numerical_feature(age, "age", train_ds)
    trestbps_encoded = encode_numerical_feature(trestbps, "trestbps", train_ds)
    chol_encoded = encode_numerical_feature(chol, "chol", train_ds)
    thalach_encoded = encode_numerical_feature(thalach, "thalach", train_ds)
    oldpeak_encoded = encode_numerical_feature(oldpeak, "oldpeak", train_ds)
    slope_encoded = encode_numerical_feature(slope, "slope", train_ds)

    all_features = tfk.layers.concatenate(
        [
            sex_encoded,
            cp_encoded,
            fbs_encoded,
            restecg_encoded,
            exang_encoded,
            slope_encoded,
            ca_encoded,
            thal_encoded,
            age_encoded,
            trestbps_encoded,
            chol_encoded,
            thalach_encoded,
            oldpeak_encoded,
        ]
    )
    x = tfk.layers.Dense(config["units"], activation=config["activation"])(
        all_features
    )

    ### START - NEW LINES
    if config.get("dense_2", False):
        x = tfk.layers.Dense(config["dense_2:units"], activation=config["dense_2:activation"])(x)
    ### END - NEW LINES

    x = tfk.layers.Dropout(config["dropout_rate"])(x)
    output = tfk.layers.Dense(1, activation="sigmoid")(x)
    model = tfk.Model(all_inputs, output)

    optimizer = tfk.optimizers.Adam(learning_rate=config["learning_rate"])
    model.compile(optimizer, "binary_crossentropy", metrics=["accuracy"])

    history = model.fit(
        train_ds, epochs=config["num_epochs"], validation_data=val_ds, verbose=0
    )

    objective = history.history["val_accuracy"][-1]
    metadata = {
        "loss": history.history["loss"],
        "val_loss": history.history["val_loss"],
        "accuracy": history.history["accuracy"],
        "val_accuracy": history.history["val_accuracy"],
    }
    metadata = {k:json.dumps(v) for k,v in metadata.items()}
    metadata.update(count_params(model))

    return {"objective": objective, "metadata": metadata}

To define conditionnal hyperparameters we use ConfigSpace. We define dense_2:units and dense_2:activation as active hyperparameters only when dense_2 == True. The cs.EqualsCondition help us do that. Then we call

problem_with_condition.add_condition(condition)

to register each new condition to the HpProblem.

[56]:
from ConfigSpace import EqualsCondition

# Define the hyperparameter problem
problem_with_condition = HpProblem()


# Define the same hyperparameters as before
problem_with_condition.add_hyperparameter((8, 128), "units")
problem_with_condition.add_hyperparameter(ACTIVATIONS, "activation")
problem_with_condition.add_hyperparameter((0.0, 0.6), "dropout_rate")
problem_with_condition.add_hyperparameter((10, 100), "num_epochs")
problem_with_condition.add_hyperparameter((8, 256, "log-uniform"), "batch_size")
problem_with_condition.add_hyperparameter((1e-5, 1e-2, "log-uniform"), "learning_rate")


# Add a new hyperparameter "dense_2 (bool)" to decide if a second fully-connected layer should be created
hp_dense_2 = problem_with_condition.add_hyperparameter([True, False], "dense_2")
hp_dense_2_units = problem_with_condition.add_hyperparameter((8, 128), "dense_2:units")
hp_dense_2_activation = problem_with_condition.add_hyperparameter(ACTIVATIONS, "dense_2:activation")

problem_with_condition.add_condition(EqualsCondition(hp_dense_2_units, hp_dense_2, True))
problem_with_condition.add_condition(EqualsCondition(hp_dense_2_activation, hp_dense_2, True))


problem_with_condition
[56]:
Configuration space object:
  Hyperparameters:
    activation, Type: Categorical, Choices: {elu, gelu, hard_sigmoid, linear, relu, selu, sigmoid, softplus, softsign, swish, tanh}, Default: elu
    batch_size, Type: UniformInteger, Range: [8, 256], Default: 45, on log-scale
    dense_2, Type: Categorical, Choices: {True, False}, Default: True
    dense_2:activation, Type: Categorical, Choices: {elu, gelu, hard_sigmoid, linear, relu, selu, sigmoid, softplus, softsign, swish, tanh}, Default: elu
    dense_2:units, Type: UniformInteger, Range: [8, 128], Default: 68
    dropout_rate, Type: UniformFloat, Range: [0.0, 0.6], Default: 0.3
    learning_rate, Type: UniformFloat, Range: [1e-05, 0.01], Default: 0.000316227766, on log-scale
    num_epochs, Type: UniformInteger, Range: [10, 100], Default: 55
    units, Type: UniformInteger, Range: [8, 128], Default: 68
  Conditions:
    dense_2:activation | dense_2 == True
    dense_2:units | dense_2 == True

We create a new evaluator evaluator_3 and start a fresh CBO search with this new problem problem_with_condition.

[57]:
evaluator_3 = get_evaluator(run_with_condition)

search_with_condition = CBO(problem_with_condition, evaluator_3)
WARNING:root:Results file already exists, it will be renamed to /Users/romainegele/Documents/Argonne/deephyper-tutorials/tutorials/colab/HPS_basic_classification_with_tabular_data/results_20240805-190449.csv
Created new evaluator with 1 worker and config: {'num_cpus': 1, 'num_cpus_per_task': 1, 'callbacks': [<deephyper.evaluator.callback.TqdmCallback object at 0x3e628d1c0>]}
[58]:
results_with_condition = search_with_condition.search(max_evals=25)
[59]:
results_with_condition
[59]:
p:activation p:batch_size p:dense_2 p:dropout_rate p:learning_rate p:num_epochs p:units p:dense_2:activation p:dense_2:units objective job_id m:timestamp_submit m:loss m:val_loss m:accuracy m:val_accuracy m:num_parameters m:num_parameters_train m:timestamp_gather
0 tanh 15 True 0.366818 0.005029 26 60 softplus 105 0.786885 0 5.803378 [0.4718848168849945, 0.3748987913131714, 0.300... [0.483046293258667, 0.4596226215362549, 0.4899... [0.7438016533851624, 0.8305785059928894, 0.842... [0.8524590134620667, 0.7868852615356445, 0.836... 18 8731 7.808099
1 tanh 8 True 0.390801 0.000175 81 30 hard_sigmoid 68 0.836066 1 8.464106 [0.7436654567718506, 0.6908826231956482, 0.664... [0.6634212136268616, 0.6115420460700989, 0.579... [0.4710743725299835, 0.5247933864593506, 0.586... [0.688524603843689, 0.8196721076965332, 0.7704... 18 3287 11.945705
2 tanh 114 False 0.196952 0.000011 17 100 elu 8 0.491803 2 12.560954 [0.7704780101776123, 0.7623841762542725, 0.757... [0.6971723437309265, 0.6961050033569336, 0.695... [0.44628098607063293, 0.43388429284095764, 0.4... [0.4754098355770111, 0.4754098355770111, 0.475... 18 3801 14.004241
3 selu 115 False 0.077023 0.000019 25 50 elu 8 0.721311 3 14.622597 [0.6315001845359802, 0.630281925201416, 0.6373... [0.6136785745620728, 0.612109899520874, 0.6105... [0.6363636255264282, 0.6280992031097412, 0.603... [0.6721311211585999, 0.6721311211585999, 0.688... 18 1901 16.010743
4 relu 73 False 0.476480 0.003393 19 34 elu 8 0.803279 4 16.631053 [0.5818436741828918, 0.5115591287612915, 0.514... [0.48164546489715576, 0.4386799931526184, 0.41... [0.702479362487793, 0.7438016533851624, 0.7438... [0.8196721076965332, 0.7704917788505554, 0.770... 18 1293 18.078223
5 selu 125 False 0.534733 0.000035 84 74 elu 8 0.245902 5 18.695810 [1.2875844240188599, 1.262080430984497, 1.3596... [1.3276287317276, 1.3229693174362183, 1.318302... [0.24793387949466705, 0.2975206673145294, 0.25... [0.2295081913471222, 0.2295081913471222, 0.245... 18 2813 20.870275
6 softplus 65 False 0.568267 0.004382 34 30 elu 8 0.836066 6 21.482700 [1.0909706354141235, 0.7855749130249023, 0.700... [0.7736652493476868, 0.5466883778572083, 0.439... [0.42561984062194824, 0.5826446413993835, 0.62... [0.44262295961380005, 0.8032786846160889, 0.80... 18 1141 23.006530
7 linear 55 False 0.162118 0.000082 50 44 elu 8 0.819672 7 23.618691 [0.8554414510726929, 0.8225628137588501, 0.824... [0.8711621761322021, 0.8583137392997742, 0.845... [0.3677685856819153, 0.3719008266925812, 0.413... [0.3442623019218445, 0.3442623019218445, 0.360... 18 1673 25.332632
8 hard_sigmoid 11 True 0.352241 0.000210 99 61 elu 125 0.803279 8 25.941547 [0.7294607162475586, 0.6213605999946594, 0.591... [0.5591641664505005, 0.5274262428283691, 0.524... [0.5, 0.71074378490448, 0.7148760557174683, 0.... [0.7704917788505554, 0.7704917788505554, 0.770... 18 10133 29.293199
9 gelu 185 False 0.341791 0.003917 65 17 elu 8 0.786885 9 29.903265 [0.9228700399398804, 0.8316530585289001, 0.809... [0.7922216057777405, 0.724449634552002, 0.6675... [0.28925618529319763, 0.42561984062194824, 0.4... [0.4262295067310333, 0.49180328845977783, 0.63... 18 647 31.662543
10 swish 28 False 0.561951 0.000036 34 30 elu 8 0.442623 10 32.490997 [0.8498828411102295, 0.8368768692016602, 0.829... [0.8152909874916077, 0.8121859431266785, 0.809... [0.40495866537094116, 0.40909090638160706, 0.4... [0.3442623019218445, 0.3442623019218445, 0.344... 18 1141 34.107248
11 swish 24 False 0.577138 0.005004 56 18 elu 8 0.770492 11 34.816861 [0.6796150803565979, 0.5124174356460571, 0.482... [0.46827590465545654, 0.4033910036087036, 0.37... [0.5950413346290588, 0.7479338645935059, 0.776... [0.7868852615356445, 0.8032786846160889, 0.819... 18 685 37.176652
12 relu 9 True 0.523482 0.000046 39 15 linear 62 0.770492 12 37.912493 [0.7733532190322876, 0.7673518657684326, 0.725... [0.7620072364807129, 0.7392509579658508, 0.716... [0.5041322112083435, 0.5330578684806824, 0.570... [0.44262295961380005, 0.49180328845977783, 0.5... 18 1610 40.207185
13 sigmoid 60 False 0.576131 0.004627 63 119 elu 8 0.819672 13 40.919064 [0.6822071075439453, 0.5901292562484741, 0.567... [0.4921690821647644, 0.4631254971027374, 0.408... [0.5743801593780518, 0.7148760557174683, 0.731... [0.7704917788505554, 0.7704917788505554, 0.786... 18 4523 42.717181
14 swish 195 True 0.550026 0.000172 56 30 softplus 72 0.754098 14 43.429739 [0.7369769215583801, 0.6703598499298096, 0.709... [0.5385007858276367, 0.5357068181037903, 0.533... [0.5950413346290588, 0.64462810754776, 0.60330... [0.7704917788505554, 0.7704917788505554, 0.770... 18 3415 45.337643
15 softplus 65 False 0.529968 0.001002 31 36 elu 8 0.819672 15 46.051016 [0.8508321046829224, 0.8252332210540771, 0.771... [0.7166399359703064, 0.6583455801010132, 0.608... [0.43388429284095764, 0.5206611752510071, 0.54... [0.44262295961380005, 0.6229507923126221, 0.72... 18 1369 47.613793
16 softsign 34 True 0.557522 0.000161 78 109 hard_sigmoid 47 0.852459 16 48.333625 [0.8418372273445129, 0.8010062575340271, 0.781... [0.768680214881897, 0.746289074420929, 0.72464... [0.45041322708129883, 0.5082644820213318, 0.46... [0.2295081913471222, 0.2295081913471222, 0.229... 18 9251 51.402913
17 softplus 85 True 0.560434 0.000156 78 118 linear 26 0.852459 17 52.115983 [0.7104208469390869, 0.7846921682357788, 0.714... [0.5288053154945374, 0.5168441534042358, 0.508... [0.6074380278587341, 0.5702479481697083, 0.628... [0.7704917788505554, 0.7704917788505554, 0.770... 18 7487 54.040631
18 softsign 82 True 0.023935 0.000069 78 117 hard_sigmoid 22 0.704918 18 54.757471 [0.8375796675682068, 0.8401011824607849, 0.837... [0.8699012398719788, 0.866222620010376, 0.8625... [0.2851239740848541, 0.2851239740848541, 0.289... [0.2295081913471222, 0.2295081913471222, 0.229... 18 6948 57.191873
19 selu 126 True 0.498169 0.000588 80 112 relu 114 0.803279 19 58.051332 [0.7589589357376099, 0.6722019910812378, 0.575... [0.6904788613319397, 0.6121534109115601, 0.551... [0.4586776793003082, 0.5702479481697083, 0.731... [0.5901639461517334, 0.7540983557701111, 0.770... 18 17141 60.207874
20 softsign 125 True 0.520356 0.000171 99 94 linear 54 0.819672 20 60.940526 [0.8036916851997375, 0.7935233116149902, 0.770... [0.7480814456939697, 0.7282476425170898, 0.709... [0.42561984062194824, 0.40495866537094116, 0.4... [0.4098360538482666, 0.44262295961380005, 0.50... 18 8663 63.042618
21 softplus 30 True 0.517795 0.000126 20 119 hard_sigmoid 44 0.770492 21 63.761625 [0.7613235712051392, 0.7507070899009705, 0.705... [0.649841845035553, 0.6311606764793396, 0.6109... [0.5206611752510071, 0.5247933864593506, 0.590... [0.7704917788505554, 0.7704917788505554, 0.770... 18 9728 65.623381
22 softplus 177 True 0.596195 0.000163 66 119 gelu 37 0.770492 22 66.337983 [0.6425224542617798, 0.6305481195449829, 0.657... [0.5786556005477905, 0.5653592944145203, 0.554... [0.6239669322967529, 0.6611570119857788, 0.628... [0.7704917788505554, 0.7704917788505554, 0.770... 18 8881 68.226429
23 softsign 8 True 0.578768 0.000083 78 105 relu 11 0.819672 23 68.941229 [0.7219522595405579, 0.7058809995651245, 0.676... [0.6714370250701904, 0.6515097618103027, 0.632... [0.5454545617103577, 0.5991735458374023, 0.644... [0.6557376980781555, 0.6557376980781555, 0.737... 18 5063 72.357375
24 tanh 33 False 0.572871 0.000163 86 49 elu 8 0.819672 24 73.081465 [0.7666149139404297, 0.7785250544548035, 0.725... [0.7132001519203186, 0.6926335692405701, 0.673... [0.4876033067703247, 0.4917355477809906, 0.508... [0.5081967115402222, 0.5409836173057556, 0.557... 18 1863 75.528965

Finally, let us print out the best configuration found from this conditionned search space.

[60]:
i_max = results_with_condition.objective.argmax()
best_job = results_with_condition.iloc[i_max].to_dict()


print(f"The default configuration has an accuracy of {objective_default:.3f}. \n"
      f"The best configuration found by DeepHyper has an accuracy {results_with_condition['objective'].iloc[i_max]:.3f}, \n"
      f"discovered after {results_with_condition['m:timestamp_gather'].iloc[i_max]:.2f} secondes of search.\n")

best_job
The default configuration has an accuracy of 0.803.
The best configuration found by DeepHyper has an accuracy 0.852,
discovered after 51.40 secondes of search.

[60]:
{'p:activation': 'softsign',
 'p:batch_size': 34,
 'p:dense_2': True,
 'p:dropout_rate': 0.5575222563657,
 'p:learning_rate': 0.0001609081602,
 'p:num_epochs': 78,
 'p:units': 109,
 'p:dense_2:activation': 'hard_sigmoid',
 'p:dense_2:units': 47,
 'objective': 0.8524590134620667,
 'job_id': 16,
 'm:timestamp_submit': 48.333625078201294,
 'm:loss': '[0.8418372273445129, 0.8010062575340271, 0.7812130451202393, 0.7930625677108765, 0.7684123516082764, 0.7609521746635437, 0.7478557229042053, 0.7406894564628601, 0.6992117762565613, 0.6690812706947327, 0.6832402944564819, 0.6622161269187927, 0.7154543399810791, 0.6304647922515869, 0.632775604724884, 0.635278046131134, 0.6395263671875, 0.601736843585968, 0.5866347551345825, 0.5930865406990051, 0.6114271879196167, 0.6148985028266907, 0.5917722582817078, 0.5602957010269165, 0.6068222522735596, 0.5561351776123047, 0.52374666929245, 0.5643391609191895, 0.5197662711143494, 0.5247690081596375, 0.5277420282363892, 0.5073702335357666, 0.46645841002464294, 0.5366184711456299, 0.5077937841415405, 0.5054715871810913, 0.5132231712341309, 0.5002268552780151, 0.4574722945690155, 0.4624318480491638, 0.45751628279685974, 0.45831236243247986, 0.46884045004844666, 0.46934202313423157, 0.43772009015083313, 0.45625945925712585, 0.42081278562545776, 0.4312887191772461, 0.4441389739513397, 0.4487796425819397, 0.44675105810165405, 0.42890456318855286, 0.42480218410491943, 0.4400651156902313, 0.4144305884838104, 0.39338719844818115, 0.359445720911026, 0.4274265468120575, 0.4111899435520172, 0.38475120067596436, 0.3514232337474823, 0.3587934970855713, 0.39239099621772766, 0.37596988677978516, 0.38595589995384216, 0.38566166162490845, 0.3809291422367096, 0.3818380534648895, 0.42398494482040405, 0.35981088876724243, 0.3791501820087433, 0.41108107566833496, 0.36640042066574097, 0.3670456111431122, 0.36325547099113464, 0.3857632577419281, 0.3685024678707123, 0.39088982343673706]',
 'm:val_loss': '[0.768680214881897, 0.746289074420929, 0.7246480584144592, 0.7041943669319153, 0.6852578520774841, 0.6664589643478394, 0.6496288776397705, 0.6337476372718811, 0.6189693808555603, 0.6056522130966187, 0.593632698059082, 0.5819397568702698, 0.569520890712738, 0.5585426688194275, 0.5482839941978455, 0.5388845205307007, 0.5297328233718872, 0.5214434862136841, 0.5139332413673401, 0.5062064528465271, 0.49924206733703613, 0.4916667640209198, 0.48417896032333374, 0.4775199890136719, 0.4708450138568878, 0.46558263897895813, 0.46020370721817017, 0.45523616671562195, 0.4506804049015045, 0.4466724097728729, 0.4432690143585205, 0.43901562690734863, 0.43512803316116333, 0.43151235580444336, 0.42741674184799194, 0.4234548509120941, 0.41981783509254456, 0.41622480750083923, 0.4126134216785431, 0.4092336595058441, 0.4061407446861267, 0.40323206782341003, 0.40096515417099, 0.3985515236854553, 0.3963643014431, 0.3940359652042389, 0.3918095529079437, 0.38982391357421875, 0.387992799282074, 0.3862203359603882, 0.384972482919693, 0.38385680317878723, 0.38305702805519104, 0.38194236159324646, 0.38052645325660706, 0.3791072964668274, 0.377858966588974, 0.3774709105491638, 0.3772493600845337, 0.37700068950653076, 0.3765186667442322, 0.3758242130279541, 0.37525153160095215, 0.37498900294303894, 0.3749580979347229, 0.3747103810310364, 0.374695748090744, 0.37496882677078247, 0.3751734793186188, 0.375230610370636, 0.37512537837028503, 0.3750172555446625, 0.3752257525920868, 0.37517407536506653, 0.37543660402297974, 0.37605175375938416, 0.3765294849872589, 0.37751638889312744]',
 'm:accuracy': '[0.45041322708129883, 0.5082644820213318, 0.4628099203109741, 0.46694216132164, 0.4876033067703247, 0.5082644820213318, 0.5165289044380188, 0.5247933864593506, 0.5371900796890259, 0.5950413346290588, 0.5826446413993835, 0.6115702390670776, 0.586776852607727, 0.5991735458374023, 0.6115702390670776, 0.6115702390670776, 0.6322314143180847, 0.6735537052154541, 0.6694214940071106, 0.6694214940071106, 0.6528925895690918, 0.64462810754776, 0.6611570119857788, 0.6818181872367859, 0.6776859760284424, 0.7190082669258118, 0.7438016533851624, 0.7066115736961365, 0.7644628286361694, 0.7272727489471436, 0.7520661354064941, 0.7479338645935059, 0.7685950398445129, 0.7396694421768188, 0.7479338645935059, 0.7603305578231812, 0.7314049601554871, 0.7272727489471436, 0.78925621509552, 0.7851239442825317, 0.7768595218658447, 0.7644628286361694, 0.7851239442825317, 0.7727272510528564, 0.7561983466148376, 0.7603305578231812, 0.8140496015548706, 0.8140496015548706, 0.7685950398445129, 0.7809917330741882, 0.7933884263038635, 0.8099173307418823, 0.8264462947845459, 0.7851239442825317, 0.7809917330741882, 0.8057851195335388, 0.8388429880142212, 0.8223140239715576, 0.8016529083251953, 0.8140496015548706, 0.8223140239715576, 0.8264462947845459, 0.8223140239715576, 0.8719007968902588, 0.8223140239715576, 0.8264462947845459, 0.8223140239715576, 0.8347107172012329, 0.78925621509552, 0.85537189245224, 0.8099173307418823, 0.8099173307418823, 0.8305785059928894, 0.8429751992225647, 0.8388429880142212, 0.8140496015548706, 0.8471074104309082, 0.8347107172012329]',
 'm:val_accuracy': '[0.2295081913471222, 0.2295081913471222, 0.2295081913471222, 0.31147539615631104, 0.6393442749977112, 0.7540983557701111, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.8032786846160889, 0.7868852615356445, 0.8032786846160889, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8196721076965332, 0.7868852615356445, 0.8196721076965332, 0.8032786846160889, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.7704917788505554, 0.7704917788505554, 0.7704917788505554, 0.7704917788505554, 0.7704917788505554, 0.7704917788505554, 0.7704917788505554, 0.7704917788505554, 0.7704917788505554, 0.7704917788505554, 0.7704917788505554, 0.7704917788505554, 0.7704917788505554, 0.7704917788505554, 0.7704917788505554, 0.7704917788505554, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.8032786846160889, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.868852436542511, 0.8524590134620667, 0.8524590134620667, 0.868852436542511, 0.868852436542511, 0.868852436542511, 0.868852436542511, 0.868852436542511, 0.8524590134620667, 0.8524590134620667]',
 'm:num_parameters': 18,
 'm:num_parameters_train': 9251,
 'm:timestamp_gather': 51.402913093566895}
[61]:
import matplotlib.pyplot as plt

plt.figure()
plt.plot(json.loads(metadata_default["val_accuracy"]), color="skyblue", label="Default (val)")
plt.plot(json.loads(metadata_default["accuracy"]), color="skyblue", linestyle="--", label="Default (train)")
plt.plot(json.loads(best_job["m:val_accuracy"]), color="coral", linewidth=2, label="Best Job(val)")
plt.plot(json.loads(best_job["m:accuracy"]), color="coral", linestyle="--", linewidth=2, label="Best Job (train)")
plt.legend()
plt.ylim(0.5, 0.9)
plt.grid()
plt.show()
../../../../_images/tutorials_tutorials_colab_HPS_basic_classification_with_tabular_data_notebook_54_0.png
[ ]: