Model Updaters

coax.model_updaters.ModelUpdater

Model updater that uses sampling for maximum-likelihood estimation.


This is a collection of objects that are used to update dynamics models, i.e. transition models and reward functions.

Object Reference

class coax.model_updaters.ModelUpdater(model, optimizer=None, loss_function=None, regularizer=None)[source]

Model updater that uses sampling for maximum-likelihood estimation.

Parameters:
  • model ([Stochastic]TransitionModel or [Stochastic]RewardFunction) – The main dynamics model to update.

  • optimizer (optax optimizer, optional) – An optax-style optimizer. The default optimizer is optax.adam(1e-3).

  • loss_function (callable, optional) –

    The loss function that will be used to regress to the (bootstrapped) target. The loss function is expected to be of the form:

    \[L(y_\text{true}, y_\text{pred}, w)\in\mathbb{R}\]

    where \(w>0\) are sample weights. If left unspecified, this defaults to coax.value_losses.huber(). Check out the coax.value_losses module for other predefined loss functions.

  • regularizer (Regularizer, optional) – A stochastic regularizer, see coax.regularizers.

apply_grads(grads, function_state)[source]

Update the model parameters (weights) of the underlying function approximator given pre-computed gradients.

This method is useful in situations in which computation of the gradients is deligated to a separate (remote) process.

Parameters:
  • grads (pytree with ndarray leaves) – A batch of gradients, generated by the grads method.

  • function_state (pytree) – The internal state of the forward-pass function. See Q.function_state and haiku.transform_with_state() for more details.

grads_and_metrics(transition_batch)[source]

Compute the gradients associated with a batch of transitions.

Parameters:

transition_batch (TransitionBatch) – A batch of transitions.

Returns:

  • grads (pytree with ndarray leaves) – A batch of gradients.

  • function_state (pytree) – The internal state of the forward-pass function. See Q.function_state and haiku.transform_with_state() for more details.

  • metrics (dict of scalar ndarrays) – The structure of the metrics dict is {name: score}.

update(transition_batch)[source]

Update the model parameters (weights) of the underlying function approximator.

Parameters:

transition_batch (TransitionBatch) – A batch of transitions.

Returns:

metrics (dict of scalar ndarrays) – The structure of the metrics dict is {name: score}.