Dynamics Models

coax.TransitionModel

A deterministic transition function \(s'_\theta(s,a)\).

coax.RewardFunction

A deterministic reward function \(r_\theta(s,a)\).

coax.StochasticTransitionModel

A stochastic transition model \(p_\theta(s'|s,a)\).

coax.StochasticRewardFunction

A stochastic reward function \(p_\theta(r|s,a)\).


Model-based methods make use of models that estimate the dynamics of transitions in a Markov decision process. In coax we offers two types of such models: a transition model \(p(s'|s,a)\) and a reward function \(r(s,a)\), where \(s'\) is a successor state and \(r(s,a)\) represents an immediate reward. Both distributions are conditioned on taking action \(a\) from state \(s\).

Coax allows you to define your own dynamics models with a function approximator, similar to how we define value functions and policies. A dynamics model is may be represented either by a deterministic or a stochastic function approximator. In the stochastic case, the forward-pass function returns distribution parameters \(\varphi\) that depend on the input state-action pair, i.e. \(\varphi_\theta(s,a)\). A common case is where the observation space is a Box, which means that the distribution parameters are the parameters of a Gaussian distribution, \(\varphi_\theta(s,a)=(\mu_\theta(s,a), \Sigma_\theta(s,a))\).

Transition Models

In this example we see how to construct a deterministic transition model \(p(s'|s,a)\). Note that the construction of a stochastic transition model is very similar to the construction of a coax.Policy, see Policies.

Let’s create some example data.

import coax
import gymnasium

env = gymnasium.make('CartPole-v0')
data = coax.TransitionModel.example_data(env)

print(data.type1)
# ExampleData(
#   inputs=Inputs(
#     args=ArgsType1(
#       S=array(shape=(1, 4), dtype=float32)
#       A=array(shape=(1, 2), dtype=float32)
#       is_training=True)
#     static_argnums=(2,))
#   output=array(shape=(1, 4), dtype=float32))

print(data.type2)
# ExampleData(
#   inputs=Inputs(
#     args=ArgsType2(
#       S=array(shape=(1, 4), dtype=float32)
#       is_training=True)
#     static_argnums=(1,))
#   output=array(shape=(1, 2, 4), dtype=float32))

Note that, similar to q-functions, there are two types of handling a discrete action space:

\[\begin{split}(s,a) &\ \mapsto\ p(s'|s,a) &\qquad &(\text{type 1}) \\ s &\ \mapsto\ p(s'|s,.) &\qquad &(\text{type 2})\end{split}\]

A type-2 model essentially returns a vector of distributions of size \(n\), which is the number of discrete actions. Note that type-2 models are only well-defined for discrete action spaces, whereas type-1 models may be defined for any action space.

Let’s first define our type-1 forward-pass function:

import jax
import jax.numpy as jnp
import haiku as hk
from numpy import prod

def func_type1(S, A, is_training):
    """ (s,a) -> p(s'|s,a) """
    output_shape = (env.action_space.n, *env.observation_space.shape)
    dS = hk.Sequential((
        hk.Linear(8), jax.nn.relu,
        hk.Linear(8), jax.nn.relu,
        hk.Linear(8), jax.nn.relu,
        hk.Linear(prod(output_shape), w_init=jnp.zeros),
        hk.Reshape(output_shape),
    ))
    X = jax.vmap(jnp.kron)(S, A)
    return S + dS(X)


p = coax.TransitionModel(func_type1, env)

# example usage
s = env.reset()
a = env.action_space.sample()

print(s)        # [ 0.008, 0.021, -0.037, 0.032]
print(p(s, a))  # [-0.015, 0.067, -0.035, 0.029]
print(p(s))     # [[-0.012, 0.064, -0.039, 0.041], [ 0.022, 0.048, -0.039, 0.027]]

Alternatively, a type-2 forward-pass function might be:

def func_type2(S, is_training):
    """ s -> p(s'|s,.) """
    output_shape = (env.action_space.n, *env.observation_space.shape)
    dS = hk.Sequential((
        hk.Linear(8), jax.nn.relu,
        hk.Linear(8), jax.nn.relu,
        hk.Linear(8), jax.nn.relu,
        hk.Linear(prod(output_shape), w_init=jnp.zeros),
        hk.Reshape(output_shape),
    ))
    return S + dS(S)


p = coax.StochasticTransitionModel(func_type2, env)

# example usage
s = env.reset()
a = env.action_space.sample()

print(s)        # [ 0.004,  0.041,  0.043, -0.015]
print(p(s, a))  # [-0.024,  0.067,  0.042,  0.011]
print(p(s))     # [[-0.014, -0.102,  0.041, -0.052], [0.007, -0.065, 0.044, 0.102]]

If something goes wrong and you’d like to debug the forward-pass function, here’s an example of what the constructor runs under the hood:

rngs = hk.PRNGSequence(42)
transformed = hk.transform_with_state(func_type2)
params, function_state = transformed.init(next(rngs), *data.type2.inputs.args)
output, function_state = transformed.apply(params, function_state, next(rngs), *data.type2.inputs.args)

Reward Functions

The coax.RewardFunction and coax.StochasticRewardFunction are essentially aliases of coax.Q and coax.StochasticQ, respectively. Have a look at the Value Functions page for more details.

Object Reference

class coax.TransitionModel(func, env, observation_preprocessor=None, observation_postprocessor=None, action_preprocessor=None, random_seed=None)[source]

A deterministic transition function \(s'_\theta(s,a)\).

Parameters:
  • func (function) – A Haiku-style function that specifies the forward pass. The function signature must be the same as the example below.

  • env (gymnasium.Env) – The gymnasium-style environment. This is used to validate the input/output structure of func.

  • observation_preprocessor (function, optional) – Turns a single observation into a batch of observations in a form that is convenient for feeding into func. If left unspecified, this defaults to proba_dist.preprocess_variate. The reason why the default is not coax.utils.default_preprocessor() is that we prefer consistence with coax.StochasticTransitionModel.

  • observation_postprocessor (function, optional) – Takes a batch of generated observations and makes sure that they are that are compatible with the original observation_space. If left unspecified, this defaults to proba_dist.postprocess_variate.

  • action_preprocessor (function, optional) – Turns a single action into a batch of actions in a form that is convenient for feeding into func. If left unspecified, this defaults default_preprocessor(env.action_space).

  • random_seed (int, optional) – Seed for pseudo-random number generators.

__call__(s, a=None)[source]

Evaluate the state-action function on a state observation \(s\) or on a state-action pair \((s, a)\).

Parameters:
  • s (state observation) – A single state observation \(s\).

  • a (action) – A single action \(a\).

Returns:

q_sa or q_s (ndarray) – Depending on whether a is provided, this either returns a scalar representing \(q(s,a)\in\mathbb{R}\) or a vector representing \(q(s,.)\in\mathbb{R}^n\), where \(n\) is the number of discrete actions. Naturally, this only applies for discrete action spaces.

copy(deep=False)

Create a copy of the current instance.

Parameters:

deep (bool, optional) – Whether the copy should be a deep copy.

Returns:

copy – A deep copy of the current instance.

classmethod example_data(env, observation_preprocessor=None, action_preprocessor=None, batch_size=1, random_seed=None)[source]

A small utility function that generates example input and output data. These may be useful for writing and debugging your own custom function approximators.

soft_update(other, tau)

Synchronize the current instance with other through exponential smoothing:

\[\theta\ \leftarrow\ \theta + \tau\, (\theta_\text{new} - \theta)\]
Parameters:
  • other – A seperate copy of the current object. This object will hold the new parameters \(\theta_\text{new}\).

  • tau (float between 0 and 1, optional) – If we set \(\tau=1\) we do a hard update. If we pick a smaller value, we do a smooth update.

property function

The function approximator itself, defined as a JIT-compiled pure function. This function may be called directly as:

output, function_state = obj.function(obj.params, obj.function_state, obj.rng, *inputs)
property function_state

The state of the function approximator, see haiku.transform_with_state().

property function_type1

Same as function, except that it ensures a type-1 function signature, regardless of the underlying modeltype.

property function_type2

Same as function, except that it ensures a type-2 function signature, regardless of the underlying modeltype.

property modeltype

Specifier for how the transition function is modeled, i.e.

\[\begin{split}(s,a) &\mapsto s'(s,a) &\qquad (\text{modeltype} &= 1) \\ s &\mapsto s'(s,.) &\qquad (\text{modeltype} &= 2)\end{split}\]

Note that modeltype=2 is only well-defined if the action space is Discrete. Namely, \(n\) is the number of discrete actions.

property params

The parameters (weights) of the function approximator.

class coax.RewardFunction(func, env, observation_preprocessor=None, action_preprocessor=None, value_transform=None, random_seed=None)[source]

A deterministic reward function \(r_\theta(s,a)\).

Parameters:
  • func (function) – A Haiku-style function that specifies the forward pass. The function signature must be the same as the example below.

  • env (gymnasium.Env) – The gymnasium-style environment. This is used to validate the input/output structure of func.

  • observation_preprocessor (function, optional) – Turns a single observation into a batch of observations in a form that is convenient for feeding into func. If left unspecified, this defaults to default_preprocessor(env.observation_space).

  • action_preprocessor (function, optional) – Turns a single action into a batch of actions in a form that is convenient for feeding into func. If left unspecified, this defaults default_preprocessor(env.action_space).

  • value_transform (ValueTransform or pair of funcs, optional) –

    If provided, the target for the underlying function approximator is transformed such that:

    \[\tilde{q}_\theta(S_t, A_t)\ \approx\ f(G_t)\]

    This means that calling the function involves undoing this transformation:

    \[q(s, a)\ =\ f^{-1}(\tilde{q}_\theta(s, a))\]

    Here, \(f\) and \(f^{-1}\) are given by value_transform.transform_func and value_transform.inverse_func, respectively. Note that a ValueTransform is just a glorified pair of functions, i.e. passing value_transform=(func, inverse_func) works just as well.

  • random_seed (int, optional) – Seed for pseudo-random number generators.

__call__(s, a=None)

Evaluate the state-action function on a state observation \(s\) or on a state-action pair \((s, a)\).

Parameters:
  • s (state observation) – A single state observation \(s\).

  • a (action) – A single action \(a\).

Returns:

q_sa or q_s (ndarray) – Depending on whether a is provided, this either returns a scalar representing \(q(s,a)\in\mathbb{R}\) or a vector representing \(q(s,.)\in\mathbb{R}^n\), where \(n\) is the number of discrete actions. Naturally, this only applies for discrete action spaces.

copy(deep=False)

Create a copy of the current instance.

Parameters:

deep (bool, optional) – Whether the copy should be a deep copy.

Returns:

copy – A deep copy of the current instance.

classmethod example_data(env, observation_preprocessor=None, action_preprocessor=None, batch_size=1, random_seed=None)

A small utility function that generates example input and output data. These may be useful for writing and debugging your own custom function approximators.

soft_update(other, tau)

Synchronize the current instance with other through exponential smoothing:

\[\theta\ \leftarrow\ \theta + \tau\, (\theta_\text{new} - \theta)\]
Parameters:
  • other – A seperate copy of the current object. This object will hold the new parameters \(\theta_\text{new}\).

  • tau (float between 0 and 1, optional) – If we set \(\tau=1\) we do a hard update. If we pick a smaller value, we do a smooth update.

property function

The function approximator itself, defined as a JIT-compiled pure function. This function may be called directly as:

output, function_state = obj.function(obj.params, obj.function_state, obj.rng, *inputs)
property function_state

The state of the function approximator, see haiku.transform_with_state().

property function_type1

Same as function, except that it ensures a type-1 function signature, regardless of the underlying modeltype.

property function_type2

Same as function, except that it ensures a type-2 function signature, regardless of the underlying modeltype.

property modeltype

Specifier for how the q-function is modeled, i.e.

\[\begin{split}(s,a) &\mapsto q(s,a)\in\mathbb{R} &\qquad (\text{modeltype} &= 1) \\ s &\mapsto q(s,.)\in\mathbb{R}^n &\qquad (\text{modeltype} &= 2)\end{split}\]

Note that modeltype=2 is only well-defined if the action space is Discrete. Namely, \(n\) is the number of discrete actions.

property params

The parameters (weights) of the function approximator.

class coax.StochasticTransitionModel(func, env, observation_preprocessor=None, action_preprocessor=None, proba_dist=None, random_seed=None)[source]

A stochastic transition model \(p_\theta(s'|s,a)\). Here, \(s'\) is the successor state, given that we take action \(a\) from state \(s\).

Parameters:
  • func (function) – A Haiku-style function that specifies the forward pass.

  • env (gymnasium.Env) – The gymnasium-style environment. This is used to validate the input/output structure of func.

  • observation_preprocessor (function, optional) – Turns a single observation into a batch of observations in a form that is convenient for feeding into func. If left unspecified, this defaults to proba_dist.preprocess_variate.

  • action_preprocessor (function, optional) – Turns a single action into a batch of actions in a form that is convenient for feeding into func. If left unspecified, this defaults default_preprocessor(env.action_space).

  • proba_dist (ProbaDist, optional) –

    A probability distribution that is used to interpret the output of func <coax.Policy.func>. Check out the coax.proba_dists module for available options.

    If left unspecified, this defaults to:

    proba_dist = coax.proba_dists.ProbaDist(observation_space)
    

  • random_seed (int, optional) – Seed for pseudo-random number generators.

__call__(s, a=None, return_logp=False)[source]

Sample a successor state \(s'\) from the dynamics model \(p(s'|s,a)\).

Parameters:
  • s (state observation) – A single state observation \(s\).

  • a (action, optional) – A single action \(a\). This is required if the actions space is non-discrete.

  • return_logp (bool, optional) – Whether to return the log-propensity \(\log p(s'|s,a)\).

Returns:

  • s_next (state observation or list thereof) – Depending on whether a is provided, this either returns a single next-state \(s'\) or a list of \(n\) next-states, one for each discrete action.

  • logp (non-positive float or list thereof, optional) – The log-propensity \(\log p(s'|s,a)\). This is only returned if we set return_logp=True. Depending on whether a is provided, this is either a single float or a list of \(n\) floats, one for each discrete action.

copy(deep=False)

Create a copy of the current instance.

Parameters:

deep (bool, optional) – Whether the copy should be a deep copy.

Returns:

copy – A deep copy of the current instance.

dist_params(s, a=None)[source]

Get the parameters of the conditional probability distribution \(p_\theta(s'|s,a)\).

Parameters:
  • s (state observation) – A single state observation \(s\).

  • a (action, optional) – A single action \(a\). This is required if the actions space is non-discrete.

Returns:

dist_params (dict or list of dicts) – Depending on whether a is provided, this either returns a single dist-params dict or a list of \(n\) such dicts, one for each discrete action.

classmethod example_data(env, action_preprocessor=None, proba_dist=None, batch_size=1, random_seed=None)[source]

A small utility function that generates example input and output data. These may be useful for writing and debugging your own custom function approximators.

mean(s, a=None)[source]

Get the mean successor state \(s'\) according to the dynamics model, \(s'=\arg\max_{s'}p_\theta(s'|s,a)\).

Parameters:
  • s (state observation) – A single state observation \(s\).

  • a (action, optional) – A single action \(a\). This is required if the actions space is non-discrete.

Returns:

s_next (state observation or list thereof) – Depending on whether a is provided, this either returns a single next-state \(s'\) or a list of \(n\) next-states, one for each discrete action.

mode(s, a=None)[source]

Get the most probable successor state \(s'\) according to the dynamics model, \(s'=\arg\max_{s'}p_\theta(s'|s,a)\).

Parameters:
  • s (state observation) – A single state observation \(s\).

  • a (action, optional) – A single action \(a\). This is required if the actions space is non-discrete.

Returns:

s_next (state observation or list thereof) – Depending on whether a is provided, this either returns a single next-state \(s'\) or a list of \(n\) next-states, one for each discrete action.

soft_update(other, tau)

Synchronize the current instance with other through exponential smoothing:

\[\theta\ \leftarrow\ \theta + \tau\, (\theta_\text{new} - \theta)\]
Parameters:
  • other – A seperate copy of the current object. This object will hold the new parameters \(\theta_\text{new}\).

  • tau (float between 0 and 1, optional) – If we set \(\tau=1\) we do a hard update. If we pick a smaller value, we do a smooth update.

property function

The function approximator itself, defined as a JIT-compiled pure function. This function may be called directly as:

output, function_state = obj.function(obj.params, obj.function_state, obj.rng, *inputs)
property function_state

The state of the function approximator, see haiku.transform_with_state().

property function_type1

Same as function, except that it ensures a type-1 function signature, regardless of the underlying modeltype.

property function_type2

Same as function, except that it ensures a type-2 function signature, regardless of the underlying modeltype.

property mean_func_type1

The function that is used for computing the mean, defined as a JIT-compiled pure function. This function may be called directly as:

output = obj.mean_func_type1(obj.params, obj.function_state, obj.rng, S, A)
property mean_func_type2

The function that is used for computing the mean, defined as a JIT-compiled pure function. This function may be called directly as:

output = obj.mean_func_type2(obj.params, obj.function_state, obj.rng, S)
property mode_func_type1

The function that is used for computing the mode, defined as a JIT-compiled pure function. This function may be called directly as:

output = obj.mode_func_type1(obj.params, obj.function_state, obj.rng, S, A)
property mode_func_type2

The function that is used for computing the mode, defined as a JIT-compiled pure function. This function may be called directly as:

output = obj.mode_func_type2(obj.params, obj.function_state, obj.rng, S)
property modeltype

Specifier for how the dynamics model is implemented, i.e.

\[\begin{split}(s,a) &\mapsto p(s'|s,a) &\qquad (\text{modeltype} &= 1) \\ s &\mapsto p(s'|s,.) &\qquad (\text{modeltype} &= 2)\end{split}\]

Note that modeltype=2 is only well-defined if the action space is Discrete. Namely, \(n\) is the number of discrete actions.

property params

The parameters (weights) of the function approximator.

property sample_func_type1

The function that is used for generating random samples, defined as a JIT-compiled pure function. This function may be called directly as:

output = obj.sample_func_type1(obj.params, obj.function_state, obj.rng, S)
property sample_func_type2

The function that is used for generating random samples, defined as a JIT-compiled pure function. This function may be called directly as:

output = obj.sample_func_type2(obj.params, obj.function_state, obj.rng, S, A)
class coax.StochasticRewardFunction(func, env, value_range=None, num_bins=51, observation_preprocessor=None, action_preprocessor=None, value_transform=None, random_seed=None)[source]

A stochastic reward function \(p_\theta(r|s,a)\).

Parameters:
  • func (function) – A Haiku-style function that specifies the forward pass.

  • env (gymnasium.Env) – The gymnasium-style environment. This is used to validate the input/output structure of func.

  • value_range (tuple of floats, optional) – A pair of floats (min_value, max_value). If left unspecifed, this defaults to env.reward_range.

  • num_bins (int, optional) – The space of rewards is discretized in num_bins equal sized bins. We use the default setting of 51 as suggested in the Distributional RL paper.

  • observation_preprocessor (function, optional) – Turns a single observation into a batch of observations in a form that is convenient for feeding into func. If left unspecified, this defaults to default_preprocessor(env.observation_space).

  • action_preprocessor (function, optional) – Turns a single action into a batch of actions in a form that is convenient for feeding into func. If left unspecified, this defaults default_preprocessor(env.action_space).

  • value_transform (ValueTransform or pair of funcs, optional) –

    If provided, the target for the underlying function approximator is transformed:

    \[\tilde{G}_t\ =\ f(G_t)\]

    This means that calling the function involves undoing this transformation using its inverse \(f^{-1}\). The functions \(f\) and \(f^{-1}\) are given by value_transform.transform_func and value_transform.inverse_func, respectively. Note that a ValueTransform is just a glorified pair of functions, i.e. passing value_transform=(func, inverse_func) works just as well.

  • random_seed (int, optional) – Seed for pseudo-random number generators.

__call__(s, a=None, return_logp=False)

Sample a value.

Parameters:
  • s (state observation) – A single state observation \(s\).

  • a (action, optional) – A single action \(a\). This is required if the actions space is non-discrete.

  • return_logp (bool, optional) – Whether to return the log-propensity associated with the sampled output value.

Returns:

  • value (float or list thereof) – Depending on whether a is provided, this either returns a single value or a list of \(n\) values, one for each discrete action.

  • logp (non-positive float or list thereof, optional) – The log-propensity associated with the sampled output value. This is only returned if we set return_logp=True. Depending on whether a is provided, this is either a single float or a list of \(n\) floats, one for each discrete action.

copy(deep=False)

Create a copy of the current instance.

Parameters:

deep (bool, optional) – Whether the copy should be a deep copy.

Returns:

copy – A deep copy of the current instance.

dist_params(s, a=None)

Get the parameters of the underlying (conditional) probability distribution.

Parameters:
  • s (state observation) – A single state observation \(s\).

  • a (action, optional) – A single action \(a\). This is required if the actions space is non-discrete.

Returns:

dist_params (dict or list of dicts) – Depending on whether a is provided, this either returns a single dist-params dict or a list of \(n\) such dicts, one for each discrete action.

classmethod example_data(env, value_range, num_bins=51, observation_preprocessor=None, action_preprocessor=None, value_transform=None, batch_size=1, random_seed=None)

A small utility function that generates example input and output data. These may be useful for writing and debugging your own custom function approximators.

mean(s, a=None)

Get the mean value.

Parameters:
  • s (state observation) – A single state observation \(s\).

  • a (action, optional) – A single action \(a\). This is required if the actions space is non-discrete.

Returns:

value (float or list thereof) – Depending on whether a is provided, this either returns a single value or a list of \(n\) values, one for each discrete action.

mode(s, a=None)

Get the most probable value.

Parameters:
  • s (state observation) – A single state observation \(s\).

  • a (action, optional) – A single action \(a\). This is required if the actions space is non-discrete.

Returns:

value (float or list thereof) – Depending on whether a is provided, this either returns a single value or a list of \(n\) values, one for each discrete action.

soft_update(other, tau)

Synchronize the current instance with other through exponential smoothing:

\[\theta\ \leftarrow\ \theta + \tau\, (\theta_\text{new} - \theta)\]
Parameters:
  • other – A seperate copy of the current object. This object will hold the new parameters \(\theta_\text{new}\).

  • tau (float between 0 and 1, optional) – If we set \(\tau=1\) we do a hard update. If we pick a smaller value, we do a smooth update.

property function

The function approximator itself, defined as a JIT-compiled pure function. This function may be called directly as:

output, function_state = obj.function(obj.params, obj.function_state, obj.rng, *inputs)
property function_state

The state of the function approximator, see haiku.transform_with_state().

property function_type1

Same as function, except that it ensures a type-1 function signature, regardless of the underlying modeltype.

property function_type2

Same as function, except that it ensures a type-2 function signature, regardless of the underlying modeltype.

property mean_func_type1

The function that is used for computing the mean, defined as a JIT-compiled pure function. This function may be called directly as:

output = obj.mean_func_type1(obj.params, obj.function_state, obj.rng, S, A)
property mean_func_type2

The function that is used for computing the mean, defined as a JIT-compiled pure function. This function may be called directly as:

output = obj.mean_func_type2(obj.params, obj.function_state, obj.rng, S)
property mode_func_type1

The function that is used for computing the mode, defined as a JIT-compiled pure function. This function may be called directly as:

output = obj.mode_func_type1(obj.params, obj.function_state, obj.rng, S, A)
property mode_func_type2

The function that is used for computing the mode, defined as a JIT-compiled pure function. This function may be called directly as:

output = obj.mode_func_type2(obj.params, obj.function_state, obj.rng, S)
property modeltype

Specifier for how the dynamics model is implemented, i.e.

\[\begin{split}(s,a) &\mapsto p(s'|s,a) &\qquad (\text{modeltype} &= 1) \\ s &\mapsto p(s'|s,.) &\qquad (\text{modeltype} &= 2)\end{split}\]

Note that modeltype=2 is only well-defined if the action space is Discrete. Namely, \(n\) is the number of discrete actions.

property params

The parameters (weights) of the function approximator.

property sample_func_type1

The function that is used for generating random samples, defined as a JIT-compiled pure function. This function may be called directly as:

output = obj.sample_func_type1(obj.params, obj.function_state, obj.rng, S)
property sample_func_type2

The function that is used for generating random samples, defined as a JIT-compiled pure function. This function may be called directly as:

output = obj.sample_func_type2(obj.params, obj.function_state, obj.rng, S, A)