TD Learning

coax.td_learning.SimpleTD

TD-learning for state value functions \(v(s)\).

coax.td_learning.Sarsa

TD-learning with SARSA updates.

coax.td_learning.ExpectedSarsa

TD-learning with expected-SARSA updates.

coax.td_learning.QLearning

TD-learning with Q-Learning updates.

coax.td_learning.DoubleQLearning

TD-learning with Double-DQN style double q-learning updates, in which the target network is only used in selecting the would-be next action.

coax.td_learning.SoftQLearning

TD-learning with soft Q-learning updates.

coax.td_learning.ClippedDoubleQLearning

TD-learning with TD3 style double q-learning updates, in which the target network is only used in selecting the would-be next action.

coax.td_learning.SoftClippedDoubleQLearning


This is a collection of objects that are used to update value functions via Temporal Difference (TD) learning. A state value function coax.V is using coax.td_learning.SimpleTD. To update a state-action value function coax.Q, there are multiple options available. The difference between the options are the manner in which the TD-target is constructed.

Object Reference

class coax.td_learning.SimpleTD(v, v_targ=None, optimizer=None, loss_function=None, policy_regularizer=None)[source]

TD-learning for state value functions \(v(s)\). The \(n\)-step bootstrapped target is constructed as:

\[G^{(n)}_t\ =\ R^{(n)}_t + I^{(n)}_t\,v_\text{targ}(S_{t+n})\]

where

