Model Updaters¶
Model updater that uses sampling for maximum-likelihood estimation. |
This is a collection of objects that are used to update dynamics models, i.e. transition models and reward functions.
Object Reference¶
- class coax.model_updaters.ModelUpdater(model, optimizer=None, loss_function=None, regularizer=None)[source]¶
Model updater that uses sampling for maximum-likelihood estimation.
- Parameters:
model ([Stochastic]TransitionModel or [Stochastic]RewardFunction) – The main dynamics model to update.
optimizer (optax optimizer, optional) – An optax-style optimizer. The default optimizer is
optax.adam(1e-3)
.loss_function (callable, optional) –
The loss function that will be used to regress to the (bootstrapped) target. The loss function is expected to be of the form:
\[L(y_\text{true}, y_\text{pred}, w)\in\mathbb{R}\]where \(w>0\) are sample weights. If left unspecified, this defaults to
coax.value_losses.huber()
. Check out thecoax.value_losses
module for other predefined loss functions.regularizer (Regularizer, optional) – A stochastic regularizer, see
coax.regularizers
.
- apply_grads(grads, function_state)[source]¶
Update the model parameters (weights) of the underlying function approximator given pre-computed gradients.
This method is useful in situations in which computation of the gradients is deligated to a separate (remote) process.
- Parameters:
grads (pytree with ndarray leaves) – A batch of gradients, generated by the
grads
method.function_state (pytree) – The internal state of the forward-pass function. See
Q.function_state
andhaiku.transform_with_state()
for more details.
- grads_and_metrics(transition_batch)[source]¶
Compute the gradients associated with a batch of transitions.
- Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
- Returns:
grads (pytree with ndarray leaves) – A batch of gradients.
function_state (pytree) – The internal state of the forward-pass function. See
Q.function_state
andhaiku.transform_with_state()
for more details.metrics (dict of scalar ndarrays) – The structure of the metrics dict is
{name: score}
.
- update(transition_batch)[source]¶
Update the model parameters (weights) of the underlying function approximator.
- Parameters:
transition_batch (TransitionBatch) – A batch of transitions.
- Returns:
metrics (dict of scalar ndarrays) – The structure of the metrics dict is
{name: score}
.