coax.proba_dists.ProbaDist

A composite probability distribution.

coax.proba_dists.CategoricalDist

A differentiable categorical distribution.

coax.proba_dists.NormalDist

A differentiable normal distribution.

coax.proba_dists.DiscretizedIntervalDist

A categorical distribution over a discretized interval.

coax.proba_dists.EmpiricalQuantileDist

coax.proba_dists.SquashedNormalDist

A differentiable squashed normal distribution.


Probability Distributions

This is a collection of differentiable probability distributions used throughout the package.

Object Reference

class coax.proba_dists.ProbaDist(space)[source]

A composite probability distribution. This consists of a nested structure, whose leaves are either coax.proba_dists.CategoricalDist or coax.proba_dists.NormalDist instances.

Parameters:

space (gymnasium.Space) – The gymnasium-style space that specifies the domain of the distribution. This may be any space included in the gymnasium.spaces module.

postprocess_variate(rng, X, index=0, batch_mode=False)[source]

The post-processor specific to variates drawn from this ditribution.

This method provides the interface between differentiable, batched variates, i.e. outputs of sample() and mode() and the provided gymnasium space.

Parameters:
  • rng (PRNGKey) – A key for seeding the pseudo-random number generator.

  • X (raw variates) – A batch of raw clean variates, i.e. same format as the outputs of sample() and mode().

  • index (int, optional) – The index to pick out from the batch. Note that this only applies if batch_mode=False.

  • batch_mode (bool, optional) – Whether to return a batch or a single instance.

Returns:

x or X (clean variate) – A single clean variate or a batch thereof (if batch_mode=True). A variate is called clean if it is an instance of the gymnasium-style space, i.e. it satisfies x in self.space.

preprocess_variate(rng, X)[source]

The pre-processor to ensure that an instance of the space is processed into the same structure as variates drawn from this ditribution, i.e. outputs of sample() and mode().

Parameters:
  • rng (PRNGKey) – A key for seeding the pseudo-random number generator.

  • X (clean variates) – A batch of clean variates, i.e. instances of the gymnasium-style space.

Returns:

X (raw variates) – A batch of raw clean variates, i.e. same format as the outputs of sample() and mode().

property affine_transform

