A composite probability distribution. 

A differentiable categorical distribution. 

A differentiable normal distribution. 

A categorical distribution over a discretized interval. 

A differentiable squashed normal distribution. 
Probability Distributions¶
This is a collection of differentiable probability distributions used throughout the package.
Object Reference¶
 class coax.proba_dists.ProbaDist(space)[source]¶
A composite probability distribution. This consists of a nested structure, whose leaves are either
coax.proba_dists.CategoricalDist
orcoax.proba_dists.NormalDist
instances. Parameters:
space (gymnasium.Space) – The gymnasiumstyle space that specifies the domain of the distribution. This may be any space included in the
gymnasium.spaces
module.
 postprocess_variate(rng, X, index=0, batch_mode=False)[source]¶
The postprocessor specific to variates drawn from this ditribution.
This method provides the interface between differentiable, batched variates, i.e. outputs of
sample()
andmode()
and the provided gymnasium space. Parameters:
rng (PRNGKey) – A key for seeding the pseudorandom number generator.
X (raw variates) – A batch of raw clean variates, i.e. same format as the outputs of
sample()
andmode()
.index (int, optional) – The index to pick out from the batch. Note that this only applies if
batch_mode=False
.batch_mode (bool, optional) – Whether to return a batch or a single instance.
 Returns:
x or X (clean variate) – A single clean variate or a batch thereof (if
batch_mode=True
). A variate is called clean if it is an instance of the gymnasiumstylespace
, i.e. it satisfiesx in self.space
.
 preprocess_variate(rng, X)[source]¶
The preprocessor to ensure that an instance of the
space
is processed into the same structure as variates drawn from this ditribution, i.e. outputs ofsample()
andmode()
.
 property affine_transform¶
