Utilities

This is a collection of utility (helper) functions used throughout the package.

coax.utils.OrnsteinUhlenbeckNoise

Add Ornstein-Uhlenbeck noise to continuous actions.

coax.utils.StepwiseLinearFunction

Stepwise linear function.

coax.utils.SegmentTree

A segment tree data structure that allows for batched updating and batched partial-range (segment) reductions.

coax.utils.SumTree

A sum-tree data structure that allows for batched updating and batched weighted sampling.

coax.utils.MinTree

A min-tree data structure, which is a SegmentTree whose reducer is minimum.

coax.utils.MaxTree

A max-tree data structure, which is a SegmentTree whose reducer is maximum.

coax.utils.argmax

This is a little hack to ensure that argmax breaks ties randomly, which is something that numpy.argmax() doesn't do.

coax.utils.argmin

This is a little hack to ensure that argmin breaks ties randomly, which is something that numpy.argmin() doesn't do.

coax.utils.batch_to_single

Extract a single instance from a pytree of array batches.

coax.utils.check_array

This helper function is mostly for internal use.

coax.utils.check_preprocessors

Check whether two preprocessors are the same.

coax.utils.clipped_logit

A safe implementation of the logit function \(x\mapsto\log(x/(1-x))\).

coax.utils.default_preprocessor

The default preprocessor for a given space.

coax.utils.diff_transform

A helper function that implements discrete differentiation for stacked state observations.

coax.utils.diff_transform_matrix

A helper function that implements discrete differentiation for stacked state observations.

coax.utils.docstring

A simple decorator that sets the __doc__ attribute to obj.__doc__ on the decorated object, see example below.

coax.utils.double_relu

A double-ReLU, whose output is the concatenated result of -relu(-arr) and relu(arr).

coax.utils.dump

Save an object to disk.

coax.utils.dumps

Serialize an object to an lz4-compressed pickle byte-string.

coax.utils.enable_logging

Enable logging output.

coax.utils.generate_gif

Store a gif from the episode frames.

coax.utils.get_env_attr

Get the given attribute from a potentially wrapped environment.

coax.utils.get_grads_diagnostics

Given a pytree of grads, return a dict that contains the quantiles of the magnitudes of each individual component.

coax.utils.get_magnitude_quantiles

Given a pytree, return a dict that contains the quantiles of the magnitudes of each individual component.

coax.utils.get_transition_batch

Generate a single transition from the environment.

coax.utils.has_env_attr

Check if a potentially wrapped environment has a given attribute.

coax.utils.idx

Given a numpy array, return its corresponding integer index array.

coax.utils.is_policy

Check whether an object is a policy.

coax.utils.is_qfunction

Check whether an object is a state-action value function, or Q-function.

coax.utils.is_reward_function

Check whether an object is a dynamics model.

coax.utils.is_stochastic

Check whether an object is a stochastic function approximator.

coax.utils.is_transition_model

Check whether an object is a dynamics model.

coax.utils.is_vfunction

Check whether an object is a state value function, or V-function.

coax.utils.isscalar

This helper uses a slightly looser definition of scalar compared to numpy.isscalar() (and jax.numpy.isscalar()) in that it also considers single-item arrays to be scalars as well.

coax.utils.jit

An alternative of jax.jit() that returns a picklable JIT-compiled function.

coax.utils.load

Load an object from a file that was created by dump(obj, filepath).

coax.utils.loads

Load an object from a byte-string that was created by dumps(obj).

coax.utils.make_dmc

Create a Gym environment for a DeepMind Control suite task.

coax.utils.merge_dicts

Merge dicts into a single dict.

coax.utils.pretty_print

Print pretty_repr(obj).

coax.utils.pretty_repr

Generate pretty repr() (string representions).

coax.utils.quantiles

Generate batch_size quantile fractions that split the interval \([0, 1]\) into num_quantiles equally spaced fractions.

coax.utils.quantiles_uniform

Generate batch_size quantile fractions that split the interval \([0, 1]\) into num_quantiles uniformly distributed fractions.

coax.utils.quantile_cos_embedding

Embed the given quantile fractions \(\tau\) in an n dimensional space using cosine basis functions.

coax.utils.reload_recursive

Recursively reload a module (in order of dependence).

coax.utils.render_episode

Run a single episode with env.render() calls with each time step.

coax.utils.safe_sample

Safely sample from a gymnasium-style space.

coax.utils.single_to_batch

Take a single instance and turn it into a batch of size 1.

coax.utils.stack_trees

Apply jnp.stack to the leaves of a pytree.

coax.utils.sync_shared_params

Synchronize shared params.

coax.utils.tree_ravel