Transform the distribution \(\mathcal{D}\to\mathcal{D}'\) in such a way that its associated variables \(X\sim\mathcal{D}\) and \(X'\sim\mathcal{D}'\) are related via an affine transformation:

\[X' = X\times\text{scale} + \text{shift}\]
Parameters:
  • dist_params (pytree with ndarray leaves) – The distribution parameters of the original distribution \(\mathcal{D}\).

  • scale (float or ndarray) – The multiplicative factor of the affine transformation.

  • shift (float or ndarray) – The additive shift of the affine transformation.

  • value_transform (ValueTransform, optional) –

    The transform to apply to the values before the affine transform, i.e.

    \[X' = f\bigl(f^{-1}(X)\times\text{scale} + \text{shift}\bigr)\]

Returns:

dist_params (pytree with ndarray leaves) – The distribution parameters of the transformed distribution \(\mathcal{D}'\).

property cross_entropy

JIT-compiled function that computes the cross-entropy of a distribution \(q\) relative to another categorical distribution \(p\):

\[\text{CE}[p,q]\ =\ -\mathbb{E}_p \log q\]
Parameters:
  • dist_params_p (pytree with ndarray leaves) – The distribution parameters of the base distribution \(p\).

  • dist_params_q (pytree with ndarray leaves) – The distribution parameters of the auxiliary distribution \(q\).

property default_priors

The default distribution parameters.

property dist_params_structure

The tree structure of the distribution parameters.

property entropy

JIT-compiled function that computes the entropy of the distribution.

\[H\ =\ -\mathbb{E}_p \log p\]
Parameters:

dist_params (pytree with ndarray leaves) – A batch of distribution parameters.

Returns:

H (ndarray of floats) – A batch of entropy values.

property hyperparams

The distribution hyperparameters.

property kl_divergence

JIT-compiled function that computes the Kullback-Leibler divergence of a categorical distribution \(q\) relative to another distribution \(p\):

\[\text{KL}[p,q]\ = -\mathbb{E}_p \left(\log q -\log p\right)\]
Parameters:
  • dist_params_p (pytree with ndarray leaves) – The distribution parameters of the base distribution \(p\).

  • dist_params_q (pytree with ndarray leaves) – The distribution parameters of the auxiliary distribution \(q\).

property log_proba

JIT-compiled function that evaluates log-probabilities.

Parameters:
  • dist_params (pytree with ndarray leaves) – A batch of distribution parameters.

  • X (ndarray) – A batch of variates, e.g. a batch of actions \(a\) collected from experience.

Returns:

logP (ndarray of floats) – A batch of log-probabilities associated with the provided variates.

property mean

JIT-compiled functions that generates differentiable means of the distribution.

Parameters:

dist_params (pytree with ndarray leaves) – A batch of distribution parameters.

Returns:

X (ndarray) – A batch of differentiable variates.

property mode

JIT-compiled functions that generates differentiable modes of the distribution.

Parameters:

dist_params (pytree with ndarray leaves) – A batch of distribution parameters.

Returns:

X (ndarray) – A batch of differentiable variates.

property sample

JIT-compiled function that generates differentiable variates.

Parameters:
  • dist_params (pytree with ndarray leaves) – A batch of distribution parameters.

  • rng (PRNGKey) – A key for seeding the pseudo-random number generator.

Returns:

X (ndarray) – A batch of differentiable variates.

property space

The gymnasium-style space that specifies the domain of the distribution.

class coax.proba_dists.CategoricalDist(space, gumbel_softmax_tau=0.2)[source]

A differentiable categorical distribution.

The input dist_params to each of the functions is expected to be of the form:

dist_params = {'logits': array([...])}

which represent the (conditional) distribution parameters. The logits, denoted \(z\in\mathbb{R}^n\), are related to the categorical distribution parameters \(p\in\Delta^n\) via a softmax:

\[p_k\ =\ \text{softmax}_k(z)\ =\ \frac{\text{e}^{z_k}}{\sum_j\text{e}^{z_j}}\]
Parameters:
  • space (gymnasium.spaces.Discrete) – The gymnasium-style space that specifies the domain of the distribution.

  • gumbel_softmax_tau (positive float, optional) – The parameter \(\tau\) specifies the sharpness of the Gumbel-softmax sampling (see sample() method below). A good value for \(\tau\) balances the trade-off between getting proper deterministic variates (i.e. one-hot vectors) versus getting smooth differentiable variates.

postprocess_variate(rng, X, index=0, batch_mode=False)[source]

The post-processor specific to variates drawn from this ditribution.

This method provides the interface between differentiable, batched variates, i.e. outputs of sample() and mode() and the provided gymnasium space.

Parameters:
  • rng (PRNGKey) – A key for seeding the pseudo-random number generator.

  • X (raw variates) – A batch of raw clean variates, i.e. same format as the outputs of sample() and mode().

  • index (int, optional) – The index to pick out from the batch. Note that this only applies if batch_mode=False.

  • batch_mode (bool, optional) – Whether to return a batch or a single instance.

Returns:

x or X (clean variate) – A single clean variate or a batch thereof (if batch_mode=True). A variate is called clean if it is an instance of the gymnasium-style space, i.e. it satisfies x in self.space.

preprocess_variate(rng, X)[source]

The pre-processor to ensure that an instance of the space is processed into the same structure as variates drawn from this ditribution, i.e. outputs of sample() and mode().

Parameters:
  • rng (PRNGKey) – A key for seeding the pseudo-random number generator.

  • X (clean variates) – A batch of clean variates, i.e. instances of the gymnasium-style space.

Returns:

X (raw variates) – A batch of raw clean variates, i.e. same format as the outputs of sample() and mode().

property affine_transform

Transform the distribution \(\mathcal{D}\to\mathcal{D}'\) in such a way that its associated variables \(X\sim\mathcal{D}\) and \(X'\sim\mathcal{D}'\) are related via an affine transformation:

\[X' = X\times\text{scale} + \text{shift}\]
Parameters:
  • dist_params (pytree with ndarray leaves) – The distribution parameters of the original distribution \(\mathcal{D}\).

  • scale (float or ndarray) – The multiplicative factor of the affine transformation.

  • shift (float or ndarray) – The additive shift of the affine transformation.

  • value_transform (ValueTransform, optional) –

    The transform to apply to the values before the affine transform, i.e.

    \[X' = f\bigl(f^{-1}(X)\times\text{scale} + \text{shift}\bigr)\]

Returns:

dist_params (pytree with ndarray leaves) – The distribution parameters of the transformed distribution \(\mathcal{D}'\).

property cross_entropy

JIT-compiled function that computes the cross-entropy of a categorical distribution \(q\) relative to another categorical distribution \(p\):

\[\text{CE}[p,q]\ =\ -\sum_k p_k \log q_k\]
Parameters:
  • dist_params_p (pytree with ndarray leaves) – The distribution parameters of the base distribution \(p\).

  • dist_params_q (pytree with ndarray leaves) – The distribution parameters of the auxiliary distribution \(q\).

property default_priors

The default distribution parameters.

property dist_params_structure

The tree structure of the distribution parameters.

property entropy

JIT-compiled function that computes the entropy of the distribution.

\[H\ =\ -\sum_k p_k \log p_k\]
Parameters:

dist_params (pytree with ndarray leaves) – A batch of distribution parameters.

Returns:

H (ndarray of floats) – A batch of entropy values.

property hyperparams

The distribution hyperparameters.

property kl_divergence

JIT-compiled function that computes the Kullback-Leibler divergence of a categorical distribution \(q\) relative to another categorical distribution \(p\):

\[\text{KL}[p,q]\ =\ -\sum_k p_k \left(\log q_k -\log p_k\right)\]
Parameters:
  • dist_params_p (pytree with ndarray leaves) – The distribution parameters of the base distribution \(p\).

  • dist_params_q (pytree with ndarray leaves) – The distribution parameters of the auxiliary distribution \(q\).

property log_proba

JIT-compiled function that evaluates log-probabilities.

Parameters:
  • dist_params (pytree with ndarray leaves) – A batch of distribution parameters.

  • X (ndarray) – A batch of variates, e.g. a batch of actions \(a\) collected from experience.

Returns:

logP (ndarray of floats) – A batch of log-probabilities associated with the provided variates.

property mean

JIT-compiled functions that generates differentiable means of the distribution. Strictly speaking, the mean of a categorical variable is not well defined. We opt for returning the raw probabilities: \(\text{mean}_k=p_k\).

Parameters:

dist_params (pytree with ndarray leaves) – A batch of distribution parameters.

Returns:

X (ndarray) – A batch of would-be variates \(x\sim\text{Cat}(p)\). In contrast to the output of other methods, these aren’t true variates because they are not almost-one-hot encoded.

property mode

JIT-compiled functions that generates differentiable modes of the distribution, for which we use a similar trick as in Gumbel-softmax sampling:

\[\text{mode}_k\ =\ \text{softmax}_k\left( \frac{\log p_k}{\tau} \right)\]
Parameters:

dist_params (pytree with ndarray leaves) – A batch of distribution parameters.

Returns:

X (ndarray) – A batch of variates \(x\sim\text{Cat}(p)\). In order to ensure differentiability of the variates this is not an integer, but instead an almost-one-hot encoded version thereof.

For example, instead of sampling \(x=2\) from a 4-class categorical distribution, Gumbel-softmax will return a vector like \(x=(0.05, 0.02, 0.86, 0.07)\). The latter representation can be viewed as an almost-one-hot encoded version of the former.

property sample

JIT-compiled function that generates differentiable variates using Gumbel-softmax sampling. \(x\sim\text{Cat}(p)\) is implemented as

\[\begin{split}u_k\ &\sim\ \text{Unif}(0, 1) \\ g_k\ &=\ -\log(-\log(u_k)) \\ x_k\ &=\ \text{softmax}_k\left( \frac{g_k + \log p_k}{\tau} \right)\end{split}\]
Parameters:
  • dist_params (pytree with ndarray leaves) – A batch of distribution parameters.

  • rng (PRNGKey) – A key for seeding the pseudo-random number generator.

Returns:

X (ndarray) – A batch of variates \(x\sim\text{Cat}(p)\). In order to ensure differentiability of the variates this is not an integer, but instead an almost-one-hot encoded version thereof.

For example, instead of sampling \(x=2\) from a 4-class categorical distribution, Gumbel-softmax will return a vector like \(x=[0.05, 0.02, 0.86, 0.07]\). The latter representation can be viewed as an almost-one-hot encoded version of the former.

property space

The gymnasium-style space that specifies the domain of the distribution.

class coax.proba_dists.NormalDist(space, clip_box=(-256.0, 256.0), clip_reals=(-30.0, 30.0), clip_logvar=(-20.0, 20.0))[source]

A differentiable normal distribution.

The input dist_params to each of the functions is expected to be of the form:

dist_params = {'mu': array([...]), 'logvar': array([...])}

which represent the (conditional) distribution parameters. Here, mu is the mean \(\mu\) and logvar is the log-variance \(\log(\sigma^2)\).

Parameters:
  • space (gymnasium.spaces.Box) – The gymnasium-style space that specifies the domain of the distribution.

  • clip_box (pair of floats, optional) – The range of values to allow for clean (compact) variates. This is mainly to ensure reasonable values when one or more dimensions of the Box space have very large ranges, while in reality only a small part of that range is occupied.

  • clip_reals (pair of floats, optional) – The range of values to allow for raw (decompactified) variates, the reals, used internally. This range is set for numeric stability. Namely, the postprocess_variate method compactifies the reals to a closed interval (Box) by applying a logistic sigmoid. Setting a finite range for clip_reals ensures that the sigmoid doesn’t fully saturate.

  • clip_logvar (pair of floats, optional) – The range of values to allow for the log-variance of the distribution.

postprocess_variate(rng, X, index=0, batch_mode=False)[source]

The post-processor specific to variates drawn from this ditribution.

This method provides the interface between differentiable, batched variates, i.e. outputs of sample() and mode() and the provided gymnasium space.

Parameters:
  • rng (PRNGKey) – A key for seeding the pseudo-random number generator.

  • X (raw variates) – A batch of raw clean variates, i.e. same format as the outputs of sample() and mode().

  • index (int, optional) – The index to pick out from the batch. Note that this only applies if batch_mode=False.

  • batch_mode (bool, optional) – Whether to return a batch or a single instance.

Returns:

x or X (clean variate) – A single clean variate or a batch thereof (if batch_mode=True). A variate is called clean if it is an instance of the gymnasium-style space, i.e. it satisfies x in self.space.

preprocess_variate(rng, X)[source]

The pre-processor to ensure that an instance of the space is processed into the same structure as variates drawn from this ditribution, i.e. outputs of sample() and mode().

Parameters:
  • rng (PRNGKey) – A key for seeding the pseudo-random number generator.

  • X (clean variates) – A batch of clean variates, i.e. instances of the gymnasium-style space.

Returns:

X (raw variates) – A batch of raw clean variates, i.e. same format as the outputs of sample() and mode().

property affine_transform

Transform the distribution \(\mathcal{D}\to\mathcal{D}'\) in such a way that its associated variables \(X\sim\mathcal{D}\) and \(X'\sim\mathcal{D}'\) are related via an affine transformation:

\[X' = X\times\text{scale} + \text{shift}\]
Parameters:
  • dist_params (pytree with ndarray leaves) – The distribution parameters of the original distribution \(\mathcal{D}\).

  • scale (float or ndarray) – The multiplicative factor of the affine transformation.

  • shift (float or ndarray) – The additive shift of the affine transformation.

  • value_transform (ValueTransform, optional) –

    The transform to apply to the values before the affine transform, i.e.

    \[X' = f\bigl(f^{-1}(X)\times\text{scale} + \text{shift}\bigr)\]

Returns:

dist_params (pytree with ndarray leaves) – The distribution parameters of the transformed distribution \(\mathcal{D}'\).

property cross_entropy

JIT-compiled function that computes the cross-entropy of a distribution \(q\) relative to another categorical distribution \(p\):

\[\text{CE}[p,q]\ =\ -\mathbb{E}_p \log q \ =\ \frac12\left( \log(2\pi\sigma_q^2) + \frac{(\mu_p-\mu_q)^2+\sigma_p^2}{\sigma_q^2} \right)\]
Parameters:
  • dist_params_p (pytree with ndarray leaves) – The distribution parameters of the base distribution \(p\).

  • dist_params_q (pytree with ndarray leaves) – The distribution parameters of the auxiliary distribution \(q\).

property default_priors

The default distribution parameters.

property dist_params_structure

The tree structure of the distribution parameters.

property entropy

JIT-compiled function that computes the entropy of the distribution.

\[H\ =\ -\mathbb{E}_p \log p \ =\ \frac12\left( \log(2\pi\sigma^2) + 1\right)\]
Parameters:

dist_params (pytree with ndarray leaves) – A batch of distribution parameters.

Returns:

H (ndarray of floats) – A batch of entropy values.

property hyperparams

The distribution hyperparameters.

property kl_divergence

JIT-compiled function that computes the Kullback-Leibler divergence of a categorical distribution \(q\) relative to another distribution \(p\):

\[\text{KL}[p,q]\ = -\mathbb{E}_p \left(\log q -\log p\right) \ =\ \frac12\left( \log(\sigma_q^2) - \log(\sigma_p^2) + \frac{(\mu_p-\mu_q)^2+\sigma_p^2}{\sigma_q^2} - 1 \right)\]
Parameters:
  • dist_params_p (pytree with ndarray leaves) – The distribution parameters of the base distribution \(p\).

  • dist_params_q (pytree with ndarray leaves) – The distribution parameters of the auxiliary distribution \(q\).

property log_proba

JIT-compiled function that evaluates log-probabilities.

Parameters:
  • dist_params (pytree with ndarray leaves) – A batch of distribution parameters.

  • X (ndarray) – A batch of variates, e.g. a batch of actions \(a\) collected from experience.

Returns:

logP (ndarray of floats) – A batch of log-probabilities associated with the provided variates.

property mean

JIT-compiled functions that generates differentiable means of the distribution, in this case simply \(\mu\).

Parameters:

dist_params (pytree with ndarray leaves) – A batch of distribution parameters.

Returns:

X (ndarray) – A batch of differentiable variates.

property mode

JIT-compiled functions that generates differentiable modes of the distribution, which for a normal distribution is the same as the mean.

Parameters:

dist_params (pytree with ndarray leaves) – A batch of distribution parameters.

Returns:

X (ndarray) – A batch of differentiable variates.

property sample

JIT-compiled function that generates differentiable variates using the reparametrization trick, i.e. \(x\sim\mathcal{N}(\mu,\sigma^2)\) is implemented as

\[\begin{split}\varepsilon\ &\sim\ \mathcal{N}(0,1) \\ x\ &=\ \mu + \sigma\,\varepsilon\end{split}\]
Parameters:
  • dist_params (pytree with ndarray leaves) – A batch of distribution parameters.

  • rng (PRNGKey) – A key for seeding the pseudo-random number generator.

Returns:

X (ndarray) – A batch of differentiable variates.

property space

The gymnasium-style space that specifies the domain of the distribution.

class coax.proba_dists.DiscretizedIntervalDist(space, num_bins=20, gumbel_softmax_tau=0.2)[source]

A categorical distribution over a discretized interval.

The input dist_params to each of the functions is expected to be of the form:

dist_params = {'logits': array([...])}

which represent the (conditional) distribution parameters. The logits, denoted \(z\in\mathbb{R}^n\), are related to the categorical distribution parameters \(p\in\Delta^n\) via a softmax:

\[p_k\ =\ \text{softmax}_k(z)\ =\ \frac{\text{e}^{z_k}}{\sum_j\text{e}^{z_j}}\]
Parameters:
  • space (gymnasium.spaces.Box) – The gymnasium-style space that specifies the domain of the distribution. The shape of the Box must have prod(shape) == 1, i.e. a single interval.

  • num_bins (int, optional) – The number of equal-sized bins used in the discretization.

  • gumbel_softmax_tau (positive float, optional) – The parameter \(\tau\) specifies the sharpness of the Gumbel-softmax sampling (see sample() method below). A good value for \(\tau\) balances the trade-off between getting proper deterministic variates (i.e. one-hot vectors) versus getting smooth differentiable variates.

postprocess_variate(rng, X, index=0, batch_mode=False)[source]

The post-processor specific to variates drawn from this ditribution.

This method provides the interface between differentiable, batched variates, i.e. outputs of sample() and mode() and the provided gymnasium space.

Parameters:
  • rng (PRNGKey) – A key for seeding the pseudo-random number generator.

  • X (raw variates) – A batch of raw clean variates, i.e. same format as the outputs of sample() and mode().

  • index (int, optional) – The index to pick out from the batch. Note that this only applies if batch_mode=False.

  • batch_mode (bool, optional) – Whether to return a batch or a single instance.

Returns:

x or X (clean variate) – A single clean variate or a batch thereof (if batch_mode=True). A variate is called clean if it is an instance of the gymnasium-style space, i.e. it satisfies x in self.space.

preprocess_variate(rng, X)[source]

The pre-processor to ensure that an instance of the space is processed into the same structure as variates drawn from this ditribution, i.e. outputs of sample() and mode().

Parameters:
  • rng (PRNGKey) – A key for seeding the pseudo-random number generator.

  • X (clean variates) – A batch of clean variates, i.e. instances of the gymnasium-style space.

Returns:

X (raw variates) – A batch of raw clean variates, i.e. same format as the outputs of sample() and mode().

property affine_transform

Transform the distribution \(\mathcal{D}\to\mathcal{D}'\) in such a way that its associated variables \(X\sim\mathcal{D}\) and \(X'\sim\mathcal{D}'\) are related via an affine transformation:

\[X' = X\times\text{scale} + \text{shift}\]
Parameters:
  • dist_params (pytree with ndarray leaves) – The distribution parameters of the original distribution \(\mathcal{D}\).

  • scale (float or ndarray) – The multiplicative factor of the affine transformation.

  • shift (float or ndarray) – The additive shift of the affine transformation.

  • value_transform (ValueTransform, optional) –

    The transform to apply to the values before the affine transform, i.e.

    \[X' = f\bigl(f^{-1}(X)\times\text{scale} + \text{shift}\bigr)\]

Returns:

dist_params (pytree with ndarray leaves) – The distribution parameters of the transformed distribution \(\mathcal{D}'\).

property cross_entropy

JIT-compiled function that computes the cross-entropy of a categorical distribution \(q\) relative to another categorical distribution \(p\):

\[\text{CE}[p,q]\ =\ -\sum_k p_k \log q_k\]
Parameters:
  • dist_params_p (pytree with ndarray leaves) – The distribution parameters of the base distribution \(p\).

  • dist_params_q (pytree with ndarray leaves) – The distribution parameters of the auxiliary distribution \(q\).

property default_priors

The default distribution parameters.

property dist_params_structure

The tree structure of the distribution parameters.

property entropy

JIT-compiled function that computes the entropy of the distribution.

\[H\ =\ -\sum_k p_k \log p_k\]
Parameters:

dist_params (pytree with ndarray leaves) – A batch of distribution parameters.

Returns:

H (ndarray of floats) – A batch of entropy values.

property hyperparams

The distribution hyperparameters.

property kl_divergence

JIT-compiled function that computes the Kullback-Leibler divergence of a categorical distribution \(q\) relative to another categorical distribution \(p\):

\[\text{KL}[p,q]\ =\ -\sum_k p_k \left(\log q_k -\log p_k\right)\]
Parameters:
  • dist_params_p (pytree with ndarray leaves) – The distribution parameters of the base distribution \(p\).

  • dist_params_q (pytree with ndarray leaves) – The distribution parameters of the auxiliary distribution \(q\).

property log_proba

JIT-compiled function that evaluates log-probabilities.

Parameters:
  • dist_params (pytree with ndarray leaves) – A batch of distribution parameters.

  • X (ndarray) – A batch of variates, e.g. a batch of actions \(a\) collected from experience.

Returns:

logP (ndarray of floats) – A batch of log-probabilities associated with the provided variates.

property mean

JIT-compiled functions that generates differentiable means of the distribution. Strictly speaking, the mean of a categorical variable is not well defined. We opt for returning the raw probabilities: \(\text{mean}_k=p_k\).

Parameters:

dist_params (pytree with ndarray leaves) – A batch of distribution parameters.

Returns:

X (ndarray) – A batch of would-be variates \(x\sim\text{Cat}(p)\). In contrast to the output of other methods, these aren’t true variates because they are not almost-one-hot encoded.

property mode

JIT-compiled functions that generates differentiable modes of the distribution, for which we use a similar trick as in Gumbel-softmax sampling:

\[\text{mode}_k\ =\ \text{softmax}_k\left( \frac{\log p_k}{\tau} \right)\]
Parameters:

dist_params (pytree with ndarray leaves) – A batch of distribution parameters.

Returns:

X (ndarray) – A batch of variates \(x\sim\text{Cat}(p)\). In order to ensure differentiability of the variates this is not an integer, but instead an almost-one-hot encoded version thereof.

For example, instead of sampling \(x=2\) from a 4-class categorical distribution, Gumbel-softmax will return a vector like \(x=(0.05, 0.02, 0.86, 0.07)\). The latter representation can be viewed as an almost-one-hot encoded version of the former.

property sample

JIT-compiled function that generates differentiable variates using Gumbel-softmax sampling. \(x\sim\text{Cat}(p)\) is implemented as

\[\begin{split}u_k\ &\sim\ \text{Unif}(0, 1) \\ g_k\ &=\ -\log(-\log(u_k)) \\ x_k\ &=\ \text{softmax}_k\left( \frac{g_k + \log p_k}{\tau} \right)\end{split}\]
Parameters:
  • dist_params (pytree with ndarray leaves) – A batch of distribution parameters.

  • rng (PRNGKey) – A key for seeding the pseudo-random number generator.

Returns:

X (ndarray) – A batch of variates \(x\sim\text{Cat}(p)\). In order to ensure differentiability of the variates this is not an integer, but instead an almost-one-hot encoded version thereof.

For example, instead of sampling \(x=2\) from a 4-class categorical distribution, Gumbel-softmax will return a vector like \(x=[0.05, 0.02, 0.86, 0.07]\). The latter representation can be viewed as an almost-one-hot encoded version of the former.

property space

The gymnasium-style space that specifies the domain of the distribution.

class coax.proba_dists.EmpiricalQuantileDist(num_quantiles)[source]
postprocess_variate(rng, X, index=0, batch_mode=False)

The post-processor specific to variates drawn from this ditribution.

This method provides the interface between differentiable, batched variates, i.e. outputs of sample() and mode() and the provided gymnasium space.

Parameters:
  • rng (PRNGKey) – A key for seeding the pseudo-random number generator.

  • X (raw variates) – A batch of raw clean variates, i.e. same format as the outputs of sample() and mode().

  • index (int, optional) – The index to pick out from the batch. Note that this only applies if batch_mode=False.

  • batch_mode (bool, optional) – Whether to return a batch or a single instance.

Returns:

x or X (clean variate) – A single clean variate or a batch thereof (if batch_mode=True). A variate is called clean if it is an instance of the gymnasium-style space, i.e. it satisfies x in self.space.

preprocess_variate(rng, X)

The pre-processor to ensure that an instance of the space is processed into the same structure as variates drawn from this ditribution, i.e. outputs of sample() and mode().

Parameters:
  • rng (PRNGKey) – A key for seeding the pseudo-random number generator.

  • X (clean variates) – A batch of clean variates, i.e. instances of the gymnasium-style space.

Returns:

X (raw variates) – A batch of raw clean variates, i.e. same format as the outputs of sample() and mode().

property affine_transform

Transform the distribution \(\mathcal{D}\to\mathcal{D}'\) in such a way that its associated variables \(X\sim\mathcal{D}\) and \(X'\sim\mathcal{D}'\) are related via an affine transformation:

\[X' = X\times\text{scale} + \text{shift}\]
Parameters:
  • dist_params (pytree with ndarray leaves) – The distribution parameters of the original distribution \(\mathcal{D}\).

  • scale (float or ndarray) – The multiplicative factor of the affine transformation.

  • shift (float or ndarray) – The additive shift of the affine transformation.

  • value_transform (ValueTransform, optional) –

    The transform to apply to the values before the affine transform, i.e.

    \[X' = f\bigl(f^{-1}(X)\times\text{scale} + \text{shift}\bigr)\]

Returns:

dist_params (pytree with ndarray leaves) – The distribution parameters of the transformed distribution \(\mathcal{D}'\).

property cross_entropy

JIT-compiled function that computes the cross-entropy of a distribution \(q\) relative to another categorical distribution \(p\):

\[\text{CE}[p,q]\ =\ -\mathbb{E}_p \log q\]
Parameters:
  • dist_params_p (pytree with ndarray leaves) – The distribution parameters of the base distribution \(p\).

  • dist_params_q (pytree with ndarray leaves) – The distribution parameters of the auxiliary distribution \(q\).

property default_priors

The default distribution parameters.

property dist_params_structure

The tree structure of the distribution parameters.

property entropy

JIT-compiled function that computes the entropy of the distribution.

\[H\ =\ -\mathbb{E}_p \log p\]
Parameters:

dist_params (pytree with ndarray leaves) – A batch of distribution parameters.

Returns:

H (ndarray of floats) – A batch of entropy values.

property hyperparams

The distribution hyperparameters.

property kl_divergence

JIT-compiled function that computes the Kullback-Leibler divergence of a categorical distribution \(q\) relative to another distribution \(p\):

\[\text{KL}[p,q]\ = -\mathbb{E}_p \left(\log q -\log p\right)\]
Parameters:
  • dist_params_p (pytree with ndarray leaves) – The distribution parameters of the base distribution \(p\).

  • dist_params_q (pytree with ndarray leaves) – The distribution parameters of the auxiliary distribution \(q\).

property log_proba

JIT-compiled function that evaluates log-probabilities.

Parameters:
  • dist_params (pytree with ndarray leaves) – A batch of distribution parameters.

  • X (ndarray) – A batch of variates, e.g. a batch of actions \(a\) collected from experience.

Returns:

logP (ndarray of floats) – A batch of log-probabilities associated with the provided variates.

property mean

JIT-compiled functions that generates differentiable means of the distribution.

Parameters:

dist_params (pytree with ndarray leaves) – A batch of distribution parameters.

Returns:

X (ndarray) – A batch of differentiable variates.

property mode

JIT-compiled functions that generates differentiable modes of the distribution.

Parameters:

dist_params (pytree with ndarray leaves) – A batch of distribution parameters.

Returns:

X (ndarray) – A batch of differentiable variates.

property sample

JIT-compiled function that generates differentiable variates.

Parameters:
  • dist_params (pytree with ndarray leaves) – A batch of distribution parameters.

  • rng (PRNGKey) – A key for seeding the pseudo-random number generator.

Returns:

X (ndarray) – A batch of differentiable variates.

property space

The gymnasium-style space that specifies the domain of the distribution.

class coax.proba_dists.SquashedNormalDist(space, clip_logvar=None)[source]

A differentiable squashed normal distribution.

The input dist_params to each of the functions is expected to be of the form:

dist_params = {'mu': array([...]), 'logvar': array([...])}

which represent the (conditional) distribution parameters. Here, mu is the mean \(\mu\) and logvar is the log-variance \(\log(\sigma^2)\).

Parameters:
  • space (gymnasium.spaces.Box) – The gymnasium-style space that specifies the domain of the distribution.

  • clip_logvar (pair of floats, optional) – The range of values to allow for the log-variance of the distribution.

postprocess_variate(rng, X, index=0, batch_mode=False)[source]

The post-processor specific to variates drawn from this ditribution.

This method provides the interface between differentiable, batched variates, i.e. outputs of sample() and mode() and the provided gymnasium space.

Parameters:
  • rng (PRNGKey) – A key for seeding the pseudo-random number generator.

  • X (raw variates) – A batch of raw clean variates, i.e. same format as the outputs of sample() and mode().

  • index (int, optional) – The index to pick out from the batch. Note that this only applies if batch_mode=False.

  • batch_mode (bool, optional) – Whether to return a batch or a single instance.

Returns:

x or X (clean variate) – A single clean variate or a batch thereof (if batch_mode=True). A variate is called clean if it is an instance of the gymnasium-style space, i.e. it satisfies x in self.space.

preprocess_variate(rng, X)[source]

The pre-processor to ensure that an instance of the space is processed into the same structure as variates drawn from this ditribution, i.e. outputs of sample() and mode().

Parameters:
  • rng (PRNGKey) – A key for seeding the pseudo-random number generator.

  • X (clean variates) – A batch of clean variates, i.e. instances of the gymnasium-style space.

Returns:

X (raw variates) – A batch of raw clean variates, i.e. same format as the outputs of sample() and mode().

property affine_transform

Transform the distribution \(\mathcal{D}\to\mathcal{D}'\) in such a way that its associated variables \(X\sim\mathcal{D}\) and \(X'\sim\mathcal{D}'\) are related via an affine transformation:

\[X' = X\times\text{scale} + \text{shift}\]
Parameters:
  • dist_params (pytree with ndarray leaves) – The distribution parameters of the original distribution \(\mathcal{D}\).

  • scale (float or ndarray) – The multiplicative factor of the affine transformation.

  • shift (float or ndarray) – The additive shift of the affine transformation.

  • value_transform (ValueTransform, optional) –

    The transform to apply to the values before the affine transform, i.e.

    \[X' = f\bigl(f^{-1}(X)\times\text{scale} + \text{shift}\bigr)\]

Returns:

dist_params (pytree with ndarray leaves) – The distribution parameters of the transformed distribution \(\mathcal{D}'\).

property cross_entropy

JIT-compiled function that computes the cross-entropy of a distribution \(q\) relative to another categorical distribution \(p\):

\[\text{CE}[p,q]\ =\ -\mathbb{E}_p \log q \ =\ \frac12\left( \log(2\pi\sigma_q^2) + \frac{(\mu_p-\mu_q)^2+\sigma_p^2}{\sigma_q^2} \right)\]
Parameters:
  • dist_params_p (pytree with ndarray leaves) – The distribution parameters of the base distribution \(p\).

  • dist_params_q (pytree with ndarray leaves) – The distribution parameters of the auxiliary distribution \(q\).

property default_priors

The default distribution parameters.

property dist_params_structure

The tree structure of the distribution parameters.

property entropy

JIT-compiled function that computes the entropy of the distribution.

\[H\ =\ -\mathbb{E}_p \log p \ =\ \frac12\left( \log(2\pi\sigma^2) + 1\right)\]
Parameters:

dist_params (pytree with ndarray leaves) – A batch of distribution parameters.

Returns:

H (ndarray of floats) – A batch of entropy values.

property hyperparams

The distribution hyperparameters.

property kl_divergence

JIT-compiled function that computes the Kullback-Leibler divergence of a categorical distribution \(q\) relative to another distribution \(p\):

\[\text{KL}[p,q]\ = -\mathbb{E}_p \left(\log q -\log p\right) \ =\ \frac12\left( \log(\sigma_q^2) - \log(\sigma_p^2) + \frac{(\mu_p-\mu_q)^2+\sigma_p^2}{\sigma_q^2} - 1 \right)\]
Parameters:
  • dist_params_p (pytree with ndarray leaves) – The distribution parameters of the base distribution \(p\).

  • dist_params_q (pytree with ndarray leaves) – The distribution parameters of the auxiliary distribution \(q\).

property log_proba

JIT-compiled function that evaluates log-probabilities.

Parameters:
  • dist_params (pytree with ndarray leaves) – A batch of distribution parameters.

  • X (ndarray) – A batch of variates, e.g. a batch of actions \(a\) collected from experience.

Returns:

logP (ndarray of floats) – A batch of log-probabilities associated with the provided variates.

property mean

JIT-compiled functions that generates differentiable means of the distribution, in this case simply \(\tanh(\mu)\).

Parameters:

dist_params (pytree with ndarray leaves) – A batch of distribution parameters.

Returns:

X (ndarray) – A batch of differentiable variates.

property mode

JIT-compiled functions that generates differentiable modes of the distribution, which for a normal distribution is the same as the mean.

Parameters:

dist_params (pytree with ndarray leaves) – A batch of distribution parameters.

Returns:

X (ndarray) – A batch of differentiable variates.

property sample

JIT-compiled function that generates differentiable variates using the reparametrization trick, i.e. \(x\sim\tanh(\mathcal{N}(\mu,\sigma^2))\) is implemented as

\[\begin{split}\varepsilon\ &\sim\ \mathcal{N}(0,1) \\ x\ &=\ \tanh(\mu + \sigma\,\varepsilon)\end{split}\]
Parameters:
  • dist_params (pytree with ndarray leaves) – A batch of distribution parameters.

  • rng (PRNGKey) – A key for seeding the pseudo-random number generator.

Returns:

X (ndarray) – A batch of differentiable variates.

property space

The gymnasium-style space that specifies the domain of the distribution.