Transform the distribution \(\mathcal{D}\to\mathcal{D}'\) in such a way that its associated variables \(X\sim\mathcal{D}\) and \(X'\sim\mathcal{D}'\) are related via an affine transformation:
\[X' = X\times\text{scale} + \text{shift}\] Parameters:
dist_params (pytree with ndarray leaves) – The distribution parameters of the original distribution \(\mathcal{D}\).
scale (float or ndarray) – The multiplicative factor of the affine transformation.
shift (float or ndarray) – The additive shift of the affine transformation.
value_transform (ValueTransform, optional) –
The transform to apply to the values before the affine transform, i.e.
\[X' = f\bigl(f^{1}(X)\times\text{scale} + \text{shift}\bigr)\]
 Returns:
dist_params (pytree with ndarray leaves) – The distribution parameters of the transformed distribution \(\mathcal{D}'\).
 property cross_entropy¶
JITcompiled function that computes the crossentropy of a distribution \(q\) relative to another categorical distribution \(p\):
\[\text{CE}[p,q]\ =\ \mathbb{E}_p \log q\] Parameters:
dist_params_p (pytree with ndarray leaves) – The distribution parameters of the base distribution \(p\).
dist_params_q (pytree with ndarray leaves) – The distribution parameters of the auxiliary distribution \(q\).
 property default_priors¶
The default distribution parameters.
 property dist_params_structure¶
The tree structure of the distribution parameters.
 property entropy¶
JITcompiled function that computes the entropy of the distribution.
\[H\ =\ \mathbb{E}_p \log p\] Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
 Returns:
H (ndarray of floats) – A batch of entropy values.
 property hyperparams¶
The distribution hyperparameters.
 property kl_divergence¶
JITcompiled function that computes the KullbackLeibler divergence of a categorical distribution \(q\) relative to another distribution \(p\):
\[\text{KL}[p,q]\ = \mathbb{E}_p \left(\log q \log p\right)\] Parameters:
dist_params_p (pytree with ndarray leaves) – The distribution parameters of the base distribution \(p\).
dist_params_q (pytree with ndarray leaves) – The distribution parameters of the auxiliary distribution \(q\).
 property log_proba¶
JITcompiled function that evaluates logprobabilities.
 Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
X (ndarray) – A batch of variates, e.g. a batch of actions \(a\) collected from experience.
 Returns:
logP (ndarray of floats) – A batch of logprobabilities associated with the provided variates.
 property mean¶
JITcompiled functions that generates differentiable means of the distribution.
 Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
 Returns:
X (ndarray) – A batch of differentiable variates.
 property mode¶
JITcompiled functions that generates differentiable modes of the distribution.
 Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
 Returns:
X (ndarray) – A batch of differentiable variates.
 property sample¶
JITcompiled function that generates differentiable variates.
 Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
rng (PRNGKey) – A key for seeding the pseudorandom number generator.
 Returns:
X (ndarray) – A batch of differentiable variates.
 property space¶
The gymnasiumstyle space that specifies the domain of the distribution.
 class coax.proba_dists.CategoricalDist(space, gumbel_softmax_tau=0.2)[source]¶
A differentiable categorical distribution.
The input
dist_params
to each of the functions is expected to be of the form:dist_params = {'logits': array([...])}
which represent the (conditional) distribution parameters. The
logits
, denoted \(z\in\mathbb{R}^n\), are related to the categorical distribution parameters \(p\in\Delta^n\) via a softmax:\[p_k\ =\ \text{softmax}_k(z)\ =\ \frac{\text{e}^{z_k}}{\sum_j\text{e}^{z_j}}\] Parameters:
space (gymnasium.spaces.Discrete) – The gymnasiumstyle space that specifies the domain of the distribution.
gumbel_softmax_tau (positive float, optional) – The parameter \(\tau\) specifies the sharpness of the Gumbelsoftmax sampling (see
sample()
method below). A good value for \(\tau\) balances the tradeoff between getting proper deterministic variates (i.e. onehot vectors) versus getting smooth differentiable variates.
 postprocess_variate(rng, X, index=0, batch_mode=False)[source]¶
The postprocessor specific to variates drawn from this ditribution.
This method provides the interface between differentiable, batched variates, i.e. outputs of
sample()
andmode()
and the provided gymnasium space. Parameters:
rng (PRNGKey) – A key for seeding the pseudorandom number generator.
X (raw variates) – A batch of raw clean variates, i.e. same format as the outputs of
sample()
andmode()
.index (int, optional) – The index to pick out from the batch. Note that this only applies if
batch_mode=False
.batch_mode (bool, optional) – Whether to return a batch or a single instance.
 Returns:
x or X (clean variate) – A single clean variate or a batch thereof (if
batch_mode=True
). A variate is called clean if it is an instance of the gymnasiumstylespace
, i.e. it satisfiesx in self.space
.
 preprocess_variate(rng, X)[source]¶
The preprocessor to ensure that an instance of the
space
is processed into the same structure as variates drawn from this ditribution, i.e. outputs ofsample()
andmode()
.
 property affine_transform¶
Transform the distribution \(\mathcal{D}\to\mathcal{D}'\) in such a way that its associated variables \(X\sim\mathcal{D}\) and \(X'\sim\mathcal{D}'\) are related via an affine transformation:
\[X' = X\times\text{scale} + \text{shift}\] Parameters:
dist_params (pytree with ndarray leaves) – The distribution parameters of the original distribution \(\mathcal{D}\).
scale (float or ndarray) – The multiplicative factor of the affine transformation.
shift (float or ndarray) – The additive shift of the affine transformation.
value_transform (ValueTransform, optional) –
The transform to apply to the values before the affine transform, i.e.
\[X' = f\bigl(f^{1}(X)\times\text{scale} + \text{shift}\bigr)\]
 Returns:
dist_params (pytree with ndarray leaves) – The distribution parameters of the transformed distribution \(\mathcal{D}'\).
 property cross_entropy¶
JITcompiled function that computes the crossentropy of a categorical distribution \(q\) relative to another categorical distribution \(p\):
\[\text{CE}[p,q]\ =\ \sum_k p_k \log q_k\] Parameters:
dist_params_p (pytree with ndarray leaves) – The distribution parameters of the base distribution \(p\).
dist_params_q (pytree with ndarray leaves) – The distribution parameters of the auxiliary distribution \(q\).
 property default_priors¶
The default distribution parameters.
 property dist_params_structure¶
The tree structure of the distribution parameters.
 property entropy¶
JITcompiled function that computes the entropy of the distribution.
\[H\ =\ \sum_k p_k \log p_k\] Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
 Returns:
H (ndarray of floats) – A batch of entropy values.
 property hyperparams¶
The distribution hyperparameters.
 property kl_divergence¶
JITcompiled function that computes the KullbackLeibler divergence of a categorical distribution \(q\) relative to another categorical distribution \(p\):
\[\text{KL}[p,q]\ =\ \sum_k p_k \left(\log q_k \log p_k\right)\] Parameters:
dist_params_p (pytree with ndarray leaves) – The distribution parameters of the base distribution \(p\).
dist_params_q (pytree with ndarray leaves) – The distribution parameters of the auxiliary distribution \(q\).
 property log_proba¶
JITcompiled function that evaluates logprobabilities.
 Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
X (ndarray) – A batch of variates, e.g. a batch of actions \(a\) collected from experience.
 Returns:
logP (ndarray of floats) – A batch of logprobabilities associated with the provided variates.
 property mean¶
JITcompiled functions that generates differentiable means of the distribution. Strictly speaking, the mean of a categorical variable is not well defined. We opt for returning the raw probabilities: \(\text{mean}_k=p_k\).
 Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
 Returns:
X (ndarray) – A batch of wouldbe variates \(x\sim\text{Cat}(p)\). In contrast to the output of other methods, these aren’t true variates because they are not almostonehot encoded.
 property mode¶
JITcompiled functions that generates differentiable modes of the distribution, for which we use a similar trick as in Gumbelsoftmax sampling:
\[\text{mode}_k\ =\ \text{softmax}_k\left( \frac{\log p_k}{\tau} \right)\] Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
 Returns:
X (ndarray) – A batch of variates \(x\sim\text{Cat}(p)\). In order to ensure differentiability of the variates this is not an integer, but instead an almostonehot encoded version thereof.
For example, instead of sampling \(x=2\) from a 4class categorical distribution, Gumbelsoftmax will return a vector like \(x=(0.05, 0.02, 0.86, 0.07)\). The latter representation can be viewed as an almostonehot encoded version of the former.
 property sample¶
JITcompiled function that generates differentiable variates using Gumbelsoftmax sampling. \(x\sim\text{Cat}(p)\) is implemented as
\[\begin{split}u_k\ &\sim\ \text{Unif}(0, 1) \\ g_k\ &=\ \log(\log(u_k)) \\ x_k\ &=\ \text{softmax}_k\left( \frac{g_k + \log p_k}{\tau} \right)\end{split}\] Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
rng (PRNGKey) – A key for seeding the pseudorandom number generator.
 Returns:
X (ndarray) – A batch of variates \(x\sim\text{Cat}(p)\). In order to ensure differentiability of the variates this is not an integer, but instead an almostonehot encoded version thereof.
For example, instead of sampling \(x=2\) from a 4class categorical distribution, Gumbelsoftmax will return a vector like \(x=[0.05, 0.02, 0.86, 0.07]\). The latter representation can be viewed as an almostonehot encoded version of the former.
 property space¶
The gymnasiumstyle space that specifies the domain of the distribution.
 class coax.proba_dists.NormalDist(space, clip_box=(256.0, 256.0), clip_reals=(30.0, 30.0), clip_logvar=(20.0, 20.0))[source]¶
A differentiable normal distribution.
The input
dist_params
to each of the functions is expected to be of the form:dist_params = {'mu': array([...]), 'logvar': array([...])}
which represent the (conditional) distribution parameters. Here,
mu
is the mean \(\mu\) andlogvar
is the logvariance \(\log(\sigma^2)\). Parameters:
space (gymnasium.spaces.Box) – The gymnasiumstyle space that specifies the domain of the distribution.
clip_box (pair of floats, optional) – The range of values to allow for clean (compact) variates. This is mainly to ensure reasonable values when one or more dimensions of the Box space have very large ranges, while in reality only a small part of that range is occupied.
clip_reals (pair of floats, optional) – The range of values to allow for raw (decompactified) variates, the reals, used internally. This range is set for numeric stability. Namely, the
postprocess_variate
method compactifies the reals to a closed interval (Box) by applying a logistic sigmoid. Setting a finite range forclip_reals
ensures that the sigmoid doesn’t fully saturate.clip_logvar (pair of floats, optional) – The range of values to allow for the logvariance of the distribution.
 postprocess_variate(rng, X, index=0, batch_mode=False)[source]¶
The postprocessor specific to variates drawn from this ditribution.
This method provides the interface between differentiable, batched variates, i.e. outputs of
sample()
andmode()
and the provided gymnasium space. Parameters:
rng (PRNGKey) – A key for seeding the pseudorandom number generator.
X (raw variates) – A batch of raw clean variates, i.e. same format as the outputs of
sample()
andmode()
.index (int, optional) – The index to pick out from the batch. Note that this only applies if
batch_mode=False
.batch_mode (bool, optional) – Whether to return a batch or a single instance.
 Returns:
x or X (clean variate) – A single clean variate or a batch thereof (if
batch_mode=True
). A variate is called clean if it is an instance of the gymnasiumstylespace
, i.e. it satisfiesx in self.space
.
 preprocess_variate(rng, X)[source]¶
The preprocessor to ensure that an instance of the
space
is processed into the same structure as variates drawn from this ditribution, i.e. outputs ofsample()
andmode()
.
 property affine_transform¶
Transform the distribution \(\mathcal{D}\to\mathcal{D}'\) in such a way that its associated variables \(X\sim\mathcal{D}\) and \(X'\sim\mathcal{D}'\) are related via an affine transformation:
\[X' = X\times\text{scale} + \text{shift}\] Parameters:
dist_params (pytree with ndarray leaves) – The distribution parameters of the original distribution \(\mathcal{D}\).
scale (float or ndarray) – The multiplicative factor of the affine transformation.
shift (float or ndarray) – The additive shift of the affine transformation.
value_transform (ValueTransform, optional) –
The transform to apply to the values before the affine transform, i.e.
\[X' = f\bigl(f^{1}(X)\times\text{scale} + \text{shift}\bigr)\]
 Returns:
dist_params (pytree with ndarray leaves) – The distribution parameters of the transformed distribution \(\mathcal{D}'\).
 property cross_entropy¶
JITcompiled function that computes the crossentropy of a distribution \(q\) relative to another categorical distribution \(p\):
\[\text{CE}[p,q]\ =\ \mathbb{E}_p \log q \ =\ \frac12\left( \log(2\pi\sigma_q^2) + \frac{(\mu_p\mu_q)^2+\sigma_p^2}{\sigma_q^2} \right)\] Parameters:
dist_params_p (pytree with ndarray leaves) – The distribution parameters of the base distribution \(p\).
dist_params_q (pytree with ndarray leaves) – The distribution parameters of the auxiliary distribution \(q\).
 property default_priors¶
The default distribution parameters.
 property dist_params_structure¶
The tree structure of the distribution parameters.
 property entropy¶
JITcompiled function that computes the entropy of the distribution.
\[H\ =\ \mathbb{E}_p \log p \ =\ \frac12\left( \log(2\pi\sigma^2) + 1\right)\] Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
 Returns:
H (ndarray of floats) – A batch of entropy values.
 property hyperparams¶
The distribution hyperparameters.
 property kl_divergence¶
JITcompiled function that computes the KullbackLeibler divergence of a categorical distribution \(q\) relative to another distribution \(p\):
\[\text{KL}[p,q]\ = \mathbb{E}_p \left(\log q \log p\right) \ =\ \frac12\left( \log(\sigma_q^2)  \log(\sigma_p^2) + \frac{(\mu_p\mu_q)^2+\sigma_p^2}{\sigma_q^2}  1 \right)\] Parameters:
dist_params_p (pytree with ndarray leaves) – The distribution parameters of the base distribution \(p\).
dist_params_q (pytree with ndarray leaves) – The distribution parameters of the auxiliary distribution \(q\).
 property log_proba¶
JITcompiled function that evaluates logprobabilities.
 Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
X (ndarray) – A batch of variates, e.g. a batch of actions \(a\) collected from experience.
 Returns:
logP (ndarray of floats) – A batch of logprobabilities associated with the provided variates.
 property mean¶
JITcompiled functions that generates differentiable means of the distribution, in this case simply \(\mu\).
 Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
 Returns:
X (ndarray) – A batch of differentiable variates.
 property mode¶
JITcompiled functions that generates differentiable modes of the distribution, which for a normal distribution is the same as the
mean
. Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
 Returns:
X (ndarray) – A batch of differentiable variates.
 property sample¶
JITcompiled function that generates differentiable variates using the reparametrization trick, i.e. \(x\sim\mathcal{N}(\mu,\sigma^2)\) is implemented as
\[\begin{split}\varepsilon\ &\sim\ \mathcal{N}(0,1) \\ x\ &=\ \mu + \sigma\,\varepsilon\end{split}\] Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
rng (PRNGKey) – A key for seeding the pseudorandom number generator.
 Returns:
X (ndarray) – A batch of differentiable variates.
 property space¶
The gymnasiumstyle space that specifies the domain of the distribution.
 class coax.proba_dists.DiscretizedIntervalDist(space, num_bins=20, gumbel_softmax_tau=0.2)[source]¶
A categorical distribution over a discretized interval.
The input
dist_params
to each of the functions is expected to be of the form:dist_params = {'logits': array([...])}
which represent the (conditional) distribution parameters. The
logits
, denoted \(z\in\mathbb{R}^n\), are related to the categorical distribution parameters \(p\in\Delta^n\) via a softmax:\[p_k\ =\ \text{softmax}_k(z)\ =\ \frac{\text{e}^{z_k}}{\sum_j\text{e}^{z_j}}\] Parameters:
space (gymnasium.spaces.Box) – The gymnasiumstyle space that specifies the domain of the distribution. The shape of the Box must have
prod(shape) == 1
, i.e. a single interval.num_bins (int, optional) – The number of equalsized bins used in the discretization.
gumbel_softmax_tau (positive float, optional) – The parameter \(\tau\) specifies the sharpness of the Gumbelsoftmax sampling (see
sample()
method below). A good value for \(\tau\) balances the tradeoff between getting proper deterministic variates (i.e. onehot vectors) versus getting smooth differentiable variates.
 postprocess_variate(rng, X, index=0, batch_mode=False)[source]¶
The postprocessor specific to variates drawn from this ditribution.
This method provides the interface between differentiable, batched variates, i.e. outputs of
sample()
andmode()
and the provided gymnasium space. Parameters:
rng (PRNGKey) – A key for seeding the pseudorandom number generator.
X (raw variates) – A batch of raw clean variates, i.e. same format as the outputs of
sample()
andmode()
.index (int, optional) – The index to pick out from the batch. Note that this only applies if
batch_mode=False
.batch_mode (bool, optional) – Whether to return a batch or a single instance.
 Returns:
x or X (clean variate) – A single clean variate or a batch thereof (if
batch_mode=True
). A variate is called clean if it is an instance of the gymnasiumstylespace
, i.e. it satisfiesx in self.space
.
 preprocess_variate(rng, X)[source]¶
The preprocessor to ensure that an instance of the
space
is processed into the same structure as variates drawn from this ditribution, i.e. outputs ofsample()
andmode()
.
 property affine_transform¶
Transform the distribution \(\mathcal{D}\to\mathcal{D}'\) in such a way that its associated variables \(X\sim\mathcal{D}\) and \(X'\sim\mathcal{D}'\) are related via an affine transformation:
\[X' = X\times\text{scale} + \text{shift}\] Parameters:
dist_params (pytree with ndarray leaves) – The distribution parameters of the original distribution \(\mathcal{D}\).
scale (float or ndarray) – The multiplicative factor of the affine transformation.
shift (float or ndarray) – The additive shift of the affine transformation.
value_transform (ValueTransform, optional) –
The transform to apply to the values before the affine transform, i.e.
\[X' = f\bigl(f^{1}(X)\times\text{scale} + \text{shift}\bigr)\]
 Returns:
dist_params (pytree with ndarray leaves) – The distribution parameters of the transformed distribution \(\mathcal{D}'\).
 property cross_entropy¶
JITcompiled function that computes the crossentropy of a categorical distribution \(q\) relative to another categorical distribution \(p\):
\[\text{CE}[p,q]\ =\ \sum_k p_k \log q_k\] Parameters:
dist_params_p (pytree with ndarray leaves) – The distribution parameters of the base distribution \(p\).
dist_params_q (pytree with ndarray leaves) – The distribution parameters of the auxiliary distribution \(q\).
 property default_priors¶
The default distribution parameters.
 property dist_params_structure¶
The tree structure of the distribution parameters.
 property entropy¶
JITcompiled function that computes the entropy of the distribution.
\[H\ =\ \sum_k p_k \log p_k\] Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
 Returns:
H (ndarray of floats) – A batch of entropy values.
 property hyperparams¶
The distribution hyperparameters.
 property kl_divergence¶
JITcompiled function that computes the KullbackLeibler divergence of a categorical distribution \(q\) relative to another categorical distribution \(p\):
\[\text{KL}[p,q]\ =\ \sum_k p_k \left(\log q_k \log p_k\right)\] Parameters:
dist_params_p (pytree with ndarray leaves) – The distribution parameters of the base distribution \(p\).
dist_params_q (pytree with ndarray leaves) – The distribution parameters of the auxiliary distribution \(q\).
 property log_proba¶
JITcompiled function that evaluates logprobabilities.
 Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
X (ndarray) – A batch of variates, e.g. a batch of actions \(a\) collected from experience.
 Returns:
logP (ndarray of floats) – A batch of logprobabilities associated with the provided variates.
 property mean¶
JITcompiled functions that generates differentiable means of the distribution. Strictly speaking, the mean of a categorical variable is not well defined. We opt for returning the raw probabilities: \(\text{mean}_k=p_k\).
 Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
 Returns:
X (ndarray) – A batch of wouldbe variates \(x\sim\text{Cat}(p)\). In contrast to the output of other methods, these aren’t true variates because they are not almostonehot encoded.
 property mode¶
JITcompiled functions that generates differentiable modes of the distribution, for which we use a similar trick as in Gumbelsoftmax sampling:
\[\text{mode}_k\ =\ \text{softmax}_k\left( \frac{\log p_k}{\tau} \right)\] Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
 Returns:
X (ndarray) – A batch of variates \(x\sim\text{Cat}(p)\). In order to ensure differentiability of the variates this is not an integer, but instead an almostonehot encoded version thereof.
For example, instead of sampling \(x=2\) from a 4class categorical distribution, Gumbelsoftmax will return a vector like \(x=(0.05, 0.02, 0.86, 0.07)\). The latter representation can be viewed as an almostonehot encoded version of the former.
 property sample¶
JITcompiled function that generates differentiable variates using Gumbelsoftmax sampling. \(x\sim\text{Cat}(p)\) is implemented as
\[\begin{split}u_k\ &\sim\ \text{Unif}(0, 1) \\ g_k\ &=\ \log(\log(u_k)) \\ x_k\ &=\ \text{softmax}_k\left( \frac{g_k + \log p_k}{\tau} \right)\end{split}\] Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
rng (PRNGKey) – A key for seeding the pseudorandom number generator.
 Returns:
X (ndarray) – A batch of variates \(x\sim\text{Cat}(p)\). In order to ensure differentiability of the variates this is not an integer, but instead an almostonehot encoded version thereof.
For example, instead of sampling \(x=2\) from a 4class categorical distribution, Gumbelsoftmax will return a vector like \(x=[0.05, 0.02, 0.86, 0.07]\). The latter representation can be viewed as an almostonehot encoded version of the former.
 property space¶
The gymnasiumstyle space that specifies the domain of the distribution.
 class coax.proba_dists.EmpiricalQuantileDist(num_quantiles)[source]¶
 postprocess_variate(rng, X, index=0, batch_mode=False)¶
The postprocessor specific to variates drawn from this ditribution.
This method provides the interface between differentiable, batched variates, i.e. outputs of
sample()
andmode()
and the provided gymnasium space. Parameters:
rng (PRNGKey) – A key for seeding the pseudorandom number generator.
X (raw variates) – A batch of raw clean variates, i.e. same format as the outputs of
sample()
andmode()
.index (int, optional) – The index to pick out from the batch. Note that this only applies if
batch_mode=False
.batch_mode (bool, optional) – Whether to return a batch or a single instance.
 Returns:
x or X (clean variate) – A single clean variate or a batch thereof (if
batch_mode=True
). A variate is called clean if it is an instance of the gymnasiumstylespace
, i.e. it satisfiesx in self.space
.
 preprocess_variate(rng, X)¶
The preprocessor to ensure that an instance of the
space
is processed into the same structure as variates drawn from this ditribution, i.e. outputs ofsample()
andmode()
.
 property affine_transform¶
Transform the distribution \(\mathcal{D}\to\mathcal{D}'\) in such a way that its associated variables \(X\sim\mathcal{D}\) and \(X'\sim\mathcal{D}'\) are related via an affine transformation:
\[X' = X\times\text{scale} + \text{shift}\] Parameters:
dist_params (pytree with ndarray leaves) – The distribution parameters of the original distribution \(\mathcal{D}\).
scale (float or ndarray) – The multiplicative factor of the affine transformation.
shift (float or ndarray) – The additive shift of the affine transformation.
value_transform (ValueTransform, optional) –
The transform to apply to the values before the affine transform, i.e.
\[X' = f\bigl(f^{1}(X)\times\text{scale} + \text{shift}\bigr)\]
 Returns:
dist_params (pytree with ndarray leaves) – The distribution parameters of the transformed distribution \(\mathcal{D}'\).
 property cross_entropy¶
JITcompiled function that computes the crossentropy of a distribution \(q\) relative to another categorical distribution \(p\):
\[\text{CE}[p,q]\ =\ \mathbb{E}_p \log q\] Parameters:
dist_params_p (pytree with ndarray leaves) – The distribution parameters of the base distribution \(p\).
dist_params_q (pytree with ndarray leaves) – The distribution parameters of the auxiliary distribution \(q\).
 property default_priors¶
The default distribution parameters.
 property dist_params_structure¶
The tree structure of the distribution parameters.
 property entropy¶
JITcompiled function that computes the entropy of the distribution.
\[H\ =\ \mathbb{E}_p \log p\] Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
 Returns:
H (ndarray of floats) – A batch of entropy values.
 property hyperparams¶
The distribution hyperparameters.
 property kl_divergence¶
JITcompiled function that computes the KullbackLeibler divergence of a categorical distribution \(q\) relative to another distribution \(p\):
\[\text{KL}[p,q]\ = \mathbb{E}_p \left(\log q \log p\right)\] Parameters:
dist_params_p (pytree with ndarray leaves) – The distribution parameters of the base distribution \(p\).
dist_params_q (pytree with ndarray leaves) – The distribution parameters of the auxiliary distribution \(q\).
 property log_proba¶
JITcompiled function that evaluates logprobabilities.
 Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
X (ndarray) – A batch of variates, e.g. a batch of actions \(a\) collected from experience.
 Returns:
logP (ndarray of floats) – A batch of logprobabilities associated with the provided variates.
 property mean¶
JITcompiled functions that generates differentiable means of the distribution.
 Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
 Returns:
X (ndarray) – A batch of differentiable variates.
 property mode¶
JITcompiled functions that generates differentiable modes of the distribution.
 Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
 Returns:
X (ndarray) – A batch of differentiable variates.
 property sample¶
JITcompiled function that generates differentiable variates.
 Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
rng (PRNGKey) – A key for seeding the pseudorandom number generator.
 Returns:
X (ndarray) – A batch of differentiable variates.
 property space¶
The gymnasiumstyle space that specifies the domain of the distribution.
 class coax.proba_dists.SquashedNormalDist(space, clip_logvar=None)[source]¶
A differentiable squashed normal distribution.
The input
dist_params
to each of the functions is expected to be of the form:dist_params = {'mu': array([...]), 'logvar': array([...])}
which represent the (conditional) distribution parameters. Here,
mu
is the mean \(\mu\) andlogvar
is the logvariance \(\log(\sigma^2)\). Parameters:
space (gymnasium.spaces.Box) – The gymnasiumstyle space that specifies the domain of the distribution.
clip_logvar (pair of floats, optional) – The range of values to allow for the logvariance of the distribution.
 postprocess_variate(rng, X, index=0, batch_mode=False)[source]¶
The postprocessor specific to variates drawn from this ditribution.
This method provides the interface between differentiable, batched variates, i.e. outputs of
sample()
andmode()
and the provided gymnasium space. Parameters:
rng (PRNGKey) – A key for seeding the pseudorandom number generator.
X (raw variates) – A batch of raw clean variates, i.e. same format as the outputs of
sample()
andmode()
.index (int, optional) – The index to pick out from the batch. Note that this only applies if
batch_mode=False
.batch_mode (bool, optional) – Whether to return a batch or a single instance.
 Returns:
x or X (clean variate) – A single clean variate or a batch thereof (if
batch_mode=True
). A variate is called clean if it is an instance of the gymnasiumstylespace
, i.e. it satisfiesx in self.space
.
 preprocess_variate(rng, X)[source]¶
The preprocessor to ensure that an instance of the
space
is processed into the same structure as variates drawn from this ditribution, i.e. outputs ofsample()
andmode()
.
 property affine_transform¶
Transform the distribution \(\mathcal{D}\to\mathcal{D}'\) in such a way that its associated variables \(X\sim\mathcal{D}\) and \(X'\sim\mathcal{D}'\) are related via an affine transformation:
\[X' = X\times\text{scale} + \text{shift}\] Parameters:
dist_params (pytree with ndarray leaves) – The distribution parameters of the original distribution \(\mathcal{D}\).
scale (float or ndarray) – The multiplicative factor of the affine transformation.
shift (float or ndarray) – The additive shift of the affine transformation.
value_transform (ValueTransform, optional) –
The transform to apply to the values before the affine transform, i.e.
\[X' = f\bigl(f^{1}(X)\times\text{scale} + \text{shift}\bigr)\]
 Returns:
dist_params (pytree with ndarray leaves) – The distribution parameters of the transformed distribution \(\mathcal{D}'\).
 property cross_entropy¶
JITcompiled function that computes the crossentropy of a distribution \(q\) relative to another categorical distribution \(p\):
\[\text{CE}[p,q]\ =\ \mathbb{E}_p \log q \ =\ \frac12\left( \log(2\pi\sigma_q^2) + \frac{(\mu_p\mu_q)^2+\sigma_p^2}{\sigma_q^2} \right)\] Parameters:
dist_params_p (pytree with ndarray leaves) – The distribution parameters of the base distribution \(p\).
dist_params_q (pytree with ndarray leaves) – The distribution parameters of the auxiliary distribution \(q\).
 property default_priors¶
The default distribution parameters.
 property dist_params_structure¶
The tree structure of the distribution parameters.
 property entropy¶
JITcompiled function that computes the entropy of the distribution.
\[H\ =\ \mathbb{E}_p \log p \ =\ \frac12\left( \log(2\pi\sigma^2) + 1\right)\] Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
 Returns:
H (ndarray of floats) – A batch of entropy values.
 property hyperparams¶
The distribution hyperparameters.
 property kl_divergence¶
JITcompiled function that computes the KullbackLeibler divergence of a categorical distribution \(q\) relative to another distribution \(p\):
\[\text{KL}[p,q]\ = \mathbb{E}_p \left(\log q \log p\right) \ =\ \frac12\left( \log(\sigma_q^2)  \log(\sigma_p^2) + \frac{(\mu_p\mu_q)^2+\sigma_p^2}{\sigma_q^2}  1 \right)\] Parameters:
dist_params_p (pytree with ndarray leaves) – The distribution parameters of the base distribution \(p\).
dist_params_q (pytree with ndarray leaves) – The distribution parameters of the auxiliary distribution \(q\).
 property log_proba¶
JITcompiled function that evaluates logprobabilities.
 Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
X (ndarray) – A batch of variates, e.g. a batch of actions \(a\) collected from experience.
 Returns:
logP (ndarray of floats) – A batch of logprobabilities associated with the provided variates.
 property mean¶
JITcompiled functions that generates differentiable means of the distribution, in this case simply \(\tanh(\mu)\).
 Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
 Returns:
X (ndarray) – A batch of differentiable variates.
 property mode¶
JITcompiled functions that generates differentiable modes of the distribution, which for a normal distribution is the same as the
mean
. Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
 Returns:
X (ndarray) – A batch of differentiable variates.
 property sample¶
JITcompiled function that generates differentiable variates using the reparametrization trick, i.e. \(x\sim\tanh(\mathcal{N}(\mu,\sigma^2))\) is implemented as
\[\begin{split}\varepsilon\ &\sim\ \mathcal{N}(0,1) \\ x\ &=\ \tanh(\mu + \sigma\,\varepsilon)\end{split}\] Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
rng (PRNGKey) – A key for seeding the pseudorandom number generator.
 Returns:
X (ndarray) – A batch of differentiable variates.
 property space¶
The gymnasiumstyle space that specifies the domain of the distribution.