Source code for coax._core.policy

from ..utils import default_preprocessor
from ..proba_dists import ProbaDist
from .base_stochastic_func_type2 import BaseStochasticFuncType2


[docs]class Policy(BaseStochasticFuncType2): r""" A parametrized policy :math:`\pi_\theta(a|s)`. Parameters ---------- func : function A Haiku-style function that specifies the forward pass. env : gymnasium.Env The gymnasium-style environment. This is used to validate the input/output structure of ``func``. observation_preprocessor : function, optional Turns a single observation into a batch of observations in a form that is convenient for feeding into :code:`func`. If left unspecified, this defaults to :func:`default_preprocessor(env.observation_space) <coax.utils.default_preprocessor>`. proba_dist : ProbaDist, optional A probability distribution that is used to interpret the output of :code:`func <coax.Policy.func>`. Check out the :mod:`coax.proba_dists` module for available options. If left unspecified, this defaults to: .. code:: python proba_dist = coax.proba_dists.ProbaDist(action_space) random_seed : int, optional Seed for pseudo-random number generators. """ def __init__(self, func, env, observation_preprocessor=None, proba_dist=None, random_seed=None): # defaults if observation_preprocessor is None: observation_preprocessor = default_preprocessor(env.observation_space) if proba_dist is None: proba_dist = ProbaDist(env.action_space) super().__init__( func=func, observation_space=env.observation_space, action_space=env.action_space, observation_preprocessor=observation_preprocessor, proba_dist=proba_dist, random_seed=random_seed)
[docs] def __call__(self, s, return_logp=False): r""" Sample an action :math:`a\sim\pi_\theta(.|s)`. Parameters ---------- s : state observation A single state observation :math:`s`. return_logp : bool, optional Whether to return the log-propensity :math:`\log\pi(a|s)`. Returns ------- a : action A single action :math:`a`. logp : float, optional The log-propensity :math:`\log\pi_\theta(a|s)`. This is only returned if we set ``return_logp=True``. """ return super().__call__(s, return_logp=return_logp)
[docs] def mean(self, s): r""" Get the mean of the distribution :math:`\pi_\theta(.|s)`. Note that if the actions are discrete, this returns the :attr:`mode` instead. Parameters ---------- s : state observation A single state observation :math:`s`. Returns ------- a : action A single action :math:`a`. """ return super().mean(s)
[docs] def mode(self, s): r""" Sample a greedy action :math:`a=\arg\max_a\pi_\theta(a|s)`. Parameters ---------- s : state observation A single state observation :math:`s`. Returns ------- a : action A single action :math:`a`. """ return super().mode(s)
[docs] def dist_params(self, s): r""" Get the conditional distribution parameters of :math:`\pi_\theta(.|s)`. Parameters ---------- s : state observation A single state observation :math:`s`. Returns ------- dist_params : Params The distribution parameters of :math:`\pi_\theta(.|s)`. """ return super().dist_params(s)
[docs] @classmethod def example_data( cls, env, observation_preprocessor=None, proba_dist=None, batch_size=1, random_seed=None): # defaults if observation_preprocessor is None: observation_preprocessor = default_preprocessor(env.observation_space) if proba_dist is None: proba_dist = ProbaDist(env.action_space) return super().example_data( env=env, observation_preprocessor=observation_preprocessor, proba_dist=proba_dist, batch_size=batch_size, random_seed=random_seed)