Advantage Actor-Critic (A2C)ΒΆ

Advantage Actor-Critic (A2C) is probably the simplest actor-critic. Instead of using a q-function as its critic, it used the fact that the advantage function can be intepreted as the expectation value of the TD error. To see this, use the definition of the q-function to express the advantage function as:

\[\mathcal{A}(s,a)\ =\ q(s,a) - v(s)\ =\ \mathbb{E}_t \left\{G_t - v(s)\,|\, S_t=s, A_t=a\right\}\]

Then, we replace \(G_t\) by our bootstrapped estimate:

\[G_t\ \approx\ G^{(n)}_t\ =\ R^{(n)}_t + I^{(n)}_t\,v(S_{t+n})\]

where

\[\begin{split}R^{(n)}_t\ =\ \sum_{k=0}^{n-1}\gamma^kR_{t+k}\ , \qquad I^{(n)}_t\ =\ \left\{\begin{matrix} 0 & \text{if $S_{t+n}$ is a terminal state} \\ \gamma^n & \text{otherwise} \end{matrix}\right.\end{split}\]

The parametrized policy \(\pi_\theta(a|s)\) is updated using the following policy gradients:

\[\begin{split}g(\theta;S_t,A_t)\ &=\ \mathcal{A}(S_t,A_t)\,\nabla_\theta \log\pi_\theta(A_t|S_t) \\ &\approx\ \left(G^{(n)}_t - v(S_t)\right)\, \nabla_\theta \log\pi_\theta(A_t|S_t)\end{split}\]

The prefactor \(G^{(n)}_t - v(S_t)\) is known as the TD error.

For more details, see section 13.5 of Sutton & Barto.


a2c.py

Open in Google Colab
import gymnasium
import coax
import optax
import haiku as hk


# pick environment
env = gymnasium.make(...)
env = coax.wrappers.TrainMonitor(env)


def torso(S, is_training):
    # custom haiku function for the shared preprocessor
    with hk.experimental.name_scope('torso'):
        net = hk.Sequential([...])
    return net(S)


def func_v(S, is_training):
    # custom haiku function
    X = torso(S, is_training)
    with hk.experimental.name_scope('v'):
        value = hk.Sequential([...])
    return value(X)  # output shape: (batch_size,)


def func_pi(S, is_training):
    # custom haiku function (for discrete actions in this example)
    X = torso(S, is_training)
    with hk.experimental.name_scope('pi'):
        logits = hk.Sequential([...])
    return {'logits': logits(X)}  # logits shape: (batch_size, num_actions)


# function approximators
v = coax.V(func_v, env)
pi = coax.Policy(func_pi, env)


# specify how to update policy and value function
vanilla_pg = coax.policy_objectives.VanillaPG(pi, optimizer=optax.adam(0.001))
simple_td = coax.td_learning.SimpleTD(v, optimizer=optax.adam(0.002))


# specify how to trace the transitions
tracer = coax.reward_tracing.NStep(n=5, gamma=0.9)
buffer = coax.experience_replay.SimpleReplayBuffer(capacity=256)


for ep in range(100):
    s, info = env.reset()

    for t in range(env.spec.max_episode_steps):
        a, logp = pi(s, return_logp=True)
        s_next, r, done, truncated, info = env.step(a)

        # add transition to buffer
        # N.B. vanilla-pg doesn't use logp but we include it to make it easy to
        # swap in another policy updater that does require it, e.g. ppo-clip
        tracer.add(s, a, r, done or truncated, logp)
        while tracer:
            buffer.add(tracer.pop())

        # update
        if len(buffer) == buffer.capacity:
            for _ in range(4 * buffer.capacity // 32):  # ~4 passes
                transition_batch = buffer.sample(batch_size=32)
                metrics_v, td_error = simple_td.update(transition_batch, return_td_error=True)
                metrics_pi = vanilla_pg.update(transition_batch, td_error)
                env.record_metrics(metrics_v)
                env.record_metrics(metrics_pi)

                # optional: sync shared parameters (this is not always optimal)
                pi.params, v.params = coax.utils.sync_shared_params(pi.params, v.params)

            buffer.clear()

        if done or truncated:
            break

        s = s_next