2. Hyperparameter Optimiaztion for Classification with Tabular Data (Tensorflow/Keras2)#

Open In Colab

In this tutorial we present how to use hyperparameter optimization on an example from the Keras documentation.

Reference: This tutorial is based on materials from the Keras Documentation: Structured data classification from scratch

Let us start with installing DeepHyper!

Warning

Since the release of Keras 3.0, this tutorial should be run with tf-keras (link to pypi).

[1]:
try:
    import deephyper
    from deephyper.evaluator import RayEvaluator
    print(deephyper.__version__)
except (ImportError, ModuleNotFoundError):
    !pip install "deephyper[tf-keras2,ray]"
    import deephyper
    from deephyper.evaluator import RayEvaluator
    print(deephyper.__version__)

0.9.0

Note

The following environment variables can be used to avoid the logging of some Tensorflow DEBUG, INFO and WARNING statements.

[2]:
import os


os.environ["TF_CPP_MIN_LOG_LEVEL"] = str(3)
os.environ["AUTOGRAPH_VERBOSITY"] = str(0)

2.1. Imports#

The import strategy from the original Keras tutorial (shown below):

from tensorflow import keras
from tensorflow.keras import layers
...
from tensorflow.keras.layers import IntegerLookup
from tensorflow.keras.layers import Normalization
from tensorflow.keras.layers import StringLookup

resulted in non-serializable data, preventing the search from executing in parallel. Therefore, we changed it to take advantage of lazily loading subpackages from Tensorflow.

[3]:
import json
import warnings

import numpy as np
import pandas as pd
import tensorflow as tf
import tf_keras as tfk
import tf_keras.backend as K

Then we detect if GPU devices are available on the current host. Therefore, this notebook will automatically adapt the parallel execution based on the resources available locally. However, this simple code will not detect available ressources from multiple nodes.

[4]:
from tensorflow.python.client import device_lib


def get_available_gpus():
    local_device_protos = device_lib.list_local_devices()
    return [x.name for x in local_device_protos if x.device_type == "GPU"]


n_gpus = len(get_available_gpus())
if n_gpus > 1:
    n_gpus -= 1

is_gpu_available = n_gpus > 0

if is_gpu_available:
    print(f"{n_gpus} GPU{'s are' if n_gpus > 1 else ' is'} available.")
else:
    print("No GPU available")
No GPU available

2.2. The dataset (from Keras.io)#

The dataset is provided by the Cleveland Clinic Foundation for Heart Disease. It’s a CSV file with 303 rows. Each row contains information about a patient (a sample), and each column describes an attribute of the patient (a feature). We use the features to predict whether a patient has a heart disease (binary classification).

Here’s the description of each feature:

Column

Description

Feature Type

Age

Age in years

Numerical

Sex

(1 = male; 0 = female)

Categorical

CP

Chest pain type (0, 1, 2, 3, 4)

Categorical

Trestbpd

Resting blood pressure (in mm Hg on admission)

Numerical

Chol

Serum cholesterol in mg/dl

Numerical

FBS

fasting blood sugar in 120 mg/dl (1 = true; 0 = false)

Categorical

RestECG

Resting electrocardiogram results (0, 1, 2)

Categorical

Thalach

Maximum heart rate achieved

Numerical

Exang

Exercise induced angina (1 = yes; 0 = no)

Categorical

Oldpeak

ST depression induced by exercise relative to rest

Numerical

Slope

Slope of the peak exercise ST segment

Numerical

CA

Number of major vessels (0-3) colored by fluoroscopy

Both numerical & categorical

Thal

3 = normal; 6 = fixed defect; 7 = reversible defect

Categorical

Target

Diagnosis of heart disease (1 = true; 0 = false)

Target

[5]:
def load_data():
    file_url = "http://storage.googleapis.com/download.tensorflow.org/data/heart.csv"
    dataframe = pd.read_csv(file_url)

    val_dataframe = dataframe.sample(frac=0.2, random_state=1337)
    train_dataframe = dataframe.drop(val_dataframe.index)

    return train_dataframe, val_dataframe


def dataframe_to_dataset(dataframe):
    dataframe = dataframe.copy()
    labels = dataframe.pop("target")
    ds = tf.data.Dataset.from_tensor_slices((dict(dataframe), labels))
    ds = ds.shuffle(buffer_size=len(dataframe))
    return ds

2.3. Preprocessing & encoding of features#

The next cells use tfk.layers.Normalization() to apply standard scaling on the features.

Then, the tfk.layers.StringLookup and tfk.layers.IntegerLookup are used to encode categorical variables.

[6]:
def encode_numerical_feature(feature, name, dataset):
    # Create a Normalization layer for our feature
    normalizer = tfk.layers.Normalization()

    # Prepare a Dataset that only yields our feature
    feature_ds = dataset.map(lambda x, y: x[name])
    feature_ds = feature_ds.map(lambda x: tf.expand_dims(x, -1))

    # Learn the statistics of the data
    normalizer.adapt(feature_ds)

    # Normalize the input feature
    encoded_feature = normalizer(feature)
    return encoded_feature


def encode_categorical_feature(feature, name, dataset, is_string):
    lookup_class = (
        tfk.layers.StringLookup if is_string else tfk.layers.IntegerLookup
    )
    # Create a lookup layer which will turn strings into integer indices
    lookup = lookup_class(output_mode="binary")

    # Prepare a Dataset that only yields our feature
    feature_ds = dataset.map(lambda x, y: x[name])
    feature_ds = feature_ds.map(lambda x: tf.expand_dims(x, -1))

    # Learn the set of possible string values and assign them a fixed integer index
    lookup.adapt(feature_ds)

    # Turn the string input into integer indices
    encoded_feature = lookup(feature)
    return encoded_feature

2.4. Define the run-function#

The run-function defines how the objective that we want to maximize is computed. It takes a config dictionary as input and often returns a scalar value that we want to maximize. The config contains a sample value of hyperparameters that we want to tune. In this example we will search for:

  • units (default value: 32)

  • activation (default value: "relu")

  • dropout_rate (default value: 0.5)

  • num_epochs (default value: 50)

  • batch_size (default value: 32)

  • learning_rate (default value: 1e-3)

A hyperparameter value can be acessed easily in the dictionary through the corresponding key, for example config["units"].

[7]:
def count_params(model: tfk.Model) -> dict:
    """Evaluate the number of parameters of a Keras model.

    Args:
        model (tfk.Model): a Keras model.

    Returns:
        dict: a dictionary with the number of trainable ``"num_parameters_train"`` and
        non-trainable parameters ``"num_parameters"``.
    """

    def count_or_null(p):
        try:
            return K.count_params(p)
        except:
            return 0

    num_parameters_train = int(
        np.sum([count_or_null(p) for p in model.trainable_weights])
    )
    num_parameters = int(
        np.sum([count_or_null(p) for p in model.non_trainable_weights])
    )
    return {
        "num_parameters": num_parameters,
        "num_parameters_train": num_parameters_train,
    }
[8]:
def run(config: dict):
    tf.autograph.set_verbosity(0)
    # Load data and split into validation set
    train_dataframe, val_dataframe = load_data()
    train_ds = dataframe_to_dataset(train_dataframe)
    val_ds = dataframe_to_dataset(val_dataframe)
    train_ds = train_ds.batch(config["batch_size"])
    val_ds = val_ds.batch(config["batch_size"])

    # Categorical features encoded as integers
    sex = tfk.Input(shape=(1,), name="sex", dtype="int64")
    cp = tfk.Input(shape=(1,), name="cp", dtype="int64")
    fbs = tfk.Input(shape=(1,), name="fbs", dtype="int64")
    restecg = tfk.Input(shape=(1,), name="restecg", dtype="int64")
    exang = tfk.Input(shape=(1,), name="exang", dtype="int64")
    ca = tfk.Input(shape=(1,), name="ca", dtype="int64")

    # Categorical feature encoded as string
    thal = tfk.Input(shape=(1,), name="thal", dtype="string")

    # Numerical features
    age = tfk.Input(shape=(1,), name="age")
    trestbps = tfk.Input(shape=(1,), name="trestbps")
    chol = tfk.Input(shape=(1,), name="chol")
    thalach = tfk.Input(shape=(1,), name="thalach")
    oldpeak = tfk.Input(shape=(1,), name="oldpeak")
    slope = tfk.Input(shape=(1,), name="slope")

    all_inputs = [
        sex,
        cp,
        fbs,
        restecg,
        exang,
        ca,
        thal,
        age,
        trestbps,
        chol,
        thalach,
        oldpeak,
        slope,
    ]

    # Integer categorical features
    sex_encoded = encode_categorical_feature(sex, "sex", train_ds, False)
    cp_encoded = encode_categorical_feature(cp, "cp", train_ds, False)
    fbs_encoded = encode_categorical_feature(fbs, "fbs", train_ds, False)
    restecg_encoded = encode_categorical_feature(restecg, "restecg", train_ds, False)
    exang_encoded = encode_categorical_feature(exang, "exang", train_ds, False)
    ca_encoded = encode_categorical_feature(ca, "ca", train_ds, False)

    # String categorical features
    thal_encoded = encode_categorical_feature(thal, "thal", train_ds, True)

    # Numerical features
    age_encoded = encode_numerical_feature(age, "age", train_ds)
    trestbps_encoded = encode_numerical_feature(trestbps, "trestbps", train_ds)
    chol_encoded = encode_numerical_feature(chol, "chol", train_ds)
    thalach_encoded = encode_numerical_feature(thalach, "thalach", train_ds)
    oldpeak_encoded = encode_numerical_feature(oldpeak, "oldpeak", train_ds)
    slope_encoded = encode_numerical_feature(slope, "slope", train_ds)

    all_features = tfk.layers.concatenate(
        [
            sex_encoded,
            cp_encoded,
            fbs_encoded,
            restecg_encoded,
            exang_encoded,
            slope_encoded,
            ca_encoded,
            thal_encoded,
            age_encoded,
            trestbps_encoded,
            chol_encoded,
            thalach_encoded,
            oldpeak_encoded,
        ]
    )
    x = tfk.layers.Dense(config["units"], activation=config["activation"])(
        all_features
    )
    x = tfk.layers.Dropout(config["dropout_rate"])(x)
    output = tfk.layers.Dense(1, activation="sigmoid")(x)
    model = tfk.Model(all_inputs, output)

    optimizer = tfk.optimizers.Adam(learning_rate=config["learning_rate"])
    model.compile(optimizer, "binary_crossentropy", metrics=["accuracy"])

    try:
        history = model.fit(
            train_ds, epochs=config["num_epochs"], validation_data=val_ds, verbose=0
        )
    except:
        class History:
            history = {
                "accuracy": None,
                "val_accuracy": ["F_fit"],
                "loss": None,
                "val_loss": None,
            }

        history = History()


    objective = history.history["val_accuracy"][-1]
    metadata = {
        "loss": history.history["loss"],
        "val_loss": history.history["val_loss"],
        "accuracy": history.history["accuracy"],
        "val_accuracy": history.history["val_accuracy"],
    }
    metadata = {k:json.dumps(v) for k,v in metadata.items()}
    metadata.update(count_params(model))

    return {"objective": objective, "metadata": metadata}
Note

The objective maximised by DeepHyper is the scalar value returned by the run-function under the "objective" key.

In this tutorial it corresponds to the validation accuracy of the last epoch of training which we retrieve in the History object returned by the model.fit(...) call.

...
history = model.fit(
    train_ds, epochs=config["num_epochs"], validation_data=val_ds, verbose=0
)
return history.history["val_accuracy"][-1]
...

Using an objective like max(history.history['val_accuracy']) can have undesired side effects.

For example, it is possible that the training curves will overshoot a local maximum, resulting in a model without the capacity to flexibly adapt to new data in the future.

2.5. Define the Hyperparameter optimization problem#

Hyperparameter ranges are defined using the following syntax:

  • Discrete integer ranges are generated from a tuple (lower: int, upper: int)

  • Continuous prarameters are generated from a tuple (lower: float, upper: float)

  • Categorical or nonordinal hyperparameter ranges can be given as a list of possible values [val1, val2, ...]

[9]:
from deephyper.hpo import HpProblem


# Creation of an hyperparameter problem
problem = HpProblem()

# Discrete hyperparameter (sampled with uniform prior)
problem.add_hyperparameter((8, 128), "units", default_value=32)
problem.add_hyperparameter((10, 100), "num_epochs", default_value=50)


# Categorical hyperparameter (sampled with uniform prior)
ACTIVATIONS = [
    "elu", "gelu", "hard_sigmoid", "linear", "relu", "selu",
    "sigmoid", "softplus", "softsign", "swish", "tanh",
]
problem.add_hyperparameter(ACTIVATIONS, "activation", default_value="relu")


# Real hyperparameter (sampled with uniform prior)
problem.add_hyperparameter((0.0, 0.6), "dropout_rate", default_value=0.5)


# Discrete and Real hyperparameters (sampled with log-uniform)
problem.add_hyperparameter((8, 256, "log-uniform"), "batch_size", default_value=32)
problem.add_hyperparameter((1e-5, 1e-2, "log-uniform"), "learning_rate", default_value=1e-3)

problem
[9]:
Configuration space object:
  Hyperparameters:
    activation, Type: Categorical, Choices: {elu, gelu, hard_sigmoid, linear, relu, selu, sigmoid, softplus, softsign, swish, tanh}, Default: relu
    batch_size, Type: UniformInteger, Range: [8, 256], Default: 32, on log-scale
    dropout_rate, Type: UniformFloat, Range: [0.0, 0.6], Default: 0.5
    learning_rate, Type: UniformFloat, Range: [1e-05, 0.01], Default: 0.001, on log-scale
    num_epochs, Type: UniformInteger, Range: [10, 100], Default: 50
    units, Type: UniformInteger, Range: [8, 128], Default: 32

2.6. Evaluate a default configuration#

We evaluate the performance of the default set of hyperparameters provided in the Keras tutorial.

[10]:
import ray


# We launch the Ray run-time depending of the detected local ressources
# and execute the `run` function with the default configuration
# WARNING: in the case of GPUs it is important to follow this scheme
# to avoid multiple processes (Ray workers vs current process) to lock
# the same GPU.
with warnings.catch_warnings():
    warnings.simplefilter("ignore")

    if is_gpu_available:
        if not(ray.is_initialized()):
            ray.init(num_cpus=n_gpus, num_gpus=n_gpus, log_to_driver=False)

        run_default = ray.remote(num_cpus=1, num_gpus=1)(run)
        out = ray.get(run_default.remote(problem.default_configuration))
    else:
        if not(ray.is_initialized()):
            ray.init(num_cpus=1, log_to_driver=False)
        run_default = run
        out = run_default(problem.default_configuration)

objective_default = out["objective"]
metadata_default = out["metadata"]

print(f"Accuracy Default Configuration:  {objective_default:.3f}")

print("Metadata Default Configuration")
for k,v in out["metadata"].items():
    print(f"\t- {k}: {v}")
