3. From Neural Architecture Search to Automated Deep Ensemble with Uncertainty Quantification

3.1. Imports and GPU Detection

Warning

By design asyncio does not allow nested event loops. Jupyter is using Tornado which already starts an event loop. Therefore the following patch is required to run this tutorial.

[1]:
!pip install nest_asyncio

import nest_asyncio
nest_asyncio.apply()
Requirement already satisfied: nest_asyncio in /Users/romainegele/opt/anaconda3/envs/dh-dev/lib/python3.8/site-packages (1.5.1)
[2]:
import json
import os
import pathlib
import shutil

!export TF_CPP_MIN_LOG_LEVEL=3
!export TF_XLA_FLAGS=--tf_xla_enable_xla_devices

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tqdm import tqdm

Note

The TF_CPP_MIN_LOG_LEVEL can be used to avoid the logging of Tensorflow DEBUG, INFO and WARNING statements.

Note

The following can be used to detect if GPU devices are available on the current host.

[3]:
available_gpus = tf.config.list_physical_devices("GPU")
n_gpus = len(available_gpus)
if n_gpus > 1:
    n_gpus -= 1
is_gpu_available = n_gpus > 0

if is_gpu_available:
    print(f"{n_gpus} GPU{'s are' if n_gpus > 1 else ' is'} available.")
else:
    print("No GPU available")
No GPU available

3.2. Start Ray

We launch the Ray run-time depending on the detected local ressources. If GPU(s) is(are) detected then 1 worker is started for each GPU. If not, then only 1 worker is started. You can start more workers by setting num_cpus=1 to a value greater than 1.

Warning

In the case of GPUs it is important to follow this scheme to avoid multiple processes (Ray workers vs current process) to lock the same GPU.

[4]:
import ray

if not(ray.is_initialized()):
    if is_gpu_available:
        ray.init(num_cpus=n_gpus, num_gpus=n_gpus, log_to_driver=False)
    else:
        ray.init(num_cpus=4, log_to_driver=False)
2021-10-29 09:50:19,777 INFO services.py:1263 -- View the Ray dashboard at http://127.0.0.1:8265

3.3. A Synthetic Dataset

Now, we will start by defining our artificial dataset based on a Sinus curve. We will first generate data for a training set (used for estimation) and a testing set (used to evaluate the final performance). Then the training set will be sub-divided in a new training set (used to estimate the neural network weights) and validation set (used to estimate the neural network hyperparameters and architecture). The data are generated from the following function:

\[y = f(x) = 2 \cdot \sin(x) + \epsilon\]

The training data will be generated in a range between \([-30, -20]\) with \(\epsilon \sim \mathcal{N}(0,0.25)\) and in a range between \([20, 30]\) with \(\epsilon \sim \mathcal{N}(0,1)\). The code for the training data is then corresponding to:

