Policy Objectives

coax.policy_objectives.VanillaPG

A vanilla policy-gradient objective, a.k.a.

coax.policy_objectives.PPOClip

PPO-clip policy objective.

coax.policy_objectives.DeterministicPG

A deterministic policy-gradient objective, a.k.a.

coax.policy_objectives.SoftPG


This is a collection of policy objectives that can be used in policy-gradient methods.

Object Reference

class coax.policy_objectives.VanillaPG(pi, optimizer=None, regularizer=None)[source]

A vanilla policy-gradient objective, a.k.a. REINFORCE-style objective.

\[J(\theta; s,a)\ =\ \mathcal{A}(s,a)\,\log\pi_\theta(a|s)\]

This objective has the property that its gradient with respect to \(\theta\) yields the REINFORCE-style policy gradient.

Parameters:
  • pi (Policy) – The parametrized policy \(\pi_\theta(a|s)\).

  • optimizer (optax optimizer, optional) – An optax-style optimizer. The default optimizer is optax.adam(1e-3).

  • regularizer (Regularizer, optional) – A policy regularizer, see coax.regularizers.

apply_grads(grads, function_state)

Update the model parameters (weights) of the underlying function approximator given pre-computed gradients.

This method is useful in situations in which computation of the gradients is deligated to a separate (remote) process.

Parameters:
  • grads (pytree with ndarray leaves) – A batch of gradients, generated by the grads method.

  • function_state (pytree) – The internal state of the forward-pass function. See Policy.function_state and haiku.transform_with_state() for more details.

grads_and_metrics(transition_batch, Adv)

Compute the gradients associated with a batch of transitions with corresponding advantages.

Parameters:
  • transition_batch (TransitionBatch) – A batch of transitions.

  • Adv (ndarray) – A batch of advantages \(\mathcal{A}(s,a)=q(s,a)-v(s)\).

Returns:

  • grads (pytree with ndarray leaves) – A batch of gradients.

  • function_state (pytree) – The internal state of the forward-pass function. See Policy.function_state and haiku.transform_with_state() for more details.

  • metrics (dict of scalar ndarrays) – The structure of the metrics dict is {name: score}.

update(transition_batch, Adv)

Update the model parameters (weights) of the underlying function approximator.

Parameters:
  • transition_batch (TransitionBatch) – A batch of transitions.

  • Adv (ndarray) – A batch of advantages \(\mathcal{A}(s,a)=q(s,a)-v(s)\).

Returns:

metrics (dict of scalar ndarrays) – The structure of the metrics dict is {name: score}.

class coax.policy_objectives.PPOClip(pi, optimizer=None, regularizer=None, epsilon=0.2)[source]

PPO-clip policy objective.

\[J(\theta; s,a)\ =\ \min\Big( \rho_\theta\,\mathcal{A}(s,a)\,,\ \bar{\rho}_\theta\,\mathcal{A}(s,a)\Big)\]

where \(\rho_\theta\) and \(\bar{\rho}_\theta\) are the bare and clipped probability ratios, respectively:

\[\rho_\theta\ =\ \frac{\pi_\theta(a|s)}{\pi_{\theta_\text{old}}(a|s)}\ , \qquad \bar{\rho}_\theta\ =\ \big[\rho_\theta\big]^{1+\epsilon}_{1-\epsilon}\]

This objective has the property that it allows for slightly more off-policy updates than the vanilla policy gradient.

Parameters:
  • pi (Policy) – The parametrized policy \(\pi_\theta(a|s)\).

  • optimizer (optax optimizer, optional) – An optax-style optimizer. The default optimizer is optax.adam(1e-3).

  • regularizer (Regularizer, optional) – A policy regularizer, see coax.regularizers.

  • epsilon (positive float, optional) – The clipping parameter \(\epsilon\) that is used to defined the clipped importance weight \(\bar{\rho}\).

apply_grads(grads, function_state)

Update the model parameters (weights) of the underlying function approximator given pre-computed gradients.

This method is useful in situations in which computation of the gradients is deligated to a separate (remote) process.

Parameters:
  • grads (pytree with ndarray leaves) – A batch of gradients, generated by the grads method.

  • function_state (pytree) – The internal state of the forward-pass function. See Policy.function_state and haiku.transform_with_state() for more details.

grads_and_metrics(transition_batch, Adv)

Compute the gradients associated with a batch of transitions with corresponding advantages.

Parameters:
  • transition_batch (TransitionBatch) – A batch of transitions.

  • Adv (ndarray) – A batch of advantages \(\mathcal{A}(s,a)=q(s,a)-v(s)\).

Returns:

  • grads (pytree with ndarray leaves) – A batch of gradients.

  • function_state (pytree) – The internal state of the forward-pass function. See Policy.function_state and haiku.transform_with_state() for more details.

  • metrics (dict of scalar ndarrays) – The structure of the metrics dict is {name: score}.

