Advantage Actor-Critic (A2C) is probably the simplest actor-critic. Instead of using a q-function as its critic, it used the fact that the advantage function can be intepreted as the expectation value of the TD error. To see this, use the definition of the q-function to express the advantage function as:

$\mathcal{A}(s,a)\ =\ q(s,a) - v(s)\ =\ \mathbb{E}_t \left\{G_t - v(s)\,|\, S_t=s, A_t=a\right\}$

Then, we replace $$G_t$$ by our bootstrapped estimate:

$G_t\ \approx\ G^{(n)}_t\ =\ R^{(n)}_t + I^{(n)}_t\,v(S_{t+n})$

where

$\begin{split}R^{(n)}_t\ =\ \sum_{k=0}^{n-1}\gamma^kR_{t+k}\ , \qquad I^{(n)}_t\ =\ \left\{\begin{matrix} 0 & \text{if S_{t+n} is a terminal state} \\ \gamma^n & \text{otherwise} \end{matrix}\right.\end{split}$

The parametrized policy $$\pi_\theta(a|s)$$ is updated using the following policy gradients:

$\begin{split}g(\theta;S_t,A_t)\ &=\ \mathcal{A}(S_t,A_t)\,\nabla_\theta \log\pi_\theta(A_t|S_t) \\ &\approx\ \left(G^{(n)}_t - v(S_t)\right)\, \nabla_\theta \log\pi_\theta(A_t|S_t)\end{split}$

The prefactor $$G^{(n)}_t - v(S_t)$$ is known as the TD error.

For more details, see section 13.5 of Sutton & Barto.

a2c.py

import gymnasium
import coax
import optax
import haiku as hk

# pick environment
env = gymnasium.make(...)
env = coax.wrappers.TrainMonitor(env)

def torso(S, is_training):
# custom haiku function for the shared preprocessor
with hk.experimental.name_scope('torso'):
net = hk.Sequential([...])
return net(S)

def func_v(S, is_training):
# custom haiku function
X = torso(S, is_training)
with hk.experimental.name_scope('v'):
value = hk.Sequential([...])
return value(X)  # output shape: (batch_size,)

def func_pi(S, is_training):
# custom haiku function (for discrete actions in this example)
X = torso(S, is_training)
with hk.experimental.name_scope('pi'):
logits = hk.Sequential([...])
return {'logits': logits(X)}  # logits shape: (batch_size, num_actions)

# function approximators
v = coax.V(func_v, env)
pi = coax.Policy(func_pi, env)

# specify how to update policy and value function

# specify how to trace the transitions
tracer = coax.reward_tracing.NStep(n=5, gamma=0.9)
buffer = coax.experience_replay.SimpleReplayBuffer(capacity=256)

for ep in range(100):
s, info = env.reset()

for t in range(env.spec.max_episode_steps):
a, logp = pi(s, return_logp=True)
s_next, r, done, truncated, info = env.step(a)

# N.B. vanilla-pg doesn't use logp but we include it to make it easy to
# swap in another policy updater that does require it, e.g. ppo-clip
tracer.add(s, a, r, done or truncated, logp)
while tracer:

# update
if len(buffer) == buffer.capacity:
for _ in range(4 * buffer.capacity // 32):  # ~4 passes
transition_batch = buffer.sample(batch_size=32)
metrics_v, td_error = simple_td.update(transition_batch, return_td_error=True)
metrics_pi = vanilla_pg.update(transition_batch, td_error)
env.record_metrics(metrics_v)
env.record_metrics(metrics_pi)

# optional: sync shared parameters (this is not always optimal)
pi.params, v.params = coax.utils.sync_shared_params(pi.params, v.params)

buffer.clear()

if done or truncated:
break

s = s_next