Policies

coax.Policy

A parametrized policy \(\pi_\theta(a|s)\).

coax.EpsilonGreedy

Create an \(\epsilon\)-greedy policy, given a q-function.

coax.BoltzmannPolicy

Derive a Boltzmann policy from a q-function.

coax.RandomPolicy

A simple random policy.


There are generally two distinct ways of constructing a policy \(\pi(a|s)\). One method uses a function approximator to parametrize a state-action value function \(q_\theta(s,a)\) and then derives a policy from this q-function. The other method uses a function approximator to parametrize the policy directly, i.e. \(\pi(a|s)=\pi_\theta(a|s)\). The methods are called value-based methods and policy gradient methods, respectively.

A policy in coax is a function that maps state observations to actions. The example below shows how to use a policy in a simple episode roll-out.

env = gymnasium.make(...)

s = env.reset()
for t in range(max_episode_steps):
    a = pi(s)
    s_next, r, done, info = env.step(a)

    if done:
        break

    s = s_next

Some algorithms require us to collect the log-propensities along with the sampled actions. For this reason, policies have the optional return_logp flag:

a, logp = pi(s, return_logp=True)

The log-propensity represents \(\log\pi(a|s)\), which is a non-positive real-valued number. A stochastic policy returns logp<0, whereas a deterministic policy returns logp=0.

As an aside, we note that coax policies have two more methods:

a = pi.mode(s)                   # same as pi(s), except 'sampling' greedily
dist_params = pi.dist_params(s)  # distribution parameters, conditioned on s
print(dist_params)               # in this example: categorical dist with n=3
# {'logits': array([-0.5711, 1.0513 , 0.0012])}

Random policy

Before we discuss value-based policies and parametrized policies, let’s discuss the simplest possible policy first, namely coax.RandomPolicy. This policy doesn’t require any function approximator. It simply calls env.action_space.sample(). This policy may be useful for creating simple benchmarks.

pi = coax.RandomPolicy(env)

Value-based policies

Value-based policies are defined indirectly, via a q-function. Examples of value-based policies are coax.EpsilonGreedy (see example below) and coax.BoltzmannPolicy.

pi = coax.EpsilonGreedy(q, epsilon=0.1)
pi = coax.BoltzmannPolicy(q, temperature=0.02)

Note that the hyperparameters epsilon and temperature may be updated at any time, e.g.

pi.epsilon *= 0.99  # at the start of each epsiode

Parametrized policies

Now that we’ve discussed value-based policies, let’s start our discussion of parametrized (learnable) policies. We provide three examples:

  1. Discrete actions (categorical dist)

  2. Continuous actions (normal dist)

  3. Composite actions

Discrete actions

A common action space is Discrete. As an example, we’ll take the CartPole environment. To get started, let’s generate some example data so that we know the correct input/output format for our forward-pass function.

env = gymnasium.make('CartPole-v0')

data = coax.Policy.example_data(env)

print(data)
# ExampleData(
#   inputs=Inputs(
#     args=ArgsType2(
#       S=array(shape=(1, 4), dtype=float32)
#       is_training=True)
#     static_argnums=(1,))
#   output={
#     'logits': array(shape=(1, 2), dtype=float32)})

Now, our task is to write a Haiku-style forward-pass function that generates this output given the input. To be clear, our task is not to recreate the exact values; the example data is only there to give us an idea of the structure (shapes, dtypes, etc.).

def func(S, is_training):
    logits = hk.Sequential((
        hk.Flatten(),
        hk.Linear(8), jax.nn.relu,
        hk.Linear(8), jax.nn.relu,
        hk.Linear(8), jax.nn.relu,
        hk.Linear(env.action_space.n, w_init=jnp.zeros)
    ))
    return {'logits': logits(S)}


pi = coax.Policy(func, env)

# example usage
s = env.observation_space.sample()
a = pi(s)
print(a)  # 0 or 1

If something goes wrong and you’d like to debug the forward-pass function, here’s an example of what coax.Policy.__init__ runs under the hood:

rngs = hk.PRNGSequence(42)
transformed = hk.transform_with_state(func)
params, function_state = transformed.init(next(rngs), *data.inputs.args)
output, function_state = transformed.apply(params, function_state, next(rngs), *data.inputs.args)

Continuous actions

Besides discrete actions, we might wish to build an agent compatible with continuous actions. Here’s an example of how to create a valid policy function approximator for the Pendulum environment:

import coax
import jax
import haiku as hk
from math import prod

def func(S, is_training):
    shared = hk.Sequential((
        hk.Flatten(),
        hk.Linear(8), jax.nn.relu,
        hk.Linear(8), jax.nn.relu,
    ))
    mu = hk.Sequential((
        shared,
        hk.Linear(8), jax.nn.relu,
        hk.Linear(prod(env.action_space.shape), w_init=jnp.zeros),
        hk.Reshape(env.action_space.shape),
    ))
    logvar = hk.Sequential((
        shared,
        hk.Linear(8), jax.nn.relu,
        hk.Linear(prod(env.action_space.shape), w_init=jnp.zeros),
        hk.Reshape(env.action_space.shape),
    ))
    return {'mu': mu(S), 'logvar': logvar(S)}


pi = coax.Policy(func, env)

# example usage
s = env.observation_space.sample()
a = pi(s)

print(a)
# array([0.39267802], dtype=float32)

Note that if you’re ever unsure what the correct input / output format is, you can always generate some example data using the coax.Policy.example_data() helper (see example above).

Composite actions

The coax package supports all action spaces that are supported by the gymnasium.spaces API.

To illustrate the flexibility of the coax framework, here’s an example of a composite action space:

from collections import namedtuple
from gymnasium.spaces import Dict, Tuple, Box, Discrete, MultiDiscrete

DummyEnv = namedtuple('DummyEnv', ('observation_space', 'action_space'))
env = DummyEnv(
    Box(low=0, high=1, shape=(7,)),
    Dict({
        'foo': MultiDiscrete([4, 5]),
        'bar': Tuple((Box(low=0, high=1, shape=(2, 3)),)),
    }))

data = coax.Policy.example_data(observation_space, action_space)
print(data.output)
# {'foo': ({'logits': DeviceArray([[-1.29,  0.34,  1.57,  1.88]], dtype=float32)},
#          {'logits': DeviceArray([[-0.11, -0.35, -0.57,  2.51, 1.78]], dtype=float32)}),
#  'bar': ({'logvar': DeviceArray([[[-0.11,  1.23,  0.12],
#                                   [-0.35,  0.46,  0.73]]], dtype=float32),
#           'mu': DeviceArray([[[-0.35, -0.37, -0.67],
#                               [-0.44, -0.71,  0.45]]], dtype=float32)},)}

Thus, if we ensure that our forward-pass function outputs this format, we can sample actions in precisely the same way as we’ve done before. For example, here’s a compatible forward-pass function:

def func(S, is_training):
    return {
        'foo': ({'logits': hk.Linear(4)(S)},
                {'logits': hk.Linear(5)(S)}),
        'bar': ({'mu': hk.Linear(6)(S).reshape(-1, 2, 3),
                 'logvar': hk.Linear(6)(S).reshape(-1, 2, 3)},),
    }

pi = coax.Policy(func, env)

# example usage:
s = observation_space.sample()
a, logp = pi(s, return_logp=True)
assert a in action_space

print(logp)  # -8.647176
print(a)
# {'foo': array([2, 4]),
#  'bar': (array([[0.18, 0.57, 0.38],
#                 [0.81, 0.21, 0.67]], dtype=float32),)}

Object Reference

class coax.Policy(func, env, observation_preprocessor=None, proba_dist=None, random_seed=None)[source]

A parametrized policy \(\pi_\theta(a|s)\).

Parameters:
  • func (function) – A Haiku-style function that specifies the forward pass.

  • env (gymnasium.Env) – The gymnasium-style environment. This is used to validate the input/output structure of func.

  • observation_preprocessor (function, optional) – Turns a single observation into a batch of observations in a form that is convenient for feeding into func. If left unspecified, this defaults to default_preprocessor(env.observation_space).

  • proba_dist (ProbaDist, optional) –

    A probability distribution that is used to interpret the output of func <coax.Policy.func>. Check out the coax.proba_dists module for available options.

    If left unspecified, this defaults to:

    proba_dist = coax.proba_dists.ProbaDist(action_space)
    

  • random_seed (int, optional) – Seed for pseudo-random number generators.

__call__(s, return_logp=False)[source]

Sample an action \(a\sim\pi_\theta(.|s)\).

Parameters:
  • s (state observation) – A single state observation \(s\).

  • return_logp (bool, optional) – Whether to return the log-propensity \(\log\pi(a|s)\).

Returns:

  • a (action) – A single action \(a\).

  • logp (float, optional) – The log-propensity \(\log\pi_\theta(a|s)\). This is only returned if we set return_logp=True.

copy(deep=False)

Create a copy of the current instance.

Parameters:

deep (bool, optional) – Whether the copy should be a deep copy.

Returns:

copy – A deep copy of the current instance.

dist_params(s)[source]

Get the conditional distribution parameters of \(\pi_\theta(.|s)\).

