Source code for coax._core.stochastic_v

from gymnasium.spaces import Box

from ..utils import default_preprocessor
from ..proba_dists import DiscretizedIntervalDist
from ..value_transforms import ValueTransform
from .base_stochastic_func_type2 import BaseStochasticFuncType2


__all__ = (
    'StochasticV',
)


[docs]class StochasticV(BaseStochasticFuncType2): r""" A state-value function :math:`v(s)`, represented by a stochastic function :math:`\mathbb{P}_\theta(G_t|S_t=s)`. Parameters ---------- func : function A Haiku-style function that specifies the forward pass. env : gymnasium.Env The gymnasium-style environment. This is used to validate the input/output structure of ``func``. value_range : tuple of floats A pair of floats :code:`(min_value, max_value)`. num_bins : int, optional The space of rewards is discretized in :code:`num_bins` equal sized bins. We use the default setting of 51 as suggested in the `Distributional RL <https://arxiv.org/abs/1707.06887>`_ paper. observation_preprocessor : function, optional Turns a single observation into a batch of observations in a form that is convenient for feeding into :code:`func`. If left unspecified, this defaults to :func:`default_preprocessor(env.observation_space) <coax.utils.default_preprocessor>`. value_transform : ValueTransform or pair of funcs, optional If provided, the target for the underlying function approximator is transformed: .. math:: \tilde{G}_t\ =\ f(G_t) This means that calling the function involves undoing this transformation using its inverse :math:`f^{-1}`. The functions :math:`f` and :math:`f^{-1}` are given by ``value_transform.transform_func`` and ``value_transform.inverse_func``, respectively. Note that a ValueTransform is just a glorified pair of functions, i.e. passing ``value_transform=(func, inverse_func)`` works just as well. random_seed : int, optional Seed for pseudo-random number generators. """ def __init__( self, func, env, value_range, num_bins=51, observation_preprocessor=None, value_transform=None, random_seed=None): self.value_transform = value_transform self.value_range = self._check_value_range(value_range) proba_dist = self._get_proba_dist(self.value_range, value_transform, num_bins) # set defaults if observation_preprocessor is None: observation_preprocessor = default_preprocessor(env.observation_space) if self.value_transform is None: self.value_transform = ValueTransform(lambda x: x, lambda x: x) if not isinstance(self.value_transform, ValueTransform): self.value_transform = ValueTransform(*value_transform) super().__init__( func=func, observation_space=env.observation_space, action_space=env.action_space, observation_preprocessor=observation_preprocessor, proba_dist=proba_dist, random_seed=random_seed) @property def num_bins(self): return self.proba_dist.space.n
[docs] @classmethod def example_data( cls, env, value_range, num_bins=51, observation_preprocessor=None, value_transform=None, batch_size=1, random_seed=None): value_range = cls._check_value_range(value_range) proba_dist = cls._get_proba_dist(value_range, value_transform, num_bins) if observation_preprocessor is None: observation_preprocessor = default_preprocessor(env.observation_space) return super().example_data( env=env, observation_preprocessor=observation_preprocessor, proba_dist=proba_dist, batch_size=batch_size, random_seed=random_seed)
[docs] def __call__(self, s, return_logp=False): r""" Sample a value. Parameters ---------- s : state observation A single state observation :math:`s`. return_logp : bool, optional Whether to return the log-propensity associated with the sampled output value. Returns ------- value : float or list thereof A single value associated with the state observation :math:`s`. logp : non-positive float or list thereof, optional The log-propensity associated with the sampled output value. This is only returned if we set ``return_logp=True``. Depending on whether :code:`a` is provided, this is either a single float or a list of :math:`n` floats, one for each discrete action. """ return super().__call__(s, return_logp=return_logp)
[docs] def mean(self, s): r""" Get the mean value. Parameters ---------- s : state observation A single state observation :math:`s`. Returns ------- value : float A single value associated with the state observation :math:`s`. """ return super().mean(s)
[docs] def mode(self, s): r""" Get the most probable value. Parameters ---------- s : state observation A single state observation :math:`s`. Returns ------- value : float A single value associated with the state observation :math:`s`. """ return super().mode(s)
[docs] def dist_params(self, s): r""" Get the parameters of the underlying (conditional) probability distribution. Parameters ---------- s : state observation A single state observation :math:`s`. Returns ------- dist_params : dict or list of dicts Depending on whether :code:`a` is provided, this either returns a single dist-params dict or a list of :math:`n` such dicts, one for each discrete action. """ return super().dist_params(s)
@staticmethod def _get_proba_dist(value_range, value_transform, num_bins): if value_transform is not None: f, _ = value_transform value_range = f(value_range[0]), f(value_range[1]) reward_space = Box(*value_range, shape=()) return DiscretizedIntervalDist(reward_space, num_bins) @staticmethod def _check_value_range(value_range): if not (isinstance(value_range, (tuple, list)) and len(value_range) == 2 and isinstance(value_range[0], (int, float)) and isinstance(value_range[1], (int, float)) and value_range[0] < value_range[1]): raise TypeError("value_range is not a valid pair tuple of floats: (low, high)") return float(value_range[0]), float(value_range[1])