Source code for coax._core.v

from inspect import signature
from collections import namedtuple

import jax
import jax.numpy as jnp
import numpy as onp
import haiku as hk
from gymnasium.spaces import Space

from ..utils import safe_sample, default_preprocessor
from ..value_transforms import ValueTransform
from .base_func import BaseFunc, ExampleData, Inputs, ArgsType2


__all__ = (
    'V',
)


[docs]class V(BaseFunc): r""" A state value function :math:`v_\theta(s)`. Parameters ---------- func : function A Haiku-style function that specifies the forward pass. The function signature must be the same as the example below. env : gymnasium.Env The gymnasium-style environment. This is used to validate the input/output structure of ``func``. observation_preprocessor : function, optional Turns a single observation into a batch of observations in a form that is convenient for feeding into :code:`func`. If left unspecified, this defaults to :func:`default_preprocessor(env.observation_space) <coax.utils.default_preprocessor>`. value_transform : ValueTransform or pair of funcs, optional If provided, the target for the underlying function approximator is transformed such that: .. math:: \tilde{v}_\theta(S_t)\ \approx\ f(G_t) This means that calling the function involves undoing this transformation: .. math:: v(s)\ =\ f^{-1}(\tilde{v}_\theta(s)) Here, :math:`f` and :math:`f^{-1}` are given by ``value_transform.transform_func`` and ``value_transform.inverse_func``, respectively. Note that a ValueTransform is just a glorified pair of functions, i.e. passing ``value_transform=(func, inverse_func)`` works just as well. random_seed : int, optional Seed for pseudo-random number generators. """ def __init__( self, func, env, observation_preprocessor=None, value_transform=None, random_seed=None): self.observation_preprocessor = observation_preprocessor self.value_transform = value_transform # defaults if self.observation_preprocessor is None: self.observation_preprocessor = default_preprocessor(env.observation_space) if self.value_transform is None: self.value_transform = ValueTransform(lambda x: x, lambda x: x) if not isinstance(self.value_transform, ValueTransform): self.value_transform = ValueTransform(*value_transform) super().__init__( func=func, observation_space=env.observation_space, action_space=None, random_seed=random_seed)
[docs] def __call__(self, s): r""" Evaluate the value function on a state observation :math:`s`. Parameters ---------- s : state observation A single state observation :math:`s`. Returns ------- v : ndarray, shape: () The estimated expected value associated with the input state observation ``s``. """ S = self.observation_preprocessor(self.rng, s) V, _ = self.function(self.params, self.function_state, self.rng, S, False) V = self.value_transform.inverse_func(V) return onp.asarray(V[0])
[docs] @classmethod def example_data(cls, env, observation_preprocessor=None, batch_size=1, random_seed=None): if not isinstance(env.observation_space, Space): raise TypeError( "env.observation_space must be derived from gymnasium.Space, " f"got: {type(env.observation_space)}") if observation_preprocessor is None: observation_preprocessor = default_preprocessor(env.observation_space) rnd = onp.random.RandomState(random_seed) rngs = hk.PRNGSequence(rnd.randint(jnp.iinfo('int32').max)) # input: state observations S = [safe_sample(env.observation_space, rnd) for _ in range(batch_size)] S = [observation_preprocessor(next(rngs), s) for s in S] S = jax.tree_map(lambda *x: jnp.concatenate(x, axis=0), *S) return ExampleData( inputs=Inputs(args=ArgsType2(S=S, is_training=True), static_argnums=(1,)), output=jnp.asarray(rnd.randn(batch_size)), )
def _check_signature(self, func): if tuple(signature(func).parameters) != ('S', 'is_training'): sig = ', '.join(signature(func).parameters) raise TypeError( f"func has bad signature; expected: func(S, is_training), got: func({sig})") # example inputs Env = namedtuple('Env', ('observation_space',)) return self.example_data( env=Env(self.observation_space), observation_preprocessor=self.observation_preprocessor, batch_size=1, random_seed=self.random_seed, ) def _check_output(self, actual, expected): if not isinstance(actual, jnp.ndarray): raise TypeError( f"func has bad return type; expected jnp.ndarray, got {actual.__class__.__name__}") if not jnp.issubdtype(actual.dtype, jnp.floating): raise TypeError( "func has bad return dtype; expected a subdtype of jnp.floating, " f"got dtype={actual.dtype}") if actual.shape != expected.shape: raise TypeError( f"func has bad return shape, expected: {expected.shape}, got: {actual.shape}")