Source code for deephyper.stopper.lce._bayesian_regression

import sys
from functools import partial

# Temporary workaround: https://github.com/pyro-ppl/numpyro/issues/2051#issuecomment-3110627625
import jax
import jax.experimental.pjit
from jax.extend.core.primitives import jit_p

jax.experimental.pjit.pjit_p = jit_p

import jax.numpy as jnp  # noqa: E402
import numpy as np  # noqa: E402
import numpyro  # noqa: E402
import numpyro.distributions as dist  # noqa: E402
from numpyro.infer import MCMC, NUTS, BarkerMH  # noqa: E402
from scipy.optimize import least_squares  # noqa: E402
from sklearn.base import BaseEstimator, RegressorMixin  # noqa: E402
from sklearn.utils import check_random_state  # noqa: E402
from sklearn.utils.validation import check_array, check_is_fitted, check_X_y  # noqa: E402

LC_MODELS = [
    "lin2",
    "pow3",
    "mmf4",
    "vapor3",
    "logloglin2",
    "hill3",
    "logpow3",
    "pow4",
    "exp4",
    "janoschek4",
    "weibull4",
    "ilog2",
    "arctan3",
]


# Learning curves models
@jax.jit
def f_lin2(z, rho):
    return rho[1] * z + rho[0]


@jax.jit
def f_pow3(z, rho):
    return rho[0] - rho[1] * z ** rho[2]


@jax.jit
def f_mmf4(z, rho):
    return (rho[0] * rho[1] + rho[2] * jnp.power(z, rho[3])) / (rho[1] + jnp.power(z, rho[3]))


@jax.jit
def f_vapor3(z, rho):
    return rho[0] + rho[1] / z + rho[2] * np.log(z)


@jax.jit
def f_logloglin2(z, rho):
    return jnp.log(rho[0] * jnp.log(z) + rho[1])


@jax.jit
def f_hill3(z, rho):
    ymax, eta, kappa = rho
    return ymax * (z**eta) / (kappa * eta + z**eta)


@jax.jit
def f_logpow3(z, rho):
    return rho[0] / (1 + (z / jnp.exp(rho[1])) ** rho[2])


@jax.jit
def f_pow4(z, rho):
    return rho[2] - (rho[0] * z + rho[1]) ** (-rho[3])


@jax.jit
def f_exp4(z, rho):
    return rho[2] - jnp.exp(-rho[0] * (z ** rho[3]) + rho[1])


@jax.jit
def f_janoschek4(z, rho):
    return rho[0] - (rho[0] - rho[1]) * jnp.exp(-rho[2] * (z ** rho[3]))


@jax.jit
def f_weibull4(z, rho):
    return rho[0] - (rho[0] - rho[1]) * jnp.exp(-((rho[2] * z) ** rho[3]))


@jax.jit
def f_ilog2(z, rho):
    return rho[1] - (rho[0] / jnp.log(z + 1))


@jax.jit
def f_arctan3(z, rho):
    return 2 / jnp.pi * jnp.arctan(rho[0] * jnp.pi / 2 * z + rho[1]) - rho[2]


# Utility to estimate parameters of learning curve model
# The combination of "partial" and "static_argnums" is necessary
# with the "f" lambda function passed as argument
@partial(jax.jit, static_argnums=(1,))
def residual_least_square(rho, f, z, y):
    """Residual for least squares."""
    y_pred = f(z, rho)
    y_pred = jnp.where(y_pred == 0.0, y_pred, 0.0)
    return y_pred - y


@partial(jax.jit, static_argnums=(1,))
def jac_residual_least_square(rho, f, z, y):
    """Jacobian of the residual for least squares."""
    return jax.jacfwd(residual_least_square, argnums=0)(rho, f, z, y)


def fit_learning_curve_model_least_square(
    z_train,
    y_train,
    f_model,
    f_model_nparams,
    max_trials_ls_fit=10,
    random_state=None,
    verbose=0,
):
    """The learning curve model is assumed to be modeled by 'f' with interface f(z, rho)."""
    random_state = check_random_state(random_state)

    results = []
    mse_hist = []

    rho_init = np.zeros((f_model_nparams,))

    for i in range(max_trials_ls_fit):
        if verbose:
            print(f"Least-Square fit - trial {i + 1}/{max_trials_ls_fit}: ", end="")

        rho_init[:] = random_state.randn(f_model_nparams)[:]

        try:
            res_lsq = least_squares(
                residual_least_square,
                rho_init,
                args=(f_model, z_train, y_train),
                method="lm" if len(z_train) >= f_model_nparams else "trf",
                jac=jac_residual_least_square,
            )
        except ValueError:
            continue

        mse_res_lsq = np.mean(res_lsq.fun**2)
        mse_hist.append(mse_res_lsq)
        results.append(res_lsq.x)

        if verbose:
            print(f"mse={mse_res_lsq:.3f}")

        if mse_res_lsq < 1e-8:
            break

    i_best = np.nanargmin(mse_hist)
    res = results[i_best]
    return res


def prob_model(
    z,
    y,
    f=None,
    rho_mu_prior=None,
    rho_sigma_prior=1.0,
    y_sigma_prior=1.0,
    num_obs=None,
):
    rho = numpyro.sample("rho", dist.Normal(rho_mu_prior, rho_sigma_prior))
    y_sigma = numpyro.sample("sigma", dist.Exponential(y_sigma_prior))  # introducing noise
    y_mu = f(z[:num_obs], rho)
    numpyro.sample("obs", dist.Normal(y_mu, y_sigma), obs=y[:num_obs])


@partial(jax.jit, static_argnums=(0,))
def predict_moments_from_posterior(f, X, posterior_samples):
    vf_model = jax.vmap(f, in_axes=(None, 0))
    posterior_mu = vf_model(X, posterior_samples)
    mean_mu = jnp.mean(posterior_mu, axis=0)
    std_mu = jnp.std(posterior_mu, axis=0)
    return mean_mu, std_mu


[docs] class BayesianLearningCurveRegressor(BaseEstimator, RegressorMixin): """Probabilistic model for learning curve regression. Args: f_model (callable, optional): The model function to use. Defaults to `f_power3` for a Power-Law with 3 parameters. f_model_nparams (int, optional): The number of parameters of the model. Defaults to `3`. max_trials_ls_fit (int, optional): The number of least-square fits that should be tried. Defaults to `10`. mcmc_kernel (str, optional): The MCMC kernel to use. It should be a string in the following list: `["NUTS", "BarkerMH"]`. Defaults to `"NUTS"`. mcmc_num_warmup (int, optional): The number of warmup steps in MCMC. Defaults to `200`. mcmc_num_samples (int, optional): The number of samples in MCMC. Defaults to `1_000`. random_state (int, optional): A random state. Defaults to `None`. verbose (int, optional): Wether or not to use the verbose mode. Defaults to `0` to deactive it. batch_size (int, optional): The expected maximum length of the X, y arrays (used in the `fit (X, y)` method) in order to preallocate memory and compile the code only once. Defaults to `100`. min_max_scaling (bool, optional): Wether or not to use min-max scaling in [0,1] for `y` values. Defaults to False. """ def __init__( self, f_model=f_pow3, f_model_nparams=3, max_trials_ls_fit=10, mcmc_kernel="NUTS", mcmc_num_chains=1, mcmc_num_warmup=200, mcmc_num_samples=1_000, random_state=None, verbose=0, batch_size=1_000, min_max_scaling=False, ): self.f_model = f_model self.f_model_nparams = f_model_nparams self.mcmc_kernel = mcmc_kernel self.mcmc_num_chains = mcmc_num_chains self.mcmc_num_warmup = mcmc_num_warmup self.mcmc_num_samples = mcmc_num_samples self.max_trials_ls_fit = max_trials_ls_fit self.random_state = check_random_state(random_state) self.verbose = verbose self.rho_mu_prior_ = np.zeros((self.f_model_nparams,)) self.batch_size = batch_size self.X_ = np.zeros((self.batch_size,)) self.y_ = np.zeros((self.batch_size,)) self.min_max_scaling = min_max_scaling
[docs] def fit(self, X, y, update_prior=True): """Fit the model. Args: X (np.ndarray): A 1-D array of inputs. y (_type_): A 1-D array of targets. update_prior (bool, optional): A boolean indicating if the prior distribution should be updated using least-squares before running the Bayesian inference. Defaults to ``True``. Raises: ValueError: if input arguments are invalid. """ check_X_y(X, y, ensure_2d=False) # !Trick for performance to avoid performign JIT again and again # !This will fix the shape of inputs of the model for numpyro # !see https://github.com/pyro-ppl/numpyro/issues/441 num_samples = len(X) assert num_samples <= self.batch_size self.X_[:num_samples] = X[:] self.y_[:num_samples] = y[:] self.X_[num_samples:] = 0.0 self.y_[num_samples:] = 0.0 # Min-Max Scaling if not (self.min_max_scaling): self.y_min_ = 0 self.y_max_ = 1 else: self.y_min_ = self.y_[:num_samples].min() self.y_max_ = self.y_[:num_samples].max() if abs(self.y_min_ - self.y_max_) <= 1e-8: # avoid division by zero self.y_max_ = self.y_min_ + 1 self.y_[:num_samples] = (self.y_[:num_samples] - self.y_min_) / ( self.y_max_ - self.y_min_ ) if update_prior: self.rho_mu_prior_[:] = fit_learning_curve_model_least_square( self.X_, self.y_, f_model=self.f_model, f_model_nparams=self.f_model_nparams, max_trials_ls_fit=self.max_trials_ls_fit, random_state=self.random_state, verbose=self.verbose, )[:] if self.verbose: print(f"rho_mu_prior: {self.rho_mu_prior_}") if not (hasattr(self, "kernel_")): target_accept_prob = 0.8 step_size = 0.05 if self.mcmc_kernel == "NUTS": self.kernel_ = NUTS( model=lambda z, y: prob_model( z, y, f=self.f_model, rho_mu_prior=self.rho_mu_prior_, num_obs=num_samples, ), target_accept_prob=target_accept_prob, step_size=step_size, ) elif self.mcmc_kernel == "BarkerMH": self.kernel_ = BarkerMH( model=lambda z, y: prob_model( z, y, f=self.f_model, rho_mu_prior=self.rho_mu_prior_, num_obs=num_samples, ), target_accept_prob=target_accept_prob, step_size=step_size, ) else: raise ValueError(f"Unknown MCMC kernel: {self.mcmc_kernel}") self.mcmc_ = MCMC( self.kernel_, num_chains=self.mcmc_num_chains, num_warmup=self.mcmc_num_warmup, num_samples=self.mcmc_num_samples, progress_bar=self.verbose, ) seed = self.random_state.randint(low=0, high=np.iinfo(np.int32).max) rng_key = jax.random.PRNGKey(seed) self.mcmc_.run(rng_key, z=self.X_, y=self.y_) if self.verbose: self.mcmc_.print_summary() return self
[docs] def predict(self, X, return_std=True): """Predict the mean and standard deviation of the model. Args: X (np.ndarray): A 1-D array of inputs. return_std (bool, optional): A boolean indicating if the standard-deviation representing uncertainty in the prediction should be returned. Defaults to ``True``. Returns: Tuple[np.ndarray, np.ndarray]: The mean prediction with shape ``(len(X),)`` and the standard deviation with shape ``(len(X),)`` if ``return_std`` is ``True``. """ posterior_samples = self.predict_posterior_samples(X) mean_mu = jnp.mean(posterior_samples, axis=0) if return_std: std_mu = jnp.std(posterior_samples, axis=0) return mean_mu, std_mu return mean_mu
[docs] def predict_posterior_samples(self, X): """Predict the posterior samples of the model. Args: X (np.ndarray): a 1-D array of inputs. Returns: np.ndarray: A 2-D array of shape (n_samples, len(X)) where n_samples is the number of samples and len(X) is the length of the input array. """ # Check if fit has been called check_is_fitted(self) # Input validation X = check_array(X, ensure_2d=False) posterior_samples = self.mcmc_.get_samples() vf_model = jax.vmap(self.f_model, in_axes=(None, 0)) posterior_mu = vf_model(X, posterior_samples["rho"]) # Inverse Min-Max Scaling posterior_mu = posterior_mu * (self.y_max_ - self.y_min_) + self.y_min_ return posterior_mu
[docs] def prob(self, X, condition): """Compute the approximate probability of P(cond(m(X_i), y_i)). Where m is the current fitted model and cond a condition. Args: X (np.array): An array of inputs. condition (callable): A function defining the condition to test. Returns: array: an array of shape X. """ posterior_mu = self.predict_posterior_samples(X) prob = jnp.mean(condition(posterior_mu), axis=0) return prob
[docs] @staticmethod def get_parametrics_model_func(name): """Return the function of the learning curve model given its name. Should be one of `` ["lin2", "pow3", "mmf4", "vapor3", "logloglin2", "hill3", "logpow3", "pow4", "exp4", "janoschek4", "weibull4", "ilog2", "arctan3"]`` where the integer suffix indicates the number of parameters of the model. Args: name (str): The name of the learning curve model. Returns: callable: A function with signature ``f(x, rho)`` of the learning curve model where ``x`` is a possible input of the model and ``rho`` is a 1-D array for the parameters of the model with length equal to the number of parameters of the model (e.g., it is of length ``3`` for ``"pow3"``). """ return getattr(sys.modules[__name__], f"f_{name}")