update(transition_batch, Adv)

Update the model parameters (weights) of the underlying function approximator.

Parameters:
  • transition_batch (TransitionBatch) – A batch of transitions.

  • Adv (ndarray) – A batch of advantages \(\mathcal{A}(s,a)=q(s,a)-v(s)\).

Returns:

metrics (dict of scalar ndarrays) – The structure of the metrics dict is {name: score}.

class coax.policy_objectives.DeterministicPG(pi, q_targ, optimizer=None, regularizer=None)[source]

A deterministic policy-gradient objective, a.k.a. DDPG-style objective. See Deep Deterministic Policy Gradient and references therein for more details.

\[J(\theta; s,a)\ =\ q_\text{targ}(s, a_\theta(s))\]

Here \(a_\theta(s)\) is the mode of the underlying conditional probability distribution \(\pi_\theta(.|s)\). See e.g. the mode method of coax.proba_dists.NormalDist. In other words, we evaluate the policy according to the current estimate of its best-case performance.

This objective has the property that it explicitly maximizes the q-value.

The full policy loss is constructed as:

\[\text{loss}(\theta; s,a)\ =\ -J(\theta; s,a) - \beta_\text{ent}\,H[\pi_\theta] + \beta_\text{kl-div}\, KL[\pi_{\theta_\text{prior}}, \pi_\theta]\]

N.B. in order to unclutter the notation we abbreviated \(\pi(.|s)\) by \(\pi\).

Parameters:
  • pi (Policy) – The parametrized policy \(\pi_\theta(a|s)\).

  • q_targ (Q) – The target state-action value function \(q_\text{targ}(s,a)\).

  • optimizer (optax optimizer, optional) – An optax-style optimizer. The default optimizer is optax.adam(1e-3).

  • regularizer (Regularizer, optional) – A policy regularizer, see coax.regularizers.

apply_grads(grads, function_state)

Update the model parameters (weights) of the underlying function approximator given pre-computed gradients.

This method is useful in situations in which computation of the gradients is deligated to a separate (remote) process.

Parameters:
  • grads (pytree with ndarray leaves) – A batch of gradients, generated by the grads method.

  • function_state (pytree) – The internal state of the forward-pass function. See Policy.function_state and haiku.transform_with_state() for more details.

grads_and_metrics(transition_batch, Adv=None)[source]

Compute the gradients associated with a batch of transitions with corresponding advantages.

Parameters:
  • transition_batch (TransitionBatch) – A batch of transitions.

  • Adv (ndarray, ignored) – This input is ignored; it is included for consistency with other policy objectives.

Returns:

  • grads (pytree with ndarray leaves) – A batch of gradients.

  • function_state (pytree) – The internal state of the forward-pass function. See Policy.function_state and haiku.transform_with_state() for more details.

  • metrics (dict of scalar ndarrays) – The structure of the metrics dict is {name: score}.

update(transition_batch, Adv=None)[source]

Update the model parameters (weights) of the underlying function approximator.

Parameters:
  • transition_batch (TransitionBatch) – A batch of transitions.

  • Adv (ndarray, ignored) – This input is ignored; it is included for consistency with other policy objectives.

Returns:

metrics (dict of scalar ndarrays) – The structure of the metrics dict is {name: score}.

class coax.policy_objectives.SoftPG(pi, q_targ_list, optimizer=None, regularizer=None)[source]
apply_grads(grads, function_state)

Update the model parameters (weights) of the underlying function approximator given pre-computed gradients.

This method is useful in situations in which computation of the gradients is deligated to a separate (remote) process.

Parameters:
  • grads (pytree with ndarray leaves) – A batch of gradients, generated by the grads method.

  • function_state (pytree) – The internal state of the forward-pass function. See Policy.function_state and haiku.transform_with_state() for more details.

grads_and_metrics(transition_batch, Adv=None)[source]

Compute the gradients associated with a batch of transitions with corresponding advantages.

Parameters:
  • transition_batch (TransitionBatch) – A batch of transitions.

  • Adv (ndarray, ignored) – This input is ignored; it is included for consistency with other policy objectives.

Returns:

  • grads (pytree with ndarray leaves) – A batch of gradients.

  • function_state (pytree) – The internal state of the forward-pass function. See Policy.function_state and haiku.transform_with_state() for more details.

  • metrics (dict of scalar ndarrays) – The structure of the metrics dict is {name: score}.

update(transition_batch, Adv=None)[source]

Update the model parameters (weights) of the underlying function approximator.

Parameters:
  • transition_batch (TransitionBatch) – A batch of transitions.

  • Adv (ndarray, ignored) – This input is ignored; it is included for consistency with other policy objectives.

Returns:

metrics (dict of scalar ndarrays) – The structure of the metrics dict is {name: score}.