Flatten and concatenate all leaves into a single flat ndarray.

coax.utils.unvectorize

Apply a batched function on a single instance, which effectively does the inverse of what jax.vmap() does.

Object Reference

class coax.utils.OrnsteinUhlenbeckNoise(mu=0.0, sigma=1.0, theta=0.15, min_value=None, max_value=None, random_seed=None)[source]

Add Ornstein-Uhlenbeck noise to continuous actions.

\[A_t\ \mapsto\ \widetilde{A}_t = A_t + X_t\]

As a side effect, the Ornstein-Uhlenbeck noise \(X_t\) is updated with every function call:

\[X_t\ =\ X_{t-1} - \theta\,\left(X_{t-1} - \mu\right) + \sigma\,\varepsilon\]

where \(\varepsilon\) is white noise, i.e. \(\varepsilon\sim\mathcal{N}(0,\mathbb{I})\).

The authors of the DDPG paper chose to use Ornstein-Uhlenbeck noise “[…] in order to explore well in physical environments that have momentum.

Parameters:
  • mu (float or ndarray, optional) – The mean \(\mu\) towards which the Ornstein-Uhlenbeck process should revert; must be broadcastable with the input actions.

  • sigma (positive float or ndarray, optional) – The spread of the noise \(\sigma>0\) of the Ornstein-Uhlenbeck process; must be broadcastable with the input actions.

  • theta (positive float or ndarray, optional) – The (element-wise) dissipation rate \(\theta>0\) of the Ornstein-Uhlenbeck process; must be broadcastable with the input actions.

  • min_value (float or ndarray, optional) – The lower bound used for clipping the output action; must be broadcastable with the input actions.

  • max_value (float or ndarray, optional) – The upper bound used for clipping the output action; must be broadcastable with the input actions.

  • random_seed (int, optional) – Sets the random state to get reproducible results.

__call__(a)[source]

Add some Ornstein-Uhlenbeck to a continuous action.

Parameters:

a (action) – A single action \(A_t\).

Returns:

a_noisy (action) – An action with noise added \(\widetilde{A}_t = A_t + X_t\).

reset()[source]

Reset the Ornstein-Uhlenbeck process.

class coax.utils.StepwiseLinearFunction(*steps)[source]

Stepwise linear function. The function remains flat outside of the regions defined by steps.

Parameters:

*steps (sequence of tuples (int, float)) – Each step (timestep, value) fixes the output value at timestep to the provided value.

Example

Here’s an example of the exploration schedule in a DQN agent:

pi = coax.EpsilonGreedy(q, epsilon=1.0)
epsilon = StepwiseLinearFunction((0, 1.0), (1000000, 0.1), (2000000, 0.01))

for _ in range(num_episodes):
    pi.epsilon = epsilon(T)  # T is a global step counter
    ...
description

Notice that the function is flat outside the interpolation range provided by steps.

__call__(timestep)[source]

Return the value according to the provided schedule.

class coax.utils.SegmentTree(capacity, reducer, init_value)[source]

A segment tree data structure that allows for batched updating and batched partial-range (segment) reductions.

Parameters:
  • capacity (positive int) – Number of values to accommodate.

  • reducer (function) – The reducer function: (float, float) -> float.

  • init_value (float) – The unit element relative to the reducer function. Some typical examples are: 0 if reducer is add, 1 for multiply, \(-\infty\) for maximum, \(\infty\) for minimum.

Warning

The values attribute and square-bracket lookups (tree[level, index]) return references of the underlying storage array. Therefore, make sure that downstream code doesn’t update these values in-place, which would corrupt the segment tree structure.

partial_reduce(start=0, stop=None)[source]

Reduce values over a partial range of indices. This is an efficient, batched implementation of reduce(reducer, values[state:stop], init_value).

Parameters:
  • start (int or array of ints) – The lower bound of the range (inclusive).

  • stop (int or array of ints, optional) – The lower bound of the range (exclusive). If left unspecified, this defaults to height.

Returns:

value (float) – The result of the partial reduction.

set_values(idx, values)[source]

Set or update the values.

Parameters:
  • idx (1d array of ints) – The indices of the values to be updated. If you wish to update all values use ellipses instead, e.g. tree.set_values(..., values).

  • values (1d array of floats) – The new values.

property height

The height of the tree \(h\sim\log(\text{capacity})\).

property root_value

The aggregated value, equivalent to reduce(reducer, values, init_value).

property values

The values stored at the leaves of the tree.

class coax.utils.SumTree(capacity, random_seed=None)[source]

A sum-tree data structure that allows for batched updating and batched weighted sampling.

