TD Learning¶
TD-learning for state value functions \(v(s)\). |
|
TD-learning with SARSA updates. |
|
TD-learning with expected-SARSA updates. |
|
TD-learning with Q-Learning updates. |
|
TD-learning with Double-DQN style double q-learning updates, in which the target network is only used in selecting the would-be next action. |
|
TD-learning with soft Q-learning updates. |
|
TD-learning with TD3 style double q-learning updates, in which the target network is only used in selecting the would-be next action. |
|
This is a collection of objects that are used to update value functions via Temporal Difference
(TD) learning. A state value function coax.V
is using coax.td_learning.SimpleTD
.
To update a state-action value function coax.Q
, there are multiple options available. The
difference between the options are the manner in which the TD-target is constructed.
Object Reference¶
- class coax.td_learning.SimpleTD(v, v_targ=None, optimizer=None, loss_function=None, policy_regularizer=None)[source]¶
TD-learning for state value functions \(v(s)\). The \(n\)-step bootstrapped target is constructed as:
\[G^{(n)}_t\ =\ R^{(n)}_t + I^{(n)}_t\,v_\text{targ}(S_{t+n})\]where
\[\begin{split}R^{(n)}_t\ &=\ \sum_{k=0}^{n-1}\gamma^kR_{t+k} \\ I^{(n)}_t\ &=\ \left\{\begin{matrix} 0 & \text{if $S_{t+n}$ is a terminal state} \\ \gamma^n & \text{otherwise} \end{matrix}\right.\end{split}\]- Parameters:
v (V) – The main state value function to update.
v_targ (V, optional) – The state value function that is used for constructing the TD-target. If this is left unspecified, we set
v_targ = v
internally.optimizer (optax optimizer, optional) – An optax-style optimizer. The default optimizer is
optax.adam(1e-3)
.loss_function (callable, optional) –
The loss function that will be used to regress to the (bootstrapped) target. The loss function is expected to be of the form:
\[L(y_\text{true}, y_\text{pred}, w)\in\mathbb{R}\]where \(w>0\) are sample weights. If left unspecified, this defaults to
coax.value_losses.huber()
. Check out thecoax.value_losses
module for other predefined loss functions.policy_regularizer (Regularizer, optional) –
If provided, this policy regularizer is added to the TD-target. A typical example is to use an
coax.regularizers.EntropyRegularizer
, which adds the policy entropy to the target. In this case, we minimize the following loss shifted by the entropy term:\[L(y_\text{true} + \beta\,H[\pi], y_\text{pred})\]Note that the coefficient \(\beta\) plays the role of the temperature in SAC-style agents.
- apply_grads(grads, function_state)¶
Update the model parameters (weights) of the underlying function approximator given pre-computed gradients.
This method is useful in situations in which computation of the gradients is deligated to a separate (remote) process.
- Parameters:
grads (pytree with ndarray leaves) – A batch of gradients, generated by the
grads
method.function_state (pytree) – The internal state of the forward-pass function. See
Q.function_state
andhaiku.transform_with_state()
for more details.
- grads_and_metrics(transition_batch)¶
Compute the gradients associated with a batch of transitions.
- Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
- Returns:
grads (pytree with ndarray leaves) – A batch of gradients.
function_state (pytree) – The internal state of the forward-pass function. See
Q.function_state
andhaiku.transform_with_state()
for more details.metrics (dict of scalar ndarrays) – The structure of the metrics dict is
{name: score}
.td_error (ndarray) – The non-aggregated TD-errors,
shape == (batch_size,)
.
- td_error(transition_batch)¶
Compute the TD-errors associated with a batch of transitions. We define the TD-error as the negative gradient of the
loss_function
with respect to the predicted value:\[\text{td_error}_i\ =\ -\frac{\partial L(y, \hat{y})}{\partial \hat{y}_i}\]Note that this reduces to the ordinary definition \(\text{td_error}=y-\hat{y}\) when we use the
coax.value_losses.mse()
loss funtion.- Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
- Returns:
td_errors (ndarray, shape: [batch_size]) – A batch of TD-errors.
- update(transition_batch, return_td_error=False)¶
Update the model parameters (weights) of the underlying function approximator.
- Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
return_td_error (bool, optional) – Whether to return the TD-errors.
- Returns:
metrics (dict of scalar ndarrays) – The structure of the metrics dict is
{name: score}
.td_error (ndarray, optional) – The non-aggregated TD-errors,
shape == (batch_size,)
. This is only returned if we setreturn_td_error=True
.
- class coax.td_learning.Sarsa(q, q_targ=None, optimizer=None, loss_function=None, policy_regularizer=None)[source]¶
TD-learning with SARSA updates. The \(n\)-step bootstrapped target is constructed as:
\[G^{(n)}_t\ =\ R^{(n)}_t + I^{(n)}_t\,q_\text{targ}(S_{t+n}, A_{t+n})\]where \(A_{t+n}\) is sampled from experience and
\[\begin{split}R^{(n)}_t\ &=\ \sum_{k=0}^{n-1}\gamma^kR_{t+k} \\ I^{(n)}_t\ &=\ \left\{\begin{matrix} 0 & \text{if $S_{t+n}$ is a terminal state} \\ \gamma^n & \text{otherwise} \end{matrix}\right.\end{split}\]- Parameters:
q (Q) – The main q-function to update.
q_targ (Q, optional) – The q-function that is used for constructing the TD-target. If this is left unspecified, we set
q_targ = q
internally.optimizer (optax optimizer, optional) – An optax-style optimizer. The default optimizer is
optax.adam(1e-3)
.loss_function (callable, optional) –
The loss function that will be used to regress to the (bootstrapped) target. The loss function is expected to be of the form:
\[L(y_\text{true}, y_\text{pred}, w)\in\mathbb{R}\]where \(w>0\) are sample weights. If left unspecified, this defaults to
coax.value_losses.huber()
. Check out thecoax.value_losses
module for other predefined loss functions.policy_regularizer (Regularizer, optional) –
If provided, this policy regularizer is added to the TD-target. A typical example is to use an
coax.regularizers.EntropyRegularizer
, which adds the policy entropy to the target. In this case, we minimize the following loss shifted by the entropy term:\[L(y_\text{true} + \beta\,H[\pi], y_\text{pred})\]Note that the coefficient \(\beta\) plays the role of the temperature in SAC-style agents.
- apply_grads(grads, function_state)¶
Update the model parameters (weights) of the underlying function approximator given pre-computed gradients.
This method is useful in situations in which computation of the gradients is deligated to a separate (remote) process.
- Parameters:
grads (pytree with ndarray leaves) – A batch of gradients, generated by the
grads
method.function_state (pytree) – The internal state of the forward-pass function. See
Q.function_state
andhaiku.transform_with_state()
for more details.
- grads_and_metrics(transition_batch)¶
Compute the gradients associated with a batch of transitions.
- Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
- Returns:
grads (pytree with ndarray leaves) – A batch of gradients.
function_state (pytree) – The internal state of the forward-pass function. See
Q.function_state
andhaiku.transform_with_state()
for more details.metrics (dict of scalar ndarrays) – The structure of the metrics dict is
{name: score}
.td_error (ndarray) – The non-aggregated TD-errors,
shape == (batch_size,)
.
- td_error(transition_batch)¶
Compute the TD-errors associated with a batch of transitions. We define the TD-error as the negative gradient of the
loss_function
with respect to the predicted value:\[\text{td_error}_i\ =\ -\frac{\partial L(y, \hat{y})}{\partial \hat{y}_i}\]Note that this reduces to the ordinary definition \(\text{td_error}=y-\hat{y}\) when we use the
coax.value_losses.mse()
loss funtion.- Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
- Returns:
td_errors (ndarray, shape: [batch_size]) – A batch of TD-errors.
- update(transition_batch, return_td_error=False)¶
Update the model parameters (weights) of the underlying function approximator.
- Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
return_td_error (bool, optional) – Whether to return the TD-errors.
- Returns:
metrics (dict of scalar ndarrays) – The structure of the metrics dict is
{name: score}
.td_error (ndarray, optional) – The non-aggregated TD-errors,
shape == (batch_size,)
. This is only returned if we setreturn_td_error=True
.
- class coax.td_learning.ExpectedSarsa(q, pi_targ, q_targ=None, optimizer=None, loss_function=None, policy_regularizer=None)[source]¶
TD-learning with expected-SARSA updates. The \(n\)-step bootstrapped target is constructed as:
\[G^{(n)}_t\ =\ R^{(n)}_t + I^{(n)}_t\,\mathop{\mathbb{E}}_{a\sim\pi_\text{targ}(.|S_{t+n})}\, q_\text{targ}\left(S_{t+n}, a\right)\]Note that ordinary
SARSA
target is the sampled estimate of the above target.Also, as usual, the \(n\)-step reward and indicator are defined as:
\[\begin{split}R^{(n)}_t\ &=\ \sum_{k=0}^{n-1}\gamma^kR_{t+k} \\ I^{(n)}_t\ &=\ \left\{\begin{matrix} 0 & \text{if $S_{t+n}$ is a terminal state} \\ \gamma^n & \text{otherwise} \end{matrix}\right.\end{split}\]- Parameters:
q (Q) – The main q-function to update.
pi_targ (Policy) – The policy that is used for constructing the TD-target.
q_targ (Q, optional) – The q-function that is used for constructing the TD-target. If this is left unspecified, we set
q_targ = q
internally.optimizer (optax optimizer, optional) – An optax-style optimizer. The default optimizer is
optax.adam(1e-3)
.loss_function (callable, optional) –
The loss function that will be used to regress to the (bootstrapped) target. The loss function is expected to be of the form:
\[L(y_\text{true}, y_\text{pred}, w)\in\mathbb{R}\]where \(w>0\) are sample weights. If left unspecified, this defaults to
coax.value_losses.huber()
. Check out thecoax.value_losses
module for other predefined loss functions.policy_regularizer (Regularizer, optional) –
If provided, this policy regularizer is added to the TD-target. A typical example is to use an
coax.regularizers.EntropyRegularizer
, which adds the policy entropy to the target. In this case, we minimize the following loss shifted by the entropy term:\[L(y_\text{true} + \beta\,H[\pi], y_\text{pred})\]Note that the coefficient \(\beta\) plays the role of the temperature in SAC-style agents.
- apply_grads(grads, function_state)¶
Update the model parameters (weights) of the underlying function approximator given pre-computed gradients.
This method is useful in situations in which computation of the gradients is deligated to a separate (remote) process.
- Parameters:
grads (pytree with ndarray leaves) – A batch of gradients, generated by the
grads
method.function_state (pytree) – The internal state of the forward-pass function. See
Q.function_state
andhaiku.transform_with_state()
for more details.
- grads_and_metrics(transition_batch)¶
Compute the gradients associated with a batch of transitions.
- Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
- Returns:
grads (pytree with ndarray leaves) – A batch of gradients.
function_state (pytree) – The internal state of the forward-pass function. See
Q.function_state
andhaiku.transform_with_state()
for more details.metrics (dict of scalar ndarrays) – The structure of the metrics dict is
{name: score}
.td_error (ndarray) – The non-aggregated TD-errors,
shape == (batch_size,)
.
- td_error(transition_batch)¶
Compute the TD-errors associated with a batch of transitions. We define the TD-error as the negative gradient of the
loss_function
with respect to the predicted value:\[\text{td_error}_i\ =\ -\frac{\partial L(y, \hat{y})}{\partial \hat{y}_i}\]Note that this reduces to the ordinary definition \(\text{td_error}=y-\hat{y}\) when we use the
coax.value_losses.mse()
loss funtion.- Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
- Returns:
td_errors (ndarray, shape: [batch_size]) – A batch of TD-errors.
- update(transition_batch, return_td_error=False)¶
Update the model parameters (weights) of the underlying function approximator.
- Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
return_td_error (bool, optional) – Whether to return the TD-errors.
- Returns:
metrics (dict of scalar ndarrays) – The structure of the metrics dict is
{name: score}
.td_error (ndarray, optional) – The non-aggregated TD-errors,
shape == (batch_size,)
. This is only returned if we setreturn_td_error=True
.
- class coax.td_learning.QLearning(q, pi_targ=None, q_targ=None, optimizer=None, loss_function=None, policy_regularizer=None)[source]¶
TD-learning with Q-Learning updates.
The \(n\)-step bootstrapped target for discrete actions is constructed as:
\[G^{(n)}_t\ =\ R^{(n)}_t + I^{(n)}_t\,\max_aq_\text{targ}\left(S_{t+n}, a\right)\]For non-discrete action spaces, this uses a DDPG-style target:
\[G^{(n)}_t\ =\ R^{(n)}_t + I^{(n)}_t\,q_\text{targ}\left( S_{t+n}, a_\text{targ}(S_{t+n})\right)\]where \(a_\text{targ}(s)\) is the mode of the underlying conditional probability distribution \(\pi_\text{targ}(.|s)\). Even though these two formulations of the q-learning target are equivalent, the implementation of the latter does require additional input, namely
pi_targ
.The \(n\)-step reward and indicator (referenced above) are defined as:
\[\begin{split}R^{(n)}_t\ &=\ \sum_{k=0}^{n-1}\gamma^kR_{t+k} \\ I^{(n)}_t\ &=\ \left\{\begin{matrix} 0 & \text{if $S_{t+n}$ is a terminal state} \\ \gamma^n & \text{otherwise} \end{matrix}\right.\end{split}\]- Parameters:
q (Q) – The main q-function to update.
pi_targ (Policy, optional) – The policy that is used for constructing the TD-target. This is ignored if the action space is discrete and required otherwise.
q_targ (Q, optional) – The q-function that is used for constructing the TD-target. If this is left unspecified, we set
q_targ = q
internally.optimizer (optax optimizer, optional) – An optax-style optimizer. The default optimizer is
optax.adam(1e-3)
.loss_function (callable, optional) –
The loss function that will be used to regress to the (bootstrapped) target. The loss function is expected to be of the form:
\[L(y_\text{true}, y_\text{pred}, w)\in\mathbb{R}\]where \(w>0\) are sample weights. If left unspecified, this defaults to
coax.value_losses.huber()
. Check out thecoax.value_losses
module for other predefined loss functions.policy_regularizer (Regularizer, optional) –
If provided, this policy regularizer is added to the TD-target. A typical example is to use an
coax.regularizers.EntropyRegularizer
, which adds the policy entropy to the target. In this case, we minimize the following loss shifted by the entropy term:\[L(y_\text{true} + \beta\,H[\pi], y_\text{pred})\]Note that the coefficient \(\beta\) plays the role of the temperature in SAC-style agents.
- apply_grads(grads, function_state)¶
Update the model parameters (weights) of the underlying function approximator given pre-computed gradients.
This method is useful in situations in which computation of the gradients is deligated to a separate (remote) process.
- Parameters:
grads (pytree with ndarray leaves) – A batch of gradients, generated by the
grads
method.function_state (pytree) – The internal state of the forward-pass function. See
Q.function_state
andhaiku.transform_with_state()
for more details.
- grads_and_metrics(transition_batch)¶
Compute the gradients associated with a batch of transitions.
- Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
- Returns:
grads (pytree with ndarray leaves) – A batch of gradients.
function_state (pytree) – The internal state of the forward-pass function. See
Q.function_state
andhaiku.transform_with_state()
for more details.metrics (dict of scalar ndarrays) – The structure of the metrics dict is
{name: score}
.td_error (ndarray) – The non-aggregated TD-errors,
shape == (batch_size,)
.
- td_error(transition_batch)¶
Compute the TD-errors associated with a batch of transitions. We define the TD-error as the negative gradient of the
loss_function
with respect to the predicted value:\[\text{td_error}_i\ =\ -\frac{\partial L(y, \hat{y})}{\partial \hat{y}_i}\]Note that this reduces to the ordinary definition \(\text{td_error}=y-\hat{y}\) when we use the
coax.value_losses.mse()
loss funtion.- Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
- Returns:
td_errors (ndarray, shape: [batch_size]) – A batch of TD-errors.
- update(transition_batch, return_td_error=False)¶
Update the model parameters (weights) of the underlying function approximator.
- Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
return_td_error (bool, optional) – Whether to return the TD-errors.
- Returns:
metrics (dict of scalar ndarrays) – The structure of the metrics dict is
{name: score}
.td_error (ndarray, optional) – The non-aggregated TD-errors,
shape == (batch_size,)
. This is only returned if we setreturn_td_error=True
.
- class coax.td_learning.DoubleQLearning(q, pi_targ=None, q_targ=None, optimizer=None, loss_function=None, policy_regularizer=None)[source]¶
TD-learning with Double-DQN style double q-learning updates, in which the target network is only used in selecting the would-be next action. The \(n\)-step bootstrapped target is thus constructed as:
\[\begin{split}a_\text{greedy}\ &=\ \arg\max_a q_\text{targ}(S_{t+n}, a) \\ G^{(n)}_t\ &=\ R^{(n)}_t + I^{(n)}_t\,q(S_{t+n}, a_\text{greedy})\end{split}\]where
\[\begin{split}R^{(n)}_t\ &=\ \sum_{k=0}^{n-1}\gamma^kR_{t+k} \\ I^{(n)}_t\ &=\ \left\{\begin{matrix} 0 & \text{if $S_{t+n}$ is a terminal state} \\ \gamma^n & \text{otherwise} \end{matrix}\right.\end{split}\]- Parameters:
q (Q) – The main q-function to update.
pi_targ (Policy, optional) – The policy that is used for constructing the TD-target. This is ignored if the action space is discrete and required otherwise.
q_targ (Q, optional) – The q-function that is used for constructing the TD-target. If this is left unspecified, we set
q_targ = q
internally.optimizer (optax optimizer, optional) – An optax-style optimizer. The default optimizer is
optax.adam(1e-3)
.loss_function (callable, optional) –
The loss function that will be used to regress to the (bootstrapped) target. The loss function is expected to be of the form:
\[L(y_\text{true}, y_\text{pred}, w)\in\mathbb{R}\]where \(w>0\) are sample weights. If left unspecified, this defaults to
coax.value_losses.huber()
. Check out thecoax.value_losses
module for other predefined loss functions.policy_regularizer (Regularizer, optional) –
If provided, this policy regularizer is added to the TD-target. A typical example is to use an
coax.regularizers.EntropyRegularizer
, which adds the policy entropy to the target. In this case, we minimize the following loss shifted by the entropy term:\[L(y_\text{true} + \beta\,H[\pi], y_\text{pred})\]Note that the coefficient \(\beta\) plays the role of the temperature in SAC-style agents.
- apply_grads(grads, function_state)¶
Update the model parameters (weights) of the underlying function approximator given pre-computed gradients.
This method is useful in situations in which computation of the gradients is deligated to a separate (remote) process.
- Parameters:
grads (pytree with ndarray leaves) – A batch of gradients, generated by the
grads
method.function_state (pytree) – The internal state of the forward-pass function. See
Q.function_state
andhaiku.transform_with_state()
for more details.
- grads_and_metrics(transition_batch)¶
Compute the gradients associated with a batch of transitions.
- Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
- Returns:
grads (pytree with ndarray leaves) – A batch of gradients.
function_state (pytree) – The internal state of the forward-pass function. See
Q.function_state
andhaiku.transform_with_state()
for more details.metrics (dict of scalar ndarrays) – The structure of the metrics dict is
{name: score}
.td_error (ndarray) – The non-aggregated TD-errors,
shape == (batch_size,)
.
- td_error(transition_batch)¶
Compute the TD-errors associated with a batch of transitions. We define the TD-error as the negative gradient of the
loss_function
with respect to the predicted value:\[\text{td_error}_i\ =\ -\frac{\partial L(y, \hat{y})}{\partial \hat{y}_i}\]Note that this reduces to the ordinary definition \(\text{td_error}=y-\hat{y}\) when we use the
coax.value_losses.mse()
loss funtion.- Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
- Returns:
td_errors (ndarray, shape: [batch_size]) – A batch of TD-errors.
- update(transition_batch, return_td_error=False)¶
Update the model parameters (weights) of the underlying function approximator.
- Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
return_td_error (bool, optional) – Whether to return the TD-errors.
- Returns:
metrics (dict of scalar ndarrays) – The structure of the metrics dict is
{name: score}
.td_error (ndarray, optional) – The non-aggregated TD-errors,
shape == (batch_size,)
. This is only returned if we setreturn_td_error=True
.
- class coax.td_learning.SoftQLearning(q, q_targ=None, optimizer=None, loss_function=None, policy_regularizer=None, temperature=1.0)[source]¶
TD-learning with soft Q-learning updates. The \(n\)-step bootstrapped target is constructed as:
\[G^{(n)}_t\ =\ R^{(n)}_t + I^{(n)}_t\,\tau\log\sum_{a'}\exp\left(q_\text{targ}(S_{t+n}, a') / \tau\right)\]where
\[\begin{split}R^{(n)}_t\ &=\ \sum_{k=0}^{n-1}\gamma^kR_{t+k} \\ I^{(n)}_t\ &=\ \left\{\begin{matrix} 0 & \text{if $S_{t+n}$ is a terminal state} \\ \gamma^n & \text{otherwise} \end{matrix}\right.\end{split}\]- Parameters:
q (Q) – The main q-function to update.
q_targ (Q, optional) – The q-function that is used for constructing the TD-target. If this is left unspecified, we set
q_targ = q
internally.optimizer (optax optimizer, optional) – An optax-style optimizer. The default optimizer is
optax.adam(1e-3)
.loss_function (callable, optional) –
The loss function that will be used to regress to the (bootstrapped) target. The loss function is expected to be of the form:
\[L(y_\text{true}, y_\text{pred}, w)\in\mathbb{R}\]where \(w>0\) are sample weights. If left unspecified, this defaults to
coax.value_losses.huber()
. Check out thecoax.value_losses
module for other predefined loss functions.policy_regularizer (Regularizer, optional) –
If provided, this policy regularizer is added to the TD-target. A typical example is to use an
coax.regularizers.EntropyRegularizer
, which adds the policy entropy to the target. In this case, we minimize the following loss shifted by the entropy term:\[L(y_\text{true} + \beta\,H[\pi], y_\text{pred})\]Note that the coefficient \(\beta\) plays the role of the temperature in SAC-style agents.
temperature (float, optional) – The Boltzmann temperature \(\tau>0\).
- apply_grads(grads, function_state)¶
Update the model parameters (weights) of the underlying function approximator given pre-computed gradients.
This method is useful in situations in which computation of the gradients is deligated to a separate (remote) process.
- Parameters:
grads (pytree with ndarray leaves) – A batch of gradients, generated by the
grads
method.function_state (pytree) – The internal state of the forward-pass function. See
Q.function_state
andhaiku.transform_with_state()
for more details.
- grads_and_metrics(transition_batch)¶
Compute the gradients associated with a batch of transitions.
- Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
- Returns:
grads (pytree with ndarray leaves) – A batch of gradients.
function_state (pytree) – The internal state of the forward-pass function. See
Q.function_state
andhaiku.transform_with_state()
for more details.metrics (dict of scalar ndarrays) – The structure of the metrics dict is
{name: score}
.td_error (ndarray) – The non-aggregated TD-errors,
shape == (batch_size,)
.
- td_error(transition_batch)¶
Compute the TD-errors associated with a batch of transitions. We define the TD-error as the negative gradient of the
loss_function
with respect to the predicted value:\[\text{td_error}_i\ =\ -\frac{\partial L(y, \hat{y})}{\partial \hat{y}_i}\]Note that this reduces to the ordinary definition \(\text{td_error}=y-\hat{y}\) when we use the
coax.value_losses.mse()
loss funtion.- Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
- Returns:
td_errors (ndarray, shape: [batch_size]) – A batch of TD-errors.
- update(transition_batch, return_td_error=False)¶
Update the model parameters (weights) of the underlying function approximator.
- Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
return_td_error (bool, optional) – Whether to return the TD-errors.
- Returns:
metrics (dict of scalar ndarrays) – The structure of the metrics dict is
{name: score}
.td_error (ndarray, optional) – The non-aggregated TD-errors,
shape == (batch_size,)
. This is only returned if we setreturn_td_error=True
.
- class coax.td_learning.ClippedDoubleQLearning(q, pi_targ_list=None, q_targ_list=None, optimizer=None, loss_function=None, policy_regularizer=None)[source]¶
TD-learning with TD3 style double q-learning updates, in which the target network is only used in selecting the would-be next action.
For discrete actions, the \(n\)-step bootstrapped target is constructed as:
\[G^{(n)}_t\ =\ R^{(n)}_t + I^{(n)}_t\,\min_{i,j}q_i(S_{t+n}, \arg\max_a q_j(S_{t+n}, a))\]where \(q_i(s,a)\) is the \(i\)-th target q-function provided in
q_targ_list
.Similarly, for non-discrete actions, the target is constructed as:
\[G^{(n)}_t\ =\ R^{(n)}_t + I^{(n)}_t\,\min_{i,j}q_i(S_{t+n}, a_j(S_{t+n}))\]where \(a_i(s)\) is the mode of the \(i\)-th target policy provided in
pi_targ_list
.where
\[\begin{split}R^{(n)}_t\ &=\ \sum_{k=0}^{n-1}\gamma^kR_{t+k} \\ I^{(n)}_t\ &=\ \left\{\begin{matrix} 0 & \text{if $S_{t+n}$ is a terminal state} \\ \gamma^n & \text{otherwise} \end{matrix}\right.\end{split}\]- Parameters:
q (Q) – The main q-function to update.
pi_targ_list (list of Policy, optional) – The list of policies that are used for constructing the TD-target. This is ignored if the action space is discrete and required otherwise.
q_targ_list (list of Q) – The list of q-functions that are used for constructing the TD-target.
optimizer (optax optimizer, optional) – An optax-style optimizer. The default optimizer is
optax.adam(1e-3)
.loss_function (callable, optional) –
The loss function that will be used to regress to the (bootstrapped) target. The loss function is expected to be of the form:
\[L(y_\text{true}, y_\text{pred}, w)\in\mathbb{R}\]where \(w>0\) are sample weights. If left unspecified, this defaults to
coax.value_losses.huber()
. Check out thecoax.value_losses
module for other predefined loss functions.policy_regularizer (Regularizer, optional) –
If provided, this policy regularizer is added to the TD-target. A typical example is to use an
coax.regularizers.EntropyRegularizer
, which adds the policy entropy to the target. In this case, we minimize the following loss shifted by the entropy term:\[L(y_\text{true} + \beta\,H[\pi], y_\text{pred})\]Note that the coefficient \(\beta\) plays the role of the temperature in SAC-style agents.
- apply_grads(grads, function_state)¶
Update the model parameters (weights) of the underlying function approximator given pre-computed gradients.
This method is useful in situations in which computation of the gradients is deligated to a separate (remote) process.
- Parameters:
grads (pytree with ndarray leaves) – A batch of gradients, generated by the
grads
method.function_state (pytree) – The internal state of the forward-pass function. See
Q.function_state
andhaiku.transform_with_state()
for more details.
- grads_and_metrics(transition_batch)¶
Compute the gradients associated with a batch of transitions.
- Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
- Returns:
grads (pytree with ndarray leaves) – A batch of gradients.
function_state (pytree) – The internal state of the forward-pass function. See
Q.function_state
andhaiku.transform_with_state()
for more details.metrics (dict of scalar ndarrays) – The structure of the metrics dict is
{name: score}
.td_error (ndarray) – The non-aggregated TD-errors,
shape == (batch_size,)
.
- td_error(transition_batch)¶
Compute the TD-errors associated with a batch of transitions. We define the TD-error as the negative gradient of the
loss_function
with respect to the predicted value:\[\text{td_error}_i\ =\ -\frac{\partial L(y, \hat{y})}{\partial \hat{y}_i}\]Note that this reduces to the ordinary definition \(\text{td_error}=y-\hat{y}\) when we use the
coax.value_losses.mse()
loss funtion.- Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
- Returns:
td_errors (ndarray, shape: [batch_size]) – A batch of TD-errors.
- update(transition_batch, return_td_error=False)¶
Update the model parameters (weights) of the underlying function approximator.
- Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
return_td_error (bool, optional) – Whether to return the TD-errors.
- Returns:
metrics (dict of scalar ndarrays) – The structure of the metrics dict is
{name: score}
.td_error (ndarray, optional) – The non-aggregated TD-errors,
shape == (batch_size,)
. This is only returned if we setreturn_td_error=True
.
- class coax.td_learning.SoftClippedDoubleQLearning(q, pi_targ_list=None, q_targ_list=None, optimizer=None, loss_function=None, policy_regularizer=None)[source]¶
- apply_grads(grads, function_state)¶
Update the model parameters (weights) of the underlying function approximator given pre-computed gradients.
This method is useful in situations in which computation of the gradients is deligated to a separate (remote) process.
- Parameters:
grads (pytree with ndarray leaves) – A batch of gradients, generated by the
grads
method.function_state (pytree) – The internal state of the forward-pass function. See
Q.function_state
andhaiku.transform_with_state()
for more details.
- grads_and_metrics(transition_batch)¶
Compute the gradients associated with a batch of transitions.
- Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
- Returns:
grads (pytree with ndarray leaves) – A batch of gradients.
function_state (pytree) – The internal state of the forward-pass function. See
Q.function_state
andhaiku.transform_with_state()
for more details.metrics (dict of scalar ndarrays) – The structure of the metrics dict is
{name: score}
.td_error (ndarray) – The non-aggregated TD-errors,
shape == (batch_size,)
.
- target_func(target_params, target_state, rng, transition_batch)[source]¶
This does almost the same as ClippedDoubleQLearning.target_func except that the action for the next state is sampled instead of taking the mode.
- td_error(transition_batch)¶
Compute the TD-errors associated with a batch of transitions. We define the TD-error as the negative gradient of the
loss_function
with respect to the predicted value:\[\text{td_error}_i\ =\ -\frac{\partial L(y, \hat{y})}{\partial \hat{y}_i}\]Note that this reduces to the ordinary definition \(\text{td_error}=y-\hat{y}\) when we use the
coax.value_losses.mse()
loss funtion.- Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
- Returns:
td_errors (ndarray, shape: [batch_size]) – A batch of TD-errors.
- update(transition_batch, return_td_error=False)¶
Update the model parameters (weights) of the underlying function approximator.
- Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
return_td_error (bool, optional) – Whether to return the TD-errors.
- Returns:
metrics (dict of scalar ndarrays) – The structure of the metrics dict is
{name: score}
.td_error (ndarray, optional) – The non-aggregated TD-errors,
shape == (batch_size,)
. This is only returned if we setreturn_td_error=True
.