Parameters:

s (state observation) – A single state observation \(s\).

Returns:

dist_params (Params) – The distribution parameters of \(\pi_\theta(.|s)\).

classmethod example_data(env, observation_preprocessor=None, proba_dist=None, batch_size=1, random_seed=None)[source]

A small utility function that generates example input and output data. These may be useful for writing and debugging your own custom function approximators.

mean(s)[source]

Get the mean of the distribution \(\pi_\theta(.|s)\).

Note that if the actions are discrete, this returns the mode instead.

Parameters:

s (state observation) – A single state observation \(s\).

Returns:

a (action) – A single action \(a\).

mode(s)[source]

Sample a greedy action \(a=\arg\max_a\pi_\theta(a|s)\).

Parameters:

s (state observation) – A single state observation \(s\).

Returns:

a (action) – A single action \(a\).

soft_update(other, tau)

Synchronize the current instance with other through exponential smoothing:

\[\theta\ \leftarrow\ \theta + \tau\, (\theta_\text{new} - \theta)\]
Parameters:
  • other – A seperate copy of the current object. This object will hold the new parameters \(\theta_\text{new}\).

  • tau (float between 0 and 1, optional) – If we set \(\tau=1\) we do a hard update. If we pick a smaller value, we do a smooth update.

property function

The function approximator itself, defined as a JIT-compiled pure function. This function may be called directly as:

output, function_state = obj.function(obj.params, obj.function_state, obj.rng, *inputs)
property function_state

The state of the function approximator, see haiku.transform_with_state().

property mean_func

The function that is used for getting the mean of the distribution, defined as a JIT-compiled pure function. This function may be called directly as:

output = obj.mean_func(obj.params, obj.function_state, obj.rng, *inputs)
property mode_func

The function that is used for getting the mode of the distribution, defined as a JIT-compiled pure function. This function may be called directly as:

output = obj.mode_func(obj.params, obj.function_state, obj.rng, *inputs)
property params

The parameters (weights) of the function approximator.

property sample_func

The function that is used for sampling random from the underlying proba_dist, defined as a JIT-compiled pure function. This function may be called directly as:

output = obj.sample_func(obj.params, obj.function_state, obj.rng, *inputs)
class coax.EpsilonGreedy(q, epsilon=0.1)[source]

Create an \(\epsilon\)-greedy policy, given a q-function.

This policy samples actions \(a\sim\pi_q(.|s)\) according to the following rule:

