Source code for coax._core.stochastic_reward_function

from .stochastic_q import StochasticQ


__all__ = (
    'StochasticRewardFunction',
)


[docs]class StochasticRewardFunction(StochasticQ): r""" A stochastic reward function :math:`p_\theta(r|s,a)`. Parameters ---------- func : function A Haiku-style function that specifies the forward pass. env : gymnasium.Env The gymnasium-style environment. This is used to validate the input/output structure of ``func``. value_range : tuple of floats, optional A pair of floats :code:`(min_value, max_value)`. If left unspecifed, this defaults to :code:`env.reward_range`. num_bins : int, optional The space of rewards is discretized in :code:`num_bins` equal sized bins. We use the default setting of 51 as suggested in the `Distributional RL <https://arxiv.org/abs/1707.06887>`_ paper. observation_preprocessor : function, optional Turns a single observation into a batch of observations in a form that is convenient for feeding into :code:`func`. If left unspecified, this defaults to :func:`default_preprocessor(env.observation_space) <coax.utils.default_preprocessor>`. action_preprocessor : function, optional Turns a single action into a batch of actions in a form that is convenient for feeding into :code:`func`. If left unspecified, this defaults :func:`default_preprocessor(env.action_space) <coax.utils.default_preprocessor>`. value_transform : ValueTransform or pair of funcs, optional If provided, the target for the underlying function approximator is transformed: .. math:: \tilde{G}_t\ =\ f(G_t) This means that calling the function involves undoing this transformation using its inverse :math:`f^{-1}`. The functions :math:`f` and :math:`f^{-1}` are given by ``value_transform.transform_func`` and ``value_transform.inverse_func``, respectively. Note that a ValueTransform is just a glorified pair of functions, i.e. passing ``value_transform=(func, inverse_func)`` works just as well. random_seed : int, optional Seed for pseudo-random number generators. """ def __init__( self, func, env, value_range=None, num_bins=51, observation_preprocessor=None, action_preprocessor=None, value_transform=None, random_seed=None): super().__init__( func, env, value_range=(value_range or env.reward_range), num_bins=51, observation_preprocessor=None, action_preprocessor=None, value_transform=None, random_seed=None)