# Value Functions¶

 coax.V A state value function $$v_\theta(s)$$. coax.Q A state-action value function $$q_\theta(s,a)$$. coax.StochasticV A state-value function $$v(s)$$, represented by a stochastic function $$\mathbb{P}_\theta(G_t|S_t=s)$$. coax.StochasticQ A q-function $$q(s,a)$$, represented by a stochastic function $$\mathbb{P}_\theta(G_t|S_t=s,A_t=a)$$. coax.SuccessorStateQ A state-action value function $$q(s,a)=r(s,a)+\gamma\mathop{\mathbb{E}}_{s'\sim p(.|s,a)}v(s')$$.

There are two kinds of value functions, state value functions $$v(s)$$ and state-action value functions (or q-functions) $$q(s,a)$$. The state value function evaluates the expected (discounted) return, defined as:

$v(s)\ =\ \mathbb{E}_t\left\{ R_t + \gamma\,R_{t+1} + \gamma^2 R_{t+2} + \dots \,\Big|\, S_t=s \right\}$

The operator $$\mathbb{E}_t$$ takes the expectation value over all transitions (indexed by $$t$$). The $$v(s)$$ function is implemented by the coax.V class. The state-action value is defined in a similar way:

$q(s,a)\ =\ \mathbb{E}_t\left\{ R_t + \gamma\,R_{t+1} + \gamma^2 R_{t+2} + \dots \,\Big|\, S_t=s, A_t=a \right\}$

This is implemented by the coax.Q class.

## v(s)¶

In this example we see how to construct a valid state value function $$v(s)$$. We’ll start by creating some example data, which allows us inspect the correct input/output format.

import coax
import gymnasium

env = gymnasium.make('CartPole-v0')
data = coax.V.example_data(env)

print(data)
# ExampleData(
#   inputs=Inputs(
#     args=ArgsType2(
#       S=array(shape=(1, 4), dtype=float32)
#       is_training=True)
#     static_argnums=(1,))
#   output=array(shape=(1,), dtype=float32))


From this we may define our Haiku-style forward-pass function:

import jax
import jax.numpy as jnp
import haiku as hk

def func(S, is_training):
seq = hk.Sequential((
hk.Linear(8), jax.nn.relu,
hk.Linear(8), jax.nn.relu,
hk.Linear(8), jax.nn.relu,
hk.Linear(1, w_init=jnp.zeros), jnp.ravel
))
return seq(S)

v = coax.V(func, env)

# example usage
s = env.observation_space.sample()
print(v(s))  # 0.0


## q(s, a)¶

In this example we see how to construct a valid state-action value function $$q(s,a)$$. Let’s create some example data again.

import coax
import gymnasium

env = gymnasium.make('CartPole-v0')
data = coax.Q.example_data(env)

print(data.type1)
# ExampleData(
#   inputs=Inputs(
#     args=ArgsType1(
#       S=array(shape=(1, 4), dtype=float32)
#       A=array(shape=(1, 2), dtype=float32)
#       is_training=True)
#     static_argnums=(2,))
#   output=array(shape=(1,), dtype=float32))

print(data.type2)
# ExampleData(
#   inputs=Inputs(
#     args=ArgsType2(
#       S=array(shape=(1, 4), dtype=float32)
#       is_training=True)
#     static_argnums=(1,))
#   output=array(shape=(1, 2), dtype=float32))


Note that there are two types of modeling a q-function:

$\begin{split}(s,a) &\ \mapsto\ q(s,a)\in\mathbb{R} &\qquad &(\text{type 1}) \\ s &\ \mapsto\ q(s,.)\in\mathbb{R}^n &\qquad &(\text{type 2})\end{split}$

where $$n$$ is the number of discrete actions. Note that type-2 q-functions are only well-defined for discrete action spaces, whereas type-1 q-functions may be defined for any action space.

Let’s first define our type-1 forward-pass function:

import jax
import jax.numpy as jnp
import haiku as hk

