Neural Architecture Search and Deep Ensemble with Uncertainty Quantification for Regression (Pytorch)#

Author(s): Romain Egele, Brett Eiffert.

In this tutorial, you will learn how to perform Neural Architecture Search (NAS) and use it to construct a diverse deep ensemble with disentangled aleatoric and epistemic uncertainty.

NAS is the idea of automatically optimizing the architecture of deep neural networks to solve a given task. Here, we will use hyperparameter optimization (HPO) algorithms to guide the NAS process.

Specifically, in this tutorial you will learn how to:

  1. Define a customizable PyTorch module that exposes neural architecture hyperparameters.

  2. Define constraints on the neural architecture hyperparameters to reduce redundancies and improve efficiency of the optimization.

This tutorial will provide a hands-on approach to leveraging NAS for robust regression models with well-calibrated uncertainty estimates.

Installation and imports#

Installing dependencies with the pip installation is recommended. It requires Python >= 3.10.

%%bash
pip install "deephyper[ray,torch]"
Code (Import statements)
import json
import os
import pathlib

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split
from tqdm import tqdm

WIDTH_PLOTS = 8
HEIGHT_PLOTS = WIDTH_PLOTS / 1.618

Synthetic data generation#

We generate synthetic data from a 1D scalar function \(Y = f(X) + \epsilon(X)\), where \(X,Y\) are random variables with support \(\mathbb{R}\).

The training data are drown uniformly from \(X \sim U([-30,-15] \cup [15,30])\) with:

\[f(x) = \cos(x/2) + 2 \cdot \sin(x/10) + x/100\]

and \(\epsilon(X) \sim \mathcal{N}(0, \sigma(X))\) with:

  • \(\sigma(x) = 0.5\) if \(x \in [-30,-15]\)

  • \(\sigma(x) = 1.0\) if \(x \in [15,30]\)

