# Deep Deterministic Policy Gradients (DDPG)¶

The Deep Deterministic Policy Gradients (DDPG) algorithm is a little different from other policy objectives. It learns a policy directly from a (type-I) q-function. The

$J(\theta; s,a)\ =\ q_{\varphi_\text{targ}}(s, a_\theta(s))$

Here $$a_\theta(s)$$ is the mode of the underlying conditional probability distribution $$\pi_\theta(.|s)$$. See e.g. the mode method of coax.proba_dists.NormalDist. In other words, we evaluate the policy according to the current estimate of its best-case performance. This is implemented by the coax.policy_objectives.DeterministicPG updater class.

Since the policy objective uses a q-function $$q_\varphi(s,a)$$, we also need to learn that. At the moment of writing, there are two ways to learn $$q_\varphi(s,a)$$ in coax.

Option 1: SARSA.

The first option is to use SARSA updates, whose $$n$$-step bootstrapped target is constructed as:

$G^{(n)}_t\ =\ R^{(n)}_t + I^{(n)}_t\,q_{\varphi_\text{targ}}\!(S_{t+n}, A_{t+n})$

where $$A_{t+n}$$ is sampled from experience and

$\begin{split}R^{(n)}_t\ &=\ \sum_{k=0}^{n-1}\gamma^kR_{t+k} \\ I^{(n)}_t\ &=\ \left\{\begin{matrix} 0 & \text{if S_{t+n} is a terminal state} \\ \gamma^n & \text{otherwise} \end{matrix}\right.\end{split}$

This is implemented by the coax.td_learning.Sarsa updater class.

Option 2: Q-Learning.

The second option is to use q-learning updates, whose $$n$$-step bootstrapped target is instead constructed as:

$G^{(n)}_t\ =\ R^{(n)}_t + I^{(n)}_t\,q_{\varphi_\text{targ}}\!\left( S_{t+n}, a_{\theta_\text{targ}}\!(s)\right)$

Here, $$a_{\theta_\text{targ}}\!(s)$$ is the mode introduced above, evaluated on the target-model weights $$\theta_\text{targ}$$. The reason why we call this q-learning is that we construct the TD-target as though the next action $$A_{t+n}$$ would have been the greedy action. This is implemented by the coax.td_learning.QLearningMode updater class.

For more details, have a look at the spinningup page on DDPG here, which includes links to the original papers.

ddpg.py import gymnasium
import coax
import optax
import haiku as hk
import jax.numpy as jnp

# pick environment
env = gymnasium.make(...)
env = coax.wrappers.TrainMonitor(env)

def func_pi(S, is_training):
# custom haiku function (for continuous actions in this example)
mu = hk.Sequential([...])(S)  # mu.shape: (batch_size, *action_space.shape)
return {'mu': mu, 'logvar': jnp.full_like(mu, -10)}  # deterministic policy

def func_q(S, A, is_training):
# custom haiku function
value = hk.Sequential([...])
return value(S)  # output shape: (batch_size,)

# define function approximator
pi = coax.Policy(func_pi, env)
q = coax.Q(func_q, env, action_preprocessor=pi.proba_dist.preprocess_variate)

# target networks
pi_targ = pi.copy()
q_targ = q.copy()

# specify how to update policy and value function
qlearning = coax.td_learning.QLearning(q, pi_targ, q_targ, optimizer=optax.adam(0.002))

# specify how to trace the transitions
tracer = coax.reward_tracing.NStep(n=1, gamma=0.9)
buffer = coax.experience_replay.SimpleReplayBuffer(capacity=1000000)

# action noise
noise = coax.utils.OrnsteinUhlenbeckNoise(mu=0., sigma=0.2, theta=0.15)

for ep in range(100):
s, info = env.reset()
noise.reset()
noise.sigma *= 0.99  # slowly decrease noise scale

for t in range(env.spec.max_episode_steps):
a = noise(pi(s))
s_next, r, done, truncated, info = env.step(a)

while tracer:

# update
transition_batch = buffer.sample(batch_size=32)
metrics_q = qlearning.update(transition_batch)
metrics_pi = determ_pg.update(transition_batch)
env.record_metrics(metrics_q)
env.record_metrics(metrics_pi)

# periodically sync target models
if ep % 10 == 0:
pi_targ.soft_update(pi, tau=1.0)
q_targ.soft_update(q, tau=1.0)

if done or truncated:
break

s = s_next