Source code for coax.value_losses._losses

import jax
import jax.numpy as jnp


__all__ = (
    'mse',
    'huber',
    'logloss',
    'logloss_sign',
)


[docs]def mse(y_true, y_pred, w=None): r""" Ordinary mean-squared error loss function. .. math:: L\ =\ \frac12(\hat{y} - y)^2 .. image:: /_static/img/mse.svg :alt: Mean-Squared Error loss :width: 320px :align: center Parameters ---------- y_true : ndarray The target :math:`y\in\mathbb{R}`. y_pred : ndarray The predicted output :math:`\hat{y}\in\mathbb{R}`. w : ndarray, optional Sample weights. Returns ------- loss : scalar ndarray The loss averaged over the batch. """ loss = 0.5 * jnp.square(y_pred - y_true) return _mean_with_weights(loss, w)
[docs]def huber(y_true, y_pred, w=None, delta=1.0): r""" `Huber <https://en.wikipedia.org/wiki/Huber_loss>`_ loss function. .. math:: L\ =\ \left\{\begin{matrix} (\hat{y} - y)^2 &\quad:\ |\hat{y} - y|\leq\delta \\ \delta\,|\hat{y} - y| - \frac{\delta^2}{2} &\quad:\ |\hat{y} - y| > \delta \end{matrix}\right. .. image:: /_static/img/huber.svg :alt: Huber loss :width: 320px :align: center Parameters ---------- y_true : ndarray The target :math:`y\in\mathbb{R}`. y_pred : ndarray The predicted output :math:`\hat{y}\in\mathbb{R}`. w : ndarray, optional Sample weights. delta : float, optional The scale of the quadratic-to-linear transition. Returns ------- loss : scalar ndarray The loss averaged over the batch. """ err = jnp.abs(y_pred - y_true) err_clipped = jnp.minimum(err, delta) loss = 0.5 * jnp.square(err_clipped) + delta * (err - err_clipped) return _mean_with_weights(loss, w)
[docs]def logloss(y_true, y_pred, w=None): r""" Logistic loss function for binary classification, `y_true` = :math:`y\in\{0,1\}` and the model output is a probability `y_pred` = :math:`\hat{y}\in[0,1]`: .. math:: L\ =\ -y\log(\hat{y}) - (1 - y)\log(1 - \hat{y}) Parameters ---------- y_true : ndarray The binary target, encoded as :math:`y\in\{0,1\}`. y_pred : (ndarray of) float The predicted output, represented by a probablity :math:`\hat{y}\in[0,1]`. w : ndarray, optional Sample weights. Returns ------- loss : scalar ndarray The loss averaged over the batch. """ loss = -y_true * jnp.log(y_pred) - (1. - y_true) * jnp.log(1. - y_pred) return _mean_with_weights(loss, w)
[docs]def logloss_sign(y_true_sign, logits, w=None): r""" Logistic loss function specific to the case in which the target is a sign :math:`y\in\{-1,1\}` and the model output is a logit :math:`\hat{z}\in\mathbb{R}`. .. math:: L\ =\ \log(1 + \exp(-y\,\hat{z})) This version tends to be more numerically stable than the generic implementation, because it avoids having to map the predicted logit to a probability. Parameters ---------- y_true_sign : ndarray The binary target, encoded as :math:`y=\pm1`. logits : ndarray The predicted output, represented by a logit :math:`\hat{z}\in\mathbb{R}`. w : ndarray, optional Sample weights. Returns ------- loss : scalar ndarray The loss averaged over the batch. """ loss = jnp.log(1.0 + jnp.exp(-y_true_sign * logits)) return _mean_with_weights(loss, w)
def _mean_with_weights(loss, w): if w is not None: assert w.ndim == 1 assert loss.ndim >= 1 assert loss.shape[0] == w.shape[0] loss = jax.vmap(jnp.multiply)(w, loss) return jnp.mean(loss)
[docs]def quantile_huber(y_true, y_pred, quantiles, w=None, delta=1.0): r""" `Quantile Huber <https://arxiv.org/abs/1806.06923>`_ loss function. .. math:: \delta_{ij} &= y_j - \hat{y}_i\\ \rho^\kappa_\tau(\delta_{ij}) &= |\tau - \mathbb{I}{\{ \delta_{ij} < 0 \}}| \ \frac{\mathcal{L}_\kappa(\delta_{ij})}{\kappa},\ \quad \text{with}\\ \mathcal{L}_\kappa(\delta_{ij}) &= \begin{cases} \frac{1}{2} \delta_{ij}^2,\quad \ &\text{if } |\delta_{ij}| \le \kappa\\ \kappa (|\delta_{ij}| - \frac{1}{2}\kappa),\quad \ &\text{otherwise} \end{cases} Parameters ---------- y_true : ndarray The target :math:`y\in\mathbb{R}^{2}`. y_pred : ndarray The predicted output :math:`\hat{y}\in\mathbb{R}^{2}`. quantiles : ndarray The quantiles of the prediction :math:`\tau\in\mathbb{R}^{2}`. w : ndarray, optional Sample weights. delta : float, optional The scale of the quadratic-to-linear transition. Returns ------- loss : scalar ndarray The loss averaged over the batch. """ y_pred = y_pred[..., None] y_true = y_true[..., None, :] quantiles = quantiles[..., None] td_error = y_true - y_pred td_error_abs = jnp.abs(td_error) err_clipped = jnp.minimum(td_error_abs, delta) elementwise_huber_loss = 0.5 * jnp.square(err_clipped) + delta * (td_error_abs - err_clipped) elementwise_quantile_huber_loss = jnp.abs( quantiles - (td_error < 0)) * elementwise_huber_loss / delta quantile_huber_loss = elementwise_quantile_huber_loss.sum(axis=-1) return _mean_with_weights(quantile_huber_loss, w=w)