TD Learning¶
TDlearning for state value functions \(v(s)\). 

TDlearning with SARSA updates. 

TDlearning with expectedSARSA updates. 

TDlearning with QLearning updates. 

TDlearning with DoubleDQN style double qlearning updates, in which the target network is only used in selecting the wouldbe next action. 

TDlearning with soft Qlearning updates. 

TDlearning with TD3 style double qlearning updates, in which the target network is only used in selecting the wouldbe next action. 

This is a collection of objects that are used to update value functions via Temporal Difference
(TD) learning. A state value function coax.V
is using coax.td_learning.SimpleTD
.
To update a stateaction value function coax.Q
, there are multiple options available. The
difference between the options are the manner in which the TDtarget is constructed.
Object Reference¶
 class coax.td_learning.SimpleTD(v, v_targ=None, optimizer=None, loss_function=None, policy_regularizer=None)[source]¶
TDlearning for state value functions \(v(s)\). The \(n\)step bootstrapped target is constructed as:
\[G^{(n)}_t\ =\ R^{(n)}_t + I^{(n)}_t\,v_\text{targ}(S_{t+n})\]where
\[\begin{split}R^{(n)}_t\ &=\ \sum_{k=0}^{n1}\gamma^kR_{t+k} \\ I^{(n)}_t\ &=\ \left\{\begin{matrix} 0 & \text{if $S_{t+n}$ is a terminal state} \\ \gamma^n & \text{otherwise} \end{matrix}\right.\end{split}\] Parameters:
v (V) – The main state value function to update.
v_targ (V, optional) – The state value function that is used for constructing the TDtarget. If this is left unspecified, we set
v_targ = v
internally.optimizer (optax optimizer, optional) – An optaxstyle optimizer. The default optimizer is
optax.adam(1e3)
.loss_function (callable, optional) –
The loss function that will be used to regress to the (bootstrapped) target. The loss function is expected to be of the form:
\[L(y_\text{true}, y_\text{pred}, w)\in\mathbb{R}\]where \(w>0\) are sample weights. If left unspecified, this defaults to
coax.value_losses.huber()
. Check out thecoax.value_losses
module for other predefined loss functions.policy_regularizer (Regularizer, optional) –
If provided, this policy regularizer is added to the TDtarget. A typical example is to use an
coax.regularizers.EntropyRegularizer
, which adds the policy entropy to the target. In this case, we minimize the following loss shifted by the entropy term:\[L(y_\text{true} + \beta\,H[\pi], y_\text{pred})\]Note that the coefficient \(\beta\) plays the role of the temperature in SACstyle agents.
 apply_grads(grads, function_state)¶
Update the model parameters (weights) of the underlying function approximator given precomputed gradients.
This method is useful in situations in which computation of the gradients is deligated to a separate (remote) process.
 Parameters:
grads (pytree with ndarray leaves) – A batch of gradients, generated by the
grads
method.function_state (pytree) – The internal state of the forwardpass function. See
Q.function_state
andhaiku.transform_with_state()
for more details.
 grads_and_metrics(transition_batch)¶
Compute the gradients associated with a batch of transitions.
 Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
 Returns:
grads (pytree with ndarray leaves) – A batch of gradients.
function_state (pytree) – The internal state of the forwardpass function. See
Q.function_state
andhaiku.transform_with_state()
for more details.metrics (dict of scalar ndarrays) – The structure of the metrics dict is
{name: score}
.td_error (ndarray) – The nonaggregated TDerrors,
shape == (batch_size,)
.
 td_error(transition_batch)¶
Compute the TDerrors associated with a batch of transitions. We define the TDerror as the negative gradient of the
loss_function
with respect to the predicted value:\[\text{td_error}_i\ =\ \frac{\partial L(y, \hat{y})}{\partial \hat{y}_i}\]Note that this reduces to the ordinary definition \(\text{td_error}=y\hat{y}\) when we use the
coax.value_losses.mse()
loss funtion. Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
 Returns:
td_errors (ndarray, shape: [batch_size]) – A batch of TDerrors.
 update(transition_batch, return_td_error=False)¶
Update the model parameters (weights) of the underlying function approximator.
 Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
return_td_error (bool, optional) – Whether to return the TDerrors.
 Returns:
metrics (dict of scalar ndarrays) – The structure of the metrics dict is
{name: score}
.td_error (ndarray, optional) – The nonaggregated TDerrors,
shape == (batch_size,)
. This is only returned if we setreturn_td_error=True
.
 class coax.td_learning.Sarsa(q, q_targ=None, optimizer=None, loss_function=None, policy_regularizer=None)[source]¶
TDlearning with SARSA updates. The \(n\)step bootstrapped target is constructed as:
\[G^{(n)}_t\ =\ R^{(n)}_t + I^{(n)}_t\,q_\text{targ}(S_{t+n}, A_{t+n})\]where \(A_{t+n}\) is sampled from experience and
\[\begin{split}R^{(n)}_t\ &=\ \sum_{k=0}^{n1}\gamma^kR_{t+k} \\ I^{(n)}_t\ &=\ \left\{\begin{matrix} 0 & \text{if $S_{t+n}$ is a terminal state} \\ \gamma^n & \text{otherwise} \end{matrix}\right.\end{split}\] Parameters:
q (Q) – The main qfunction to update.
q_targ (Q, optional) – The qfunction that is used for constructing the TDtarget. If this is left unspecified, we set
q_targ = q
internally.optimizer (optax optimizer, optional) – An optaxstyle optimizer. The default optimizer is
optax.adam(1e3)
.loss_function (callable, optional) –
The loss function that will be used to regress to the (bootstrapped) target. The loss function is expected to be of the form:
\[L(y_\text{true}, y_\text{pred}, w)\in\mathbb{R}\]where \(w>0\) are sample weights. If left unspecified, this defaults to
coax.value_losses.huber()
. Check out thecoax.value_losses
module for other predefined loss functions.policy_regularizer (Regularizer, optional) –
If provided, this policy regularizer is added to the TDtarget. A typical example is to use an
coax.regularizers.EntropyRegularizer
, which adds the policy entropy to the target. In this case, we minimize the following loss shifted by the entropy term:\[L(y_\text{true} + \beta\,H[\pi], y_\text{pred})\]Note that the coefficient \(\beta\) plays the role of the temperature in SACstyle agents.
 apply_grads(grads, function_state)¶
Update the model parameters (weights) of the underlying function approximator given precomputed gradients.
This method is useful in situations in which computation of the gradients is deligated to a separate (remote) process.
 Parameters:
grads (pytree with ndarray leaves) – A batch of gradients, generated by the
grads
method.function_state (pytree) – The internal state of the forwardpass function. See
Q.function_state
andhaiku.transform_with_state()
for more details.
 grads_and_metrics(transition_batch)¶
Compute the gradients associated with a batch of transitions.
 Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
 Returns:
grads (pytree with ndarray leaves) – A batch of gradients.
function_state (pytree) – The internal state of the forwardpass function. See
Q.function_state
andhaiku.transform_with_state()
for more details.metrics (dict of scalar ndarrays) – The structure of the metrics dict is
{name: score}
.td_error (ndarray) – The nonaggregated TDerrors,
shape == (batch_size,)
.
 td_error(transition_batch)¶
Compute the TDerrors associated with a batch of transitions. We define the TDerror as the negative gradient of the
loss_function
with respect to the predicted value:\[\text{td_error}_i\ =\ \frac{\partial L(y, \hat{y})}{\partial \hat{y}_i}\]Note that this reduces to the ordinary definition \(\text{td_error}=y\hat{y}\) when we use the
coax.value_losses.mse()
loss funtion. Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
 Returns:
td_errors (ndarray, shape: [batch_size]) – A batch of TDerrors.
 update(transition_batch, return_td_error=False)¶
Update the model parameters (weights) of the underlying function approximator.
 Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
return_td_error (bool, optional) – Whether to return the TDerrors.
 Returns:
metrics (dict of scalar ndarrays) – The structure of the metrics dict is
{name: score}
.td_error (ndarray, optional) – The nonaggregated TDerrors,
shape == (batch_size,)
. This is only returned if we setreturn_td_error=True
.
 class coax.td_learning.ExpectedSarsa(q, pi_targ, q_targ=None, optimizer=None, loss_function=None, policy_regularizer=None)[source]¶
TDlearning with expectedSARSA updates. The \(n\)step bootstrapped target is constructed as:
\[G^{(n)}_t\ =\ R^{(n)}_t + I^{(n)}_t\,\mathop{\mathbb{E}}_{a\sim\pi_\text{targ}(.S_{t+n})}\, q_\text{targ}\left(S_{t+n}, a\right)\]Note that ordinary
SARSA
target is the sampled estimate of the above target.Also, as usual, the \(n\)step reward and indicator are defined as:
\[\begin{split}R^{(n)}_t\ &=\ \sum_{k=0}^{n1}\gamma^kR_{t+k} \\ I^{(n)}_t\ &=\ \left\{\begin{matrix} 0 & \text{if $S_{t+n}$ is a terminal state} \\ \gamma^n & \text{otherwise} \end{matrix}\right.\end{split}\] Parameters:
q (Q) – The main qfunction to update.
pi_targ (Policy) – The policy that is used for constructing the TDtarget.
q_targ (Q, optional) – The qfunction that is used for constructing the TDtarget. If this is left unspecified, we set
q_targ = q
internally.optimizer (optax optimizer, optional) – An optaxstyle optimizer. The default optimizer is
optax.adam(1e3)
.loss_function (callable, optional) –
The loss function that will be used to regress to the (bootstrapped) target. The loss function is expected to be of the form:
\[L(y_\text{true}, y_\text{pred}, w)\in\mathbb{R}\]where \(w>0\) are sample weights. If left unspecified, this defaults to
coax.value_losses.huber()
. Check out thecoax.value_losses
module for other predefined loss functions.policy_regularizer (Regularizer, optional) –
If provided, this policy regularizer is added to the TDtarget. A typical example is to use an
coax.regularizers.EntropyRegularizer
, which adds the policy entropy to the target. In this case, we minimize the following loss shifted by the entropy term:\[L(y_\text{true} + \beta\,H[\pi], y_\text{pred})\]Note that the coefficient \(\beta\) plays the role of the temperature in SACstyle agents.
 apply_grads(grads, function_state)¶
Update the model parameters (weights) of the underlying function approximator given precomputed gradients.
This method is useful in situations in which computation of the gradients is deligated to a separate (remote) process.
 Parameters:
grads (pytree with ndarray leaves) – A batch of gradients, generated by the
grads
method.function_state (pytree) – The internal state of the forwardpass function. See
Q.function_state
andhaiku.transform_with_state()
for more details.
 grads_and_metrics(transition_batch)¶
Compute the gradients associated with a batch of transitions.
 Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
 Returns:
grads (pytree with ndarray leaves) – A batch of gradients.
function_state (pytree) – The internal state of the forwardpass function. See
Q.function_state
andhaiku.transform_with_state()
for more details.metrics (dict of scalar ndarrays) – The structure of the metrics dict is
{name: score}
.td_error (ndarray) – The nonaggregated TDerrors,
shape == (batch_size,)
.
 td_error(transition_batch)¶
Compute the TDerrors associated with a batch of transitions. We define the TDerror as the negative gradient of the
loss_function
with respect to the predicted value:\[\text{td_error}_i\ =\ \frac{\partial L(y, \hat{y})}{\partial \hat{y}_i}\]Note that this reduces to the ordinary definition \(\text{td_error}=y\hat{y}\) when we use the
coax.value_losses.mse()
loss funtion. Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
 Returns:
td_errors (ndarray, shape: [batch_size]) – A batch of TDerrors.
 update(transition_batch, return_td_error=False)¶
Update the model parameters (weights) of the underlying function approximator.
 Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
return_td_error (bool, optional) – Whether to return the TDerrors.
 Returns:
metrics (dict of scalar ndarrays) – The structure of the metrics dict is
{name: score}
.td_error (ndarray, optional) – The nonaggregated TDerrors,
shape == (batch_size,)
. This is only returned if we setreturn_td_error=True
.
 class coax.td_learning.QLearning(q, pi_targ=None, q_targ=None, optimizer=None, loss_function=None, policy_regularizer=None)[source]¶
TDlearning with QLearning updates.
The \(n\)step bootstrapped target for discrete actions is constructed as:
\[G^{(n)}_t\ =\ R^{(n)}_t + I^{(n)}_t\,\max_aq_\text{targ}\left(S_{t+n}, a\right)\]For nondiscrete action spaces, this uses a DDPGstyle target:
\[G^{(n)}_t\ =\ R^{(n)}_t + I^{(n)}_t\,q_\text{targ}\left( S_{t+n}, a_\text{targ}(S_{t+n})\right)\]where \(a_\text{targ}(s)\) is the mode of the underlying conditional probability distribution \(\pi_\text{targ}(.s)\). Even though these two formulations of the qlearning target are equivalent, the implementation of the latter does require additional input, namely
pi_targ
.The \(n\)step reward and indicator (referenced above) are defined as:
\[\begin{split}R^{(n)}_t\ &=\ \sum_{k=0}^{n1}\gamma^kR_{t+k} \\ I^{(n)}_t\ &=\ \left\{\begin{matrix} 0 & \text{if $S_{t+n}$ is a terminal state} \\ \gamma^n & \text{otherwise} \end{matrix}\right.\end{split}\] Parameters:
q (Q) – The main qfunction to update.
pi_targ (Policy, optional) – The policy that is used for constructing the TDtarget. This is ignored if the action space is discrete and required otherwise.
q_targ (Q, optional) – The qfunction that is used for constructing the TDtarget. If this is left unspecified, we set
q_targ = q
internally.optimizer (optax optimizer, optional) – An optaxstyle optimizer. The default optimizer is
optax.adam(1e3)
.loss_function (callable, optional) –
The loss function that will be used to regress to the (bootstrapped) target. The loss function is expected to be of the form:
\[L(y_\text{true}, y_\text{pred}, w)\in\mathbb{R}\]where \(w>0\) are sample weights. If left unspecified, this defaults to
coax.value_losses.huber()
. Check out thecoax.value_losses
module for other predefined loss functions.policy_regularizer (Regularizer, optional) –
If provided, this policy regularizer is added to the TDtarget. A typical example is to use an
coax.regularizers.EntropyRegularizer
, which adds the policy entropy to the target. In this case, we minimize the following loss shifted by the entropy term:\[L(y_\text{true} + \beta\,H[\pi], y_\text{pred})\]Note that the coefficient \(\beta\) plays the role of the temperature in SACstyle agents.
 apply_grads(grads, function_state)¶
Update the model parameters (weights) of the underlying function approximator given precomputed gradients.
This method is useful in situations in which computation of the gradients is deligated to a separate (remote) process.
 Parameters:
grads (pytree with ndarray leaves) – A batch of gradients, generated by the
grads
method.function_state (pytree) – The internal state of the forwardpass function. See
Q.function_state
andhaiku.transform_with_state()
for more details.
 grads_and_metrics(transition_batch)¶
Compute the gradients associated with a batch of transitions.
 Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
 Returns:
grads (pytree with ndarray leaves) – A batch of gradients.
function_state (pytree) – The internal state of the forwardpass function. See
Q.function_state
andhaiku.transform_with_state()
for more details.metrics (dict of scalar ndarrays) – The structure of the metrics dict is
{name: score}
.td_error (ndarray) – The nonaggregated TDerrors,
shape == (batch_size,)
.
 td_error(transition_batch)¶
Compute the TDerrors associated with a batch of transitions. We define the TDerror as the negative gradient of the
loss_function
with respect to the predicted value:\[\text{td_error}_i\ =\ \frac{\partial L(y, \hat{y})}{\partial \hat{y}_i}\]Note that this reduces to the ordinary definition \(\text{td_error}=y\hat{y}\) when we use the
coax.value_losses.mse()
loss funtion. Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
 Returns:
td_errors (ndarray, shape: [batch_size]) – A batch of TDerrors.
 update(transition_batch, return_td_error=False)¶
Update the model parameters (weights) of the underlying function approximator.
 Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
return_td_error (bool, optional) – Whether to return the TDerrors.
 Returns:
metrics (dict of scalar ndarrays) – The structure of the metrics dict is
{name: score}
.td_error (ndarray, optional) – The nonaggregated TDerrors,
shape == (batch_size,)
. This is only returned if we setreturn_td_error=True
.
 class coax.td_learning.DoubleQLearning(q, pi_targ=None, q_targ=None, optimizer=None, loss_function=None, policy_regularizer=None)[source]¶
TDlearning with DoubleDQN style double qlearning updates, in which the target network is only used in selecting the wouldbe next action. The \(n\)step bootstrapped target is thus constructed as:
\[\begin{split}a_\text{greedy}\ &=\ \arg\max_a q_\text{targ}(S_{t+n}, a) \\ G^{(n)}_t\ &=\ R^{(n)}_t + I^{(n)}_t\,q(S_{t+n}, a_\text{greedy})\end{split}\]where
\[\begin{split}R^{(n)}_t\ &=\ \sum_{k=0}^{n1}\gamma^kR_{t+k} \\ I^{(n)}_t\ &=\ \left\{\begin{matrix} 0 & \text{if $S_{t+n}$ is a terminal state} \\ \gamma^n & \text{otherwise} \end{matrix}\right.\end{split}\] Parameters:
q (Q) – The main qfunction to update.
pi_targ (Policy, optional) – The policy that is used for constructing the TDtarget. This is ignored if the action space is discrete and required otherwise.
q_targ (Q, optional) – The qfunction that is used for constructing the TDtarget. If this is left unspecified, we set
q_targ = q
internally.optimizer (optax optimizer, optional) – An optaxstyle optimizer. The default optimizer is
optax.adam(1e3)
.loss_function (callable, optional) –
The loss function that will be used to regress to the (bootstrapped) target. The loss function is expected to be of the form:
\[L(y_\text{true}, y_\text{pred}, w)\in\mathbb{R}\]where \(w>0\) are sample weights. If left unspecified, this defaults to
coax.value_losses.huber()
. Check out thecoax.value_losses
module for other predefined loss functions.policy_regularizer (Regularizer, optional) –
If provided, this policy regularizer is added to the TDtarget. A typical example is to use an
coax.regularizers.EntropyRegularizer
, which adds the policy entropy to the target. In this case, we minimize the following loss shifted by the entropy term:\[L(y_\text{true} + \beta\,H[\pi], y_\text{pred})\]Note that the coefficient \(\beta\) plays the role of the temperature in SACstyle agents.
 apply_grads(grads, function_state)¶
Update the model parameters (weights) of the underlying function approximator given precomputed gradients.
This method is useful in situations in which computation of the gradients is deligated to a separate (remote) process.
 Parameters:
grads (pytree with ndarray leaves) – A batch of gradients, generated by the
grads
method.function_state (pytree) – The internal state of the forwardpass function. See
Q.function_state
andhaiku.transform_with_state()
for more details.
 grads_and_metrics(transition_batch)¶
Compute the gradients associated with a batch of transitions.
 Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
 Returns:
grads (pytree with ndarray leaves) – A batch of gradients.
function_state (pytree) – The internal state of the forwardpass function. See
Q.function_state
andhaiku.transform_with_state()
for more details.metrics (dict of scalar ndarrays) – The structure of the metrics dict is
{name: score}
.td_error (ndarray) – The nonaggregated TDerrors,
shape == (batch_size,)
.
 td_error(transition_batch)¶
Compute the TDerrors associated with a batch of transitions. We define the TDerror as the negative gradient of the
loss_function
with respect to the predicted value:\[\text{td_error}_i\ =\ \frac{\partial L(y, \hat{y})}{\partial \hat{y}_i}\]Note that this reduces to the ordinary definition \(\text{td_error}=y\hat{y}\) when we use the
coax.value_losses.mse()
loss funtion. Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
 Returns:
td_errors (ndarray, shape: [batch_size]) – A batch of TDerrors.
 update(transition_batch, return_td_error=False)¶
Update the model parameters (weights) of the underlying function approximator.
 Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
return_td_error (bool, optional) – Whether to return the TDerrors.
 Returns:
metrics (dict of scalar ndarrays) – The structure of the metrics dict is
{name: score}
.td_error (ndarray, optional) – The nonaggregated TDerrors,
shape == (batch_size,)
. This is only returned if we setreturn_td_error=True
.
 class coax.td_learning.SoftQLearning(q, q_targ=None, optimizer=None, loss_function=None, policy_regularizer=None, temperature=1.0)[source]¶
TDlearning with soft Qlearning updates. The \(n\)step bootstrapped target is constructed as:
\[G^{(n)}_t\ =\ R^{(n)}_t + I^{(n)}_t\,\tau\log\sum_{a'}\exp\left(q_\text{targ}(S_{t+n}, a') / \tau\right)\]where
\[\begin{split}R^{(n)}_t\ &=\ \sum_{k=0}^{n1}\gamma^kR_{t+k} \\ I^{(n)}_t\ &=\ \left\{\begin{matrix} 0 & \text{if $S_{t+n}$ is a terminal state} \\ \gamma^n & \text{otherwise} \end{matrix}\right.\end{split}\] Parameters:
q (Q) – The main qfunction to update.
q_targ (Q, optional) – The qfunction that is used for constructing the TDtarget. If this is left unspecified, we set
q_targ = q
internally.optimizer (optax optimizer, optional) – An optaxstyle optimizer. The default optimizer is
optax.adam(1e3)
.loss_function (callable, optional) –
The loss function that will be used to regress to the (bootstrapped) target. The loss function is expected to be of the form:
\[L(y_\text{true}, y_\text{pred}, w)\in\mathbb{R}\]where \(w>0\) are sample weights. If left unspecified, this defaults to
coax.value_losses.huber()
. Check out thecoax.value_losses
module for other predefined loss functions.policy_regularizer (Regularizer, optional) –
If provided, this policy regularizer is added to the TDtarget. A typical example is to use an
coax.regularizers.EntropyRegularizer
, which adds the policy entropy to the target. In this case, we minimize the following loss shifted by the entropy term:\[L(y_\text{true} + \beta\,H[\pi], y_\text{pred})\]Note that the coefficient \(\beta\) plays the role of the temperature in SACstyle agents.
temperature (float, optional) – The Boltzmann temperature \(\tau>0\).
 apply_grads(grads, function_state)¶
Update the model parameters (weights) of the underlying function approximator given precomputed gradients.
This method is useful in situations in which computation of the gradients is deligated to a separate (remote) process.
 Parameters:
grads (pytree with ndarray leaves) – A batch of gradients, generated by the
grads
method.function_state (pytree) – The internal state of the forwardpass function. See
Q.function_state
andhaiku.transform_with_state()
for more details.
 grads_and_metrics(transition_batch)¶
Compute the gradients associated with a batch of transitions.
 Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
 Returns:
grads (pytree with ndarray leaves) – A batch of gradients.
function_state (pytree) – The internal state of the forwardpass function. See
Q.function_state
andhaiku.transform_with_state()
for more details.metrics (dict of scalar ndarrays) – The structure of the metrics dict is
{name: score}
.td_error (ndarray) – The nonaggregated TDerrors,
shape == (batch_size,)
.
 td_error(transition_batch)¶
Compute the TDerrors associated with a batch of transitions. We define the TDerror as the negative gradient of the
loss_function
with respect to the predicted value:\[\text{td_error}_i\ =\ \frac{\partial L(y, \hat{y})}{\partial \hat{y}_i}\]Note that this reduces to the ordinary definition \(\text{td_error}=y\hat{y}\) when we use the
coax.value_losses.mse()
loss funtion. Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
 Returns:
td_errors (ndarray, shape: [batch_size]) – A batch of TDerrors.
 update(transition_batch, return_td_error=False)¶
Update the model parameters (weights) of the underlying function approximator.
 Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
return_td_error (bool, optional) – Whether to return the TDerrors.
 Returns:
metrics (dict of scalar ndarrays) – The structure of the metrics dict is
{name: score}
.td_error (ndarray, optional) – The nonaggregated TDerrors,
shape == (batch_size,)
. This is only returned if we setreturn_td_error=True
.
 class coax.td_learning.ClippedDoubleQLearning(q, pi_targ_list=None, q_targ_list=None, optimizer=None, loss_function=None, policy_regularizer=None)[source]¶
TDlearning with TD3 style double qlearning updates, in which the target network is only used in selecting the wouldbe next action.
For discrete actions, the \(n\)step bootstrapped target is constructed as:
\[G^{(n)}_t\ =\ R^{(n)}_t + I^{(n)}_t\,\min_{i,j}q_i(S_{t+n}, \arg\max_a q_j(S_{t+n}, a))\]where \(q_i(s,a)\) is the \(i\)th target qfunction provided in
q_targ_list
.Similarly, for nondiscrete actions, the target is constructed as:
\[G^{(n)}_t\ =\ R^{(n)}_t + I^{(n)}_t\,\min_{i,j}q_i(S_{t+n}, a_j(S_{t+n}))\]where \(a_i(s)\) is the mode of the \(i\)th target policy provided in
pi_targ_list
.where
\[\begin{split}R^{(n)}_t\ &=\ \sum_{k=0}^{n1}\gamma^kR_{t+k} \\ I^{(n)}_t\ &=\ \left\{\begin{matrix} 0 & \text{if $S_{t+n}$ is a terminal state} \\ \gamma^n & \text{otherwise} \end{matrix}\right.\end{split}\] Parameters:
q (Q) – The main qfunction to update.
pi_targ_list (list of Policy, optional) – The list of policies that are used for constructing the TDtarget. This is ignored if the action space is discrete and required otherwise.
q_targ_list (list of Q) – The list of qfunctions that are used for constructing the TDtarget.
optimizer (optax optimizer, optional) – An optaxstyle optimizer. The default optimizer is
optax.adam(1e3)
.loss_function (callable, optional) –
The loss function that will be used to regress to the (bootstrapped) target. The loss function is expected to be of the form:
\[L(y_\text{true}, y_\text{pred}, w)\in\mathbb{R}\]where \(w>0\) are sample weights. If left unspecified, this defaults to
coax.value_losses.huber()
. Check out thecoax.value_losses
module for other predefined loss functions.policy_regularizer (Regularizer, optional) –
If provided, this policy regularizer is added to the TDtarget. A typical example is to use an
coax.regularizers.EntropyRegularizer
, which adds the policy entropy to the target. In this case, we minimize the following loss shifted by the entropy term:\[L(y_\text{true} + \beta\,H[\pi], y_\text{pred})\]Note that the coefficient \(\beta\) plays the role of the temperature in SACstyle agents.
 apply_grads(grads, function_state)¶
Update the model parameters (weights) of the underlying function approximator given precomputed gradients.
This method is useful in situations in which computation of the gradients is deligated to a separate (remote) process.
 Parameters:
grads (pytree with ndarray leaves) – A batch of gradients, generated by the
grads
method.function_state (pytree) – The internal state of the forwardpass function. See
Q.function_state
andhaiku.transform_with_state()
for more details.
 grads_and_metrics(transition_batch)¶
Compute the gradients associated with a batch of transitions.
 Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
 Returns:
grads (pytree with ndarray leaves) – A batch of gradients.
function_state (pytree) – The internal state of the forwardpass function. See
Q.function_state
andhaiku.transform_with_state()
for more details.metrics (dict of scalar ndarrays) – The structure of the metrics dict is
{name: score}
.td_error (ndarray) – The nonaggregated TDerrors,
shape == (batch_size,)
.
 td_error(transition_batch)¶
Compute the TDerrors associated with a batch of transitions. We define the TDerror as the negative gradient of the
loss_function
with respect to the predicted value:\[\text{td_error}_i\ =\ \frac{\partial L(y, \hat{y})}{\partial \hat{y}_i}\]Note that this reduces to the ordinary definition \(\text{td_error}=y\hat{y}\) when we use the
coax.value_losses.mse()
loss funtion. Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
 Returns:
td_errors (ndarray, shape: [batch_size]) – A batch of TDerrors.
 update(transition_batch, return_td_error=False)¶
Update the model parameters (weights) of the underlying function approximator.
 Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
return_td_error (bool, optional) – Whether to return the TDerrors.
 Returns:
metrics (dict of scalar ndarrays) – The structure of the metrics dict is
{name: score}
.td_error (ndarray, optional) – The nonaggregated TDerrors,
shape == (batch_size,)
. This is only returned if we setreturn_td_error=True
.
 class coax.td_learning.SoftClippedDoubleQLearning(q, pi_targ_list=None, q_targ_list=None, optimizer=None, loss_function=None, policy_regularizer=None)[source]¶
 apply_grads(grads, function_state)¶
Update the model parameters (weights) of the underlying function approximator given precomputed gradients.
This method is useful in situations in which computation of the gradients is deligated to a separate (remote) process.
 Parameters:
grads (pytree with ndarray leaves) – A batch of gradients, generated by the
grads
method.function_state (pytree) – The internal state of the forwardpass function. See
Q.function_state
andhaiku.transform_with_state()
for more details.
 grads_and_metrics(transition_batch)¶
Compute the gradients associated with a batch of transitions.
 Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
 Returns:
grads (pytree with ndarray leaves) – A batch of gradients.
function_state (pytree) – The internal state of the forwardpass function. See
Q.function_state
andhaiku.transform_with_state()
for more details.metrics (dict of scalar ndarrays) – The structure of the metrics dict is
{name: score}
.td_error (ndarray) – The nonaggregated TDerrors,
shape == (batch_size,)
.
 target_func(target_params, target_state, rng, transition_batch)[source]¶
This does almost the same as ClippedDoubleQLearning.target_func except that the action for the next state is sampled instead of taking the mode.
 td_error(transition_batch)¶
Compute the TDerrors associated with a batch of transitions. We define the TDerror as the negative gradient of the
loss_function
with respect to the predicted value:\[\text{td_error}_i\ =\ \frac{\partial L(y, \hat{y})}{\partial \hat{y}_i}\]Note that this reduces to the ordinary definition \(\text{td_error}=y\hat{y}\) when we use the
coax.value_losses.mse()
loss funtion. Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
 Returns:
td_errors (ndarray, shape: [batch_size]) – A batch of TDerrors.
 update(transition_batch, return_td_error=False)¶
Update the model parameters (weights) of the underlying function approximator.
 Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
return_td_error (bool, optional) – Whether to return the TDerrors.
 Returns:
metrics (dict of scalar ndarrays) – The structure of the metrics dict is
{name: score}
.td_error (ndarray, optional) – The nonaggregated TDerrors,
shape == (batch_size,)
. This is only returned if we setreturn_td_error=True
.