1. Hyperparameter search for text classification (Pytorch)#

In this tutorial we present how to use hyperparameter optimization on a text classification analysis example from the Pytorch documentation.

Reference: This tutorial is based on materials from the Pytorch Documentation: Text classification with the torchtext library

[1]:
!pip3 install deephyper
!pip3 install ray
!pip3 install torch torchtext torchdata
Requirement already satisfied: deephyper in /Users/romainegele/Documents/Argonne/deephyper (0.3.4)
Requirement already satisfied: ConfigSpace>=0.4.20 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from deephyper) (0.5.0)
Requirement already satisfied: dm-tree in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from deephyper) (0.1.7)
Requirement already satisfied: Jinja2<3.1 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from deephyper) (3.0.3)
Requirement already satisfied: numpy in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from deephyper) (1.22.4)
Requirement already satisfied: pandas>=0.24.2 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from deephyper) (1.4.2)
Requirement already satisfied: packaging in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from deephyper) (21.3)
Requirement already satisfied: scikit-learn>=0.23.1 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from deephyper) (1.1.1)
Requirement already satisfied: scipy>=0.19.1 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from deephyper) (1.8.1)
Requirement already satisfied: tqdm>=4.64.0 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from deephyper) (4.64.0)
Requirement already satisfied: pyyaml in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from deephyper) (6.0)
Requirement already satisfied: cython in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from ConfigSpace>=0.4.20->deephyper) (0.29.30)
Requirement already satisfied: pyparsing in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from ConfigSpace>=0.4.20->deephyper) (3.0.9)
Requirement already satisfied: MarkupSafe>=2.0 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from Jinja2<3.1->deephyper) (2.1.1)
Requirement already satisfied: python-dateutil>=2.8.1 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from pandas>=0.24.2->deephyper) (2.8.2)
Requirement already satisfied: pytz>=2020.1 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from pandas>=0.24.2->deephyper) (2022.1)
Requirement already satisfied: joblib>=1.0.0 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from scikit-learn>=0.23.1->deephyper) (1.1.0)
Requirement already satisfied: threadpoolctl>=2.0.0 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from scikit-learn>=0.23.1->deephyper) (3.1.0)
Requirement already satisfied: six>=1.5 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from python-dateutil>=2.8.1->pandas>=0.24.2->deephyper) (1.15.0)
Requirement already satisfied: ray in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (1.12.1)
Requirement already satisfied: grpcio<=1.43.0,>=1.28.1 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from ray) (1.42.0)
Requirement already satisfied: requests in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from ray) (2.27.1)
Requirement already satisfied: click>=7.0 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from ray) (8.1.3)
Requirement already satisfied: virtualenv in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from ray) (20.14.1)
Requirement already satisfied: frozenlist in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from ray) (1.3.0)
Requirement already satisfied: pyyaml in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from ray) (6.0)
Requirement already satisfied: jsonschema in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from ray) (4.6.0)
Requirement already satisfied: protobuf>=3.15.3 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from ray) (3.18.1)
Requirement already satisfied: aiosignal in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from ray) (1.2.0)
Requirement already satisfied: attrs in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from ray) (21.4.0)
Requirement already satisfied: msgpack<2.0.0,>=1.0.0 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from ray) (1.0.4)
Requirement already satisfied: numpy>=1.19.3 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from ray) (1.22.4)
Requirement already satisfied: filelock in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from ray) (3.7.1)
Requirement already satisfied: six>=1.5.2 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from grpcio<=1.43.0,>=1.28.1->ray) (1.15.0)
Requirement already satisfied: pyrsistent!=0.17.0,!=0.17.1,!=0.17.2,>=0.14.0 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from jsonschema->ray) (0.18.1)
Requirement already satisfied: idna<4,>=2.5 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from requests->ray) (3.3)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from requests->ray) (1.26.9)
Requirement already satisfied: certifi>=2017.4.17 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from requests->ray) (2022.5.18.1)
Requirement already satisfied: charset-normalizer~=2.0.0 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from requests->ray) (2.0.12)
Requirement already satisfied: distlib<1,>=0.3.1 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from virtualenv->ray) (0.3.4)
Requirement already satisfied: platformdirs<3,>=2 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from virtualenv->ray) (2.5.2)
Requirement already satisfied: torch in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (1.11.0)
Requirement already satisfied: torchtext in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (0.12.0)
Requirement already satisfied: torchdata in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (0.3.0)
Requirement already satisfied: typing-extensions in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from torch) (4.2.0)
Requirement already satisfied: numpy in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from torchtext) (1.22.4)
Requirement already satisfied: tqdm in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from torchtext) (4.64.0)
Requirement already satisfied: requests in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from torchtext) (2.27.1)
Requirement already satisfied: urllib3>=1.25 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from torchdata) (1.26.9)
Requirement already satisfied: idna<4,>=2.5 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from requests->torchtext) (3.3)
Requirement already satisfied: certifi>=2017.4.17 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from requests->torchtext) (2022.5.18.1)
Requirement already satisfied: charset-normalizer~=2.0.0 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from requests->torchtext) (2.0.12)

