Plug-n-play Reinforcement Learning in Python with OpenAI Gym and JAX

Cartpole-v0 Environment

Coax is a modular Reinforcement Learning (RL) python package for solving OpenAI Gym environments with JAX-based function approximators.


Coax is built on top of JAX, but it doesn’t have an explicit dependence on the jax python package. The reason is that your version of jaxlib will depend on your CUDA version.

To install coax and jax together, please select the configuration that applies to your setup.

CUDA version:

Command to run:
      $ pip install --upgrade jaxlib jax coax

Alternatively, you could build jaxlib from source by following this guide.


Here’s a short video that explains some design choices for coax.

RL concepts, not agents

The primary thing that sets coax apart from other packages is that is designed to align with the core RL concepts, not with the high-level concept of an agent. This makes coax more modular and user-friendly for RL researchers and practitioners.

You’re in control

Other RL frameworks often hide structure that you (the RL practitioner) are interested in. Most notably, the neural network architecture of the function approximators is often hidden from you. In coax, the network architecture takes center stage. You are in charge of defining their own forward-pass function.

Another bit of structure that other RL frameworks hide from you is the main training loop. This makes it hard to take an algorithm from paper to code. The design of coax is agnostic of the details of your training loop. You decide how and when you update your function approximators.

To illustrate these points, we include the full working example that trains a simple Q-learning agent in coax below.


We’ll implement a simple q-learning agent on the non-slippery variant of the FrozenLake environment, in which the agent must learn to navigate from the start state S to the goal state G, without hitting the holes H, see grid below.

















We start by defining our q-function. In coax, this is done by specifying a forward-pass function:

import gym
import coax
import haiku as hk

env = gym.make('FrozenLakeNonSlippery-v0')
env = coax.wrappers.TrainMonitor(env)

def func(S, is_training):
    values = hk.Linear(env.action_space.n, w_init=jnp.zeros)
    return values(S)  # shape: (batch_size, num_actions)

Note that if the action space is discrete, there are generally two ways of modeling a q-function:

\[\begin{split}(s,a) &\ \mapsto\ q(s,a)\in\mathbb{R} &\qquad &(\text{type 1}) \\ s &\ \mapsto\ q(s,.)\in\mathbb{R}^n &\qquad &(\text{type 2})\end{split}\]

where \(n\) is the number of discrete actions. Type-1 q-functions may be defined for any action space, whereas type-2 q-functions are specific to discrete actions. Coax accommodates both types of q-functions. In this example, we’re using a type-2 q-function.

Now that we defined our forward-pass function, we can create a q-function:

q = coax.Q(func, env)

# example input
s = env.observation_space.sample()
a = env.action_space.sample()

# example usage
q(s, a)  # 0.
q(s)     # array([0., 0., 0., 0.])

A function approximator \(q_\theta(s,a)\) holds a collection of model parameters (weights), denoted \(\theta\). These parameters are included in the q-function instance as:

# frozendict({
#   'linear': frozendict({
#      'w': DeviceArray(shape=(16, 4), dtype=float32),
#      'b': DeviceArray(shape=(4,), dtype=float32),
#    }),
# })

These q.params are used internally when we call the function, e.g. q(s,a). The next step is to create a policy, i.e. a function that maps states to actions. We’ll use a simple value-based policy:

# derive policy from q-function
pi = coax.EpsilonGreedy(q, epsilon=1.0)  # we'll scale down epsilon later

# sample action
a = pi(s)

The action a is an integer \(a\in\{0,1,2,3\}\), representing a single action. Now that we have our policy, we can start doing episode roll-outs:

s = env.reset()

for t in range(env.spec.max_episode_steps):
    a = pi(s)
    s_next, r, done, info = env.step(a)

    # this is where we should update our q-function

    if done:

    s = s_next

Of course, we can’t expect our policy to do very well, because it hasn’t been able to learn anything from the reward signal r. To do that, we need to create two more objects: a tracer and an updater. A tracer takes raw transition data and turns it into transition data can be readily used by the updater to update our function approximator. In the example below we see how this works in practice.

from optax import adam

# tracer and updater
tracer = coax.reward_tracing.NStep(n=1, gamma=0.9)
qlearning = coax.td_learning.QLearning(q, optimizer=adam(0.02))

for ep in range(500):
    pi.epsilon *= 0.99  # reduce exploration over time
    s = env.reset()

    for t in range(env.spec.max_episode_steps):
        a = pi(s)
        s_next, r, done, info = env.step(a)

        # trace and update
        tracer.add(s, a, r, done)
        while tracer:
            transition_batch = tracer.pop()

        if done:

        s = s_next

# [TrainMonitor|INFO] ep: 1,   T: 21,  G: 0,   avg_G: 0,   t: 20,  dt: 33.436ms
# [TrainMonitor|INFO] ep: 2,   T: 42,  G: 0,   avg_G: 0,   t: 20,  dt: 2.504ms
# [TrainMonitor|INFO] ep: 3,   T: 58,  G: 0,   avg_G: 0,   t: 15,  dt: 2.654ms
# [TrainMonitor|INFO] ep: 4,   T: 72,  G: 0,   avg_G: 0,   t: 13,  dt: 2.670ms
# [TrainMonitor|INFO] ep: 5,   T: 83,  G: 0,   avg_G: 0,   t: 10,  dt: 2.565ms
# ...
# [TrainMonitor|INFO] ep: 105, T: 1,020,   G: 0,   avg_G: 0.0868,  t: 5,   dt: 3.088ms
# [TrainMonitor|INFO] ep: 106, T: 1,023,   G: 0,   avg_G: 0.0781,  t: 2,   dt: 3.154ms
# [TrainMonitor|INFO] ep: 107, T: 1,035,   G: 1,   avg_G: 0.17,    t: 11,  dt: 3.401ms
# [TrainMonitor|INFO] ep: 108, T: 1,044,   G: 0,   avg_G: 0.153,   t: 8,   dt: 2.432ms
# [TrainMonitor|INFO] ep: 109, T: 1,057,   G: 1,   avg_G: 0.238,   t: 12,  dt: 2.439ms
# [TrainMonitor|INFO] ep: 110, T: 1,065,   G: 1,   avg_G: 0.314,   t: 7,   dt: 2.428ms
# ...
# [TrainMonitor|INFO] ep: 495, T: 4,096,   G: 1,   avg_G: 1,   t: 6,   dt: 2.572ms
# [TrainMonitor|INFO] ep: 496, T: 4,103,   G: 1,   avg_G: 1,   t: 6,   dt: 2.611ms
# [TrainMonitor|INFO] ep: 497, T: 4,110,   G: 1,   avg_G: 1,   t: 6,   dt: 2.601ms
# [TrainMonitor|INFO] ep: 498, T: 4,117,   G: 1,   avg_G: 1,   t: 6,   dt: 2.571ms
# [TrainMonitor|INFO] ep: 499, T: 4,124,   G: 1,   avg_G: 1,   t: 6,   dt: 2.611ms

Getting Started

Have a look at the Getting Started page to train your first RL agent.

If this ain’t your first rodeo, head over the examples listed here.