Both update and sampling operations have a time complexity of \(\mathcal{O}(\log N)\) and a memory footprint of \(\mathcal{O}(N)\), where \(N\) is the length of the underlying values.

Parameters:
  • capacity (positive int) – Number of values to accommodate.

  • reducer (function) – The reducer function: (float, float) -> float.

  • init_value (float) – The unit element relative to the reducer function. Some typical examples are: 0 if reducer is operator.add(), 1 for operator.mul(), \(-\infty\) for max(), \(\infty\) for min().

inverse_cdf(u)[source]

Inverse of the cumulative distribution function (CDF) of the categorical distribution \(\text{Cat}(p)\), where \(p\) are the normalized values \(p_i=\) values[i] / sum(values).

This function provides the machinery for the sample method.

Parameters:

u (float or 1d array of floats) – One of more numbers \(u\in[0,1]\). These are typically sampled from \(\text{Unif([0, 1])}\).

Returns:

idx (array of ints) – The indices associated with \(u\), shape: (n,)

Warning

This method presumes (but doesn’t check) that all values stored in the tree are non-negative.

partial_reduce(start=0, stop=None)

Reduce values over a partial range of indices. This is an efficient, batched implementation of reduce(reducer, values[state:stop], init_value).

Parameters:
  • start (int or array of ints) – The lower bound of the range (inclusive).

  • stop (int or array of ints, optional) – The lower bound of the range (exclusive). If left unspecified, this defaults to height.

Returns:

value (float) – The result of the partial reduction.

sample(n)[source]

Sample array indices using weighted sampling, where the sample weights are proprotional to the values stored in values.

Parameters:

n (positive int) – The number of samples to return.

Returns:

idx (array of ints) – The sampled indices, shape: (n,)

Warning

This method presumes (but doesn’t check) that all values stored in the tree are non-negative.

set_values(idx, values)

Set or update the values.

Parameters:
  • idx (1d array of ints) – The indices of the values to be updated. If you wish to update all values use ellipses instead, e.g. tree.set_values(..., values).

  • values (1d array of floats) – The new values.

property height

The height of the tree \(h\sim\log(\text{capacity})\).

property root_value

The aggregated value, equivalent to reduce(reducer, values, init_value).

property values

The values stored at the leaves of the tree.

class coax.utils.MinTree(capacity)[source]

A min-tree data structure, which is a SegmentTree whose reducer is minimum.

Parameters:

capacity (positive int) – Number of values to accommodate.

partial_reduce(start=0, stop=None)

Reduce values over a partial range of indices. This is an efficient, batched implementation of reduce(reducer, values[state:stop], init_value).

Parameters:
  • start (int or array of ints) – The lower bound of the range (inclusive).

  • stop (int or array of ints, optional) – The lower bound of the range (exclusive). If left unspecified, this defaults to height.

Returns:

value (float) – The result of the partial reduction.

set_values(idx, values)

Set or update the values.

Parameters:
  • idx (1d array of ints) – The indices of the values to be updated. If you wish to update all values use ellipses instead, e.g. tree.set_values(..., values).

  • values (1d array of floats) – The new values.

property height

The height of the tree \(h\sim\log(\text{capacity})\).

property root_value

The aggregated value, equivalent to reduce(reducer, values, init_value).

property values

The values stored at the leaves of the tree.

class coax.utils.MaxTree(capacity)[source]

A max-tree data structure, which is a SegmentTree whose reducer is maximum.

Parameters:

capacity (positive int) – Number of values to accommodate.

partial_reduce(start=0, stop=None)

Reduce values over a partial range of indices. This is an efficient, batched implementation of reduce(reducer, values[state:stop], init_value).

Parameters:
  • start (int or array of ints) – The lower bound of the range (inclusive).

  • stop (int or array of ints, optional) – The lower bound of the range (exclusive). If left unspecified, this defaults to height.

Returns:

value (float) – The result of the partial reduction.

set_values(idx, values)

Set or update the values.

Parameters:
  • idx (1d array of ints) – The indices of the values to be updated. If you wish to update all values use ellipses instead, e.g. tree.set_values(..., values).

  • values (1d array of floats) – The new values.

property height

The height of the tree \(h\sim\log(\text{capacity})\).

property root_value

The aggregated value, equivalent to reduce(reducer, values, init_value).

property values

The values stored at the leaves of the tree.

coax.utils.argmax(rng, arr, axis=-1)[source]

This is a little hack to ensure that argmax breaks ties randomly, which is something that numpy.argmax() doesn’t do.

Parameters:
  • rng (jax.random.PRNGKey) – A pseudo-random number generator key.

  • arr (array_like) – Input array.

  • axis (int, optional) – By default, the index is into the flattened array, otherwise along the specified axis.

