Policy Objectives¶
A vanilla policygradient objective, a.k.a. 

PPOclip policy objective. 

A deterministic policygradient objective, a.k.a. 

This is a collection of policy objectives that can be used in policygradient methods.
Object Reference¶
 class coax.policy_objectives.VanillaPG(pi, optimizer=None, regularizer=None)[source]¶
A vanilla policygradient objective, a.k.a. REINFORCEstyle objective.
\[J(\theta; s,a)\ =\ \mathcal{A}(s,a)\,\log\pi_\theta(as)\]This objective has the property that its gradient with respect to \(\theta\) yields the REINFORCEstyle policy gradient.
 Parameters:
pi (Policy) – The parametrized policy \(\pi_\theta(as)\).
optimizer (optax optimizer, optional) – An optaxstyle optimizer. The default optimizer is
optax.adam(1e3)
.regularizer (Regularizer, optional) – A policy regularizer, see
coax.regularizers
.
 apply_grads(grads, function_state)¶
Update the model parameters (weights) of the underlying function approximator given precomputed gradients.
This method is useful in situations in which computation of the gradients is deligated to a separate (remote) process.
 Parameters:
grads (pytree with ndarray leaves) – A batch of gradients, generated by the
grads
method.function_state (pytree) – The internal state of the forwardpass function. See
Policy.function_state
andhaiku.transform_with_state()
for more details.
 grads_and_metrics(transition_batch, Adv)¶
Compute the gradients associated with a batch of transitions with corresponding advantages.
 Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
Adv (ndarray) – A batch of advantages \(\mathcal{A}(s,a)=q(s,a)v(s)\).
 Returns:
grads (pytree with ndarray leaves) – A batch of gradients.
function_state (pytree) – The internal state of the forwardpass function. See
Policy.function_state
andhaiku.transform_with_state()
for more details.metrics (dict of scalar ndarrays) – The structure of the metrics dict is
{name: score}
.
 update(transition_batch, Adv)¶
Update the model parameters (weights) of the underlying function approximator.
 Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
Adv (ndarray) – A batch of advantages \(\mathcal{A}(s,a)=q(s,a)v(s)\).
 Returns:
metrics (dict of scalar ndarrays) – The structure of the metrics dict is
{name: score}
.
 class coax.policy_objectives.PPOClip(pi, optimizer=None, regularizer=None, epsilon=0.2)[source]¶
PPOclip policy objective.
\[J(\theta; s,a)\ =\ \min\Big( \rho_\theta\,\mathcal{A}(s,a)\,,\ \bar{\rho}_\theta\,\mathcal{A}(s,a)\Big)\]where \(\rho_\theta\) and \(\bar{\rho}_\theta\) are the bare and clipped probability ratios, respectively:
\[\rho_\theta\ =\ \frac{\pi_\theta(as)}{\pi_{\theta_\text{old}}(as)}\ , \qquad \bar{\rho}_\theta\ =\ \big[\rho_\theta\big]^{1+\epsilon}_{1\epsilon}\]This objective has the property that it allows for slightly more offpolicy updates than the vanilla policy gradient.
 Parameters:
pi (Policy) – The parametrized policy \(\pi_\theta(as)\).
optimizer (optax optimizer, optional) – An optaxstyle optimizer. The default optimizer is
optax.adam(1e3)
.regularizer (Regularizer, optional) – A policy regularizer, see
coax.regularizers
.epsilon (positive float, optional) – The clipping parameter \(\epsilon\) that is used to defined the clipped importance weight \(\bar{\rho}\).
 apply_grads(grads, function_state)¶
Update the model parameters (weights) of the underlying function approximator given precomputed gradients.
This method is useful in situations in which computation of the gradients is deligated to a separate (remote) process.
 Parameters:
grads (pytree with ndarray leaves) – A batch of gradients, generated by the
grads
method.function_state (pytree) – The internal state of the forwardpass function. See
Policy.function_state
andhaiku.transform_with_state()
for more details.
 grads_and_metrics(transition_batch, Adv)¶
Compute the gradients associated with a batch of transitions with corresponding advantages.
 Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
Adv (ndarray) – A batch of advantages \(\mathcal{A}(s,a)=q(s,a)v(s)\).
 Returns:
grads (pytree with ndarray leaves) – A batch of gradients.
function_state (pytree) – The internal state of the forwardpass function. See
Policy.function_state
andhaiku.transform_with_state()
for more details.metrics (dict of scalar ndarrays) – The structure of the metrics dict is
{name: score}
.
 update(transition_batch, Adv)¶
