Source code for coax.proba_dists._composite

from enum import Enum

import gymnasium
import numpy as onp
import haiku as hk
import jax

from ..utils import jit
from ._base import BaseProbaDist
from ._categorical import CategoricalDist
from ._normal import NormalDist


__all__ = (
    'ProbaDist',
)


class StructureType(Enum):
    LEAF = 0
    LIST = 1
    DICT = 2


[docs]class ProbaDist(BaseProbaDist): r""" A composite probability distribution. This consists of a nested structure, whose leaves are either :class:`coax.proba_dists.CategoricalDist` or :class:`coax.proba_dists.NormalDist` instances. Parameters ---------- space : gymnasium.Space The gymnasium-style space that specifies the domain of the distribution. This may be any space included in the :mod:`gymnasium.spaces` module. """ __slots__ = BaseProbaDist.__slots__ + ('_structure', '_structure_type') def __init__(self, space): super().__init__(space) if isinstance(self.space, gymnasium.spaces.Discrete): self._structure_type = StructureType.LEAF self._structure = CategoricalDist(space) elif isinstance(self.space, gymnasium.spaces.Box): self._structure_type = StructureType.LEAF self._structure = NormalDist(space) elif isinstance(self.space, gymnasium.spaces.MultiDiscrete): self._structure_type = StructureType.LIST self._structure = [self.__class__(gymnasium.spaces.Discrete(n)) for n in space.nvec] elif isinstance(self.space, gymnasium.spaces.MultiBinary): self._structure_type = StructureType.LIST self._structure = [self.__class__(gymnasium.spaces.Discrete(2)) for _ in range(space.n)] elif isinstance(self.space, gymnasium.spaces.Tuple): self._structure_type = StructureType.LIST self._structure = [self.__class__(sp) for sp in space.spaces] elif isinstance(self.space, gymnasium.spaces.Dict): self._structure_type = StructureType.DICT self._structure = {k: self.__class__(sp) for k, sp in space.spaces.items()} else: raise TypeError(f"unsupported space: {space.__class__.__name__}") def sample(dist_params, rng): if self._structure_type == StructureType.LEAF: return self._structure.sample(dist_params, rng) rngs = hk.PRNGSequence(rng) if self._structure_type == StructureType.LIST: return [ dist.sample(dist_params[i], next(rngs)) for i, dist in enumerate(self._structure)] if self._structure_type == StructureType.DICT: return { k: dist.sample(dist_params[k], next(rngs)) for k, dist in self._structure.items()} raise AssertionError(f"bad structure_type: {self._structure_type}") def mean(dist_params): if self._structure_type == StructureType.LEAF: return self._structure.mean(dist_params) if self._structure_type == StructureType.LIST: return [ dist.mean(dist_params[i]) for i, dist in enumerate(self._structure)] if self._structure_type == StructureType.DICT: return { k: dist.mean(dist_params[k]) for k, dist in self._structure.items()} raise AssertionError(f"bad structure_type: {self._structure_type}") def mode(dist_params): if self._structure_type == StructureType.LEAF: return self._structure.mode(dist_params) if self._structure_type == StructureType.LIST: return [ dist.mode(dist_params[i]) for i, dist in enumerate(self._structure)] if self._structure_type == StructureType.DICT: return { k: dist.mode(dist_params[k]) for k, dist in self._structure.items()} raise AssertionError(f"bad structure_type: {self._structure_type}") def log_proba(dist_params, X): if self._structure_type == StructureType.LEAF: return self._structure.log_proba(dist_params, X) if self._structure_type == StructureType.LIST: return sum( dist.log_proba(dist_params[i], X[i]) for i, dist in enumerate(self._structure)) if self._structure_type == StructureType.DICT: return sum( dist.log_proba(dist_params[k], X[k]) for k, dist in self._structure.items()) def entropy(dist_params): if self._structure_type == StructureType.LEAF: return self._structure.entropy(dist_params) if self._structure_type == StructureType.LIST: return sum( dist.entropy(dist_params[i]) for i, dist in enumerate(self._structure)) if self._structure_type == StructureType.DICT: return sum( dist.entropy(dist_params[k]) for k, dist in self._structure.items()) raise AssertionError(f"bad structure_type: {self._structure_type}") def cross_entropy(dist_params_p, dist_params_q): if self._structure_type == StructureType.LEAF: return self._structure.cross_entropy(dist_params_p, dist_params_q) if self._structure_type == StructureType.LIST: return sum( dist.cross_entropy(dist_params_p[i], dist_params_q[i]) for i, dist in enumerate(self._structure)) if self._structure_type == StructureType.DICT: return sum( dist.cross_entropy(dist_params_p[k], dist_params_q[k]) for k, dist in self._structure.items()) raise AssertionError(f"bad structure_type: {self._structure_type}") def kl_divergence(dist_params_p, dist_params_q): if self._structure_type == StructureType.LEAF: return self._structure.kl_divergence(dist_params_p, dist_params_q) if self._structure_type == StructureType.LIST: return sum( dist.kl_divergence(dist_params_p[i], dist_params_q[i]) for i, dist in enumerate(self._structure)) if self._structure_type == StructureType.DICT: return sum( dist.kl_divergence(dist_params_p[k], dist_params_q[k]) for k, dist in self._structure.items()) raise AssertionError(f"bad structure_type: {self._structure_type}") def affine_transform(dist_params, scale, shift, value_transform=None): if value_transform is None: value_transform = jax.tree_map(lambda _: None, self._structure) if self._structure_type == StructureType.LEAF: return self._structure.affine_transform(dist_params, scale, shift, value_transform) if self._structure_type == StructureType.LIST: assert len(dist_params) == len(scale) == len(shift) == len(self._structure) assert len(value_transform) == len(self._structure) return [ dist.affine_transform(dist_params[i], scale[i], shift[i], value_transform[i]) for i, dist in enumerate(self._structure)] if self._structure_type == StructureType.DICT: assert set(dist_params) == set(scale) == set(shift) == set(self._structure) assert set(value_transform) == set(self._structure) return { k: dist.affine_transform(dist_params[k], scale[k], shift[k], value_transform[k]) for k, dist in self._structure.items()} raise AssertionError(f"bad structure_type: {self._structure_type}") self._sample_func = jit(sample) self._mean_func = jit(mean) self._mode_func = jit(mode) self._log_proba_func = jit(log_proba) self._entropy_func = jit(entropy) self._cross_entropy_func = jit(cross_entropy) self._kl_divergence_func = jit(kl_divergence) self._affine_transform_func = jit(affine_transform) @property def hyperparams(self): if self._structure_type == StructureType.LEAF: return self._structure.hyperparams if self._structure_type == StructureType.LIST: return tuple(dist.hyperparams for dist in self._structure) if self._structure_type == StructureType.DICT: return {k: dist.hyperparams for k, dist in self._structure.items()} raise AssertionError(f"bad structure_type: {self._structure_type}") @property def default_priors(self): if self._structure_type == StructureType.LEAF: return self._structure.default_priors if self._structure_type == StructureType.LIST: return tuple(dist.default_priors for dist in self._structure) if self._structure_type == StructureType.DICT: return {k: dist.default_priors for k, dist in self._structure.items()} raise AssertionError(f"bad structure_type: {self._structure_type}")
[docs] def postprocess_variate(self, rng, X, index=0, batch_mode=False): rngs = hk.PRNGSequence(rng) if self._structure_type == StructureType.LEAF: return self._structure.postprocess_variate( next(rngs), X, index=index, batch_mode=batch_mode) if isinstance(self.space, (gymnasium.spaces.MultiDiscrete, gymnasium.spaces.MultiBinary)): assert self._structure_type == StructureType.LIST return onp.stack([ dist.postprocess_variate(next(rngs), X[i], index=index, batch_mode=batch_mode) for i, dist in enumerate(self._structure)], axis=-1) if isinstance(self.space, gymnasium.spaces.Tuple): assert self._structure_type == StructureType.LIST return tuple( dist.postprocess_variate(next(rngs), X[i], index=index, batch_mode=batch_mode) for i, dist in enumerate(self._structure)) if isinstance(self.space, gymnasium.spaces.Dict): assert self._structure_type == StructureType.DICT return { k: dist.postprocess_variate(next(rngs), X[k], index=index, batch_mode=batch_mode) for k, dist in self._structure.items()} raise AssertionError( f"postprocess_variate not implemented for space: {self.space.__class__.__name__}; " "please send us a bug report / feature request")
[docs] def preprocess_variate(self, rng, X): rngs = hk.PRNGSequence(rng) if self._structure_type == StructureType.LEAF: return self._structure.preprocess_variate(next(rngs), X) if isinstance(self.space, (gymnasium.spaces.MultiDiscrete, gymnasium.spaces.MultiBinary)): assert self._structure_type == StructureType.LIST return [ dist.preprocess_variate(next(rngs), X[..., i]) for i, dist in enumerate(self._structure)] if self._structure_type == StructureType.LIST: return [ dist.preprocess_variate(next(rngs), X[i]) for i, dist in enumerate(self._structure)] if self._structure_type == StructureType.DICT: return { k: dist.preprocess_variate(next(rngs), X[k]) for k, dist in self._structure.items()} raise AssertionError(f"bad structure_type: {self._structure_type}")