Returns:

index_array (ndarray of ints) – Array of indices into the array. It has the same shape as a.shape with the dimension along axis removed.

coax.utils.argmin(rng, arr, axis=-1)[source]

This is a little hack to ensure that argmin breaks ties randomly, which is something that numpy.argmin() doesn’t do.

Note: random tie breaking is only done for 1d arrays; for multidimensional inputs, we fall back to the numpy version.

Parameters:
  • rng (jax.random.PRNGKey) – A pseudo-random number generator key.

  • arr (array_like) – Input array.

  • axis (int, optional) – By default, the index is into the flattened array, otherwise along the specified axis.

Returns:

index_array (ndarray of ints) – Array of indices into the array. It has the same shape as a.shape with the dimension along axis removed.

coax.utils.batch_to_single(pytree, index=0)[source]

Extract a single instance from a pytree of array batches.

This just does an leaf[index] on all leaf nodes of the pytree.

Parameters:

pytree_batch (pytree with ndarray leaves) – A pytree representing a batch.

Returns:

pytree_single (pytree with ndarray leaves) – A pytree representing e.g. a single state observation.

coax.utils.check_array(arr, ndim=None, ndim_min=None, ndim_max=None, dtype=None, shape=None, axis_size=None, axis=None, except_np=False)[source]

This helper function is mostly for internal use. It is used to check a few common properties of a numpy array.

Raises:

TypeError – If one of the checks fails.

coax.utils.check_preprocessors(space, *preprocessors, num_samples=20, random_seed=None)[source]

Check whether two preprocessors are the same.

Parameters:
  • space (gymnasium.Space) – The domain of the prepocessors.

  • *preprocessors – Preprocessor functions, which are functions with input signature: func(rng: PRNGKey, x: Element[space]) -> Any.

  • num_samples (positive int) – The number of samples in which to run checks.

Returns:

match (bool) – Whether the preprocessors match.

coax.utils.clipped_logit(x, epsilon=1e-15)[source]

A safe implementation of the logit function \(x\mapsto\log(x/(1-x))\). It clips the arguments of the log function from below so as to avoid evaluating it at 0:

\[\text{logit}_\epsilon(x)\ =\ \log(\max(\epsilon, x)) - \log(\max(\epsilon, 1 - x))\]
Parameters:
  • x (ndarray) – Input numpy array whose entries lie on the unit interval, \(x_i\in [0, 1]\).

  • epsilon (float, optional) – The small number with which to clip the arguments of the logarithm from below.

Returns:

z (ndarray, dtype: float, shape: same as input) – The output logits whose entries lie on the real line, \(z_i\in\mathbb{R}\).

coax.utils.default_preprocessor(space)[source]

The default preprocessor for a given space.

Parameters:

space (gymnasium.Space) – The domain of the prepocessor.

Returns:

preprocessor (Callable[PRGNKey, Element[space], Any]) – The preprocessor function. See NormalDist.preprocess_variate for an example.

coax.utils.diff_transform(X, dtype='float32')[source]

A helper function that implements discrete differentiation for stacked state observations. See diff_transform_matrix() for a detailed description.

M = diff_transform_matrix(num_frames=X.shape[-1])
X_transformed = np.dot(X, M)
Parameters:

X (ndarray) – An array whose shape is such that the last axis is the frame-stack axis, i.e. X.shape[-1] == num_frames.

Returns:

X_transformed (ndarray) – The shape is the same as the input shape, but the last axis are mixed to represent position, velocity, acceleration, etc.

coax.utils.diff_transform_matrix(num_frames, dtype='float32')[source]

A helper function that implements discrete differentiation for stacked state observations.

Let’s say we have a feature vector \(X\) consisting of four stacked frames, i.e. the shape would be: [batch_size, height, width, 4].

The corresponding diff-transform matrix with num_frames=4 is a \(4\times 4\) matrix given by:

\[\begin{split}M_\text{diff}^{(4)}\ =\ \begin{pmatrix} -1 & 0 & 0 & 0 \\ 3 & 1 & 0 & 0 \\ -3 & -2 & -1 & 0 \\ 1 & 1 & 1 & 1 \end{pmatrix}\end{split}\]

such that the diff-transformed feature vector is readily computed as:

\[X_\text{diff}\ =\ X\, M_\text{diff}^{(4)}\]

The diff-transformation preserves the shape, but it reorganizes the frames in such a way that they look more like canonical variables. You can think of \(X_\text{diff}\) as the stacked variables \(x\), \(\dot{x}\), \(\ddot{x}\), etc. (in reverse order). These represent the position, velocity, acceleration, etc. of pixels in a single frame.