2024-12-16 14:04:27,079 INFO worker.py:1810 -- Started a local Ray instance. View the dashboard at http://127.0.0.1:8265 
WARNING:absl:At this time, the v2.11+ optimizer `tf.keras.optimizers.Adam` runs slowly on M1/M2 Macs, please use the legacy TF-Keras optimizer instead, located at `tf.keras.optimizers.legacy.Adam`.
Accuracy Default Configuration:  0.820
Metadata Default Configuration
        - loss: [0.7264149188995361, 0.6627438068389893, 0.5929161310195923, 0.5542343854904175, 0.5475674867630005, 0.5502562522888184, 0.5195257067680359, 0.4931687116622925, 0.46538957953453064, 0.440327912569046, 0.3925045430660248, 0.4195988178253174, 0.42187967896461487, 0.3911048173904419, 0.39405590295791626, 0.4064271152019501, 0.3759456276893616, 0.38313448429107666, 0.36988547444343567, 0.3851970136165619, 0.3579743802547455, 0.32884514331817627, 0.34706202149391174, 0.3246057331562042, 0.32372382283210754, 0.3260881304740906, 0.3325619697570801, 0.31495144963264465, 0.33428072929382324, 0.31294286251068115, 0.31025436520576477, 0.33517444133758545, 0.3024511933326721, 0.3178369700908661, 0.30852562189102173, 0.3208227753639221, 0.2828509509563446, 0.30513060092926025, 0.3040333390235901, 0.30188292264938354, 0.29300230741500854, 0.29074907302856445, 0.2864045202732086, 0.28148123621940613, 0.28043219447135925, 0.2753585875034332, 0.28282228112220764, 0.30097872018814087, 0.28231197595596313, 0.28002142906188965]
        - val_loss: [0.588463306427002, 0.5357356071472168, 0.4963495135307312, 0.4686204791069031, 0.4479129910469055, 0.4316793978214264, 0.41860753297805786, 0.4082474708557129, 0.3998509347438812, 0.3935147225856781, 0.38819533586502075, 0.38361212611198425, 0.3806271255016327, 0.37765082716941833, 0.3746330440044403, 0.37243545055389404, 0.36961108446121216, 0.3686935603618622, 0.3676575720310211, 0.36642464995384216, 0.3659955561161041, 0.36697107553482056, 0.3678717017173767, 0.369035929441452, 0.36980104446411133, 0.37020236253738403, 0.37156641483306885, 0.3713993728160858, 0.3708018958568573, 0.3703863322734833, 0.37158429622650146, 0.3729158341884613, 0.3733491599559784, 0.37256723642349243, 0.3729548752307892, 0.3723505139350891, 0.37281492352485657, 0.3737792372703552, 0.37401700019836426, 0.374211847782135, 0.3739033043384552, 0.3738318383693695, 0.3738160729408264, 0.3752853572368622, 0.37610745429992676, 0.3757418096065521, 0.3749243915081024, 0.37448209524154663, 0.3736680746078491, 0.37360718846321106]
        - accuracy: [0.56611567735672, 0.6280992031097412, 0.6942148804664612, 0.6900826692581177, 0.7066115736961365, 0.7066115736961365, 0.71074378490448, 0.7603305578231812, 0.7561983466148376, 0.7809917330741882, 0.8140496015548706, 0.7809917330741882, 0.7809917330741882, 0.8429751992225647, 0.8099173307418823, 0.8016529083251953, 0.8264462947845459, 0.8140496015548706, 0.8264462947845459, 0.8223140239715576, 0.85537189245224, 0.8512396812438965, 0.8471074104309082, 0.85537189245224, 0.8429751992225647, 0.85537189245224, 0.8471074104309082, 0.8512396812438965, 0.8347107172012329, 0.85537189245224, 0.8512396812438965, 0.8595041036605835, 0.8677685856819153, 0.8677685856819153, 0.8471074104309082, 0.8429751992225647, 0.8925619721412659, 0.8842975497245789, 0.8595041036605835, 0.8636363744735718, 0.8595041036605835, 0.8636363744735718, 0.8636363744735718, 0.8719007968902588, 0.8842975497245789, 0.8760330677032471, 0.8760330677032471, 0.8636363744735718, 0.8842975497245789, 0.8884297609329224]
        - val_accuracy: [0.8196721076965332, 0.7868852615356445, 0.8032786846160889, 0.8032786846160889, 0.8360655903816223, 0.8360655903816223, 0.8524590134620667, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8360655903816223, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332, 0.8196721076965332]
        - num_parameters: 18
        - num_parameters_train: 1217

2.7. Define the evaluator object#

The Evaluator object allows to change the parallelization backend used by DeepHyper.
It is a standalone object which schedules the execution of remote tasks. All evaluators needs a run_function to be instantiated.
Then a keyword method defines the backend (e.g., "ray") and the method_kwargs corresponds to keyword arguments of this chosen method.
evaluator = Evaluator.create(run_function, method, method_kwargs)

Once created the evaluator.num_workers gives access to the number of available parallel workers.

Finally, to submit and collect tasks to the evaluator one just needs to use the following interface:

configs = [{"units": 8, ...}, ...]
evaluator.submit(configs)
...
# To collect the first finished task (asynchronous)
tasks_done = evaluator.gather("BATCH", size=1)

# To collect all of the pending tasks (synchronous)
tasks_done = evaluator.gather("ALL")

Warning

Each Evaluator saves its own state, therefore it is crucial to create a new evaluator when launching a fresh search.

[11]:
from deephyper.evaluator import Evaluator
from deephyper.evaluator.callback import TqdmCallback


def get_evaluator(run_function):
    # Default arguments for Ray: 1 worker and 1 worker per evaluation
    method_kwargs = {
        "num_cpus": 1,
        "num_cpus_per_task": 1,
        "callbacks": [TqdmCallback()]
    }

    # If GPU devices are detected then it will create 'n_gpus' workers
    # and use 1 worker for each evaluation
    if is_gpu_available:
        method_kwargs["num_cpus"] = n_gpus
        method_kwargs["num_gpus"] = n_gpus
        method_kwargs["num_cpus_per_task"] = 1
        method_kwargs["num_gpus_per_task"] = 1

    evaluator = Evaluator.create(
        run_function,
        method="ray",
        method_kwargs=method_kwargs
    )
    print(f"Created new evaluator with {evaluator.num_workers} worker{'s' if evaluator.num_workers > 1 else ''} and config: {method_kwargs}", )

    return evaluator

evaluator_1 = get_evaluator(run)
Created new evaluator with 1 worker and config: {'num_cpus': 1, 'num_cpus_per_task': 1, 'callbacks': [<deephyper.evaluator.callback.TqdmCallback object at 0x32b4fa5d0>]}

2.8. Define and run the centralized Bayesian optimization search (CBO)#

A primary pillar of hyperparameter search in DeepHyper is given by a centralized Bayesian optimization search (henceforth CBO). CBO may be described in the following algorithm:

c70c3b42b5224f028fa4940cc56ba7e5


Following the parallelized evaluation of these configurations, a low-fidelity and high efficiency model (henceforth “the surrogate”) is devised to reproduce the relationship between the input variables involved in the model (i.e., the choice of hyperparameters) and the outputs (which are generally a measure of validation data accuracy).

After obtaining this surrogate of the validation accuracy, we may utilize ideas from classical methods in Bayesian optimization literature for adaptively sample the search space of hyperparameters.

First, the surrogate is used to obtain an estimate for the mean value of the validation accuracy at a certain sampling location \(x\) in addition to an estimated variance. The latter requirement restricts us to the use of high efficiency data-driven modeling strategies that have inbuilt variance estimates (such as a Gaussian process or Random Forest regressor).

Regions where the mean is high represent opportunities for exploitation and regions where the variance is high represent opportunities for exploration. An optimistic acquisition function called UCB can be constructed using these two quantities:

\[L_{\text{UCB}}(x) = \mu(x) + \kappa \cdot \sigma(x)\]

The unevaluated hyperparameter configurations that maximize the acquisition function are chosen for the next batch of evaluations.

Note that the choice of the variance weighting parameter \(\kappa\) controls the degree of exploration in the hyperparameter search with zero indicating purely exploitation (unseen configurations where the predicted accuracy is highest will be sampled).

The top s configurations are selected for the new batch. The following schematic demonstrates this process:

fa4e1724386a40319b08ef5a29c2d68b

The process of obtaining s configurations relies on the “constant-liar” strategy where a sampled configuration is mapped to a dummy output given by a bulk metric of all the evaluated configurations thus far (such as the maximum, mean or median validation accuracy).

Prior to sampling the next configuration by acquisition function maximization, the surrogate is retrained with the dummy output as a data point. As the true validation accuracy becomes available for one of the sampled configurations, the dummy output is replaced and the surrogate is updated.

This allows for scalable asynchronous (or batch synchronous) sampling of new hyperparameter configurations.

2.8.1. Choice of surrogate model#

Users should note that our choice of the surrogate is given by the Random Forest regressor due to its ability to handle non-ordinal data (hyperparameter configurations may not be purely continuous or even numerical). Evidence for how they outperform other methods (such as Gaussian processes) is also available in [1]

6c4b0898e4f0425390f937ba6792b3bc

25e3c7a59e0e4aec8ad5377e6640ad73

2.8.1.1. Setup CBO#

We create the CBO using the problem and evaluator defined above.

[12]:
from deephyper.hpo import CBO
# Uncomment the following line to show the arguments of CBO.
# help(CBO)
[ ]:
# Instanciate the search with the problem and the evaluator that we created before
search = CBO(
    problem,
    evaluator_1,
    acq_func="UCBd",
    acq_optimizer="mixedga",
    acq_optimizer_freq=1,
    initial_points=[problem.default_configuration],
)

Note

All DeepHyper’s search algorithm have two stopping criteria:

  • max_evals (int): Defines the maximum number of evaluations that we want to perform. Default to -1 for an infinite number.

  • timeout (int): Defines a time budget (in seconds) before stopping the search. Default to None for an infinite time budget.

[14]:
results = search.search(max_evals=50)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[14], line 1
----> 1 results = search.search(max_evals=50)

File ~/Documents/Argonne/deephyper/src/deephyper/hpo/_search.py:192, in Search.search(self, max_evals, timeout, max_evals_strict)
    190     if np.isscalar(timeout) and timeout > 0:
    191         self._evaluator.timeout = timeout
--> 192     self._search(max_evals, timeout)
    193 except TimeoutReached:
    194     self.stopped = True

File ~/Documents/Argonne/deephyper/src/deephyper/hpo/_cbo.py:567, in CBO._search(self, max_evals, timeout, max_evals_strict)
    564 if self._opt is None:
    565     self._setup_optimizer()
--> 567 super()._search(max_evals, timeout, max_evals_strict)

File ~/Documents/Argonne/deephyper/src/deephyper/hpo/_search.py:305, in Search._search(self, max_evals, timeout, max_evals_strict)
    302 logging.info("Gathering jobs...")
    303 t1 = time.time()
--> 305 new_results = self._evaluator.gather(self.gather_type, self.gather_batch_size)
    307 # Check if results are received from other search instances
    308 # connected to the same storage
    309 if isinstance(new_results, tuple) and len(new_results) == 2:

File ~/Documents/Argonne/deephyper/src/deephyper/evaluator/_evaluator.py:422, in Evaluator.gather(self, type, size)
    419     size = len(self._tasks_running)  # Get all tasks.
    421 if size > 0:
--> 422     self.loop.run_until_complete(self._await_at_least_n_tasks(size))
    424 local_results = self.process_local_tasks_done(self._tasks_done)
    426 # Access storage to return results from other processes

File ~/miniforge3/envs/dh-3.12-240724/lib/python3.12/asyncio/base_events.py:663, in BaseEventLoop.run_until_complete(self, future)
    652 """Run until the Future is done.
    653
    654 If the argument is a coroutine, it is wrapped in a Task.
   (...)
    660 Return the Future's result, or raise its exception.
    661 """
    662 self._check_closed()
--> 663 self._check_running()
    665 new_task = not futures.isfuture(future)
    666 future = tasks.ensure_future(future, loop=self)

File ~/miniforge3/envs/dh-3.12-240724/lib/python3.12/asyncio/base_events.py:622, in BaseEventLoop._check_running(self)
    620 def _check_running(self):
    621     if self.is_running():
--> 622         raise RuntimeError('This event loop is already running')
    623     if events._get_running_loop() is not None:
    624         raise RuntimeError(
    625             'Cannot run the event loop while another loop is running')

RuntimeError: This event loop is already running

Warning

The search call does not output any information about the current status of the search. However, a results.csv file is created in the local directly and can be visualized to see finished tasks.

The returned results is a Pandas Dataframe where columns starting by "p:" are hyperparameters, columns starting by "m:" are additional metadata (from the user or from the Evaluator) as well as the objective value and the job_id:

  • job_id is a unique identifier corresponding to the order of creation of tasks.

  • objective is the value returned by the run-function.

  • m:timestamp_submit is the time (in seconds) when the task was created by the evaluator since the creation of the evaluator.

  • m:timestamp_gather is the time (in seconds) when the task was received after finishing by the evaluator since the creation of the evaluator.

