Source code for coax._core.random_policy

import gymnasium
import jax.numpy as jnp
import numpy as onp

from ..utils import docstring
from .policy import Policy


__all__ = (
    'RandomPolicy',
)


[docs]class RandomPolicy: r""" A simple random policy. Parameters ---------- env : gymnasium.Env The gymnasium-style environment. This is only used to get the :code:`env.action_space`. random_seed : int, optional Sets the random state to get reproducible results. """ def __init__(self, env, random_seed=None): if not isinstance(env.action_space, gymnasium.Space): raise TypeError( f"env.action_space must be a gymnasium.Space, got: {type(env.action_space)}") self.action_space = env.action_space self.action_space.seed(random_seed) self.random_seed = random_seed
[docs] @docstring(Policy.__call__) def __call__(self, s, return_logp=False): a = self.action_space.sample() if not return_logp: return a if isinstance(self.action_space, gymnasium.spaces.Discrete): logp = -onp.log(self.num_actions) return a, logp if isinstance(self.action_space, gymnasium.spaces.Box): sizes = self.action_space.high - self.action_space.low logp = -onp.sum(onp.log(sizes)) # log(prod(1/sizes)) return a, logp raise NotImplementedError( "the log-propensity of a 'uniform' distribution over a " f"{self.action_space.__class__.__name__} space is not yet implemented; " "please submit a feature request")
[docs] @docstring(Policy.mode) def mode(self, s): return self(s, return_logp=False)
[docs] @docstring(Policy.dist_params) def dist_params(self, s): if isinstance(self.action_space, gymnasium.spaces.Discrete): return {'logits': jnp.zeros(self.action_space.n)} if isinstance(self.action_space, gymnasium.spaces.Box): return { 'mu': jnp.zeros(self.action_space.shape), 'logvar': 15 * jnp.ones(self.action_space.shape)} raise NotImplementedError( "the dist_params of a 'uniform' distribution over a " f"{self.action_space.__class__.__name__} space is not yet implemented; " "please submit a feature request")