[5]:
def load_data_train_test(random_state=42):
    rs = np.random.RandomState(random_state)

    train_size = 400
    f = lambda x: 2*np.sin(x) # a simlpe affine function

    x_1 = rs.uniform(low=-30, high=-20.0, size=train_size//2)
    eps_1 = rs.normal(loc=0.0, scale=0.5, size=train_size//2)
    y_1 = f(x_1) + eps_1

    x_2 = rs.uniform(low=20.0, high=30.0, size=train_size//2)
    eps_2 = rs.normal(loc=0.0, scale=1.0, size=train_size//2)
    y_2 = f(x_2) + eps_2

    x = np.concatenate([x_1, x_2], axis=0)
    y = np.concatenate([y_1, y_2], axis=0)

    x_tst = np.linspace(-40.0, 40.0, 200)
    y_tst = f(x_tst)

    x = x.reshape(-1, 1)
    y = y.reshape(-1, 1)

    x_tst = x_tst.reshape(-1, 1)
    y_tst = y_tst.reshape(-1, 1)

    return (x, y), (x_tst, y_tst)

Then the code to split the training data in a new training set and a validation set corresponds to:

[6]:
from sklearn.model_selection import train_test_split

def load_data_train_valid(verbose=0, random_state=42):

    (x, y), _ = load_data_train_test(random_state=random_state)

    train_X, valid_X, train_y, valid_y = train_test_split(
        x, y, test_size=0.33, random_state=random_state
    )

    if verbose:
        print(f'train_X shape: {np.shape(train_X)}')
        print(f'train_y shape: {np.shape(train_y)}')
        print(f'valid_X shape: {np.shape(valid_X)}')
        print(f'valid_y shape: {np.shape(valid_y)}')
    return (train_X, train_y), (valid_X, valid_y)


(x, y), (vx, vy) = load_data_train_valid(verbose=1)
_, (tx , ty) = load_data_train_test()
train_X shape: (268, 1)
train_y shape: (268, 1)
valid_X shape: (132, 1)
valid_y shape: (132, 1)

Note

When it is possible to factorize the two previous function into one, DeepHyper interface requires a function which returns (train_inputs, train_outputs), (valid_inputs, valid_outputs).

We can give a visualization of this data:

[7]:
width = 8
height = width/1.618
plt.figure(figsize=(width, height))

plt.plot(tx.reshape(-1), ty.reshape(-1), "ko--", label="test", alpha=0.5)
plt.plot(x.reshape(-1), y.reshape(-1), "bo", label="train", alpha=0.8)
plt.plot(vx.reshape(-1), vy.reshape(-1), "ro", label="valid", alpha=0.8)

plt.ylabel("$y = f(x)$", fontsize=12)
plt.xlabel("$x$", fontsize=12)

plt.xlim(-40, 40)
plt.legend(loc="upper center", ncol=3, fontsize=12)

plt.show()
../../../../_images/tutorials_tutorials_notebooks_07_NAS_with_Ensemble_and_UQ_tutorial_07_12_0.png

3.4. Scaling the Data

It is important to apply standard scaling on the input/output data to have faster convergence when training.

[8]:
from sklearn.preprocessing import StandardScaler


scaler_x = StandardScaler()
s_x = scaler_x.fit_transform(x)
s_vx = scaler_x.transform(vx)
s_tx = scaler_x.transform(tx)

scaler_y = StandardScaler()
s_y = scaler_y.fit_transform(y)
s_vy = scaler_y.transform(vy)
s_ty = scaler_y.transform(ty)

3.5. Baseline Neural Network

Let us define a baseline neural network based on a regular multi-layer perceptron architecture which learn the mean estimate and minimise the mean squared error.

[9]:
input_ = tf.keras.layers.Input(shape=(1,))
out = tf.keras.layers.Dense(200, activation="relu")(input_)
out = tf.keras.layers.Dense(200, activation="relu")(out)
output = tf.keras.layers.Dense(1)(out)
model = tf.keras.Model(input_, output)

optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
model.compile(optimizer, "mse")

history = model.fit(s_x, s_y, epochs=200, batch_size=4, validation_data=(s_vx, s_vy), verbose=1).history
2021-10-29 09:50:22.693837: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2021-10-29 09:50:22.883748: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)
Epoch 1/200
67/67 [==============================] - 0s 2ms/step - loss: 1.0169 - val_loss: 1.0942
Epoch 2/200
67/67 [==============================] - 0s 1ms/step - loss: 1.0078 - val_loss: 1.0726
Epoch 3/200
67/67 [==============================] - 0s 1ms/step - loss: 1.0140 - val_loss: 1.0725
Epoch 4/200
67/67 [==============================] - 0s 1ms/step - loss: 1.0076 - val_loss: 1.0752
Epoch 5/200
67/67 [==============================] - 0s 1ms/step - loss: 1.0135 - val_loss: 1.0751
Epoch 6/200
67/67 [==============================] - 0s 1ms/step - loss: 1.0033 - val_loss: 1.0728
Epoch 7/200
67/67 [==============================] - 0s 1ms/step - loss: 1.0090 - val_loss: 1.0686
Epoch 8/200
67/67 [==============================] - 0s 1ms/step - loss: 1.0057 - val_loss: 1.0755
Epoch 9/200
67/67 [==============================] - 0s 1ms/step - loss: 1.0043 - val_loss: 1.0792
Epoch 10/200
67/67 [==============================] - 0s 983us/step - loss: 0.9990 - val_loss: 1.0653
Epoch 11/200
67/67 [==============================] - 0s 979us/step - loss: 0.9965 - val_loss: 1.0673
Epoch 12/200
67/67 [==============================] - 0s 1ms/step - loss: 1.0042 - val_loss: 1.0841
Epoch 13/200
67/67 [==============================] - 0s 978us/step - loss: 1.0021 - val_loss: 1.0820
Epoch 14/200
67/67 [==============================] - 0s 954us/step - loss: 0.9933 - val_loss: 1.0681
Epoch 15/200
67/67 [==============================] - 0s 983us/step - loss: 0.9909 - val_loss: 1.0637
Epoch 16/200
67/67 [==============================] - ETA: 0s - loss: 1.560 - 0s 1ms/step - loss: 0.9938 - val_loss: 1.0758
Epoch 17/200
67/67 [==============================] - 0s 1ms/step - loss: 0.9867 - val_loss: 1.0602
Epoch 18/200
67/67 [==============================] - 0s 1ms/step - loss: 0.9924 - val_loss: 1.0765
Epoch 19/200
67/67 [==============================] - 0s 1ms/step - loss: 0.9904 - val_loss: 1.0639
Epoch 20/200
67/67 [==============================] - 0s 1ms/step - loss: 0.9866 - val_loss: 1.0646
Epoch 21/200
67/67 [==============================] - 0s 976us/step - loss: 0.9899 - val_loss: 1.0568
Epoch 22/200
67/67 [==============================] - 0s 978us/step - loss: 0.9923 - val_loss: 1.0539
Epoch 23/200
67/67 [==============================] - 0s 1ms/step - loss: 0.9816 - val_loss: 1.0522
Epoch 24/200
67/67 [==============================] - 0s 1ms/step - loss: 0.9879 - val_loss: 1.0554
Epoch 25/200
67/67 [==============================] - 0s 1ms/step - loss: 0.9838 - val_loss: 1.0544
Epoch 26/200
67/67 [==============================] - 0s 1ms/step - loss: 0.9842 - val_loss: 1.0493
Epoch 27/200
67/67 [==============================] - 0s 1ms/step - loss: 0.9759 - val_loss: 1.0493
Epoch 28/200
67/67 [==============================] - 0s 1000us/step - loss: 0.9762 - val_loss: 1.0560
Epoch 29/200
67/67 [==============================] - 0s 998us/step - loss: 0.9679 - val_loss: 1.0518
Epoch 30/200
67/67 [==============================] - 0s 974us/step - loss: 0.9672 - val_loss: 1.0697
Epoch 31/200
67/67 [==============================] - 0s 975us/step - loss: 0.9707 - val_loss: 1.0412
Epoch 32/200
67/67 [==============================] - 0s 972us/step - loss: 0.9794 - val_loss: 1.0419
Epoch 33/200
67/67 [==============================] - 0s 1ms/step - loss: 0.9591 - val_loss: 1.0904
Epoch 34/200
67/67 [==============================] - 0s 1ms/step - loss: 0.9841 - val_loss: 1.0759
Epoch 35/200
67/67 [==============================] - 0s 1ms/step - loss: 0.9938 - val_loss: 1.0487
Epoch 36/200
67/67 [==============================] - 0s 1ms/step - loss: 0.9663 - val_loss: 1.0302
Epoch 37/200
67/67 [==============================] - 0s 1ms/step - loss: 0.9658 - val_loss: 1.0306
Epoch 38/200
67/67 [==============================] - 0s 949us/step - loss: 0.9590 - val_loss: 1.0279
Epoch 39/200
67/67 [==============================] - 0s 953us/step - loss: 0.9685 - val_loss: 1.0316
Epoch 40/200
67/67 [==============================] - 0s 987us/step - loss: 0.9588 - val_loss: 1.0292
Epoch 41/200
67/67 [==============================] - 0s 1ms/step - loss: 0.9611 - val_loss: 1.0291
Epoch 42/200
67/67 [==============================] - 0s 977us/step - loss: 0.9628 - val_loss: 1.0264
Epoch 43/200
67/67 [==============================] - 0s 1ms/step - loss: 0.9693 - val_loss: 1.0456
Epoch 44/200
67/67 [==============================] - 0s 987us/step - loss: 0.9751 - val_loss: 1.0591
Epoch 45/200
67/67 [==============================] - 0s 1ms/step - loss: 0.9631 - val_loss: 1.0340
Epoch 46/200
67/67 [==============================] - 0s 1ms/step - loss: 0.9527 - val_loss: 1.0374
Epoch 47/200
67/67 [==============================] - 0s 1ms/step - loss: 0.9536 - val_loss: 1.0383
Epoch 48/200
67/67 [==============================] - 0s 1ms/step - loss: 0.9444 - val_loss: 1.0532
Epoch 49/200
67/67 [==============================] - 0s 992us/step - loss: 0.9501 - val_loss: 1.0730
Epoch 50/200
67/67 [==============================] - 0s 1ms/step - loss: 0.9636 - val_loss: 1.0226
Epoch 51/200
67/67 [==============================] - 0s 1ms/step - loss: 0.9568 - val_loss: 1.0206
Epoch 52/200
67/67 [==============================] - 0s 1ms/step - loss: 0.9574 - val_loss: 1.0259
Epoch 53/200
67/67 [==============================] - 0s 1ms/step - loss: 0.9536 - val_loss: 1.0180
Epoch 54/200
67/67 [==============================] - 0s 978us/step - loss: 0.9489 - val_loss: 1.0312
Epoch 55/200
67/67 [==============================] - 0s 1ms/step - loss: 0.9583 - val_loss: 1.0231
Epoch 56/200
67/67 [==============================] - 0s 1ms/step - loss: 0.9579 - val_loss: 1.0308
Epoch 57/200
67/67 [==============================] - 0s 1ms/step - loss: 0.9357 - val_loss: 1.0114
Epoch 58/200
67/67 [==============================] - 0s 1ms/step - loss: 0.9359 - val_loss: 1.0088
Epoch 59/200
67/67 [==============================] - 0s 1ms/step - loss: 0.9506 - val_loss: 1.0162
Epoch 60/200
67/67 [==============================] - 0s 1ms/step - loss: 0.9439 - val_loss: 1.0261
Epoch 61/200
67/67 [==============================] - 0s 980us/step - loss: 0.9348 - val_loss: 0.9999
Epoch 62/200
67/67 [==============================] - 0s 1ms/step - loss: 0.9246 - val_loss: 1.0156
Epoch 63/200
67/67 [==============================] - 0s 1ms/step - loss: 0.9329 - val_loss: 1.0313
Epoch 64/200
67/67 [==============================] - 0s 1ms/step - loss: 0.9284 - val_loss: 1.0160
Epoch 65/200
67/67 [==============================] - 0s 1ms/step - loss: 0.9368 - val_loss: 1.0075
Epoch 66/200
67/67 [==============================] - 0s 1ms/step - loss: 0.9350 - val_loss: 1.0501
Epoch 67/200
67/67 [==============================] - 0s 983us/step - loss: 0.9063 - val_loss: 0.9868
Epoch 68/200
67/67 [==============================] - 0s 1ms/step - loss: 0.9094 - val_loss: 1.0086
Epoch 69/200
67/67 [==============================] - 0s 1ms/step - loss: 0.9156 - val_loss: 0.9821
Epoch 70/200
67/67 [==============================] - 0s 1ms/step - loss: 0.9170 - val_loss: 0.9796
Epoch 71/200
67/67 [==============================] - 0s 1ms/step - loss: 0.8983 - val_loss: 1.0076
Epoch 72/200
67/67 [==============================] - 0s 963us/step - loss: 0.9002 - val_loss: 0.9780
Epoch 73/200
67/67 [==============================] - 0s 961us/step - loss: 0.9014 - val_loss: 0.9714
Epoch 74/200
67/67 [==============================] - 0s 997us/step - loss: 0.9275 - val_loss: 0.9764
Epoch 75/200
67/67 [==============================] - 0s 1ms/step - loss: 0.9021 - val_loss: 0.9771
Epoch 76/200
67/67 [==============================] - 0s 1ms/step - loss: 0.8960 - val_loss: 0.9624
Epoch 77/200
67/67 [==============================] - 0s 1ms/step - loss: 0.8892 - val_loss: 0.9557
Epoch 78/200
67/67 [==============================] - 0s 1ms/step - loss: 0.8993 - val_loss: 0.9742
Epoch 79/200
67/67 [==============================] - 0s 1ms/step - loss: 0.8901 - val_loss: 0.9723
Epoch 80/200
67/67 [==============================] - 0s 1ms/step - loss: 0.8841 - val_loss: 0.9761
Epoch 81/200
67/67 [==============================] - 0s 975us/step - loss: 0.8781 - val_loss: 0.9588
Epoch 82/200
67/67 [==============================] - 0s 999us/step - loss: 0.8716 - val_loss: 0.9567
Epoch 83/200
67/67 [==============================] - 0s 1ms/step - loss: 0.8672 - val_loss: 1.0171
Epoch 84/200
67/67 [==============================] - 0s 1ms/step - loss: 0.8902 - val_loss: 0.9796
Epoch 85/200
67/67 [==============================] - 0s 1ms/step - loss: 0.8592 - val_loss: 0.9406
Epoch 86/200
67/67 [==============================] - 0s 1ms/step - loss: 0.8634 - val_loss: 0.9317
Epoch 87/200
67/67 [==============================] - 0s 975us/step - loss: 0.8779 - val_loss: 0.9268
Epoch 88/200
67/67 [==============================] - 0s 979us/step - loss: 0.8654 - val_loss: 0.9470
Epoch 89/200
67/67 [==============================] - 0s 987us/step - loss: 0.8590 - val_loss: 0.9294
Epoch 90/200
67/67 [==============================] - 0s 966us/step - loss: 0.8655 - val_loss: 0.9483
Epoch 91/200
67/67 [==============================] - 0s 1ms/step - loss: 0.8339 - val_loss: 0.9287
Epoch 92/200
67/67 [==============================] - 0s 1ms/step - loss: 0.8432 - val_loss: 1.0219
Epoch 93/200
67/67 [==============================] - 0s 1ms/step - loss: 0.8387 - val_loss: 0.9278
Epoch 94/200
67/67 [==============================] - 0s 1ms/step - loss: 0.8374 - val_loss: 0.9091
Epoch 95/200
67/67 [==============================] - 0s 1ms/step - loss: 0.8430 - val_loss: 0.9218
Epoch 96/200
67/67 [==============================] - 0s 1ms/step - loss: 0.8255 - val_loss: 0.9224
Epoch 97/200
67/67 [==============================] - 0s 1ms/step - loss: 0.8209 - val_loss: 0.9122
Epoch 98/200
67/67 [==============================] - 0s 992us/step - loss: 0.8351 - val_loss: 0.9074
Epoch 99/200
67/67 [==============================] - 0s 1ms/step - loss: 0.8269 - val_loss: 0.8986
Epoch 100/200
67/67 [==============================] - 0s 1ms/step - loss: 0.8146 - val_loss: 0.8949
Epoch 101/200
67/67 [==============================] - 0s 1ms/step - loss: 0.8276 - val_loss: 0.8887
Epoch 102/200
67/67 [==============================] - 0s 1ms/step - loss: 0.8130 - val_loss: 1.0022
Epoch 103/200
67/67 [==============================] - 0s 1ms/step - loss: 0.8429 - val_loss: 0.8888
Epoch 104/200
67/67 [==============================] - 0s 997us/step - loss: 0.7991 - val_loss: 0.8799
Epoch 105/200
67/67 [==============================] - 0s 955us/step - loss: 0.7970 - val_loss: 0.8792
Epoch 106/200
67/67 [==============================] - 0s 964us/step - loss: 0.8021 - val_loss: 0.8720
Epoch 107/200
67/67 [==============================] - 0s 958us/step - loss: 0.8006 - val_loss: 0.8885
Epoch 108/200
67/67 [==============================] - 0s 982us/step - loss: 0.8247 - val_loss: 0.8667
Epoch 109/200
67/67 [==============================] - 0s 961us/step - loss: 0.7904 - val_loss: 0.8808
Epoch 110/200
67/67 [==============================] - 0s 965us/step - loss: 0.7780 - val_loss: 0.8692
Epoch 111/200
67/67 [==============================] - 0s 940us/step - loss: 0.7683 - val_loss: 0.8502
Epoch 112/200
67/67 [==============================] - 0s 985us/step - loss: 0.7553 - val_loss: 0.8536
Epoch 113/200
67/67 [==============================] - 0s 954us/step - loss: 0.7705 - val_loss: 0.8386
Epoch 114/200
67/67 [==============================] - 0s 983us/step - loss: 0.7647 - val_loss: 0.8461
Epoch 115/200
67/67 [==============================] - 0s 968us/step - loss: 0.7656 - val_loss: 0.8337
Epoch 116/200
67/67 [==============================] - 0s 995us/step - loss: 0.7556 - val_loss: 0.8904
Epoch 117/200
67/67 [==============================] - 0s 1ms/step - loss: 0.7494 - val_loss: 0.8408
Epoch 118/200
67/67 [==============================] - 0s 1ms/step - loss: 0.7577 - val_loss: 0.8125
Epoch 119/200
67/67 [==============================] - 0s 958us/step - loss: 0.7342 - val_loss: 0.8105
Epoch 120/200
67/67 [==============================] - 0s 963us/step - loss: 0.7360 - val_loss: 0.8148
Epoch 121/200
67/67 [==============================] - 0s 940us/step - loss: 0.7290 - val_loss: 0.8008
Epoch 122/200
67/67 [==============================] - 0s 1ms/step - loss: 0.7302 - val_loss: 0.8171
Epoch 123/200
67/67 [==============================] - 0s 982us/step - loss: 0.7215 - val_loss: 0.8351
Epoch 124/200
67/67 [==============================] - 0s 989us/step - loss: 0.7359 - val_loss: 0.7932
Epoch 125/200
67/67 [==============================] - 0s 989us/step - loss: 0.7037 - val_loss: 0.7774
Epoch 126/200
67/67 [==============================] - 0s 989us/step - loss: 0.7176 - val_loss: 0.7834
Epoch 127/200
67/67 [==============================] - 0s 1ms/step - loss: 0.6945 - val_loss: 0.7771
Epoch 128/200
67/67 [==============================] - 0s 955us/step - loss: 0.6890 - val_loss: 0.7646
Epoch 129/200
67/67 [==============================] - 0s 972us/step - loss: 0.6951 - val_loss: 0.7545
Epoch 130/200
67/67 [==============================] - 0s 995us/step - loss: 0.6971 - val_loss: 0.7485
Epoch 131/200
67/67 [==============================] - 0s 969us/step - loss: 0.6753 - val_loss: 0.7523
Epoch 132/200
67/67 [==============================] - 0s 996us/step - loss: 0.6562 - val_loss: 0.7713
Epoch 133/200
67/67 [==============================] - 0s 1ms/step - loss: 0.6542 - val_loss: 0.7239
Epoch 134/200
67/67 [==============================] - 0s 988us/step - loss: 0.6551 - val_loss: 0.7296
Epoch 135/200
67/67 [==============================] - 0s 986us/step - loss: 0.6428 - val_loss: 0.7149
Epoch 136/200
67/67 [==============================] - 0s 950us/step - loss: 0.6363 - val_loss: 0.7704
Epoch 137/200
67/67 [==============================] - 0s 985us/step - loss: 0.6381 - val_loss: 0.7178
Epoch 138/200
67/67 [==============================] - 0s 1ms/step - loss: 0.6345 - val_loss: 0.6933
Epoch 139/200
67/67 [==============================] - 0s 1ms/step - loss: 0.6079 - val_loss: 0.6830
Epoch 140/200
67/67 [==============================] - 0s 1ms/step - loss: 0.6263 - val_loss: 0.6878
Epoch 141/200
67/67 [==============================] - 0s 948us/step - loss: 0.6010 - val_loss: 0.7053
Epoch 142/200
67/67 [==============================] - 0s 1ms/step - loss: 0.5947 - val_loss: 0.6959
Epoch 143/200
67/67 [==============================] - 0s 980us/step - loss: 0.5906 - val_loss: 0.6964
Epoch 144/200
67/67 [==============================] - 0s 972us/step - loss: 0.5758 - val_loss: 0.6495
Epoch 145/200
67/67 [==============================] - 0s 998us/step - loss: 0.5524 - val_loss: 0.6835
Epoch 146/200
67/67 [==============================] - 0s 956us/step - loss: 0.5748 - val_loss: 0.6207
Epoch 147/200
67/67 [==============================] - 0s 1ms/step - loss: 0.5466 - val_loss: 0.6315
Epoch 148/200
67/67 [==============================] - 0s 1ms/step - loss: 0.5237 - val_loss: 0.6414
Epoch 149/200
67/67 [==============================] - 0s 968us/step - loss: 0.5347 - val_loss: 0.5952
Epoch 150/200
67/67 [==============================] - 0s 944us/step - loss: 0.5601 - val_loss: 0.5972
Epoch 151/200
67/67 [==============================] - 0s 991us/step - loss: 0.5363 - val_loss: 0.6263
Epoch 152/200
67/67 [==============================] - 0s 962us/step - loss: 0.5202 - val_loss: 0.5857
Epoch 153/200
67/67 [==============================] - 0s 982us/step - loss: 0.5138 - val_loss: 0.5581
Epoch 154/200
67/67 [==============================] - 0s 962us/step - loss: 0.4919 - val_loss: 0.5909
Epoch 155/200
67/67 [==============================] - 0s 1ms/step - loss: 0.4716 - val_loss: 0.5965
Epoch 156/200
67/67 [==============================] - 0s 977us/step - loss: 0.4606 - val_loss: 0.5349
Epoch 157/200
67/67 [==============================] - 0s 1ms/step - loss: 0.4782 - val_loss: 0.5741
Epoch 158/200
67/67 [==============================] - 0s 976us/step - loss: 0.4719 - val_loss: 0.5271
Epoch 159/200
67/67 [==============================] - 0s 949us/step - loss: 0.4365 - val_loss: 0.5051
Epoch 160/200
67/67 [==============================] - 0s 980us/step - loss: 0.4526 - val_loss: 0.5139
Epoch 161/200
67/67 [==============================] - 0s 991us/step - loss: 0.4320 - val_loss: 0.4930
Epoch 162/200
67/67 [==============================] - 0s 1ms/step - loss: 0.4385 - val_loss: 0.4843
Epoch 163/200
67/67 [==============================] - 0s 993us/step - loss: 0.4122 - val_loss: 0.5005
Epoch 164/200
67/67 [==============================] - 0s 957us/step - loss: 0.4187 - val_loss: 0.4799
Epoch 165/200
67/67 [==============================] - 0s 947us/step - loss: 0.3914 - val_loss: 0.5181
Epoch 166/200
67/67 [==============================] - 0s 947us/step - loss: 0.3934 - val_loss: 0.4505
Epoch 167/200
67/67 [==============================] - 0s 956us/step - loss: 0.3746 - val_loss: 0.4901
Epoch 168/200
67/67 [==============================] - 0s 970us/step - loss: 0.3614 - val_loss: 0.4266
Epoch 169/200
67/67 [==============================] - 0s 950us/step - loss: 0.3540 - val_loss: 0.4490
Epoch 170/200
67/67 [==============================] - 0s 979us/step - loss: 0.3522 - val_loss: 0.4580
Epoch 171/200
67/67 [==============================] - 0s 1ms/step - loss: 0.3365 - val_loss: 0.4231
Epoch 172/200
67/67 [==============================] - 0s 1ms/step - loss: 0.3457 - val_loss: 0.4861
Epoch 173/200
67/67 [==============================] - 0s 968us/step - loss: 0.3228 - val_loss: 0.4155
Epoch 174/200
67/67 [==============================] - 0s 1ms/step - loss: 0.3142 - val_loss: 0.3756
Epoch 175/200
67/67 [==============================] - 0s 1ms/step - loss: 0.3016 - val_loss: 0.3775
Epoch 176/200
67/67 [==============================] - 0s 1ms/step - loss: 0.3326 - val_loss: 0.3581
Epoch 177/200
67/67 [==============================] - 0s 1ms/step - loss: 0.2821 - val_loss: 0.3453
Epoch 178/200
67/67 [==============================] - 0s 1ms/step - loss: 0.2799 - val_loss: 0.3860
Epoch 179/200
67/67 [==============================] - 0s 1ms/step - loss: 0.2971 - val_loss: 0.3752
Epoch 180/200
67/67 [==============================] - 0s 1ms/step - loss: 0.2735 - val_loss: 0.3607
Epoch 181/200
67/67 [==============================] - 0s 993us/step - loss: 0.2597 - val_loss: 0.3483
Epoch 182/200
67/67 [==============================] - 0s 984us/step - loss: 0.2624 - val_loss: 0.3335
Epoch 183/200
67/67 [==============================] - 0s 980us/step - loss: 0.2635 - val_loss: 0.3855
Epoch 184/200
67/67 [==============================] - 0s 1ms/step - loss: 0.2916 - val_loss: 0.3407
Epoch 185/200
67/67 [==============================] - 0s 1ms/step - loss: 0.2528 - val_loss: 0.3558
Epoch 186/200
67/67 [==============================] - 0s 973us/step - loss: 0.2552 - val_loss: 0.3594
Epoch 187/200
67/67 [==============================] - 0s 954us/step - loss: 0.2488 - val_loss: 0.3351
Epoch 188/200
67/67 [==============================] - 0s 968us/step - loss: 0.2346 - val_loss: 0.3204
Epoch 189/200
67/67 [==============================] - 0s 981us/step - loss: 0.2642 - val_loss: 0.3498
Epoch 190/200
67/67 [==============================] - 0s 953us/step - loss: 0.2680 - val_loss: 0.4117
Epoch 191/200
67/67 [==============================] - 0s 963us/step - loss: 0.2432 - val_loss: 0.3660
Epoch 192/200
67/67 [==============================] - 0s 976us/step - loss: 0.2501 - val_loss: 0.3176
Epoch 193/200
67/67 [==============================] - 0s 973us/step - loss: 0.2492 - val_loss: 0.3547
Epoch 194/200
67/67 [==============================] - 0s 942us/step - loss: 0.2362 - val_loss: 0.3234
Epoch 195/200
67/67 [==============================] - 0s 981us/step - loss: 0.2428 - val_loss: 0.3182
Epoch 196/200
67/67 [==============================] - 0s 995us/step - loss: 0.2313 - val_loss: 0.3205
Epoch 197/200
67/67 [==============================] - 0s 942us/step - loss: 0.2164 - val_loss: 0.3392
Epoch 198/200
67/67 [==============================] - 0s 996us/step - loss: 0.2222 - val_loss: 0.3222
Epoch 199/200
67/67 [==============================] - 0s 1ms/step - loss: 0.2188 - val_loss: 0.3677
Epoch 200/200
67/67 [==============================] - 0s 957us/step - loss: 0.2224 - val_loss: 0.3099

We can do a vizualisation of our learning curves to make sure the training and validation loss decrease correctly.

[10]:
width = 8
height = width/1.618
plt.figure(figsize=(width, height))

plt.plot(history["loss"], label="training")
plt.plot(history["val_loss"], label="validation")

plt.xlabel("Epochs")
plt.ylabel("MSE")

plt.legend()

plt.show()
../../../../_images/tutorials_tutorials_notebooks_07_NAS_with_Ensemble_and_UQ_tutorial_07_18_0.png

Also, let us look at the prediction on the test set after reversing the scaling of predicted variables.

[11]:
pred_s_ty = model(s_tx)
pred_ty = scaler_y.inverse_transform(pred_s_ty)

width = 8
height = width/1.618
plt.figure(figsize=(width, height))

plt.plot(tx, ty, label="truth")
plt.plot(tx, pred_ty, label="pred")

y_lim = 10
plt.fill_between([-30, -20], [-y_lim, -y_lim], [y_lim, y_lim], color="grey", alpha=0.5)
plt.fill_between([20, 30], [-y_lim, -y_lim], [y_lim, y_lim], color="grey", alpha=0.5)

plt.legend()
plt.ylim(-y_lim, y_lim)

plt.show()
../../../../_images/tutorials_tutorials_notebooks_07_NAS_with_Ensemble_and_UQ_tutorial_07_20_0.png

3.6. Define the Neural Architecture Search Space

The neural architecture search space is composed of discrete decision variables. For each decision variable we choose among a list of possible operation to perform (e.g., fully connected, ReLU). To define this search space, it is necessary to use two classes:

  • KSearchSpace (for Keras Search Space): represents a directed acyclic graph (DAG) in which each node represents a chosen operation. It represents the possible neural networks that can be created.

  • SpaceFactory: is a utilitiy class used to pack the logic of a search space definition and share it with others.

Then, inside a KSearchSpace we will have two types of nodes: * VariableNode: corresponds to discrete decision variables and are used to define a list of possible operation. * ConstantNode: corresponds to fixed operation in the search space (e.g., input/outputs)

Finally, it is possible to reuse any tf.keras.layers to define a KSearchSpace. However, it is important to wrap each layer in an operation to perform a lazy memory allocation of tensors.

[12]:
import collections

from deephyper.nas import KSearchSpace

# Decision variables are represented by nodes in a graph
from deephyper.nas.node import ConstantNode, VariableNode

# The "operation" creates a wrapper around Keras layers avoid allocating
# memory each time a new layer is defined in the search space
# For Skip/Residual connections we use "Zero", "Connect" and "AddByProjecting"
from deephyper.nas.operation import operation, Zero, Connect, AddByProjecting, Identity

Dense = operation(tf.keras.layers.Dense)

# Possible activation functions
ACTIVATIONS = [
    tf.keras.activations.elu,
    tf.keras.activations.gelu,
    tf.keras.activations.hard_sigmoid,
    tf.keras.activations.linear,
    tf.keras.activations.relu,
    tf.keras.activations.selu,
    tf.keras.activations.sigmoid,
    tf.keras.activations.softplus,
    tf.keras.activations.softsign,
    tf.keras.activations.swish,
    tf.keras.activations.tanh,
]

We implement the constructor __init__ and build method of the RegressionSpace a subclass of KSearchSpace. The __init__ method interface is:

def __init__(self, input_shape, output_shape, **kwargs):
    ...

for the build method the interface is:

def build(self):
    ...
    return self

where: * input_shape corresponds to a tuple or a list of tuple indicating the shapes of inputs tensors. * output_shape corresponds to the same but of output_tensors. * **kwargs denotes that any other key word argument can be defined by the user.

[13]:
class RegressionSpace(KSearchSpace):

    def __init__(self, input_shape, output_shape, seed=None, num_layers=3):
        super().__init__(input_shape, output_shape, seed=seed)

        self.num_layers = 3

    def build(self):

        # After creating a KSearchSpace nodes corresponds to the inputs are directly accessible
        out_sub_graph = self.build_sub_graph(self.input_nodes[0], self.num_layers)

        output = ConstantNode(op=Dense(self.output_shape[0]))
        self.connect(out_sub_graph, output)

        return self

    def build_sub_graph(self, input_node, num_layers=3):


        # Look over skip connections within a range of the 3 previous nodes
        anchor_points = collections.deque([input_node], maxlen=3)

        prev_node = input_node

        for _ in range(num_layers):

            # Create a variable node to list possible "Dense" layers
            dense = VariableNode()

            # Add the possible operations to the dense node
            self.add_dense_to_(dense)

            # Connect the previous node to the dense node
            self.connect(prev_node, dense)

            # Create a constant node to merge all input connections
            merge = ConstantNode()
            merge.set_op(
                AddByProjecting(self, [dense], activation="relu")
            )

            for node in anchor_points:

                # Create a variable node for each possible connection
                skipco = VariableNode()

                skipco.add_op(Zero()) # corresponds to no connection
                skipco.add_op(Connect(self, node)) # corresponds to (node => skipco)

                # Connect the (skipco => merge)
                self.connect(skipco, merge)


            # ! for next iter
            prev_node = merge
            anchor_points.append(prev_node)

        return prev_node

    def add_dense_to_(self, node):

        # We add the "Identity" operation to allow the choice "doing nothing"
        node.add_op(Identity())

        step = 16
        for units in range(step, step * 16 + 1, step):
            for activation in ACTIVATIONS:
                node.add_op(Dense(units=units, activation=activation))

Let us visualize a few randomly sampled neural architecture from this search space.

[14]:
import matplotlib.image as mpimg
from tensorflow.keras.utils import plot_model

shapes = dict(input_shape=(1,), output_shape=(1,))
space = RegressionSpace(**shapes).build()


images = []
plt.figure(figsize=(15,15))
for i in range(4):

    model = space.sample()
    plt.subplot(2,2,i+1)
    plot_model(model, "random_model.png", show_shapes=False, show_layer_names=False)
    image = mpimg.imread("random_model.png")
    plt.imshow(image)
    plt.axis('off')

plt.show()
../../../../_images/tutorials_tutorials_notebooks_07_NAS_with_Ensemble_and_UQ_tutorial_07_26_0.png

3.7. Define the Neural Architecture Optimization Problem

In order to define a neural architecture search problem we have to use the NaProblem class. This class gives access to different method for the user to customize the training settings of neural networks.

[15]:
from deephyper.problem import NaProblem


def stdscaler():
    return StandardScaler()


problem = NaProblem()

# Bind a function which returns (train_input, train_output), (valid_input, valid_output)
problem.load_data(load_data_train_valid)

# Bind a function which return a scikit-learn preprocessor (with fit, fit_transform, inv_transform...etc)
problem.preprocessing(stdscaler)

# Bind a function which returns a search space and give some arguments for the `build` method
problem.search_space(RegressionSpace, num_layers=3)

# Define a set of fixed hyperparameters for all trained neural networks
problem.hyperparameters(
    batch_size=4,
    learning_rate=1e-3,
    optimizer="adam",
    num_epochs=200,
    callbacks=dict(
        EarlyStopping=dict(monitor="val_loss", mode="min", verbose=0, patience=30)
    ),
)

# Define the loss to minimize
problem.loss("mse")

# Define complementary metrics
problem.metrics([])

# Define the maximized objective. Here we take the negative of the validation loss.
problem.objective("-val_loss")

problem
[15]:
Problem is:
    - search space   : __main__.RegressionSpace
    - data loading   : __main__.load_data_train_valid
    - preprocessing  : __main__.stdscaler
    - hyperparameters:
        * verbose: 0
        * batch_size: 4
        * learning_rate: 0.001
        * optimizer: adam
        * num_epochs: 200
        * callbacks: {'EarlyStopping': {'monitor': 'val_loss', 'mode': 'min', 'verbose': 0, 'patience': 30}}
    - loss           : mse
    - metrics        :
    - objective      : -val_loss

Tip

Adding an EarlyStopping(...) callback is a good idea to stop the training of your model as soon as it stops to improve.

...
EarlyStopping=dict(monitor="val_loss", mode="min", verbose=0, patience=30)
...

3.8. Define the Evaluator Object

The Evaluator object is responsible of defining the backend used to distribute the function evaluation in DeepHyper.

[16]:
from deephyper.evaluator import Evaluator
from deephyper.evaluator.callback import LoggerCallback


def get_evaluator(run_function):

    # Default arguments for Ray: 1 worker and 1 worker per evaluation
    method_kwargs = {
        "num_cpus": 1,
        "num_cpus_per_task": 1,
        "callbacks": [LoggerCallback()] # To interactively follow the finished evaluations,
    }

    # If GPU devices are detected then it will create 'n_gpus' workers
    # and use 1 worker for each evaluation
    if is_gpu_available:
        method_kwargs["num_cpus"] = n_gpus
        method_kwargs["num_gpus"] = n_gpus
        method_kwargs["num_cpus_per_task"] = 1
        method_kwargs["num_gpus_per_task"] = 1

    evaluator = Evaluator.create(
        run_function,
        method="ray",
        method_kwargs=method_kwargs
    )
    print(f"Created new evaluator with {evaluator.num_workers} worker{'s' if evaluator.num_workers > 1 else ''} and config: {method_kwargs}", )

    return evaluator

For neural architecture search a standard training pipeline is provided by the run_base_trainer function from the deephyper.nas.run module.

[17]:
from deephyper.nas.run import run_base_trainer

3.10. Adding Uncertainty Quantification to the Baseline Neural Network

To add uncertainty estimates we use the Tensorflow Probability library which is fully compatible with the neural architecture search API because it is accessible through Keras layers.

[24]:
import tensorflow_probability as tfp
tfd = tfp.distributions

Then, instead of minimising the mean squared error we will minimize the negative log-likelihood baed on the learned probability distribution \(p(y|\mathbf{x};\theta)\) where \(\theta\) represents a neural network (architecture, training hyperparameters, weights).

[25]:
def nll(y, rv_y):
    """Negative log likelihood for Tensorflow probability.

    Args:
        y: true data.
        rv_y: learned (predicted) probability distribution.
    """
    return -rv_y.log_prob(y)
[26]:
input_ = tf.keras.layers.Input(shape=(1,))
out = tf.keras.layers.Dense(200, activation="relu")(input_)
out = tf.keras.layers.Dense(200, activation="relu")(out)

# For each predicted variable (1) we need the mean and variance estimate
out = tf.keras.layers.Dense(1*2)(out)

# We feed these estimates to output a Normal distribution for each predicted variable
output = tfp.layers.DistributionLambda(
            lambda t: tfd.Normal(
                loc=t[..., :1],
                scale=1e-3 + tf.math.softplus(0.05 * t[..., 1:]), # positive constraint on the standard dev.
            )
        )(out)

model_uq = tf.keras.Model(input_, output)

optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
model_uq.compile(optimizer, loss=nll)

history = model_uq.fit(s_x, s_y, epochs=200, batch_size=4, validation_data=(s_vx, s_vy), verbose=1)
2021-10-29 09:52:19.047916: W tensorflow/python/util/util.cc:348] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
Epoch 1/200
67/67 [==============================] - 1s 3ms/step - loss: 1.5990 - val_loss: 1.6310
Epoch 2/200
67/67 [==============================] - 0s 1ms/step - loss: 1.5394 - val_loss: 1.6104
Epoch 3/200
67/67 [==============================] - 0s 1ms/step - loss: 1.5388 - val_loss: 1.5367
Epoch 4/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4749 - val_loss: 1.4912
Epoch 5/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4394 - val_loss: 1.4594
Epoch 6/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4269 - val_loss: 1.4986
Epoch 7/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4323 - val_loss: 1.4595
Epoch 8/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4439 - val_loss: 1.4499
Epoch 9/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4328 - val_loss: 1.4599
Epoch 10/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4319 - val_loss: 1.4645
Epoch 11/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4408 - val_loss: 1.4611
Epoch 12/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4271 - val_loss: 1.4619
Epoch 13/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4330 - val_loss: 1.4647
Epoch 14/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4358 - val_loss: 1.4543
Epoch 15/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4292 - val_loss: 1.4683
Epoch 16/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4261 - val_loss: 1.4576
Epoch 17/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4212 - val_loss: 1.4592
Epoch 18/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4255 - val_loss: 1.4464
Epoch 19/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4227 - val_loss: 1.4587
Epoch 20/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4344 - val_loss: 1.4787
Epoch 21/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4225 - val_loss: 1.4488
Epoch 22/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4261 - val_loss: 1.4586
Epoch 23/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4207 - val_loss: 1.4644
Epoch 24/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4250 - val_loss: 1.4559
Epoch 25/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4205 - val_loss: 1.4476
Epoch 26/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4239 - val_loss: 1.4473
Epoch 27/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4243 - val_loss: 1.4515
Epoch 28/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4159 - val_loss: 1.4653
Epoch 29/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4223 - val_loss: 1.4616
Epoch 30/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4179 - val_loss: 1.4603
Epoch 31/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4149 - val_loss: 1.4649
Epoch 32/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4168 - val_loss: 1.4495
Epoch 33/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4175 - val_loss: 1.4619
Epoch 34/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4142 - val_loss: 1.4438
Epoch 35/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4205 - val_loss: 1.4454
Epoch 36/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4138 - val_loss: 1.4684
Epoch 37/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4140 - val_loss: 1.4498
Epoch 38/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4218 - val_loss: 1.4761
Epoch 39/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4108 - val_loss: 1.4470
Epoch 40/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4073 - val_loss: 1.4839
Epoch 41/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4196 - val_loss: 1.4428
Epoch 42/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4118 - val_loss: 1.4692
Epoch 43/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4093 - val_loss: 1.4411
Epoch 44/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4117 - val_loss: 1.4406
Epoch 45/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4043 - val_loss: 1.4663
Epoch 46/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4070 - val_loss: 1.4440
Epoch 47/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4138 - val_loss: 1.4390
Epoch 48/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4015 - val_loss: 1.4458
Epoch 49/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4250 - val_loss: 1.4409
Epoch 50/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4057 - val_loss: 1.4760
Epoch 51/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4159 - val_loss: 1.4411
Epoch 52/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4171 - val_loss: 1.4459
Epoch 53/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4053 - val_loss: 1.4432
Epoch 54/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4073 - val_loss: 1.4457
Epoch 55/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4009 - val_loss: 1.4514
Epoch 56/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4130 - val_loss: 1.4528
Epoch 57/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4085 - val_loss: 1.4500
Epoch 58/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4027 - val_loss: 1.4337
Epoch 59/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3984 - val_loss: 1.4413
Epoch 60/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3967 - val_loss: 1.4368
Epoch 61/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4150 - val_loss: 1.4343
Epoch 62/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4024 - val_loss: 1.4350
Epoch 63/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3946 - val_loss: 1.4354
Epoch 64/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4000 - val_loss: 1.4699
Epoch 65/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3969 - val_loss: 1.4397
Epoch 66/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3967 - val_loss: 1.4463
Epoch 67/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4031 - val_loss: 1.4372
Epoch 68/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3926 - val_loss: 1.4449
Epoch 69/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3969 - val_loss: 1.4316
Epoch 70/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3933 - val_loss: 1.4698
Epoch 71/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4074 - val_loss: 1.4536
Epoch 72/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4019 - val_loss: 1.4372
Epoch 73/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3983 - val_loss: 1.4479
Epoch 74/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4011 - val_loss: 1.4339
Epoch 75/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3933 - val_loss: 1.4380
Epoch 76/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3880 - val_loss: 1.4647
Epoch 77/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4035 - val_loss: 1.4459
Epoch 78/200
67/67 [==============================] - 0s 1ms/step - loss: 1.4007 - val_loss: 1.4488
Epoch 79/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3965 - val_loss: 1.4428
Epoch 80/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3993 - val_loss: 1.4470
Epoch 81/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3910 - val_loss: 1.4333
Epoch 82/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3961 - val_loss: 1.4271
Epoch 83/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3961 - val_loss: 1.4376
Epoch 84/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3891 - val_loss: 1.4405
Epoch 85/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3871 - val_loss: 1.4329
Epoch 86/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3925 - val_loss: 1.4633
Epoch 87/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3965 - val_loss: 1.4732
Epoch 88/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3995 - val_loss: 1.4752
Epoch 89/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3913 - val_loss: 1.4565
Epoch 90/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3820 - val_loss: 1.4258
Epoch 91/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3882 - val_loss: 1.4231
Epoch 92/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3918 - val_loss: 1.4166
Epoch 93/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3813 - val_loss: 1.4213
Epoch 94/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3826 - val_loss: 1.4187
Epoch 95/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3791 - val_loss: 1.4164
Epoch 96/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3826 - val_loss: 1.4247
Epoch 97/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3754 - val_loss: 1.4138
Epoch 98/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3703 - val_loss: 1.4227
Epoch 99/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3711 - val_loss: 1.4443
Epoch 100/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3803 - val_loss: 1.4070
Epoch 101/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3797 - val_loss: 1.4129
Epoch 102/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3638 - val_loss: 1.4096
Epoch 103/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3676 - val_loss: 1.4213
Epoch 104/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3591 - val_loss: 1.4065
Epoch 105/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3688 - val_loss: 1.4185
Epoch 106/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3521 - val_loss: 1.3930
Epoch 107/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3566 - val_loss: 1.4023
Epoch 108/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3483 - val_loss: 1.3949
Epoch 109/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3483 - val_loss: 1.4045
Epoch 110/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3433 - val_loss: 1.4223
Epoch 111/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3444 - val_loss: 1.3978
Epoch 112/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3394 - val_loss: 1.4009
Epoch 113/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3459 - val_loss: 1.3916
Epoch 114/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3344 - val_loss: 1.3711
Epoch 115/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3366 - val_loss: 1.4002
Epoch 116/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3199 - val_loss: 1.3647
Epoch 117/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3245 - val_loss: 1.3759
Epoch 118/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3239 - val_loss: 1.3595
Epoch 119/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3140 - val_loss: 1.3631
Epoch 120/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3050 - val_loss: 1.3855
Epoch 121/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3309 - val_loss: 1.3643
Epoch 122/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3050 - val_loss: 1.3535
Epoch 123/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3104 - val_loss: 1.3497
Epoch 124/200
67/67 [==============================] - 0s 1ms/step - loss: 1.2973 - val_loss: 1.3304
Epoch 125/200
67/67 [==============================] - 0s 1ms/step - loss: 1.2944 - val_loss: 1.3448
Epoch 126/200
67/67 [==============================] - 0s 1ms/step - loss: 1.3199 - val_loss: 1.3595
Epoch 127/200
67/67 [==============================] - 0s 1ms/step - loss: 1.2815 - val_loss: 1.3211
Epoch 128/200
67/67 [==============================] - 0s 1ms/step - loss: 1.2841 - val_loss: 1.3954
Epoch 129/200
67/67 [==============================] - 0s 1ms/step - loss: 1.2834 - val_loss: 1.3440
Epoch 130/200
67/67 [==============================] - 0s 1ms/step - loss: 1.2753 - val_loss: 1.3197
Epoch 131/200
67/67 [==============================] - 0s 1ms/step - loss: 1.2694 - val_loss: 1.3059
Epoch 132/200
67/67 [==============================] - 0s 1ms/step - loss: 1.2450 - val_loss: 1.3094
Epoch 133/200
67/67 [==============================] - 0s 1ms/step - loss: 1.2424 - val_loss: 1.2674
Epoch 134/200
67/67 [==============================] - 0s 1ms/step - loss: 1.2554 - val_loss: 1.2714
Epoch 135/200
67/67 [==============================] - 0s 1ms/step - loss: 1.2358 - val_loss: 1.3275
Epoch 136/200
67/67 [==============================] - 0s 1ms/step - loss: 1.2185 - val_loss: 1.2633
Epoch 137/200
67/67 [==============================] - 0s 1ms/step - loss: 1.2277 - val_loss: 1.2389
Epoch 138/200
67/67 [==============================] - 0s 1ms/step - loss: 1.2141 - val_loss: 1.2603
Epoch 139/200
67/67 [==============================] - 0s 1ms/step - loss: 1.1824 - val_loss: 1.2820
Epoch 140/200
67/67 [==============================] - 0s 1ms/step - loss: 1.1883 - val_loss: 1.2078
Epoch 141/200
67/67 [==============================] - 0s 1ms/step - loss: 1.1811 - val_loss: 1.3008
Epoch 142/200
67/67 [==============================] - 0s 1ms/step - loss: 1.1649 - val_loss: 1.2258
Epoch 143/200
67/67 [==============================] - 0s 1ms/step - loss: 1.1830 - val_loss: 1.1801
Epoch 144/200
67/67 [==============================] - 0s 1ms/step - loss: 1.1320 - val_loss: 1.1932
Epoch 145/200
67/67 [==============================] - 0s 1ms/step - loss: 1.1369 - val_loss: 1.1730
Epoch 146/200
67/67 [==============================] - 0s 1ms/step - loss: 1.1522 - val_loss: 1.1883
Epoch 147/200
67/67 [==============================] - 0s 1ms/step - loss: 1.1198 - val_loss: 1.1289
Epoch 148/200
67/67 [==============================] - 0s 1ms/step - loss: 1.1170 - val_loss: 1.1695
Epoch 149/200
67/67 [==============================] - 0s 1ms/step - loss: 1.1040 - val_loss: 1.1062
Epoch 150/200
67/67 [==============================] - 0s 1ms/step - loss: 1.0694 - val_loss: 1.1602
Epoch 151/200
67/67 [==============================] - 0s 1ms/step - loss: 1.0973 - val_loss: 1.1509
Epoch 152/200
67/67 [==============================] - 0s 1ms/step - loss: 1.0618 - val_loss: 1.0855
Epoch 153/200
67/67 [==============================] - 0s 1ms/step - loss: 1.0469 - val_loss: 1.0444
Epoch 154/200
67/67 [==============================] - 0s 1ms/step - loss: 1.0495 - val_loss: 1.0442
Epoch 155/200
67/67 [==============================] - 0s 1ms/step - loss: 1.0002 - val_loss: 1.1552
Epoch 156/200
67/67 [==============================] - 0s 1ms/step - loss: 0.9914 - val_loss: 1.0057
Epoch 157/200
67/67 [==============================] - 0s 1ms/step - loss: 0.9686 - val_loss: 1.0287
Epoch 158/200
67/67 [==============================] - 0s 1ms/step - loss: 1.1138 - val_loss: 1.1875
Epoch 159/200
67/67 [==============================] - 0s 1ms/step - loss: 0.9824 - val_loss: 1.0634
Epoch 160/200
67/67 [==============================] - 0s 1ms/step - loss: 0.9392 - val_loss: 0.9828
Epoch 161/200
67/67 [==============================] - 0s 1ms/step - loss: 0.9241 - val_loss: 1.0409
Epoch 162/200
67/67 [==============================] - 0s 1ms/step - loss: 0.9242 - val_loss: 0.9957
Epoch 163/200
67/67 [==============================] - 0s 1ms/step - loss: 0.9510 - val_loss: 1.0923
Epoch 164/200
67/67 [==============================] - 0s 1ms/step - loss: 0.8898 - val_loss: 1.0067
Epoch 165/200
67/67 [==============================] - 0s 1ms/step - loss: 0.9533 - val_loss: 0.8918
Epoch 166/200
67/67 [==============================] - 0s 1ms/step - loss: 0.9179 - val_loss: 0.9627
Epoch 167/200
67/67 [==============================] - 0s 1ms/step - loss: 0.8661 - val_loss: 0.9034
Epoch 168/200
67/67 [==============================] - 0s 1ms/step - loss: 0.8696 - val_loss: 1.1602
Epoch 169/200
67/67 [==============================] - 0s 1ms/step - loss: 0.9438 - val_loss: 0.8882
Epoch 170/200
67/67 [==============================] - 0s 1ms/step - loss: 0.8712 - val_loss: 0.8547
Epoch 171/200
67/67 [==============================] - 0s 1ms/step - loss: 0.8884 - val_loss: 0.8886
Epoch 172/200
67/67 [==============================] - 0s 1ms/step - loss: 0.8917 - val_loss: 0.8640
Epoch 173/200
67/67 [==============================] - 0s 1ms/step - loss: 0.8299 - val_loss: 0.8614
Epoch 174/200
67/67 [==============================] - 0s 1ms/step - loss: 0.9755 - val_loss: 1.4353
Epoch 175/200
67/67 [==============================] - 0s 1ms/step - loss: 0.8967 - val_loss: 0.9240
Epoch 176/200
67/67 [==============================] - 0s 1ms/step - loss: 0.8329 - val_loss: 1.0151
Epoch 177/200
67/67 [==============================] - 0s 1ms/step - loss: 0.8825 - val_loss: 0.8877
Epoch 178/200
67/67 [==============================] - 0s 1ms/step - loss: 0.7946 - val_loss: 0.8850
Epoch 179/200
67/67 [==============================] - 0s 1ms/step - loss: 0.8628 - val_loss: 1.0449
Epoch 180/200
67/67 [==============================] - 0s 1ms/step - loss: 0.8027 - val_loss: 0.8317
Epoch 181/200
67/67 [==============================] - 0s 1ms/step - loss: 0.7829 - val_loss: 0.8131
Epoch 182/200
67/67 [==============================] - 0s 1ms/step - loss: 0.8097 - val_loss: 0.8427
Epoch 183/200
67/67 [==============================] - 0s 1ms/step - loss: 0.8298 - val_loss: 0.8044
Epoch 184/200
67/67 [==============================] - 0s 1ms/step - loss: 0.8350 - val_loss: 0.8432
Epoch 185/200
67/67 [==============================] - 0s 1ms/step - loss: 0.7747 - val_loss: 0.8587
Epoch 186/200
67/67 [==============================] - 0s 1ms/step - loss: 0.8348 - val_loss: 0.9781
Epoch 187/200
67/67 [==============================] - 0s 1ms/step - loss: 0.8056 - val_loss: 1.0371
Epoch 188/200
67/67 [==============================] - 0s 1ms/step - loss: 0.8461 - val_loss: 0.8138
Epoch 189/200
67/67 [==============================] - 0s 1ms/step - loss: 0.7835 - val_loss: 0.9100
Epoch 190/200
67/67 [==============================] - 0s 1ms/step - loss: 0.7852 - val_loss: 0.9815
Epoch 191/200
67/67 [==============================] - 0s 1ms/step - loss: 0.7624 - val_loss: 0.9432
Epoch 192/200
67/67 [==============================] - 0s 1ms/step - loss: 0.8102 - val_loss: 0.9019
Epoch 193/200
67/67 [==============================] - 0s 1ms/step - loss: 0.8075 - val_loss: 0.7932
Epoch 194/200
67/67 [==============================] - 0s 1ms/step - loss: 0.7962 - val_loss: 0.8306
Epoch 195/200
67/67 [==============================] - 0s 1ms/step - loss: 0.8372 - val_loss: 0.9382
Epoch 196/200
67/67 [==============================] - 0s 1ms/step - loss: 0.7559 - val_loss: 0.7840
Epoch 197/200
67/67 [==============================] - 0s 1ms/step - loss: 0.7268 - val_loss: 0.8268
Epoch 198/200
67/67 [==============================] - 0s 1ms/step - loss: 0.7763 - val_loss: 0.8148
Epoch 199/200
67/67 [==============================] - 0s 1ms/step - loss: 0.7768 - val_loss: 0.9372
Epoch 200/200
67/67 [==============================] - 0s 1ms/step - loss: 0.7356 - val_loss: 0.8664
[27]:
width = 8
height = width/1.618
plt.figure(figsize=(width, height))

plt.plot(history.history["loss"], label="training")
plt.plot(history.history["val_loss"], label="validation")

plt.xlabel("Epochs")
plt.ylabel("NLL")

plt.legend()

plt.show()
../../../../_images/tutorials_tutorials_notebooks_07_NAS_with_Ensemble_and_UQ_tutorial_07_50_0.png

Let us visualize the learned uncertainty estimates.

[28]:
pred_s_ty = model_uq(s_tx)

pred_ty_mean = pred_s_ty.loc.numpy() + scaler_y.mean_
pred_ty_var = np.square(pred_s_ty.scale.numpy()) * scaler_y.var_

width = 8
height = width/1.618
plt.figure(figsize=(width, height))

plt.plot(tx, ty, label="truth")
plt.plot(tx, pred_ty_mean, label="$\mu$")
plt.fill_between(
    tx.reshape(-1),
    (pred_ty_mean - pred_ty_var).reshape(-1),
    (pred_ty_mean + pred_ty_var).reshape(-1),
    color="orange",
    alpha=0.5,
    label="$\sigma^2$"
)

y_lim = 10
plt.fill_between([-30, -20], [-y_lim, -y_lim], [y_lim, y_lim], color="grey", alpha=0.5)
plt.fill_between([20, 30], [-y_lim, -y_lim], [y_lim, y_lim], color="grey", alpha=0.5)

plt.legend()
plt.ylim(-y_lim, y_lim)

plt.show()
../../../../_images/tutorials_tutorials_notebooks_07_NAS_with_Ensemble_and_UQ_tutorial_07_52_0.png

The learned mean estimates appears to be worse than when minimizing the mean squared error loss. Also, we can see than the variance estimate are not meaningful in areas missing data (white background) and do not learn properly the noise in are with data (grey background).

3.11. Ensemble of Neural Networks With Random Initialization

The uncertainty estimate of a single neural network corresponds to aleatoric uncertainty (i.e., intrinsic noise of the data). To estimate the epistemic uncertainty, composed of estimation uncertainty (e.g., optimization algorithm) and model uncertainty (e.g., hypothesis space of models), we need to quantify the variation of predictions for different models and estimation. One of the most basic method to do it is to keep a fixed neural network architecture and training hyperparameters to then re-train it multiple times from different random weight initialization.

[29]:
def generate_model(model_id):

    # Model
    input_ = tf.keras.layers.Input(shape=(1,))
    out = tf.keras.layers.Dense(200, activation="relu")(input_)
    out = tf.keras.layers.Dense(200, activation="relu")(out)
    out = tf.keras.layers.Dense(2)(out) # 1 unit for the mean, 1 unit for the scale
    output = tfp.layers.DistributionLambda(
                lambda t: tfd.Normal(
                    loc=t[..., :1],
                    scale=1e-3 + tf.math.softplus(0.05 * t[..., 1:]),
                )
            )(out)
    model_uq = tf.keras.Model(input_, output)


    model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
        filepath=os.path.join("models_random_init", f"{model_id}.h5"),
        monitor='val_loss',
        verbose=0,
        save_best_only=True,
        save_weights_only=False,
        mode='min',
        save_freq='epoch'
    )

    optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
    model_uq.compile(optimizer, loss=nll)

    history = model_uq.fit(s_x, s_y,
                           epochs=200,
                           batch_size=4,
                           validation_data=(s_vx, s_vy),
                           verbose=0,
                           callbacks=[model_checkpoint_callback]
                          ).history
    return history["val_loss"][-1]

if is_gpu_available:
    generate_model = ray.remote(num_cpus=1, num_gpus=1)(generate_model)
else:
    generate_model = ray.remote(num_cpus=1)(generate_model)

We generate n_models from different random weight initializations. The computation can be distributed on different GPUs or CPU cores if we use Ray .remote(...) calls.

[30]:
if os.path.exists("models_random_init"):
    shutil.rmtree("models_random_init")
pathlib.Path("models_random_init").mkdir(parents=False, exist_ok=False)

n_models = 10
scores = ray.get([generate_model.remote(model_id) for model_id in range(n_models)])

print(scores)
[0.8486694693565369, 0.8497845530509949, 0.882605791091919, 0.8592196702957153, 0.9243467450141907, 0.8486694693565369, 1.1459598541259766, 1.2993806600570679, 0.9240729212760925, 0.8383921980857849]

The UQBaggingEnsembleRegressor provides different strategies to build ensemble of neural networks from a library of saved models. The computation is distributed with Ray (for the inference and ranking of ensemble members).

[31]:
from deephyper.ensemble import UQBaggingEnsembleRegressor

ensemble = UQBaggingEnsembleRegressor(
    model_dir="models_random_init",
    loss=nll,  # default is nll
    size=5,
    verbose=True,
    ray_address="auto",
    num_cpus=1,
    num_gpus=1 if is_gpu_available else None,
    selection="topk",
)
[32]:
# Follow the Scikit-Learn fit/predict interface
ensemble.fit(s_vx, s_vy)

print(f"Selected members are: ", ensemble.members_files)
Selected members are:  ['9.h5', '3.h5', '5.h5', '0.h5', '1.h5']

We can visualize the uncertainty estimate of such ensemble.

[33]:
pred_s_ty, pred_s_ty_aleatoric_var, pred_s_ty_epistemic_var = ensemble.predict_var_decomposition(s_tx)

pred_ty_mean = pred_s_ty.loc.numpy() + scaler_y.mean_
pred_ty_aleatoric_var = pred_s_ty_aleatoric_var * scaler_y.var_
pred_ty_epistemic_var = pred_s_ty_epistemic_var * scaler_y.var_

width = 8
height = width/1.618
plt.figure(figsize=(width, height))

plt.plot(tx, ty, label="truth")
plt.plot(tx, pred_ty_mean, label="$\mu$")
plt.fill_between(
    tx.reshape(-1),
    (pred_ty_mean - pred_ty_aleatoric_var).reshape(-1),
    (pred_ty_mean + pred_ty_aleatoric_var).reshape(-1),
    color="yellow",
    alpha=0.5,
    label="$E[\sigma^2]$: aleatoric"
)
plt.fill_between(
    tx.reshape(-1),
    (pred_ty_mean - pred_ty_aleatoric_var).reshape(-1),
    (pred_ty_mean - pred_ty_aleatoric_var - pred_s_ty_epistemic_var).reshape(-1),
    color="orange",
    alpha=0.5,
    label="$V[\mu]$: epistemic"
)
plt.fill_between(
    tx.reshape(-1),
    (pred_ty_mean + pred_ty_aleatoric_var).reshape(-1),
    (pred_ty_mean + pred_ty_aleatoric_var + pred_s_ty_epistemic_var).reshape(-1),
    color="orange",
    alpha=0.5,
#     label="$V[\mu]$: epistemic"
)

y_lim = 10
plt.fill_between([-30, -20], [-y_lim, -y_lim], [y_lim, y_lim], color="grey", alpha=0.5)
plt.fill_between([20, 30], [-y_lim, -y_lim], [y_lim, y_lim], color="grey", alpha=0.5)

plt.legend()
plt.ylim(-y_lim, y_lim)

plt.show()
../../../../_images/tutorials_tutorials_notebooks_07_NAS_with_Ensemble_and_UQ_tutorial_07_62_0.png

By using the Law of Total Variance we can decompose the aleatoric and epistemic components of the predicted uncertainty. With random-initialization we can see that epistemic uncertainty is almost null everywhere and not informative on area missing data (white background).

3.12. AutoDEUQ: Automated Deep Ensemble with Uncertainty Quantification

AutoDEUQ is an algorithm in 2 steps: 1. joint hyperparameter and neural architecture search to generate a catalog of models. 2. build an ensemble from the catalog

To this end we start by editing slightly the previous RegressionFactory by adding the DistributionLambda layer as output.

[34]:
DistributionLambda = operation(tfp.layers.DistributionLambda)
[35]:
class RegressionUQSpace(KSearchSpace):

    def __init__(self, input_shape, output_shape, seed=None, num_layers=3):
        super().__init__(input_shape, output_shape, seed=seed)

        self.num_layers = 3

    def build(self):

        out_sub_graph = self.build_sub_graph(self.input_nodes[0], self.num_layers)

        output_dim = self.output_shape[0]
        output_dense = ConstantNode(op=Dense(output_dim*2))
        self.connect(out_sub_graph, output_dense)


        output_dist = ConstantNode(
            op=DistributionLambda(
                lambda t: tfd.Normal(
                    loc=t[..., :output_dim],
                    scale=1e-3 + tf.math.softplus(0.05 * t[..., output_dim:]),
                )
            )
        )
        self.connect(output_dense, output_dist)

        return self

    def build_sub_graph(self, input_node, num_layers=3):


        # Look over skip connections within a range of the 3 previous nodes
        anchor_points = collections.deque([input_node], maxlen=3)

        prev_node = input_node

        for _ in range(num_layers):

            # Create a variable node to list possible "Dense" layers
            dense = VariableNode()

            # Add the possible operations to the dense node
            self.add_dense_to_(dense)

            # Connect the previous node to the dense node
            self.connect(prev_node, dense)

            # Create a constant node to merge all input connections
            merge = ConstantNode()
            merge.set_op(
                AddByProjecting(self, [dense], activation="relu")
            )

            for node in anchor_points:

                # Create a variable node for each possible connection
                skipco = VariableNode()

                skipco.add_op(Zero()) # corresponds to no connection
                skipco.add_op(Connect(self, node)) # corresponds to (node => skipco)

                # Connect the (skipco => merge)
                self.connect(skipco, merge)


            # ! for next iter
            prev_node = merge
            anchor_points.append(prev_node)

        return prev_node

    def add_dense_to_(self, node):

        # We add the "Identity" operation to allow the choice "doing nothing"
        node.add_op(Identity())

        step = 16
        for units in range(step, step * 16 + 1, step):
            for activation in ACTIVATIONS:
                node.add_op(Dense(units=units, activation=activation))

For joint hyperparameter and neural architecture search it is possible to use the problem.add_hyperparameter(...) to define variable hyperparameters in the NAS Problem.

[36]:
problem_uq = NaProblem()

problem_uq.load_data(load_data_train_valid)

problem_uq.preprocessing(stdscaler)

problem_uq.search_space(RegressionUQSpace, num_layers=3)

problem_uq.hyperparameters(
    batch_size=problem_uq.add_hyperparameter((1, 32), "batch_size"),
    learning_rate=problem_uq.add_hyperparameter(
        (1e-4, 0.1, "log-uniform"),
        "learning_rate",
    ),
    optimizer=problem_uq.add_hyperparameter(
        ["sgd", "rmsprop", "adagrad", "adam", "adadelta", "adamax", "nadam"],
        "optimizer",
    ),
    patience_ReduceLROnPlateau=problem_uq.add_hyperparameter(
        (10, 20), "patience_ReduceLROnPlateau"
    ),
    patience_EarlyStopping=problem_uq.add_hyperparameter(
        (20, 30), "patience_EarlyStopping"
    ),
    num_epochs=200,
    callbacks=dict(
        ReduceLROnPlateau=dict(monitor="val_loss", mode="min", verbose=0, patience=5),
        EarlyStopping=dict(monitor="val_loss", mode="min", verbose=0, patience=10),
        # We save trained models in neural architecture search
        ModelCheckpoint=dict(
            monitor="val_loss",
            mode="min",
            save_best_only=True,
            verbose=0,
            filepath="model.h5",
            save_weights_only=False,
        ),
    ),
)

problem_uq.loss(nll)

problem_uq.metrics([])

# The objective is maximized so we take the negative of the validation loss
# where the loss is minimized
problem_uq.objective("-val_loss")

problem_uq
[36]:
Problem is:
    - search space   : __main__.RegressionUQSpace
    - data loading   : __main__.load_data_train_valid
    - preprocessing  : __main__.stdscaler
    - hyperparameters:
        * verbose: 0
        * batch_size: batch_size, Type: UniformInteger, Range: [1, 32], Default: 16
        * learning_rate: learning_rate, Type: UniformFloat, Range: [0.0001, 0.1], Default: 0.0031622777, on log-scale
        * optimizer: optimizer, Type: Categorical, Choices: {sgd, rmsprop, adagrad, adam, adadelta, adamax, nadam}, Default: sgd
        * patience_ReduceLROnPlateau: patience_ReduceLROnPlateau, Type: UniformInteger, Range: [10, 20], Default: 15
        * patience_EarlyStopping: patience_EarlyStopping, Type: UniformInteger, Range: [20, 30], Default: 25
        * num_epochs: 200
        * callbacks: {'ReduceLROnPlateau': {'monitor': 'val_loss', 'mode': 'min', 'verbose': 0, 'patience': 5}, 'EarlyStopping': {'monitor': 'val_loss', 'mode': 'min', 'verbose': 0, 'patience': 10}, 'ModelCheckpoint': {'monitor': 'val_loss', 'mode': 'min', 'save_best_only': True, 'verbose': 0, 'filepath': 'model.h5', 'save_weights_only': False}}
    - loss           : <function nll at 0x7fb720a828b0>
    - metrics        :
    - objective      : -val_loss
[37]:
results_uq = {}

The max_evals has to be superior or equal to 400 to start having a good UQ estimate.

[40]:
from deephyper.search.nas import AgEBO


if os.path.exists("agebo_search"):
    shutil.rmtree("agebo_search")

# "n_jobs" is the number of processes used to refresh the state of the surrogate model used in AgEBO
agebo_search = AgEBO(problem_uq, get_evaluator(run_base_trainer), log_dir="agebo_search", n_jobs=4)

results_uq["agebo"] = agebo_search.search(max_evals=500)
Created new evaluator with 4 workers and config: {'num_cpus': 1, 'num_cpus_per_task': 1, 'callbacks': [<deephyper.evaluator.callback.LoggerCallback object at 0x7fb71be53cd0>]}
{'acq_optimizer': 'sampling', 'acq_optimizer_kwargs': {'n_points': 10000, 'filter_duplicated': False}, 'dimensions': Configuration space object:
  Hyperparameters:
    batch_size, Type: UniformInteger, Range: [1, 32], Default: 16
    learning_rate, Type: UniformFloat, Range: [0.0001, 0.1], Default: 0.0031622777, on log-scale
    optimizer, Type: Categorical, Choices: {sgd, rmsprop, adagrad, adam, adadelta, adamax, nadam}, Default: sgd
    patience_EarlyStopping, Type: UniformInteger, Range: [20, 30], Default: 25
    patience_ReduceLROnPlateau, Type: UniformInteger, Range: [10, 20], Default: 15
, 'base_estimator': RandomForestRegressor(n_jobs=4, random_state=2147483648), 'acq_func': 'LCB', 'acq_func_kwargs': {'xi': 1e-06, 'kappa': 0.001}, 'n_initial_points': 4, 'random_state': 2147483648}
[00001] -- best objective: -1.41546 -- received objective: -1.41546
[00002] -- best objective: -1.41546 -- received objective: -1.45327
[00003] -- best objective: -1.41546 -- received objective: -1.56062
[00004] -- best objective: -1.41546 -- received objective: -1.61397
[00005] -- best objective: -1.41546 -- received objective: -1.43990
[00006] -- best objective: -1.41546 -- received objective: -1.79926
[00007] -- best objective: -1.41546 -- received objective: -1.45409
[00008] -- best objective: -1.41546 -- received objective: -1.45558
[00009] -- best objective: -1.41546 -- received objective: -1.44694
[00010] -- best objective: -1.41546 -- received objective: -1.42703
[00011] -- best objective: -0.81093 -- received objective: -0.81093
[00012] -- best objective: -0.81093 -- received objective: -1.43291
[00013] -- best objective: -0.81093 -- received objective: -1.43775
[00014] -- best objective: -0.81093 -- received objective: -1.43076
[00015] -- best objective: -0.81093 -- received objective: -1.43275
[00016] -- best objective: -0.81093 -- received objective: -1.43637
[00017] -- best objective: -0.81093 -- received objective: -1.43495
[00018] -- best objective: -0.81093 -- received objective: -1.30267
[00019] -- best objective: -0.81093 -- received objective: -1.57078
[00020] -- best objective: -0.81093 -- received objective: -1.38905
[00021] -- best objective: -0.81093 -- received objective: -1.50893
[00022] -- best objective: -0.81093 -- received objective: -1.43510
[00023] -- best objective: -0.81093 -- received objective: -1.45480
[00024] -- best objective: -0.81093 -- received objective: -1.21582
[00025] -- best objective: -0.81093 -- received objective: -0.91147
[00026] -- best objective: -0.81093 -- received objective: -5.48971
[00027] -- best objective: -0.81093 -- received objective: -1.44399
[00028] -- best objective: -0.81093 -- received objective: -1.07231
[00029] -- best objective: -0.81093 -- received objective: -1.44112
[00030] -- best objective: -0.81093 -- received objective: -1.45602
[00031] -- best objective: -0.81093 -- received objective: -1.00021
[00032] -- best objective: -0.81093 -- received objective: -1.46971
[00033] -- best objective: -0.81093 -- received objective: -1.45325
[00034] -- best objective: -0.81093 -- received objective: -1.45220
[00035] -- best objective: -0.81093 -- received objective: -1.45069
[00036] -- best objective: -0.81093 -- received objective: -1.42688
[00037] -- best objective: -0.81093 -- received objective: -1.43362
[00038] -- best objective: -0.81093 -- received objective: -1.42800
[00039] -- best objective: -0.81093 -- received objective: -1.43210
[00040] -- best objective: -0.81093 -- received objective: -1.36331
[00041] -- best objective: -0.81093 -- received objective: -1.43566
[00042] -- best objective: -0.81093 -- received objective: -1.45410
[00043] -- best objective: -0.81093 -- received objective: -1.39384
[00044] -- best objective: -0.81093 -- received objective: -1.42321
[00045] -- best objective: -0.81093 -- received objective: -1.42897
[00046] -- best objective: -0.81093 -- received objective: -1.44908
[00047] -- best objective: -0.81093 -- received objective: -1.47396
[00048] -- best objective: -0.81093 -- received objective: -1.43180
[00049] -- best objective: -0.81093 -- received objective: -1.43014
[00050] -- best objective: -0.81093 -- received objective: -1.45732
[00051] -- best objective: -0.81093 -- received objective: -2.72800
[00052] -- best objective: -0.81093 -- received objective: -1.43124
[00053] -- best objective: -0.81093 -- received objective: -1.45690
[00054] -- best objective: -0.81093 -- received objective: -1.45252
[00055] -- best objective: -0.81093 -- received objective: -1.42997
[00056] -- best objective: -0.81093 -- received objective: -1.44718
[00057] -- best objective: -0.81093 -- received objective: -2.74708
[00058] -- best objective: -0.81093 -- received objective: -1.45145
[00059] -- best objective: -0.81093 -- received objective: -1.44711
[00060] -- best objective: -0.81093 -- received objective: -1.41510
[00061] -- best objective: -0.81093 -- received objective: -1.45046
[00062] -- best objective: -0.81093 -- received objective: -1.45949
[00063] -- best objective: -0.81093 -- received objective: -1.45289
[00064] -- best objective: -0.81093 -- received objective: -1.41439
[00065] -- best objective: -0.81093 -- received objective: -3.31526
[00066] -- best objective: -0.81093 -- received objective: -1.45728
[00067] -- best objective: -0.81093 -- received objective: -4.07340
[00068] -- best objective: -0.81093 -- received objective: -1.45544
[00069] -- best objective: -0.81093 -- received objective: -1.44247
[00070] -- best objective: -0.81093 -- received objective: -1.45364
[00071] -- best objective: -0.81093 -- received objective: -1.42052
[00072] -- best objective: -0.81093 -- received objective: -1.43449
[00073] -- best objective: -0.71931 -- received objective: -0.71931
[00074] -- best objective: -0.71931 -- received objective: -1.42743
[00075] -- best objective: -0.71931 -- received objective: -1.45668
[00076] -- best objective: -0.71931 -- received objective: -1.45094
[00077] -- best objective: -0.71931 -- received objective: -1.45678
[00078] -- best objective: -0.71931 -- received objective: -1.45237
[00079] -- best objective: -0.71931 -- received objective: -1.43688
[00080] -- best objective: -0.71931 -- received objective: -1.44689
[00081] -- best objective: -0.71931 -- received objective: -1.45454
[00082] -- best objective: -0.71931 -- received objective: -1.45504
[00083] -- best objective: -0.71931 -- received objective: -1.45478
[00084] -- best objective: -0.71931 -- received objective: -1.46439
[00085] -- best objective: -0.71931 -- received objective: -1.44304
[00086] -- best objective: -0.71931 -- received objective: -1.44873
[00087] -- best objective: -0.71931 -- received objective: -1.45605
[00088] -- best objective: -0.71931 -- received objective: -1.44172
[00089] -- best objective: -0.71931 -- received objective: -1.45658
[00090] -- best objective: -0.71931 -- received objective: -1.42659
[00091] -- best objective: -0.71931 -- received objective: -1.43165
[00092] -- best objective: -0.71931 -- received objective: -1.42960
[00093] -- best objective: -0.71931 -- received objective: -1.39466
[00094] -- best objective: -0.71931 -- received objective: -1.44494
[00095] -- best objective: -0.71931 -- received objective: -1.42991
[00096] -- best objective: -0.71931 -- received objective: -1.45781
[00097] -- best objective: -0.71931 -- received objective: -1.45678
[00098] -- best objective: -0.71931 -- received objective: -1.55824
[00099] -- best objective: -0.71931 -- received objective: -1.45536
[00100] -- best objective: -0.71931 -- received objective: -1.42004
[00101] -- best objective: -0.71931 -- received objective: -0.80708
[00102] -- best objective: -0.71931 -- received objective: -1.41861
[00103] -- best objective: -0.71931 -- received objective: -1.43940
[00104] -- best objective: -0.71931 -- received objective: -1.43955
[00105] -- best objective: -0.71931 -- received objective: -1.45978
[00106] -- best objective: -0.71931 -- received objective: -1.43033
[00107] -- best objective: -0.71931 -- received objective: -1.43704
[00108] -- best objective: -0.71931 -- received objective: -1.43700
[00109] -- best objective: -0.71931 -- received objective: -1.42785
[00110] -- best objective: -0.71931 -- received objective: -1.43107
[00111] -- best objective: -0.71931 -- received objective: -1.45524
[00112] -- best objective: -0.71931 -- received objective: -1.44868
[00113] -- best objective: -0.71931 -- received objective: -1.43329
[00114] -- best objective: -0.71931 -- received objective: -1.45350
[00115] -- best objective: -0.71931 -- received objective: -1.44557
[00116] -- best objective: -0.71931 -- received objective: -0.75683
[00117] -- best objective: -0.71931 -- received objective: -1.43503
[00118] -- best objective: -0.71931 -- received objective: -1.44292
[00119] -- best objective: -0.71931 -- received objective: -1.42997
[00120] -- best objective: -0.71931 -- received objective: -1.43535
[00121] -- best objective: -0.71931 -- received objective: -1.39890
[00122] -- best objective: -0.71931 -- received objective: -1.46928
[00123] -- best objective: -0.71931 -- received objective: -1.43725
[00124] -- best objective: -0.71931 -- received objective: -1.45492
[00125] -- best objective: -0.71931 -- received objective: -1.45191
[00126] -- best objective: -0.71931 -- received objective: -1.42815
[00127] -- best objective: -0.71931 -- received objective: -1.44565
[00128] -- best objective: -0.71931 -- received objective: -1.42904
[00129] -- best objective: -0.71931 -- received objective: -1.45950
[00130] -- best objective: -0.71931 -- received objective: -1.43312
[00131] -- best objective: -0.71931 -- received objective: -1.41984
[00132] -- best objective: -0.71931 -- received objective: -1.45176
[00133] -- best objective: -0.71931 -- received objective: -1.45983
[00134] -- best objective: -0.71931 -- received objective: -1.43363
[00135] -- best objective: -0.71931 -- received objective: -1.44105
[00136] -- best objective: -0.71931 -- received objective: -1.41815
[00137] -- best objective: -0.71931 -- received objective: -1.45694
[00138] -- best objective: -0.71931 -- received objective: -1.45668
[00139] -- best objective: -0.71931 -- received objective: -1.45239
[00140] -- best objective: -0.71931 -- received objective: -1.20071
[00141] -- best objective: -0.71931 -- received objective: -1.43030
[00142] -- best objective: -0.71931 -- received objective: -1.43514
[00143] -- best objective: -0.71931 -- received objective: -1.39529
[00144] -- best objective: -0.71931 -- received objective: -1.42177
[00145] -- best objective: -0.71931 -- received objective: -1.43993
[00146] -- best objective: -0.71931 -- received objective: -1.44254
[00147] -- best objective: -0.71931 -- received objective: -0.87992
[00148] -- best objective: -0.71931 -- received objective: -1.45687
[00149] -- best objective: -0.71931 -- received objective: -1.45024
[00150] -- best objective: -0.71931 -- received objective: -1.42521
[00151] -- best objective: -0.71931 -- received objective: -1.45721
[00152] -- best objective: -0.71931 -- received objective: -1.45105
[00153] -- best objective: -0.71931 -- received objective: -1.48770
[00154] -- best objective: -0.71931 -- received objective: -1.45838
[00155] -- best objective: -0.71931 -- received objective: -1.44815
[00156] -- best objective: -0.71931 -- received objective: -1.45398
[00157] -- best objective: -0.71931 -- received objective: -1.44200
[00158] -- best objective: -0.71931 -- received objective: -1.36608
[00159] -- best objective: -0.71931 -- received objective: -1.40234
[00160] -- best objective: -0.71931 -- received objective: -1.43278
[00161] -- best objective: -0.71931 -- received objective: -1.36176
[00162] -- best objective: -0.71931 -- received objective: -1.46017
[00163] -- best objective: -0.71931 -- received objective: -1.04931
[00164] -- best objective: -0.71931 -- received objective: -1.45232
[00165] -- best objective: -0.71931 -- received objective: -1.45659
[00166] -- best objective: -0.71931 -- received objective: -1.45074
[00167] -- best objective: -0.71931 -- received objective: -1.41671
[00168] -- best objective: -0.71931 -- received objective: -1.44916
[00169] -- best objective: -0.71931 -- received objective: -1.41268
[00170] -- best objective: -0.71931 -- received objective: -1.42258
[00171] -- best objective: -0.71931 -- received objective: -1.80510
[00172] -- best objective: -0.71931 -- received objective: -1.37938
[00173] -- best objective: -0.71931 -- received objective: -1.42922
[00174] -- best objective: -0.71931 -- received objective: -1.42787
[00175] -- best objective: -0.71931 -- received objective: -1.44523
[00176] -- best objective: -0.71931 -- received objective: -1.32807
[00177] -- best objective: -0.71931 -- received objective: -1.45407
[00178] -- best objective: -0.71931 -- received objective: -1.43923
[00179] -- best objective: -0.71931 -- received objective: -1.42698
[00180] -- best objective: -0.71931 -- received objective: -0.90297
[00181] -- best objective: -0.71931 -- received objective: -1.41060
[00182] -- best objective: -0.71931 -- received objective: -1.41926
[00183] -- best objective: -0.71931 -- received objective: -1.38359
[00184] -- best objective: -0.71931 -- received objective: -1.31420
[00185] -- best objective: -0.71931 -- received objective: -1.43918
[00186] -- best objective: -0.71931 -- received objective: -1.45792
[00187] -- best objective: -0.71931 -- received objective: -1.42656
[00188] -- best objective: -0.71931 -- received objective: -1.44442
[00189] -- best objective: -0.71931 -- received objective: -1.44507
[00190] -- best objective: -0.71931 -- received objective: -1.44427
[00191] -- best objective: -0.71931 -- received objective: -1.44745
[00192] -- best objective: -0.71931 -- received objective: -1.46342
[00193] -- best objective: -0.71931 -- received objective: -1.42576
[00194] -- best objective: -0.71931 -- received objective: -1.43336
[00195] -- best objective: -0.71931 -- received objective: -1.42682
[00196] -- best objective: -0.71931 -- received objective: -3.14193
[00197] -- best objective: -0.71931 -- received objective: -1.42550
[00198] -- best objective: -0.71931 -- received objective: -1.43409
[00199] -- best objective: -0.71931 -- received objective: -1.44634
[00200] -- best objective: -0.71931 -- received objective: -1.44997
[00201] -- best objective: -0.71931 -- received objective: -1.61185
[00202] -- best objective: -0.71931 -- received objective: -1.56791
[00203] -- best objective: -0.71931 -- received objective: -1.44982
[00204] -- best objective: -0.71931 -- received objective: -1.42505
[00205] -- best objective: -0.71931 -- received objective: -1.43631
[00206] -- best objective: -0.71931 -- received objective: -0.84930
[00207] -- best objective: -0.71931 -- received objective: -1.45358
[00208] -- best objective: -0.71931 -- received objective: -1.45878
[00209] -- best objective: -0.71931 -- received objective: -1.43476
[00210] -- best objective: -0.71931 -- received objective: -1.44281
[00211] -- best objective: -0.71931 -- received objective: -1.12588
[00212] -- best objective: -0.71931 -- received objective: -1.42940
[00213] -- best objective: -0.71931 -- received objective: -1.44740
[00214] -- best objective: -0.71931 -- received objective: -1.42644
[00215] -- best objective: -0.71931 -- received objective: -1.45220
[00216] -- best objective: -0.71931 -- received objective: -0.92525
[00217] -- best objective: -0.71931 -- received objective: -1.44177
[00218] -- best objective: -0.71931 -- received objective: -1.44430
[00219] -- best objective: -0.71931 -- received objective: -1.44824
[00220] -- best objective: -0.71931 -- received objective: -1.45207
[00221] -- best objective: -0.71931 -- received objective: -1.51004
[00222] -- best objective: -0.71931 -- received objective: -1.45485
[00223] -- best objective: -0.71931 -- received objective: -1.42609
[00224] -- best objective: -0.71931 -- received objective: -0.93720
[00225] -- best objective: -0.71931 -- received objective: -1.43300
[00226] -- best objective: -0.71931 -- received objective: -1.43693
[00227] -- best objective: -0.71931 -- received objective: -1.46001
[00228] -- best objective: -0.71931 -- received objective: -1.41308
[00229] -- best objective: -0.71931 -- received objective: -1.41629
[00230] -- best objective: -0.71931 -- received objective: -1.34070
[00231] -- best objective: -0.71931 -- received objective: -1.41198
[00232] -- best objective: -0.71931 -- received objective: -1.43235
[00233] -- best objective: -0.71931 -- received objective: -1.41764
[00234] -- best objective: -0.71931 -- received objective: -1.47284
[00235] -- best objective: -0.71931 -- received objective: -1.39747
[00236] -- best objective: -0.71931 -- received objective: -1.45206
[00237] -- best objective: -0.71931 -- received objective: -1.42272
[00238] -- best objective: -0.71931 -- received objective: -0.85472
[00239] -- best objective: -0.71931 -- received objective: -2.61976
[00240] -- best objective: -0.71931 -- received objective: -1.45459
[00241] -- best objective: -0.71931 -- received objective: -1.46222
[00242] -- best objective: -0.71931 -- received objective: -0.91470
[00243] -- best objective: -0.71931 -- received objective: -1.45774
[00244] -- best objective: -0.71931 -- received objective: -1.42460
[00245] -- best objective: -0.71931 -- received objective: -0.89142
[00246] -- best objective: -0.71931 -- received objective: -1.42632
[00247] -- best objective: -0.71931 -- received objective: -1.45606
[00248] -- best objective: -0.71931 -- received objective: -1.42153
[00249] -- best objective: -0.71931 -- received objective: -1.43402
[00250] -- best objective: -0.71931 -- received objective: -1.45294
[00251] -- best objective: -0.71931 -- received objective: -1.42485
[00252] -- best objective: -0.71931 -- received objective: -1.44038
[00253] -- best objective: -0.71931 -- received objective: -1.45735
[00254] -- best objective: -0.71931 -- received objective: -1.44466
[00255] -- best objective: -0.71931 -- received objective: -1.43610
[00256] -- best objective: -0.71931 -- received objective: -1.40491
[00257] -- best objective: -0.71931 -- received objective: -1.44926
[00258] -- best objective: -0.71931 -- received objective: -1.45612
[00259] -- best objective: -0.71931 -- received objective: -1.39901
[00260] -- best objective: -0.71931 -- received objective: -1.45547
[00261] -- best objective: -0.71931 -- received objective: -1.41932
[00262] -- best objective: -0.71931 -- received objective: -1.42832
[00263] -- best objective: -0.71931 -- received objective: -1.44120
[00264] -- best objective: -0.71931 -- received objective: -1.43465
[00265] -- best objective: -0.71931 -- received objective: -1.43254
[00266] -- best objective: -0.71931 -- received objective: -1.45280
[00267] -- best objective: -0.71931 -- received objective: -1.46333
[00268] -- best objective: -0.71931 -- received objective: -1.45654
[00269] -- best objective: -0.71931 -- received objective: -1.41420
[00270] -- best objective: -0.71931 -- received objective: -1.45523
[00271] -- best objective: -0.71931 -- received objective: -1.44700
[00272] -- best objective: -0.71931 -- received objective: -1.45743
[00273] -- best objective: -0.71931 -- received objective: -1.45239
[00274] -- best objective: -0.71931 -- received objective: -1.21812
[00275] -- best objective: -0.71931 -- received objective: -1.43688
[00276] -- best objective: -0.71931 -- received objective: -1.44832
[00277] -- best objective: -0.71931 -- received objective: -1.44360
[00278] -- best objective: -0.71931 -- received objective: -1.45681
[00279] -- best objective: -0.71931 -- received objective: -1.40239
[00280] -- best objective: -0.71931 -- received objective: -1.35913
[00281] -- best objective: -0.71931 -- received objective: -1.33926
[00282] -- best objective: -0.71931 -- received objective: -1.43923
[00283] -- best objective: -0.71931 -- received objective: -1.47597
[00284] -- best objective: -0.71931 -- received objective: -1.17367
[00285] -- best objective: -0.71931 -- received objective: -1.38899
[00286] -- best objective: -0.71931 -- received objective: -1.43378
[00287] -- best objective: -0.70193 -- received objective: -0.70193
[00288] -- best objective: -0.70193 -- received objective: -1.43677
[00289] -- best objective: -0.70193 -- received objective: -1.41945
[00290] -- best objective: -0.70193 -- received objective: -1.45242
[00291] -- best objective: -0.70193 -- received objective: -1.45307
[00292] -- best objective: -0.70193 -- received objective: -1.46069
[00293] -- best objective: -0.70193 -- received objective: -0.78253
[00294] -- best objective: -0.70193 -- received objective: -1.42921
[00295] -- best objective: -0.70193 -- received objective: -1.43373
[00296] -- best objective: -0.70193 -- received objective: -1.44761
[00297] -- best objective: -0.70193 -- received objective: -1.43483
[00298] -- best objective: -0.70193 -- received objective: -1.57855
[00299] -- best objective: -0.70193 -- received objective: -1.44155
[00300] -- best objective: -0.70193 -- received objective: -0.78217
[00301] -- best objective: -0.70193 -- received objective: -1.44862
[00302] -- best objective: -0.70193 -- received objective: -1.45383
[00303] -- best objective: -0.70193 -- received objective: -1.41816
[00304] -- best objective: -0.70193 -- received objective: -1.44653
[00305] -- best objective: -0.70193 -- received objective: -1.45448
[00306] -- best objective: -0.70193 -- received objective: -1.44932
[00307] -- best objective: -0.70193 -- received objective: -1.45843
[00308] -- best objective: -0.70193 -- received objective: -1.45335
[00309] -- best objective: -0.70193 -- received objective: -1.44072
[00310] -- best objective: -0.70193 -- received objective: -1.45692
[00311] -- best objective: -0.70193 -- received objective: -1.45291
[00312] -- best objective: -0.70193 -- received objective: -1.44020
[00313] -- best objective: -0.70193 -- received objective: -1.43500
[00314] -- best objective: -0.70193 -- received objective: -1.43182
[00315] -- best objective: -0.70193 -- received objective: -1.05928
[00316] -- best objective: -0.70193 -- received objective: -1.44430
[00317] -- best objective: -0.70193 -- received objective: -1.42412
[00318] -- best objective: -0.70193 -- received objective: -1.44319
[00319] -- best objective: -0.70193 -- received objective: -1.44925
[00320] -- best objective: -0.70193 -- received objective: -1.45564
[00321] -- best objective: -0.70193 -- received objective: -1.45421
[00322] -- best objective: -0.70193 -- received objective: -1.45736
[00323] -- best objective: -0.70193 -- received objective: -1.45604
[00324] -- best objective: -0.70193 -- received objective: -1.45594
[00325] -- best objective: -0.70193 -- received objective: -1.45602
[00326] -- best objective: -0.70193 -- received objective: -1.45143
[00327] -- best objective: -0.70193 -- received objective: -3.49919
[00328] -- best objective: -0.70193 -- received objective: -1.40792
[00329] -- best objective: -0.70193 -- received objective: -1.44140
[00330] -- best objective: -0.70193 -- received objective: -1.44091
[00331] -- best objective: -0.70193 -- received objective: -1.44601
[00332] -- best objective: -0.70193 -- received objective: -1.42981
[00333] -- best objective: -0.70193 -- received objective: -1.42551
[00334] -- best objective: -0.70193 -- received objective: -1.42944
[00335] -- best objective: -0.70193 -- received objective: -1.45000
[00336] -- best objective: -0.70193 -- received objective: -1.44603
[00337] -- best objective: -0.70193 -- received objective: -1.26101
[00338] -- best objective: -0.70193 -- received objective: -1.46027
[00339] -- best objective: -0.70193 -- received objective: -0.96610
[00340] -- best objective: -0.70193 -- received objective: -1.45264
[00341] -- best objective: -0.70193 -- received objective: -1.42552
[00342] -- best objective: -0.70193 -- received objective: -1.42399
[00343] -- best objective: -0.70193 -- received objective: -1.45277
[00344] -- best objective: -0.70193 -- received objective: -1.45425
[00345] -- best objective: -0.70193 -- received objective: -1.43214
[00346] -- best objective: -0.70193 -- received objective: -1.45659
[00347] -- best objective: -0.70193 -- received objective: -1.42178
[00348] -- best objective: -0.70193 -- received objective: -0.83681
[00349] -- best objective: -0.70193 -- received objective: -1.44048
[00350] -- best objective: -0.70193 -- received objective: -1.43159
[00351] -- best objective: -0.70193 -- received objective: -1.45593
[00352] -- best objective: -0.70193 -- received objective: -0.95960
[00353] -- best objective: -0.70193 -- received objective: -1.45756
[00354] -- best objective: -0.70193 -- received objective: -3.91287
[00355] -- best objective: -0.70193 -- received objective: -1.45406
[00356] -- best objective: -0.70193 -- received objective: -1.45432
[00357] -- best objective: -0.70193 -- received objective: -1.42765
[00358] -- best objective: -0.70193 -- received objective: -1.43022
[00359] -- best objective: -0.70193 -- received objective: -1.45372
[00360] -- best objective: -0.70193 -- received objective: -1.45701
[00361] -- best objective: -0.70193 -- received objective: -1.30770
[00362] -- best objective: -0.70193 -- received objective: -1.44747
[00363] -- best objective: -0.70193 -- received objective: -1.43278
[00364] -- best objective: -0.70193 -- received objective: -1.33041
[00365] -- best objective: -0.70193 -- received objective: -1.37751
[00366] -- best objective: -0.70193 -- received objective: -1.44355
[00367] -- best objective: -0.70193 -- received objective: -1.46236
[00368] -- best objective: -0.70193 -- received objective: -1.49842
[00369] -- best objective: -0.70193 -- received objective: -1.51039
[00370] -- best objective: -0.70193 -- received objective: -1.43416
[00371] -- best objective: -0.70193 -- received objective: -1.43585
[00372] -- best objective: -0.70193 -- received objective: -1.44355
[00373] -- best objective: -0.70193 -- received objective: -1.42863
[00374] -- best objective: -0.70193 -- received objective: -1.45477
[00375] -- best objective: -0.70193 -- received objective: -0.79066
[00376] -- best objective: -0.70193 -- received objective: -0.89664
[00377] -- best objective: -0.70193 -- received objective: -1.43924
[00378] -- best objective: -0.70193 -- received objective: -1.42741
[00379] -- best objective: -0.70193 -- received objective: -1.43395
[00380] -- best objective: -0.70193 -- received objective: -1.44974
[00381] -- best objective: -0.70193 -- received objective: -1.43227
[00382] -- best objective: -0.70193 -- received objective: -1.42909
[00383] -- best objective: -0.70193 -- received objective: -1.44563
[00384] -- best objective: -0.70193 -- received objective: -1.45783
[00385] -- best objective: -0.70193 -- received objective: -1.20835
[00386] -- best objective: -0.70193 -- received objective: -1.45022
[00387] -- best objective: -0.70193 -- received objective: -1.46588
[00388] -- best objective: -0.70193 -- received objective: -1.42809
[00389] -- best objective: -0.70193 -- received objective: -1.42034
[00390] -- best objective: -0.70193 -- received objective: -1.43729
[00391] -- best objective: -0.70193 -- received objective: -1.43845
[00392] -- best objective: -0.70193 -- received objective: -1.45045
[00393] -- best objective: -0.70193 -- received objective: -1.45183
[00394] -- best objective: -0.70193 -- received objective: -340282346638528859811704183484516925440.00000
[00395] -- best objective: -0.70193 -- received objective: -1.45705
[00396] -- best objective: -0.70193 -- received objective: -1.22951
[00397] -- best objective: -0.70193 -- received objective: -1.30033
[00398] -- best objective: -0.70193 -- received objective: -1.41722
[00399] -- best objective: -0.70193 -- received objective: -1.44119
[00400] -- best objective: -0.70193 -- received objective: -1.43586
[00401] -- best objective: -0.70193 -- received objective: -1.44859
[00402] -- best objective: -0.70193 -- received objective: -1.42498
[00403] -- best objective: -0.70193 -- received objective: -1.42504
[00404] -- best objective: -0.70193 -- received objective: -1.43216
[00405] -- best objective: -0.70193 -- received objective: -1.42574
[00406] -- best objective: -0.70193 -- received objective: -1.45998
[00407] -- best objective: -0.70193 -- received objective: -1.44175
[00408] -- best objective: -0.70193 -- received objective: -1.41786
[00409] -- best objective: -0.70193 -- received objective: -1.41467
[00410] -- best objective: -0.70193 -- received objective: -1.43241
[00411] -- best objective: -0.70193 -- received objective: -1.44931
[00412] -- best objective: -0.70193 -- received objective: -1.47552
[00413] -- best objective: -0.70193 -- received objective: -1.42608
[00414] -- best objective: -0.70193 -- received objective: -1.45435
[00415] -- best objective: -0.70193 -- received objective: -0.73556
[00416] -- best objective: -0.70193 -- received objective: -1.45478
[00417] -- best objective: -0.70193 -- received objective: -1.47345
[00418] -- best objective: -0.70193 -- received objective: -1.41285
[00419] -- best objective: -0.70193 -- received objective: -1.44405
[00420] -- best objective: -0.70193 -- received objective: -1.45730
[00421] -- best objective: -0.70193 -- received objective: -1.45453
[00422] -- best objective: -0.70193 -- received objective: -1.45661
[00423] -- best objective: -0.70193 -- received objective: -1.43419
[00424] -- best objective: -0.70193 -- received objective: -1.39651
[00425] -- best objective: -0.70193 -- received objective: -1.43756
[00426] -- best objective: -0.70193 -- received objective: -1.43867
[00427] -- best objective: -0.70193 -- received objective: -1.45141
[00428] -- best objective: -0.70193 -- received objective: -1.44158
[00429] -- best objective: -0.70193 -- received objective: -1.44977
[00430] -- best objective: -0.70193 -- received objective: -1.45497
[00431] -- best objective: -0.70193 -- received objective: -1.21795
[00432] -- best objective: -0.70193 -- received objective: -1.31898
[00433] -- best objective: -0.70193 -- received objective: -1.43267
[00434] -- best objective: -0.70193 -- received objective: -1.45077
[00435] -- best objective: -0.70193 -- received objective: -1.44004
[00436] -- best objective: -0.70193 -- received objective: -1.44921
[00437] -- best objective: -0.70193 -- received objective: -1.43402
[00438] -- best objective: -0.70193 -- received objective: -1.45731
[00439] -- best objective: -0.70193 -- received objective: -1.43075
[00440] -- best objective: -0.70193 -- received objective: -1.45725
[00441] -- best objective: -0.70193 -- received objective: -1.45750
[00442] -- best objective: -0.70193 -- received objective: -1.45466
[00443] -- best objective: -0.70193 -- received objective: -1.45316
[00444] -- best objective: -0.70193 -- received objective: -1.41686
[00445] -- best objective: -0.70193 -- received objective: -1.46630
[00446] -- best objective: -0.70193 -- received objective: -1.22989
[00447] -- best objective: -0.70193 -- received objective: -1.39547
[00448] -- best objective: -0.67502 -- received objective: -0.67502
[00449] -- best objective: -0.67502 -- received objective: -1.39434
[00450] -- best objective: -0.67502 -- received objective: -1.45108
[00451] -- best objective: -0.67502 -- received objective: -1.45273
[00452] -- best objective: -0.67502 -- received objective: -1.45454
[00453] -- best objective: -0.67502 -- received objective: -1.38490
[00454] -- best objective: -0.67502 -- received objective: -1.44552
[00455] -- best objective: -0.67502 -- received objective: -1.42979
[00456] -- best objective: -0.67502 -- received objective: -1.43964
[00457] -- best objective: -0.67502 -- received objective: -1.44500
[00458] -- best objective: -0.67502 -- received objective: -1.45289
[00459] -- best objective: -0.67502 -- received objective: -1.42089
[00460] -- best objective: -0.67502 -- received objective: -1.44563
[00461] -- best objective: -0.67502 -- received objective: -1.39744
[00462] -- best objective: -0.67502 -- received objective: -0.70283
[00463] -- best objective: -0.67502 -- received objective: -0.73002
[00464] -- best objective: -0.67502 -- received objective: -1.42785
[00465] -- best objective: -0.67502 -- received objective: -1.42936
[00466] -- best objective: -0.67502 -- received objective: -1.44163
[00467] -- best objective: -0.67502 -- received objective: -1.44762
[00468] -- best objective: -0.67502 -- received objective: -1.42812
[00469] -- best objective: -0.67502 -- received objective: -1.45961
[00470] -- best objective: -0.67502 -- received objective: -1.40973
[00471] -- best objective: -0.67502 -- received objective: -1.42947
[00472] -- best objective: -0.67502 -- received objective: -1.36007
[00473] -- best objective: -0.67502 -- received objective: -1.45462
[00474] -- best objective: -0.67502 -- received objective: -1.45631
[00475] -- best objective: -0.67502 -- received objective: -1.45667
[00476] -- best objective: -0.67502 -- received objective: -1.42858
[00477] -- best objective: -0.67502 -- received objective: -1.38607
[00478] -- best objective: -0.67502 -- received objective: -1.34590
[00479] -- best objective: -0.67502 -- received objective: -1.41468
[00480] -- best objective: -0.67502 -- received objective: -1.43850
[00481] -- best objective: -0.67502 -- received objective: -1.46302
[00482] -- best objective: -0.67502 -- received objective: -1.45317
[00483] -- best objective: -0.67502 -- received objective: -1.45557
[00484] -- best objective: -0.67502 -- received objective: -1.43622
[00485] -- best objective: -0.67502 -- received objective: -1.45013
[00486] -- best objective: -0.67502 -- received objective: -1.40959
[00487] -- best objective: -0.67502 -- received objective: -1.43481
[00488] -- best objective: -0.67502 -- received objective: -1.45249
[00489] -- best objective: -0.67502 -- received objective: -1.39193
[00490] -- best objective: -0.67502 -- received objective: -1.86431
[00491] -- best objective: -0.67502 -- received objective: -1.41392
[00492] -- best objective: -0.67502 -- received objective: -1.42682
[00493] -- best objective: -0.67502 -- received objective: -1.43078
[00494] -- best objective: -0.67502 -- received objective: -1.43148
[00495] -- best objective: -0.67502 -- received objective: -1.45150
[00496] -- best objective: -0.67502 -- received objective: -1.43097
[00497] -- best objective: -0.67502 -- received objective: -1.09619
[00498] -- best objective: -0.67502 -- received objective: -1.41682
[00499] -- best objective: -0.67502 -- received objective: -0.72077
[00500] -- best objective: -0.67502 -- received objective: -1.44452
[41]:
results_uq["agebo"]
[41]:
arch_seq batch_size learning_rate optimizer patience_EarlyStopping patience_ReduceLROnPlateau id objective elapsed_sec duration
0 [109, 0, 45, 0, 0, 83, 1, 0, 0] 20 0.062255 adamax 27 18 2 -1.415458 7.840414 6.168005
1 [69, 1, 69, 0, 0, 94, 1, 1, 0] 5 0.000993 adamax 23 19 1 -1.453274 8.828585 7.156255
2 [39, 1, 95, 1, 0, 54, 1, 0, 0] 13 0.006326 adadelta 24 18 4 -1.560622 17.567594 15.895050
3 [131, 0, 61, 1, 1, 15, 1, 1, 0] 6 0.025969 adadelta 24 14 3 -1.613966 23.575635 21.903158
4 [91, 1, 50, 1, 0, 127, 0, 0, 1] 30 0.098383 adadelta 23 18 7 -1.439902 24.859868 6.820644
... ... ... ... ... ... ... ... ... ... ...
494 [29, 1, 34, 0, 1, 70, 0, 0, 0] 29 0.089742 adagrad 30 14 497 -1.451497 928.070709 3.169032
495 [87, 1, 34, 0, 1, 31, 1, 0, 0] 27 0.083251 adamax 26 19 499 -1.430971 933.139551 3.969134
496 [87, 1, 129, 1, 1, 31, 0, 0, 0] 23 0.084330 adamax 29 20 498 -1.096194 935.875403 9.437265
497 [6, 1, 41, 0, 1, 116, 1, 1, 0] 27 0.088917 adamax 30 15 501 -1.416822 940.937758 3.764933
498 [87, 1, 20, 1, 1, 31, 0, 0, 0] 21 0.086940 adamax 25 17 500 -0.720768 942.180480 7.843822

499 rows × 10 columns

Different sizes of ensemble and two selection strategies can be experimented for ensembles: * caruana: a greedy selection from Caruana based on the loss function. * topk: select the top-k (i.e., size) best models based on the loss function.

[42]:
ensemble = UQBaggingEnsembleRegressor(
    model_dir="agebo_search/save/model",
    loss=nll,  # default is nll
    size=5,
    verbose=True,
    ray_address="auto",
    num_cpus=1,
    num_gpus=1 if is_gpu_available else None,
    selection="caruana",
)

ensemble.fit(s_vx, s_vy)

print(f"Selected {len(ensemble.members_files)} members are: ", ensemble.members_files)
Selected 8 members are:  ['430.h5', '491.h5', '500.h5', '430.h5', '397.h5', '491.h5', '430.h5', '11.h5']
[43]:
pred_s_ty, pred_s_ty_aleatoric_var, pred_s_ty_epistemic_var = ensemble.predict_var_decomposition(s_tx)

pred_ty_mean = pred_s_ty.loc.numpy() + scaler_y.mean_
pred_ty_aleatoric_var = pred_s_ty_aleatoric_var * scaler_y.var_
pred_ty_epistemic_var = pred_s_ty_epistemic_var * scaler_y.var_

width = 8
height = width/1.618
plt.figure(figsize=(width, height))

plt.plot(tx, ty, label="truth")
plt.plot(tx, pred_ty_mean, label="$\mu$")
plt.fill_between(
    tx.reshape(-1),
    (pred_ty_mean - pred_ty_aleatoric_var).reshape(-1),
    (pred_ty_mean + pred_ty_aleatoric_var).reshape(-1),
    color="yellow",
    alpha=0.5,
    label="$E[\sigma^2]$: aleatoric"
)
plt.fill_between(
    tx.reshape(-1),
    (pred_ty_mean - pred_ty_aleatoric_var).reshape(-1),
    (pred_ty_mean - pred_ty_aleatoric_var - pred_ty_epistemic_var).reshape(-1),
    color="orange",
    alpha=0.5,
    label="$V[\mu]$: epistemic"
)
plt.fill_between(
    tx.reshape(-1),
    (pred_ty_mean + pred_ty_aleatoric_var).reshape(-1),
    (pred_ty_mean + pred_ty_aleatoric_var + pred_ty_epistemic_var).reshape(-1),
    color="orange",
    alpha=0.5,
)

y_lim = 10
plt.fill_between([-30, -20], [-y_lim, -y_lim], [y_lim, y_lim], color="grey", alpha=0.5)
plt.fill_between([20, 30], [-y_lim, -y_lim], [y_lim, y_lim], color="grey", alpha=0.5)

plt.legend()
plt.ylim(-y_lim, y_lim)

plt.show()
../../../../_images/tutorials_tutorials_notebooks_07_NAS_with_Ensemble_and_UQ_tutorial_07_75_0.png

We notice that adding more diversity in the catalog can help have much better epistemic uncertainty estimates.

3.13. Random Forest

What if we compare this to using the same technic (i.e., law of total variance) on the trees used to build a random forest regressor.

[94]:
import warnings

from skopt.learning import RandomForestRegressor

with warnings.catch_warnings():
    warnings.simplefilter("ignore")

    rf_model = RandomForestRegressor(criterion="mse")

    rf_model.fit(s_x, s_y)
[95]:
pred_s_ty, pred_s_ty_std = rf_model.predict(s_tx, return_std=True)
pred_s_ty_var = pred_s_ty_std ** 2
[96]:
pred_ty_mean = pred_s_ty + scaler_y.mean_
pred_ty_var = pred_s_ty_var * scaler_y.var_

width = 8
height = width/1.618
plt.figure(figsize=(width, height))

plt.plot(tx, ty, label="truth")
plt.plot(tx, pred_ty_mean, label="$\mu$")
plt.fill_between(
    tx.reshape(-1),
    (pred_ty_mean - pred_ty_var).reshape(-1),
    (pred_ty_mean + pred_ty_var).reshape(-1),
    color="yellow",
    alpha=0.5,
    label="$\sigma^2$"
)

y_lim = 10
plt.fill_between([-30, -20], [-y_lim, -y_lim], [y_lim, y_lim], color="grey", alpha=0.5)
plt.fill_between([20, 30], [-y_lim, -y_lim], [y_lim, y_lim], color="grey", alpha=0.5)

plt.legend()
plt.ylim(-y_lim, y_lim)

plt.show()
../../../../_images/tutorials_tutorials_notebooks_07_NAS_with_Ensemble_and_UQ_tutorial_07_80_0.png

We can see that the uncertainty estimate is really poor in this case.

[ ]: