# coax¶

*Plug-n-play Reinforcement Learning in Python with Gymnasium and JAX*

Coax is a modular Reinforcement Learning (RL) python package for solving Gymnasium (formerly OpenAI Gym) environments with JAX-based function approximators.

## Install¶

Coax is built on top of JAX, but it doesn’t have an explicit dependence on the `jax`

python package. The reason is that your version of `jaxlib`

will depend on your CUDA version.

To install `jax`

, please have a look at the instructions: https://github.com/google/jax#installation

Once `jax`

and `jaxlib`

are installed, you can install **coax** simple by running:

```
$ pip install coax
```

Or, alternatively, to install **coax** from the latest branch on github:

```
$ pip install git+https://github.com/coax-dev/coax.git@main
```

## Introduction¶

Here’s a short video that explains some design choices for **coax**.

## RL concepts, not agents¶

The primary thing that sets coax apart from other packages is that is designed to align with the core RL concepts, not with the high-level concept of an agent. This makes coax more modular and user-friendly for RL researchers and practitioners.

## You’re in control¶

Other RL frameworks often hide structure that you (the RL practitioner) are interested in. Most
notably, the **neural network architecture** of the function approximators is often hidden from you.
In coax, the network architecture takes center stage. You are in charge of defining your own
forward-pass function.

Another bit of structure that other RL frameworks hide from you is the main **training loop**. This
makes it hard to take an algorithm from paper to code. The design of coax is agnostic of the
details of your training loop. You decide how and when you update your function approximators.

To illustrate these points, we include the full working example that trains a simple Q-learning agent in coax below.

## Example¶

We’ll implement a simple q-learning agent on the non-slippery variant of the *FrozenLake*
environment, in which the agent must learn to navigate from the start state **S** to the goal state
**G**, without hitting the holes **H**, see grid below.

S |
F |
F |
F |

F |
H |
F |
H |

F |
F |
F |
H |

H |
F |
F |
G |

We start by defining our q-function. In coax, this is done by specifying a forward-pass function:

```
import gymnasium
import coax
import haiku as hk
env = gymnasium.make('FrozenLakeNonSlippery-v0')
env = coax.wrappers.TrainMonitor(env)
def func(S, is_training):
values = hk.Linear(env.action_space.n, w_init=jnp.zeros)
return values(S) # shape: (batch_size, num_actions)
```

Note that if the action space is discrete, there are generally two ways of modeling a q-function:

where \(n\) is the number of discrete actions. Type-1 q-functions may be defined for any action space, whereas type-2 q-functions are specific to discrete actions. Coax accommodates both types of q-functions. In this example, we’re using a type-2 q-function.

Now that we defined our forward-pass function, we can create a q-function:

```
q = coax.Q(func, env)
# example input
s = env.observation_space.sample()
a = env.action_space.sample()
# example usage
q(s, a) # 0.
q(s) # array([0., 0., 0., 0.])
```

A function approximator \(q_\theta(s,a)\) holds a collection of model parameters (weights), denoted \(\theta\). These parameters are included in the q-function instance as:

```
q.params
# frozendict({
# 'linear': frozendict({
# 'w': DeviceArray(shape=(16, 4), dtype=float32),
# 'b': DeviceArray(shape=(4,), dtype=float32),
# }),
# })
```

These `q.params`

are used internally when we call the function, e.g. `q(s,a)`

. The next
step is to create a policy, i.e. a function that maps states to actions. We’ll use a simple
value-based policy:

```
# derive policy from q-function
pi = coax.EpsilonGreedy(q, epsilon=1.0) # we'll scale down epsilon later
# sample action
a = pi(s)
```

The action `a`

is an integer \(a\in\{0,1,2,3\}\), representing a single action. Now that
we have our policy, we can start doing episode roll-outs:

```
s, info = env.reset()
for t in range(env.spec.max_episode_steps):
a = pi(s)
s_next, r, done, truncated, info = env.step(a)
# this is where we should update our q-function
...
if done:
break
s = s_next
```

Of course, we can’t expect our policy to do very well, because it hasn’t been able to learn anything
from the reward signal `r`

. To do that, we need to create two more objects: a **tracer** and
an **updater**. A *tracer* takes raw transition data and turns it into transition data that can then
be used by the *updater* to update the function approximator `q`

. In the example below we see
how this works in practice.

```
from optax import adam
# tracer and updater
tracer = coax.reward_tracing.NStep(n=1, gamma=0.9)
qlearning = coax.td_learning.QLearning(q, optimizer=adam(0.02))
for ep in range(500):
pi.epsilon *= 0.99 # reduce exploration over time
s, info = env.reset()
for t in range(env.spec.max_episode_steps):
a = pi(s)
s_next, r, done, truncated, info = env.step(a)
# trace and update
tracer.add(s, a, r, done or truncated)
while tracer:
transition_batch = tracer.pop()
qlearning.update(transition_batch)
if done or truncated:
break
s = s_next
# [TrainMonitor|INFO] ep: 1, T: 21, G: 0, avg_G: 0, t: 20, dt: 33.436ms
# [TrainMonitor|INFO] ep: 2, T: 42, G: 0, avg_G: 0, t: 20, dt: 2.504ms
# [TrainMonitor|INFO] ep: 3, T: 58, G: 0, avg_G: 0, t: 15, dt: 2.654ms
# [TrainMonitor|INFO] ep: 4, T: 72, G: 0, avg_G: 0, t: 13, dt: 2.670ms
# [TrainMonitor|INFO] ep: 5, T: 83, G: 0, avg_G: 0, t: 10, dt: 2.565ms
# ...
# [TrainMonitor|INFO] ep: 105, T: 1,020, G: 0, avg_G: 0.0868, t: 5, dt: 3.088ms
# [TrainMonitor|INFO] ep: 106, T: 1,023, G: 0, avg_G: 0.0781, t: 2, dt: 3.154ms
# [TrainMonitor|INFO] ep: 107, T: 1,035, G: 1, avg_G: 0.17, t: 11, dt: 3.401ms
# [TrainMonitor|INFO] ep: 108, T: 1,044, G: 0, avg_G: 0.153, t: 8, dt: 2.432ms
# [TrainMonitor|INFO] ep: 109, T: 1,057, G: 1, avg_G: 0.238, t: 12, dt: 2.439ms
# [TrainMonitor|INFO] ep: 110, T: 1,065, G: 1, avg_G: 0.314, t: 7, dt: 2.428ms
# ...
# [TrainMonitor|INFO] ep: 495, T: 4,096, G: 1, avg_G: 1, t: 6, dt: 2.572ms
# [TrainMonitor|INFO] ep: 496, T: 4,103, G: 1, avg_G: 1, t: 6, dt: 2.611ms
# [TrainMonitor|INFO] ep: 497, T: 4,110, G: 1, avg_G: 1, t: 6, dt: 2.601ms
# [TrainMonitor|INFO] ep: 498, T: 4,117, G: 1, avg_G: 1, t: 6, dt: 2.571ms
# [TrainMonitor|INFO] ep: 499, T: 4,124, G: 1, avg_G: 1, t: 6, dt: 2.611ms
```

## Getting Started¶

Have a look at the Getting Started page to train your first RL agent.

If this ain’t your first rodeo, head over the examples listed here.