Twin-Delayed DDPG (TD3)ΒΆ
The TD3 algorithm is a variant of DDPG, which replaces the ordinary q-learning updates by double q-learning updates i, in which the \(n\)-step bootstrapped target is constructed as:
\[G^{(n)}_t\ =\ R^{(n)}_t + I^{(n)}_t\,\min_{i,j}q_i(S_{t+n}, \arg\max_a q_j(S_{t+n}, a))\]
The rest of the agent is essentially the same as that of DDPG.
import gymnasium
import coax
import jax
import jax.numpy as jnp
import haiku as hk
import optax
# pick environment
env = gymnasium.make(...)
env = coax.wrappers.TrainMonitor(env)
def func_pi(S, is_training):
# custom haiku function (for continuous actions in this example)
mu = hk.Sequential([...])(S) # mu.shape: (batch_size, *action_space.shape)
return {'mu': mu, 'logvar': jnp.full_like(mu, -10)} # deterministic policy
def func_q(S, A, is_training):
# custom haiku function
value = hk.Sequential([...])
return value(S) # output shape: (batch_size,)
# define function approximator
pi = coax.Policy(func_pi, env)
q1 = coax.Q(func_q, env, action_preprocessor=pi.proba_dist.preprocess_variate)
q2 = coax.Q(func_q, env, action_preprocessor=pi.proba_dist.preprocess_variate)
# target networks
pi_targ = pi.copy()
q1_targ = q1.copy()
q2_targ = q2.copy()
# specify how to update policy and value function
determ_pg = coax.policy_objectives.DeterministicPG(pi, q1, optimizer=optax.adam(0.001))
qlearning1 = coax.td_learning.ClippedDoubleQLearning(
q1, q_targ_list=[q1_targ, q2_targ], optimizer=optax.adam(0.001))
qlearning2 = coax.td_learning.ClippedDoubleQLearning(
q2, q_targ_list=[q1_targ, q2_targ], optimizer=optax.adam(0.001))
# specify how to trace the transitions
tracer = coax.reward_tracing.NStep(n=1, gamma=0.9)
buffer = coax.experience_replay.SimpleReplayBuffer(capacity=1000000)
# action noise
noise = coax.utils.OrnsteinUhlenbeckNoise(mu=0., sigma=0.2, theta=0.15)
for ep in range(100):
s, info = env.reset()
noise.reset()
noise.sigma *= 0.99 # slowly decrease noise scale
for t in range(env.spec.max_episode_steps):
a = noise(pi(s))
s_next, r, done, truncated, info = env.step(a)
# add transition to buffer
tracer.add(s, a, r, done)
while tracer:
buffer.add(tracer.pop())
# update
if len(buffer) >= 128:
transition_batch = buffer.sample(batch_size=32)
# flip a coin to decide which of the q-functions to update
qlearning = qlearning1 if jax.random.bernoulli(q1.rng) else qlearning2
metrics = qlearning.update(transition_batch)
env.record_metrics(metrics)
# delay policy updates
if env.T % 2 == 0:
metrics = determ_pg.update(transition_batch)
env.record_metrics(metrics)
# sync target models
pi_targ.soft_update(pi, tau=0.01)
q1_targ.soft_update(q1, tau=0.01)
q2_targ.soft_update(q2, tau=0.01)
if done or truncated:
break
s = s_next