Regularizers

coax.regularizers.EntropyRegularizer

Policy regularization term based on the entropy of the policy.

coax.regularizers.KLDivRegularizer

Policy regularization term based on the Kullback-Leibler divergence of the policy relative to a given set of priors.


This is a collection of regularizers that can be used to put soft constraints on stochastic function approximators. These is typically added to the loss/objective to avoid premature exploitation of a policy.

Object Reference

class coax.regularizers.EntropyRegularizer(f, beta=0.001)[source]

Policy regularization term based on the entropy of the policy.

The regularization term is to be added to the loss function:

\[\text{loss}(\theta; s,a)\ =\ -J(\theta; s,a) - \beta\,H[\pi_\theta(.|s)]\]

where \(J(\theta)\) is the bare policy objective.

Parameters:
  • f (stochastic function approximator) – The stochastic function approximator (e.g. coax.Policy) to regularize.

  • beta (non-negative float) – The coefficient that determines the strength of the overall regularization term.

property function

JIT-compiled function that returns the values for the regularization term.

Parameters:
  • dist_params (pytree with ndarray leaves) – The distribution parameters of the (conditional) probability distribution.

  • beta (non-negative float) – The coefficient that determines the strength of the overall regularization term.

property metrics_func

JIT-compiled function that returns the performance metrics for the regularization term.

Parameters:
  • dist_params (pytree with ndarray leaves) – The distribution parameters of the (conditional) probability distribution.

  • beta (non-negative float) – The coefficient that determines the strength of the overall regularization term.

class coax.regularizers.KLDivRegularizer(f, beta=0.001, priors=None)[source]

Policy regularization term based on the Kullback-Leibler divergence of the policy relative to a given set of priors.

The regularization term is to be added to the loss function:

\[\text{loss}(\theta; s,a)\ =\ -J(\theta; s,a) + \beta\,KL[\pi_\theta, \pi_\text{prior}]\]

where \(J(\theta)\) is the bare policy objective. Also, in order to unclutter the notation we abbreviated \(\pi(.|s)\) by \(\pi\).

Parameters:
  • f (stochastic function approximator) – The stochastic function approximator (e.g. coax.Policy) to regularize.

  • beta (non-negative float) – The coefficient that determines the strength of the overall regularization term.

  • priors (pytree with ndarray leaves, optional) – The distribution parameters that correspond to the priors. If left unspecified, we’ll use proba_dist.default_priors, see e.g. NormalDist.default_priors.

property function

JIT-compiled function that returns the values for the regularization term.

Parameters:
  • dist_params (pytree with ndarray leaves) – The distribution parameters of the (conditional) probability distribution.

  • beta (non-negative float) – The coefficient that determines the strength of the overall regularization term.

  • priors (pytree with ndarray leaves) – The distribution parameters that correspond to the priors.

property metrics_func

JIT-compiled function that returns the performance metrics for the regularization term.

Parameters:
  • dist_params (pytree with ndarray leaves) – The distribution parameters of the (conditional) probability distribution.

  • beta (non-negative float) – The coefficient that determines the strength of the overall regularization term.

  • priors (pytree with ndarray leaves) – The distribution parameters that correspond to the priors.