Source code for coax.proba_dists._squashed_normal

import jax
import jax.numpy as jnp
import numpy as onp

from ._base import BaseProbaDist
from ._normal import NormalDist


[docs]class SquashedNormalDist(BaseProbaDist): r""" A differentiable squashed normal distribution. The input ``dist_params`` to each of the functions is expected to be of the form: .. code:: python dist_params = {'mu': array([...]), 'logvar': array([...])} which represent the (conditional) distribution parameters. Here, ``mu`` is the mean :math:`\mu` and ``logvar`` is the log-variance :math:`\log(\sigma^2)`. Parameters ---------- space : gymnasium.spaces.Box The gymnasium-style space that specifies the domain of the distribution. clip_logvar : pair of floats, optional The range of values to allow for the log-variance of the distribution. """ def __init__(self, space, clip_logvar=None): super().__init__(space) self._normal_dist = NormalDist(space=space, clip_logvar=clip_logvar) self._scale = (space.high - space.low) / 2.0 self._offset = (space.high + space.low) / 2.0 def sample(dist_params, rng): X = self._normal_dist.sample(dist_params, rng) return jnp.tanh(X) * self._scale + self._offset def mean(dist_params): mu = self._normal_dist.mean(dist_params) return jnp.tanh(mu) * self._scale + self._offset def mode(dist_params): return mean(dist_params) arctanh_eps = 1e-7 # avoid arctanh(1) = acrtanh(-1) = inf def log_proba(dist_params, X): X = jnp.arctanh(jnp.clip(X, a_min=-1.0 + arctanh_eps, a_max=1.0 - arctanh_eps)) logp = self._normal_dist.log_proba(dist_params, X) return logp - jnp.sum(2 * (jnp.log(2) - X - jnp.log(1 + jnp.exp(-2 * X))), axis=-1) self._sample_func = jax.jit(sample) self._mean_func = jax.jit(mean) self._mode_func = jax.jit(mode) self._log_proba_func = jax.jit(log_proba) self._affine_transform_func = self._normal_dist.affine_transform @property def default_priors(self): return self._normal_dist.default_priors
[docs] def preprocess_variate(self, rng, X): X = jnp.asarray(X, dtype=self.space.dtype) # ensure ndarray X = jnp.reshape(X, (-1, *self.space.shape)) # ensure batch axis return X
[docs] def postprocess_variate(self, rng, X, index=0, batch_mode=False): X = jnp.asarray(X, dtype=self.space.dtype) # ensure ndarray X = jnp.reshape(X, (-1, *self.space.shape)) # ensure correct shape return X if batch_mode else onp.asanyarray(X[index])
@property def sample(self): r""" JIT-compiled function that generates differentiable variates using the reparametrization trick, i.e. :math:`x\sim\tanh(\mathcal{N}(\mu,\sigma^2))` is implemented as .. math:: \varepsilon\ &\sim\ \mathcal{N}(0,1) \\ x\ &=\ \tanh(\mu + \sigma\,\varepsilon) Parameters ---------- dist_params : pytree with ndarray leaves A batch of distribution parameters. rng : PRNGKey A key for seeding the pseudo-random number generator. Returns ------- X : ndarray A batch of differentiable variates. """ return self._sample_func @property def mean(self): r""" JIT-compiled functions that generates differentiable means of the distribution, in this case simply :math:`\tanh(\mu)`. Parameters ---------- dist_params : pytree with ndarray leaves A batch of distribution parameters. Returns ------- X : ndarray A batch of differentiable variates. """ return self._mean_func @property def mode(self): r""" JIT-compiled functions that generates differentiable modes of the distribution, which for a normal distribution is the same as the :attr:`mean`. Parameters ---------- dist_params : pytree with ndarray leaves A batch of distribution parameters. Returns ------- X : ndarray A batch of differentiable variates. """ return self._mode_func @property def log_proba(self): r""" JIT-compiled function that evaluates log-probabilities. Parameters ---------- dist_params : pytree with ndarray leaves A batch of distribution parameters. X : ndarray A batch of variates, e.g. a batch of actions :math:`a` collected from experience. Returns ------- logP : ndarray of floats A batch of log-probabilities associated with the provided variates. """ return self._log_proba_func @property def entropy(self): r""" JIT-compiled function that computes the entropy of the distribution. .. math:: H\ =\ -\mathbb{E}_p \log p \ =\ \frac12\left( \log(2\pi\sigma^2) + 1\right) Parameters ---------- dist_params : pytree with ndarray leaves A batch of distribution parameters. Returns ------- H : ndarray of floats A batch of entropy values. """ return self._normal_dist.entropy @property def cross_entropy(self): r""" JIT-compiled function that computes the cross-entropy of a distribution :math:`q` relative to another categorical distribution :math:`p`: .. math:: \text{CE}[p,q]\ =\ -\mathbb{E}_p \log q \ =\ \frac12\left( \log(2\pi\sigma_q^2) + \frac{(\mu_p-\mu_q)^2+\sigma_p^2}{\sigma_q^2} \right) Parameters ---------- dist_params_p : pytree with ndarray leaves The distribution parameters of the *base* distribution :math:`p`. dist_params_q : pytree with ndarray leaves The distribution parameters of the *auxiliary* distribution :math:`q`. """ return self._normal_dist.cross_entropy @property def kl_divergence(self): r""" JIT-compiled function that computes the Kullback-Leibler divergence of a categorical distribution :math:`q` relative to another distribution :math:`p`: .. math:: \text{KL}[p,q]\ = -\mathbb{E}_p \left(\log q -\log p\right) \ =\ \frac12\left( \log(\sigma_q^2) - \log(\sigma_p^2) + \frac{(\mu_p-\mu_q)^2+\sigma_p^2}{\sigma_q^2} - 1 \right) Parameters ---------- dist_params_p : pytree with ndarray leaves The distribution parameters of the *base* distribution :math:`p`. dist_params_q : pytree with ndarray leaves The distribution parameters of the *auxiliary* distribution :math:`q`. """ return self._normal_dist.kl_divergence