Source code for coax.proba_dists._normal

import warnings

import jax
import jax.numpy as jnp
import numpy as onp
from gymnasium.spaces import Box

from ..utils import clipped_logit, jit
from ._base import BaseProbaDist


__all__ = (
    'NormalDist',
)


[docs]class NormalDist(BaseProbaDist): r""" A differentiable normal distribution. The input ``dist_params`` to each of the functions is expected to be of the form: .. code:: python dist_params = {'mu': array([...]), 'logvar': array([...])} which represent the (conditional) distribution parameters. Here, ``mu`` is the mean :math:`\mu` and ``logvar`` is the log-variance :math:`\log(\sigma^2)`. Parameters ---------- space : gymnasium.spaces.Box The gymnasium-style space that specifies the domain of the distribution. clip_box : pair of floats, optional The range of values to allow for *clean* (compact) variates. This is mainly to ensure reasonable values when one or more dimensions of the Box space have very large ranges, while in reality only a small part of that range is occupied. clip_reals : pair of floats, optional The range of values to allow for *raw* (decompactified) variates, the *reals*, used internally. This range is set for numeric stability. Namely, the :attr:`postprocess_variate` method compactifies the reals to a closed interval (Box) by applying a logistic sigmoid. Setting a finite range for :code:`clip_reals` ensures that the sigmoid doesn't fully saturate. clip_logvar : pair of floats, optional The range of values to allow for the log-variance of the distribution. """ def __init__(self, space, clip_box=(-256., 256.), clip_reals=(-30., 30.), clip_logvar=(-20., 20.)): if not isinstance(space, Box): raise TypeError(f"{self.__class__.__name__} can only be defined over Box spaces") super().__init__(space) self.clip_box = clip_box self.clip_reals = clip_reals self.clip_logvar = clip_logvar self._low = onp.maximum(onp.expand_dims(self.space.low, axis=0), self.clip_box[0]) self._high = onp.minimum(onp.expand_dims(self.space.high, axis=0), self.clip_box[1]) onp.testing.assert_array_less( self._low, self._high, "Box clipping resulted in inconsistent boundaries: " f"low={self._low}, high={self._high}; please specify proper clipping values, " "e.g. NormalDist(space, clip_box=(-1000., 1000.))") if onp.any(self._low > self.space.low) or onp.any(self._high < self.space.high): with onp.printoptions(precision=1): warnings.warn( f"one or more dimensions of Box(low={self.space.low}, high={self.space.high}) " f"will be clipped to Box(low={self._low[0]}, high={self._high[0]})") log_2pi = onp.asarray(1.8378770664093453) # abbreviation def check_shape(x, name, flatten): if not isinstance(x, jnp.ndarray): raise TypeError(f"expected an jax.numpy.ndarray, got: {type(x)}") if not (x.ndim == len(space.shape) + 1 and x.shape[1:] == space.shape): expected = ', '.join(f'{i:d}' for i in space.shape) raise ValueError(f"expected {name}.shape: (?, {expected}), got: {x.shape}") if flatten: x = x.reshape(x.shape[0], -1) # batch-flatten if name.startswith("logvar"): x = jnp.clip(x, *self.clip_logvar) return x def sample(dist_params, rng): mu = check_shape(dist_params['mu'], name='mu', flatten=True) logvar = check_shape(dist_params['logvar'], name='logvar', flatten=True) X = mu + jnp.exp(logvar / 2) * jax.random.normal(rng, mu.shape) return X.reshape(-1, *self.space.shape) def mean(dist_params): mu = check_shape(dist_params['mu'], name='mu', flatten=False) return mu def mode(dist_params): return mean(dist_params) def log_proba(dist_params, X): X = check_shape(X, name='X', flatten=True) mu = check_shape(dist_params['mu'], name='mu', flatten=True) logvar = check_shape(dist_params['logvar'], name='logvar', flatten=True) n = logvar.shape[-1] logdetvar = jnp.sum(logvar, axis=-1) # log(det(M)) = tr(log(M)) quadratic = jnp.einsum('ij,ij->i', jnp.square(X - mu), jnp.exp(-logvar)) logp = -0.5 * (n * log_2pi + logdetvar + quadratic) return logp def entropy(dist_params): logvar = check_shape(dist_params['logvar'], name='logvar', flatten=True) assert logvar.ndim == 2 # check if flattened logdetvar = jnp.sum(logvar, axis=-1) # log(det(M)) = tr(log(M)) n = logvar.shape[-1] return 0.5 * (n * log_2pi + logdetvar + n) def cross_entropy(dist_params_p, dist_params_q): mu1 = check_shape(dist_params_p['mu'], name='mu_p', flatten=True) mu2 = check_shape(dist_params_q['mu'], name='mu_q', flatten=True) logvar1 = check_shape(dist_params_p['logvar'], name='logvar_p', flatten=True) logvar2 = check_shape(dist_params_q['logvar'], name='logvar_q', flatten=True) n = mu1.shape[-1] assert n == mu2.shape[-1] == logvar1.shape[-1] == logvar2.shape[-1] var1 = jnp.exp(logvar1) var2_inv = jnp.exp(-logvar2) logdetvar2 = jnp.sum(logvar2, axis=-1) # log(det(M)) = tr(log(M)) quadratic = jnp.einsum('ij,ij->i', var1 + jnp.square(mu1 - mu2), var2_inv) return 0.5 * (n * log_2pi + logdetvar2 + quadratic) def kl_divergence(dist_params_p, dist_params_q): mu1 = check_shape(dist_params_p['mu'], name='mu_p', flatten=True) mu2 = check_shape(dist_params_q['mu'], name='mu_q', flatten=True) logvar1 = check_shape(dist_params_p['logvar'], name='logvar_p', flatten=True) logvar2 = check_shape(dist_params_q['logvar'], name='logvar_q', flatten=True) n = mu1.shape[-1] assert n == mu2.shape[-1] == logvar1.shape[-1] == logvar2.shape[-1] var1 = jnp.exp(logvar1) var2_inv = jnp.exp(-logvar2) logdetvar1 = jnp.sum(logvar1, axis=-1) # log(det(M)) = tr(log(M)) logdetvar2 = jnp.sum(logvar2, axis=-1) # log(det(M)) = tr(log(M)) quadratic = jnp.einsum('ij,ij->i', var1 + jnp.square(mu1 - mu2), var2_inv) return 0.5 * (logdetvar2 - logdetvar1 + quadratic - n) def affine_transform_func(dist_params, scale, shift, value_transform=None): if value_transform is None: f = f_inv = lambda x: x else: f, f_inv = value_transform mu = check_shape(dist_params['mu'], name='mu', flatten=False) logvar = check_shape(dist_params['logvar'], name='logvar', flatten=False) var_new = f(f_inv(jnp.exp(logvar)) * jnp.square(scale)) return {'mu': f(f_inv(mu) + shift), 'logvar': jnp.log(var_new)} self._sample_func = jit(sample) self._mean_func = jit(mean) self._mode_func = jit(mode) self._log_proba_func = jit(log_proba) self._entropy_func = jit(entropy) self._cross_entropy_func = jit(cross_entropy) self._kl_divergence_func = jit(kl_divergence) self._affine_transform_func = jit(affine_transform_func, static_argnums=(3,)) @property def default_priors(self): shape = (1, *self.space.shape) # include batch axis return {'mu': jnp.zeros(shape), 'logvar': jnp.zeros(shape)}
[docs] def preprocess_variate(self, rng, X): X = jnp.asarray(X, dtype=self.space.dtype) # ensure ndarray X = jnp.reshape(X, (-1, *self.space.shape)) # ensure batch axis X = jnp.clip(X, self._low, self._high) # clip to be safe X = clipped_logit((X - self._low) / (self._high - self._low)) # closed intervals->reals return X
[docs] def postprocess_variate(self, rng, X, index=0, batch_mode=False): X = jnp.asarray(X, dtype=self.space.dtype) # ensure ndarray X = jnp.reshape(X, (-1, *self.space.shape)) # ensure correct shape X = jnp.clip(X, *self.clip_reals) # clip for stability X = self._low + (self._high - self._low) * jax.nn.sigmoid(X) # reals->closed interval return X if batch_mode else onp.asanyarray(X[index])
@property def sample(self): r""" JIT-compiled function that generates differentiable variates using the reparametrization trick, i.e. :math:`x\sim\mathcal{N}(\mu,\sigma^2)` is implemented as .. math:: \varepsilon\ &\sim\ \mathcal{N}(0,1) \\ x\ &=\ \mu + \sigma\,\varepsilon Parameters ---------- dist_params : pytree with ndarray leaves A batch of distribution parameters. rng : PRNGKey A key for seeding the pseudo-random number generator. Returns ------- X : ndarray A batch of differentiable variates. """ return self._sample_func @property def mean(self): r""" JIT-compiled functions that generates differentiable means of the distribution, in this case simply :math:`\mu`. Parameters ---------- dist_params : pytree with ndarray leaves A batch of distribution parameters. Returns ------- X : ndarray A batch of differentiable variates. """ return self._mean_func @property def mode(self): r""" JIT-compiled functions that generates differentiable modes of the distribution, which for a normal distribution is the same as the :attr:`mean`. Parameters ---------- dist_params : pytree with ndarray leaves A batch of distribution parameters. Returns ------- X : ndarray A batch of differentiable variates. """ return self._mode_func @property def log_proba(self): r""" JIT-compiled function that evaluates log-probabilities. Parameters ---------- dist_params : pytree with ndarray leaves A batch of distribution parameters. X : ndarray A batch of variates, e.g. a batch of actions :math:`a` collected from experience. Returns ------- logP : ndarray of floats A batch of log-probabilities associated with the provided variates. """ return self._log_proba_func @property def entropy(self): r""" JIT-compiled function that computes the entropy of the distribution. .. math:: H\ =\ -\mathbb{E}_p \log p \ =\ \frac12\left( \log(2\pi\sigma^2) + 1\right) Parameters ---------- dist_params : pytree with ndarray leaves A batch of distribution parameters. Returns ------- H : ndarray of floats A batch of entropy values. """ return self._entropy_func @property def cross_entropy(self): r""" JIT-compiled function that computes the cross-entropy of a distribution :math:`q` relative to another categorical distribution :math:`p`: .. math:: \text{CE}[p,q]\ =\ -\mathbb{E}_p \log q \ =\ \frac12\left( \log(2\pi\sigma_q^2) + \frac{(\mu_p-\mu_q)^2+\sigma_p^2}{\sigma_q^2} \right) Parameters ---------- dist_params_p : pytree with ndarray leaves The distribution parameters of the *base* distribution :math:`p`. dist_params_q : pytree with ndarray leaves The distribution parameters of the *auxiliary* distribution :math:`q`. """ return self._cross_entropy_func @property def kl_divergence(self): r""" JIT-compiled function that computes the Kullback-Leibler divergence of a categorical distribution :math:`q` relative to another distribution :math:`p`: .. math:: \text{KL}[p,q]\ = -\mathbb{E}_p \left(\log q -\log p\right) \ =\ \frac12\left( \log(\sigma_q^2) - \log(\sigma_p^2) + \frac{(\mu_p-\mu_q)^2+\sigma_p^2}{\sigma_q^2} - 1 \right) Parameters ---------- dist_params_p : pytree with ndarray leaves The distribution parameters of the *base* distribution :math:`p`. dist_params_q : pytree with ndarray leaves The distribution parameters of the *auxiliary* distribution :math:`q`. """ return self._kl_divergence_func