1.1. Imports#

[2]:
import ray
import json
import pandas as pd
from functools import partial

import torch

from torchtext.data.utils import get_tokenizer
from torchtext.data.functional import to_map_style_dataset
from torchtext.vocab import build_vocab_from_iterator

from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split

from torch import nn

Note

The following can be used to detect if CUDA devices are available on the current host. Therefore, this notebook will automatically adapt the parallel execution based on the ressources available locally. However, it will not be the case if many compute nodes are requested.

[3]:
is_gpu_available = torch.cuda.is_available()
n_gpus = torch.cuda.device_count()

1.2. The dataset#

The torchtext library provides a few raw dataset iterators, which yield the raw text strings. For example, the AG_NEWS dataset iterators yield the raw data as a tuple of label and text. It has four labels (1 : World 2 : Sports 3 : Business 4 : Sci/Tec).

[4]:
from torchtext.datasets import AG_NEWS

def load_data(train_ratio):
    train_iter, test_iter = AG_NEWS()
    train_dataset = to_map_style_dataset(train_iter)
    test_dataset = to_map_style_dataset(test_iter)
    num_train = int(len(train_dataset) * train_ratio)
    split_train, split_valid = \
        random_split(train_dataset, [num_train, len(train_dataset) - num_train])

    return split_train, split_valid, test_dataset

1.3. Preprocessing pipelines and Batch generation#

Here is an example for typical NLP data processing with tokenizer and vocabulary. The first step is to build a vocabulary with the raw training dataset. Here we use built in factory function build_vocab_from_iterator which accepts iterator that yield list or iterator of tokens. Users can also pass any special symbols to be added to the vocabulary.

The vocabulary block converts a list of tokens into integers.

vocab(['here', 'is', 'an', 'example'])
>>> [475, 21, 30, 5286]

The text pipeline converts a text string into a list of integers based on the lookup table defined in the vocabulary. The label pipeline converts the label into integers. For example,

text_pipeline('here is the an example')
>>> [475, 21, 2, 30, 5286]
label_pipeline('10')
>>> 9
[5]:
train_iter = AG_NEWS(split='train')
num_class = 4

tokenizer = get_tokenizer('basic_english')

def yield_tokens(data_iter):
    for _, text in data_iter:
        yield tokenizer(text)

vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])
vocab_size = len(vocab)

text_pipeline = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: int(x) - 1


def collate_batch(batch, device):
    label_list, text_list, offsets = [], [], [0]
    for (_label, _text) in batch:
        label_list.append(label_pipeline(_label))
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
        text_list.append(processed_text)
        offsets.append(processed_text.size(0))
    label_list = torch.tensor(label_list, dtype=torch.int64)
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
    text_list = torch.cat(text_list)
    return label_list.to(device), text_list.to(device), offsets.to(device)
/Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages/torch/utils/data/datapipes/utils/common.py:24: UserWarning: Lambda function is not supported for pickle, please use regular python function or functools.partial instead.
  warnings.warn(
Note

The collate_fn function works on a batch of samples generated from DataLoader. The input to collate_fn is a batch of data with the batch size in DataLoader, and collate_fn processes them according to the data processing pipelines declared previously.

1.4. Define the model#

The model is composed of the nn.EmbeddingBag layer plus a linear layer for the classification purpose.

[6]:
class TextClassificationModel(nn.Module):

    def __init__(self, vocab_size, embed_dim, num_class):
        super(TextClassificationModel, self).__init__()
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
        self.fc = nn.Linear(embed_dim, num_class)
        self.init_weights()

    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()

    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        return self.fc(embedded)

1.5. Define functions to train the model and evaluate results.#

[7]:
def train(model, criterion, optimizer, dataloader):
    model.train()

    for _, (label, text, offsets) in enumerate(dataloader):
        optimizer.zero_grad()
        predicted_label = model(text, offsets)
        loss = criterion(predicted_label, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        optimizer.step()

def evaluate(model, dataloader):
    model.eval()
    total_acc, total_count = 0, 0

    with torch.no_grad():
        for _, (label, text, offsets) in enumerate(dataloader):
            predicted_label = model(text, offsets)
            total_acc += (predicted_label.argmax(1) == label).sum().item()
            total_count += label.size(0)
    return total_acc/total_count

1.6. Define the run-function#

The run-function defines how the objective that we want to maximize is computed. It takes a config dictionary as input and often returns a scalar value that we want to maximize. The config contains a sample value of hyperparameters that we want to tune. In this example we will search for:

  • num_epochs (default value: 10)

  • batch_size (default value: 64)

  • learning_rate (default value: 5)

A hyperparameter value can be acessed easily in the dictionary through the corresponding key, for example config["units"].

[8]:
def get_run(train_ratio=0.95):
  def run(config: dict):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    embed_dim = 64

    collate_fn = partial(collate_batch, device=device)
    split_train, split_valid, _ = load_data(train_ratio)
    train_dataloader = DataLoader(split_train, batch_size=int(config["batch_size"]),
                                shuffle=True, collate_fn=collate_fn)
    valid_dataloader = DataLoader(split_valid, batch_size=int(config["batch_size"]),
                                shuffle=True, collate_fn=collate_fn)

    model = TextClassificationModel(vocab_size, int(embed_dim), num_class).to(device)

    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=config["learning_rate"])

    for _ in range(1, int(config["num_epochs"]) + 1):
        train(model, criterion, optimizer, train_dataloader)

    accu_test = evaluate(model, valid_dataloader)
    return accu_test
  return run

We create two versions of run, one quicker to evaluate for the seacrh, with a small training dataset, and another one, for performance evaluation, which uses a normal training/validation ratio.

[9]:
quick_run = get_run(train_ratio=0.3)
perf_run = get_run(train_ratio=0.95)
Note

The objective maximised by DeepHyper is the scalar value returned by the run-function.

In this tutorial it corresponds to the validation accuracy of the model after training.

1.7. Define the Hyperparameter optimization problem#

Hyperparameter ranges are defined using the following syntax:

  • Discrete integer ranges are generated from a tuple (lower: int, upper: int)

  • Continuous prarameters are generated from a tuple (lower: float, upper: float)

  • Categorical or nonordinal hyperparameter ranges can be given as a list of possible values [val1, val2, ...]

We provide the default configuration of hyperparameters as a starting point of the problem.

[10]:
from deephyper.problem import HpProblem

problem = HpProblem()

# Discrete hyperparameter (sampled with uniform prior)
problem.add_hyperparameter((5, 20), "num_epochs", default_value=10)

# Discrete and Real hyperparameters (sampled with log-uniform)
problem.add_hyperparameter((8, 512, "log-uniform"), "batch_size", default_value=64)
problem.add_hyperparameter((0.1, 10, "log-uniform"), "learning_rate", default_value=5)

problem
[10]:
Configuration space object:
  Hyperparameters:
    batch_size, Type: UniformInteger, Range: [8, 512], Default: 64, on log-scale
    learning_rate, Type: UniformFloat, Range: [0.1, 10.0], Default: 5.0, on log-scale
    num_epochs, Type: UniformInteger, Range: [5, 20], Default: 10

1.8. Evaluate a default configuration#

We evaluate the performance of the default set of hyperparameters provided in the Pytorch tutorial.

[11]:
# We launch the Ray run-time and execute the `run` function
# with the default configuration

if is_gpu_available:
    if not(ray.is_initialized()):
        ray.init(num_cpus=n_gpus, num_gpus=n_gpus, log_to_driver=False)

    run_default = ray.remote(num_cpus=1, num_gpus=1)(perf_run)
    objective_default = ray.get(run_default.remote(problem.default_configuration))
else:
    if not(ray.is_initialized()):
        ray.init(num_cpus=1, log_to_driver=False)
    run_default = perf_run
    objective_default = run_default(problem.default_configuration)

print(f"Accuracy Default Configuration:  {objective_default:.3f}")
/Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages/torch/utils/data/datapipes/utils/common.py:24: UserWarning: Lambda function is not supported for pickle, please use regular python function or functools.partial instead.
  warnings.warn(
Accuracy Default Configuration:  0.902

1.9. Define the evaluator object#

The Evaluator object allows to change the parallelization backend used by DeepHyper.
It is a standalone object which schedules the execution of remote tasks. All evaluators needs a run_function to be instantiated.
Then a keyword method defines the backend (e.g., "ray") and the method_kwargs corresponds to keyword arguments of this chosen method.
evaluator = Evaluator.create(run_function, method, method_kwargs)

Once created the evaluator.num_workers gives access to the number of available parallel workers.

Finally, to submit and collect tasks to the evaluator one just needs to use the following interface:

configs = [...]
evaluator.submit(configs)
...
tasks_done = evaluator.get("BATCH", size=1) # For asynchronous
tasks_done = evaluator.get("ALL") # For batch synchronous

Warning

Each Evaluator saves its own state, therefore it is crucial to create a new evaluator when launching a fresh search.

[12]:
from deephyper.evaluator import Evaluator
from deephyper.evaluator.callback import TqdmCallback

def get_evaluator(run_function):
    # Default arguments for Ray: 1 worker and 1 worker per evaluation
    method_kwargs = {
        "num_cpus": 1,
        "num_cpus_per_task": 1,
        "callbacks": [TqdmCallback()]
    }

    # If GPU devices are detected then it will create 'n_gpus' workers
    # and use 1 worker for each evaluation
    if is_gpu_available:
        method_kwargs["num_cpus"] = n_gpus
        method_kwargs["num_gpus"] = n_gpus
        method_kwargs["num_cpus_per_task"] = 1
        method_kwargs["num_gpus_per_task"] = 1

    evaluator = Evaluator.create(
        run_function,
        method="ray",
        method_kwargs=method_kwargs
    )
    print(f"Created new evaluator with {evaluator.num_workers} worker{'s' if evaluator.num_workers > 1 else ''} and config: {method_kwargs}", )

    return evaluator

evaluator_1 = get_evaluator(quick_run)
Created new evaluator with 1 worker and config: {'num_cpus': 1, 'num_cpus_per_task': 1, 'callbacks': [<deephyper.evaluator.callback.TqdmCallback object at 0x17d271e80>]}
/Users/romainegele/Documents/Argonne/deephyper/deephyper/evaluator/_evaluator.py:99: UserWarning: Applying nest-asyncio patch for IPython Shell!
  warnings.warn(

1.10. Define and run the Centralized Bayesian Optimization search (CBO)#

We create the CBO using the problem and evaluator defined above.

[13]:
from deephyper.search.hps import CBO
# Uncomment the following line to show the arguments of CBO.
# CBO?
[14]:
# Instanciate the search with the problem and a specific evaluator
search = CBO(problem, evaluator_1)

Note

All DeepHyper’s search algorithm have two stopping criteria:

  • max_evals (int): Defines the maximum number of evaluations that we want to perform. Default to -1 for an infinite number.

  • timeout (int): Defines a time budget (in seconds) before stopping the search. Default to None for an infinite time budget.

[15]:
results = search.search(max_evals=30)
100%|██████████| 30/30 [07:08<00:00, 11.39s/it, objective=0.892]

The returned results is a Pandas Dataframe where columns are hyperparameters and information stored by the evaluator:

  • job_id is a unique identifier corresponding to the order of creation of tasks

  • objective is the value returned by the run-function

  • timestamp_submit is the time (in seconds) when the hyperparameter configuration was submitted by the Evaluator relative to the creation of the evaluator.

  • timestamp_gather is the time (in seconds) when the hyperparameter configuration was collected by the Evaluator relative to the creation of the evaluator.

[16]:
results
[16]:
batch_size learning_rate num_epochs job_id objective timestamp_submit timestamp_gather
0 13 0.396163 10 1 0.876310 0.167126 21.861172
1 23 0.176352 13 2 0.836810 21.893403 46.585527
2 85 0.615317 9 3 0.849250 46.603550 62.330578
3 9 4.647854 15 4 0.888524 62.348272 96.357171
4 22 1.862895 7 5 0.892131 96.375062 111.058422
5 13 8.390420 7 6 0.892155 111.077491 127.247439
6 56 5.709700 7 7 0.888726 127.265297 140.431804
7 47 1.073937 17 8 0.891607 140.449749 168.683609
8 116 0.356077 8 9 0.757940 168.787206 183.260858
9 39 3.127017 7 10 0.891643 183.278615 196.925297
10 512 0.348708 6 11 0.478179 197.162700 207.922961
11 512 3.726319 17 12 0.879036 208.166723 232.997138
12 512 8.765370 6 13 0.864667 233.320723 244.213859
13 512 1.910815 5 14 0.737810 244.466962 254.037378
14 512 0.313972 6 15 0.471179 254.293386 265.038085
15 512 0.390254 17 16 0.659024 265.296111 291.446844
16 512 0.314117 6 17 0.463964 291.793700 302.993824
17 512 0.406197 6 18 0.495940 303.264247 314.411603
18 512 0.320737 6 19 0.476452 314.679307 325.618635
19 512 0.310132 6 20 0.455762 325.959387 336.946107
20 512 0.277178 6 21 0.438250 337.212328 348.157166
21 512 0.276917 6 22 0.449714 348.422956 359.220447
22 512 0.273978 6 23 0.453071 359.485135 370.697273
23 512 0.298666 6 24 0.450143 371.048602 382.304554
24 512 0.280337 6 25 0.463952 382.571858 393.747633
25 512 0.292463 6 26 0.469250 394.017767 405.152517
26 512 0.273828 6 27 0.459488 405.424290 416.641872
27 512 0.306901 6 28 0.461964 416.986107 428.063347
28 512 0.281089 6 29 0.458131 428.335923 439.447637
29 512 0.278833 6 30 0.436702 439.720846 450.642993

1.11. Evaluate the best configuration#

Now that the search is over, let us print the best configuration found during this run and evaluate it on the full training dataset.

[17]:
i_max = results.objective.argmax()
best_config = results.iloc[i_max][:-3].to_dict()
best_config = {k[2:]: v for k, v in best_config.items() if k.startswith("p:")}

print(f"The default configuration has an accuracy of {objective_default:.3f}. \n"
      f"The best configuration found by DeepHyper has an accuracy {results['objective'].iloc[i_max]:.3f}, \n"
      f"finished after {results['timestamp_gather'].iloc[i_max]:.2f} secondes of search.\n")

print(json.dumps(best_config, indent=4))
The default configuration has an accuracy of 0.902.
The best configuration found by DeepHyper has an accuracy 0.892,
finished after 127.25 secondes of search.

{
    "batch_size": 13.0,
    "learning_rate": 8.39041977280772,
    "num_epochs": 7.0,
    "job_id": 6.0
}
[18]:
objective_best = perf_run(best_config)
print(f"Accuracy Best Configuration:  {objective_best:.3f}")
/Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages/torch/utils/data/datapipes/utils/common.py:24: UserWarning: Lambda function is not supported for pickle, please use regular python function or functools.partial instead.
  warnings.warn(
Accuracy Best Configuration:  0.914