Source code for coax._core.transition_model

from inspect import signature
from collections import namedtuple

import jax
import jax.numpy as jnp
import numpy as onp
import haiku as hk
from gymnasium.spaces import Space, Discrete

from ..utils import safe_sample, batch_to_single, default_preprocessor
from ..proba_dists import ProbaDist
from .base_func import BaseFunc, ExampleData, Inputs, ArgsType1, ArgsType2, ModelTypes


__all__ = (
    'TransitionModel',
)


[docs]class TransitionModel(BaseFunc): r""" A deterministic transition function :math:`s'_\theta(s,a)`. Parameters ---------- func : function A Haiku-style function that specifies the forward pass. The function signature must be the same as the example below. env : gymnasium.Env The gymnasium-style environment. This is used to validate the input/output structure of ``func``. observation_preprocessor : function, optional Turns a single observation into a batch of observations in a form that is convenient for feeding into :code:`func`. If left unspecified, this defaults to :attr:`proba_dist.preprocess_variate <coax.proba_dists.ProbaDist.preprocess_variate>`. The reason why the default is not :func:`coax.utils.default_preprocessor` is that we prefer consistence with :class:`coax.StochasticTransitionModel`. observation_postprocessor : function, optional Takes a batch of generated observations and makes sure that they are that are compatible with the original :code:`observation_space`. If left unspecified, this defaults to :attr:`proba_dist.postprocess_variate <coax.proba_dists.ProbaDist.postprocess_variate>`. action_preprocessor : function, optional Turns a single action into a batch of actions in a form that is convenient for feeding into :code:`func`. If left unspecified, this defaults :func:`default_preprocessor(env.action_space) <coax.utils.default_preprocessor>`. random_seed : int, optional Seed for pseudo-random number generators. """ def __init__( self, func, env, observation_preprocessor=None, observation_postprocessor=None, action_preprocessor=None, random_seed=None): self.observation_preprocessor = observation_preprocessor self.observation_postprocessor = observation_postprocessor self.action_preprocessor = action_preprocessor # defaults if self.observation_preprocessor is None: self.observation_preprocessor = ProbaDist(env.observation_space).preprocess_variate if self.observation_postprocessor is None: self.observation_postprocessor = ProbaDist(env.observation_space).postprocess_variate if self.action_preprocessor is None: self.action_preprocessor = default_preprocessor(env.action_space) super().__init__( func, observation_space=env.observation_space, action_space=env.action_space, random_seed=random_seed)
[docs] def __call__(self, s, a=None): r""" Evaluate the state-action function on a state observation :math:`s` or on a state-action pair :math:`(s, a)`. Parameters ---------- s : state observation A single state observation :math:`s`. a : action A single action :math:`a`. Returns ------- q_sa or q_s : ndarray Depending on whether :code:`a` is provided, this either returns a scalar representing :math:`q(s,a)\in\mathbb{R}` or a vector representing :math:`q(s,.)\in\mathbb{R}^n`, where :math:`n` is the number of discrete actions. Naturally, this only applies for discrete action spaces. """ S = self.observation_preprocessor(self.rng, s) if a is None: S_next, _ = self.function_type2(self.params, self.function_state, self.rng, S, False) S_next = batch_to_single(S_next) # (batch, num_actions, *) -> (num_actions, *) n = self.action_space.n s_next = [self.observation_postprocessor(self.rng, S_next, index=i) for i in range(n)] else: A = self.action_preprocessor(self.rng, a) S_next, _ = self.function_type1(self.params, self.function_state, self.rng, S, A, False) s_next = self.observation_postprocessor(self.rng, S_next) return s_next
@property def function_type1(self): r""" Same as :attr:`function`, except that it ensures a type-1 function signature, regardless of the underlying :attr:`modeltype`. """ if self.modeltype == 1: return self.function assert isinstance(self.action_space, Discrete) def project(A): assert A.ndim == 2, f"bad shape: {A.shape}" assert A.shape[1] == self.action_space.n, f"bad shape: {A.shape}" def func(leaf): # noqa: E306 assert isinstance(leaf, jnp.ndarray), f"leaf must be ndarray, got: {type(leaf)}" assert leaf.ndim >= 2, f"bad shape: {leaf.shape}" assert leaf.shape[0] == A.shape[0], \ f"batch_size (axis=0) mismatch: leaf.shape: {leaf.shape}, A.shape: {A.shape}" assert leaf.shape[1] == A.shape[1], \ f"num_actions (axis=1) mismatch: leaf.shape: {leaf.shape}, A.shape: {A.shape}" return jax.vmap(jnp.dot)(jnp.moveaxis(leaf, 1, -1), A) return func def type1_func(type2_params, type2_state, rng, S, A, is_training): S_next, state_new = self.function(type2_params, type2_state, rng, S, is_training) S_next = jax.tree_map(project(A), S_next) return S_next, state_new return type1_func @property def function_type2(self): r""" Same as :attr:`function`, except that it ensures a type-2 function signature, regardless of the underlying :attr:`modeltype`. """ if self.modeltype == 2: return self.function if not isinstance(self.action_space, Discrete): raise ValueError( "input 'A' is required for type-1 dynamics model when action space is non-Discrete") n = self.action_space.n def reshape(leaf): # reshape from (batch * num_actions, *shape) -> (batch, *shape, num_actions) assert isinstance(leaf, jnp.ndarray), f"all leaves must be ndarray, got: {type(leaf)}" assert leaf.ndim >= 1, f"bad shape: {leaf.shape}" assert leaf.shape[0] % n == 0, \ f"first axis size must be a multiple of num_actions, got shape: {leaf.shape}" leaf = jnp.reshape(leaf, (-1, n, *leaf.shape[1:])) # (batch, num_actions, *shape) return leaf def type2_func(type1_params, type1_state, rng, S, is_training): rngs = hk.PRNGSequence(rng) batch_size = jax.tree_util.tree_leaves(S)[0].shape[0] # example: let S = [7, 2, 5, 8] and num_actions = 3, then # S_rep = [7, 7, 7, 2, 2, 2, 5, 5, 5, 8, 8, 8] # repeated # A_rep = [0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2] # tiled S_rep = jax.tree_map(lambda x: jnp.repeat(x, n, axis=0), S) A_rep = jnp.tile(jnp.arange(n), batch_size) A_rep = self.action_preprocessor(next(rngs), A_rep) # one-hot encoding # evaluate on replicas => output shape: (batch * num_actions, *shape) S_next_rep, state_new = self.function( type1_params, type1_state, next(rngs), S_rep, A_rep, is_training) S_next = jax.tree_map(reshape, S_next_rep) return S_next, state_new return type2_func @property def modeltype(self): r""" Specifier for how the transition function is modeled, i.e. .. math:: (s,a) &\mapsto s'(s,a) &\qquad (\text{modeltype} &= 1) \\ s &\mapsto s'(s,.) &\qquad (\text{modeltype} &= 2) Note that modeltype=2 is only well-defined if the action space is :class:`Discrete <gymnasium.spaces.Discrete>`. Namely, :math:`n` is the number of discrete actions. """ return self._modeltype
[docs] @classmethod def example_data( cls, env, observation_preprocessor=None, action_preprocessor=None, batch_size=1, random_seed=None): if not isinstance(env.observation_space, Space): raise TypeError( "env.observation_space must be derived from gymnasium.Space, " f"got: {type(env.observation_space)}") if not isinstance(env.action_space, Space): raise TypeError( "env.action_space must be derived from gymnasium.Space, " f"got: {type(env.action_space)}") if observation_preprocessor is None: observation_preprocessor = ProbaDist(env.observation_space).preprocess_variate if action_preprocessor is None: action_preprocessor = default_preprocessor(env.action_space) rnd = onp.random.RandomState(random_seed) rngs = hk.PRNGSequence(rnd.randint(jnp.iinfo('int32').max)) # input: state observations S = [safe_sample(env.observation_space, rnd) for _ in range(batch_size)] S = [observation_preprocessor(next(rngs), s) for s in S] S = jax.tree_map(lambda *x: jnp.concatenate(x, axis=0), *S) # input: actions A = [safe_sample(env.action_space, rnd) for _ in range(batch_size)] A = [action_preprocessor(next(rngs), a) for a in A] A = jax.tree_map(lambda *x: jnp.concatenate(x, axis=0), *A) # output: type1 S_next_type1 = jax.tree_map(lambda x: jnp.asarray(rnd.randn(batch_size, *x.shape[1:])), S) q1_data = ExampleData( inputs=Inputs(args=ArgsType1(S=S, A=A, is_training=True), static_argnums=(2,)), output=S_next_type1) if not isinstance(env.action_space, Discrete): return ModelTypes(type1=q1_data, type2=None) # output: type2 (if actions are discrete) S_next_type2 = jax.tree_map( lambda x: jnp.asarray(rnd.randn(batch_size, env.action_space.n, *x.shape[1:])), S) q2_data = ExampleData( inputs=Inputs(args=ArgsType2(S=S, is_training=True), static_argnums=(1,)), output=S_next_type2) return ModelTypes(type1=q1_data, type2=q2_data)
def _check_signature(self, func): sig_type1 = ('S', 'A', 'is_training') sig_type2 = ('S', 'is_training') sig = tuple(signature(func).parameters) if sig not in (sig_type1, sig_type2): sig = ', '.join(sig) alt = ' or func(S, is_training)' if isinstance(self.action_space, Discrete) else '' raise TypeError( f"func has bad signature; expected: func(S, A, is_training){alt}, got: func({sig})") if sig == sig_type2 and not isinstance(self.action_space, Discrete): raise TypeError("type-2 models are only well-defined for Discrete action spaces") Env = namedtuple('Env', ('observation_space', 'action_space')) example_data_per_modeltype = self.example_data( env=Env(self.observation_space, self.action_space), action_preprocessor=self.action_preprocessor, batch_size=1, random_seed=self.random_seed) if sig == sig_type1: self._modeltype = 1 example_data = example_data_per_modeltype.type1 else: self._modeltype = 2 example_data = example_data_per_modeltype.type2 return example_data def _check_output(self, actual, expected): expected_leaves, expected_structure = jax.tree_util.tree_flatten(expected) actual_leaves, actual_structure = jax.tree_util.tree_flatten(actual) assert all(isinstance(x, jnp.ndarray) for x in expected_leaves), "bad example_data" if actual_structure != expected_structure: raise TypeError( f"func has bad return tree_structure, expected: {expected_structure}, " f"got: {actual_structure}") if not all(isinstance(x, jnp.ndarray) for x in actual_leaves): bad_types = tuple(type(x) for x in actual_leaves if not isinstance(x, jnp.ndarray)) raise TypeError( "all leaves of dist_params must be of type: jax.numpy.ndarray, " f"found leaves of type: {bad_types}") if not all(a.shape == b.shape for a, b in zip(actual_leaves, expected_leaves)): shapes_tree = jax.tree_map( lambda a, b: f"{a.shape} {'!=' if a.shape != b.shape else '=='} {b.shape}", actual, expected) raise TypeError(f"found leaves with unexpected shapes: {shapes_tree}")