Parameters:
  • num_frames (positive int) – The number of stacked frames in the original \(X\).

  • dtype (dtype, optional) – The output data type.

Returns:

M (2d-Tensor, shape: [num_frames, num_frames]) – A square matrix that is intended to be multiplied from the left, e.g. X_diff = K.dot(X_orig, M), where we assume that the frames are stacked in axis=-1 of X_orig, in chronological order.

coax.utils.docstring(obj)[source]

A simple decorator that sets the __doc__ attribute to obj.__doc__ on the decorated object, see example below.

Parameters:

obj (object) – The objects whose docstring you wish to copy onto the wrapped object.

Examples

>>> def f(x):
...     """Some docstring"""
...     return x * x
...
>>> def g(x):
...     return 13 - x
...
>>> g.__doc__ = f.__doc__

This can abbreviated by:

>>> @docstring(f)
... def g(x):
...     return 13 - x
...
coax.utils.double_relu(arr)[source]

A double-ReLU, whose output is the concatenated result of -relu(-arr) and relu(arr).

This activation function has the advantage that no signal is lost between layers.

Parameters:

arr (ndarray) – The input array, e.g. activations.

Returns:

doubled_arr – The output array, e.g. input for next layer.

Examples

>>> import coax
>>> import jax.numpy as jnp
>>> x = jnp.array([[-11, -8],
...                [ 17,  5],
...                [-13,  7],
...                [ 19, -3]])
...
>>> coax.utils.double_relu(x)
DeviceArray([[-11,  -8,   0,   0],
             [  0,   0,  17,   5],
             [-13,   0,   0,   7],
             [  0,  -3,  19,   0]], dtype=int32)

There are two things we may observe from the above example. The first is that all components from the original array are passed on as output. The second thing is that half of the output components (along axis=1) are masked out, which means that the doubling of array size doesn’t result in doubling the amount of “activation” passed on to the next layer. It merely allows for the neural net to learn conditional branches in its internal logic.

coax.utils.dump(obj, filepath)[source]

Save an object to disk.

Parameters:
  • obj (object) – Any python object.

  • filepath (str) – Where to store the instance.

Warning

References between objects are only preserved if they are stored as part of a single object, for example:

# b has a reference to a
a = [13]
b = {'a': a}

# references preserved
dump((a, b), 'ab.pkl.lz4')
a_new, b_new = load('ab.pkl.lz4')
b_new['a'].append(7)
print(b_new)  # {'a': [13, 7]}
print(a_new)  # [13, 7]         <-- updated

# references not preserved
dump(a, 'a.pkl.lz4')
dump(b, 'b.pkl.lz4')
a_new = load('a.pkl.lz4')
b_new = load('b.pkl.lz4')
b_new['a'].append(7)
print(b_new)  # {'a': [13, 7]}
print(a_new)  # [13]            <-- not updated!!

Therefore, the safest way to create checkpoints is to store the entire state as a single object like a dict or a tuple.

coax.utils.dumps(obj)[source]

Serialize an object to an lz4-compressed pickle byte-string.

Parameters:

obj (object) – Any python object.

Returns:

s (bytes) – An lz4-compressed pickle byte-string.

Warning

References between objects are only preserved if they are stored as part of a single object, for example:

# b has a reference to a
a = [13]
b = {'a': a}

# references preserved
s = dumps((a, b))
a_new, b_new = loads(s)
b_new['a'].append(7)
print(b_new)  # {'a': [13, 7]}
print(a_new)  # [13, 7]         <-- updated

# references not preserved
s_a = dumps(a)
s_b = dumps(b)
a_new = loads(s_a)
b_new = loads(s_b)
b_new['a'].append(7)
print(b_new)  # {'a': [13, 7]}
print(a_new)  # [13]            <-- not updated!!

Therefore, the safest way to create checkpoints is to store the entire state as a single object like a dict or a tuple.

coax.utils.enable_logging(name=None, level=20, output_filepath=None, output_level=None)[source]

Enable logging output.

This executes the following two lines of code:

import logging
logging.basicConfig(level=logging.INFO)
Parameters:
  • name (str, optional) – Name of the process that is logging. This can be set to whatever you like.

  • level (int, optional) – Logging level for the default StreamHandler. The default setting is level=logging.INFO (which is 20). If you’d like to see more verbose logging messages you might set level=logging.DEBUG.

  • output_filepath (str, optional) –

    If provided, a FileHandler will be added to the root logger via:

    file_handler = logging.FileHandler(output_filepath)
    logging.getLogger('').addHandler(file_handler)
    

  • output_level (int, optional) – Logging level for the FileHandler. If left unspecified, this defaults to level, i.e. the same level as the default StreamHandler.