Update the model parameters (weights) of the underlying function approximator.
 Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
Adv (ndarray) – A batch of advantages \(\mathcal{A}(s,a)=q(s,a)v(s)\).
 Returns:
metrics (dict of scalar ndarrays) – The structure of the metrics dict is
{name: score}
.
 class coax.policy_objectives.DeterministicPG(pi, q_targ, optimizer=None, regularizer=None)[source]¶
A deterministic policygradient objective, a.k.a. DDPGstyle objective. See Deep Deterministic Policy Gradient and references therein for more details.
\[J(\theta; s,a)\ =\ q_\text{targ}(s, a_\theta(s))\]Here \(a_\theta(s)\) is the mode of the underlying conditional probability distribution \(\pi_\theta(.s)\). See e.g. the
mode
method ofcoax.proba_dists.NormalDist
. In other words, we evaluate the policy according to the current estimate of its bestcase performance.This objective has the property that it explicitly maximizes the qvalue.
The full policy loss is constructed as:
\[\text{loss}(\theta; s,a)\ =\ J(\theta; s,a)  \beta_\text{ent}\,H[\pi_\theta] + \beta_\text{kldiv}\, KL[\pi_{\theta_\text{prior}}, \pi_\theta]\]N.B. in order to unclutter the notation we abbreviated \(\pi(.s)\) by \(\pi\).
 Parameters:
pi (Policy) – The parametrized policy \(\pi_\theta(as)\).
q_targ (Q) – The target stateaction value function \(q_\text{targ}(s,a)\).
optimizer (optax optimizer, optional) – An optaxstyle optimizer. The default optimizer is
optax.adam(1e3)
.regularizer (Regularizer, optional) – A policy regularizer, see
coax.regularizers
.
 apply_grads(grads, function_state)¶
Update the model parameters (weights) of the underlying function approximator given precomputed gradients.
This method is useful in situations in which computation of the gradients is deligated to a separate (remote) process.
 Parameters:
grads (pytree with ndarray leaves) – A batch of gradients, generated by the
grads
method.function_state (pytree) – The internal state of the forwardpass function. See
Policy.function_state
andhaiku.transform_with_state()
for more details.
 grads_and_metrics(transition_batch, Adv=None)[source]¶
Compute the gradients associated with a batch of transitions with corresponding advantages.
 Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
Adv (ndarray, ignored) – This input is ignored; it is included for consistency with other policy objectives.
 Returns:
grads (pytree with ndarray leaves) – A batch of gradients.
function_state (pytree) – The internal state of the forwardpass function. See
Policy.function_state
andhaiku.transform_with_state()
for more details.metrics (dict of scalar ndarrays) – The structure of the metrics dict is
{name: score}
.
 update(transition_batch, Adv=None)[source]¶
Update the model parameters (weights) of the underlying function approximator.
 Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
Adv (ndarray, ignored) – This input is ignored; it is included for consistency with other policy objectives.
 Returns:
metrics (dict of scalar ndarrays) – The structure of the metrics dict is
{name: score}
.
 class coax.policy_objectives.SoftPG(pi, q_targ_list, optimizer=None, regularizer=None)[source]¶
 apply_grads(grads, function_state)¶
Update the model parameters (weights) of the underlying function approximator given precomputed gradients.
This method is useful in situations in which computation of the gradients is deligated to a separate (remote) process.
 Parameters:
grads (pytree with ndarray leaves) – A batch of gradients, generated by the
grads
method.function_state (pytree) – The internal state of the forwardpass function. See
Policy.function_state
andhaiku.transform_with_state()
for more details.
 grads_and_metrics(transition_batch, Adv=None)[source]¶
Compute the gradients associated with a batch of transitions with corresponding advantages.
 Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
Adv (ndarray, ignored) – This input is ignored; it is included for consistency with other policy objectives.
 Returns:
grads (pytree with ndarray leaves) – A batch of gradients.
function_state (pytree) – The internal state of the forwardpass function. See
Policy.function_state
andhaiku.transform_with_state()
for more details.metrics (dict of scalar ndarrays) – The structure of the metrics dict is
{name: score}
.
 update(transition_batch, Adv=None)[source]¶
Update the model parameters (weights) of the underlying function approximator.
 Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
Adv (ndarray, ignored) – This input is ignored; it is included for consistency with other policy objectives.
 Returns:
metrics (dict of scalar ndarrays) – The structure of the metrics dict is
{name: score}
.