Source code for coax.model_updaters._model_updater

import jax
import jax.numpy as jnp
import haiku as hk
import optax

from ..utils import (
    get_grads_diagnostics, is_stochastic, is_reward_function, is_transition_model, jit)
from ..value_losses import huber
from ..regularizers import Regularizer


__all__ = (
    'ModelUpdater',
)


[docs]class ModelUpdater: r""" Model updater that uses *sampling* for maximum-likelihood estimation. Parameters ---------- model : [Stochastic]TransitionModel or [Stochastic]RewardFunction The main dynamics model to update. optimizer : optax optimizer, optional An optax-style optimizer. The default optimizer is :func:`optax.adam(1e-3) <optax.adam>`. loss_function : callable, optional The loss function that will be used to regress to the (bootstrapped) target. The loss function is expected to be of the form: .. math:: L(y_\text{true}, y_\text{pred}, w)\in\mathbb{R} where :math:`w>0` are sample weights. If left unspecified, this defaults to :func:`coax.value_losses.huber`. Check out the :mod:`coax.value_losses` module for other predefined loss functions. regularizer : Regularizer, optional A stochastic regularizer, see :mod:`coax.regularizers`. """ def __init__(self, model, optimizer=None, loss_function=None, regularizer=None): if not (is_reward_function(model) or is_transition_model(model)): raise TypeError(f"model must be a dynamics model, got: {type(model)}") if not isinstance(regularizer, (Regularizer, type(None))): raise TypeError(f"regularizer must be a Regularizer, got: {type(regularizer)}") self.model = model self.loss_function = huber if loss_function is None else loss_function self.regularizer = regularizer # optimizer self._optimizer = optax.adam(1e-3) if optimizer is None else optimizer self._optimizer_state = self.optimizer.init(self.model.params) def apply_grads_func(opt, opt_state, params, grads): updates, new_opt_state = opt.update(grads, opt_state, params) new_params = optax.apply_updates(params, updates) return new_opt_state, new_params def loss_func(params, state, hyperparams, rng, transition_batch): rngs = hk.PRNGSequence(rng) S = self.model.observation_preprocessor(next(rngs), transition_batch.S) A = self.model.action_preprocessor(next(rngs), transition_batch.A) W = jnp.clip(transition_batch.W, 0.1, 10.) # clip importance weights to reduce variance if is_stochastic(self.model): dist_params, new_state = \ self.model.function_type1(params, state, next(rngs), S, A, True) y_pred = self.model.proba_dist.sample(dist_params, next(rngs)) else: y_pred, new_state = self.model.function_type1(params, state, next(rngs), S, A, True) if is_transition_model(self.model): y_true = self.model.observation_preprocessor(next(rngs), transition_batch.S_next) elif is_reward_function(self.model): y_true = self.model.value_transform.transform_func(transition_batch.Rn) else: raise AssertionError(f"unexpected model type: {type(self.model)}") loss = self.loss_function(y_true, y_pred, W) metrics = { f'{self.__class__.__name__}/loss': loss, f'{self.__class__.__name__}/loss_bare': loss, } # add regularization term if self.regularizer is not None: hparams = hyperparams['regularizer'] loss = loss + jnp.mean(W * self.regularizer.function(dist_params, **hparams)) metrics[f'{self.__class__.__name__}/loss'] = loss metrics.update(self.regularizer.metrics_func(dist_params, **hparams)) return loss, (metrics, new_state) def grads_and_metrics_func(params, state, hyperparams, rng, transition_batch): grads, (metrics, new_state) = \ jax.grad(loss_func, has_aux=True)(params, state, hyperparams, rng, transition_batch) # add some diagnostics of the gradients metrics.update(get_grads_diagnostics(grads, f'{self.__class__.__name__}/grads_')) return grads, new_state, metrics self._apply_grads_func = jit(apply_grads_func, static_argnums=0) self._grads_and_metrics_func = jit(grads_and_metrics_func)
[docs] def update(self, transition_batch): r""" Update the model parameters (weights) of the underlying function approximator. Parameters ---------- transition_batch : TransitionBatch A batch of transitions. Returns ------- metrics : dict of scalar ndarrays The structure of the metrics dict is ``{name: score}``. """ grads, function_state, metrics = self.grads_and_metrics(transition_batch) if any(jnp.any(jnp.isnan(g)) for g in jax.tree_util.tree_leaves(grads)): raise RuntimeError(f"found nan's in grads: {grads}") self.apply_grads(grads, function_state) return metrics
[docs] def apply_grads(self, grads, function_state): r""" Update the model parameters (weights) of the underlying function approximator given pre-computed gradients. This method is useful in situations in which computation of the gradients is deligated to a separate (remote) process. Parameters ---------- grads : pytree with ndarray leaves A batch of gradients, generated by the :attr:`grads` method. function_state : pytree The internal state of the forward-pass function. See :attr:`Q.function_state <coax.Q.function_state>` and :func:`haiku.transform_with_state` for more details. """ self.model.function_state = function_state self.optimizer_state, self.model.params = \ self._apply_grads_func(self.optimizer, self.optimizer_state, self.model.params, grads)
[docs] def grads_and_metrics(self, transition_batch): r""" Compute the gradients associated with a batch of transitions. Parameters ---------- transition_batch : TransitionBatch A batch of transitions. Returns ------- grads : pytree with ndarray leaves A batch of gradients. function_state : pytree The internal state of the forward-pass function. See :attr:`Q.function_state <coax.Q.function_state>` and :func:`haiku.transform_with_state` for more details. metrics : dict of scalar ndarrays The structure of the metrics dict is ``{name: score}``. """ return self._grads_and_metrics_func( self.model.params, self.model.function_state, self.hyperparams, self.model.rng, transition_batch)
@property def hyperparams(self): return hk.data_structures.to_immutable_dict({ 'regularizer': getattr(self.regularizer, 'hyperparams', {})}) @property def optimizer(self): return self._optimizer @optimizer.setter def optimizer(self, new_optimizer): new_optimizer_state_structure = jax.tree_util.tree_structure( new_optimizer.init(self.model.params)) if new_optimizer_state_structure != jax.tree_util.tree_structure(self.optimizer_state): raise AttributeError("cannot set optimizer attr: mismatch in optimizer_state structure") self._optimizer = new_optimizer @property def optimizer_state(self): return self._optimizer_state @optimizer_state.setter def optimizer_state(self, new_optimizer_state): new_tree_structure = jax.tree_util.tree_structure(new_optimizer_state) tree_structure = jax.tree_util.tree_structure(self.optimizer_state) if new_tree_structure != tree_structure: raise AttributeError("cannot set optimizer_state attr: mismatch in tree structure") self._optimizer_state = new_optimizer_state