def func_type1(S, A, is_training):
""" (s,a) -> q(s,a) """
seq = hk.Sequential((
hk.Linear(8), jax.nn.relu,
hk.Linear(8), jax.nn.relu,
hk.Linear(8), jax.nn.relu,
hk.Linear(1, w_init=jnp.zeros), jnp.ravel
))
X = jnp.concatenate((S, A), axis=-1)
return seq(X)

q = coax.Q(func_type1, env)

# example usage
s = env.observation_space.sample()
a = env.action_space.sample()
print(q(s, a))  # 0.0
print(q(s))     # array([0., 0.])


Alternatively, a type-2 forward-pass function might be:

def func_type2(S, is_training):
""" s -> q(s,.) """
seq = hk.Sequential((
hk.Linear(8), jax.nn.relu,
hk.Linear(8), jax.nn.relu,
hk.Linear(8), jax.nn.relu,
hk.Linear(env.action_space.n, w_init=jnp.zeros)
))
return seq(S)

q = coax.Q(func_type2, env)

# example usage
s = env.observation_space.sample()
a = env.action_space.sample()
print(q(s, a))  # 0.0
print(q(s))     # array([0., 0.])


If something goes wrong and you’d like to debug the forward-pass function, here’s an example of what coax.Q.__init__ runs under the hood:

rngs = hk.PRNGSequence(42)
transformed = hk.transform_with_state(func_type2)
params, function_state = transformed.init(next(rngs), *data.type2.inputs.args)
output, function_state = transformed.apply(params, function_state, next(rngs), *data.type2.inputs.args)


## Object Reference¶

class coax.V(func, env, observation_preprocessor=None, value_transform=None, random_seed=None)[source]

A state value function $$v_\theta(s)$$.

Parameters:
• func (function) – A Haiku-style function that specifies the forward pass. The function signature must be the same as the example below.

• env (gymnasium.Env) – The gymnasium-style environment. This is used to validate the input/output structure of func.

• observation_preprocessor (function, optional) – Turns a single observation into a batch of observations in a form that is convenient for feeding into func. If left unspecified, this defaults to default_preprocessor(env.observation_space).

• value_transform (ValueTransform or pair of funcs, optional) –

If provided, the target for the underlying function approximator is transformed such that:

$\tilde{v}_\theta(S_t)\ \approx\ f(G_t)$

This means that calling the function involves undoing this transformation:

$v(s)\ =\ f^{-1}(\tilde{v}_\theta(s))$

Here, $$f$$ and $$f^{-1}$$ are given by value_transform.transform_func and value_transform.inverse_func, respectively. Note that a ValueTransform is just a glorified pair of functions, i.e. passing value_transform=(func, inverse_func) works just as well.

• random_seed (int, optional) – Seed for pseudo-random number generators.

__call__(s)[source]

Evaluate the value function on a state observation $$s$$.

Parameters:

s (state observation) – A single state observation $$s$$.

Returns:

v (ndarray, shape: ()) – The estimated expected value associated with the input state observation s.

copy(deep=False)

Create a copy of the current instance.

Parameters:

deep (bool, optional) – Whether the copy should be a deep copy.

Returns:

copy – A deep copy of the current instance.

classmethod example_data(env, observation_preprocessor=None, batch_size=1, random_seed=None)[source]

A small utility function that generates example input and output data. These may be useful for writing and debugging your own custom function approximators.

soft_update(other, tau)

Synchronize the current instance with other through exponential smoothing:

$\theta\ \leftarrow\ \theta + \tau\, (\theta_\text{new} - \theta)$
Parameters:
• other – A seperate copy of the current object. This object will hold the new parameters $$\theta_\text{new}$$.

• tau (float between 0 and 1, optional) – If we set $$\tau=1$$ we do a hard update. If we pick a smaller value, we do a smooth update.

property function

The function approximator itself, defined as a JIT-compiled pure function. This function may be called directly as:

output, function_state = obj.function(obj.params, obj.function_state, obj.rng, *inputs)