[15]:
results
[15]:
p:activation p:batch_size p:dropout_rate p:learning_rate p:num_epochs p:units objective job_id job_status m:timestamp_submit m:loss m:val_loss m:accuracy m:val_accuracy m:num_parameters m:num_parameters_train m:timestamp_gather
0 relu 32 0.500000 0.001000 50 32 0.803279 0 DONE 19.298753 [0.7522220015525818, 0.6794403195381165, 0.633... [0.6496000289916992, 0.5947670340538025, 0.549... [0.5247933864593506, 0.6033057570457458, 0.640... [0.6721311211585999, 0.7704917788505554, 0.770... 18 1217 24.308906
1 swish 24 0.354249 0.000129 61 38 0.786885 1 DONE 24.353600 [0.7639603614807129, 0.7402582168579102, 0.740... [0.7230178117752075, 0.7104435563087463, 0.698... [0.4545454680919647, 0.4545454680919647, 0.520... [0.5245901346206665, 0.5409836173057556, 0.557... 18 1445 26.377773
2 softsign 125 0.448487 0.000104 52 57 0.721311 2 DONE 26.404781 [0.7735962867736816, 0.7760669589042664, 0.748... [0.7457154989242554, 0.7422785758972168, 0.738... [0.42148759961128235, 0.4545454680919647, 0.48... [0.4262295067310333, 0.4262295067310333, 0.426... 18 2167 28.073918
3 swish 38 0.026489 0.002062 32 63 0.803279 3 DONE 28.101012 [0.6325234770774841, 0.5060442686080933, 0.431... [0.5108410120010376, 0.42839738726615906, 0.39... [0.6570248007774353, 0.7933884263038635, 0.818... [0.8032786846160889, 0.8032786846160889, 0.819... 18 2395 29.637213
4 hard_sigmoid 10 0.040314 0.000152 48 91 0.836066 4 DONE 29.664040 [0.6587173342704773, 0.6248915195465088, 0.595... [0.6039978861808777, 0.5679578185081482, 0.541... [0.6652892827987671, 0.71074378490448, 0.71900... [0.7704917788505554, 0.7704917788505554, 0.770... 18 3459 31.929472
5 elu 8 0.092565 0.000090 90 100 0.819672 5 DONE 31.956442 [0.632409930229187, 0.5705690979957581, 0.5305... [0.581476092338562, 0.53741455078125, 0.501731... [0.6859503984451294, 0.7685950398445129, 0.805... [0.7704917788505554, 0.7868852615356445, 0.803... 18 3801 35.583780
6 hard_sigmoid 213 0.062061 0.000015 17 22 0.229508 6 DONE 35.610619 [0.7244776487350464, 0.7173941135406494, 0.715... [0.738309919834137, 0.7380985617637634, 0.7378... [0.40909090638160706, 0.41735535860061646, 0.4... [0.2295081913471222, 0.2295081913471222, 0.229... 18 837 36.913661
7 swish 20 0.161608 0.005545 63 62 0.786885 7 DONE 36.939836 [0.5608718991279602, 0.3559199273586273, 0.302... [0.3681420683860779, 0.39301276206970215, 0.40... [0.71074378490448, 0.8347107172012329, 0.87603... [0.7704917788505554, 0.8360655903816223, 0.836... 18 2357 39.035438
8 hard_sigmoid 42 0.478465 0.007507 46 46 0.819672 8 DONE 39.062334 [0.6264892816543579, 0.5298553705215454, 0.515... [0.46544015407562256, 0.419380247592926, 0.398... [0.6322314143180847, 0.7396694421768188, 0.714... [0.7704917788505554, 0.7868852615356445, 0.770... 18 1749 40.845130
9 hard_sigmoid 22 0.254639 0.000233 31 43 0.786885 9 DONE 40.872630 [0.69364994764328, 0.6741333603858948, 0.64725... [0.6586198210716248, 0.6416864395141602, 0.626... [0.5702479481697083, 0.5743801593780518, 0.644... [0.6557376980781555, 0.6721311211585999, 0.754... 18 1635 42.473509
10 hard_sigmoid 11 0.010323 0.000014 49 91 0.639344 10 DONE 43.029043 [0.9453896880149841, 0.9396271705627441, 0.926... [0.977564811706543, 0.9690657258033752, 0.9601... [0.2851239740848541, 0.2851239740848541, 0.285... [0.2295081913471222, 0.2295081913471222, 0.229... 18 3459 46.093403
11 hard_sigmoid 10 0.004417 0.000148 49 91 0.803279 11 DONE 46.801438 [0.9805307984352112, 0.8879711627960205, 0.808... [0.9651934504508972, 0.8695720434188843, 0.795... [0.2851239740848541, 0.2851239740848541, 0.285... [0.2295081913471222, 0.2295081913471222, 0.229... 18 3459 49.263470
12 hard_sigmoid 10 0.042830 0.000189 15 91 0.786885 12 DONE 49.817882 [0.7628817558288574, 0.6877034902572632, 0.644... [0.7294394969940186, 0.6623942852020264, 0.617... [0.3264462947845459, 0.5413222908973694, 0.706... [0.26229506731033325, 0.7704917788505554, 0.77... 18 3459 51.414139
13 hard_sigmoid 11 0.052243 0.000181 48 88 0.803279 13 DONE 51.918125 [0.7297526001930237, 0.6772193312644958, 0.633... [0.6942347884178162, 0.6416444778442383, 0.600... [0.3719008266925812, 0.5909090638160706, 0.710... [0.44262295961380005, 0.7868852615356445, 0.78... 18 3345 54.095216
14 hard_sigmoid 10 0.041325 0.000049 47 126 0.786885 14 DONE 54.813164 [0.7548640966415405, 0.7348778247833252, 0.708... [0.7591538429260254, 0.7276302576065063, 0.703... [0.3016528785228729, 0.3677685856819153, 0.5, ... [0.21311475336551666, 0.26229506731033325, 0.4... 18 4789 57.250200
15 elu 65 0.089230 0.009213 95 8 0.803279 15 DONE 57.731290 [0.6084141731262207, 0.46208781003952026, 0.37... [0.41744959354400635, 0.3670448660850525, 0.37... [0.6363636255264282, 0.7685950398445129, 0.809... [0.8196721076965332, 0.8524590134620667, 0.836... 18 305 59.692247
16 elu 9 0.277494 0.000029 89 92 0.819672 16 DONE 60.467549 [0.7212059497833252, 0.7207241058349609, 0.674... [0.662761926651001, 0.6511141061782837, 0.6394... [0.56611567735672, 0.586776852607727, 0.611570... [0.6065573692321777, 0.6557376980781555, 0.672... 18 3497 63.682839
17 elu 9 0.367547 0.000011 88 90 0.803279 17 DONE 64.294921 [0.7646175622940063, 0.7496189475059509, 0.769... [0.7482584714889526, 0.7417523860931396, 0.734... [0.4834710657596588, 0.5123966932296753, 0.5, ... [0.49180328845977783, 0.49180328845977783, 0.4... 18 3421 67.540046
18 hard_sigmoid 100 0.484333 0.009965 46 47 0.819672 18 DONE 68.222909 [0.7115792632102966, 0.5767557621002197, 0.520... [0.5003692507743835, 0.47499191761016846, 0.42... [0.6983470916748047, 0.6942148804664612, 0.731... [0.7704917788505554, 0.7868852615356445, 0.770... 18 1787 69.957222
19 hard_sigmoid 214 0.480979 0.009989 46 50 0.852459 19 DONE 70.504195 [1.2610570192337036, 0.8506705164909363, 0.622... [0.8369094133377075, 0.5838184952735901, 0.470... [0.2851239740848541, 0.43388429284095764, 0.63... [0.2295081913471222, 0.7868852615356445, 0.786... 18 1901 72.019933
20 hard_sigmoid 254 0.481125 0.007501 31 127 0.852459 20 DONE 72.619805 [0.7543298602104187, 0.687750518321991, 0.6241... [0.5784507989883423, 0.5200879573822021, 0.499... [0.4958677589893341, 0.6652892827987671, 0.710... [0.7704917788505554, 0.7704917788505554, 0.770... 18 4827 74.013629
21 hard_sigmoid 254 0.479997 0.002899 19 109 0.770492 21 DONE 74.779004 [1.0613369941711426, 1.0073189735412598, 0.875... [0.9813054203987122, 0.8806155920028687, 0.792... [0.3471074402332306, 0.3181818127632141, 0.438... [0.2295081913471222, 0.2295081913471222, 0.229... 18 4143 76.249230
22 hard_sigmoid 225 0.479887 0.005930 27 120 0.852459 22 DONE 77.105368 [0.8292042016983032, 0.6387853622436523, 0.578... [0.5494740009307861, 0.4861133396625519, 0.481... [0.42148759961128235, 0.6776859760284424, 0.71... [0.7704917788505554, 0.7704917788505554, 0.770... 18 4561 78.530955
23 hard_sigmoid 210 0.470153 0.006405 19 123 0.852459 23 DONE 79.051023 [0.7748973369598389, 0.6380338072776794, 0.579... [0.5329240560531616, 0.47213122248649597, 0.44... [0.4793388545513153, 0.6611570119857788, 0.710... [0.7868852615356445, 0.7704917788505554, 0.786... 18 4675 80.376450
24 hard_sigmoid 251 0.469616 0.004705 17 123 0.786885 24 DONE 81.166138 [0.6553329229354858, 0.6020475029945374, 0.584... [0.5192084312438965, 0.4955504536628723, 0.479... [0.6322314143180847, 0.7355371713638306, 0.719... [0.7704917788505554, 0.7704917788505554, 0.770... 18 4675 82.451308
25 softsign 250 0.480790 0.007832 53 125 0.803279 25 DONE 83.011265 [0.7321151494979858, 0.5804288387298584, 0.476... [0.5269145369529724, 0.43378546833992004, 0.38... [0.46694216132164, 0.7148760557174683, 0.78512... [0.8196721076965332, 0.8196721076965332, 0.786... 18 4751 84.752865
26 hard_sigmoid 255 0.480756 0.005885 22 99 0.852459 26 DONE 85.420057 [0.6696801781654358, 0.609772801399231, 0.6074... [0.5007878541946411, 0.4793156087398529, 0.462... [0.6528925895690918, 0.6983470916748047, 0.706... [0.7704917788505554, 0.7704917788505554, 0.770... 18 3763 86.796966
27 hard_sigmoid 252 0.401441 0.006651 20 109 0.852459 27 DONE 87.506137 [1.0473660230636597, 0.8656474351882935, 0.706... [0.8114010095596313, 0.6466981172561646, 0.547... [0.32231405377388, 0.37603306770324707, 0.5619... [0.2295081913471222, 0.7868852615356445, 0.770... 18 4143 88.849887
28 hard_sigmoid 218 0.189225 0.006515 19 99 0.852459 28 DONE 89.431249 [0.6733125448226929, 0.583982527256012, 0.5316... [0.5367051959037781, 0.4850410521030426, 0.448... [0.5909090638160706, 0.7272727489471436, 0.727... [0.7704917788505554, 0.7704917788505554, 0.770... 18 3763 90.752168
29 hard_sigmoid 254 0.481146 0.005438 22 107 0.868852 29 DONE 91.540962 [0.8287509083747864, 0.7461135983467102, 0.665... [0.695465087890625, 0.6011614799499512, 0.5411... [0.38842976093292236, 0.5206611752510071, 0.61... [0.5081967115402222, 0.7704917788505554, 0.770... 18 4067 92.863941
30 hard_sigmoid 254 0.495178 0.004864 22 107 0.803279 30 DONE 93.677080 [0.9318445324897766, 0.8095709681510925, 0.722... [0.776494562625885, 0.6602926850318909, 0.5766... [0.3512396812438965, 0.42561984062194824, 0.56... [0.2295081913471222, 0.7540983557701111, 0.786... 18 4067 95.254718
31 hard_sigmoid 254 0.482023 0.005422 11 107 0.819672 31 DONE 96.240324 [0.7314707040786743, 0.674094021320343, 0.6030... [0.5905916690826416, 0.5288848280906677, 0.495... [0.5413222908973694, 0.6157024502754211, 0.648... [0.7704917788505554, 0.7704917788505554, 0.770... 18 4067 97.500010
32 hard_sigmoid 251 0.041958 0.005440 22 111 0.852459 32 DONE 98.373513 [0.6084542870521545, 0.5798646211624146, 0.559... [0.5280794501304626, 0.5053377747535706, 0.487... [0.7148760557174683, 0.7148760557174683, 0.714... [0.7704917788505554, 0.7704917788505554, 0.770... 18 4219 99.715976
33 linear 256 0.456569 0.005438 67 75 0.819672 33 DONE 100.393998 [0.6855553388595581, 0.5385969281196594, 0.455... [0.47430887818336487, 0.4145543575286865, 0.38... [0.6239669322967529, 0.7561983466148376, 0.776... [0.8032786846160889, 0.8032786846160889, 0.786... 18 2851 102.011602
34 hard_sigmoid 254 0.236034 0.005247 22 124 0.803279 34 DONE 102.587218 [0.6595293283462524, 0.6073529124259949, 0.598... [0.5491055846214294, 0.5108243823051453, 0.490... [0.6198347210884094, 0.6983470916748047, 0.719... [0.7704917788505554, 0.7704917788505554, 0.770... 18 4713 104.143555
35 elu 254 0.597516 0.005437 21 107 0.836066 35 DONE 105.333566 [0.747581958770752, 0.559327244758606, 0.46736... [0.5129536390304565, 0.4232344925403595, 0.378... [0.5330578684806824, 0.702479362487793, 0.8057... [0.7868852615356445, 0.7704917788505554, 0.786... 18 4067 106.656631
36 hard_sigmoid 218 0.462655 0.005438 22 21 0.786885 36 DONE 107.529558 [0.6573420166969299, 0.6241464018821716, 0.620... [0.5412810444831848, 0.5244167447090149, 0.507... [0.6652892827987671, 0.6487603187561035, 0.690... [0.7704917788505554, 0.7704917788505554, 0.770... 18 799 108.890761
37 hard_sigmoid 220 0.151721 0.007159 17 52 0.836066 37 DONE 109.478401 [1.2892435789108276, 0.9626411199569702, 0.750... [1.0040860176086426, 0.7683011293411255, 0.614... [0.2851239740848541, 0.2851239740848541, 0.433... [0.2295081913471222, 0.2295081913471222, 0.803... 18 1977 110.786285
38 hard_sigmoid 210 0.481841 0.005438 22 107 0.868852 38 DONE 111.632075 [0.7545424103736877, 0.6784554123878479, 0.577... [0.5750901699066162, 0.5008385181427002, 0.481... [0.4958677589893341, 0.6157024502754211, 0.739... [0.7704917788505554, 0.7704917788505554, 0.770... 18 4067 113.192420
39 hard_sigmoid 203 0.492826 0.005321 20 102 0.836066 39 DONE 113.682802 [0.8296107649803162, 0.6741348505020142, 0.592... [0.587567150592804, 0.4910070300102234, 0.4631... [0.43388429284095764, 0.5826446413993835, 0.68... [0.7868852615356445, 0.7704917788505554, 0.770... 18 3877 115.029475
40 relu 218 0.488815 0.005438 22 106 0.786885 40 DONE 115.841947 [0.7895721793174744, 0.6154360771179199, 0.539... [0.5899650454521179, 0.4901629090309143, 0.436... [0.5123966932296753, 0.6570248007774353, 0.710... [0.7704917788505554, 0.7213114500045776, 0.770... 18 4029 117.180042
41 linear 211 0.482496 0.005438 23 107 0.803279 41 DONE 118.049994 [0.6696746945381165, 0.45291176438331604, 0.39... [0.3958238661289215, 0.34970206022262573, 0.34... [0.5909090638160706, 0.7685950398445129, 0.814... [0.8196721076965332, 0.7868852615356445, 0.803... 18 4067 119.390520
42 hard_sigmoid 137 0.304672 0.005433 77 106 0.786885 42 DONE 119.806385 [0.6320827007293701, 0.5912678837776184, 0.528... [0.5139389634132385, 0.4769989848136902, 0.447... [0.6528925895690918, 0.7148760557174683, 0.739... [0.7704917788505554, 0.7704917788505554, 0.786... 18 4029 121.544481
43 hard_sigmoid 213 0.464688 0.005430 22 81 0.819672 43 DONE 122.233962 [0.791218638420105, 0.6553912162780762, 0.5934... [0.5946007966995239, 0.5060430765151978, 0.468... [0.4834710657596588, 0.5785123705863953, 0.698... [0.7704917788505554, 0.7704917788505554, 0.770... 18 3079 123.827419
44 hard_sigmoid 207 0.397137 0.005438 58 106 0.836066 44 DONE 124.510809 [1.029747724533081, 0.7393372058868408, 0.6286... [0.7274479269981384, 0.5479812026023865, 0.486... [0.3099173605442047, 0.5, 0.6528925895690918, ... [0.2950819730758667, 0.7704917788505554, 0.770... 18 4029 126.174829
45 hard_sigmoid 254 0.273684 0.006529 70 106 0.803279 45 DONE 126.860459 [0.6634817719459534, 0.5854498147964478, 0.573... [0.5192104578018188, 0.5170759558677673, 0.504... [0.7148760557174683, 0.7148760557174683, 0.743... [0.7704917788505554, 0.7704917788505554, 0.770... 18 4029 128.550364
46 hard_sigmoid 168 0.481791 0.005438 22 107 0.852459 46 DONE 129.907288 [0.8404354453086853, 0.6287763714790344, 0.606... [0.5717772841453552, 0.5005543231964111, 0.480... [0.4545454680919647, 0.6363636255264282, 0.714... [0.7704917788505554, 0.7704917788505554, 0.770... 18 4067 131.242743
47 hard_sigmoid 128 0.000276 0.005500 22 107 0.836066 47 DONE 131.834284 [0.6067198514938354, 0.5579516291618347, 0.525... [0.4999825358390808, 0.4646409749984741, 0.438... [0.71074378490448, 0.7148760557174683, 0.71487... [0.7704917788505554, 0.7704917788505554, 0.786... 18 4067 133.405986
48 hard_sigmoid 21 0.342111 0.006704 19 115 0.803279 48 DONE 133.948273 [0.6007430553436279, 0.42002755403518677, 0.35... [0.42077913880348206, 0.368651419878006, 0.369... [0.6611570119857788, 0.8057851195335388, 0.847... [0.8196721076965332, 0.8360655903816223, 0.852... 18 4371 135.453494
49 hard_sigmoid 210 0.451442 0.010000 35 35 0.836066 49 DONE 136.282800 [0.6316529512405396, 0.6350192427635193, 0.574... [0.5175772905349731, 0.48914796113967896, 0.46... [0.6818181872367859, 0.6942148804664612, 0.710... [0.7704917788505554, 0.7704917788505554, 0.770... 18 1331 137.721099

The search can be continued without any issue.

[16]:
results = search.search(max_evals=5)

results
[16]:
p:activation p:batch_size p:dropout_rate p:learning_rate p:num_epochs p:units objective job_id job_status m:timestamp_submit m:loss m:val_loss m:accuracy m:val_accuracy m:num_parameters m:num_parameters_train m:timestamp_gather
0 relu 32 0.500000 0.001000 50 32 0.803279 0 DONE 19.298753 [0.7522220015525818, 0.6794403195381165, 0.633... [0.6496000289916992, 0.5947670340538025, 0.549... [0.5247933864593506, 0.6033057570457458, 0.640... [0.6721311211585999, 0.7704917788505554, 0.770... 18 1217 24.308906
1 swish 24 0.354249 0.000129 61 38 0.786885 1 DONE 24.353600 [0.7639603614807129, 0.7402582168579102, 0.740... [0.7230178117752075, 0.7104435563087463, 0.698... [0.4545454680919647, 0.4545454680919647, 0.520... [0.5245901346206665, 0.5409836173057556, 0.557... 18 1445 26.377773
2 softsign 125 0.448487 0.000104 52 57 0.721311 2 DONE 26.404781 [0.7735962867736816, 0.7760669589042664, 0.748... [0.7457154989242554, 0.7422785758972168, 0.738... [0.42148759961128235, 0.4545454680919647, 0.48... [0.4262295067310333, 0.4262295067310333, 0.426... 18 2167 28.073918
3 swish 38 0.026489 0.002062 32 63 0.803279 3 DONE 28.101012 [0.6325234770774841, 0.5060442686080933, 0.431... [0.5108410120010376, 0.42839738726615906, 0.39... [0.6570248007774353, 0.7933884263038635, 0.818... [0.8032786846160889, 0.8032786846160889, 0.819... 18 2395 29.637213
4 hard_sigmoid 10 0.040314 0.000152 48 91 0.836066 4 DONE 29.664040 [0.6587173342704773, 0.6248915195465088, 0.595... [0.6039978861808777, 0.5679578185081482, 0.541... [0.6652892827987671, 0.71074378490448, 0.71900... [0.7704917788505554, 0.7704917788505554, 0.770... 18 3459 31.929472
5 elu 8 0.092565 0.000090 90 100 0.819672 5 DONE 31.956442 [0.632409930229187, 0.5705690979957581, 0.5305... [0.581476092338562, 0.53741455078125, 0.501731... [0.6859503984451294, 0.7685950398445129, 0.805... [0.7704917788505554, 0.7868852615356445, 0.803... 18 3801 35.583780
6 hard_sigmoid 213 0.062061 0.000015 17 22 0.229508 6 DONE 35.610619 [0.7244776487350464, 0.7173941135406494, 0.715... [0.738309919834137, 0.7380985617637634, 0.7378... [0.40909090638160706, 0.41735535860061646, 0.4... [0.2295081913471222, 0.2295081913471222, 0.229... 18 837 36.913661
7 swish 20 0.161608 0.005545 63 62 0.786885 7 DONE 36.939836 [0.5608718991279602, 0.3559199273586273, 0.302... [0.3681420683860779, 0.39301276206970215, 0.40... [0.71074378490448, 0.8347107172012329, 0.87603... [0.7704917788505554, 0.8360655903816223, 0.836... 18 2357 39.035438
8 hard_sigmoid 42 0.478465 0.007507 46 46 0.819672 8 DONE 39.062334 [0.6264892816543579, 0.5298553705215454, 0.515... [0.46544015407562256, 0.419380247592926, 0.398... [0.6322314143180847, 0.7396694421768188, 0.714... [0.7704917788505554, 0.7868852615356445, 0.770... 18 1749 40.845130
9 hard_sigmoid 22 0.254639 0.000233 31 43 0.786885 9 DONE 40.872630 [0.69364994764328, 0.6741333603858948, 0.64725... [0.6586198210716248, 0.6416864395141602, 0.626... [0.5702479481697083, 0.5743801593780518, 0.644... [0.6557376980781555, 0.6721311211585999, 0.754... 18 1635 42.473509
10 hard_sigmoid 11 0.010323 0.000014 49 91 0.639344 10 DONE 43.029043 [0.9453896880149841, 0.9396271705627441, 0.926... [0.977564811706543, 0.9690657258033752, 0.9601... [0.2851239740848541, 0.2851239740848541, 0.285... [0.2295081913471222, 0.2295081913471222, 0.229... 18 3459 46.093403
11 hard_sigmoid 10 0.004417 0.000148 49 91 0.803279 11 DONE 46.801438 [0.9805307984352112, 0.8879711627960205, 0.808... [0.9651934504508972, 0.8695720434188843, 0.795... [0.2851239740848541, 0.2851239740848541, 0.285... [0.2295081913471222, 0.2295081913471222, 0.229... 18 3459 49.263470
12 hard_sigmoid 10 0.042830 0.000189 15 91 0.786885 12 DONE 49.817882 [0.7628817558288574, 0.6877034902572632, 0.644... [0.7294394969940186, 0.6623942852020264, 0.617... [0.3264462947845459, 0.5413222908973694, 0.706... [0.26229506731033325, 0.7704917788505554, 0.77... 18 3459 51.414139
13 hard_sigmoid 11 0.052243 0.000181 48 88 0.803279 13 DONE 51.918125 [0.7297526001930237, 0.6772193312644958, 0.633... [0.6942347884178162, 0.6416444778442383, 0.600... [0.3719008266925812, 0.5909090638160706, 0.710... [0.44262295961380005, 0.7868852615356445, 0.78... 18 3345 54.095216
14 hard_sigmoid 10 0.041325 0.000049 47 126 0.786885 14 DONE 54.813164 [0.7548640966415405, 0.7348778247833252, 0.708... [0.7591538429260254, 0.7276302576065063, 0.703... [0.3016528785228729, 0.3677685856819153, 0.5, ... [0.21311475336551666, 0.26229506731033325, 0.4... 18 4789 57.250200
15 elu 65 0.089230 0.009213 95 8 0.803279 15 DONE 57.731290 [0.6084141731262207, 0.46208781003952026, 0.37... [0.41744959354400635, 0.3670448660850525, 0.37... [0.6363636255264282, 0.7685950398445129, 0.809... [0.8196721076965332, 0.8524590134620667, 0.836... 18 305 59.692247
16 elu 9 0.277494 0.000029 89 92 0.819672 16 DONE 60.467549 [0.7212059497833252, 0.7207241058349609, 0.674... [0.662761926651001, 0.6511141061782837, 0.6394... [0.56611567735672, 0.586776852607727, 0.611570... [0.6065573692321777, 0.6557376980781555, 0.672... 18 3497 63.682839
17 elu 9 0.367547 0.000011 88 90 0.803279 17 DONE 64.294921 [0.7646175622940063, 0.7496189475059509, 0.769... [0.7482584714889526, 0.7417523860931396, 0.734... [0.4834710657596588, 0.5123966932296753, 0.5, ... [0.49180328845977783, 0.49180328845977783, 0.4... 18 3421 67.540046
18 hard_sigmoid 100 0.484333 0.009965 46 47 0.819672 18 DONE 68.222909 [0.7115792632102966, 0.5767557621002197, 0.520... [0.5003692507743835, 0.47499191761016846, 0.42... [0.6983470916748047, 0.6942148804664612, 0.731... [0.7704917788505554, 0.7868852615356445, 0.770... 18 1787 69.957222
19 hard_sigmoid 214 0.480979 0.009989 46 50 0.852459 19 DONE 70.504195 [1.2610570192337036, 0.8506705164909363, 0.622... [0.8369094133377075, 0.5838184952735901, 0.470... [0.2851239740848541, 0.43388429284095764, 0.63... [0.2295081913471222, 0.7868852615356445, 0.786... 18 1901 72.019933
20 hard_sigmoid 254 0.481125 0.007501 31 127 0.852459 20 DONE 72.619805 [0.7543298602104187, 0.687750518321991, 0.6241... [0.5784507989883423, 0.5200879573822021, 0.499... [0.4958677589893341, 0.6652892827987671, 0.710... [0.7704917788505554, 0.7704917788505554, 0.770... 18 4827 74.013629
21 hard_sigmoid 254 0.479997 0.002899 19 109 0.770492 21 DONE 74.779004 [1.0613369941711426, 1.0073189735412598, 0.875... [0.9813054203987122, 0.8806155920028687, 0.792... [0.3471074402332306, 0.3181818127632141, 0.438... [0.2295081913471222, 0.2295081913471222, 0.229... 18 4143 76.249230
22 hard_sigmoid 225 0.479887 0.005930 27 120 0.852459 22 DONE 77.105368 [0.8292042016983032, 0.6387853622436523, 0.578... [0.5494740009307861, 0.4861133396625519, 0.481... [0.42148759961128235, 0.6776859760284424, 0.71... [0.7704917788505554, 0.7704917788505554, 0.770... 18 4561 78.530955
23 hard_sigmoid 210 0.470153 0.006405 19 123 0.852459 23 DONE 79.051023 [0.7748973369598389, 0.6380338072776794, 0.579... [0.5329240560531616, 0.47213122248649597, 0.44... [0.4793388545513153, 0.6611570119857788, 0.710... [0.7868852615356445, 0.7704917788505554, 0.786... 18 4675 80.376450
24 hard_sigmoid 251 0.469616 0.004705 17 123 0.786885 24 DONE 81.166138 [0.6553329229354858, 0.6020475029945374, 0.584... [0.5192084312438965, 0.4955504536628723, 0.479... [0.6322314143180847, 0.7355371713638306, 0.719... [0.7704917788505554, 0.7704917788505554, 0.770... 18 4675 82.451308
25 softsign 250 0.480790 0.007832 53 125 0.803279 25 DONE 83.011265 [0.7321151494979858, 0.5804288387298584, 0.476... [0.5269145369529724, 0.43378546833992004, 0.38... [0.46694216132164, 0.7148760557174683, 0.78512... [0.8196721076965332, 0.8196721076965332, 0.786... 18 4751 84.752865
26 hard_sigmoid 255 0.480756 0.005885 22 99 0.852459 26 DONE 85.420057 [0.6696801781654358, 0.609772801399231, 0.6074... [0.5007878541946411, 0.4793156087398529, 0.462... [0.6528925895690918, 0.6983470916748047, 0.706... [0.7704917788505554, 0.7704917788505554, 0.770... 18 3763 86.796966
27 hard_sigmoid 252 0.401441 0.006651 20 109 0.852459 27 DONE 87.506137 [1.0473660230636597, 0.8656474351882935, 0.706... [0.8114010095596313, 0.6466981172561646, 0.547... [0.32231405377388, 0.37603306770324707, 0.5619... [0.2295081913471222, 0.7868852615356445, 0.770... 18 4143 88.849887
28 hard_sigmoid 218 0.189225 0.006515 19 99 0.852459 28 DONE 89.431249 [0.6733125448226929, 0.583982527256012, 0.5316... [0.5367051959037781, 0.4850410521030426, 0.448... [0.5909090638160706, 0.7272727489471436, 0.727... [0.7704917788505554, 0.7704917788505554, 0.770... 18 3763 90.752168
29 hard_sigmoid 254 0.481146 0.005438 22 107 0.868852 29 DONE 91.540962 [0.8287509083747864, 0.7461135983467102, 0.665... [0.695465087890625, 0.6011614799499512, 0.5411... [0.38842976093292236, 0.5206611752510071, 0.61... [0.5081967115402222, 0.7704917788505554, 0.770... 18 4067 92.863941
30 hard_sigmoid 254 0.495178 0.004864 22 107 0.803279 30 DONE 93.677080 [0.9318445324897766, 0.8095709681510925, 0.722... [0.776494562625885, 0.6602926850318909, 0.5766... [0.3512396812438965, 0.42561984062194824, 0.56... [0.2295081913471222, 0.7540983557701111, 0.786... 18 4067 95.254718
31 hard_sigmoid 254 0.482023 0.005422 11 107 0.819672 31 DONE 96.240324 [0.7314707040786743, 0.674094021320343, 0.6030... [0.5905916690826416, 0.5288848280906677, 0.495... [0.5413222908973694, 0.6157024502754211, 0.648... [0.7704917788505554, 0.7704917788505554, 0.770... 18 4067 97.500010
32 hard_sigmoid 251 0.041958 0.005440 22 111 0.852459 32 DONE 98.373513 [0.6084542870521545, 0.5798646211624146, 0.559... [0.5280794501304626, 0.5053377747535706, 0.487... [0.7148760557174683, 0.7148760557174683, 0.714... [0.7704917788505554, 0.7704917788505554, 0.770... 18 4219 99.715976
33 linear 256 0.456569 0.005438 67 75 0.819672 33 DONE 100.393998 [0.6855553388595581, 0.5385969281196594, 0.455... [0.47430887818336487, 0.4145543575286865, 0.38... [0.6239669322967529, 0.7561983466148376, 0.776... [0.8032786846160889, 0.8032786846160889, 0.786... 18 2851 102.011602
34 hard_sigmoid 254 0.236034 0.005247 22 124 0.803279 34 DONE 102.587218 [0.6595293283462524, 0.6073529124259949, 0.598... [0.5491055846214294, 0.5108243823051453, 0.490... [0.6198347210884094, 0.6983470916748047, 0.719... [0.7704917788505554, 0.7704917788505554, 0.770... 18 4713 104.143555
35 elu 254 0.597516 0.005437 21 107 0.836066 35 DONE 105.333566 [0.747581958770752, 0.559327244758606, 0.46736... [0.5129536390304565, 0.4232344925403595, 0.378... [0.5330578684806824, 0.702479362487793, 0.8057... [0.7868852615356445, 0.7704917788505554, 0.786... 18 4067 106.656631
36 hard_sigmoid 218 0.462655 0.005438 22 21 0.786885 36 DONE 107.529558 [0.6573420166969299, 0.6241464018821716, 0.620... [0.5412810444831848, 0.5244167447090149, 0.507... [0.6652892827987671, 0.6487603187561035, 0.690... [0.7704917788505554, 0.7704917788505554, 0.770... 18 799 108.890761
37 hard_sigmoid 220 0.151721 0.007159 17 52 0.836066 37 DONE 109.478401 [1.2892435789108276, 0.9626411199569702, 0.750... [1.0040860176086426, 0.7683011293411255, 0.614... [0.2851239740848541, 0.2851239740848541, 0.433... [0.2295081913471222, 0.2295081913471222, 0.803... 18 1977 110.786285
38 hard_sigmoid 210 0.481841 0.005438 22 107 0.868852 38 DONE 111.632075 [0.7545424103736877, 0.6784554123878479, 0.577... [0.5750901699066162, 0.5008385181427002, 0.481... [0.4958677589893341, 0.6157024502754211, 0.739... [0.7704917788505554, 0.7704917788505554, 0.770... 18 4067 113.192420
39 hard_sigmoid 203 0.492826 0.005321 20 102 0.836066 39 DONE 113.682802 [0.8296107649803162, 0.6741348505020142, 0.592... [0.587567150592804, 0.4910070300102234, 0.4631... [0.43388429284095764, 0.5826446413993835, 0.68... [0.7868852615356445, 0.7704917788505554, 0.770... 18 3877 115.029475
40 relu 218 0.488815 0.005438 22 106 0.786885 40 DONE 115.841947 [0.7895721793174744, 0.6154360771179199, 0.539... [0.5899650454521179, 0.4901629090309143, 0.436... [0.5123966932296753, 0.6570248007774353, 0.710... [0.7704917788505554, 0.7213114500045776, 0.770... 18 4029 117.180042
41 linear 211 0.482496 0.005438 23 107 0.803279 41 DONE 118.049994 [0.6696746945381165, 0.45291176438331604, 0.39... [0.3958238661289215, 0.34970206022262573, 0.34... [0.5909090638160706, 0.7685950398445129, 0.814... [0.8196721076965332, 0.7868852615356445, 0.803... 18 4067 119.390520
42 hard_sigmoid 137 0.304672 0.005433 77 106 0.786885 42 DONE 119.806385 [0.6320827007293701, 0.5912678837776184, 0.528... [0.5139389634132385, 0.4769989848136902, 0.447... [0.6528925895690918, 0.7148760557174683, 0.739... [0.7704917788505554, 0.7704917788505554, 0.786... 18 4029 121.544481
43 hard_sigmoid 213 0.464688 0.005430 22 81 0.819672 43 DONE 122.233962 [0.791218638420105, 0.6553912162780762, 0.5934... [0.5946007966995239, 0.5060430765151978, 0.468... [0.4834710657596588, 0.5785123705863953, 0.698... [0.7704917788505554, 0.7704917788505554, 0.770... 18 3079 123.827419
44 hard_sigmoid 207 0.397137 0.005438 58 106 0.836066 44 DONE 124.510809 [1.029747724533081, 0.7393372058868408, 0.6286... [0.7274479269981384, 0.5479812026023865, 0.486... [0.3099173605442047, 0.5, 0.6528925895690918, ... [0.2950819730758667, 0.7704917788505554, 0.770... 18 4029 126.174829
45 hard_sigmoid 254 0.273684 0.006529 70 106 0.803279 45 DONE 126.860459 [0.6634817719459534, 0.5854498147964478, 0.573... [0.5192104578018188, 0.5170759558677673, 0.504... [0.7148760557174683, 0.7148760557174683, 0.743... [0.7704917788505554, 0.7704917788505554, 0.770... 18 4029 128.550364
46 hard_sigmoid 168 0.481791 0.005438 22 107 0.852459 46 DONE 129.907288 [0.8404354453086853, 0.6287763714790344, 0.606... [0.5717772841453552, 0.5005543231964111, 0.480... [0.4545454680919647, 0.6363636255264282, 0.714... [0.7704917788505554, 0.7704917788505554, 0.770... 18 4067 131.242743
47 hard_sigmoid 128 0.000276 0.005500 22 107 0.836066 47 DONE 131.834284 [0.6067198514938354, 0.5579516291618347, 0.525... [0.4999825358390808, 0.4646409749984741, 0.438... [0.71074378490448, 0.7148760557174683, 0.71487... [0.7704917788505554, 0.7704917788505554, 0.786... 18 4067 133.405986
48 hard_sigmoid 21 0.342111 0.006704 19 115 0.803279 48 DONE 133.948273 [0.6007430553436279, 0.42002755403518677, 0.35... [0.42077913880348206, 0.368651419878006, 0.369... [0.6611570119857788, 0.8057851195335388, 0.847... [0.8196721076965332, 0.8360655903816223, 0.852... 18 4371 135.453494
49 hard_sigmoid 210 0.451442 0.010000 35 35 0.836066 49 DONE 136.282800 [0.6316529512405396, 0.6350192427635193, 0.574... [0.5175772905349731, 0.48914796113967896, 0.46... [0.6818181872367859, 0.6942148804664612, 0.710... [0.7704917788505554, 0.7704917788505554, 0.770... 18 1331 137.721099
50 hard_sigmoid 245 0.483799 0.005438 22 123 0.852459 50 DONE 204.981866 [0.6570180654525757, 0.6534910202026367, 0.595... [0.5275946259498596, 0.5072992444038391, 0.489... [0.6652892827987671, 0.6942148804664612, 0.723... [0.7704917788505554, 0.7704917788505554, 0.770... 18 4675 206.361649
51 hard_sigmoid 250 0.345535 0.009981 51 96 0.836066 51 DONE 207.050688 [0.8433136940002441, 0.6663693785667419, 0.573... [0.6260725259780884, 0.5095583200454712, 0.465... [0.3719008266925812, 0.5702479481697083, 0.710... [0.8196721076965332, 0.7868852615356445, 0.770... 18 3649 208.571498
52 hard_sigmoid 251 0.201286 0.005980 22 123 0.852459 52 DONE 209.124434 [0.585923969745636, 0.5572856068611145, 0.5469... [0.4995526373386383, 0.4735236167907715, 0.453... [0.7355371713638306, 0.7148760557174683, 0.710... [0.7704917788505554, 0.7704917788505554, 0.770... 18 4675 210.710767
53 hard_sigmoid 156 0.281505 0.005449 38 123 0.836066 53 DONE 211.417707 [0.649643063545227, 0.5964835286140442, 0.5450... [0.5178518295288086, 0.4810403287410736, 0.452... [0.6487603187561035, 0.7148760557174683, 0.719... [0.7704917788505554, 0.7704917788505554, 0.770... 18 4675 212.908311
54 gelu 216 0.480514 0.009988 46 28 0.803279 54 DONE 213.778275 [0.5881480574607849, 0.49971804022789, 0.45157... [0.4912894368171692, 0.43112826347351074, 0.39... [0.6859503984451294, 0.7851239442825317, 0.809... [0.8196721076965332, 0.8196721076965332, 0.819... 18 1065 215.346050

Now that the search is over we can look at the evolution of performance of explored models.

[22]:
import matplotlib.pyplot as plt
from deephyper.analysis.hpo import plot_search_trajectory_single_objective_hpo

fig, ax = plot_search_trajectory_single_objective_hpo(results)
plt.axhline(results.iloc[0]["objective"], label="Default", color="red", linestyle="--")
plt.title("Search Trajectory")
plt.legend()
plt.show()
../../../../_images/tutorials_tutorials_colab_HPS_basic_classification_with_tabular_data_notebook_34_0.png

Then, we can look at the best configuration.

[20]:
i_max = results.objective.argmax()
best_job = results.iloc[i_max].to_dict()


print(f"The default configuration has an accuracy of {objective_default:.3f}. \n"
      f"The best configuration found by DeepHyper has an accuracy {results['objective'].iloc[i_max]:.3f}, \n"
      f"discovered after {results['m:timestamp_gather'].iloc[i_max]:.2f} secondes of search.\n")

best_job
The default configuration has an accuracy of 0.836.
The best configuration found by DeepHyper has an accuracy 0.869,
discovered after 92.86 secondes of search.

[20]:
{'p:activation': 'hard_sigmoid',
 'p:batch_size': 254,
 'p:dropout_rate': 0.4811464422566914,
 'p:learning_rate': 0.0054377979900459,
 'p:num_epochs': 22,
 'p:units': 107,
 'objective': 0.868852436542511,
 'job_id': 29,
 'job_status': 'DONE',
 'm:timestamp_submit': 91.5409619808197,
 'm:loss': '[0.8287509083747864, 0.7461135983467102, 0.6655626893043518, 0.6344414949417114, 0.5947127938270569, 0.5709614753723145, 0.5807883739471436, 0.587151050567627, 0.5646253824234009, 0.6007587909698486, 0.5480071902275085, 0.5170748233795166, 0.5336693525314331, 0.4768647849559784, 0.47153279185295105, 0.4784860610961914, 0.4682762622833252, 0.46586450934410095, 0.43191930651664734, 0.44100987911224365, 0.4295496344566345, 0.4277530312538147]',
 'm:val_loss': '[0.695465087890625, 0.6011614799499512, 0.5411741137504578, 0.5071674585342407, 0.49052631855010986, 0.48139819502830505, 0.4735262393951416, 0.46430036425590515, 0.4532046914100647, 0.4405619502067566, 0.4278675317764282, 0.4162490665912628, 0.40684399008750916, 0.40055155754089355, 0.39762410521507263, 0.3973425030708313, 0.3983728587627411, 0.4002285599708557, 0.4010710120201111, 0.40030139684677124, 0.39769506454467773, 0.39371100068092346]',
 'm:accuracy': '[0.38842976093292236, 0.5206611752510071, 0.6198347210884094, 0.6611570119857788, 0.7272727489471436, 0.7272727489471436, 0.7190082669258118, 0.71074378490448, 0.7272727489471436, 0.7066115736961365, 0.71074378490448, 0.7396694421768188, 0.7272727489471436, 0.7603305578231812, 0.7933884263038635, 0.7561983466148376, 0.7809917330741882, 0.7727272510528564, 0.7851239442825317, 0.8057851195335388, 0.8140496015548706, 0.8016529083251953]',
 'm:val_accuracy': '[0.5081967115402222, 0.7704917788505554, 0.7704917788505554, 0.7704917788505554, 0.7704917788505554, 0.7704917788505554, 0.7704917788505554, 0.7704917788505554, 0.7704917788505554, 0.7704917788505554, 0.7704917788505554, 0.7868852615356445, 0.7868852615356445, 0.8032786846160889, 0.7868852615356445, 0.7868852615356445, 0.8032786846160889, 0.8196721076965332, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.868852436542511]',
 'm:num_parameters': 18,
 'm:num_parameters_train': 4067,
 'm:timestamp_gather': 92.8639407157898}
[21]:
plt.figure()
plt.plot(json.loads(metadata_default["val_accuracy"]), color="skyblue", label="Default (val)")
plt.plot(json.loads(metadata_default["accuracy"]), color="skyblue", linestyle="--", label="Default (train)")
plt.plot(json.loads(best_job["m:val_accuracy"]), color="coral", linewidth=2, label="Best Job(val)")
plt.plot(json.loads(best_job["m:accuracy"]), color="coral", linestyle="--", linewidth=2, label="Best Job (train)")
plt.legend()
plt.ylim(0.3, 0.9)
plt.grid()
plt.show()
../../../../_images/tutorials_tutorials_colab_HPS_basic_classification_with_tabular_data_notebook_37_0.png

2.9. Restart from a checkpoint#

It can often be useful to continue the search from previous results. For example, if the allocation requested was not enough or if an unexpected crash happened. The CBO search provides the fit_surrogate(dataframe_of_results) method for this use case.

To simulate this we create a second evaluator evaluator_2 and start a fresh CBO search with strong explotation kappa=0.001.

[ ]:
# Create a new evaluator
evaluator_2 = get_evaluator(run)

# Create a new CBO search with strong explotation (i.e., small kappa)
search_from_checkpoint = CBO(
    problem,
    evaluator_2,
    acq_func="UCBd",
    kappa=0.001,
    acq_optimizer="mixedga",
    acq_optimizer_freq=1,
)

# Initialize surrogate model of Bayesian optization
# With results of previous search
search_from_checkpoint.fit_surrogate(results)
WARNING:root:Results file already exists, it will be renamed to /Users/romainegele/Documents/Argonne/deephyper-tutorials/tutorials/colab/HPS_basic_classification_with_tabular_data/results_20241216-113606.csv
Created new evaluator with 1 worker and config: {'num_cpus': 1, 'num_cpus_per_task': 1, 'callbacks': [<deephyper.evaluator.callback.TqdmCallback object at 0x3c94ecb90>]}
[24]:
results_from_checkpoint = search_from_checkpoint.search(max_evals=25)
[25]:
results_from_checkpoint
[25]:
p:activation p:batch_size p:dropout_rate p:learning_rate p:num_epochs p:units objective job_id job_status m:timestamp_submit m:loss m:val_loss m:accuracy m:val_accuracy m:num_parameters m:num_parameters_train m:timestamp_gather
0 hard_sigmoid 208 0.478206 0.005454 22 107 0.819672 0 DONE 33.978275 [0.9971071481704712, 0.7485412955284119, 0.662... [0.7065404653549194, 0.5395659804344177, 0.486... [0.3471074402332306, 0.4958677589893341, 0.628... [0.4754098355770111, 0.7704917788505554, 0.770... 18 4067 35.390350
1 hard_sigmoid 253 0.473896 0.006269 17 127 0.819672 1 DONE 35.950736 [1.3935201168060303, 1.0515316724777222, 0.836... [1.0693024396896362, 0.8184766173362732, 0.641... [0.2933884263038635, 0.3016528785228729, 0.446... [0.2295081913471222, 0.2295081913471222, 0.737... 18 4827 37.273664
2 hard_sigmoid 254 0.481205 0.005385 22 108 0.836066 2 DONE 39.145892 [1.0258198976516724, 0.8684240579605103, 0.718... [0.8224217295646667, 0.6814082264900208, 0.581... [0.3264462947845459, 0.41735535860061646, 0.50... [0.2295081913471222, 0.6229507923126221, 0.819... 18 4105 40.475309
3 gelu 212 0.524199 0.009018 27 116 0.786885 3 DONE 41.019375 [0.8138331770896912, 0.5084949135780334, 0.424... [0.494068443775177, 0.3931865990161896, 0.3644... [0.40082645416259766, 0.7685950398445129, 0.77... [0.8032786846160889, 0.8032786846160889, 0.786... 18 4409 42.734825
4 hard_sigmoid 217 0.524250 0.007515 28 114 0.836066 4 DONE 43.419275 [1.0927315950393677, 0.7291306853294373, 0.562... [0.6644625663757324, 0.5093420743942261, 0.467... [0.3140496015548706, 0.5, 0.7314049601554871, ... [0.6721311211585999, 0.7704917788505554, 0.770... 18 4333 44.829498
5 hard_sigmoid 255 0.481120 0.008677 22 107 0.852459 5 DONE 45.742031 [0.6349105834960938, 0.6272992491722107, 0.558... [0.5131592750549316, 0.4829488694667816, 0.457... [0.6900826692581177, 0.702479362487793, 0.7107... [0.7704917788505554, 0.7704917788505554, 0.770... 18 4067 47.083459
6 hard_sigmoid 223 0.482572 0.009001 22 124 0.836066 6 DONE 47.650057 [1.2967249155044556, 0.7333991527557373, 0.593... [0.7114768624305725, 0.48710745573043823, 0.49... [0.2975206673145294, 0.5206611752510071, 0.698... [0.4098360538482666, 0.7704917788505554, 0.770... 18 4713 48.999962
7 hard_sigmoid 254 0.101842 0.005435 22 102 0.836066 7 DONE 49.465036 [0.9877975583076477, 0.8333274722099304, 0.716... [0.8490940928459167, 0.711073637008667, 0.6106... [0.2851239740848541, 0.2933884263038635, 0.491... [0.2295081913471222, 0.3606557250022888, 0.754... 18 3877 50.795491
8 hard_sigmoid 254 0.467120 0.005438 22 107 0.852459 8 DONE 52.233492 [0.6846663951873779, 0.6341745257377625, 0.613... [0.5275814533233643, 0.5193657875061035, 0.509... [0.71074378490448, 0.7066115736961365, 0.68181... [0.7704917788505554, 0.7704917788505554, 0.770... 18 4067 53.847407
9 hard_sigmoid 252 0.209863 0.006543 19 124 0.852459 9 DONE 54.733154 [0.6476067900657654, 0.5874485373497009, 0.560... [0.5384824872016907, 0.5076740980148315, 0.476... [0.71074378490448, 0.7148760557174683, 0.73553... [0.7704917788505554, 0.7704917788505554, 0.770... 18 4713 56.087044
10 hard_sigmoid 251 0.235999 0.006532 18 124 0.836066 10 DONE 57.097306 [0.6024585962295532, 0.6163465976715088, 0.567... [0.5226372480392456, 0.5190794467926025, 0.495... [0.71074378490448, 0.702479362487793, 0.706611... [0.7704917788505554, 0.7704917788505554, 0.786... 18 4713 58.407120
11 hard_sigmoid 217 0.192716 0.006084 21 94 0.836066 11 DONE 59.171898 [0.6448632478713989, 0.5608571171760559, 0.553... [0.5072973370552063, 0.4817456305027008, 0.454... [0.64462810754776, 0.7272727489471436, 0.71487... [0.7704917788505554, 0.7704917788505554, 0.770... 18 3573 60.519694
12 hard_sigmoid 137 0.188666 0.006409 19 121 0.852459 12 DONE 61.262111 [0.9859618544578552, 0.645491898059845, 0.5461... [0.6592459082603455, 0.49092304706573486, 0.45... [0.2933884263038635, 0.6570248007774353, 0.710... [0.7540983557701111, 0.7704917788505554, 0.770... 18 4599 62.586249
13 hard_sigmoid 237 0.494145 0.005753 24 100 0.836066 13 DONE 63.357738 [0.6676536798477173, 0.5729154944419861, 0.566... [0.4894915521144867, 0.4660187363624573, 0.453... [0.5991735458374023, 0.6735537052154541, 0.739... [0.7868852615356445, 0.7868852615356445, 0.786... 18 3801 64.721456
14 hard_sigmoid 101 0.164318 0.006518 19 125 0.852459 14 DONE 65.256193 [0.8550885319709778, 0.5706097483634949, 0.528... [0.5197867155075073, 0.4509219527244568, 0.431... [0.40909090638160706, 0.7148760557174683, 0.71... [0.7868852615356445, 0.7704917788505554, 0.770... 18 4751 66.915138
15 hard_sigmoid 73 0.178585 0.006472 19 121 0.819672 15 DONE 67.734217 [0.8199191093444824, 0.5523964166641235, 0.541... [0.48327475786209106, 0.4553845524787903, 0.40... [0.4545454680919647, 0.7190082669258118, 0.714... [0.7704917788505554, 0.7704917788505554, 0.786... 18 4599 69.091438
16 hard_sigmoid 214 0.297488 0.009672 59 50 0.819672 16 DONE 70.074940 [0.9355904459953308, 0.6737803220748901, 0.558... [0.6501448750495911, 0.5055423974990845, 0.449... [0.3140496015548706, 0.5785123705863953, 0.723... [0.688524603843689, 0.7704917788505554, 0.7704... 18 1901 71.709189
17 gelu 101 0.093142 0.006611 19 125 0.770492 17 DONE 72.562575 [0.607595682144165, 0.42184796929359436, 0.338... [0.42516565322875977, 0.36241415143013, 0.3848... [0.71074378490448, 0.8181818127632141, 0.84297... [0.7868852615356445, 0.8196721076965332, 0.852... 18 4751 73.914067
18 hard_sigmoid 109 0.078303 0.006389 19 124 0.836066 18 DONE 74.884503 [0.623633623123169, 0.5488640666007996, 0.4732... [0.4851135313510895, 0.4532409906387329, 0.414... [0.71074378490448, 0.71074378490448, 0.7809917... [0.7704917788505554, 0.7704917788505554, 0.803... 18 4713 76.243901
19 hard_sigmoid 254 0.479748 0.005437 22 124 0.819672 19 DONE 77.795482 [0.6718873977661133, 0.6292739510536194, 0.610... [0.5227072834968567, 0.503911018371582, 0.4834... [0.6900826692581177, 0.6652892827987671, 0.698... [0.7704917788505554, 0.7704917788505554, 0.770... 18 4713 79.416509
20 hard_sigmoid 100 0.204713 0.006655 31 127 0.786885 20 DONE 80.211175 [0.6078946590423584, 0.5327951908111572, 0.477... [0.47984498739242554, 0.44338375329971313, 0.4... [0.7066115736961365, 0.7231404781341553, 0.747... [0.7704917788505554, 0.7868852615356445, 0.770... 18 4827 81.721249
21 hard_sigmoid 138 0.202129 0.006370 30 121 0.836066 21 DONE 82.733657 [0.6401527523994446, 0.5989956259727478, 0.549... [0.502800703048706, 0.46909981966018677, 0.433... [0.5991735458374023, 0.7148760557174683, 0.719... [0.7704917788505554, 0.7704917788505554, 0.786... 18 4599 84.153420
22 hard_sigmoid 101 0.167965 0.006624 24 125 0.819672 22 DONE 85.068344 [0.6539877653121948, 0.5791202783584595, 0.496... [0.4972912669181824, 0.44041258096694946, 0.40... [0.6115702390670776, 0.7148760557174683, 0.710... [0.7704917788505554, 0.7704917788505554, 0.836... 18 4751 86.462392
23 hard_sigmoid 214 0.474491 0.009992 46 30 0.819672 23 DONE 87.529800 [0.6667641401290894, 0.5999571084976196, 0.545... [0.49277263879776, 0.46131962537765503, 0.4375... [0.6115702390670776, 0.6818181872367859, 0.743... [0.7704917788505554, 0.7704917788505554, 0.786... 18 1141 89.057380
24 hard_sigmoid 243 0.201349 0.005321 19 123 0.852459 24 DONE 89.542004 [0.846777081489563, 0.7141139507293701, 0.6427... [0.7241750955581665, 0.6084526181221008, 0.538... [0.32231405377388, 0.42975205183029175, 0.6611... [0.31147539615631104, 0.7704917788505554, 0.77... 18 4675 90.894842
[26]:
i_max = results_from_checkpoint.objective.argmax()
best_job = results_from_checkpoint.iloc[i_max].to_dict()


print(f"The default configuration has an accuracy of {objective_default:.3f}. \n"
      f"The best configuration found by DeepHyper has an accuracy {results_from_checkpoint['objective'].iloc[i_max]:.3f}, \n"
      f"discovered after {results_from_checkpoint['m:timestamp_gather'].iloc[i_max]:.2f} secondes of search.\n")

best_job
The default configuration has an accuracy of 0.836.
The best configuration found by DeepHyper has an accuracy 0.852,
discovered after 47.08 secondes of search.

[26]:
{'p:activation': 'hard_sigmoid',
 'p:batch_size': 255,
 'p:dropout_rate': 0.481119568066835,
 'p:learning_rate': 0.0086765454003444,
 'p:num_epochs': 22,
 'p:units': 107,
 'objective': 0.8524590134620667,
 'job_id': 5,
 'job_status': 'DONE',
 'm:timestamp_submit': 45.74203109741211,
 'm:loss': '[0.6349105834960938, 0.6272992491722107, 0.5584259629249573, 0.5526432394981384, 0.5168532133102417, 0.5305481553077698, 0.4433193504810333, 0.45578300952911377, 0.4224569797515869, 0.42115506529808044, 0.41517341136932373, 0.40223243832588196, 0.3649812936782837, 0.3698986768722534, 0.38615867495536804, 0.34518924355506897, 0.3437093496322632, 0.3397452235221863, 0.35441818833351135, 0.33848950266838074, 0.30461177229881287, 0.34012123942375183]',
 'm:val_loss': '[0.5131592750549316, 0.4829488694667816, 0.4570136070251465, 0.43298590183258057, 0.41293779015541077, 0.3986186981201172, 0.3872995972633362, 0.37787696719169617, 0.3706487715244293, 0.36465173959732056, 0.36144769191741943, 0.35923588275909424, 0.3602728545665741, 0.3635474443435669, 0.36736416816711426, 0.37060967087745667, 0.37457913160324097, 0.3791694939136505, 0.38267746567726135, 0.3848181664943695, 0.3866138458251953, 0.390216201543808]',
 'm:accuracy': '[0.6900826692581177, 0.702479362487793, 0.71074378490448, 0.7066115736961365, 0.7727272510528564, 0.7148760557174683, 0.8016529083251953, 0.7561983466148376, 0.797520637512207, 0.8016529083251953, 0.8057851195335388, 0.8264462947845459, 0.8347107172012329, 0.8388429880142212, 0.8181818127632141, 0.8388429880142212, 0.8347107172012329, 0.8471074104309082, 0.8181818127632141, 0.8677685856819153, 0.8636363744735718, 0.8305785059928894]',
 'm:val_accuracy': '[0.7704917788505554, 0.7704917788505554, 0.7704917788505554, 0.7868852615356445, 0.7704917788505554, 0.7704917788505554, 0.7868852615356445, 0.7868852615356445, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.8196721076965332, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667]',
 'm:num_parameters': 18,
 'm:num_parameters_train': 4067,
 'm:timestamp_gather': 47.08345913887024}
[27]:
plt.figure()
plt.plot(json.loads(metadata_default["val_accuracy"]), color="skyblue", label="Default (val)")
plt.plot(json.loads(metadata_default["accuracy"]), color="skyblue", linestyle="--", label="Default (train)")
plt.plot(json.loads(best_job["m:val_accuracy"]), color="coral", linewidth=2, label="Best Job(val)")
plt.plot(json.loads(best_job["m:accuracy"]), color="coral", linestyle="--", linewidth=2, label="Best Job (train)")
plt.legend()
plt.ylim(0.5, 0.9)
plt.grid()
plt.show()
../../../../_images/tutorials_tutorials_colab_HPS_basic_classification_with_tabular_data_notebook_43_0.png

2.10. Add conditional hyperparameters#

Now we want to add the possibility to search for a second fully-connected layer. We simply add two new lines:

if config.get("dense_2", False):
    x = tfk.layers.Dense(config["dense_2:units"], activation=config["dense_2:activation"])(x)
[29]:
def run_with_condition(config: dict):
    tf.autograph.set_verbosity(0)

    train_dataframe, val_dataframe = load_data()

    train_ds = dataframe_to_dataset(train_dataframe)
    val_ds = dataframe_to_dataset(val_dataframe)

    train_ds = train_ds.batch(config["batch_size"])
    val_ds = val_ds.batch(config["batch_size"])

    # Categorical features encoded as integers
    sex = tfk.Input(shape=(1,), name="sex", dtype="int64")
    cp = tfk.Input(shape=(1,), name="cp", dtype="int64")
    fbs = tfk.Input(shape=(1,), name="fbs", dtype="int64")
    restecg = tfk.Input(shape=(1,), name="restecg", dtype="int64")
    exang = tfk.Input(shape=(1,), name="exang", dtype="int64")
    ca = tfk.Input(shape=(1,), name="ca", dtype="int64")

    # Categorical feature encoded as string
    thal = tfk.Input(shape=(1,), name="thal", dtype="string")

    # Numerical features
    age = tfk.Input(shape=(1,), name="age")
    trestbps = tfk.Input(shape=(1,), name="trestbps")
    chol = tfk.Input(shape=(1,), name="chol")
    thalach = tfk.Input(shape=(1,), name="thalach")
    oldpeak = tfk.Input(shape=(1,), name="oldpeak")
    slope = tfk.Input(shape=(1,), name="slope")

    all_inputs = [
        sex,
        cp,
        fbs,
        restecg,
        exang,
        ca,
        thal,
        age,
        trestbps,
        chol,
        thalach,
        oldpeak,
        slope,
    ]

    # Integer categorical features
    sex_encoded = encode_categorical_feature(sex, "sex", train_ds, False)
    cp_encoded = encode_categorical_feature(cp, "cp", train_ds, False)
    fbs_encoded = encode_categorical_feature(fbs, "fbs", train_ds, False)
    restecg_encoded = encode_categorical_feature(restecg, "restecg", train_ds, False)
    exang_encoded = encode_categorical_feature(exang, "exang", train_ds, False)
    ca_encoded = encode_categorical_feature(ca, "ca", train_ds, False)

    # String categorical features
    thal_encoded = encode_categorical_feature(thal, "thal", train_ds, True)

    # Numerical features
    age_encoded = encode_numerical_feature(age, "age", train_ds)
    trestbps_encoded = encode_numerical_feature(trestbps, "trestbps", train_ds)
    chol_encoded = encode_numerical_feature(chol, "chol", train_ds)
    thalach_encoded = encode_numerical_feature(thalach, "thalach", train_ds)
    oldpeak_encoded = encode_numerical_feature(oldpeak, "oldpeak", train_ds)
    slope_encoded = encode_numerical_feature(slope, "slope", train_ds)

    all_features = tfk.layers.concatenate(
        [
            sex_encoded,
            cp_encoded,
            fbs_encoded,
            restecg_encoded,
            exang_encoded,
            slope_encoded,
            ca_encoded,
            thal_encoded,
            age_encoded,
            trestbps_encoded,
            chol_encoded,
            thalach_encoded,
            oldpeak_encoded,
        ]
    )
    x = tfk.layers.Dense(config["units"], activation=config["activation"])(
        all_features
    )

    ### START - NEW LINES
    if config.get("dense_2", False):
        x = tfk.layers.Dense(config["dense_2:units"], activation=config["dense_2:activation"])(x)
    ### END - NEW LINES

    x = tfk.layers.Dropout(config["dropout_rate"])(x)
    output = tfk.layers.Dense(1, activation="sigmoid")(x)
    model = tfk.Model(all_inputs, output)

    optimizer = tfk.optimizers.Adam(learning_rate=config["learning_rate"])
    model.compile(optimizer, "binary_crossentropy", metrics=["accuracy"])

    try:
        history = model.fit(
            train_ds, epochs=config["num_epochs"], validation_data=val_ds, verbose=0
        )
    except:
        class History:
            history = {
                "accuracy": None,
                "val_accuracy": ["F_fit"],
                "loss": None,
                "val_loss": None,
            }

        history = History()

    objective = history.history["val_accuracy"][-1]
    metadata = {
        "loss": history.history["loss"],
        "val_loss": history.history["val_loss"],
        "accuracy": history.history["accuracy"],
        "val_accuracy": history.history["val_accuracy"],
    }
    metadata = {k:json.dumps(v) for k,v in metadata.items()}
    metadata.update(count_params(model))

    return {"objective": objective, "metadata": metadata}

To define conditionnal hyperparameters we use ConfigSpace. We define dense_2:units and dense_2:activation as active hyperparameters only when dense_2 == True. The cs.EqualsCondition help us do that. Then we call

problem_with_condition.add_condition(condition)

to register each new condition to the HpProblem.

[30]:
from ConfigSpace import EqualsCondition

# Define the hyperparameter problem
problem_with_condition = HpProblem()


# Define the same hyperparameters as before
problem_with_condition.add_hyperparameter((8, 128), "units")
problem_with_condition.add_hyperparameter(ACTIVATIONS, "activation")
problem_with_condition.add_hyperparameter((0.0, 0.6), "dropout_rate")
problem_with_condition.add_hyperparameter((10, 100), "num_epochs")
problem_with_condition.add_hyperparameter((8, 256, "log-uniform"), "batch_size")
problem_with_condition.add_hyperparameter((1e-5, 1e-2, "log-uniform"), "learning_rate")


# Add a new hyperparameter "dense_2 (bool)" to decide if a second fully-connected layer should be created
hp_dense_2 = problem_with_condition.add_hyperparameter([True, False], "dense_2")
hp_dense_2_units = problem_with_condition.add_hyperparameter((8, 128), "dense_2:units")
hp_dense_2_activation = problem_with_condition.add_hyperparameter(ACTIVATIONS, "dense_2:activation")

problem_with_condition.add_condition(EqualsCondition(hp_dense_2_units, hp_dense_2, True))
problem_with_condition.add_condition(EqualsCondition(hp_dense_2_activation, hp_dense_2, True))


problem_with_condition
[30]:
Configuration space object:
  Hyperparameters:
    activation, Type: Categorical, Choices: {elu, gelu, hard_sigmoid, linear, relu, selu, sigmoid, softplus, softsign, swish, tanh}, Default: elu
    batch_size, Type: UniformInteger, Range: [8, 256], Default: 45, on log-scale
    dense_2, Type: Categorical, Choices: {True, False}, Default: True
    dense_2:activation, Type: Categorical, Choices: {elu, gelu, hard_sigmoid, linear, relu, selu, sigmoid, softplus, softsign, swish, tanh}, Default: elu
    dense_2:units, Type: UniformInteger, Range: [8, 128], Default: 68
    dropout_rate, Type: UniformFloat, Range: [0.0, 0.6], Default: 0.3
    learning_rate, Type: UniformFloat, Range: [1e-05, 0.01], Default: 0.000316227766, on log-scale
    num_epochs, Type: UniformInteger, Range: [10, 100], Default: 55
    units, Type: UniformInteger, Range: [8, 128], Default: 68
  Conditions:
    dense_2:activation | dense_2 == True
    dense_2:units | dense_2 == True

We create a new evaluator evaluator_3 and start a fresh CBO search with this new problem problem_with_condition.

[32]:
evaluator_3 = get_evaluator(run_with_condition)

search_with_condition = CBO(
    problem_with_condition,
    evaluator_3,
    acq_func="UCBd",
    acq_optimizer="mixedga",
    acq_optimizer_freq=1,
)
Created new evaluator with 1 worker and config: {'num_cpus': 1, 'num_cpus_per_task': 1, 'callbacks': [<deephyper.evaluator.callback.TqdmCallback object at 0x3cb0dee70>]}
[33]:
results_with_condition = search_with_condition.search(max_evals=50)
[34]:
results_with_condition
[34]:
p:activation p:batch_size p:dense_2 p:dropout_rate p:learning_rate p:num_epochs p:units p:dense_2:activation p:dense_2:units objective job_id job_status m:timestamp_submit m:loss m:val_loss m:accuracy m:val_accuracy m:num_parameters m:num_parameters_train m:timestamp_gather
0 tanh 15 True 0.366818 0.005029 26 60 softplus 105 0.7704917788505554 0 DONE 11.082239 [0.5009121298789978, 0.3078427314758301, 0.338... [0.4244871735572815, 0.5283817648887634, 0.486... [0.7479338645935059, 0.8512396812438965, 0.863... [0.8360655903816223, 0.8360655903816223, 0.819... 18 8731 12.913492
1 tanh 8 True 0.390801 0.000175 81 30 hard_sigmoid 68 0.8032786846160889 1 DONE 13.528550 [0.604907214641571, 0.5923880338668823, 0.5909... [0.5101694464683533, 0.4980083405971527, 0.486... [0.71074378490448, 0.7231404781341553, 0.70661... [0.7704917788505554, 0.7704917788505554, 0.770... 18 3287 17.357964
2 tanh 114 False 0.196952 0.000011 17 100 elu 8 0.4754098355770111 2 DONE 18.065723 [0.7861196994781494, 0.7642049789428711, 0.761... [0.7636403441429138, 0.7626331448554993, 0.761... [0.40909090638160706, 0.45041322708129883, 0.4... [0.44262295961380005, 0.44262295961380005, 0.4... 18 3801 19.381279
3 selu 115 False 0.077023 0.000019 25 50 elu 8 0.32786884903907776 3 DONE 19.988189 [0.9111230969429016, 0.9428355097770691, 0.920... [0.9447649121284485, 0.9423429369926453, 0.940... [0.37603306770324707, 0.3636363744735718, 0.37... [0.2950819730758667, 0.2950819730758667, 0.295... 18 1901 21.368746
4 relu 73 False 0.476480 0.003393 19 34 elu 8 0.8032786846160889 4 DONE 21.978687 [0.7194283604621887, 0.6411665678024292, 0.568... [0.6180158257484436, 0.5356436967849731, 0.484... [0.5413222908973694, 0.6611570119857788, 0.698... [0.6393442749977112, 0.688524603843689, 0.7540... 18 1293 23.331681
5 selu 125 False 0.534733 0.000035 84 74 elu 8 0.6721311211585999 5 DONE 23.939715 [0.7867716550827026, 0.7413351535797119, 0.748... [0.7112182378768921, 0.7088605761528015, 0.706... [0.557851254940033, 0.6074380278587341, 0.5619... [0.49180328845977783, 0.49180328845977783, 0.4... 18 2813 25.724129
6 softplus 65 False 0.568267 0.004382 34 30 elu 8 0.8360655903816223 6 DONE 26.333320 [0.7672441005706787, 0.6641848087310791, 0.563... [0.4976312518119812, 0.45320722460746765, 0.42... [0.5991735458374023, 0.6611570119857788, 0.727... [0.8032786846160889, 0.8032786846160889, 0.786... 18 1141 27.797871
7 linear 55 False 0.162118 0.000082 50 44 elu 8 0.7377049326896667 7 DONE 28.413964 [0.7725011110305786, 0.7690296769142151, 0.750... [0.7852056622505188, 0.7744285464286804, 0.763... [0.44628098607063293, 0.4834710657596588, 0.52... [0.3442623019218445, 0.3606557250022888, 0.393... 18 1673 30.404005
8 hard_sigmoid 11 True 0.352241 0.000210 99 61 elu 125 0.8196721076965332 8 DONE 31.013985 [0.7297561168670654, 0.5910099148750305, 0.583... [0.5413499474525452, 0.5040895342826843, 0.501... [0.5165289044380188, 0.7190082669258118, 0.719... [0.7704917788505554, 0.7704917788505554, 0.770... 18 10133 34.436101
9 gelu 185 False 0.341791 0.003917 65 17 elu 8 0.8032786846160889 9 DONE 35.063485 [0.7614551186561584, 0.7111193537712097, 0.664... [0.6993500590324402, 0.6591216921806335, 0.622... [0.4710743725299835, 0.5454545617103577, 0.648... [0.6229507923126221, 0.6393442749977112, 0.655... 18 647 36.736329
10 softplus 66 False 0.584467 0.009327 36 70 elu 8 0.8032786846160889 10 DONE 38.730741 [0.6659552454948425, 0.47290360927581787, 0.41... [0.39005225896835327, 0.3625498414039612, 0.36... [0.6735537052154541, 0.7685950398445129, 0.822... [0.8196721076965332, 0.868852436542511, 0.8524... 18 2661 40.221937
11 softplus 54 False 0.570221 0.000206 26 20 elu 8 0.7868852615356445 11 DONE 41.766595 [0.8241037130355835, 0.8505690097808838, 0.879... [0.7222375273704529, 0.7125732898712158, 0.703... [0.5206611752510071, 0.5082644820213318, 0.5, ... [0.44262295961380005, 0.44262295961380005, 0.4... 18 761 43.202924
12 softsign 64 False 0.246370 0.004419 34 30 elu 8 0.8032786846160889 12 DONE 45.144218 [0.730465292930603, 0.5761306881904602, 0.4992... [0.6262984275817871, 0.49016880989074707, 0.42... [0.56611567735672, 0.7066115736961365, 0.75619... [0.6557376980781555, 0.7868852615356445, 0.836... 18 1141 46.601122
13 softplus 118 False 0.568740 0.004097 34 30 elu 8 0.8360655903816223 13 DONE 48.957806 [1.8636727333068848, 1.5934174060821533, 1.166... [1.5479514598846436, 1.1948175430297852, 0.919... [0.3057851195335388, 0.3264462947845459, 0.388... [0.2295081913471222, 0.2295081913471222, 0.229... 18 1141 50.747012
14 sigmoid 147 False 0.568471 0.001517 34 30 elu 8 F_fit 14 DONE 52.141820 NaN NaN NaN ["F_fit"] 18 1141 53.529244
15 softplus 141 False 0.562377 0.003627 34 30 elu 8 0.8524590134620667 15 DONE 55.616748 [0.8003625869750977, 0.7473790049552917, 0.699... [0.5667710900306702, 0.530837893486023, 0.5010... [0.6487603187561035, 0.6611570119857788, 0.681... [0.7704917788505554, 0.7704917788505554, 0.770... 18 1141 57.045611
16 softplus 144 False 0.562273 0.002747 34 30 elu 8 0.8524590134620667 16 DONE 59.651056 [0.6615883111953735, 0.644455075263977, 0.6060... [0.4805201292037964, 0.46684712171554565, 0.45... [0.6735537052154541, 0.6694214940071106, 0.706... [0.7704917788505554, 0.7868852615356445, 0.786... 18 1141 61.075928
17 softplus 193 False 0.560263 0.002264 34 31 elu 8 0.8196721076965332 17 DONE 62.720634 [1.322947382926941, 1.3005346059799194, 1.1319... [1.1749550104141235, 1.0512288808822632, 0.940... [0.3801652789115906, 0.3512396812438965, 0.400... [0.2295081913471222, 0.2295081913471222, 0.245... 18 1179 64.146442
18 sigmoid 146 False 0.561713 0.002848 32 29 elu 8 0.8032786846160889 18 DONE 66.090119 [1.0363850593566895, 0.9314866662025452, 0.875... [0.9298886060714722, 0.8662550449371338, 0.808... [0.3181818127632141, 0.3595041334629059, 0.409... [0.2295081913471222, 0.2295081913471222, 0.229... 18 1103 67.498575
19 softplus 137 False 0.562651 0.001992 34 30 elu 8 0.8360655903816223 19 DONE 69.450445 [0.8851305842399597, 0.7541695833206177, 0.730... [0.6303464770317078, 0.5886359810829163, 0.552... [0.5, 0.557851254940033, 0.5702479481697083, 0... [0.6393442749977112, 0.688524603843689, 0.7540... 18 1141 70.869987
20 softplus 143 False 0.562130 0.001625 34 30 elu 8 0.7868852615356445 20 DONE 73.525754 [0.7884242534637451, 0.72602379322052, 0.67113... [0.5665600299835205, 0.5440786480903625, 0.524... [0.6859503984451294, 0.7231404781341553, 0.677... [0.7704917788505554, 0.7704917788505554, 0.770... 18 1141 75.315221
21 softplus 161 True 0.560928 0.003245 34 30 selu 10 0.8360655903816223 21 DONE 76.920507 [0.7835374474525452, 0.7212239503860474, 0.639... [0.514182984828949, 0.4651140868663788, 0.4428... [0.5330578684806824, 0.6239669322967529, 0.677... [0.8032786846160889, 0.7704917788505554, 0.770... 18 1431 78.415666
22 softplus 142 False 0.562011 0.002741 24 21 elu 8 0.8196721076965332 22 DONE 80.673716 [0.871577262878418, 0.7561327815055847, 0.7247... [0.6371726989746094, 0.584244966506958, 0.5413... [0.4834710657596588, 0.5785123705863953, 0.595... [0.6065573692321777, 0.7377049326896667, 0.754... 18 799 82.031128
23 tanh 141 False 0.562200 0.002750 56 30 elu 8 0.8196721076965332 23 DONE 84.483893 [0.7138877511024475, 0.6605839729309082, 0.624... [0.6050224304199219, 0.5486882328987122, 0.506... [0.5826446413993835, 0.6528925895690918, 0.698... [0.688524603843689, 0.7540983557701111, 0.7377... 18 1141 86.070269
24 softplus 148 False 0.569910 0.002346 35 30 elu 8 0.8524590134620667 24 DONE 87.728148 [0.7038779854774475, 0.70468670129776, 0.63516... [0.443899005651474, 0.43171101808547974, 0.421... [0.6487603187561035, 0.6074380278587341, 0.694... [0.7868852615356445, 0.8032786846160889, 0.803... 18 1141 89.166137
25 softplus 149 False 0.570141 0.001867 39 30 elu 8 F_fit 25 DONE 91.996638 NaN NaN NaN ["F_fit"] 18 1141 93.473902
26 softplus 151 False 0.571299 0.002355 37 30 elu 8 0.868852436542511 26 DONE 95.618133 [0.9386454820632935, 0.8260695934295654, 0.804... [0.7142730951309204, 0.6488375067710876, 0.594... [0.4628099203109741, 0.5165289044380188, 0.491... [0.4754098355770111, 0.6393442749977112, 0.737... 18 1141 97.060111
27 softplus 154 False 0.571641 0.002241 40 30 elu 8 0.8524590134620667 27 DONE 99.360630 [0.7653211355209351, 0.7396792769432068, 0.655... [0.5599848031997681, 0.5429885387420654, 0.527... [0.64462810754776, 0.6528925895690918, 0.69834... [0.7704917788505554, 0.7704917788505554, 0.770... 18 1141 101.235053
28 softplus 151 False 0.570702 0.002098 38 29 elu 8 0.868852436542511 28 DONE 103.152042 [0.8922500014305115, 0.9002847075462341, 0.840... [0.7290818095207214, 0.6712472438812256, 0.621... [0.5082644820213318, 0.4793388545513153, 0.512... [0.4754098355770111, 0.5901639461517334, 0.688... 18 1103 104.613654
29 softplus 151 False 0.571097 0.001946 37 29 elu 8 0.8196721076965332 29 DONE 107.597612 [1.4136673212051392, 1.2668572664260864, 1.105... [1.1506351232528687, 1.0501872301101685, 0.959... [0.3512396812438965, 0.3677685856819153, 0.417... [0.2295081913471222, 0.26229506731033325, 0.22... 18 1103 109.052432
30 softplus 149 False 0.570417 0.002060 39 29 elu 8 0.8196721076965332 30 DONE 111.246263 [0.9358372688293457, 0.9006103873252869, 0.906... [0.7770434617996216, 0.7139856219291687, 0.657... [0.4793388545513153, 0.5123966932296753, 0.466... [0.3442623019218445, 0.5245901346206665, 0.590... 18 1103 112.719204
31 softplus 151 True 0.573680 0.002096 37 36 selu 8 0.8360655903816223 31 DONE 114.547893 [2.4386861324310303, 2.143130302429199, 1.8441... [2.1191811561584473, 1.703428030014038, 1.3390... [0.3512396812438965, 0.3677685856819153, 0.347... [0.2295081913471222, 0.2295081913471222, 0.213... 18 1637 116.050490
32 softplus 152 False 0.567930 0.002125 38 41 elu 8 0.8032786846160889 32 DONE 117.710405 [1.1407198905944824, 1.1017353534698486, 0.989... [0.8478963971138, 0.7879840731620789, 0.732605... [0.71074378490448, 0.71074378490448, 0.7024793... [0.7704917788505554, 0.7704917788505554, 0.770... 18 1559 119.158029
33 softplus 151 True 0.600000 0.002942 38 26 elu 30 0.8524590134620667 33 DONE 120.565808 [1.1980469226837158, 0.9109521508216858, 0.743... [0.8629004955291748, 0.6327856779098511, 0.524... [0.3595041334629059, 0.4586776793003082, 0.541... [0.2295081913471222, 0.7213114500045776, 0.786... 18 1803 122.087813
34 softplus 13 False 0.571321 0.002379 37 30 elu 8 0.8196721076965332 34 DONE 124.607431 [0.7201182842254639, 0.6420775055885315, 0.508... [0.49458569288253784, 0.4283734858036041, 0.40... [0.6157024502754211, 0.6983470916748047, 0.739... [0.7868852615356445, 0.7704917788505554, 0.803... 18 1141 126.484019
35 softplus 163 True 0.597093 0.001868 52 30 elu 37 0.8196721076965332 35 DONE 128.893857 [1.0655272006988525, 0.815341055393219, 0.6936... [0.837419331073761, 0.6606911420822144, 0.5658... [0.3181818127632141, 0.42975205183029175, 0.57... [0.24590164422988892, 0.6393442749977112, 0.78... 18 2295 130.914070
36 softplus 151 False 0.528450 0.002494 37 29 elu 8 0.8196721076965332 36 DONE 133.153525 [0.746720552444458, 0.7041527628898621, 0.6633... [0.5352126359939575, 0.5036073327064514, 0.478... [0.6074380278587341, 0.6157024502754211, 0.661... [0.7540983557701111, 0.7868852615356445, 0.786... 18 1103 134.597995
37 softplus 150 False 0.569545 0.002119 38 31 elu 8 0.8196721076965332 37 DONE 137.581175 [0.8011699914932251, 0.7314833998680115, 0.690... [0.5537089705467224, 0.5326763987541199, 0.514... [0.557851254940033, 0.6363636255264282, 0.6404... [0.7704917788505554, 0.7704917788505554, 0.770... 18 1179 139.039871
38 softplus 152 True 0.597162 0.002118 38 19 elu 118 0.8032786846160889 38 DONE 141.612782 [0.6528862714767456, 0.612349808216095, 0.5584... [0.515251874923706, 0.4898030161857605, 0.4741... [0.6033057570457458, 0.702479362487793, 0.7148... [0.7704917788505554, 0.7704917788505554, 0.770... 18 3182 143.174587
39 softplus 153 False 0.593458 0.002099 38 13 elu 8 0.7540983557701111 39 DONE 145.678543 [1.0391321182250977, 1.0277173519134521, 0.980... [0.9314367771148682, 0.8901909589767456, 0.851... [0.41322314739227295, 0.35537189245224, 0.4090... [0.24590164422988892, 0.26229506731033325, 0.2... 18 495 147.133688
40 softplus 151 False 0.595459 0.002107 38 26 elu 8 0.8196721076965332 40 DONE 149.485027 [0.7215235829353333, 0.6772743463516235, 0.686... [0.495942085981369, 0.4821834862232208, 0.4688... [0.6570248007774353, 0.6528925895690918, 0.619... [0.8032786846160889, 0.8032786846160889, 0.803... 18 989 150.936354
41 softplus 155 True 0.584819 0.002361 37 30 elu 63 0.8360655903816223 41 DONE 152.818419 [0.7417583465576172, 0.624437153339386, 0.6375... [0.5465797781944275, 0.5226823091506958, 0.501... [0.557851254940033, 0.6776859760284424, 0.7107... [0.7704917788505554, 0.7704917788505554, 0.770... 18 3127 154.318564
42 softplus 144 True 0.599080 0.003167 38 26 elu 124 0.8196721076965332 42 DONE 156.350498 [0.7083154916763306, 0.6766270399093628, 0.616... [0.5326813459396362, 0.5164983868598938, 0.473... [0.557851254940033, 0.6942148804664612, 0.7107... [0.7704917788505554, 0.7704917788505554, 0.770... 18 4435 157.875385
43 softplus 153 False 0.562895 0.009250 41 30 elu 8 0.8196721076965332 43 DONE 159.693906 [1.2732993364334106, 0.9017221927642822, 0.743... [0.8885003328323364, 0.6285920143127441, 0.500... [0.3388429880142212, 0.42975205183029175, 0.54... [0.24590164422988892, 0.7049180269241333, 0.83... 18 1141 161.611814
44 softplus 149 False 0.569925 0.002351 42 30 elu 8 0.8196721076965332 44 DONE 164.472124 [0.7111098766326904, 0.7342808246612549, 0.696... [0.5153035521507263, 0.49886760115623474, 0.48... [0.6652892827987671, 0.6404958963394165, 0.669... [0.7704917788505554, 0.7704917788505554, 0.770... 18 1141 165.953773
45 softplus 151 False 0.574937 0.002082 39 30 elu 8 0.8360655903816223 45 DONE 169.389578 [0.7992091774940491, 0.7631123661994934, 0.693... [0.5029991865158081, 0.48447418212890625, 0.46... [0.6074380278587341, 0.6322314143180847, 0.661... [0.7704917788505554, 0.7704917788505554, 0.770... 18 1141 170.948393
46 softplus 152 False 0.581180 0.002565 29 30 elu 8 0.8032786846160889 46 DONE 172.951565 [0.9612069725990295, 0.9028563499450684, 0.846... [0.7866339683532715, 0.7080065608024597, 0.640... [0.43801653385162354, 0.4958677589893341, 0.5,... [0.32786884903907776, 0.4754098355770111, 0.57... 18 1141 174.347412
47 softplus 152 False 0.579499 0.002451 37 30 elu 8 0.8360655903816223 47 DONE 177.674756 [1.149562954902649, 0.9906631112098694, 0.9415... [0.8440399765968323, 0.77037513256073, 0.70880... [0.41735535860061646, 0.4958677589893341, 0.48... [0.37704917788505554, 0.4754098355770111, 0.55... 18 1141 179.122212
48 softplus 161 True 0.599999 0.003029 38 23 elu 30 0.8360655903816223 48 DONE 182.594063 [0.7441322207450867, 0.6444833874702454, 0.656... [0.5382938981056213, 0.485158771276474, 0.4663... [0.5330578684806824, 0.6652892827987671, 0.690... [0.7704917788505554, 0.7704917788505554, 0.770... 18 1602 184.113826
49 softplus 151 False 0.572341 0.002377 32 30 elu 8 0.8360655903816223 49 DONE 187.173539 [0.6830765008926392, 0.7289465665817261, 0.654... [0.5346873998641968, 0.5149286389350891, 0.497... [0.6652892827987671, 0.6652892827987671, 0.648... [0.7704917788505554, 0.7704917788505554, 0.770... 18 1141 188.601617

Finally, let us print out the best configuration found from this conditionned search space.

[37]:
from deephyper.analysis.hpo import filter_failed_objectives

results_with_condition_without_failure, _ = filter_failed_objectives(results_with_condition)
i_max = results_with_condition_without_failure.objective.argmax()
best_job = results_with_condition_without_failure.iloc[i_max].to_dict()


print(f"The default configuration has an accuracy of {objective_default:.3f}. \n"
      f"The best configuration found by DeepHyper has an accuracy {results_with_condition_without_failure['objective'].iloc[i_max]:.3f}, \n"
      f"discovered after {results_with_condition_without_failure['m:timestamp_gather'].iloc[i_max]:.2f} secondes of search.\n")

best_job
The default configuration has an accuracy of 0.836.
The best configuration found by DeepHyper has an accuracy 0.869,
discovered after 97.06 secondes of search.

[37]:
{'p:activation': 'softplus',
 'p:batch_size': 151,
 'p:dense_2': False,
 'p:dropout_rate': 0.5712985151526,
 'p:learning_rate': 0.0023551435679,
 'p:num_epochs': 37,
 'p:units': 30,
 'p:dense_2:activation': 'elu',
 'p:dense_2:units': 8,
 'objective': 0.868852436542511,
 'job_id': 26,
 'job_status': 'DONE',
 'm:timestamp_submit': 95.61813306808472,
 'm:loss': '[0.9386454820632935, 0.8260695934295654, 0.8045966029167175, 0.794924259185791, 0.7922928929328918, 0.71309894323349, 0.674532949924469, 0.6240278482437134, 0.5885002017021179, 0.5997695922851562, 0.5670955181121826, 0.614417552947998, 0.5098029375076294, 0.511737585067749, 0.5198138952255249, 0.5427704453468323, 0.5118103623390198, 0.5001246929168701, 0.48275113105773926, 0.4626578986644745, 0.4975070357322693, 0.4643486738204956, 0.42703598737716675, 0.4019463360309601, 0.44776079058647156, 0.4671458303928375, 0.43884947896003723, 0.4610165059566498, 0.4526631236076355, 0.3922064006328583, 0.434793084859848, 0.41919174790382385, 0.4202392101287842, 0.41499119997024536, 0.38881367444992065, 0.4379543960094452, 0.41741785407066345]',
 'm:val_loss': '[0.7142730951309204, 0.6488375067710876, 0.5947321653366089, 0.550585150718689, 0.5155931115150452, 0.4885355830192566, 0.46746018528938293, 0.45108214020729065, 0.4376329481601715, 0.42609837651252747, 0.4160616993904114, 0.4073067009449005, 0.39975255727767944, 0.39315667748451233, 0.3873841166496277, 0.3822612166404724, 0.3777984380722046, 0.37391534447669983, 0.3705607056617737, 0.3678196668624878, 0.36563098430633545, 0.3638599216938019, 0.3621633052825928, 0.3607686161994934, 0.3599274456501007, 0.3591628968715668, 0.3585789203643799, 0.3581162095069885, 0.3575083911418915, 0.35662978887557983, 0.355956494808197, 0.3556397557258606, 0.3556240200996399, 0.3558986783027649, 0.3559959828853607, 0.35617774724960327, 0.35643190145492554]',
 'm:accuracy': '[0.4628099203109741, 0.5165289044380188, 0.4917355477809906, 0.5247933864593506, 0.557851254940033, 0.6322314143180847, 0.6115702390670776, 0.6983470916748047, 0.7066115736961365, 0.7148760557174683, 0.7272727489471436, 0.702479362487793, 0.7438016533851624, 0.7644628286361694, 0.7644628286361694, 0.7190082669258118, 0.7355371713638306, 0.7561983466148376, 0.7685950398445129, 0.7851239442825317, 0.7809917330741882, 0.7603305578231812, 0.7851239442825317, 0.8099173307418823, 0.8181818127632141, 0.7685950398445129, 0.7933884263038635, 0.7727272510528564, 0.7479338645935059, 0.8223140239715576, 0.8016529083251953, 0.7933884263038635, 0.8057851195335388, 0.8057851195335388, 0.8223140239715576, 0.8140496015548706, 0.7933884263038635]',
 'm:val_accuracy': '[0.4754098355770111, 0.6393442749977112, 0.7377049326896667, 0.8032786846160889, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.7868852615356445, 0.8032786846160889, 0.8032786846160889, 0.8032786846160889, 0.7868852615356445, 0.7704917788505554, 0.7704917788505554, 0.7868852615356445, 0.7868852615356445, 0.7540983557701111, 0.7704917788505554, 0.7868852615356445, 0.8032786846160889, 0.8032786846160889, 0.8196721076965332, 0.8360655903816223, 0.8360655903816223, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.8524590134620667, 0.868852436542511, 0.868852436542511, 0.868852436542511, 0.868852436542511]',
 'm:num_parameters': 18,
 'm:num_parameters_train': 1141,
 'm:timestamp_gather': 97.0601110458374}
[38]:
lt.figure()
plt.plot(json.loads(metadata_default["val_accuracy"]), color="skyblue", label="Default (val)")
plt.plot(json.loads(metadata_default["accuracy"]), color="skyblue", linestyle="--", label="Default (train)")
plt.plot(json.loads(best_job["m:val_accuracy"]), color="coral", linewidth=2, label="Best Job(val)")
plt.plot(json.loads(best_job["m:accuracy"]), color="coral", linestyle="--", linewidth=2, label="Best Job (train)")
plt.legend()
plt.ylim(0.5, 0.9)
plt.grid()
plt.show()
../../../../_images/tutorials_tutorials_colab_HPS_basic_classification_with_tabular_data_notebook_54_0.png
[ ]: