Source code for deephyper.nas.metrics

"""This module provides different metric functions. A metric can be defined by a keyword (str) or a callable. If it is a keyword it has to be available in ``tensorflow.keras`` or in ``deephyper.netrics``. The loss functions availble in ``deephyper.metrics`` are:
* Sparse Perplexity: ``sparse_perplexity``
* R2: ``r2``
* AUC ROC: ``auroc``
* AUC Precision-Recall: ``aucpr``
import functools
from collections import OrderedDict

import tensorflow as tf
from deephyper.core.utils import load_attr

[docs]def r2(y_true, y_pred): SS_res = tf.math.reduce_sum(tf.math.square(y_true - y_pred), axis=0) SS_tot = tf.math.reduce_sum( tf.math.square(y_true - tf.math.reduce_mean(y_true, axis=0)), axis=0 ) output_scores = 1 - SS_res / (SS_tot + tf.keras.backend.epsilon()) r2 = tf.math.reduce_mean(output_scores) return r2
[docs]def mae(y_true, y_pred): return tf.keras.metrics.mean_absolute_error(y_true, y_pred)
[docs]def mse(y_true, y_pred): return tf.keras.metrics.mean_squared_error(y_true, y_pred)
[docs]def rmse(y_true, y_pred): return tf.math.sqrt(tf.math.reduce_mean(tf.math.square(y_pred - y_true)))
[docs]def acc(y_true, y_pred): return tf.keras.metrics.categorical_accuracy(y_true, y_pred)
[docs]def sparse_perplexity(y_true, y_pred): cross_entropy = tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred) perplexity = tf.pow(2.0, cross_entropy) return perplexity
[docs]def to_tfp(metric_func): """Convert a regular tensorflow-keras metric for tensorflow probability where the output is a distribution. Args: metric_func (func): A regular tensorflow-keras metric function. """ @functools.wraps(metric_func) def wrapper(y_true, y_pred): return metric_func(y_true, y_pred.mean()) wrapper.__name__ = f"tfp_{metric_func.__name__}" return wrapper
# convert some metrics for Tensorflow Probability where the output of the model is # a distribution tfp_r2 = to_tfp(r2) tfp_mae = to_tfp(mae) tfp_mse = to_tfp(mse) tfp_rmse = to_tfp(rmse) metrics_func = OrderedDict() metrics_func["mean_absolute_error"] = metrics_func["mae"] = mae metrics_func["r2"] = r2 metrics_func["mean_squared_error"] = metrics_func["mse"] = mse metrics_func["root_mean_squared_error"] = metrics_func["rmse"] = rmse metrics_func["accuracy"] = metrics_func["acc"] = acc metrics_func["sparse_perplexity"] = sparse_perplexity metrics_func["tfp_r2"] = tfp_r2 metrics_func["tfp_mse"] = tfp_mse metrics_func["tfp_mae"] = tfp_mae metrics_func["tfp_rmse"] = tfp_rmse metrics_obj = OrderedDict() metrics_obj["auroc"] = lambda: tf.keras.metrics.AUC(name="auroc", curve="ROC") metrics_obj["aucpr"] = lambda: tf.keras.metrics.AUC(name="aucpr", curve="PR")
[docs]def selectMetric(name: str): """Return the metric defined by name. Args: name (str): a string referenced in DeepHyper, one referenced in keras or an attribute name to import. Returns: str or callable: a string suppossing it is referenced in the keras framework or a callable taking (y_true, y_pred) as inputs and returning a tensor. """ if callable(name): return name if metrics_func.get(name) is None and metrics_obj.get(name) is None: try: return load_attr(name) except Exception: return name # supposing it is referenced in keras metrics else: if name in metrics_func: return metrics_func[name] else: return metrics_obj[name]()