Steps: [ install | jax | haiku | q-learning | dqn | ppo | next_steps ]

PPO on Pong

In this example we build an even more sophisticated agent known as PPO (paper).

Beating Atari 2600 Pong after a few hundred episodes.

You’ll solve the Pong environment, in which the agent learns to beat its opponent at the famous Atari video game.

This is the first example in which we use policy-based method (as opposed to a value-based method). It introduces the concept of a policy optimizer and a policy regularizer.

Just as before, you may either hit the Google Colab button or download and run the script on your local machine.


ppo.py

Open in Google Colab
import os

# set some env vars
os.environ.setdefault('JAX_PLATFORM_NAME', 'gpu')     # tell JAX to use GPU
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.1'  # don't use all gpu mem
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'              # tell XLA to be quiet

import gymnasium
import jax
import coax
import haiku as hk
import jax.numpy as jnp
from optax import adam


# the name of this script
name = 'ppo'

# env with preprocessing
env = gymnasium.make('PongNoFrameskip-v4', render_mode='rgb_array')
env = gymnasium.wrappers.AtariPreprocessing(env)
env = coax.wrappers.FrameStacking(env, num_frames=3)
env = gymnasium.wrappers.TimeLimit(env, max_episode_steps=108000 // 3)
env = coax.wrappers.TrainMonitor(env, name=name, tensorboard_dir=f"./data/tensorboard/{name}")


def shared(S, is_training):
    seq = hk.Sequential([
        coax.utils.diff_transform,
        hk.Conv2D(16, kernel_shape=8, stride=4), jax.nn.relu,
        hk.Conv2D(32, kernel_shape=4, stride=2), jax.nn.relu,
        hk.Flatten(),
    ])
    X = jnp.stack(S, axis=-1) / 255.  # stack frames
    return seq(X)


def func_pi(S, is_training):
    logits = hk.Sequential((
        hk.Linear(256), jax.nn.relu,
        hk.Linear(env.action_space.n, w_init=jnp.zeros),
    ))
    X = shared(S, is_training)
    return {'logits': logits(X)}


def func_v(S, is_training):
    value = hk.Sequential((
        hk.Linear(256), jax.nn.relu,
        hk.Linear(1, w_init=jnp.zeros), jnp.ravel
    ))
    X = shared(S, is_training)
    return value(X)


# function approximators
pi = coax.Policy(func_pi, env)
v = coax.V(func_v, env)

# target networks
pi_behavior = pi.copy()
v_targ = v.copy()

# policy regularizer (avoid premature exploitation)
entropy = coax.regularizers.EntropyRegularizer(pi, beta=0.001)

# updaters
simpletd = coax.td_learning.SimpleTD(v, v_targ, optimizer=adam(3e-4))
ppo_clip = coax.policy_objectives.PPOClip(pi, regularizer=entropy, optimizer=adam(3e-4))

# reward tracer and replay buffer
tracer = coax.reward_tracing.NStep(n=5, gamma=0.99)
buffer = coax.experience_replay.SimpleReplayBuffer(capacity=256)


# run episodes
while env.T < 3000000:
    s, info = env.reset()

    for t in range(env.spec.max_episode_steps):
        a, logp = pi_behavior(s, return_logp=True)
        s_next, r, done, truncated, info = env.step(a)

        # trace rewards and add transition to replay buffer
        tracer.add(s, a, r, done, logp)
        while tracer:
            buffer.add(tracer.pop())

        # learn
        if len(buffer) >= buffer.capacity:
            num_batches = int(4 * buffer.capacity / 32)  # 4 epochs per round
            for _ in range(num_batches):
                transition_batch = buffer.sample(32)
                metrics_v, td_error = simpletd.update(transition_batch, return_td_error=True)
                metrics_pi = ppo_clip.update(transition_batch, td_error)
                env.record_metrics(metrics_v)
                env.record_metrics(metrics_pi)

            buffer.clear()

            # sync target networks
            pi_behavior.soft_update(pi, tau=0.1)
            v_targ.soft_update(v, tau=0.1)

        if done or truncated:
            break

        s = s_next

    # generate an animated GIF to see what's going on
    if env.period(name='generate_gif', T_period=10000) and env.T > 50000:
        T = env.T - env.T % 10000  # round to 10000s
        coax.utils.generate_gif(
            env=env, policy=pi, resize_to=(320, 420),
            filepath=f"./data/gifs/{name}/T{T:08d}.gif")