Deep Deterministic Policy Gradients (DDPG)ΒΆ

The Deep Deterministic Policy Gradients (DDPG) algorithm is a little different from other policy objectives. It learns a policy directly from a (type-I) q-function. The

\[J(\theta; s,a)\ =\ q_{\varphi_\text{targ}}(s, a_\theta(s))\]

Here \(a_\theta(s)\) is the mode of the underlying conditional probability distribution \(\pi_\theta(.|s)\). See e.g. the mode method of coax.proba_dists.NormalDist. In other words, we evaluate the policy according to the current estimate of its best-case performance. This is implemented by the coax.policy_objectives.DeterministicPG updater class.

Since the policy objective uses a q-function \(q_\varphi(s,a)\), we also need to learn that. At the moment of writing, there are two ways to learn \(q_\varphi(s,a)\) in coax.

Option 1: SARSA.

The first option is to use SARSA updates, whose \(n\)-step bootstrapped target is constructed as:

\[G^{(n)}_t\ =\ R^{(n)}_t + I^{(n)}_t\,q_{\varphi_\text{targ}}\!(S_{t+n}, A_{t+n})\]

where \(A_{t+n}\) is sampled from experience and

\[\begin{split}R^{(n)}_t\ &=\ \sum_{k=0}^{n-1}\gamma^kR_{t+k} \\ I^{(n)}_t\ &=\ \left\{\begin{matrix} 0 & \text{if $S_{t+n}$ is a terminal state} \\ \gamma^n & \text{otherwise} \end{matrix}\right.\end{split}\]

This is implemented by the coax.td_learning.Sarsa updater class.

Option 2: Q-Learning.

The second option is to use q-learning updates, whose \(n\)-step bootstrapped target is instead constructed as:

\[G^{(n)}_t\ =\ R^{(n)}_t + I^{(n)}_t\,q_{\varphi_\text{targ}}\!\left( S_{t+n}, a_{\theta_\text{targ}}\!(s)\right)\]

Here, \(a_{\theta_\text{targ}}\!(s)\) is the mode introduced above, evaluated on the target-model weights \(\theta_\text{targ}\). The reason why we call this q-learning is that we construct the TD-target as though the next action \(A_{t+n}\) would have been the greedy action. This is implemented by the coax.td_learning.QLearningMode updater class.

For more details, have a look at the spinningup page on DDPG here, which includes links to the original papers.


ddpg.py

Open in Google Colab
import gymnasium
import coax
import optax
import haiku as hk
import jax.numpy as jnp


# pick environment
env = gymnasium.make(...)
env = coax.wrappers.TrainMonitor(env)


def func_pi(S, is_training):
    # custom haiku function (for continuous actions in this example)
    mu = hk.Sequential([...])(S)  # mu.shape: (batch_size, *action_space.shape)
    return {'mu': mu, 'logvar': jnp.full_like(mu, -10)}  # deterministic policy


def func_q(S, A, is_training):
    # custom haiku function
    value = hk.Sequential([...])
    return value(S)  # output shape: (batch_size,)


# define function approximator
pi = coax.Policy(func_pi, env)
q = coax.Q(func_q, env, action_preprocessor=pi.proba_dist.preprocess_variate)


# target networks
pi_targ = pi.copy()
q_targ = q.copy()


# specify how to update policy and value function
determ_pg = coax.policy_objectives.DeterministicPG(pi, q, optimizer=optax.adam(0.001))
qlearning = coax.td_learning.QLearning(q, pi_targ, q_targ, optimizer=optax.adam(0.002))


# specify how to trace the transitions
tracer = coax.reward_tracing.NStep(n=1, gamma=0.9)
buffer = coax.experience_replay.SimpleReplayBuffer(capacity=1000000)


# action noise
noise = coax.utils.OrnsteinUhlenbeckNoise(mu=0., sigma=0.2, theta=0.15)


for ep in range(100):
    s, info = env.reset()
    noise.reset()
    noise.sigma *= 0.99  # slowly decrease noise scale

    for t in range(env.spec.max_episode_steps):
        a = noise(pi(s))
        s_next, r, done, truncated, info = env.step(a)

        # add transition to buffer
        tracer.add(s, a, r, done)
        while tracer:
            buffer.add(tracer.pop())

        # update
        transition_batch = buffer.sample(batch_size=32)
        metrics_q = qlearning.update(transition_batch)
        metrics_pi = determ_pg.update(transition_batch)
        env.record_metrics(metrics_q)
        env.record_metrics(metrics_pi)

        # periodically sync target models
        if ep % 10 == 0:
            pi_targ.soft_update(pi, tau=1.0)
            q_targ.soft_update(q, tau=1.0)

        if done or truncated:
            break

        s = s_next