Source code for coax.proba_dists._categorical

import jax
import jax.numpy as jnp
from gymnasium.spaces import Discrete

from ..utils import argmax, jit
from ._base import BaseProbaDist


__all__ = (
    'CategoricalDist',
)


[docs]class CategoricalDist(BaseProbaDist): r""" A differentiable categorical distribution. The input ``dist_params`` to each of the functions is expected to be of the form: .. code:: python dist_params = {'logits': array([...])} which represent the (conditional) distribution parameters. The ``logits``, denoted :math:`z\in\mathbb{R}^n`, are related to the categorical distribution parameters :math:`p\in\Delta^n` via a softmax: .. math:: p_k\ =\ \text{softmax}_k(z)\ =\ \frac{\text{e}^{z_k}}{\sum_j\text{e}^{z_j}} Parameters ---------- space : gymnasium.spaces.Discrete The gymnasium-style space that specifies the domain of the distribution. gumbel_softmax_tau : positive float, optional The parameter :math:`\tau` specifies the sharpness of the Gumbel-softmax sampling (see :func:`sample` method below). A good value for :math:`\tau` balances the trade-off between getting proper deterministic variates (i.e. one-hot vectors) versus getting smooth differentiable variates. """ __slots__ = (*BaseProbaDist.__slots__, '_gumbel_softmax_tau') def __init__(self, space, gumbel_softmax_tau=0.2): if not isinstance(space, Discrete): raise TypeError(f"{self.__class__.__name__} can only be defined over Discrete spaces") super().__init__(space) self._gumbel_softmax_tau = gumbel_softmax_tau def check_shape(x, name): if not isinstance(x, jnp.ndarray): raise TypeError(f"expected an jax.numpy.ndarray, got: {type(x)}") if not (x.ndim == 2 and x.shape[1] == space.n): raise ValueError(f"expected {name}.shape: (?, {space.n}), got: {x.shape}") return x def sample(dist_params, rng): logits = check_shape(dist_params['logits'], 'logits') logp = jax.nn.log_softmax(logits) u = jax.random.uniform(rng, logp.shape) g = -jnp.log(-jnp.log(u)) # g ~ Gumbel(0,1) return jax.nn.softmax((g + logp) / self.gumbel_softmax_tau) def mean(dist_params): logits = check_shape(dist_params['logits'], 'logits') return jax.nn.softmax(logits) def mode(dist_params): logits = check_shape(dist_params['logits'], 'logits') logp = jax.nn.log_softmax(logits) return jax.nn.softmax(logp / self.gumbel_softmax_tau) def log_proba(dist_params, X): X = check_shape(X, 'X') logits = check_shape(dist_params['logits'], 'logits') logp = jax.nn.log_softmax(logits) return jnp.einsum('ij,ij->i', X, logp) def entropy(dist_params): logits = check_shape(dist_params['logits'], 'logits') logp = jax.nn.log_softmax(logits) return jnp.einsum('ij,ij->i', jnp.exp(logp), -logp) def cross_entropy(dist_params_p, dist_params_q): logits_p = check_shape(dist_params_p['logits'], 'logits_p') logits_q = check_shape(dist_params_q['logits'], 'logits_q') p = jax.nn.softmax(logits_p) logq = jax.nn.log_softmax(logits_q) return jnp.einsum('ij,ij->i', p, -logq) def kl_divergence(dist_params_p, dist_params_q): logits_p = check_shape(dist_params_p['logits'], 'logits_p') logits_q = check_shape(dist_params_q['logits'], 'logits_q') logp = jax.nn.log_softmax(logits_p) logq = jax.nn.log_softmax(logits_q) return jnp.einsum('ij,ij->i', jnp.exp(logp), logp - logq) def affine_transform_func(dist_params, scale, shift, value_transform=None): raise NotImplementedError("affine_transform is ill-defined on categorical variables") self._sample_func = jit(sample) self._mean_func = jit(mean) self._mode_func = jit(mode) self._log_proba_func = jit(log_proba) self._entropy_func = jit(entropy) self._cross_entropy_func = jit(cross_entropy) self._kl_divergence_func = jit(kl_divergence) self._affine_transform_func = affine_transform_func @property def hyperparams(self): return {'gumbel_softmax_tau': self.gumbel_softmax_tau} @property def gumbel_softmax_tau(self): return self._gumbel_softmax_tau @property def default_priors(self): return {'logits': jnp.zeros((1, self.space.n))}
[docs] def preprocess_variate(self, rng, X): X = jnp.asarray(X) assert X.ndim <= 1, f"unexpected X.shape: {X.shape}" assert jnp.issubdtype(X.dtype, jnp.integer), f"expected an integer dtype, got {X.dtype}" return jax.nn.one_hot(X, self.space.n).reshape(-1, self.space.n)
[docs] def postprocess_variate(self, rng, X, index=0, batch_mode=False): assert X.ndim == 2 assert X.shape[1] == self.space.n X = argmax(rng, X, axis=1) return X if batch_mode else int(X[index])
@property def sample(self): r""" JIT-compiled function that generates differentiable variates using Gumbel-softmax sampling. :math:`x\sim\text{Cat}(p)` is implemented as .. math:: u_k\ &\sim\ \text{Unif}(0, 1) \\ g_k\ &=\ -\log(-\log(u_k)) \\ x_k\ &=\ \text{softmax}_k\left( \frac{g_k + \log p_k}{\tau} \right) Parameters ---------- dist_params : pytree with ndarray leaves A batch of distribution parameters. rng : PRNGKey A key for seeding the pseudo-random number generator. Returns ------- X : ndarray A batch of variates :math:`x\sim\text{Cat}(p)`. In order to ensure differentiability of the variates this is not an integer, but instead an *almost*-one-hot encoded version thereof. For example, instead of sampling :math:`x=2` from a 4-class categorical distribution, Gumbel-softmax will return a vector like :math:`x=[0.05, 0.02, 0.86, 0.07]`. The latter representation can be viewed as an *almost*-one-hot encoded version of the former. """ return self._sample_func @property def mean(self): r""" JIT-compiled functions that generates differentiable means of the distribution. Strictly speaking, the mean of a categorical variable is not well defined. We opt for returning the raw probabilities: :math:`\text{mean}_k=p_k`. Parameters ---------- dist_params : pytree with ndarray leaves A batch of distribution parameters. Returns ------- X : ndarray A batch of would-be variates :math:`x\sim\text{Cat}(p)`. In contrast to the output of other methods, these aren't true variates because they are not *almost*-one-hot encoded. """ return self._mean_func @property def mode(self): r""" JIT-compiled functions that generates differentiable modes of the distribution, for which we use a similar trick as in Gumbel-softmax sampling: .. math:: \text{mode}_k\ =\ \text{softmax}_k\left( \frac{\log p_k}{\tau} \right) Parameters ---------- dist_params : pytree with ndarray leaves A batch of distribution parameters. Returns ------- X : ndarray A batch of variates :math:`x\sim\text{Cat}(p)`. In order to ensure differentiability of the variates this is not an integer, but instead an *almost*-one-hot encoded version thereof. For example, instead of sampling :math:`x=2` from a 4-class categorical distribution, Gumbel-softmax will return a vector like :math:`x=(0.05, 0.02, 0.86, 0.07)`. The latter representation can be viewed as an *almost*-one-hot encoded version of the former. """ return self._mode_func @property def log_proba(self): r""" JIT-compiled function that evaluates log-probabilities. Parameters ---------- dist_params : pytree with ndarray leaves A batch of distribution parameters. X : ndarray A batch of variates, e.g. a batch of actions :math:`a` collected from experience. Returns ------- logP : ndarray of floats A batch of log-probabilities associated with the provided variates. """ return self._log_proba_func @property def entropy(self): r""" JIT-compiled function that computes the entropy of the distribution. .. math:: H\ =\ -\sum_k p_k \log p_k Parameters ---------- dist_params : pytree with ndarray leaves A batch of distribution parameters. Returns ------- H : ndarray of floats A batch of entropy values. """ return self._entropy_func @property def cross_entropy(self): r""" JIT-compiled function that computes the cross-entropy of a categorical distribution :math:`q` relative to another categorical distribution :math:`p`: .. math:: \text{CE}[p,q]\ =\ -\sum_k p_k \log q_k Parameters ---------- dist_params_p : pytree with ndarray leaves The distribution parameters of the *base* distribution :math:`p`. dist_params_q : pytree with ndarray leaves The distribution parameters of the *auxiliary* distribution :math:`q`. """ return self._cross_entropy_func @property def kl_divergence(self): r""" JIT-compiled function that computes the Kullback-Leibler divergence of a categorical distribution :math:`q` relative to another categorical distribution :math:`p`: .. math:: \text{KL}[p,q]\ =\ -\sum_k p_k \left(\log q_k -\log p_k\right) Parameters ---------- dist_params_p : pytree with ndarray leaves The distribution parameters of the *base* distribution :math:`p`. dist_params_q : pytree with ndarray leaves The distribution parameters of the *auxiliary* distribution :math:`q`. """ return self._kl_divergence_func