Cartpole with Model-Based agent

In this notebook we solve the CartPole-v0 environment using a model-based agent, which uses a function approximator for a value function \(v(s)\) as well as a dynamics model \(p(s'|s,a)\). Since the CartPole observation space covers the full phase space of the dynamics, this agent is able to learn the task within the first episode.

The way in which the dynamics model is used in this agent is rather simple. Namely, we only use it to define a single-step look-ahead q-function, i.e.

\[q(s,a)\ =\ r(s,a) + \mathop{\mathbb{E}}_{s'\sim p_\theta(.|s,a)} v_\theta(s')\]

This composite q-function is implemented by coax.SuccessorStateQ. Note that the reward function for the CartPole environment is simply \(r(s,a)=1\) at each time step, so we don’t need to model that.

If training is successful, this is what the result would look like:

CartPole environment solved.

import coax
import gymnasium
import jax.numpy as jnp
import haiku as hk
import optax
from coax.value_losses import mse

# the name of this script
name = 'model_based'

# the cart-pole MDP
env = gymnasium.make('CartPole-v0', render_mode='rgb_array')
env = coax.wrappers.TrainMonitor(env, name=name, tensorboard_dir=f"./data/tensorboard/{name}")

def func_v(S, is_training):
    potential = hk.Sequential((jnp.square, hk.Linear(1, w_init=jnp.zeros), jnp.ravel))
    return -jnp.square(S[:, 3]) + potential(S[:, :3])  # kinetic term is angular velocity squared

def func_p(S, A, is_training):
    dS = hk.Linear(4, w_init=jnp.zeros)
    return S + dS(A)

def func_r(S, A, is_training):
    return jnp.ones(S.shape[0])  # CartPole yields r=1 at every time step (no need to learn)

# function approximators
p = coax.TransitionModel(func_p, env)
v = coax.V(func_v, env, observation_preprocessor=p.observation_preprocessor)
r = coax.RewardFunction(func_r, env, observation_preprocessor=p.observation_preprocessor)

# composite objects
q = coax.SuccessorStateQ(v, p, r, gamma=0.9)
pi = coax.EpsilonGreedy(q, epsilon=0.)  # no exploration

# reward tracer
tracer = coax.reward_tracing.NStep(n=1, gamma=q.gamma)

# updaters
adam = optax.chain(optax.apply_every(k=16), optax.adam(1e-4))
simple_td = coax.td_learning.SimpleTD(v, loss_function=mse, optimizer=adam)

sgd = optax.sgd(1e-3, momentum=0.9, nesterov=True)
model_updater = coax.model_updaters.ModelUpdater(p, optimizer=sgd)

while env.T < 100000:
    s, info = env.reset()

    for t in range(env.spec.max_episode_steps):
        a = pi(s)
        s_next, r, done, truncated, info = env.step(a)

        tracer.add(s, a, r, done or truncated)
        while tracer:
            transition_batch = tracer.pop()

        if done or truncated:

        s = s_next

    # early stopping
    if env.ep >= 5 and env.avg_G > env.spec.reward_threshold:

# run env one more time to render
coax.utils.generate_gif(env, policy=pi, filepath=f"./data/{name}.gif", duration=25)