\[\begin{split}R^{(n)}_t\ &=\ \sum_{k=0}^{n-1}\gamma^kR_{t+k} \\ I^{(n)}_t\ &=\ \left\{\begin{matrix} 0 & \text{if $S_{t+n}$ is a terminal state} \\ \gamma^n & \text{otherwise} \end{matrix}\right.\end{split}\]
Parameters:
  • v (V) – The main state value function to update.

  • v_targ (V, optional) – The state value function that is used for constructing the TD-target. If this is left unspecified, we set v_targ = v internally.

  • optimizer (optax optimizer, optional) – An optax-style optimizer. The default optimizer is optax.adam(1e-3).

  • loss_function (callable, optional) –

    The loss function that will be used to regress to the (bootstrapped) target. The loss function is expected to be of the form:

    \[L(y_\text{true}, y_\text{pred}, w)\in\mathbb{R}\]

    where \(w>0\) are sample weights. If left unspecified, this defaults to coax.value_losses.huber(). Check out the coax.value_losses module for other predefined loss functions.

  • policy_regularizer (Regularizer, optional) –

    If provided, this policy regularizer is added to the TD-target. A typical example is to use an coax.regularizers.EntropyRegularizer, which adds the policy entropy to the target. In this case, we minimize the following loss shifted by the entropy term:

    \[L(y_\text{true} + \beta\,H[\pi], y_\text{pred})\]

    Note that the coefficient \(\beta\) plays the role of the temperature in SAC-style agents.

apply_grads(grads, function_state)

Update the model parameters (weights) of the underlying function approximator given pre-computed gradients.

This method is useful in situations in which computation of the gradients is deligated to a separate (remote) process.

Parameters:
  • grads (pytree with ndarray leaves) – A batch of gradients, generated by the grads method.

  • function_state (pytree) – The internal state of the forward-pass function. See Q.function_state and haiku.transform_with_state() for more details.

grads_and_metrics(transition_batch)

Compute the gradients associated with a batch of transitions.

Parameters:

transition_batch (TransitionBatch) – A batch of transitions.

Returns:

  • grads (pytree with ndarray leaves) – A batch of gradients.

  • function_state (pytree) – The internal state of the forward-pass function. See Q.function_state and haiku.transform_with_state() for more details.

  • metrics (dict of scalar ndarrays) – The structure of the metrics dict is {name: score}.

  • td_error (ndarray) – The non-aggregated TD-errors, shape == (batch_size,).

td_error(transition_batch)

Compute the TD-errors associated with a batch of transitions. We define the TD-error as the negative gradient of the loss_function with respect to the predicted value:

\[\text{td_error}_i\ =\ -\frac{\partial L(y, \hat{y})}{\partial \hat{y}_i}\]

Note that this reduces to the ordinary definition \(\text{td_error}=y-\hat{y}\) when we use the coax.value_losses.mse() loss funtion.

Parameters:

transition_batch (TransitionBatch) – A batch of transitions.

Returns:

td_errors (ndarray, shape: [batch_size]) – A batch of TD-errors.

update(transition_batch, return_td_error=False)

Update the model parameters (weights) of the underlying function approximator.

Parameters:
  • transition_batch (TransitionBatch) – A batch of transitions.

  • return_td_error (bool, optional) – Whether to return the TD-errors.

Returns:

  • metrics (dict of scalar ndarrays) – The structure of the metrics dict is {name: score}.

  • td_error (ndarray, optional) – The non-aggregated TD-errors, shape == (batch_size,). This is only returned if we set return_td_error=True.

class coax.td_learning.Sarsa(q, q_targ=None, optimizer=None, loss_function=None, policy_regularizer=None)[source]

TD-learning with SARSA updates. The \(n\)-step bootstrapped target is constructed as:

\[G^{(n)}_t\ =\ R^{(n)}_t + I^{(n)}_t\,q_\text{targ}(S_{t+n}, A_{t+n})\]

where \(A_{t+n}\) is sampled from experience and

\[\begin{split}R^{(n)}_t\ &=\ \sum_{k=0}^{n-1}\gamma^kR_{t+k} \\ I^{(n)}_t\ &=\ \left\{\begin{matrix} 0 & \text{if $S_{t+n}$ is a terminal state} \\ \gamma^n & \text{otherwise} \end{matrix}\right.\end{split}\]
Parameters:
  • q (Q) – The main q-function to update.

  • q_targ (Q, optional) – The q-function that is used for constructing the TD-target. If this is left unspecified, we set q_targ = q internally.

  • optimizer (optax optimizer, optional) – An optax-style optimizer. The default optimizer is optax.adam(1e-3).

  • loss_function (callable, optional) –

    The loss function that will be used to regress to the (bootstrapped) target. The loss function is expected to be of the form:

    \[L(y_\text{true}, y_\text{pred}, w)\in\mathbb{R}\]

    where \(w>0\) are sample weights. If left unspecified, this defaults to coax.value_losses.huber(). Check out the coax.value_losses module for other predefined loss functions.

  • policy_regularizer (Regularizer, optional) –

    If provided, this policy regularizer is added to the TD-target. A typical example is to use an coax.regularizers.EntropyRegularizer, which adds the policy entropy to the target. In this case, we minimize the following loss shifted by the entropy term:

    \[L(y_\text{true} + \beta\,H[\pi], y_\text{pred})\]

    Note that the coefficient \(\beta\) plays the role of the temperature in SAC-style agents.

apply_grads(grads, function_state)

Update the model parameters (weights) of the underlying function approximator given pre-computed gradients.

This method is useful in situations in which computation of the gradients is deligated to a separate (remote) process.

Parameters:
  • grads (pytree with ndarray leaves) – A batch of gradients, generated by the grads method.

  • function_state (pytree) – The internal state of the forward-pass function. See Q.function_state and haiku.transform_with_state() for more details.

grads_and_metrics(transition_batch)

Compute the gradients associated with a batch of transitions.

Parameters:

transition_batch (TransitionBatch) – A batch of transitions.

Returns:

  • grads (pytree with ndarray leaves) – A batch of gradients.

  • function_state (pytree) – The internal state of the forward-pass function. See Q.function_state and haiku.transform_with_state() for more details.

  • metrics (dict of scalar ndarrays) – The structure of the metrics dict is {name: score}.

  • td_error (ndarray) – The non-aggregated TD-errors, shape == (batch_size,).

td_error(transition_batch)

Compute the TD-errors associated with a batch of transitions. We define the TD-error as the negative gradient of the loss_function with respect to the predicted value:

\[\text{td_error}_i\ =\ -\frac{\partial L(y, \hat{y})}{\partial \hat{y}_i}\]

Note that this reduces to the ordinary definition \(\text{td_error}=y-\hat{y}\) when we use the coax.value_losses.mse() loss funtion.

Parameters:

transition_batch (TransitionBatch) – A batch of transitions.

Returns:

td_errors (ndarray, shape: [batch_size]) – A batch of TD-errors.

update(transition_batch, return_td_error=False)

Update the model parameters (weights) of the underlying function approximator.

Parameters:
  • transition_batch (TransitionBatch) – A batch of transitions.

  • return_td_error (bool, optional) – Whether to return the TD-errors.

Returns:

  • metrics (dict of scalar ndarrays) – The structure of the metrics dict is {name: score}.

  • td_error (ndarray, optional) – The non-aggregated TD-errors, shape == (batch_size,). This is only returned if we set return_td_error=True.

class coax.td_learning.ExpectedSarsa(q, pi_targ, q_targ=None, optimizer=None, loss_function=None, policy_regularizer=None)[source]

TD-learning with expected-SARSA updates. The \(n\)-step bootstrapped target is constructed as:

\[G^{(n)}_t\ =\ R^{(n)}_t + I^{(n)}_t\,\mathop{\mathbb{E}}_{a\sim\pi_\text{targ}(.|S_{t+n})}\, q_\text{targ}\left(S_{t+n}, a\right)\]

Note that ordinary SARSA target is the sampled estimate of the above target.

Also, as usual, the \(n\)-step reward and indicator are defined as:

\[\begin{split}R^{(n)}_t\ &=\ \sum_{k=0}^{n-1}\gamma^kR_{t+k} \\ I^{(n)}_t\ &=\ \left\{\begin{matrix} 0 & \text{if $S_{t+n}$ is a terminal state} \\ \gamma^n & \text{otherwise} \end{matrix}\right.\end{split}\]
Parameters:
  • q (Q) – The main q-function to update.

  • pi_targ (Policy) – The policy that is used for constructing the TD-target.

  • q_targ (Q, optional) – The q-function that is used for constructing the TD-target. If this is left unspecified, we set q_targ = q internally.

  • optimizer (optax optimizer, optional) – An optax-style optimizer. The default optimizer is optax.adam(1e-3).

  • loss_function (callable, optional) –

    The loss function that will be used to regress to the (bootstrapped) target. The loss function is expected to be of the form:

    \[L(y_\text{true}, y_\text{pred}, w)\in\mathbb{R}\]

    where \(w>0\) are sample weights. If left unspecified, this defaults to coax.value_losses.huber(). Check out the coax.value_losses module for other predefined loss functions.

  • policy_regularizer (Regularizer, optional) –

    If provided, this policy regularizer is added to the TD-target. A typical example is to use an coax.regularizers.EntropyRegularizer, which adds the policy entropy to the target. In this case, we minimize the following loss shifted by the entropy term:

    \[L(y_\text{true} + \beta\,H[\pi], y_\text{pred})\]

    Note that the coefficient \(\beta\) plays the role of the temperature in SAC-style agents.

apply_grads(grads, function_state)

Update the model parameters (weights) of the underlying function approximator given pre-computed gradients.

This method is useful in situations in which computation of the gradients is deligated to a separate (remote) process.

Parameters:
  • grads (pytree with ndarray leaves) – A batch of gradients, generated by the grads method.

  • function_state (pytree) – The internal state of the forward-pass function. See Q.function_state and haiku.transform_with_state() for more details.

grads_and_metrics(transition_batch)

Compute the gradients associated with a batch of transitions.

Parameters:

transition_batch (TransitionBatch) – A batch of transitions.

Returns:

  • grads (pytree with ndarray leaves) – A batch of gradients.

  • function_state (pytree) – The internal state of the forward-pass function. See Q.function_state and haiku.transform_with_state() for more details.

  • metrics (dict of scalar ndarrays) – The structure of the metrics dict is {name: score}.

  • td_error (ndarray) – The non-aggregated TD-errors, shape == (batch_size,).

td_error(transition_batch)

Compute the TD-errors associated with a batch of transitions. We define the TD-error as the negative gradient of the loss_function with respect to the predicted value:

\[\text{td_error}_i\ =\ -\frac{\partial L(y, \hat{y})}{\partial \hat{y}_i}\]

Note that this reduces to the ordinary definition \(\text{td_error}=y-\hat{y}\) when we use the coax.value_losses.mse() loss funtion.

Parameters:

transition_batch (TransitionBatch) – A batch of transitions.

Returns:

td_errors (ndarray, shape: [batch_size]) – A batch of TD-errors.

update(transition_batch, return_td_error=False)

Update the model parameters (weights) of the underlying function approximator.

Parameters:
  • transition_batch (TransitionBatch) – A batch of transitions.

  • return_td_error (bool, optional) – Whether to return the TD-errors.

Returns:

  • metrics (dict of scalar ndarrays) – The structure of the metrics dict is {name: score}.

  • td_error (ndarray, optional) – The non-aggregated TD-errors, shape == (batch_size,). This is only returned if we set return_td_error=True.

class coax.td_learning.QLearning(q, pi_targ=None, q_targ=None, optimizer=None, loss_function=None, policy_regularizer=None)[source]

TD-learning with Q-Learning updates.

The \(n\)-step bootstrapped target for discrete actions is constructed as:

\[G^{(n)}_t\ =\ R^{(n)}_t + I^{(n)}_t\,\max_aq_\text{targ}\left(S_{t+n}, a\right)\]

For non-discrete action spaces, this uses a DDPG-style target:

\[G^{(n)}_t\ =\ R^{(n)}_t + I^{(n)}_t\,q_\text{targ}\left( S_{t+n}, a_\text{targ}(S_{t+n})\right)\]

where \(a_\text{targ}(s)\) is the mode of the underlying conditional probability distribution \(\pi_\text{targ}(.|s)\). Even though these two formulations of the q-learning target are equivalent, the implementation of the latter does require additional input, namely pi_targ.

The \(n\)-step reward and indicator (referenced above) are defined as:

\[\begin{split}R^{(n)}_t\ &=\ \sum_{k=0}^{n-1}\gamma^kR_{t+k} \\ I^{(n)}_t\ &=\ \left\{\begin{matrix} 0 & \text{if $S_{t+n}$ is a terminal state} \\ \gamma^n & \text{otherwise} \end{matrix}\right.\end{split}\]
Parameters:
  • q (Q) – The main q-function to update.

  • pi_targ (Policy, optional) – The policy that is used for constructing the TD-target. This is ignored if the action space is discrete and required otherwise.

  • q_targ (Q, optional) – The q-function that is used for constructing the TD-target. If this is left unspecified, we set q_targ = q internally.

  • optimizer (optax optimizer, optional) – An optax-style optimizer. The default optimizer is optax.adam(1e-3).

  • loss_function (callable, optional) –

    The loss function that will be used to regress to the (bootstrapped) target. The loss function is expected to be of the form:

    \[L(y_\text{true}, y_\text{pred}, w)\in\mathbb{R}\]

    where \(w>0\) are sample weights. If left unspecified, this defaults to coax.value_losses.huber(). Check out the coax.value_losses module for other predefined loss functions.

  • policy_regularizer (Regularizer, optional) –

    If provided, this policy regularizer is added to the TD-target. A typical example is to use an coax.regularizers.EntropyRegularizer, which adds the policy entropy to the target. In this case, we minimize the following loss shifted by the entropy term:

    \[L(y_\text{true} + \beta\,H[\pi], y_\text{pred})\]

    Note that the coefficient \(\beta\) plays the role of the temperature in SAC-style agents.

apply_grads(grads, function_state)

Update the model parameters (weights) of the underlying function approximator given pre-computed gradients.

This method is useful in situations in which computation of the gradients is deligated to a separate (remote) process.

Parameters:
  • grads (pytree with ndarray leaves) – A batch of gradients, generated by the grads method.

  • function_state (pytree) – The internal state of the forward-pass function. See Q.function_state and haiku.transform_with_state() for more details.

grads_and_metrics(transition_batch)

Compute the gradients associated with a batch of transitions.

Parameters:

transition_batch (TransitionBatch) – A batch of transitions.

Returns:

  • grads (pytree with ndarray leaves) – A batch of gradients.

  • function_state (pytree) – The internal state of the forward-pass function. See Q.function_state and haiku.transform_with_state() for more details.

  • metrics (dict of scalar ndarrays) – The structure of the metrics dict is {name: score}.

  • td_error (ndarray) – The non-aggregated TD-errors, shape == (batch_size,).

td_error(transition_batch)

Compute the TD-errors associated with a batch of transitions. We define the TD-error as the negative gradient of the loss_function with respect to the predicted value:

\[\text{td_error}_i\ =\ -\frac{\partial L(y, \hat{y})}{\partial \hat{y}_i}\]

Note that this reduces to the ordinary definition \(\text{td_error}=y-\hat{y}\) when we use the coax.value_losses.mse() loss funtion.

Parameters:

transition_batch (TransitionBatch) – A batch of transitions.

Returns:

td_errors (ndarray, shape: [batch_size]) – A batch of TD-errors.

update(transition_batch, return_td_error=False)

Update the model parameters (weights) of the underlying function approximator.

Parameters:
  • transition_batch (TransitionBatch) – A batch of transitions.

  • return_td_error (bool, optional) – Whether to return the TD-errors.

Returns:

  • metrics (dict of scalar ndarrays) – The structure of the metrics dict is {name: score}.

  • td_error (ndarray, optional) – The non-aggregated TD-errors, shape == (batch_size,). This is only returned if we set return_td_error=True.

class coax.td_learning.DoubleQLearning(q, pi_targ=None, q_targ=None, optimizer=None, loss_function=None, policy_regularizer=None)[source]

TD-learning with Double-DQN style double q-learning updates, in which the target network is only used in selecting the would-be next action. The \(n\)-step bootstrapped target is thus constructed as:

\[\begin{split}a_\text{greedy}\ &=\ \arg\max_a q_\text{targ}(S_{t+n}, a) \\ G^{(n)}_t\ &=\ R^{(n)}_t + I^{(n)}_t\,q(S_{t+n}, a_\text{greedy})\end{split}\]

where

\[\begin{split}R^{(n)}_t\ &=\ \sum_{k=0}^{n-1}\gamma^kR_{t+k} \\ I^{(n)}_t\ &=\ \left\{\begin{matrix} 0 & \text{if $S_{t+n}$ is a terminal state} \\ \gamma^n & \text{otherwise} \end{matrix}\right.\end{split}\]
Parameters:
  • q (Q) – The main q-function to update.

  • pi_targ (Policy, optional) – The policy that is used for constructing the TD-target. This is ignored if the action space is discrete and required otherwise.

  • q_targ (Q, optional) – The q-function that is used for constructing the TD-target. If this is left unspecified, we set q_targ = q internally.

  • optimizer (optax optimizer, optional) – An optax-style optimizer. The default optimizer is optax.adam(1e-3).

  • loss_function (callable, optional) –

    The loss function that will be used to regress to the (bootstrapped) target. The loss function is expected to be of the form:

    \[L(y_\text{true}, y_\text{pred}, w)\in\mathbb{R}\]

    where \(w>0\) are sample weights. If left unspecified, this defaults to coax.value_losses.huber(). Check out the coax.value_losses module for other predefined loss functions.

  • policy_regularizer (Regularizer, optional) –

    If provided, this policy regularizer is added to the TD-target. A typical example is to use an coax.regularizers.EntropyRegularizer, which adds the policy entropy to the target. In this case, we minimize the following loss shifted by the entropy term:

    \[L(y_\text{true} + \beta\,H[\pi], y_\text{pred})\]

    Note that the coefficient \(\beta\) plays the role of the temperature in SAC-style agents.

apply_grads(grads, function_state)

Update the model parameters (weights) of the underlying function approximator given pre-computed gradients.

This method is useful in situations in which computation of the gradients is deligated to a separate (remote) process.

Parameters:
  • grads (pytree with ndarray leaves) – A batch of gradients, generated by the grads method.

  • function_state (pytree) – The internal state of the forward-pass function. See Q.function_state and haiku.transform_with_state() for more details.

grads_and_metrics(transition_batch)

Compute the gradients associated with a batch of transitions.

Parameters:

transition_batch (TransitionBatch) – A batch of transitions.

Returns:

  • grads (pytree with ndarray leaves) – A batch of gradients.

  • function_state (pytree) – The internal state of the forward-pass function. See Q.function_state and haiku.transform_with_state() for more details.

  • metrics (dict of scalar ndarrays) – The structure of the metrics dict is {name: score}.

  • td_error (ndarray) – The non-aggregated TD-errors, shape == (batch_size,).

td_error(transition_batch)

Compute the TD-errors associated with a batch of transitions. We define the TD-error as the negative gradient of the loss_function with respect to the predicted value:

\[\text{td_error}_i\ =\ -\frac{\partial L(y, \hat{y})}{\partial \hat{y}_i}\]

Note that this reduces to the ordinary definition \(\text{td_error}=y-\hat{y}\) when we use the coax.value_losses.mse() loss funtion.

Parameters:

transition_batch (TransitionBatch) – A batch of transitions.

Returns:

td_errors (ndarray, shape: [batch_size]) – A batch of TD-errors.

update(transition_batch, return_td_error=False)

Update the model parameters (weights) of the underlying function approximator.

Parameters:
  • transition_batch (TransitionBatch) – A batch of transitions.

  • return_td_error (bool, optional) – Whether to return the TD-errors.

Returns:

  • metrics (dict of scalar ndarrays) – The structure of the metrics dict is {name: score}.

  • td_error (ndarray, optional) – The non-aggregated TD-errors, shape == (batch_size,). This is only returned if we set return_td_error=True.

class coax.td_learning.SoftQLearning(q, q_targ=None, optimizer=None, loss_function=None, policy_regularizer=None, temperature=1.0)[source]

TD-learning with soft Q-learning updates. The \(n\)-step bootstrapped target is constructed as:

\[G^{(n)}_t\ =\ R^{(n)}_t + I^{(n)}_t\,\tau\log\sum_{a'}\exp\left(q_\text{targ}(S_{t+n}, a') / \tau\right)\]

where

\[\begin{split}R^{(n)}_t\ &=\ \sum_{k=0}^{n-1}\gamma^kR_{t+k} \\ I^{(n)}_t\ &=\ \left\{\begin{matrix} 0 & \text{if $S_{t+n}$ is a terminal state} \\ \gamma^n & \text{otherwise} \end{matrix}\right.\end{split}\]
Parameters:
  • q (Q) – The main q-function to update.

  • q_targ (Q, optional) – The q-function that is used for constructing the TD-target. If this is left unspecified, we set q_targ = q internally.

  • optimizer (optax optimizer, optional) – An optax-style optimizer. The default optimizer is optax.adam(1e-3).

  • loss_function (callable, optional) –

    The loss function that will be used to regress to the (bootstrapped) target. The loss function is expected to be of the form:

    \[L(y_\text{true}, y_\text{pred}, w)\in\mathbb{R}\]

    where \(w>0\) are sample weights. If left unspecified, this defaults to coax.value_losses.huber(). Check out the coax.value_losses module for other predefined loss functions.

  • policy_regularizer (Regularizer, optional) –

    If provided, this policy regularizer is added to the TD-target. A typical example is to use an coax.regularizers.EntropyRegularizer, which adds the policy entropy to the target. In this case, we minimize the following loss shifted by the entropy term:

    \[L(y_\text{true} + \beta\,H[\pi], y_\text{pred})\]

    Note that the coefficient \(\beta\) plays the role of the temperature in SAC-style agents.

  • temperature (float, optional) – The Boltzmann temperature \(\tau>0\).

apply_grads(grads, function_state)

Update the model parameters (weights) of the underlying function approximator given pre-computed gradients.

This method is useful in situations in which computation of the gradients is deligated to a separate (remote) process.

Parameters:
  • grads (pytree with ndarray leaves) – A batch of gradients, generated by the grads method.

  • function_state (pytree) – The internal state of the forward-pass function. See Q.function_state and haiku.transform_with_state() for more details.

grads_and_metrics(transition_batch)

Compute the gradients associated with a batch of transitions.

Parameters:

transition_batch (TransitionBatch) – A batch of transitions.

Returns:

  • grads (pytree with ndarray leaves) – A batch of gradients.

  • function_state (pytree) – The internal state of the forward-pass function. See Q.function_state and haiku.transform_with_state() for more details.

  • metrics (dict of scalar ndarrays) – The structure of the metrics dict is {name: score}.

  • td_error (ndarray) – The non-aggregated TD-errors, shape == (batch_size,).

td_error(transition_batch)

Compute the TD-errors associated with a batch of transitions. We define the TD-error as the negative gradient of the loss_function with respect to the predicted value:

\[\text{td_error}_i\ =\ -\frac{\partial L(y, \hat{y})}{\partial \hat{y}_i}\]

Note that this reduces to the ordinary definition \(\text{td_error}=y-\hat{y}\) when we use the coax.value_losses.mse() loss funtion.

Parameters:

transition_batch (TransitionBatch) – A batch of transitions.

Returns:

td_errors (ndarray, shape: [batch_size]) – A batch of TD-errors.

update(transition_batch, return_td_error=False)

Update the model parameters (weights) of the underlying function approximator.

Parameters:
  • transition_batch (TransitionBatch) – A batch of transitions.

  • return_td_error (bool, optional) – Whether to return the TD-errors.

Returns:

  • metrics (dict of scalar ndarrays) – The structure of the metrics dict is {name: score}.

  • td_error (ndarray, optional) – The non-aggregated TD-errors, shape == (batch_size,). This is only returned if we set return_td_error=True.

class coax.td_learning.ClippedDoubleQLearning(q, pi_targ_list=None, q_targ_list=None, optimizer=None, loss_function=None, policy_regularizer=None)[source]

TD-learning with TD3 style double q-learning updates, in which the target network is only used in selecting the would-be next action.

For discrete actions, the \(n\)-step bootstrapped target is constructed as:

\[G^{(n)}_t\ =\ R^{(n)}_t + I^{(n)}_t\,\min_{i,j}q_i(S_{t+n}, \arg\max_a q_j(S_{t+n}, a))\]

where \(q_i(s,a)\) is the \(i\)-th target q-function provided in q_targ_list.

Similarly, for non-discrete actions, the target is constructed as:

\[G^{(n)}_t\ =\ R^{(n)}_t + I^{(n)}_t\,\min_{i,j}q_i(S_{t+n}, a_j(S_{t+n}))\]

where \(a_i(s)\) is the mode of the \(i\)-th target policy provided in pi_targ_list.

where

\[\begin{split}R^{(n)}_t\ &=\ \sum_{k=0}^{n-1}\gamma^kR_{t+k} \\ I^{(n)}_t\ &=\ \left\{\begin{matrix} 0 & \text{if $S_{t+n}$ is a terminal state} \\ \gamma^n & \text{otherwise} \end{matrix}\right.\end{split}\]
Parameters:
  • q (Q) – The main q-function to update.

  • pi_targ_list (list of Policy, optional) – The list of policies that are used for constructing the TD-target. This is ignored if the action space is discrete and required otherwise.

  • q_targ_list (list of Q) – The list of q-functions that are used for constructing the TD-target.

  • optimizer (optax optimizer, optional) – An optax-style optimizer. The default optimizer is optax.adam(1e-3).

  • loss_function (callable, optional) –

    The loss function that will be used to regress to the (bootstrapped) target. The loss function is expected to be of the form:

    \[L(y_\text{true}, y_\text{pred}, w)\in\mathbb{R}\]

    where \(w>0\) are sample weights. If left unspecified, this defaults to coax.value_losses.huber(). Check out the coax.value_losses module for other predefined loss functions.

  • policy_regularizer (Regularizer, optional) –

    If provided, this policy regularizer is added to the TD-target. A typical example is to use an coax.regularizers.EntropyRegularizer, which adds the policy entropy to the target. In this case, we minimize the following loss shifted by the entropy term:

    \[L(y_\text{true} + \beta\,H[\pi], y_\text{pred})\]

    Note that the coefficient \(\beta\) plays the role of the temperature in SAC-style agents.

apply_grads(grads, function_state)

Update the model parameters (weights) of the underlying function approximator given pre-computed gradients.

This method is useful in situations in which computation of the gradients is deligated to a separate (remote) process.

Parameters:
  • grads (pytree with ndarray leaves) – A batch of gradients, generated by the grads method.

  • function_state (pytree) – The internal state of the forward-pass function. See Q.function_state and haiku.transform_with_state() for more details.

grads_and_metrics(transition_batch)

Compute the gradients associated with a batch of transitions.

Parameters:

transition_batch (TransitionBatch) – A batch of transitions.

Returns:

  • grads (pytree with ndarray leaves) – A batch of gradients.

  • function_state (pytree) – The internal state of the forward-pass function. See Q.function_state and haiku.transform_with_state() for more details.

  • metrics (dict of scalar ndarrays) – The structure of the metrics dict is {name: score}.

  • td_error (ndarray) – The non-aggregated TD-errors, shape == (batch_size,).

td_error(transition_batch)

Compute the TD-errors associated with a batch of transitions. We define the TD-error as the negative gradient of the loss_function with respect to the predicted value:

\[\text{td_error}_i\ =\ -\frac{\partial L(y, \hat{y})}{\partial \hat{y}_i}\]

Note that this reduces to the ordinary definition \(\text{td_error}=y-\hat{y}\) when we use the coax.value_losses.mse() loss funtion.

Parameters:

transition_batch (TransitionBatch) – A batch of transitions.

Returns:

td_errors (ndarray, shape: [batch_size]) – A batch of TD-errors.

update(transition_batch, return_td_error=False)

Update the model parameters (weights) of the underlying function approximator.

Parameters:
  • transition_batch (TransitionBatch) – A batch of transitions.

  • return_td_error (bool, optional) – Whether to return the TD-errors.

Returns:

  • metrics (dict of scalar ndarrays) – The structure of the metrics dict is {name: score}.

  • td_error (ndarray, optional) – The non-aggregated TD-errors, shape == (batch_size,). This is only returned if we set return_td_error=True.

class coax.td_learning.SoftClippedDoubleQLearning(q, pi_targ_list=None, q_targ_list=None, optimizer=None, loss_function=None, policy_regularizer=None)[source]
apply_grads(grads, function_state)

Update the model parameters (weights) of the underlying function approximator given pre-computed gradients.

This method is useful in situations in which computation of the gradients is deligated to a separate (remote) process.

Parameters:
  • grads (pytree with ndarray leaves) – A batch of gradients, generated by the grads method.

  • function_state (pytree) – The internal state of the forward-pass function. See Q.function_state and haiku.transform_with_state() for more details.

grads_and_metrics(transition_batch)

Compute the gradients associated with a batch of transitions.

Parameters:

transition_batch (TransitionBatch) – A batch of transitions.

Returns:

  • grads (pytree with ndarray leaves) – A batch of gradients.

  • function_state (pytree) – The internal state of the forward-pass function. See Q.function_state and haiku.transform_with_state() for more details.

  • metrics (dict of scalar ndarrays) – The structure of the metrics dict is {name: score}.

  • td_error (ndarray) – The non-aggregated TD-errors, shape == (batch_size,).

target_func(target_params, target_state, rng, transition_batch)[source]

This does almost the same as ClippedDoubleQLearning.target_func except that the action for the next state is sampled instead of taking the mode.

td_error(transition_batch)

Compute the TD-errors associated with a batch of transitions. We define the TD-error as the negative gradient of the loss_function with respect to the predicted value:

\[\text{td_error}_i\ =\ -\frac{\partial L(y, \hat{y})}{\partial \hat{y}_i}\]

Note that this reduces to the ordinary definition \(\text{td_error}=y-\hat{y}\) when we use the coax.value_losses.mse() loss funtion.

Parameters:

transition_batch (TransitionBatch) – A batch of transitions.

Returns:

td_errors (ndarray, shape: [batch_size]) – A batch of TD-errors.

update(transition_batch, return_td_error=False)

Update the model parameters (weights) of the underlying function approximator.

Parameters:
  • transition_batch (TransitionBatch) – A batch of transitions.

  • return_td_error (bool, optional) – Whether to return the TD-errors.

Returns:

  • metrics (dict of scalar ndarrays) – The structure of the metrics dict is {name: score}.

  • td_error (ndarray, optional) – The non-aggregated TD-errors, shape == (batch_size,). This is only returned if we set return_td_error=True.