Source code for coax.proba_dists._empirical_quantile

import chex
import jax
import jax.numpy as jnp
from gymnasium.spaces import Box

from ..utils import jit, isscalar
from ._base import BaseProbaDist


__all__ = (
    'EmpiricalQuantileDist',
)


[docs]class EmpiricalQuantileDist(BaseProbaDist): def __init__(self, num_quantiles): self.num_quantiles = num_quantiles super().__init__(Box(low=-jnp.inf, high=jnp.inf, shape=[num_quantiles])) def check_shape(x, name): if not isinstance(x, jnp.ndarray): raise TypeError(f"expected an jax.numpy.ndarray, got: {type(x)}") return x def mean(dist_params): values = check_shape(dist_params['values'], 'values') return jnp.mean(values, axis=-1) def sample(dist_params, rng): # bootstrapping values = check_shape(dist_params['values'], 'values') return jax.random.choice(rng, values, values.shape, replace=True) def log_proba(dist_params, X): X = check_shape(X, 'X') values = check_shape(dist_params['values'], 'values') occurrences = jnp.mean(X[None, ...] == values[..., None], axis=-1) return jnp.log(occurrences) def affine_transform(dist_params, scale, shift, value_transform=None): chex.assert_rank([dist_params['values'], scale, shift], [2, {0, 1}, {0, 1}]) values = check_shape(dist_params['values'], 'values') quantile_fractions = check_shape( dist_params['quantile_fractions'], 'quantile_fractions') batch_size = values.shape[0] if isscalar(scale): scale = jnp.full(shape=(batch_size, 1), fill_value=jnp.squeeze(scale)) if isscalar(shift): shift = jnp.full(shape=(batch_size, 1), fill_value=jnp.squeeze(shift)) scale = jnp.reshape(scale, (batch_size, 1)) shift = jnp.reshape(shift, (batch_size, 1)) chex.assert_shape(values, (batch_size, self.num_quantiles)) chex.assert_shape([scale, shift], (batch_size, 1)) if value_transform is None: f = f_inv = lambda x: x else: f, f_inv = value_transform return {'values': f(shift + scale * f_inv(values)), 'quantile_fractions': quantile_fractions} self._sample_func = jit(sample) self._mean_func = jit(mean) self._log_proba_func = jit(log_proba) self._affine_transform_func = jit(affine_transform, static_argnums=(3,)) @property def default_priors(self): return {'values': jnp.zeros((1, self.num_quantiles)), 'quantile_fractions': jnp.ones((1, self.num_quantiles,))} @property def sample(self): return self._sample_func @property def mean(self): return self._mean_func @property def log_proba(self): return self._log_proba_func