Policy Objectives¶
A vanilla policy-gradient objective, a.k.a. |
|
PPO-clip policy objective. |
|
A deterministic policy-gradient objective, a.k.a. |
|
This is a collection of policy objectives that can be used in policy-gradient methods.
Object Reference¶
- class coax.policy_objectives.VanillaPG(pi, optimizer=None, regularizer=None)[source]¶
A vanilla policy-gradient objective, a.k.a. REINFORCE-style objective.
\[J(\theta; s,a)\ =\ \mathcal{A}(s,a)\,\log\pi_\theta(a|s)\]This objective has the property that its gradient with respect to \(\theta\) yields the REINFORCE-style policy gradient.
- Parameters:
pi (Policy) – The parametrized policy \(\pi_\theta(a|s)\).
optimizer (optax optimizer, optional) – An optax-style optimizer. The default optimizer is
optax.adam(1e-3)
.regularizer (Regularizer, optional) – A policy regularizer, see
coax.regularizers
.
- apply_grads(grads, function_state)¶
Update the model parameters (weights) of the underlying function approximator given pre-computed gradients.
This method is useful in situations in which computation of the gradients is deligated to a separate (remote) process.
- Parameters:
grads (pytree with ndarray leaves) – A batch of gradients, generated by the
grads
method.function_state (pytree) – The internal state of the forward-pass function. See
Policy.function_state
andhaiku.transform_with_state()
for more details.
- grads_and_metrics(transition_batch, Adv)¶
Compute the gradients associated with a batch of transitions with corresponding advantages.
- Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
Adv (ndarray) – A batch of advantages \(\mathcal{A}(s,a)=q(s,a)-v(s)\).
- Returns:
grads (pytree with ndarray leaves) – A batch of gradients.
function_state (pytree) – The internal state of the forward-pass function. See
Policy.function_state
andhaiku.transform_with_state()
for more details.metrics (dict of scalar ndarrays) – The structure of the metrics dict is
{name: score}
.
- update(transition_batch, Adv)¶
Update the model parameters (weights) of the underlying function approximator.
- Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
Adv (ndarray) – A batch of advantages \(\mathcal{A}(s,a)=q(s,a)-v(s)\).
- Returns:
metrics (dict of scalar ndarrays) – The structure of the metrics dict is
{name: score}
.
- class coax.policy_objectives.PPOClip(pi, optimizer=None, regularizer=None, epsilon=0.2)[source]¶
PPO-clip policy objective.
\[J(\theta; s,a)\ =\ \min\Big( \rho_\theta\,\mathcal{A}(s,a)\,,\ \bar{\rho}_\theta\,\mathcal{A}(s,a)\Big)\]where \(\rho_\theta\) and \(\bar{\rho}_\theta\) are the bare and clipped probability ratios, respectively:
\[\rho_\theta\ =\ \frac{\pi_\theta(a|s)}{\pi_{\theta_\text{old}}(a|s)}\ , \qquad \bar{\rho}_\theta\ =\ \big[\rho_\theta\big]^{1+\epsilon}_{1-\epsilon}\]This objective has the property that it allows for slightly more off-policy updates than the vanilla policy gradient.
- Parameters:
pi (Policy) – The parametrized policy \(\pi_\theta(a|s)\).
optimizer (optax optimizer, optional) – An optax-style optimizer. The default optimizer is
optax.adam(1e-3)
.regularizer (Regularizer, optional) – A policy regularizer, see
coax.regularizers
.epsilon (positive float, optional) – The clipping parameter \(\epsilon\) that is used to defined the clipped importance weight \(\bar{\rho}\).
- apply_grads(grads, function_state)¶
Update the model parameters (weights) of the underlying function approximator given pre-computed gradients.
This method is useful in situations in which computation of the gradients is deligated to a separate (remote) process.
- Parameters:
grads (pytree with ndarray leaves) – A batch of gradients, generated by the
grads
method.function_state (pytree) – The internal state of the forward-pass function. See
Policy.function_state
andhaiku.transform_with_state()
for more details.
- grads_and_metrics(transition_batch, Adv)¶
Compute the gradients associated with a batch of transitions with corresponding advantages.
- Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
Adv (ndarray) – A batch of advantages \(\mathcal{A}(s,a)=q(s,a)-v(s)\).
- Returns:
grads (pytree with ndarray leaves) – A batch of gradients.
function_state (pytree) – The internal state of the forward-pass function. See
Policy.function_state
andhaiku.transform_with_state()
for more details.metrics (dict of scalar ndarrays) – The structure of the metrics dict is
{name: score}
.
- update(transition_batch, Adv)¶
Update the model parameters (weights) of the underlying function approximator.
- Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
Adv (ndarray) – A batch of advantages \(\mathcal{A}(s,a)=q(s,a)-v(s)\).
- Returns:
metrics (dict of scalar ndarrays) – The structure of the metrics dict is
{name: score}
.
- class coax.policy_objectives.DeterministicPG(pi, q_targ, optimizer=None, regularizer=None)[source]¶
A deterministic policy-gradient objective, a.k.a. DDPG-style objective. See Deep Deterministic Policy Gradient and references therein for more details.
\[J(\theta; s,a)\ =\ q_\text{targ}(s, a_\theta(s))\]Here \(a_\theta(s)\) is the mode of the underlying conditional probability distribution \(\pi_\theta(.|s)\). See e.g. the
mode
method ofcoax.proba_dists.NormalDist
. In other words, we evaluate the policy according to the current estimate of its best-case performance.This objective has the property that it explicitly maximizes the q-value.
The full policy loss is constructed as:
\[\text{loss}(\theta; s,a)\ =\ -J(\theta; s,a) - \beta_\text{ent}\,H[\pi_\theta] + \beta_\text{kl-div}\, KL[\pi_{\theta_\text{prior}}, \pi_\theta]\]N.B. in order to unclutter the notation we abbreviated \(\pi(.|s)\) by \(\pi\).
- Parameters:
pi (Policy) – The parametrized policy \(\pi_\theta(a|s)\).
q_targ (Q) – The target state-action value function \(q_\text{targ}(s,a)\).
optimizer (optax optimizer, optional) – An optax-style optimizer. The default optimizer is
optax.adam(1e-3)
.regularizer (Regularizer, optional) – A policy regularizer, see
coax.regularizers
.
- apply_grads(grads, function_state)¶
Update the model parameters (weights) of the underlying function approximator given pre-computed gradients.
This method is useful in situations in which computation of the gradients is deligated to a separate (remote) process.
- Parameters:
grads (pytree with ndarray leaves) – A batch of gradients, generated by the
grads
method.function_state (pytree) – The internal state of the forward-pass function. See
Policy.function_state
andhaiku.transform_with_state()
for more details.
- grads_and_metrics(transition_batch, Adv=None)[source]¶
Compute the gradients associated with a batch of transitions with corresponding advantages.
- Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
Adv (ndarray, ignored) – This input is ignored; it is included for consistency with other policy objectives.
- Returns:
grads (pytree with ndarray leaves) – A batch of gradients.
function_state (pytree) – The internal state of the forward-pass function. See
Policy.function_state
andhaiku.transform_with_state()
for more details.metrics (dict of scalar ndarrays) – The structure of the metrics dict is
{name: score}
.
- update(transition_batch, Adv=None)[source]¶
Update the model parameters (weights) of the underlying function approximator.
- Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
Adv (ndarray, ignored) – This input is ignored; it is included for consistency with other policy objectives.
- Returns:
metrics (dict of scalar ndarrays) – The structure of the metrics dict is
{name: score}
.
- class coax.policy_objectives.SoftPG(pi, q_targ_list, optimizer=None, regularizer=None)[source]¶
- apply_grads(grads, function_state)¶
Update the model parameters (weights) of the underlying function approximator given pre-computed gradients.
This method is useful in situations in which computation of the gradients is deligated to a separate (remote) process.
- Parameters:
grads (pytree with ndarray leaves) – A batch of gradients, generated by the
grads
method.function_state (pytree) – The internal state of the forward-pass function. See
Policy.function_state
andhaiku.transform_with_state()
for more details.
- grads_and_metrics(transition_batch, Adv=None)[source]¶
Compute the gradients associated with a batch of transitions with corresponding advantages.
- Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
Adv (ndarray, ignored) – This input is ignored; it is included for consistency with other policy objectives.
- Returns:
grads (pytree with ndarray leaves) – A batch of gradients.
function_state (pytree) – The internal state of the forward-pass function. See
Policy.function_state
andhaiku.transform_with_state()
for more details.metrics (dict of scalar ndarrays) – The structure of the metrics dict is
{name: score}
.
- update(transition_batch, Adv=None)[source]¶
Update the model parameters (weights) of the underlying function approximator.
- Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
Adv (ndarray, ignored) – This input is ignored; it is included for consistency with other policy objectives.
- Returns:
metrics (dict of scalar ndarrays) – The structure of the metrics dict is
{name: score}
.