coax.utils.generate_gif(env, filepath, policy=None, resize_to=None, duration=50, max_episode_steps=None)[source]

Store a gif from the episode frames.

Parameters:
  • env (gymnasium environment) – The environment to record from.

  • filepath (str) – Location of the output gif file.

  • policy (callable, optional) – A policy objects that is used to pick actions: a = policy(s). If left unspecified, we’ll just take random actions instead, i.e. a = env.action_space.sample().

  • resize_to (tuple of ints, optional) – The size of the output frames, (width, height). Notice the ordering: first width, then height. This is the convention PIL uses.

  • duration (float, optional) – Time between frames in the animated gif, in milliseconds.

  • max_episode_steps (int, optional) – The maximum number of step in the episode. If left unspecified, we’ll attempt to get the value from env.spec.max_episode_steps and if that fails we default to 10000.

coax.utils.get_env_attr(env, attr, default='__ERROR__', max_depth=100)[source]

Get the given attribute from a potentially wrapped environment.

Note that the wrapped envs are traversed from the outside in. Once the attribute is found, the search stops. This means that an inner wrapped env may carry the same (possibly conflicting) attribute. This situation is not resolved by this function.

Parameters:
  • env (gymnasium environment) – A potentially wrapped environment.

  • attr (str) – The attribute name.

  • max_depth (positive int, optional) – The maximum depth of wrappers to traverse.

coax.utils.get_grads_diagnostics(grads, key_prefix='', keep_tree_structure=False)[source]

Given a pytree of grads, return a dict that contains the quantiles of the magnitudes of each individual component.

This is meant to be a high-level diagnostic. It first extracts the leaves of the pytree, then flattens each leaf and then it computes the element-wise magnitude. Then, it concatenates all magnitudes into one long flat array. The quantiles are computed on this array.

Parameters:
  • grads (a pytree with ndarray leaves) – The gradients of some loss function with respect to the model parameters (weights).

  • key_prefix (str, optional) – The prefix to add the output dict keys.

  • keep_tree_structure (bool, optional) – Whether to keep the tree structure, i.e. to compute the grads diagnostics for each individual leaf. If False (default), we only compute the global grads diagnostics.

Returns:

grads_diagnotics (dict<str, float>) – A dict with structure {name: score}.

coax.utils.get_magnitude_quantiles(pytree, key_prefix='')[source]

Given a pytree, return a dict that contains the quantiles of the magnitudes of each individual component.

This is meant to be a high-level diagnostic. It first extracts the leaves of the pytree, then flattens each leaf and then it computes the element-wise magnitude. Then, it concatenates all magnitudes into one long flat array. The quantiles are computed on this array.

Parameters:
  • pytree (a pytree with ndarray leaves) – A typical example is a pytree of model params (weights) or gradients with respect to such model params.

  • key_prefix (str, optional) – The prefix to add the output dict keys.

Returns:

magnitude_quantiles (dict) – A dict with keys: ['min', 'p25', 'p50', 'p75', 'max']. The values of the dict are non-negative floats that represent the magnitude quantiles.

coax.utils.get_transition_batch(env, batch_size=1, gamma=0.9, random_seed=None)[source]

Generate a single transition from the environment.

This basically does a single step on the environment and then closes it.

Parameters:
  • env (gymnasium environment) – A gymnasium-style environment.

  • batch_size (positive int, optional) – The desired batch size of the sample.

  • random_seed (int, optional) – In order to generate the transition, we do some random sampling from the provided spaces. This random_seed set the seed for the pseudo-random number generators.

Returns:

transition_batch (TransitionBatch) – A batch of transitions.

coax.utils.has_env_attr(env, attr, max_depth=100)[source]

Check if a potentially wrapped environment has a given attribute.

Parameters:
  • env (gymnasium environment) – A potentially wrapped environment.

  • attr (str) – The attribute name.

  • max_depth (positive int, optional) – The maximum depth of wrappers to traverse.

coax.utils.idx(arr, axis=0)[source]

Given a numpy array, return its corresponding integer index array.

Parameters:
  • arr (array) – Input array.

  • axis (int, optional) – The axis along which we’d like to get an index.

Returns:

index (1d array, shape: arr.shape[axis]) – An index array [0, 1, 2, …].

coax.utils.is_policy(obj)[source]

Check whether an object is a policy.

Parameters:

obj – Object to check.

Returns:

bool – Whether obj is a policy.

coax.utils.is_qfunction(obj)[source]

Check whether an object is a state-action value function, or Q-function.

Parameters:

obj – Object to check.

Returns:

bool – Whether obj is a Q-function and (optionally) whether it is of modeltype 1 or 2.

