Deep Deterministic Policy Gradients (DDPG)ΒΆ

The Deep Deterministic Policy Gradients (DDPG) algorithm is a little different from other policy objectives. It learns a policy directly from a (type-I) q-function. The

\[J(\theta; s,a)\ =\ q_{\varphi_\text{targ}}(s, a_\theta(s))\]

Here \(a_\theta(s)\) is the mode of the underlying conditional probability distribution \(\pi_\theta(.|s)\). See e.g. the mode method of coax.proba_dists.NormalDist. In other words, we evaluate the policy according to the current estimate of its best-case performance. This is implemented by the coax.policy_objectives.DeterministicPG updater class.

Since the policy objective uses a q-function \(q_\varphi(s,a)\), we also need to learn that. At the moment of writing, there are two ways to learn \(q_\varphi(s,a)\) in coax.

Option 1: SARSA.

The first option is to use SARSA updates, whose \(n\)-step bootstrapped target is constructed as:

\[G^{(n)}_t\ =\ R^{(n)}_t + I^{(n)}_t\,q_{\varphi_\text{targ}}\!(S_{t+n}, A_{t+n})\]

where \(A_{t+n}\) is sampled from experience and

\[\begin{split}R^{(n)}_t\ &=\ \sum_{k=0}^{n-1}\gamma^kR_{t+k} \\ I^{(n)}_t\ &=\ \left\{\begin{matrix} 0 & \text{if $S_{t+n}$ is a terminal state} \\ \gamma^n & \text{otherwise} \end{matrix}\right.\end{split}\]

This is implemented by the coax.td_learning.Sarsa updater class.

Option 2: Q-Learning.

The second option is to use q-learning updates, whose \(n\)-step bootstrapped target is instead constructed as:

\[G^{(n)}_t\ =\ R^{(n)}_t + I^{(n)}_t\,q_{\varphi_\text{targ}}\!\left( S_{t+n}, a_{\theta_\text{targ}}\!(s)\right)\]

Here, \(a_{\theta_\text{targ}}\!(s)\) is the mode introduced above, evaluated on the target-model weights \(\theta_\text{targ}\). The reason why we call this q-learning is that we construct the TD-target as though the next action \(A_{t+n}\) would have been the greedy action. This is implemented by the coax.td_learning.QLearningMode updater class.

For more details, have a look at the spinningup page on DDPG here, which includes links to the original papers.

import gymnasium
import coax
import optax
import haiku as hk
import jax.numpy as jnp

# pick environment
env = gymnasium.make(...)
env = coax.wrappers.TrainMonitor(env)

def func_pi(S, is_training):
    # custom haiku function (for continuous actions in this example)
    mu = hk.Sequential([...])(S)  # mu.shape: (batch_size, *action_space.shape)
    return {'mu': mu, 'logvar': jnp.full_like(mu, -10)}  # deterministic policy

def func_q(S, A, is_training):
    # custom haiku function
    value = hk.Sequential([...])
    return value(S)  # output shape: (batch_size,)

# define function approximator
pi = coax.Policy(func_pi, env)
q = coax.Q(func_q, env, action_preprocessor=pi.proba_dist.preprocess_variate)

# target networks
pi_targ = pi.copy()
q_targ = q.copy()

# specify how to update policy and value function
determ_pg = coax.policy_objectives.DeterministicPG(pi, q, optimizer=optax.adam(0.001))
qlearning = coax.td_learning.QLearning(q, pi_targ, q_targ, optimizer=optax.adam(0.002))

# specify how to trace the transitions
tracer = coax.reward_tracing.NStep(n=1, gamma=0.9)
buffer = coax.experience_replay.SimpleReplayBuffer(capacity=1000000)

# action noise
noise = coax.utils.OrnsteinUhlenbeckNoise(mu=0., sigma=0.2, theta=0.15)

for ep in range(100):
    s, info = env.reset()
    noise.sigma *= 0.99  # slowly decrease noise scale

    for t in range(env.spec.max_episode_steps):
        a = noise(pi(s))
        s_next, r, done, truncated, info = env.step(a)

        # add transition to buffer
        tracer.add(s, a, r, done)
        while tracer:

        # update
        transition_batch = buffer.sample(batch_size=32)
        metrics_q = qlearning.update(transition_batch)
        metrics_pi = determ_pg.update(transition_batch)

        # periodically sync target models
        if ep % 10 == 0:
            pi_targ.soft_update(pi, tau=1.0)
            q_targ.soft_update(q, tau=1.0)

        if done or truncated:

        s = s_next