Code (Loading synthetic data)
def load_data(
    developement_size=500,
    test_size=200,
    random_state=42,
    x_min=-50,
    x_max=50,
):
    rs = np.random.RandomState(random_state)

    def f(x):
        return np.cos(x / 2) + 2 * np.sin(x / 10) + x / 100

    x_1 = rs.uniform(low=-30, high=-15.0, size=developement_size // 2)
    eps_1 = rs.normal(loc=0.0, scale=0.5, size=developement_size // 2)
    y_1 = f(x_1) + eps_1

    x_2 = rs.uniform(low=15.0, high=30.0, size=developement_size // 2)
    eps_2 = rs.normal(loc=0.0, scale=1.0, size=developement_size // 2)
    y_2 = f(x_2) + eps_2

    x = np.concatenate([x_1, x_2], axis=0)
    y = np.concatenate([y_1, y_2], axis=0)

    test_X = np.linspace(x_min, x_max, test_size)
    test_y = f(test_X)

    x = x.reshape(-1, 1)
    y = y.reshape(-1, 1)

    train_X, valid_X, train_y, valid_y = train_test_split(
        x, y, test_size=0.33, random_state=random_state
    )

    test_X = test_X.reshape(-1, 1)
    test_y = test_y.reshape(-1, 1)

    return (train_X, train_y), (valid_X, valid_y), (test_X, test_y)


(train_X, train_y), (valid_X, valid_y), (test_X, test_y) = load_data()

y_mu, y_std = np.mean(train_y), np.std(train_y)

x_lim, y_lim = 50, 7
_ = plt.figure(figsize=(WIDTH_PLOTS, HEIGHT_PLOTS))
_ = plt.scatter(train_X, train_y, s=5, label="Training")
_ = plt.scatter(valid_X, valid_y, s=5, label="Validation")
_ = plt.plot(test_X, test_y, linestyle="--", color="gray", label="Test")
_ = plt.fill_between([-30, -15], [-y_lim, -y_lim], [y_lim, y_lim], color="gray", alpha=0.25)
_ = plt.fill_between([15, 30], [-y_lim, -y_lim], [y_lim, y_lim], color="gray", alpha=0.25)
_ = plt.xlim(-x_lim, x_lim)
_ = plt.ylim(-y_lim, y_lim)
_ = plt.legend()
_ = plt.xlabel(r"$x$")
_ = plt.ylabel(r"$f(x)$")
_ = plt.grid(which="both", linestyle=":")
plot nas deep ensemble uq regression pytorch

Configurable neural network with uncertainty#

We define a configurable Pytorch module to be able to explore:

  • the number of layers

  • the number of units per layer

  • the activation function per layer

  • the dropout rate

  • the output layer

The output of this module will be a Gaussian distribution \(\mathcal{N}(\mu_\theta(x), \sigma_\theta(x))\), where \(\theta\) represent the concatenation of the weights and the hyperparameters of our model.

The uncertainty \(\sigma_\theta(x)\) estimated by the network is an estimator of \(V_Y[Y|X=x]\) therefore corresponding to aleatoric uncertainty (a.k.a., intrinsic noise).

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset


class DeepNormalRegressor(nn.Module):
    def __init__(
        self,
        n_inputs,
        layers,
        n_units_mean=64,
        n_units_std=64,
        std_offset=1e-3,
        softplus_factor=0.05,
        loc=0,
        scale=1.0,
    ):
        super().__init__()

        layers_ = []
        prev_n_units = n_inputs
        for n_units, activation, dropout_rate in layers:
            linear_layer = nn.Linear(prev_n_units, n_units)
            if activation == "relu":
                activation_layer = nn.ReLU()
            elif activation == "sigmoid":
                activation_layer = nn.Sigmoid()
            elif activation == "tanh":
                activation_layer = nn.Tanh()
            elif activation == "swish":
                activation_layer = nn.SiLU()
            elif activation == "mish":
                activation_layer = nn.Mish()
            elif activation == "gelu":
                activation_layer = nn.GELU()
            elif activation == "silu":
                activation_layer = nn.SiLU()
            dropout_layer = nn.Dropout(dropout_rate)

            layers_.extend([linear_layer, activation_layer, dropout_layer])

            prev_n_units = n_units

        # Shared parameters
        self.shared_layer = nn.Sequential(
            *layers_,
        )

        # Mean parameters
        self.mean_layer = nn.Sequential(
            nn.Linear(prev_n_units, n_units_mean),
            nn.ReLU(),
            nn.Linear(n_units_mean, 1),
        )

        # Standard deviation parameters
        self.std_layer = nn.Sequential(
            nn.Linear(prev_n_units, n_units_std),
            nn.ReLU(),
            nn.Linear(n_units_std, 1),
            nn.Softplus(beta=1.0, threshold=20.0),  # enforces positivity
        )

        self.std_offset = std_offset
        self.softplus_factor = softplus_factor
        self.loc = loc
        self.scale = scale

    def forward(self, x):
        # Shared embedding
        shared = self.shared_layer(x)

        # Parametrization of the mean
        mu = self.mean_layer(shared) + self.loc

        # Parametrization of the standard deviation
        sigma = self.std_offset + self.std_layer(self.softplus_factor * shared) * self.scale

        return torch.distributions.Normal(mu, sigma)

Hyperparameter search space#

We define the hyperparameter space that includes both neural architecture and training hyperparameters.

Without having a good heuristic on training hyperparameters given the neural architecture hyperparameter search space it is important to define them jointly with the neural architecture hyperparameters as they can have strong interactions.

In the definition of the hyperparameter space, we add constraints using ConfigSpace.GreaterThanCondition to represent when an hyperparameter is active. In this example, “active” means it actually influence the code execution of the trained model.

from ConfigSpace import GreaterThanCondition
from deephyper.hpo import HpProblem


def create_hpo_problem(min_num_layers=3, max_num_layers=8, max_num_units=512):
    problem = HpProblem()

    # Neural Architecture Hyperparameters
    num_layers = problem.add_hyperparameter((min_num_layers, max_num_layers), "num_layers", default_value=5)

    conditions = []
    for i in range(max_num_layers):

        # Adding the hyperparameters that impact each layer of the model
        layer_i_units = problem.add_hyperparameter((16, max_num_units), f"layer_{i}_units", default_value=max_num_units)
        layer_i_activation = problem.add_hyperparameter(
            ["relu", "sigmoid", "tanh", "swish", "mish", "gelu", "silu"],
            f"layer_{i}_activation",
            default_value="relu",
        )
        layer_i_dropout_rate = problem.add_hyperparameter(
            (0.0, 0.25), f"layer_{i}_dropout_rate", default_value=0.0
        )

        # Adding the constraints to define when these hyperparameters are active
        if i + 1 > min_num_layers:
            conditions.extend(
                [
                    GreaterThanCondition(layer_i_units, num_layers, i),
                    GreaterThanCondition(layer_i_activation, num_layers, i),
                    GreaterThanCondition(layer_i_dropout_rate, num_layers, i),
                ]
            )

    problem.add_conditions(conditions)

    # Hyperparameters of the output layers
    problem.add_hyperparameter((16, max_num_units), "n_units_mean", default_value=max_num_units)
    problem.add_hyperparameter((16, max_num_units), "n_units_std", default_value=max_num_units)
    problem.add_hyperparameter((1e-8, 1e-2, "log-uniform"), "std_offset", default_value=1e-3)
    problem.add_hyperparameter((0.01, 1.0), "softplus_factor", default_value=0.05)

    # Training Hyperparameters
    problem.add_hyperparameter((1e-5, 1e-1, "log-uniform"), "learning_rate", default_value=2e-3)
    problem.add_hyperparameter((8, 256, "log-uniform"), "batch_size", default_value=32)
    problem.add_hyperparameter((0.01, 0.99), "lr_scheduler_factor", default_value=0.1)
    problem.add_hyperparameter((10, 100), "lr_scheduler_patience", default_value=20)

    return problem

problem = create_hpo_problem()
problem
Configuration space object:
  Hyperparameters:
    batch_size, Type: UniformInteger, Range: [8, 256], Default: 32, on log-scale
    layer_0_activation, Type: Categorical, Choices: {relu, sigmoid, tanh, swish, mish, gelu, silu}, Default: relu
    layer_0_dropout_rate, Type: UniformFloat, Range: [0.0, 0.25], Default: 0.0
    layer_0_units, Type: UniformInteger, Range: [16, 512], Default: 512
    layer_1_activation, Type: Categorical, Choices: {relu, sigmoid, tanh, swish, mish, gelu, silu}, Default: relu
    layer_1_dropout_rate, Type: UniformFloat, Range: [0.0, 0.25], Default: 0.0
    layer_1_units, Type: UniformInteger, Range: [16, 512], Default: 512
    layer_2_activation, Type: Categorical, Choices: {relu, sigmoid, tanh, swish, mish, gelu, silu}, Default: relu
    layer_2_dropout_rate, Type: UniformFloat, Range: [0.0, 0.25], Default: 0.0
    layer_2_units, Type: UniformInteger, Range: [16, 512], Default: 512
    layer_3_activation, Type: Categorical, Choices: {relu, sigmoid, tanh, swish, mish, gelu, silu}, Default: relu
    layer_3_dropout_rate, Type: UniformFloat, Range: [0.0, 0.25], Default: 0.0
    layer_3_units, Type: UniformInteger, Range: [16, 512], Default: 512
    layer_4_activation, Type: Categorical, Choices: {relu, sigmoid, tanh, swish, mish, gelu, silu}, Default: relu
    layer_4_dropout_rate, Type: UniformFloat, Range: [0.0, 0.25], Default: 0.0
    layer_4_units, Type: UniformInteger, Range: [16, 512], Default: 512
    layer_5_activation, Type: Categorical, Choices: {relu, sigmoid, tanh, swish, mish, gelu, silu}, Default: relu
    layer_5_dropout_rate, Type: UniformFloat, Range: [0.0, 0.25], Default: 0.0
    layer_5_units, Type: UniformInteger, Range: [16, 512], Default: 512
    layer_6_activation, Type: Categorical, Choices: {relu, sigmoid, tanh, swish, mish, gelu, silu}, Default: relu
    layer_6_dropout_rate, Type: UniformFloat, Range: [0.0, 0.25], Default: 0.0
    layer_6_units, Type: UniformInteger, Range: [16, 512], Default: 512
    layer_7_activation, Type: Categorical, Choices: {relu, sigmoid, tanh, swish, mish, gelu, silu}, Default: relu
    layer_7_dropout_rate, Type: UniformFloat, Range: [0.0, 0.25], Default: 0.0
    layer_7_units, Type: UniformInteger, Range: [16, 512], Default: 512
    learning_rate, Type: UniformFloat, Range: [1e-05, 0.1], Default: 0.002, on log-scale
    lr_scheduler_factor, Type: UniformFloat, Range: [0.01, 0.99], Default: 0.1
    lr_scheduler_patience, Type: UniformInteger, Range: [10, 100], Default: 20
    n_units_mean, Type: UniformInteger, Range: [16, 512], Default: 512
    n_units_std, Type: UniformInteger, Range: [16, 512], Default: 512
    num_layers, Type: UniformInteger, Range: [3, 8], Default: 5
    softplus_factor, Type: UniformFloat, Range: [0.01, 1.0], Default: 0.05
    std_offset, Type: UniformFloat, Range: [1e-08, 0.01], Default: 0.001, on log-scale
  Conditions:
    layer_3_activation | num_layers > 3
    layer_3_dropout_rate | num_layers > 3
    layer_3_units | num_layers > 3
    layer_4_activation | num_layers > 4
    layer_4_dropout_rate | num_layers > 4
    layer_4_units | num_layers > 4
    layer_5_activation | num_layers > 5
    layer_5_dropout_rate | num_layers > 5
    layer_5_units | num_layers > 5
    layer_6_activation | num_layers > 6
    layer_6_dropout_rate | num_layers > 6
    layer_6_units | num_layers > 6
    layer_7_activation | num_layers > 7
    layer_7_dropout_rate | num_layers > 7
    layer_7_units | num_layers > 7

Loss and Metric#

For the loss we will use the Gaussian negative log-likelihood to evalute the quality of the predicted distribution \(\mathcal{N}(\mu_\theta(x), \sigma_\theta(x))\) using with formula:

\[L_\text{NLL}(x, y;\theta) = \frac{1}{2}\left(\log\left(\sigma_\theta^{2}(x)\right) + \frac{\left(y-\mu_{\theta}(x)\right)^{2}}{\sigma_{\theta}^{2}(x)}\right) + \text{cst}\]

As complementary metric, we use the squared error to evaluate the quality of the mean predictions \(\mu_\theta(x)\):

\[L_\text{SE}(x, y;\theta) = (\mu_\theta(x)-y)^2\]
def nll(y, rv_y):
    """Negative log likelihood for Pytorch distribution.

    Args:
        y: true data.
        rv_y: learned (predicted) probability distribution.
    """
    return -rv_y.log_prob(y)


def squared_error(y_true, rv_y):
    """Squared error for Pytorch distribution.

    Args:
        y: true data.
        rv_y: learned (predicted) probability distribution.
    """
    y_pred = rv_y.mean
    return (y_true - y_pred) ** 2

Training loop#

In our training loop, we make sure to collect training and validation learning curves for better analysis.

We also add a mechanism to checkpoint weights of the model based on the best observed validation loss.

Finally, we add an early stopping mechanism to save computing resources.

Code (Training loop)
def train_one_step(model, optimizer, x_batch, y_batch):
    model.train()
    optimizer.zero_grad()
    y_dist = model(x_batch)

    loss = torch.mean(nll(y_batch, y_dist))
    mse = torch.mean(squared_error(y_batch, y_dist))

    loss.backward()
    optimizer.step()

    return loss, mse


def train(
    job,
    model,
    optimizer,
    x_train,
    x_val,
    y_train,
    y_val,
    n_epochs,
    batch_size,
    scheduler=None,
    patience=200,
    progressbar=True,
):
    data_train = DataLoader(TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True)

    checkpointed_state_dict = model.state_dict()
    checkpointed_val_loss = np.inf

    train_loss, val_loss = [], []
    train_mse, val_mse = [], []

    tqdm_bar = tqdm(total=n_epochs, disable=not progressbar)

    for epoch in range(n_epochs):
        batch_losses_t, batch_losses_v, batch_mse_t, batch_mse_v = [], [], [], []

        for batch_x, batch_y in data_train:
            b_train_loss, b_train_mse = train_one_step(model, optimizer, batch_x, batch_y)

            model.eval()
            y_dist = model(x_val)
            b_val_loss = torch.mean(nll(y_val, y_dist))
            b_val_mse = torch.mean(squared_error(y_val, y_dist))

            batch_losses_t.append(to_numpy(b_train_loss))
            batch_mse_t.append(to_numpy(b_train_mse))
            batch_losses_v.append(to_numpy(b_val_loss))
            batch_mse_v.append(to_numpy(b_val_mse))

        train_loss.append(np.mean(batch_losses_t))
        val_loss.append(np.mean(batch_losses_v))
        train_mse.append(np.mean(batch_mse_t))
        val_mse.append(np.mean(batch_mse_v))

        if scheduler is not None:
            scheduler.step(val_loss[-1])

        tqdm_bar.update(1)
        tqdm_bar.set_postfix(
            {
                "train_loss": f"{train_loss[-1]:.3f}",
                "val_loss": f"{val_loss[-1]:.3f}",
                "train_mse": f"{train_mse[-1]:.3f}",
                "val_mse": f"{val_mse[-1]:.3f}",
            }
        )

        # Checkpoint weights if they improve
        if val_loss[-1] < checkpointed_val_loss:
            checkpointed_val_loss = val_loss[-1]
            checkpointed_state_dict = model.state_dict()

        # Early discarding
        job.record(budget=epoch+1, objective=-val_loss[-1])
        if job.stopped():
            break

        if len(val_loss) > (patience + 1) and val_loss[-patience - 1] < min(val_loss[-patience:]):
            break

    # Reload the best weights
    model.load_state_dict(checkpointed_state_dict)

    return train_loss, val_loss, train_mse, val_mse

Run time#

import multiprocessing

dtype = torch.float32
if torch.cuda.is_available():
    device = "cuda"
    device_count = 1
else:
    device = "cpu"
    device_count = multiprocessing.cpu_count()

print(f"Runtime with {device=}, {device_count=}, {dtype=}")
Runtime with device='cpu', device_count=10, dtype=torch.float32
Code (Conversion utility functions)
def to_torch(array):
    return torch.from_numpy(array).to(device=device, dtype=dtype)

def to_numpy(tensor):
    return tensor.detach().cpu().numpy()

Evaluation function#

We start by defining a function that will create the Torch module from a dictionnary of hyperparameters.

def create_model(parameters: dict, y_mu=0, y_std=1):
    num_layers = parameters["num_layers"]
    torch_module = DeepNormalRegressor(
        n_inputs=1,
        layers=[
            (
                parameters[f"layer_{i}_units"],
                parameters[f"layer_{i}_activation"],
                parameters[f"layer_{i}_dropout_rate"],
            )
            for i in range(num_layers)
        ],
        n_units_mean=parameters["n_units_mean"],
        n_units_std=parameters["n_units_std"],
        std_offset=parameters["std_offset"],
        softplus_factor=parameters["softplus_factor"],
        loc=y_mu,
        scale=y_std,
    ).to(device=device, dtype=dtype)
    return torch_module

The evaluation function (often called run-function in DeepHyper) is the function that receives suggested parameters as inputs job.parameters and returns an "objective" that we want to maximize.

max_n_epochs = 1_000


def run(job, model_checkpoint_dir=".", verbose=False):
    (x, y), (vx, vy), (tx, ty) = load_data()

    # Create the model based on neural architecture hyperparameters
    model = create_model(job.parameters, y_mu, y_std)

    if verbose:
        print(model)

    # Initialize training loop based on training hyperparameters
    optimizer = torch.optim.Adam(model.parameters(), lr=job.parameters["learning_rate"])
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        factor=job.parameters["lr_scheduler_factor"],
        patience=job.parameters["lr_scheduler_patience"],
    )

    x, vx, tx = to_torch(x), to_torch(vx), to_torch(tx)
    y, vy, ty = to_torch(y), to_torch(vy), to_torch(ty)

    try:
        train_losses, val_losses, train_mse, val_mse = train(
            job,
            model,
            optimizer,
            x,
            vx,
            y,
            vy,
            n_epochs=max_n_epochs,
            batch_size=job.parameters["batch_size"],
            scheduler=scheduler,
            progressbar=verbose,
        )
    except Exception:
        return "F_fit"

    ty_pred = model(tx)
    test_loss = to_numpy(torch.mean(nll(ty, ty_pred)))
    test_mse = to_numpy(torch.mean(squared_error(ty, ty_pred)))

    # Saving the model's state (i.e., weights)
    torch.save(model.state_dict(), os.path.join(model_checkpoint_dir, f"model_{job.id}.pt"))

    return {
        "objective": -val_losses[-1],
        "metadata": {
            "train_loss": train_losses,
            "val_loss": val_losses,
            "train_mse": train_mse,
            "val_mse": val_mse,
            "test_loss": test_loss,
            "test_mse": test_mse,
            "budget": len(val_losses),
        },
    }

Evaluation of the baseline#

We evaluate the default configuration of hyperparameters that we call “baseline” using the same evaluation function. This allows to test the evaluation function.

from deephyper.evaluator import RunningJob

baseline_dir = "nas_baseline_regression"

def evaluate_baseline(problem):
    model_checkpoint_dir = os.path.join(baseline_dir, "models")
    pathlib.Path(model_checkpoint_dir).mkdir(parents=True, exist_ok=True)

    default_parameters = problem.default_configuration
    print(f"{default_parameters=}\n")

    result = run(
        RunningJob(parameters=default_parameters),
        model_checkpoint_dir=model_checkpoint_dir,
        verbose=True,
    )
    return result

baseline_results = evaluate_baseline(problem)
default_parameters={'batch_size': 32, 'layer_0_activation': 'relu', 'layer_0_dropout_rate': 0.0, 'layer_0_units': 512, 'layer_1_activation': 'relu', 'layer_1_dropout_rate': 0.0, 'layer_1_units': 512, 'layer_2_activation': 'relu', 'layer_2_dropout_rate': 0.0, 'layer_2_units': 512, 'learning_rate': 0.002, 'lr_scheduler_factor': 0.1, 'lr_scheduler_patience': 20, 'n_units_mean': 512, 'n_units_std': 512, 'num_layers': 5, 'softplus_factor': 0.05, 'std_offset': 0.001, 'layer_3_activation': 'relu', 'layer_3_dropout_rate': 0.0, 'layer_3_units': 512, 'layer_4_activation': 'relu', 'layer_4_dropout_rate': 0.0, 'layer_4_units': 512, 'layer_5_activation': 'relu', 'layer_5_dropout_rate': 0.0, 'layer_5_units': 16, 'layer_6_activation': 'relu', 'layer_6_dropout_rate': 0.0, 'layer_6_units': 16, 'layer_7_activation': 'relu', 'layer_7_dropout_rate': 0.0, 'layer_7_units': 16}

DeepNormalRegressor(
  (shared_layer): Sequential(
    (0): Linear(in_features=1, out_features=512, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.0, inplace=False)
    (3): Linear(in_features=512, out_features=512, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.0, inplace=False)
    (6): Linear(in_features=512, out_features=512, bias=True)
    (7): ReLU()
    (8): Dropout(p=0.0, inplace=False)
    (9): Linear(in_features=512, out_features=512, bias=True)
    (10): ReLU()
    (11): Dropout(p=0.0, inplace=False)
    (12): Linear(in_features=512, out_features=512, bias=True)
    (13): ReLU()
    (14): Dropout(p=0.0, inplace=False)
  )
  (mean_layer): Sequential(
    (0): Linear(in_features=512, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=1, bias=True)
  )
  (std_layer): Sequential(
    (0): Linear(in_features=512, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=1, bias=True)
    (3): Softplus(beta=1.0, threshold=20.0)
  )
)

  0%|          | 0/1000 [00:00<?, ?it/s]
  0%|          | 1/1000 [00:00<01:15, 13.31it/s, train_loss=3.878, val_loss=3.780, train_mse=11.845, val_mse=11.397]
  0%|          | 2/1000 [00:00<01:11, 14.04it/s, train_loss=3.878, val_loss=3.780, train_mse=11.845, val_mse=11.397]
  0%|          | 2/1000 [00:00<01:11, 14.04it/s, train_loss=1.778, val_loss=1.827, train_mse=1.961, val_mse=2.228]
  0%|          | 3/1000 [00:00<01:11, 14.04it/s, train_loss=1.671, val_loss=1.750, train_mse=1.629, val_mse=1.939]
  0%|          | 4/1000 [00:00<01:06, 15.06it/s, train_loss=1.671, val_loss=1.750, train_mse=1.629, val_mse=1.939]
  0%|          | 4/1000 [00:00<01:06, 15.06it/s, train_loss=1.675, val_loss=1.776, train_mse=1.663, val_mse=2.003]
  0%|          | 5/1000 [00:00<01:06, 15.06it/s, train_loss=1.684, val_loss=1.809, train_mse=1.698, val_mse=2.139]
  1%|          | 6/1000 [00:00<01:05, 15.15it/s, train_loss=1.684, val_loss=1.809, train_mse=1.698, val_mse=2.139]
  1%|          | 6/1000 [00:00<01:05, 15.15it/s, train_loss=1.712, val_loss=1.789, train_mse=1.807, val_mse=2.077]
  1%|          | 7/1000 [00:00<01:05, 15.15it/s, train_loss=1.649, val_loss=1.764, train_mse=1.580, val_mse=1.978]
  1%|          | 8/1000 [00:00<01:04, 15.40it/s, train_loss=1.649, val_loss=1.764, train_mse=1.580, val_mse=1.978]
  1%|          | 8/1000 [00:00<01:04, 15.40it/s, train_loss=1.657, val_loss=1.740, train_mse=1.610, val_mse=1.891]
  1%|          | 9/1000 [00:00<01:04, 15.40it/s, train_loss=1.644, val_loss=1.730, train_mse=1.567, val_mse=1.864]
  1%|          | 10/1000 [00:00<01:03, 15.54it/s, train_loss=1.644, val_loss=1.730, train_mse=1.567, val_mse=1.864]
  1%|          | 10/1000 [00:00<01:03, 15.54it/s, train_loss=1.628, val_loss=1.706, train_mse=1.515, val_mse=1.772]
  1%|          | 11/1000 [00:00<01:03, 15.54it/s, train_loss=1.591, val_loss=1.704, train_mse=1.403, val_mse=1.761]
  1%|          | 12/1000 [00:00<01:03, 15.48it/s, train_loss=1.591, val_loss=1.704, train_mse=1.403, val_mse=1.761]
  1%|          | 12/1000 [00:00<01:03, 15.48it/s, train_loss=1.614, val_loss=1.708, train_mse=1.464, val_mse=1.764]
  1%|▏         | 13/1000 [00:00<01:03, 15.48it/s, train_loss=1.652, val_loss=1.725, train_mse=1.581, val_mse=1.810]
  1%|▏         | 14/1000 [00:00<01:05, 15.12it/s, train_loss=1.652, val_loss=1.725, train_mse=1.581, val_mse=1.810]
  1%|▏         | 14/1000 [00:00<01:05, 15.12it/s, train_loss=1.572, val_loss=1.698, train_mse=1.327, val_mse=1.720]
  2%|▏         | 15/1000 [00:00<01:05, 15.12it/s, train_loss=1.574, val_loss=1.661, train_mse=1.345, val_mse=1.611]
  2%|▏         | 16/1000 [00:01<01:05, 15.08it/s, train_loss=1.574, val_loss=1.661, train_mse=1.345, val_mse=1.611]
  2%|▏         | 16/1000 [00:01<01:05, 15.08it/s, train_loss=1.578, val_loss=1.688, train_mse=1.369, val_mse=1.683]
  2%|▏         | 17/1000 [00:01<01:05, 15.08it/s, train_loss=1.588, val_loss=1.663, train_mse=1.398, val_mse=1.613]
  2%|▏         | 18/1000 [00:01<01:05, 15.04it/s, train_loss=1.588, val_loss=1.663, train_mse=1.398, val_mse=1.613]
  2%|▏         | 18/1000 [00:01<01:05, 15.04it/s, train_loss=1.563, val_loss=1.688, train_mse=1.330, val_mse=1.693]
  2%|▏         | 19/1000 [00:01<01:05, 15.04it/s, train_loss=1.674, val_loss=1.746, train_mse=1.650, val_mse=1.883]
  2%|▏         | 20/1000 [00:01<01:04, 15.14it/s, train_loss=1.674, val_loss=1.746, train_mse=1.650, val_mse=1.883]
  2%|▏         | 20/1000 [00:01<01:04, 15.14it/s, train_loss=1.600, val_loss=1.681, train_mse=1.444, val_mse=1.694]
  2%|▏         | 21/1000 [00:01<01:04, 15.14it/s, train_loss=1.543, val_loss=1.652, train_mse=1.270, val_mse=1.601]
  2%|▏         | 22/1000 [00:01<01:02, 15.53it/s, train_loss=1.543, val_loss=1.652, train_mse=1.270, val_mse=1.601]
  2%|▏         | 22/1000 [00:01<01:02, 15.53it/s, train_loss=1.549, val_loss=1.649, train_mse=1.295, val_mse=1.576]
  2%|▏         | 23/1000 [00:01<01:02, 15.53it/s, train_loss=1.537, val_loss=1.665, train_mse=1.261, val_mse=1.602]
  2%|▏         | 24/1000 [00:01<01:02, 15.57it/s, train_loss=1.537, val_loss=1.665, train_mse=1.261, val_mse=1.602]
  2%|▏         | 24/1000 [00:01<01:02, 15.57it/s, train_loss=1.546, val_loss=1.673, train_mse=1.289, val_mse=1.632]
  2%|▎         | 25/1000 [00:01<01:02, 15.57it/s, train_loss=1.549, val_loss=1.642, train_mse=1.294, val_mse=1.547]
  3%|▎         | 26/1000 [00:01<01:01, 15.77it/s, train_loss=1.549, val_loss=1.642, train_mse=1.294, val_mse=1.547]
  3%|▎         | 26/1000 [00:01<01:01, 15.77it/s, train_loss=1.525, val_loss=1.644, train_mse=1.231, val_mse=1.547]
  3%|▎         | 27/1000 [00:01<01:01, 15.77it/s, train_loss=1.533, val_loss=1.640, train_mse=1.255, val_mse=1.531]
  3%|▎         | 28/1000 [00:01<01:02, 15.66it/s, train_loss=1.533, val_loss=1.640, train_mse=1.255, val_mse=1.531]
  3%|▎         | 28/1000 [00:01<01:02, 15.66it/s, train_loss=1.542, val_loss=1.649, train_mse=1.277, val_mse=1.556]
  3%|▎         | 29/1000 [00:01<01:02, 15.66it/s, train_loss=1.530, val_loss=1.631, train_mse=1.245, val_mse=1.506]
  3%|▎         | 30/1000 [00:01<01:03, 15.34it/s, train_loss=1.530, val_loss=1.631, train_mse=1.245, val_mse=1.506]
  3%|▎         | 30/1000 [00:01<01:03, 15.34it/s, train_loss=1.525, val_loss=1.635, train_mse=1.235, val_mse=1.520]
  3%|▎         | 31/1000 [00:02<01:03, 15.34it/s, train_loss=1.533, val_loss=1.633, train_mse=1.256, val_mse=1.510]
  3%|▎         | 32/1000 [00:02<01:02, 15.54it/s, train_loss=1.533, val_loss=1.633, train_mse=1.256, val_mse=1.510]
  3%|▎         | 32/1000 [00:02<01:02, 15.54it/s, train_loss=1.528, val_loss=1.632, train_mse=1.244, val_mse=1.504]
  3%|▎         | 33/1000 [00:02<01:02, 15.54it/s, train_loss=1.524, val_loss=1.635, train_mse=1.231, val_mse=1.511]
  3%|▎         | 34/1000 [00:02<01:03, 15.29it/s, train_loss=1.524, val_loss=1.635, train_mse=1.231, val_mse=1.511]
  3%|▎         | 34/1000 [00:02<01:03, 15.29it/s, train_loss=1.515, val_loss=1.628, train_mse=1.207, val_mse=1.499]
  4%|▎         | 35/1000 [00:02<01:03, 15.29it/s, train_loss=1.542, val_loss=1.642, train_mse=1.277, val_mse=1.530]
  4%|▎         | 36/1000 [00:02<01:02, 15.50it/s, train_loss=1.542, val_loss=1.642, train_mse=1.277, val_mse=1.530]
  4%|▎         | 36/1000 [00:02<01:02, 15.50it/s, train_loss=1.535, val_loss=1.630, train_mse=1.261, val_mse=1.501]
  4%|▎         | 37/1000 [00:02<01:02, 15.50it/s, train_loss=1.534, val_loss=1.632, train_mse=1.256, val_mse=1.505]
  4%|▍         | 38/1000 [00:02<01:01, 15.69it/s, train_loss=1.534, val_loss=1.632, train_mse=1.256, val_mse=1.505]
  4%|▍         | 38/1000 [00:02<01:01, 15.69it/s, train_loss=1.540, val_loss=1.657, train_mse=1.273, val_mse=1.574]
  4%|▍         | 39/1000 [00:02<01:01, 15.69it/s, train_loss=1.513, val_loss=1.627, train_mse=1.203, val_mse=1.497]
  4%|▍         | 40/1000 [00:02<01:00, 15.78it/s, train_loss=1.513, val_loss=1.627, train_mse=1.203, val_mse=1.497]
  4%|▍         | 40/1000 [00:02<01:00, 15.78it/s, train_loss=1.514, val_loss=1.633, train_mse=1.208, val_mse=1.506]
  4%|▍         | 41/1000 [00:02<01:00, 15.78it/s, train_loss=1.518, val_loss=1.631, train_mse=1.219, val_mse=1.496]
  4%|▍         | 42/1000 [00:02<00:59, 16.09it/s, train_loss=1.518, val_loss=1.631, train_mse=1.219, val_mse=1.496]
  4%|▍         | 42/1000 [00:02<00:59, 16.09it/s, train_loss=1.539, val_loss=1.636, train_mse=1.267, val_mse=1.500]
  4%|▍         | 43/1000 [00:02<00:59, 16.09it/s, train_loss=1.539, val_loss=1.641, train_mse=1.272, val_mse=1.520]
  4%|▍         | 44/1000 [00:02<00:58, 16.45it/s, train_loss=1.539, val_loss=1.641, train_mse=1.272, val_mse=1.520]
  4%|▍         | 44/1000 [00:02<00:58, 16.45it/s, train_loss=1.536, val_loss=1.641, train_mse=1.257, val_mse=1.536]
  4%|▍         | 45/1000 [00:02<00:58, 16.45it/s, train_loss=1.530, val_loss=1.639, train_mse=1.247, val_mse=1.534]
  5%|▍         | 46/1000 [00:02<00:58, 16.24it/s, train_loss=1.530, val_loss=1.639, train_mse=1.247, val_mse=1.534]
  5%|▍         | 46/1000 [00:02<00:58, 16.24it/s, train_loss=1.518, val_loss=1.632, train_mse=1.216, val_mse=1.503]
  5%|▍         | 47/1000 [00:03<00:58, 16.24it/s, train_loss=1.522, val_loss=1.631, train_mse=1.229, val_mse=1.496]
  5%|▍         | 48/1000 [00:03<00:59, 15.98it/s, train_loss=1.522, val_loss=1.631, train_mse=1.229, val_mse=1.496]
  5%|▍         | 48/1000 [00:03<00:59, 15.98it/s, train_loss=1.518, val_loss=1.639, train_mse=1.218, val_mse=1.521]
  5%|▍         | 49/1000 [00:03<00:59, 15.98it/s, train_loss=1.530, val_loss=1.632, train_mse=1.246, val_mse=1.493]
  5%|▌         | 50/1000 [00:03<01:00, 15.63it/s, train_loss=1.530, val_loss=1.632, train_mse=1.246, val_mse=1.493]
  5%|▌         | 50/1000 [00:03<01:00, 15.63it/s, train_loss=1.511, val_loss=1.629, train_mse=1.200, val_mse=1.493]
  5%|▌         | 51/1000 [00:03<01:00, 15.63it/s, train_loss=1.535, val_loss=1.640, train_mse=1.261, val_mse=1.524]
  5%|▌         | 52/1000 [00:03<01:00, 15.63it/s, train_loss=1.535, val_loss=1.640, train_mse=1.261, val_mse=1.524]
  5%|▌         | 52/1000 [00:03<01:00, 15.63it/s, train_loss=1.524, val_loss=1.632, train_mse=1.232, val_mse=1.496]
  5%|▌         | 53/1000 [00:03<01:00, 15.63it/s, train_loss=1.526, val_loss=1.631, train_mse=1.239, val_mse=1.496]
  5%|▌         | 54/1000 [00:03<00:59, 15.84it/s, train_loss=1.526, val_loss=1.631, train_mse=1.239, val_mse=1.496]
  5%|▌         | 54/1000 [00:03<00:59, 15.84it/s, train_loss=1.537, val_loss=1.644, train_mse=1.263, val_mse=1.534]
  6%|▌         | 55/1000 [00:03<00:59, 15.84it/s, train_loss=1.536, val_loss=1.629, train_mse=1.263, val_mse=1.496]
  6%|▌         | 56/1000 [00:03<00:59, 15.91it/s, train_loss=1.536, val_loss=1.629, train_mse=1.263, val_mse=1.496]
  6%|▌         | 56/1000 [00:03<00:59, 15.91it/s, train_loss=1.518, val_loss=1.638, train_mse=1.217, val_mse=1.519]
  6%|▌         | 57/1000 [00:03<00:59, 15.91it/s, train_loss=1.519, val_loss=1.631, train_mse=1.220, val_mse=1.501]
  6%|▌         | 58/1000 [00:03<00:58, 16.07it/s, train_loss=1.519, val_loss=1.631, train_mse=1.220, val_mse=1.501]
  6%|▌         | 58/1000 [00:03<00:58, 16.07it/s, train_loss=1.522, val_loss=1.644, train_mse=1.229, val_mse=1.531]
  6%|▌         | 59/1000 [00:03<00:58, 16.07it/s, train_loss=1.530, val_loss=1.631, train_mse=1.248, val_mse=1.496]
  6%|▌         | 60/1000 [00:03<00:57, 16.21it/s, train_loss=1.530, val_loss=1.631, train_mse=1.248, val_mse=1.496]
  6%|▌         | 60/1000 [00:03<00:57, 16.21it/s, train_loss=1.505, val_loss=1.633, train_mse=1.187, val_mse=1.501]
  6%|▌         | 61/1000 [00:03<00:57, 16.21it/s, train_loss=1.523, val_loss=1.635, train_mse=1.232, val_mse=1.507]
  6%|▌         | 62/1000 [00:03<00:57, 16.29it/s, train_loss=1.523, val_loss=1.635, train_mse=1.232, val_mse=1.507]
  6%|▌         | 62/1000 [00:03<00:57, 16.29it/s, train_loss=1.515, val_loss=1.632, train_mse=1.212, val_mse=1.499]
  6%|▋         | 63/1000 [00:04<00:57, 16.29it/s, train_loss=1.516, val_loss=1.630, train_mse=1.215, val_mse=1.495]
  6%|▋         | 64/1000 [00:04<00:57, 16.34it/s, train_loss=1.516, val_loss=1.630, train_mse=1.215, val_mse=1.495]
  6%|▋         | 64/1000 [00:04<00:57, 16.34it/s, train_loss=1.521, val_loss=1.630, train_mse=1.226, val_mse=1.494]
  6%|▋         | 65/1000 [00:04<00:57, 16.34it/s, train_loss=1.517, val_loss=1.630, train_mse=1.217, val_mse=1.493]
  7%|▋         | 66/1000 [00:04<00:57, 16.38it/s, train_loss=1.517, val_loss=1.630, train_mse=1.217, val_mse=1.493]
  7%|▋         | 66/1000 [00:04<00:57, 16.38it/s, train_loss=1.517, val_loss=1.630, train_mse=1.218, val_mse=1.493]
  7%|▋         | 67/1000 [00:04<00:56, 16.38it/s, train_loss=1.519, val_loss=1.630, train_mse=1.221, val_mse=1.494]
  7%|▋         | 68/1000 [00:04<00:56, 16.40it/s, train_loss=1.519, val_loss=1.630, train_mse=1.221, val_mse=1.494]
  7%|▋         | 68/1000 [00:04<00:56, 16.40it/s, train_loss=1.508, val_loss=1.630, train_mse=1.196, val_mse=1.494]
  7%|▋         | 69/1000 [00:04<00:56, 16.40it/s, train_loss=1.504, val_loss=1.630, train_mse=1.185, val_mse=1.494]
  7%|▋         | 70/1000 [00:04<00:56, 16.43it/s, train_loss=1.504, val_loss=1.630, train_mse=1.185, val_mse=1.494]
  7%|▋         | 70/1000 [00:04<00:56, 16.43it/s, train_loss=1.519, val_loss=1.631, train_mse=1.221, val_mse=1.494]
  7%|▋         | 71/1000 [00:04<00:56, 16.43it/s, train_loss=1.521, val_loss=1.630, train_mse=1.228, val_mse=1.493]
  7%|▋         | 72/1000 [00:04<00:56, 16.50it/s, train_loss=1.521, val_loss=1.630, train_mse=1.228, val_mse=1.493]
  7%|▋         | 72/1000 [00:04<00:56, 16.50it/s, train_loss=1.521, val_loss=1.631, train_mse=1.225, val_mse=1.495]
  7%|▋         | 73/1000 [00:04<00:56, 16.50it/s, train_loss=1.509, val_loss=1.631, train_mse=1.198, val_mse=1.494]
  7%|▋         | 74/1000 [00:04<00:56, 16.39it/s, train_loss=1.509, val_loss=1.631, train_mse=1.198, val_mse=1.494]
  7%|▋         | 74/1000 [00:04<00:56, 16.39it/s, train_loss=1.528, val_loss=1.631, train_mse=1.244, val_mse=1.493]
  8%|▊         | 75/1000 [00:04<00:56, 16.39it/s, train_loss=1.529, val_loss=1.631, train_mse=1.247, val_mse=1.493]
  8%|▊         | 76/1000 [00:04<00:54, 16.82it/s, train_loss=1.529, val_loss=1.631, train_mse=1.247, val_mse=1.493]
  8%|▊         | 76/1000 [00:04<00:54, 16.82it/s, train_loss=1.522, val_loss=1.631, train_mse=1.229, val_mse=1.493]
  8%|▊         | 77/1000 [00:04<00:54, 16.82it/s, train_loss=1.514, val_loss=1.630, train_mse=1.209, val_mse=1.493]
  8%|▊         | 78/1000 [00:04<00:55, 16.74it/s, train_loss=1.514, val_loss=1.630, train_mse=1.209, val_mse=1.493]
  8%|▊         | 78/1000 [00:04<00:55, 16.74it/s, train_loss=1.517, val_loss=1.631, train_mse=1.217, val_mse=1.496]
  8%|▊         | 79/1000 [00:04<00:55, 16.74it/s, train_loss=1.511, val_loss=1.631, train_mse=1.202, val_mse=1.495]
  8%|▊         | 80/1000 [00:05<00:55, 16.68it/s, train_loss=1.511, val_loss=1.631, train_mse=1.202, val_mse=1.495]
  8%|▊         | 80/1000 [00:05<00:55, 16.68it/s, train_loss=1.513, val_loss=1.631, train_mse=1.207, val_mse=1.494]
  8%|▊         | 81/1000 [00:05<00:55, 16.68it/s, train_loss=1.528, val_loss=1.631, train_mse=1.242, val_mse=1.493]
  8%|▊         | 82/1000 [00:05<00:56, 16.39it/s, train_loss=1.528, val_loss=1.631, train_mse=1.242, val_mse=1.493]
  8%|▊         | 82/1000 [00:05<00:56, 16.39it/s, train_loss=1.526, val_loss=1.631, train_mse=1.240, val_mse=1.493]
  8%|▊         | 83/1000 [00:05<00:55, 16.39it/s, train_loss=1.531, val_loss=1.631, train_mse=1.251, val_mse=1.493]
  8%|▊         | 84/1000 [00:05<00:57, 16.06it/s, train_loss=1.531, val_loss=1.631, train_mse=1.251, val_mse=1.493]
  8%|▊         | 84/1000 [00:05<00:57, 16.06it/s, train_loss=1.515, val_loss=1.631, train_mse=1.213, val_mse=1.493]
  8%|▊         | 85/1000 [00:05<00:56, 16.06it/s, train_loss=1.516, val_loss=1.631, train_mse=1.215, val_mse=1.493]
  9%|▊         | 86/1000 [00:05<00:56, 16.19it/s, train_loss=1.516, val_loss=1.631, train_mse=1.215, val_mse=1.493]
  9%|▊         | 86/1000 [00:05<00:56, 16.19it/s, train_loss=1.518, val_loss=1.631, train_mse=1.219, val_mse=1.493]
  9%|▊         | 87/1000 [00:05<00:56, 16.19it/s, train_loss=1.516, val_loss=1.631, train_mse=1.213, val_mse=1.493]
  9%|▉         | 88/1000 [00:05<00:55, 16.48it/s, train_loss=1.516, val_loss=1.631, train_mse=1.213, val_mse=1.493]
  9%|▉         | 88/1000 [00:05<00:55, 16.48it/s, train_loss=1.509, val_loss=1.631, train_mse=1.198, val_mse=1.493]
  9%|▉         | 89/1000 [00:05<00:55, 16.48it/s, train_loss=1.521, val_loss=1.631, train_mse=1.225, val_mse=1.493]
  9%|▉         | 90/1000 [00:05<00:55, 16.54it/s, train_loss=1.521, val_loss=1.631, train_mse=1.225, val_mse=1.493]
  9%|▉         | 90/1000 [00:05<00:55, 16.54it/s, train_loss=1.509, val_loss=1.631, train_mse=1.198, val_mse=1.493]
  9%|▉         | 91/1000 [00:05<00:54, 16.54it/s, train_loss=1.506, val_loss=1.631, train_mse=1.190, val_mse=1.493]
  9%|▉         | 92/1000 [00:05<00:54, 16.70it/s, train_loss=1.506, val_loss=1.631, train_mse=1.190, val_mse=1.493]
  9%|▉         | 92/1000 [00:05<00:54, 16.70it/s, train_loss=1.531, val_loss=1.631, train_mse=1.251, val_mse=1.493]
  9%|▉         | 93/1000 [00:05<00:54, 16.70it/s, train_loss=1.508, val_loss=1.631, train_mse=1.195, val_mse=1.493]
  9%|▉         | 94/1000 [00:05<00:54, 16.75it/s, train_loss=1.508, val_loss=1.631, train_mse=1.195, val_mse=1.493]
  9%|▉         | 94/1000 [00:05<00:54, 16.75it/s, train_loss=1.543, val_loss=1.631, train_mse=1.279, val_mse=1.493]
 10%|▉         | 95/1000 [00:05<00:54, 16.75it/s, train_loss=1.512, val_loss=1.631, train_mse=1.206, val_mse=1.493]
 10%|▉         | 96/1000 [00:06<00:54, 16.46it/s, train_loss=1.512, val_loss=1.631, train_mse=1.206, val_mse=1.493]
 10%|▉         | 96/1000 [00:06<00:54, 16.46it/s, train_loss=1.527, val_loss=1.631, train_mse=1.240, val_mse=1.493]
 10%|▉         | 97/1000 [00:06<00:54, 16.46it/s, train_loss=1.519, val_loss=1.631, train_mse=1.221, val_mse=1.493]
 10%|▉         | 98/1000 [00:06<00:55, 16.27it/s, train_loss=1.519, val_loss=1.631, train_mse=1.221, val_mse=1.493]
 10%|▉         | 98/1000 [00:06<00:55, 16.27it/s, train_loss=1.527, val_loss=1.631, train_mse=1.242, val_mse=1.493]
 10%|▉         | 99/1000 [00:06<00:55, 16.27it/s, train_loss=1.507, val_loss=1.631, train_mse=1.192, val_mse=1.493]
 10%|█         | 100/1000 [00:06<00:54, 16.42it/s, train_loss=1.507, val_loss=1.631, train_mse=1.192, val_mse=1.493]
 10%|█         | 100/1000 [00:06<00:54, 16.42it/s, train_loss=1.527, val_loss=1.631, train_mse=1.241, val_mse=1.493]
 10%|█         | 101/1000 [00:06<00:54, 16.42it/s, train_loss=1.517, val_loss=1.631, train_mse=1.218, val_mse=1.493]
 10%|█         | 102/1000 [00:06<00:55, 16.20it/s, train_loss=1.517, val_loss=1.631, train_mse=1.218, val_mse=1.493]
 10%|█         | 102/1000 [00:06<00:55, 16.20it/s, train_loss=1.542, val_loss=1.631, train_mse=1.277, val_mse=1.493]
 10%|█         | 103/1000 [00:06<00:55, 16.20it/s, train_loss=1.523, val_loss=1.631, train_mse=1.231, val_mse=1.493]
 10%|█         | 104/1000 [00:06<00:55, 16.28it/s, train_loss=1.523, val_loss=1.631, train_mse=1.231, val_mse=1.493]
 10%|█         | 104/1000 [00:06<00:55, 16.28it/s, train_loss=1.530, val_loss=1.631, train_mse=1.248, val_mse=1.493]
 10%|█         | 105/1000 [00:06<00:54, 16.28it/s, train_loss=1.507, val_loss=1.631, train_mse=1.192, val_mse=1.493]
 11%|█         | 106/1000 [00:06<00:55, 16.17it/s, train_loss=1.507, val_loss=1.631, train_mse=1.192, val_mse=1.493]
 11%|█         | 106/1000 [00:06<00:55, 16.17it/s, train_loss=1.521, val_loss=1.631, train_mse=1.227, val_mse=1.493]
 11%|█         | 107/1000 [00:06<00:55, 16.17it/s, train_loss=1.527, val_loss=1.631, train_mse=1.240, val_mse=1.493]
 11%|█         | 108/1000 [00:06<00:54, 16.28it/s, train_loss=1.527, val_loss=1.631, train_mse=1.240, val_mse=1.493]
 11%|█         | 108/1000 [00:06<00:54, 16.28it/s, train_loss=1.502, val_loss=1.631, train_mse=1.180, val_mse=1.493]
 11%|█         | 109/1000 [00:06<00:54, 16.28it/s, train_loss=1.517, val_loss=1.631, train_mse=1.216, val_mse=1.493]
 11%|█         | 110/1000 [00:06<00:54, 16.25it/s, train_loss=1.517, val_loss=1.631, train_mse=1.216, val_mse=1.493]
 11%|█         | 110/1000 [00:06<00:54, 16.25it/s, train_loss=1.513, val_loss=1.631, train_mse=1.207, val_mse=1.493]
 11%|█         | 111/1000 [00:06<00:54, 16.25it/s, train_loss=1.515, val_loss=1.631, train_mse=1.213, val_mse=1.493]
 11%|█         | 112/1000 [00:06<00:53, 16.45it/s, train_loss=1.515, val_loss=1.631, train_mse=1.213, val_mse=1.493]
 11%|█         | 112/1000 [00:06<00:53, 16.45it/s, train_loss=1.516, val_loss=1.631, train_mse=1.214, val_mse=1.493]
 11%|█▏        | 113/1000 [00:07<00:53, 16.45it/s, train_loss=1.511, val_loss=1.631, train_mse=1.204, val_mse=1.493]
 11%|█▏        | 114/1000 [00:07<00:54, 16.35it/s, train_loss=1.511, val_loss=1.631, train_mse=1.204, val_mse=1.493]
 11%|█▏        | 114/1000 [00:07<00:54, 16.35it/s, train_loss=1.526, val_loss=1.631, train_mse=1.237, val_mse=1.493]
 12%|█▏        | 115/1000 [00:07<00:54, 16.35it/s, train_loss=1.505, val_loss=1.631, train_mse=1.187, val_mse=1.493]
 12%|█▏        | 116/1000 [00:07<00:53, 16.55it/s, train_loss=1.505, val_loss=1.631, train_mse=1.187, val_mse=1.493]
 12%|█▏        | 116/1000 [00:07<00:53, 16.55it/s, train_loss=1.504, val_loss=1.631, train_mse=1.186, val_mse=1.493]
 12%|█▏        | 117/1000 [00:07<00:53, 16.55it/s, train_loss=1.519, val_loss=1.631, train_mse=1.221, val_mse=1.493]
 12%|█▏        | 118/1000 [00:07<00:54, 16.24it/s, train_loss=1.519, val_loss=1.631, train_mse=1.221, val_mse=1.493]
 12%|█▏        | 118/1000 [00:07<00:54, 16.24it/s, train_loss=1.519, val_loss=1.631, train_mse=1.221, val_mse=1.493]
 12%|█▏        | 119/1000 [00:07<00:54, 16.24it/s, train_loss=1.516, val_loss=1.631, train_mse=1.214, val_mse=1.493]
 12%|█▏        | 120/1000 [00:07<00:55, 15.92it/s, train_loss=1.516, val_loss=1.631, train_mse=1.214, val_mse=1.493]
 12%|█▏        | 120/1000 [00:07<00:55, 15.92it/s, train_loss=1.523, val_loss=1.631, train_mse=1.231, val_mse=1.493]
 12%|█▏        | 121/1000 [00:07<00:55, 15.92it/s, train_loss=1.514, val_loss=1.631, train_mse=1.209, val_mse=1.493]
 12%|█▏        | 122/1000 [00:07<00:54, 16.05it/s, train_loss=1.514, val_loss=1.631, train_mse=1.209, val_mse=1.493]
 12%|█▏        | 122/1000 [00:07<00:54, 16.05it/s, train_loss=1.509, val_loss=1.631, train_mse=1.197, val_mse=1.493]
 12%|█▏        | 123/1000 [00:07<00:54, 16.05it/s, train_loss=1.513, val_loss=1.631, train_mse=1.207, val_mse=1.493]
 12%|█▏        | 124/1000 [00:07<00:54, 16.04it/s, train_loss=1.513, val_loss=1.631, train_mse=1.207, val_mse=1.493]
 12%|█▏        | 124/1000 [00:07<00:54, 16.04it/s, train_loss=1.518, val_loss=1.631, train_mse=1.220, val_mse=1.493]
 12%|█▎        | 125/1000 [00:07<00:54, 16.04it/s, train_loss=1.514, val_loss=1.631, train_mse=1.210, val_mse=1.493]
 13%|█▎        | 126/1000 [00:07<00:54, 15.93it/s, train_loss=1.514, val_loss=1.631, train_mse=1.210, val_mse=1.493]
 13%|█▎        | 126/1000 [00:07<00:54, 15.93it/s, train_loss=1.529, val_loss=1.631, train_mse=1.246, val_mse=1.493]
 13%|█▎        | 127/1000 [00:07<00:54, 15.93it/s, train_loss=1.520, val_loss=1.631, train_mse=1.224, val_mse=1.493]
 13%|█▎        | 128/1000 [00:07<00:54, 15.91it/s, train_loss=1.520, val_loss=1.631, train_mse=1.224, val_mse=1.493]
 13%|█▎        | 128/1000 [00:07<00:54, 15.91it/s, train_loss=1.513, val_loss=1.631, train_mse=1.206, val_mse=1.493]
 13%|█▎        | 129/1000 [00:08<00:54, 15.91it/s, train_loss=1.515, val_loss=1.631, train_mse=1.212, val_mse=1.493]
 13%|█▎        | 130/1000 [00:08<00:54, 15.94it/s, train_loss=1.515, val_loss=1.631, train_mse=1.212, val_mse=1.493]
 13%|█▎        | 130/1000 [00:08<00:54, 15.94it/s, train_loss=1.522, val_loss=1.631, train_mse=1.229, val_mse=1.493]
 13%|█▎        | 131/1000 [00:08<00:54, 15.94it/s, train_loss=1.516, val_loss=1.631, train_mse=1.213, val_mse=1.493]
 13%|█▎        | 132/1000 [00:08<00:53, 16.13it/s, train_loss=1.516, val_loss=1.631, train_mse=1.213, val_mse=1.493]
 13%|█▎        | 132/1000 [00:08<00:53, 16.13it/s, train_loss=1.527, val_loss=1.631, train_mse=1.242, val_mse=1.493]
 13%|█▎        | 133/1000 [00:08<00:53, 16.13it/s, train_loss=1.526, val_loss=1.631, train_mse=1.239, val_mse=1.493]
 13%|█▎        | 134/1000 [00:08<00:54, 15.97it/s, train_loss=1.526, val_loss=1.631, train_mse=1.239, val_mse=1.493]
 13%|█▎        | 134/1000 [00:08<00:54, 15.97it/s, train_loss=1.526, val_loss=1.631, train_mse=1.238, val_mse=1.493]
 14%|█▎        | 135/1000 [00:08<00:54, 15.97it/s, train_loss=1.510, val_loss=1.631, train_mse=1.200, val_mse=1.493]
 14%|█▎        | 136/1000 [00:08<00:55, 15.65it/s, train_loss=1.510, val_loss=1.631, train_mse=1.200, val_mse=1.493]
 14%|█▎        | 136/1000 [00:08<00:55, 15.65it/s, train_loss=1.512, val_loss=1.631, train_mse=1.205, val_mse=1.493]
 14%|█▎        | 137/1000 [00:08<00:55, 15.65it/s, train_loss=1.517, val_loss=1.631, train_mse=1.216, val_mse=1.493]
 14%|█▍        | 138/1000 [00:08<00:55, 15.63it/s, train_loss=1.517, val_loss=1.631, train_mse=1.216, val_mse=1.493]
 14%|█▍        | 138/1000 [00:08<00:55, 15.63it/s, train_loss=1.517, val_loss=1.631, train_mse=1.216, val_mse=1.493]
 14%|█▍        | 139/1000 [00:08<00:55, 15.63it/s, train_loss=1.521, val_loss=1.631, train_mse=1.226, val_mse=1.493]
 14%|█▍        | 140/1000 [00:08<00:53, 16.08it/s, train_loss=1.521, val_loss=1.631, train_mse=1.226, val_mse=1.493]
 14%|█▍        | 140/1000 [00:08<00:53, 16.08it/s, train_loss=1.504, val_loss=1.631, train_mse=1.184, val_mse=1.493]
 14%|█▍        | 141/1000 [00:08<00:53, 16.08it/s, train_loss=1.506, val_loss=1.631, train_mse=1.191, val_mse=1.493]
 14%|█▍        | 142/1000 [00:08<00:52, 16.31it/s, train_loss=1.506, val_loss=1.631, train_mse=1.191, val_mse=1.493]
 14%|█▍        | 142/1000 [00:08<00:52, 16.31it/s, train_loss=1.541, val_loss=1.631, train_mse=1.274, val_mse=1.493]
 14%|█▍        | 143/1000 [00:08<00:52, 16.31it/s, train_loss=1.512, val_loss=1.631, train_mse=1.204, val_mse=1.493]
 14%|█▍        | 144/1000 [00:08<00:52, 16.32it/s, train_loss=1.512, val_loss=1.631, train_mse=1.204, val_mse=1.493]
 14%|█▍        | 144/1000 [00:08<00:52, 16.32it/s, train_loss=1.511, val_loss=1.631, train_mse=1.203, val_mse=1.493]
 14%|█▍        | 145/1000 [00:09<00:52, 16.32it/s, train_loss=1.516, val_loss=1.631, train_mse=1.216, val_mse=1.493]
 15%|█▍        | 146/1000 [00:09<00:51, 16.55it/s, train_loss=1.516, val_loss=1.631, train_mse=1.216, val_mse=1.493]
 15%|█▍        | 146/1000 [00:09<00:51, 16.55it/s, train_loss=1.509, val_loss=1.631, train_mse=1.199, val_mse=1.493]
 15%|█▍        | 147/1000 [00:09<00:51, 16.55it/s, train_loss=1.522, val_loss=1.631, train_mse=1.230, val_mse=1.493]
 15%|█▍        | 148/1000 [00:09<00:52, 16.29it/s, train_loss=1.522, val_loss=1.631, train_mse=1.230, val_mse=1.493]
 15%|█▍        | 148/1000 [00:09<00:52, 16.29it/s, train_loss=1.510, val_loss=1.631, train_mse=1.200, val_mse=1.493]
 15%|█▍        | 149/1000 [00:09<00:52, 16.29it/s, train_loss=1.521, val_loss=1.631, train_mse=1.228, val_mse=1.493]
 15%|█▌        | 150/1000 [00:09<00:52, 16.31it/s, train_loss=1.521, val_loss=1.631, train_mse=1.228, val_mse=1.493]
 15%|█▌        | 150/1000 [00:09<00:52, 16.31it/s, train_loss=1.530, val_loss=1.631, train_mse=1.248, val_mse=1.493]
 15%|█▌        | 151/1000 [00:09<00:52, 16.31it/s, train_loss=1.532, val_loss=1.631, train_mse=1.253, val_mse=1.493]
 15%|█▌        | 152/1000 [00:09<00:53, 15.94it/s, train_loss=1.532, val_loss=1.631, train_mse=1.253, val_mse=1.493]
 15%|█▌        | 152/1000 [00:09<00:53, 15.94it/s, train_loss=1.513, val_loss=1.631, train_mse=1.208, val_mse=1.493]
 15%|█▌        | 153/1000 [00:09<00:53, 15.94it/s, train_loss=1.504, val_loss=1.631, train_mse=1.185, val_mse=1.493]
 15%|█▌        | 154/1000 [00:09<00:52, 16.21it/s, train_loss=1.504, val_loss=1.631, train_mse=1.185, val_mse=1.493]
 15%|█▌        | 154/1000 [00:09<00:52, 16.21it/s, train_loss=1.521, val_loss=1.631, train_mse=1.227, val_mse=1.493]
 16%|█▌        | 155/1000 [00:09<00:52, 16.21it/s, train_loss=1.528, val_loss=1.631, train_mse=1.243, val_mse=1.493]
 16%|█▌        | 156/1000 [00:09<00:52, 16.18it/s, train_loss=1.528, val_loss=1.631, train_mse=1.243, val_mse=1.493]
 16%|█▌        | 156/1000 [00:09<00:52, 16.18it/s, train_loss=1.504, val_loss=1.631, train_mse=1.186, val_mse=1.493]
 16%|█▌        | 157/1000 [00:09<00:52, 16.18it/s, train_loss=1.512, val_loss=1.631, train_mse=1.205, val_mse=1.493]
 16%|█▌        | 158/1000 [00:09<00:51, 16.35it/s, train_loss=1.512, val_loss=1.631, train_mse=1.205, val_mse=1.493]
 16%|█▌        | 158/1000 [00:09<00:51, 16.35it/s, train_loss=1.517, val_loss=1.631, train_mse=1.218, val_mse=1.493]
 16%|█▌        | 159/1000 [00:09<00:51, 16.35it/s, train_loss=1.509, val_loss=1.631, train_mse=1.196, val_mse=1.493]
 16%|█▌        | 160/1000 [00:09<00:51, 16.35it/s, train_loss=1.509, val_loss=1.631, train_mse=1.196, val_mse=1.493]
 16%|█▌        | 160/1000 [00:09<00:51, 16.35it/s, train_loss=1.513, val_loss=1.631, train_mse=1.207, val_mse=1.493]
 16%|█▌        | 161/1000 [00:10<00:51, 16.35it/s, train_loss=1.512, val_loss=1.631, train_mse=1.206, val_mse=1.493]
 16%|█▌        | 162/1000 [00:10<00:51, 16.40it/s, train_loss=1.512, val_loss=1.631, train_mse=1.206, val_mse=1.493]
 16%|█▌        | 162/1000 [00:10<00:51, 16.40it/s, train_loss=1.524, val_loss=1.631, train_mse=1.233, val_mse=1.493]
 16%|█▋        | 163/1000 [00:10<00:51, 16.40it/s, train_loss=1.514, val_loss=1.631, train_mse=1.209, val_mse=1.493]
 16%|█▋        | 164/1000 [00:10<00:50, 16.53it/s, train_loss=1.514, val_loss=1.631, train_mse=1.209, val_mse=1.493]
 16%|█▋        | 164/1000 [00:10<00:50, 16.53it/s, train_loss=1.529, val_loss=1.631, train_mse=1.245, val_mse=1.493]
 16%|█▋        | 165/1000 [00:10<00:50, 16.53it/s, train_loss=1.516, val_loss=1.631, train_mse=1.215, val_mse=1.493]
 17%|█▋        | 166/1000 [00:10<00:50, 16.37it/s, train_loss=1.516, val_loss=1.631, train_mse=1.215, val_mse=1.493]
 17%|█▋        | 166/1000 [00:10<00:50, 16.37it/s, train_loss=1.513, val_loss=1.631, train_mse=1.207, val_mse=1.493]
 17%|█▋        | 167/1000 [00:10<00:50, 16.37it/s, train_loss=1.516, val_loss=1.631, train_mse=1.215, val_mse=1.493]
 17%|█▋        | 168/1000 [00:10<00:51, 16.31it/s, train_loss=1.516, val_loss=1.631, train_mse=1.215, val_mse=1.493]
 17%|█▋        | 168/1000 [00:10<00:51, 16.31it/s, train_loss=1.525, val_loss=1.631, train_mse=1.236, val_mse=1.493]
 17%|█▋        | 169/1000 [00:10<00:50, 16.31it/s, train_loss=1.516, val_loss=1.631, train_mse=1.215, val_mse=1.493]
 17%|█▋        | 170/1000 [00:10<00:50, 16.33it/s, train_loss=1.516, val_loss=1.631, train_mse=1.215, val_mse=1.493]
 17%|█▋        | 170/1000 [00:10<00:50, 16.33it/s, train_loss=1.517, val_loss=1.631, train_mse=1.216, val_mse=1.493]
 17%|█▋        | 171/1000 [00:10<00:50, 16.33it/s, train_loss=1.523, val_loss=1.631, train_mse=1.231, val_mse=1.493]
 17%|█▋        | 172/1000 [00:10<00:51, 16.05it/s, train_loss=1.523, val_loss=1.631, train_mse=1.231, val_mse=1.493]
 17%|█▋        | 172/1000 [00:10<00:51, 16.05it/s, train_loss=1.523, val_loss=1.631, train_mse=1.231, val_mse=1.493]
 17%|█▋        | 173/1000 [00:10<00:51, 16.05it/s, train_loss=1.508, val_loss=1.631, train_mse=1.194, val_mse=1.493]
 17%|█▋        | 174/1000 [00:10<00:50, 16.41it/s, train_loss=1.508, val_loss=1.631, train_mse=1.194, val_mse=1.493]
 17%|█▋        | 174/1000 [00:10<00:50, 16.41it/s, train_loss=1.506, val_loss=1.631, train_mse=1.191, val_mse=1.493]
 18%|█▊        | 175/1000 [00:10<00:50, 16.41it/s, train_loss=1.513, val_loss=1.631, train_mse=1.206, val_mse=1.493]
 18%|█▊        | 176/1000 [00:10<00:50, 16.42it/s, train_loss=1.513, val_loss=1.631, train_mse=1.206, val_mse=1.493]
 18%|█▊        | 176/1000 [00:10<00:50, 16.42it/s, train_loss=1.528, val_loss=1.631, train_mse=1.244, val_mse=1.493]
 18%|█▊        | 177/1000 [00:11<00:50, 16.42it/s, train_loss=1.517, val_loss=1.631, train_mse=1.216, val_mse=1.493]
 18%|█▊        | 178/1000 [00:11<00:49, 16.56it/s, train_loss=1.517, val_loss=1.631, train_mse=1.216, val_mse=1.493]
 18%|█▊        | 178/1000 [00:11<00:49, 16.56it/s, train_loss=1.510, val_loss=1.631, train_mse=1.199, val_mse=1.493]
 18%|█▊        | 179/1000 [00:11<00:49, 16.56it/s, train_loss=1.511, val_loss=1.631, train_mse=1.202, val_mse=1.493]
 18%|█▊        | 180/1000 [00:11<00:49, 16.61it/s, train_loss=1.511, val_loss=1.631, train_mse=1.202, val_mse=1.493]
 18%|█▊        | 180/1000 [00:11<00:49, 16.61it/s, train_loss=1.523, val_loss=1.631, train_mse=1.233, val_mse=1.493]
 18%|█▊        | 181/1000 [00:11<00:49, 16.61it/s, train_loss=1.511, val_loss=1.631, train_mse=1.203, val_mse=1.493]
 18%|█▊        | 182/1000 [00:11<00:49, 16.61it/s, train_loss=1.511, val_loss=1.631, train_mse=1.203, val_mse=1.493]
 18%|█▊        | 182/1000 [00:11<00:49, 16.61it/s, train_loss=1.517, val_loss=1.631, train_mse=1.218, val_mse=1.493]
 18%|█▊        | 183/1000 [00:11<00:49, 16.61it/s, train_loss=1.515, val_loss=1.631, train_mse=1.213, val_mse=1.493]
 18%|█▊        | 184/1000 [00:11<00:49, 16.39it/s, train_loss=1.515, val_loss=1.631, train_mse=1.213, val_mse=1.493]
 18%|█▊        | 184/1000 [00:11<00:49, 16.39it/s, train_loss=1.524, val_loss=1.631, train_mse=1.235, val_mse=1.493]
 18%|█▊        | 185/1000 [00:11<00:49, 16.39it/s, train_loss=1.526, val_loss=1.631, train_mse=1.239, val_mse=1.493]
 19%|█▊        | 186/1000 [00:11<00:50, 16.26it/s, train_loss=1.526, val_loss=1.631, train_mse=1.239, val_mse=1.493]
 19%|█▊        | 186/1000 [00:11<00:50, 16.26it/s, train_loss=1.518, val_loss=1.631, train_mse=1.218, val_mse=1.493]
 19%|█▊        | 187/1000 [00:11<00:50, 16.26it/s, train_loss=1.515, val_loss=1.631, train_mse=1.211, val_mse=1.493]
 19%|█▉        | 188/1000 [00:11<00:49, 16.25it/s, train_loss=1.515, val_loss=1.631, train_mse=1.211, val_mse=1.493]
 19%|█▉        | 188/1000 [00:11<00:49, 16.25it/s, train_loss=1.505, val_loss=1.631, train_mse=1.187, val_mse=1.493]
 19%|█▉        | 189/1000 [00:11<00:49, 16.25it/s, train_loss=1.508, val_loss=1.631, train_mse=1.194, val_mse=1.493]
 19%|█▉        | 190/1000 [00:11<00:49, 16.35it/s, train_loss=1.508, val_loss=1.631, train_mse=1.194, val_mse=1.493]
 19%|█▉        | 190/1000 [00:11<00:49, 16.35it/s, train_loss=1.511, val_loss=1.631, train_mse=1.202, val_mse=1.493]
 19%|█▉        | 191/1000 [00:11<00:49, 16.35it/s, train_loss=1.516, val_loss=1.631, train_mse=1.215, val_mse=1.493]
 19%|█▉        | 192/1000 [00:11<00:49, 16.37it/s, train_loss=1.516, val_loss=1.631, train_mse=1.215, val_mse=1.493]
 19%|█▉        | 192/1000 [00:11<00:49, 16.37it/s, train_loss=1.519, val_loss=1.631, train_mse=1.222, val_mse=1.493]
 19%|█▉        | 193/1000 [00:11<00:49, 16.37it/s, train_loss=1.516, val_loss=1.631, train_mse=1.214, val_mse=1.493]
 19%|█▉        | 194/1000 [00:12<00:48, 16.46it/s, train_loss=1.516, val_loss=1.631, train_mse=1.214, val_mse=1.493]
 19%|█▉        | 194/1000 [00:12<00:48, 16.46it/s, train_loss=1.510, val_loss=1.631, train_mse=1.199, val_mse=1.493]
 20%|█▉        | 195/1000 [00:12<00:48, 16.46it/s, train_loss=1.505, val_loss=1.631, train_mse=1.187, val_mse=1.493]
 20%|█▉        | 196/1000 [00:12<00:49, 16.17it/s, train_loss=1.505, val_loss=1.631, train_mse=1.187, val_mse=1.493]
 20%|█▉        | 196/1000 [00:12<00:49, 16.17it/s, train_loss=1.511, val_loss=1.631, train_mse=1.201, val_mse=1.493]
 20%|█▉        | 197/1000 [00:12<00:49, 16.17it/s, train_loss=1.530, val_loss=1.631, train_mse=1.248, val_mse=1.493]
 20%|█▉        | 198/1000 [00:12<00:49, 16.14it/s, train_loss=1.530, val_loss=1.631, train_mse=1.248, val_mse=1.493]
 20%|█▉        | 198/1000 [00:12<00:49, 16.14it/s, train_loss=1.508, val_loss=1.631, train_mse=1.195, val_mse=1.493]
 20%|█▉        | 199/1000 [00:12<00:49, 16.14it/s, train_loss=1.527, val_loss=1.631, train_mse=1.242, val_mse=1.493]
 20%|██        | 200/1000 [00:12<00:49, 16.13it/s, train_loss=1.527, val_loss=1.631, train_mse=1.242, val_mse=1.493]
 20%|██        | 200/1000 [00:12<00:49, 16.13it/s, train_loss=1.546, val_loss=1.631, train_mse=1.287, val_mse=1.493]
 20%|██        | 201/1000 [00:12<00:49, 16.13it/s, train_loss=1.519, val_loss=1.631, train_mse=1.221, val_mse=1.493]
 20%|██        | 202/1000 [00:12<00:49, 16.26it/s, train_loss=1.519, val_loss=1.631, train_mse=1.221, val_mse=1.493]
 20%|██        | 202/1000 [00:12<00:49, 16.26it/s, train_loss=1.509, val_loss=1.631, train_mse=1.197, val_mse=1.493]
 20%|██        | 203/1000 [00:12<00:49, 16.26it/s, train_loss=1.538, val_loss=1.631, train_mse=1.269, val_mse=1.493]
 20%|██        | 204/1000 [00:12<00:48, 16.37it/s, train_loss=1.538, val_loss=1.631, train_mse=1.269, val_mse=1.493]
 20%|██        | 204/1000 [00:12<00:48, 16.37it/s, train_loss=1.513, val_loss=1.631, train_mse=1.207, val_mse=1.493]
 20%|██        | 205/1000 [00:12<00:48, 16.37it/s, train_loss=1.530, val_loss=1.631, train_mse=1.249, val_mse=1.493]
 21%|██        | 206/1000 [00:12<00:48, 16.54it/s, train_loss=1.530, val_loss=1.631, train_mse=1.249, val_mse=1.493]
 21%|██        | 206/1000 [00:12<00:48, 16.54it/s, train_loss=1.501, val_loss=1.631, train_mse=1.177, val_mse=1.493]
 21%|██        | 207/1000 [00:12<00:47, 16.54it/s, train_loss=1.515, val_loss=1.631, train_mse=1.211, val_mse=1.493]
 21%|██        | 208/1000 [00:12<00:48, 16.20it/s, train_loss=1.515, val_loss=1.631, train_mse=1.211, val_mse=1.493]
 21%|██        | 208/1000 [00:12<00:48, 16.20it/s, train_loss=1.510, val_loss=1.631, train_mse=1.201, val_mse=1.493]
 21%|██        | 209/1000 [00:12<00:48, 16.20it/s, train_loss=1.512, val_loss=1.631, train_mse=1.204, val_mse=1.493]
 21%|██        | 210/1000 [00:13<00:48, 16.37it/s, train_loss=1.512, val_loss=1.631, train_mse=1.204, val_mse=1.493]
 21%|██        | 210/1000 [00:13<00:48, 16.37it/s, train_loss=1.515, val_loss=1.631, train_mse=1.211, val_mse=1.493]
 21%|██        | 211/1000 [00:13<00:48, 16.37it/s, train_loss=1.530, val_loss=1.631, train_mse=1.249, val_mse=1.493]
 21%|██        | 212/1000 [00:13<00:47, 16.44it/s, train_loss=1.530, val_loss=1.631, train_mse=1.249, val_mse=1.493]
 21%|██        | 212/1000 [00:13<00:47, 16.44it/s, train_loss=1.524, val_loss=1.631, train_mse=1.235, val_mse=1.493]
 21%|██▏       | 213/1000 [00:13<00:47, 16.44it/s, train_loss=1.509, val_loss=1.631, train_mse=1.196, val_mse=1.493]
 21%|██▏       | 214/1000 [00:13<00:48, 16.37it/s, train_loss=1.509, val_loss=1.631, train_mse=1.196, val_mse=1.493]
 21%|██▏       | 214/1000 [00:13<00:48, 16.37it/s, train_loss=1.518, val_loss=1.631, train_mse=1.219, val_mse=1.493]
 22%|██▏       | 215/1000 [00:13<00:47, 16.37it/s, train_loss=1.522, val_loss=1.631, train_mse=1.229, val_mse=1.493]
 22%|██▏       | 216/1000 [00:13<00:48, 16.14it/s, train_loss=1.522, val_loss=1.631, train_mse=1.229, val_mse=1.493]
 22%|██▏       | 216/1000 [00:13<00:48, 16.14it/s, train_loss=1.508, val_loss=1.631, train_mse=1.196, val_mse=1.493]
 22%|██▏       | 217/1000 [00:13<00:48, 16.14it/s, train_loss=1.530, val_loss=1.631, train_mse=1.248, val_mse=1.493]
 22%|██▏       | 218/1000 [00:13<00:47, 16.36it/s, train_loss=1.530, val_loss=1.631, train_mse=1.248, val_mse=1.493]
 22%|██▏       | 218/1000 [00:13<00:47, 16.36it/s, train_loss=1.504, val_loss=1.631, train_mse=1.186, val_mse=1.493]
 22%|██▏       | 219/1000 [00:13<00:47, 16.36it/s, train_loss=1.521, val_loss=1.631, train_mse=1.226, val_mse=1.493]
 22%|██▏       | 220/1000 [00:13<00:48, 16.11it/s, train_loss=1.521, val_loss=1.631, train_mse=1.226, val_mse=1.493]
 22%|██▏       | 220/1000 [00:13<00:48, 16.11it/s, train_loss=1.518, val_loss=1.631, train_mse=1.219, val_mse=1.493]
 22%|██▏       | 221/1000 [00:13<00:48, 16.11it/s, train_loss=1.510, val_loss=1.631, train_mse=1.198, val_mse=1.493]
 22%|██▏       | 222/1000 [00:13<00:47, 16.41it/s, train_loss=1.510, val_loss=1.631, train_mse=1.198, val_mse=1.493]
 22%|██▏       | 222/1000 [00:13<00:47, 16.41it/s, train_loss=1.502, val_loss=1.631, train_mse=1.179, val_mse=1.493]
 22%|██▏       | 223/1000 [00:13<00:47, 16.41it/s, train_loss=1.514, val_loss=1.631, train_mse=1.210, val_mse=1.493]
 22%|██▏       | 224/1000 [00:13<00:47, 16.45it/s, train_loss=1.514, val_loss=1.631, train_mse=1.210, val_mse=1.493]
 22%|██▏       | 224/1000 [00:13<00:47, 16.45it/s, train_loss=1.515, val_loss=1.631, train_mse=1.211, val_mse=1.493]
 22%|██▎       | 225/1000 [00:13<00:47, 16.45it/s, train_loss=1.513, val_loss=1.631, train_mse=1.208, val_mse=1.493]
 23%|██▎       | 226/1000 [00:14<00:47, 16.21it/s, train_loss=1.513, val_loss=1.631, train_mse=1.208, val_mse=1.493]
 23%|██▎       | 226/1000 [00:14<00:47, 16.21it/s, train_loss=1.518, val_loss=1.631, train_mse=1.218, val_mse=1.493]
 23%|██▎       | 227/1000 [00:14<00:47, 16.21it/s, train_loss=1.517, val_loss=1.631, train_mse=1.217, val_mse=1.493]
 23%|██▎       | 228/1000 [00:14<00:47, 16.29it/s, train_loss=1.517, val_loss=1.631, train_mse=1.217, val_mse=1.493]
 23%|██▎       | 228/1000 [00:14<00:47, 16.29it/s, train_loss=1.522, val_loss=1.631, train_mse=1.228, val_mse=1.493]
 23%|██▎       | 229/1000 [00:14<00:47, 16.29it/s, train_loss=1.527, val_loss=1.631, train_mse=1.241, val_mse=1.493]
 23%|██▎       | 230/1000 [00:14<00:49, 15.52it/s, train_loss=1.527, val_loss=1.631, train_mse=1.241, val_mse=1.493]
 23%|██▎       | 230/1000 [00:14<00:49, 15.52it/s, train_loss=1.532, val_loss=1.631, train_mse=1.253, val_mse=1.493]
 23%|██▎       | 231/1000 [00:14<00:49, 15.52it/s, train_loss=1.508, val_loss=1.631, train_mse=1.196, val_mse=1.493]
 23%|██▎       | 232/1000 [00:14<00:48, 15.94it/s, train_loss=1.508, val_loss=1.631, train_mse=1.196, val_mse=1.493]
 23%|██▎       | 232/1000 [00:14<00:48, 15.94it/s, train_loss=1.532, val_loss=1.631, train_mse=1.253, val_mse=1.493]
 23%|██▎       | 233/1000 [00:14<00:48, 15.94it/s, train_loss=1.519, val_loss=1.631, train_mse=1.221, val_mse=1.493]
 23%|██▎       | 234/1000 [00:14<00:47, 15.98it/s, train_loss=1.519, val_loss=1.631, train_mse=1.221, val_mse=1.493]
 23%|██▎       | 234/1000 [00:14<00:47, 15.98it/s, train_loss=1.520, val_loss=1.631, train_mse=1.223, val_mse=1.493]
 24%|██▎       | 235/1000 [00:14<00:47, 15.98it/s, train_loss=1.519, val_loss=1.631, train_mse=1.222, val_mse=1.493]
 24%|██▎       | 236/1000 [00:14<00:47, 15.99it/s, train_loss=1.519, val_loss=1.631, train_mse=1.222, val_mse=1.493]
 24%|██▎       | 236/1000 [00:14<00:47, 15.99it/s, train_loss=1.526, val_loss=1.631, train_mse=1.238, val_mse=1.493]
 24%|██▎       | 237/1000 [00:14<00:47, 15.99it/s, train_loss=1.514, val_loss=1.631, train_mse=1.209, val_mse=1.493]
 24%|██▍       | 238/1000 [00:14<00:51, 14.89it/s, train_loss=1.514, val_loss=1.631, train_mse=1.209, val_mse=1.493]
 24%|██▍       | 238/1000 [00:14<00:51, 14.89it/s, train_loss=1.505, val_loss=1.631, train_mse=1.187, val_mse=1.493]
 24%|██▍       | 239/1000 [00:14<00:51, 14.89it/s, train_loss=1.522, val_loss=1.631, train_mse=1.230, val_mse=1.493]
 24%|██▍       | 239/1000 [00:14<00:47, 16.08it/s, train_loss=1.522, val_loss=1.631, train_mse=1.230, val_mse=1.493]

Then, we look at the learning curves of our baseline model returned by the evaluation function.

These curves display a good learning behaviour:

  • the training and validation curves follow each other closely and are decreasing.

  • a clear convergence plateau is reached at the end of the training.

Code (Make learning curves plot)
_ = plt.figure(figsize=(WIDTH_PLOTS, HEIGHT_PLOTS))

x_values = np.arange(1, len(baseline_results["metadata"]["train_loss"]) + 1)
_ = plt.plot(
    x_values,
    baseline_results["metadata"]["train_loss"],
    label="Training",
)
_ = plt.plot(
    x_values,
    baseline_results["metadata"]["val_loss"],
    label="Validation",
)

_ = plt.xlim(x_values.min(), x_values.max())
_ = plt.grid(which="both", linestyle=":")
_ = plt.legend()
_ = plt.xlabel("Epochs")
_ = plt.ylabel("NLL")
plot nas deep ensemble uq regression pytorch

In addition, we look at the predictions by reloading the checkpointed weights.

We first need to recreate the torch module and then we update its state using the checkpointed weights.

Code (Make prediction plot)
_ = plt.figure(figsize=(WIDTH_PLOTS, HEIGHT_PLOTS))
_ = plt.scatter(train_X, train_y, s=5, label="Training")
_ = plt.scatter(valid_X, valid_y, s=5, label="Validation")
_ = plt.plot(test_X, test_y, linestyle="--", color="gray", label="Test")

_ = plt.plot(test_X, y_pred_mean, label=r"$\mu(x)$")
kappa = 1.96
_ = plt.fill_between(
    test_X.reshape(-1),
    (y_pred_mean - kappa * y_pred_std).reshape(-1),
    (y_pred_mean + kappa * y_pred_std).reshape(-1),
    alpha=0.25,
    label=r"$\sigma_\text{al}(x)$",
)

_ = plt.fill_between([-30, -15], [-y_lim, -y_lim], [y_lim, y_lim], color="gray", alpha=0.15)
_ = plt.fill_between([15, 30], [-y_lim, -y_lim], [y_lim, y_lim], color="gray", alpha=0.15)
_ = plt.xlim(-x_lim, x_lim)
_ = plt.ylim(-y_lim, y_lim)
_ = plt.legend(ncols=2)
_ = plt.xlabel(r"$x$")
_ = plt.ylabel(r"$f(x)$")
_ = plt.grid(which="both", linestyle=":")
plot nas deep ensemble uq regression pytorch

Analysis of the results#

We will now look at the results of the search globally in term of evolution of the objective and worker’s activity.

from deephyper.analysis.hpo import plot_search_trajectory_single_objective_hpo
from deephyper.analysis.hpo import plot_worker_utilization


fig, axes = plt.subplots(
    nrows=2,
    ncols=1,
    sharex=True,
    figsize=(WIDTH_PLOTS, HEIGHT_PLOTS),
)

_ = plot_search_trajectory_single_objective_hpo(
    hpo_results,
    mode="min",
    x_units="seconds",
    ax=axes[0],
)
axes[0].set_yscale("log")

_ = plot_worker_utilization(
    hpo_results,
    profile_type="submit/gather",
    ax=axes[1],
)

plt.tight_layout()
plot nas deep ensemble uq regression pytorch

Then, we split results between successful and failed results if there are some.

p:batch_size p:layer_0_activation p:layer_0_dropout_rate p:layer_0_units p:layer_1_activation p:layer_1_dropout_rate p:layer_1_units p:layer_2_activation p:layer_2_dropout_rate p:layer_2_units p:learning_rate p:lr_scheduler_factor p:lr_scheduler_patience p:n_units_mean p:n_units_std p:num_layers p:softplus_factor p:std_offset p:layer_3_activation p:layer_3_dropout_rate p:layer_3_units p:layer_4_activation p:layer_4_dropout_rate p:layer_4_units p:layer_5_activation p:layer_5_dropout_rate p:layer_5_units p:layer_6_activation p:layer_6_dropout_rate p:layer_6_units p:layer_7_activation p:layer_7_dropout_rate p:layer_7_units objective job_id job_status m:timestamp_submit m:train_loss m:val_loss m:train_mse m:val_mse m:test_loss m:test_mse m:budget m:timestamp_gather
0 68 silu 0.122704 105 mish 0.171507 130 silu 0.025040 258 0.000051 0.726719 31 29 320 5 0.071061 9.339449e-03 tanh 0.039217 149 gelu 0.192898 228 relu 0.000000 16 relu 0.00000 16 relu 0.000000 16 -1.623808 5 DONE 3.079311 [2.2356727, 2.1990643, 2.1611142, 2.1306512, 2... [2.3715773, 2.321428, 2.2811942, 2.244391, 2.2... [3.729117, 3.6063304, 3.473826, 3.3686516, 3.2... [4.2356567, 4.062816, 3.9246833, 3.79775, 3.67... 2.311146 3.122602e+00 247.0 13.158018
1 70 sigmoid 0.181926 474 relu 0.206947 102 gelu 0.077514 102 0.025985 0.865468 44 34 72 6 0.295450 3.219338e-06 tanh 0.013806 31 sigmoid 0.084893 133 swish 0.041425 342 relu 0.00000 16 relu 0.000000 16 -1.632331 2 DONE 3.075146 [6.5537024, 1.9579012, 1.665855, 1.6295078, 1.... [6.8209295, 1.9405267, 1.7315395, 1.684643, 1.... [13.900175, 2.9138274, 1.5408843, 1.4303554, 1... [14.717955, 2.8753903, 1.622345, 1.6306633, 1.... 2.433584 3.431710e+00 288.0 15.570797
2 97 mish 0.092652 353 mish 0.160574 209 gelu 0.187758 473 0.011654 0.285402 67 138 352 4 0.587813 4.738140e-08 swish 0.231107 320 relu 0.000000 16 relu 0.000000 16 relu 0.00000 16 relu 0.000000 16 -10.482303 8 DONE 3.083357 [4936380000000000.0, 5.681423, 6.664331, 45.25... [1811503100000000.0, 6.081249, 6.5736256, 6.43... [76.39607, 1464.7188, 3686.473, 2281.423, 3065... [162.81638, 1842.9956, 2423.849, 409.41083, 35... 10.246502 1.667279e+09 202.0 15.596288
3 119 tanh 0.001329 69 swish 0.184957 295 relu 0.176468 307 0.000262 0.205413 91 225 133 3 0.745321 5.758780e-08 relu 0.000000 16 relu 0.000000 16 relu 0.000000 16 relu 0.00000 16 relu 0.000000 16 -1.137323 1 DONE 3.073541 [2.2547195, 2.105179, 1.9712442, 1.85241, 1.76... [2.3378735, 2.1741226, 2.0298746, 1.9063791, 1... [3.66675, 3.2829812, 2.8488493, 2.375151, 1.89... [4.025764, 3.5948925, 3.1349583, 2.635012, 2.1... 5.430150 2.168292e+00 1000.0 21.298497
4 212 swish 0.080304 108 swish 0.133348 180 swish 0.214755 499 0.000073 0.744227 60 184 462 8 0.236448 1.947248e-05 mish 0.094552 398 relu 0.020214 69 silu 0.165548 130 mish 0.24868 403 gelu 0.175789 333 -2.191429 10 DONE 15.566050 [2.2363968, 2.2094727, 2.1878111, 2.2310803, 2... [2.36148, 2.3554432, 2.348463, 2.3400981, 2.32... [3.7908025, 3.6970637, 3.6223884, 3.7993572, 3... [4.2729263, 4.2616057, 4.2462792, 4.2252903, 4... 3.207318 2.229017e+00 340.0 29.968976
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
248 12 tanh 0.165450 455 relu 0.017045 352 swish 0.154026 236 0.000677 0.900386 75 156 97 3 0.654945 2.088938e-07 relu 0.000000 16 relu 0.000000 16 relu 0.000000 16 relu 0.00000 16 relu 0.000000 16 -1.192449 239 DONE 1165.378879 [1.6727325, 1.5800053, 1.6157894, 1.5432746, 1... [1.7266064, 1.6549356, 1.7140119, 1.637512, 1.... [1.6016694, 1.3569081, 1.4258856, 1.2578446, 1... [1.813346, 1.590621, 1.7395469, 1.5146513, 1.6... 5.439753 2.908574e+00 760.0 1245.544874
249 79 tanh 0.161465 364 silu 0.056837 350 sigmoid 0.192261 415 0.000176 0.738764 65 510 105 3 0.468523 2.400839e-04 relu 0.000000 16 relu 0.000000 16 relu 0.000000 16 relu 0.00000 16 relu 0.000000 16 -1.203743 250 DONE 1212.161439 [2.122525, 1.9694458, 1.8289667, 1.7220303, 1.... [2.1825612, 1.9890454, 1.8536297, 1.7302277, 1... [3.5389302, 3.000673, 2.0841455, 1.5233757, 1.... [3.9003167, 3.1071992, 2.2470388, 1.6122707, 1... 4.303766 3.376671e+00 815.0 1245.564055
250 10 tanh 0.241329 460 tanh 0.017031 492 silu 0.212373 431 0.000131 0.949106 96 453 127 4 0.477518 1.958822e-07 relu 0.116858 126 relu 0.000000 16 relu 0.000000 16 relu 0.00000 16 relu 0.000000 16 -1.187596 240 DONE 1165.380704 [1.8427774, 1.5654259, 1.54907, 1.5718877, 1.5... [1.91599, 1.630422, 1.6431626, 1.6635764, 1.63... [2.2706072, 1.2972956, 1.2960027, 1.3628039, 1... [2.5700068, 1.5153091, 1.5366507, 1.5885208, 1... 7.368615 2.003667e+00 989.0 1313.699958
251 10 tanh 0.070972 486 silu 0.023657 278 gelu 0.046773 225 0.000455 0.443645 89 78 34 3 0.587185 4.961249e-04 relu 0.000000 16 relu 0.000000 16 relu 0.000000 16 relu 0.00000 16 relu 0.000000 16 -1.159050 252 DONE 1245.439769 [1.6650637, 1.5602826, 1.5515006, 1.5409552, 1... [1.7371452, 1.649412, 1.6291426, 1.6426994, 1.... [1.6652527, 1.2852505, 1.2749716, 1.2613335, 1... [1.898218, 1.5504943, 1.506757, 1.5249182, 1.4... 1632.896000 5.692229e+00 849.0 1313.717623
252 29 tanh 0.066118 491 silu 0.022694 310 gelu 0.141575 376 0.000453 0.294295 92 95 55 3 0.044767 2.045635e-04 relu 0.000000 16 relu 0.000000 16 relu 0.000000 16 relu 0.00000 16 relu 0.000000 16 -1.122970 251 DONE 1231.533599 [1.7851361, 1.5758463, 1.5461717, 1.541927, 1.... [1.8419315, 1.6515197, 1.6262579, 1.624602, 1.... [2.0054357, 1.3500684, 1.2656512, 1.2646219, 1... [2.185182, 1.5942222, 1.5158734, 1.5109658, 1.... 5.322379 3.880587e+00 890.0 1313.733935

251 rows × 45 columns



We look at the learning curves of the best model and observe improvements in both training and validation loss:

# .. dropdown: Make learning curves plot
x_values = np.arange(1, len(baseline_results["metadata"]["train_loss"]) + 1)
x_min, x_max = x_values.min(), x_values.max()
_ = plt.figure(figsize=(WIDTH_PLOTS, HEIGHT_PLOTS))
_ = plt.plot(
    x_values,
    baseline_results["metadata"]["train_loss"],
    linestyle=":",
    label="Baseline Training",
)
_ = plt.plot(
    x_values,
    baseline_results["metadata"]["val_loss"],
    linestyle=":",
    label="Baseline Validation",
)

i_max = hpo_results["objective"].argmax()
train_loss = json.loads(hpo_results.iloc[i_max]["m:train_loss"])
val_loss = json.loads(hpo_results.iloc[i_max]["m:val_loss"])
x_values = np.arange(1, len(train_loss) + 1)
x_max = max(x_max, x_values.max())
_ = plt.plot(
    x_values,
    train_loss,
    alpha=0.8,
    linestyle="--",
    label="Best Training",
)
_ = plt.plot(
    x_values,
    val_loss,
    alpha=0.8,
    linestyle="--",
    label="Best Validation",
)
_ = plt.xlim(x_min, x_max)
_ = plt.grid(which="both", linestyle=":")
_ = plt.legend()
_ = plt.xlabel("Epochs")
_ = plt.ylabel("NLL")
plot nas deep ensemble uq regression pytorch

Finally, we look at predictions of this best model and observe that it manage to predict much better than the baseline one the right range.

Code (Make prediction plot)
kappa = 1.96
_ = plt.figure(figsize=(WIDTH_PLOTS, HEIGHT_PLOTS))
_ = plt.scatter(train_X, train_y, s=5, label="Training")
_ = plt.scatter(valid_X, valid_y, s=5, label="Validation")
_ = plt.plot(test_X, test_y, linestyle="--", color="gray", label="Test")

_ = plt.plot(test_X, y_pred_mean, label=r"$\mu(x)$")
_ = plt.fill_between(
    test_X.reshape(-1),
    (y_pred_mean - kappa * y_pred_std).reshape(-1),
    (y_pred_mean + kappa * y_pred_std).reshape(-1),
    alpha=0.25,
    label=r"$\sigma_\text{al}(x)$",
)
_ = plt.fill_between([-30, -15], [-y_lim, -y_lim], [y_lim, y_lim], color="gray", alpha=0.25)
_ = plt.fill_between([15, 30], [-y_lim, -y_lim], [y_lim, y_lim], color="gray", alpha=0.25)
_ = plt.xlim(-x_lim, x_lim)
_ = plt.ylim(-y_lim, y_lim)
_ = plt.legend(ncols=2)
_ = plt.xlabel(r"$x$")
_ = plt.ylabel(r"$f(x)$")
_ = plt.grid(which="both", linestyle=":")
plot nas deep ensemble uq regression pytorch

Deep ensemble#

After running the neural architecture search we have an available library of checkpointed models. From this section, you will learn how to combine these models to form an ensemble that can improve both accuracy and provide disentangled uncertainty quantification.

We start by importing classes from deephyper.predictor and deephyper.ensemble.

The deephyper.predictor module includes subclasses of deephyper.predictor.Predictor to wrap predictive models ready for inference. In our case, we will use deephyper.predictor.torch.TorchPredictor. The deephyper.ensemble module includes modular components to build an ensemble of predictive models. The ensemble module is organized around loss functions, aggregation functions and selection algorithms. The implementation of these functions is based on Numpy. In this example, we start by wrapping our torch module within a subclass of deephyper.predictor.torch.TorchPredictor that we call NormalTorchPredictor. This predictor class is used to make a torch module compatible with our Numpy-based implementation for ensembles.

The pre_process_inputs is used to map a Numpy array to a Torch tensor. The post_process_predictions is used to map a Torch tensor to a Numpy array. It also formats the prediction as a dictionnary with "loc" (for the predictive mean) and "scale" (for the predictive standard deviation) that is necessary for our aggregation function MixedNormalAggregator.

from deephyper.ensemble import EnsemblePredictor
from deephyper.ensemble.aggregator import MixedNormalAggregator
from deephyper.ensemble.loss import NormalNegLogLikelihood
from deephyper.ensemble.selector import GreedySelector, TopKSelector
from deephyper.predictor.torch import TorchPredictor

class NormalTorchPredictor(TorchPredictor):
    def __init__(self, torch_module):
        super().__init__(torch_module.to(device=device, dtype=dtype))

    def pre_process_inputs(self, X):
        return to_torch(X)

    def post_process_predictions(self, y):
        return {
            "loc": to_numpy(y.loc),
            "scale": to_numpy(y.scale),
        }

After defining the predictor, we load the checkpointed models to collect their predictions into y_predictors. These predictions are the inputs of our loss, aggregation and selection functions. We also collect the job ids of the checkpointed models into job_id_predictors.

model_checkpoint_dir = os.path.join(hpo_dir, "models")

y_predictors = []
job_id_predictors = []

for file_name in tqdm(os.listdir(model_checkpoint_dir)):
    if not file_name.endswith(".pt"):
        continue

    weights_path = os.path.join(model_checkpoint_dir, file_name)
    job_id = int(file_name[6:-3].split(".")[-1])

    row = hpo_results[hpo_results["job_id"] == job_id]
    if len(row) == 0:
        continue
    assert len(row) == 1

    row = row.iloc[0]
    parameters = parameters_from_row(row)
    torch_module = create_model(parameters, y_mu, y_std)
    try:
        torch_module.load_state_dict(torch.load(weights_path, weights_only=True))
    except RuntimeError:
        continue

    predictor = NormalTorchPredictor(torch_module)
    y_pred = predictor.predict(valid_X)
    y_predictors.append(y_pred)
    job_id_predictors.append(job_id)
  0%|          | 0/251 [00:00<?, ?it/s]
  8%|▊         | 21/251 [00:00<00:01, 205.32it/s]
 17%|█▋        | 42/251 [00:00<00:01, 186.36it/s]
 26%|██▋       | 66/251 [00:00<00:00, 206.12it/s]
 35%|███▍      | 87/251 [00:00<00:00, 204.30it/s]
 43%|████▎     | 108/251 [00:00<00:00, 205.48it/s]
 51%|█████▏    | 129/251 [00:00<00:00, 200.15it/s]
 60%|█████▉    | 150/251 [00:00<00:00, 201.24it/s]
 69%|██████▊   | 172/251 [00:00<00:00, 206.65it/s]
 77%|███████▋  | 193/251 [00:00<00:00, 196.23it/s]
 86%|████████▌ | 215/251 [00:01<00:00, 200.46it/s]
 94%|█████████▍| 236/251 [00:01<00:00, 202.83it/s]
100%|██████████| 251/251 [00:01<00:00, 201.93it/s]

Ensemble selection#

This is where the ensemble selection logic happens. We use the deephyper.ensemble.selector.GreedySelector or deephyper.ensemble.selector.TopKSelector class. The top-k selection, selects the topk-k models according to the given los_func and weight them equally in the ensemble. The greedy selection, iteratively selects models from the checkpoints that improves the current ensemble.

The aggregator is the logic that combines a set of predictors into a single predictor to form the ensemble’s prediction. In our case, we use the deephyper.ensemble.aggregator.MixedNormalAggregator that approximates a mixture of normal distribution (each normal distribution is the output of a checkpointed model) as a normal distribution.

To try top-k or greedy selection just uncomment/comment the corresponding code. This part of the code is fast to compute.

k = 50

# Top-K Selection
# selector = TopKSelector(
#     loss_func=NormalNegLogLikelihood(),
#     k=k,
# )

# Greedy Selection
selector = GreedySelector(
    loss_func=NormalNegLogLikelihood(),
    aggregator=MixedNormalAggregator(),
    k=k,
    max_it=k,
    k_init=3,
    early_stopping=True,
    with_replacement=True,
    bagging=True,
    verbose=True,
)

selected_predictors_indexes, selected_predictors_weights = selector.select(
    valid_y,
    y_predictors,
)

print(f"{selected_predictors_indexes=}")
print(f"{selected_predictors_weights=}")

selected_predictors_job_ids = np.array(job_id_predictors)[selected_predictors_indexes]
selected_predictors_job_ids

print(f"{selected_predictors_job_ids=}")
Ensemble initialized with [142, 14, 92] with loss [1.1185382057666016, 1.1227946721185649, 1.1288688523104407]
Step 1, ensemble is [142, 14, 92, 142], new member 142 with loss 1.120765690069929
Step 2, ensemble is [142, 14, 92, 142, 134], new member 134 with loss 1.1192680669085386
Step 3, ensemble selection stopped
After 3 steps, the final ensemble is [ 14  92 134 142] with weights [0.2 0.2 0.2 0.4]
selected_predictors_indexes=[14, 92, 134, 142]
selected_predictors_weights=[0.2, 0.2, 0.2, 0.4]
selected_predictors_job_ids=array([251, 158,  32, 199])

Evaluation of the ensemble#

Now that we have a set of predictors with their corresponding weights in the ensemble we can look at the predictions. For this, we use the deephyper.ensemble.EnsemblePredictor class. This class can use the deephyper.evaluator.Evaluator to parallelize the inference of ensemble members. Then, we need to give it the list of predictors, weights and the aggregator. For inference, we set decomposed_scale=True for the deephyper.ensemble.aggregator.MixedNormalAggregator as we want to predict disentangled epistemic and aleatoric uncertainty using the law of total variance:

\[V_Y[Y|X=x] = \underbrace{E_\Theta\left[V_Y[Y|X=x;\Theta\right]}_\text{Aleatoric Uncertainty} + \underbrace{V_\Theta\left[E_Y[Y|X=x;\Theta]\right]}_\text{Epistemic Uncertainty}\]

where \(\Theta\) is the random variable that represents a concatenation of weights and hyperparameters, \(Y`\) is the random variable representing a target prediction, and \(X\) is the random variable representing an observed input.

In the visualization, we can first observe that the mean prediction is close to the true function.

Then, to visualize both uncertainties together we plot the variance. The goal is to observe the epistemic component vanish in areas where we observed data.

Code (Make uncertainty plot)
_ = plt.figure(figsize=(WIDTH_PLOTS, HEIGHT_PLOTS))
_ = plt.scatter(train_X, train_y, s=5, label="Training")
_ = plt.scatter(valid_X, valid_y, s=5, label="Validation")
_ = plt.plot(test_X, test_y, linestyle="--", color="gray", label="Test")
_ = plt.plot(test_X, y_pred["loc"], label=r"$\mu(x)$")
_ = plt.fill_between(
    test_X.reshape(-1),
    (y_pred["loc"] - y_pred["scale_aleatoric"]**2).reshape(-1),
    (y_pred["loc"] + y_pred["scale_aleatoric"]**2).reshape(-1),
    alpha=0.25,
    label=r"$\sigma_\text{al}^2(x)$",
)
_ = plt.fill_between(
    test_X.reshape(-1),
    (y_pred["loc"] - y_pred["scale_aleatoric"]**2).reshape(-1),
    (y_pred["loc"] - y_pred["scale_aleatoric"]**2 - y_pred["scale_epistemic"]**2).reshape(-1),
    alpha=0.25,
    color="red",
    label=r"$\sigma_\text{ep}^2(x)$",
)
_ = plt.fill_between(
    test_X.reshape(-1),
    (y_pred["loc"] + y_pred["scale_aleatoric"]**2).reshape(-1),
    (y_pred["loc"] + y_pred["scale_aleatoric"]**2 + y_pred["scale_epistemic"]**2).reshape(-1),
    alpha=0.25,
    color="red",
)
_ = plt.fill_between([-30, -15], [-y_lim, -y_lim], [y_lim, y_lim], color="gray", alpha=0.25)
_ = plt.fill_between([15, 30], [-y_lim, -y_lim], [y_lim, y_lim], color="gray", alpha=0.25)
_ = plt.xlim(-x_lim, x_lim)
_ = plt.ylim(-y_lim, y_lim)
_ = plt.legend(ncols=2)
_ = plt.xlabel(r"$x$")
_ = plt.ylabel(r"$f(x)$")
_ = plt.grid(which="both", linestyle=":")
plot nas deep ensemble uq regression pytorch

Aleatoric Uncertainty#

Now, if we isolate the aleatoric uncertainty we observe that we somewhat correctly estimated the lower aleatoric uncertainty on the left side, and larger on the right side.

Code (Make aleatoric uncertainty plot)
kappa = 1.96
_ = plt.figure(figsize=(WIDTH_PLOTS, HEIGHT_PLOTS))
_ = plt.scatter(train_X, train_y, s=5, label="Training")
_ = plt.scatter(valid_X, valid_y, s=5, label="Validation")
_ = plt.plot(test_X, test_y, linestyle="--", color="gray", label="Test")
_ = plt.plot(test_X, y_pred["loc"], label=r"$\mu(x)$")
_ = plt.fill_between(
    test_X.reshape(-1),
    (y_pred["loc"] - kappa * y_pred["scale_aleatoric"]).reshape(-1),
    (y_pred["loc"] + kappa * y_pred["scale_aleatoric"]).reshape(-1),
    alpha=0.25,
    label=r"$\sigma_\text{al}(x)$",
)
_ = plt.fill_between([-30, -15], [-y_lim, -y_lim], [y_lim, y_lim], color="gray", alpha=0.25)
_ = plt.fill_between([15, 30], [-y_lim, -y_lim], [y_lim, y_lim], color="gray", alpha=0.25)
_ = plt.xlim(-x_lim, x_lim)
_ = plt.ylim(-y_lim, y_lim)
_ = plt.legend(ncols=2)
_ = plt.xlabel(r"$x$")
_ = plt.ylabel(r"$f(x)$")
_ = plt.grid(which="both", linestyle=":")
plot nas deep ensemble uq regression pytorch

Epistemic uncertainty#

Finally, if we isole the epistemic uncertainty we observe that it vanishes in the grey areas where we observed data and grows in areas were we did not have data.

Code (Make epistemic uncertainty plot)
kappa = 1.96
_ = plt.figure(figsize=(WIDTH_PLOTS, HEIGHT_PLOTS))
_ = plt.scatter(train_X, train_y, s=5, label="Training")
_ = plt.scatter(valid_X, valid_y, s=5, label="Validation")
_ = plt.plot(test_X, test_y, linestyle="--", color="gray", label="Test")
_ = plt.plot(test_X, y_pred["loc"], label=r"$\mu(x)$")
_ = plt.fill_between(
    test_X.reshape(-1),
    (y_pred["loc"] - kappa * y_pred["scale_epistemic"]).reshape(-1),
    (y_pred["loc"] + kappa * y_pred["scale_epistemic"]).reshape(-1),
    alpha=0.25,
    color="red",
    label=r"$\sigma_\text{ep}(x)$",
)
_ = plt.fill_between([-30, -15], [-y_lim, -y_lim], [y_lim, y_lim], color="gray", alpha=0.25)
_ = plt.fill_between([15, 30], [-y_lim, -y_lim], [y_lim, y_lim], color="gray", alpha=0.25)
_ = plt.xlim(-x_lim, x_lim)
_ = plt.ylim(-y_lim, y_lim)
_ = plt.legend(ncols=2)
_ = plt.xlabel(r"$x$")
_ = plt.ylabel(r"$f(x)$")
_ = plt.grid(which="both", linestyle=":")
plot nas deep ensemble uq regression pytorch

Total running time of the script: (0 minutes 18.016 seconds)

Gallery generated by Sphinx-Gallery