coax.utils.is_reward_function(obj)[source]

Check whether an object is a dynamics model.

Parameters:

obj – Object to check.

Returns:

bool – Whether obj is a dynamics function.

coax.utils.is_stochastic(obj)[source]

Check whether an object is a stochastic function approximator.

Parameters:

obj – Object to check.

Returns:

bool – Whether obj is a stochastic function approximator.

coax.utils.is_transition_model(obj)[source]

Check whether an object is a dynamics model.

Parameters:

obj – Object to check.

Returns:

bool – Whether obj is a dynamics function.

coax.utils.is_vfunction(obj)[source]

Check whether an object is a state value function, or V-function.

Parameters:

obj – Object to check.

Returns:

bool – Whether obj is a V-function.

coax.utils.isscalar(num)[source]

This helper uses a slightly looser definition of scalar compared to numpy.isscalar() (and jax.numpy.isscalar()) in that it also considers single-item arrays to be scalars as well.

Parameters:

num (number or ndarray) – Input array.

Returns:

isscalar (bool) – Whether the input number is either a number or a single-item array.

coax.utils.jit(func, static_argnums=(), donate_argnums=())[source]

An alternative of jax.jit() that returns a picklable JIT-compiled function.

Note that jax.jit() produces non-picklable functions, because the JIT compilation depends on the device and backend. In order to facilitate serialization, this function does not for the user to specify device or backend. Instead jax.jit() is called with the default: jax.jit(..., device=None, backend=None).

Check out the original jax.jit() docs for a more detailed description of the arguments.

Parameters:
  • func (function) – Function to be JIT compiled.

  • static_argnums (int or tuple of ints) – Arguments to exclude from JIT compilation.

  • donate_argnums (int or tuple of ints) – To be donated arguments, see jax.jit().

Returns:

jitted_func (JittedFunc) – A picklable JIT-compiled function.

coax.utils.load(filepath)[source]

Load an object from a file that was created by dump(obj, filepath).

Parameters:

filepath (str) – File to load.

coax.utils.loads(s)[source]

Load an object from a byte-string that was created by dumps(obj).

Parameters:

s (str) – An lz4-compressed pickle byte-string.

coax.utils.make_dmc(domain, task, seed=0, max_episode_steps=1000, height=84, width=84, camera_id=0)[source]

Create a Gym environment for a DeepMind Control suite task.

Parameters:
  • domain (str) – Name of the domain.

  • task (str) – Name of the task.

  • seed (int) – Random seed.

  • max_episode_steps (int) – Maximum number of steps per episode.

  • height (int) – Height of the observation.

  • width (int) – Width of the observation.

  • camera_id (int) – Camera ID.

Returns:

env (gymnasium.Env) – Gym environment.

coax.utils.merge_dicts(*dicts)[source]

Merge dicts into a single dict.

WARNING: duplicate keys are not resolved.

Parameters:

*dicts (*dict) – Multiple dictionaries.

Returns:

merged (dict) – A single dictionary.

coax.utils.pretty_print(obj)[source]

Print pretty_repr(obj).

Parameters:

obj (object) – Any object.

coax.utils.pretty_repr(o, d=0)[source]

Generate pretty repr() (string representions).

Parameters:
  • o (object) – Any object.

  • d (int, optional) – The depth of the recursion. This is used to determine the indentation level in recursive calls, so we typically keep this 0.

Returns:

pretty_repr (str) – A nicely formatted string representation of object.

coax.utils.quantiles(batch_size, num_quantiles=200)[source]

Generate batch_size quantile fractions that split the interval \([0, 1]\) into num_quantiles equally spaced fractions.

Parameters:
  • batch_size (int) – The batch size for which the quantile fractions should be generated.

  • num_quantiles (int, optional) – The number of quantile fractions. By default 200.

Returns:

quantile_fractions (ndarray) – Array of quantile fractions.

coax.utils.quantiles_uniform(rng, batch_size, num_quantiles=32)[source]

Generate batch_size quantile fractions that split the interval \([0, 1]\) into num_quantiles uniformly distributed fractions.

Parameters:
  • rng (jax.random.PRNGKey) – A pseudo-random number generator key.

  • batch_size (int) – The batch size for which the quantile fractions should be generated.

  • num_quantiles (int, optional) – The number of quantile fractions. By default 32.

Returns:

quantile_fractions (ndarray) – Array of quantile fractions.

coax.utils.quantile_cos_embedding(quantile_fractions, n=64)[source]

Embed the given quantile fractions \(\tau\) in an n dimensional space using cosine basis functions.