property function_state

The state of the function approximator, see haiku.transform_with_state().

property params

The parameters (weights) of the function approximator.

class coax.Q(func, env, observation_preprocessor=None, action_preprocessor=None, value_transform=None, random_seed=None)[source]

A state-action value function $$q_\theta(s,a)$$.

Parameters:
• func (function) – A Haiku-style function that specifies the forward pass. The function signature must be the same as the example below.

• env (gymnasium.Env) – The gymnasium-style environment. This is used to validate the input/output structure of func. func.

• observation_preprocessor (function, optional) – Turns a single observation into a batch of observations in a form that is convenient for feeding into func. If left unspecified, this defaults to default_preprocessor(env.observation_space).

• action_preprocessor (function, optional) – Turns a single action into a batch of actions in a form that is convenient for feeding into func. If left unspecified, this defaults default_preprocessor(env.action_space).

• value_transform (ValueTransform or pair of funcs, optional) –

If provided, the target for the underlying function approximator is transformed such that:

$\tilde{q}_\theta(S_t, A_t)\ \approx\ f(G_t)$

This means that calling the function involves undoing this transformation:

$q(s, a)\ =\ f^{-1}(\tilde{q}_\theta(s, a))$

Here, $$f$$ and $$f^{-1}$$ are given by value_transform.transform_func and value_transform.inverse_func, respectively. Note that a ValueTransform is just a glorified pair of functions, i.e. passing value_transform=(func, inverse_func) works just as well.

• random_seed (int, optional) – Seed for pseudo-random number generators.

__call__(s, a=None)[source]

Evaluate the state-action function on a state observation $$s$$ or on a state-action pair $$(s, a)$$.

Parameters:
• s (state observation) – A single state observation $$s$$.

• a (action) – A single action $$a$$.

Returns:

q_sa or q_s (ndarray) – Depending on whether a is provided, this either returns a scalar representing $$q(s,a)\in\mathbb{R}$$ or a vector representing $$q(s,.)\in\mathbb{R}^n$$, where $$n$$ is the number of discrete actions. Naturally, this only applies for discrete action spaces.

copy(deep=False)

Create a copy of the current instance.

Parameters:

deep (bool, optional) – Whether the copy should be a deep copy.

Returns:

copy – A deep copy of the current instance.

classmethod example_data(env, observation_preprocessor=None, action_preprocessor=None, batch_size=1, random_seed=None)[source]

A small utility function that generates example input and output data. These may be useful for writing and debugging your own custom function approximators.

soft_update(other, tau)

Synchronize the current instance with other through exponential smoothing:

$\theta\ \leftarrow\ \theta + \tau\, (\theta_\text{new} - \theta)$
Parameters:
• other – A seperate copy of the current object. This object will hold the new parameters $$\theta_\text{new}$$.

• tau (float between 0 and 1, optional) – If we set $$\tau=1$$ we do a hard update. If we pick a smaller value, we do a smooth update.

property function

The function approximator itself, defined as a JIT-compiled pure function. This function may be called directly as:

output, function_state = obj.function(obj.params, obj.function_state, obj.rng, *inputs)

property function_state

The state of the function approximator, see haiku.transform_with_state().

property function_type1

Same as function, except that it ensures a type-1 function signature, regardless of the underlying modeltype.

property function_type2

Same as function, except that it ensures a type-2 function signature, regardless of the underlying modeltype.

property modeltype

Specifier for how the q-function is modeled, i.e.

$\begin{split}(s,a) &\mapsto q(s,a)\in\mathbb{R} &\qquad (\text{modeltype} &= 1) \\ s &\mapsto q(s,.)\in\mathbb{R}^n &\qquad (\text{modeltype} &= 2)\end{split}$

Note that modeltype=2 is only well-defined if the action space is Discrete. Namely, $$n$$ is the number of discrete actions.

property params

The parameters (weights) of the function approximator.

class coax.StochasticV(func, env, value_range, num_bins=51, observation_preprocessor=None, value_transform=None, random_seed=None)[source]