\[\begin{split}u &\sim \text{Uniform([0, 1])} \\ a_\text{rand} &\sim \text{Uniform}(\text{actions}) \\ a\ &=\ \left\{\begin{matrix} a_\text{rand} & \text{ if } u < \epsilon \\ \arg\max_{a'} q(s,a') & \text{ otherwise } \end{matrix}\right.\end{split}\]
Parameters:
  • q (Q) – A state-action value function.

  • epsilon (float between 0 and 1, optional) – The probability of sampling an action uniformly at random (as opposed to sampling greedily).

__call__(s, return_logp=False)

Sample an action \(a\sim\pi_q(.|s)\).

Parameters:
  • s (state observation) – A single state observation \(s\).

  • return_logp (bool, optional) – Whether to return the log-propensity \(\log\pi_q(a|s)\).

Returns:

  • a (action) – A single action \(a\).

  • logp (float, optional) – The log-propensity \(\log\pi_q(a|s)\). This is only returned if we set return_logp=True.

dist_params(s)

Get the conditional distribution parameters of \(\pi_q(.|s)\).

Parameters:

s (state observation) – A single state observation \(s\).

Returns:

dist_params (Params) – The distribution parameters of \(\pi_q(.|s)\).

mean(s)

Get the mean of the distribution \(\pi_q(.|s)\).

Note that if the actions are discrete, this returns the mode instead.

Parameters:

s (state observation) – A single state observation \(s\).

Returns:

a (action) – A single action \(a\).

mode(s)

Sample a greedy action \(a=\arg\max_a\pi_q(a|s)\).

Parameters:

s (state observation) – A single state observation \(s\).

Returns:

a (action) – A single action \(a\).

property function

The function approximator itself, defined as a JIT-compiled pure function. This function may be called directly as:

output, function_state = obj.function(obj.params, obj.function_state, obj.rng, *inputs)
property function_state

The state of the function approximator, see haiku.transform_with_state().

property mean_func

The function that is used for getting the mean of the distribution, defined as a JIT-compiled pure function. This function may be called directly as:

output = obj.mean_func(obj.params, obj.function_state, obj.rng, *inputs)
property mode_func

The function that is used for getting the mode of the distribution, defined as a JIT-compiled pure function. This function may be called directly as:

output = obj.mode_func(obj.params, obj.function_state, obj.rng, *inputs)
property params

The parameters (weights) of the function approximator.

property sample_func

The function that is used for sampling random from the underlying proba_dist, defined as a JIT-compiled pure function. This function may be called directly as:

output = obj.sample_func(obj.params, obj.function_state, obj.rng, *inputs)
class coax.BoltzmannPolicy(q, temperature=0.02)[source]

Derive a Boltzmann policy from a q-function.

This policy samples actions \(a\sim\pi_q(.|s)\) according to the following rule:

\[\begin{split}p &= \text{softmax}(q(s,.) / \tau) \\ a &\sim \text{Cat}(p)\end{split}\]

Note that this policy is only well-defined for discrete action spaces. Also, it’s worth noting that if the q-function has a non-trivial value transform \(f(.)\) (e.g. coax.value_transforms.LogTransform), we feed in the transformed estimate as our logits, i.e.

\[p = \text{softmax}(f(q(s,.)) / \tau)\]
Parameters:
  • q (Q) – A state-action value function.

  • temperature (positive float, optional) – The Boltzmann temperature \(\tau>0\) sets the sharpness of the categorical distribution. Picking a small value for \(\tau\) results in greedy sampling while large values results in uniform sampling.

__call__(s, return_logp=False)

Sample an action \(a\sim\pi_q(.|s)\).

Parameters:
  • s (state observation) – A single state observation \(s\).

  • return_logp (bool, optional) – Whether to return the log-propensity \(\log\pi_q(a|s)\).

Returns:

  • a (action) – A single action \(a\).

  • logp (float, optional) – The log-propensity \(\log\pi_q(a|s)\). This is only returned if we set return_logp=True.

dist_params(s)

Get the conditional distribution parameters of \(\pi_q(.|s)\).

Parameters:

s (state observation) – A single state observation \(s\).

Returns:

dist_params (Params) – The distribution parameters of \(\pi_q(.|s)\).

mean(s)

Get the mean of the distribution \(\pi_q(.|s)\).

Note that if the actions are discrete, this returns the mode instead.

Parameters:

s (state observation) – A single state observation \(s\).

Returns:

a (action) – A single action \(a\).

mode(s)

Sample a greedy action \(a=\arg\max_a\pi_q(a|s)\).

Parameters:

s (state observation) – A single state observation \(s\).

Returns:

a (action) – A single action \(a\).

property function

The function approximator itself, defined as a JIT-compiled pure function. This function may be called directly as:

output, function_state = obj.function(obj.params, obj.function_state, obj.rng, *inputs)
property function_state

The state of the function approximator, see haiku.transform_with_state().

property mean_func

The function that is used for getting the mean of the distribution, defined as a JIT-compiled pure function. This function may be called directly as:

output = obj.mean_func(obj.params, obj.function_state, obj.rng, *inputs)
property mode_func

The function that is used for getting the mode of the distribution, defined as a JIT-compiled pure function. This function may be called directly as:

output = obj.mode_func(obj.params, obj.function_state, obj.rng, *inputs)
property params

The parameters (weights) of the function approximator.

property sample_func

The function that is used for sampling random from the underlying proba_dist, defined as a JIT-compiled pure function. This function may be called directly as:

output = obj.sample_func(obj.params, obj.function_state, obj.rng, *inputs)
class coax.RandomPolicy(env, random_seed=None)[source]

A simple random policy.

Parameters:
  • env (gymnasium.Env) – The gymnasium-style environment. This is only used to get the env.action_space.

  • random_seed (int, optional) – Sets the random state to get reproducible results.

__call__(s, return_logp=False)[source]

Sample an action \(a\sim\pi_\theta(.|s)\).

Parameters:
  • s (state observation) – A single state observation \(s\).

  • return_logp (bool, optional) – Whether to return the log-propensity \(\log\pi(a|s)\).

Returns:

  • a (action) – A single action \(a\).

  • logp (float, optional) – The log-propensity \(\log\pi_\theta(a|s)\). This is only returned if we set return_logp=True.

dist_params(s)[source]

Get the conditional distribution parameters of \(\pi_\theta(.|s)\).

Parameters:

s (state observation) – A single state observation \(s\).

Returns:

dist_params (Params) – The distribution parameters of \(\pi_\theta(.|s)\).

mode(s)[source]

Sample a greedy action \(a=\arg\max_a\pi_\theta(a|s)\).

Parameters:

s (state observation) – A single state observation \(s\).

Returns:

a (action) – A single action \(a\).