\[\phi(\tau) = \cos(\tau i \pi) \qquad 0 \leq i \lt n\]
Parameters:
  • quantile_fractions (ndarray) – Array of quantile fractions \(\tau\) to be embedded.

  • n (int) – The dimensionality of the embedding. By default 64.

Returns:

quantile_embs (ndarray) – Array of quantile embeddings with shape (quantile_fractions.shape[0], n).

coax.utils.reload_recursive(module, reload_external_modules=False)[source]

Recursively reload a module (in order of dependence).

Parameters:
  • module (ModuleType or str) – The module to reload.

  • reload_external_modules (bool, optional) – Whether to reload all referenced modules, including external ones which aren’t submodules of module.

coax.utils.render_episode(env, policy=None, step_delay_ms=0)[source]

Run a single episode with env.render() calls with each time step.

Parameters:
  • env (gymnasium environment) – A gymnasium environment.

  • policy (callable, optional) – A policy objects that is used to pick actions: a = policy(s). If left unspecified, we’ll just take random actions instead, i.e. a = env.action_space.sample().

  • step_delay_ms (non-negative float) – The number of milliseconds to wait between consecutive timesteps. This can be used to slow down the rendering.

coax.utils.safe_sample(space, seed=None)[source]

Safely sample from a gymnasium-style space.

Parameters:
  • space (gymnasium.Space) – A gymnasium-style space.

  • seed (int, optional) – The seed for the pseudo-random number generator.

Returns:

sample – An single sample from of the given space.

coax.utils.single_to_batch(pytree)[source]

Take a single instance and turn it into a batch of size 1.

This just does an np.expand_dims(leaf, axis=0) on all leaf nodes of the pytree.

Parameters:

pytree_single (pytree with ndarray leaves) – A pytree representing e.g. a single state observation.

Returns:

pytree_batch (pytree with ndarray leaves) – A pytree representing a batch with batch_size=1.

coax.utils.stack_trees(*trees)[source]

Apply jnp.stack to the leaves of a pytree.

Parameters:

trees (sequence of pytrees with ndarray leaves) – A typical example are pytrees containing the parameters and function states of a model that should be used in a function which is vectorized by jax.vmap. The trees have to have the same pytree structure.

Returns:

pytree (pytree with ndarray leaves) – A tuple of pytrees.

coax.utils.sync_shared_params(*params, weights=None)[source]

Synchronize shared params. See the A2C stub for example usage.

Parameters:
  • *params (multiple hk.Params objects) – The parameter dicts that contain shared parameters.

  • weights (list of positive floats) – The relative weights to use for averaging the shared params.

Returns:

params (tuple of hk.Params objects) – Same as input *params but with synchronized shared params.

coax.utils.tree_ravel(pytree)[source]

Flatten and concatenate all leaves into a single flat ndarray.

Parameters:

pytree (a pytree with ndarray leaves) – A typical example is a pytree of model parameters (weights) or gradients with respect to such model params.

Returns:

arr (ndarray with ndim=1) – A single flat array.

coax.utils.unvectorize(f, in_axes=0, out_axes=0)[source]

Apply a batched function on a single instance, which effectively does the inverse of what jax.vmap() does.

Parameters:
  • f (callable) – A batched function.

  • in_axes (int or tuple of ints, optional) – Specify the batch axes of the inputs of the function f. If left unpsecified, this defaults to axis 0 for all inputs.

  • out_axis (int, optional) – Specify the batch axes of the outputs of the function f. These axes will be dropped by jnp.squeeze, i.e. dropped. If left unpsecified, this defaults to axis 0 for all outputs.

Returns:

f_single (callable) – The unvectorized version of f.

Examples

Haiku uses a batch-oriented design (although some components may be batch-agnostic). To create a function that acts on a single instance, we can use unvectorize() as follows:

import jax.numpy as jnp
import haiku as hk
import coax

def f(x_batch):
    return hk.Linear(11)(x_batch)


rngs = hk.PRNGSequence(42)

x_batch = jnp.zeros(shape=(3, 5))  # batch of 3 instances
x_single = jnp.zeros(shape=(5,))   # single instance

init, f_batch = hk.transform(f)
params = init(next(rngs), x_batch)
y_batch = f_batch(params, next(rngs), x_batch)
assert y_batch.shape == (3, 11)

f_single = coax.unvectorize(f_batch, in_axes=(None, None, 0), out_axes=0)
y_single = f_single(params, next(rngs), x_single)
assert y_single.shape == (11,)

Alternatively, and perhaps more conveniently, we can unvectorize the function before doing the Haiku transform:

init, f_single = hk.transform(coax.unvectorize(f))
params = init(next(rngs), x_single)
y_single = f_single(params, next(rngs), x_single)
assert y_single.shape == (11,)