A state-value function $$v(s)$$, represented by a stochastic function $$\mathbb{P}_\theta(G_t|S_t=s)$$.

Parameters:
• func (function) – A Haiku-style function that specifies the forward pass.

• env (gymnasium.Env) – The gymnasium-style environment. This is used to validate the input/output structure of func.

• value_range (tuple of floats) – A pair of floats (min_value, max_value).

• num_bins (int, optional) – The space of rewards is discretized in num_bins equal sized bins. We use the default setting of 51 as suggested in the Distributional RL paper.

• observation_preprocessor (function, optional) – Turns a single observation into a batch of observations in a form that is convenient for feeding into func. If left unspecified, this defaults to default_preprocessor(env.observation_space).

• value_transform (ValueTransform or pair of funcs, optional) –

If provided, the target for the underlying function approximator is transformed:

$\tilde{G}_t\ =\ f(G_t)$

This means that calling the function involves undoing this transformation using its inverse $$f^{-1}$$. The functions $$f$$ and $$f^{-1}$$ are given by value_transform.transform_func and value_transform.inverse_func, respectively. Note that a ValueTransform is just a glorified pair of functions, i.e. passing value_transform=(func, inverse_func) works just as well.

• random_seed (int, optional) – Seed for pseudo-random number generators.

__call__(s, return_logp=False)[source]

Sample a value.

Parameters:
• s (state observation) – A single state observation $$s$$.

• return_logp (bool, optional) – Whether to return the log-propensity associated with the sampled output value.

Returns:

• value (float or list thereof) – A single value associated with the state observation $$s$$.

• logp (non-positive float or list thereof, optional) – The log-propensity associated with the sampled output value. This is only returned if we set return_logp=True. Depending on whether a is provided, this is either a single float or a list of $$n$$ floats, one for each discrete action.

copy(deep=False)

Create a copy of the current instance.

Parameters:

deep (bool, optional) – Whether the copy should be a deep copy.

Returns:

copy – A deep copy of the current instance.

dist_params(s)[source]

Get the parameters of the underlying (conditional) probability distribution.

Parameters:

s (state observation) – A single state observation $$s$$.

Returns:

dist_params (dict or list of dicts) – Depending on whether a is provided, this either returns a single dist-params dict or a list of $$n$$ such dicts, one for each discrete action.

classmethod example_data(env, value_range, num_bins=51, observation_preprocessor=None, value_transform=None, batch_size=1, random_seed=None)[source]

A small utility function that generates example input and output data. These may be useful for writing and debugging your own custom function approximators.

mean(s)[source]

Get the mean value.

Parameters:

s (state observation) – A single state observation $$s$$.

Returns:

value (float) – A single value associated with the state observation $$s$$.

mode(s)[source]

Get the most probable value.

Parameters:

s (state observation) – A single state observation $$s$$.

Returns:

value (float) – A single value associated with the state observation $$s$$.

soft_update(other, tau)

Synchronize the current instance with other through exponential smoothing:

$\theta\ \leftarrow\ \theta + \tau\, (\theta_\text{new} - \theta)$
Parameters:
• other – A seperate copy of the current object. This object will hold the new parameters $$\theta_\text{new}$$.

• tau (float between 0 and 1, optional) – If we set $$\tau=1$$ we do a hard update. If we pick a smaller value, we do a smooth update.

property function

The function approximator itself, defined as a JIT-compiled pure function. This function may be called directly as:

output, function_state = obj.function(obj.params, obj.function_state, obj.rng, *inputs)

property function_state

The state of the function approximator, see haiku.transform_with_state().

property mean_func

The function that is used for getting the mean of the distribution, defined as a JIT-compiled pure function. This function may be called directly as:

output = obj.mean_func(obj.params, obj.function_state, obj.rng, *inputs)

property mode_func

The function that is used for getting the mode of the distribution, defined as a JIT-compiled pure function. This function may be called directly as:

output = obj.mode_func(obj.params, obj.function_state, obj.rng, *inputs)

property params

The parameters (weights) of the function approximator.

property sample_func

The function that is used for sampling random from the underlying proba_dist, defined as a JIT-compiled pure function. This function may be called directly as:

output = obj.sample_func(obj.params, obj.function_state, obj.rng, *inputs)

class coax.StochasticQ(func, env, value_range=None, num_bins=51, observation_preprocessor=None, action_preprocessor=None, value_transform=None, random_seed=None)[source]

A q-function $$q(s,a)$$, represented by a stochastic function $$\mathbb{P}_\theta(G_t|S_t=s,A_t=a)$$.

Parameters:
• func (function) – A Haiku-style function that specifies the forward pass.

• env (gymnasium.Env) – The gymnasium-style environment. This is used to validate the input/output structure of func.

• value_range (tuple of floats, optional) – A pair of floats (min_value, max_value). If no value_range is given, num_bins is the number of bins of the quantile function as in IQN or QR-DQN.

• num_bins (int, optional) –

If value_range is given: The space of rewards is discretized in num_bins equal sized bins. We use the default setting of 51 as suggested in the Distributional RL paper.

Else: The number of fractions of the quantile function of the rewards is defined by num_bins as in IQN or QR-DQN.

• observation_preprocessor (function, optional) – Turns a single observation into a batch of observations in a form that is convenient for feeding into func. If left unspecified, this defaults to default_preprocessor(env.observation_space).

• action_preprocessor (function, optional) – Turns a single action into a batch of actions in a form that is convenient for feeding into func. If left unspecified, this defaults default_preprocessor(env.action_space).

• value_transform (ValueTransform or pair of funcs, optional) –

If provided, the target for the underlying function approximator is transformed:

$\tilde{G}_t\ =\ f(G_t)$

This means that calling the function involves undoing this transformation using its inverse $$f^{-1}$$. The functions $$f$$ and $$f^{-1}$$ are given by value_transform.transform_func and value_transform.inverse_func, respectively. Note that a ValueTransform is just a glorified pair of functions, i.e. passing value_transform=(func, inverse_func) works just as well.

• random_seed (int, optional) – Seed for pseudo-random number generators.

__call__(s, a=None, return_logp=False)[source]

Sample a value.

Parameters:
• s (state observation) – A single state observation $$s$$.

• a (action, optional) – A single action $$a$$. This is required if the actions space is non-discrete.

• return_logp (bool, optional) – Whether to return the log-propensity associated with the sampled output value.

Returns:

• value (float or list thereof) – Depending on whether a is provided, this either returns a single value or a list of $$n$$ values, one for each discrete action.

• logp (non-positive float or list thereof, optional) – The log-propensity associated with the sampled output value. This is only returned if we set return_logp=True. Depending on whether a is provided, this is either a single float or a list of $$n$$ floats, one for each discrete action.

copy(deep=False)

Create a copy of the current instance.

Parameters:

deep (bool, optional) – Whether the copy should be a deep copy.

Returns:

copy – A deep copy of the current instance.

dist_params(s, a=None)[source]

Get the parameters of the underlying (conditional) probability distribution.

Parameters:
• s (state observation) – A single state observation $$s$$.

• a (action, optional) – A single action $$a$$. This is required if the actions space is non-discrete.

Returns:

dist_params (dict or list of dicts) – Depending on whether a is provided, this either returns a single dist-params dict or a list of $$n$$ such dicts, one for each discrete action.

classmethod example_data(env, value_range, num_bins=51, observation_preprocessor=None, action_preprocessor=None, value_transform=None, batch_size=1, random_seed=None)[source]

A small utility function that generates example input and output data. These may be useful for writing and debugging your own custom function approximators.

mean(s, a=None)[source]

Get the mean value.

Parameters:
• s (state observation) – A single state observation $$s$$.

• a (action, optional) – A single action $$a$$. This is required if the actions space is non-discrete.

Returns:

