Proximal Policy Optimization (PPO)

Consider the following the following importance-weighted off-policy objective:

(1)\[J(\theta)\ =\ \mathbb{E}_t \left\{ \rho_\theta(S_t, A_t)\,\mathcal{A}(S_t, A_t) \right\}\]

where \(\mathcal{A}(s, a)\) is the advantage function and \(\rho_\theta(s, a)\) is the probability ratio, defined as:

\[\rho_\theta(s, a)\ =\ \frac{\pi_\theta(a|s)}{\pi_{\theta_\text{targ}}(a|s)}\]

The parameters \(\theta_\text{targ}\) are the weights of the behavior policy, which is to say \(A_t\sim\pi_{\theta_\text{targ}}(.|S_t)\) in Eq. (1).

Importance sampling and outliers

The use of the probability ratios like \(\rho_\theta(s, a)\) is known as importance sampling, which allows for us to create unbiased estimates from out-of-distribution (off-policy) samples. A big problem with importance sampling is that the probability ratios are unbounded from above, which often leads to overestimation or underestimation.

Mitigating overestimation and the PPO-clip objective

The Proximal Policy Optimization (PPO) algorithm mitigates the problem of overestimation, leaving underestimation uncorrected for. This mitigation is achieved by effectively clipping the probability ratio in a specific way.

(2)\[J(\theta; s,a)\ =\ \min\Big( \rho_\theta(s,a)\,\mathcal{A}(s,a)\,,\ \bar{\rho}_\theta(s,a)\,\mathcal{A}(s,a) \Big)\]

where we introduced the clipped probability ratio:

\[\bar{\rho}_\theta(s,a)\ =\ \text{clip}(\rho_\theta(s,a), 1-\epsilon, 1+\epsilon)\]

The clipped estimate \(\bar{\rho}_\theta(s,a)\,\mathcal{A}(s,a)\) removes both overestimation and underestimation. Taking the minimal value between the unclipped and clipped estimates ensures that we don’t correct for underestimation. One reason to do this is that underestimation is harmless, but a more important reason is that it provides a path towards higher values of the expected advantage. In other words, not correcting for underestimation ensures that our objective stays concave.

Off-policy data collection

A very nice property of the clipped surrogate objective it that it allows for slightly more off-policy updates compared to the vanilla policy gradient. Moreover, it does this in a way that is compatible with our ordinary first-order optimization techniques.

In other words, the PPO-clip objective allows for our behavior policy to differ slightly from the current policy that’s being updated. This makes more suitable for parallelization than the standard REINFORCE-style policy objective, which is much more sensitive to off-policy deviations.

Further reading

This stub uses the same advantage actor-critic style setup as in Advantage Actor-Critic (A2C).

For more details on the PPO-clip objective, see the PPO paper. For the coax implementation, have a look at coax.policy_objectives.PPOClip.


ppo.py

Open in Google Colab
import gymnasium
import coax
import optax
import haiku as hk


# pick environment
env = gymnasium.make(...)
env = coax.wrappers.TrainMonitor(env)


def func_v(S, is_training):
    # custom haiku function
    value = hk.Sequential([...])
    return value(S)  # output shape: (batch_size,)


def func_pi(S, is_training):
    # custom haiku function (for discrete actions in this example)
    logits = hk.Sequential([...])
    return {'logits': logits(S)}  # logits shape: (batch_size, num_actions)


# function approximators
v = coax.V(func_v, env)
pi = coax.Policy(func_pi, env)


# slow-moving avg of pi
pi_behavior = pi.copy()


# specify how to update policy and value function
ppo_clip = coax.policy_objectives.PPOClip(pi, optimizer=optax.adam(0.001))
simple_td = coax.td_learning.SimpleTD(v, optimizer=optax.adam(0.001))


# specify how to trace the transitions
tracer = coax.reward_tracing.NStep(n=5, gamma=0.9)
buffer = coax.experience_replay.SimpleReplayBuffer(capacity=256)


for ep in range(100):
    s, info = env.reset()

    for t in range(env.spec.max_episode_steps):
        a, logp = pi_behavior(s, return_logp=True)
        s_next, r, done, truncated, info = env.step(a)

        # add transition to buffer
        tracer.add(s, a, r, done or truncated, logp)
        while tracer:
            buffer.add(tracer.pop())

        # update
        if len(buffer) == buffer.capacity:
            for _ in range(4 * buffer.capacity // 32):  # ~4 passes
                transition_batch = buffer.sample(batch_size=32)
                metrics_v, td_error = simple_td.update(transition_batch, return_td_error=True)
                metrics_pi = ppo_clip.update(transition_batch, td_error)
                env.record_metrics(metrics_v)
                env.record_metrics(metrics_pi)

            buffer.clear()
            pi_behavior.soft_update(pi, tau=0.1)

        if done or truncated:
            break

        s = s_next