value (float or list thereof) – Depending on whether a is provided, this either returns a single value or a list of $$n$$ values, one for each discrete action.

mode(s, a=None)[source]

Get the most probable value.

Parameters:
• s (state observation) – A single state observation $$s$$.

• a (action, optional) – A single action $$a$$. This is required if the actions space is non-discrete.

Returns:

value (float or list thereof) – Depending on whether a is provided, this either returns a single value or a list of $$n$$ values, one for each discrete action.

soft_update(other, tau)

Synchronize the current instance with other through exponential smoothing:

$\theta\ \leftarrow\ \theta + \tau\, (\theta_\text{new} - \theta)$
Parameters:
• other – A seperate copy of the current object. This object will hold the new parameters $$\theta_\text{new}$$.

• tau (float between 0 and 1, optional) – If we set $$\tau=1$$ we do a hard update. If we pick a smaller value, we do a smooth update.

property function

The function approximator itself, defined as a JIT-compiled pure function. This function may be called directly as:

output, function_state = obj.function(obj.params, obj.function_state, obj.rng, *inputs)

property function_state

The state of the function approximator, see haiku.transform_with_state().

property function_type1

Same as function, except that it ensures a type-1 function signature, regardless of the underlying modeltype.

property function_type2

Same as function, except that it ensures a type-2 function signature, regardless of the underlying modeltype.

property mean_func_type1

The function that is used for computing the mean, defined as a JIT-compiled pure function. This function may be called directly as:

output = obj.mean_func_type1(obj.params, obj.function_state, obj.rng, S, A)

property mean_func_type2

The function that is used for computing the mean, defined as a JIT-compiled pure function. This function may be called directly as:

output = obj.mean_func_type2(obj.params, obj.function_state, obj.rng, S)

property mode_func_type1

The function that is used for computing the mode, defined as a JIT-compiled pure function. This function may be called directly as:

output = obj.mode_func_type1(obj.params, obj.function_state, obj.rng, S, A)

property mode_func_type2

The function that is used for computing the mode, defined as a JIT-compiled pure function. This function may be called directly as:

output = obj.mode_func_type2(obj.params, obj.function_state, obj.rng, S)

property modeltype

Specifier for how the dynamics model is implemented, i.e.

$\begin{split}(s,a) &\mapsto p(s'|s,a) &\qquad (\text{modeltype} &= 1) \\ s &\mapsto p(s'|s,.) &\qquad (\text{modeltype} &= 2)\end{split}$

Note that modeltype=2 is only well-defined if the action space is Discrete. Namely, $$n$$ is the number of discrete actions.

property params

The parameters (weights) of the function approximator.

property sample_func_type1

The function that is used for generating random samples, defined as a JIT-compiled pure function. This function may be called directly as:

output = obj.sample_func_type1(obj.params, obj.function_state, obj.rng, S)

property sample_func_type2

The function that is used for generating random samples, defined as a JIT-compiled pure function. This function may be called directly as:

output = obj.sample_func_type2(obj.params, obj.function_state, obj.rng, S, A)

class coax.SuccessorStateQ(v, p, r, gamma=0.9)[source]

A state-action value function $$q(s,a)=r(s,a)+\gamma\mathop{\mathbb{E}}_{s'\sim p(.|s,a)}v(s')$$.

caution A word of caution: If you use custom observation/action pre-/post-processors, please make sure that all three function approximators v, p and r use the same ones.

Parameters:
__call__(s, a=None)[source]

Evaluate the state-action function on a state observation $$s$$ or on a state-action pair $$(s, a)$$.

Parameters:
• s (state observation) – A single state observation $$s$$.

• a (action) – A single action $$a$$.

Returns:

q_sa or q_s (ndarray) – Depending on whether a is provided, this either returns a scalar representing $$q(s,a)\in\mathbb{R}$$ or a vector representing $$q(s,.)\in\mathbb{R}^n$$, where $$n$$ is the number of discrete actions. Naturally, this only applies for discrete action spaces.