Utilities¶
This is a collection of utility (helper) functions used throughout the package.
Add OrnsteinUhlenbeck noise to continuous actions. 

Stepwise linear function. 

A segment tree data structure that allows for batched updating and batched partialrange (segment) reductions. 

A sumtree data structure that allows for batched updating and batched weighted sampling. 

A mintree data structure, which is a 

A maxtree data structure, which is a 

This is a little hack to ensure that argmax breaks ties randomly, which is something that 

This is a little hack to ensure that argmin breaks ties randomly, which is something that 

Extract a single instance from a pytree of array batches. 

This helper function is mostly for internal use. 

Check whether two preprocessors are the same. 

A safe implementation of the logit function \(x\mapsto\log(x/(1x))\). 

The default preprocessor for a given space. 

A helper function that implements discrete differentiation for stacked state observations. 

A helper function that implements discrete differentiation for stacked state observations. 

A simple decorator that sets the 

A doubleReLU, whose output is the concatenated result of 

Save an object to disk. 

Serialize an object to an lz4compressed pickle bytestring. 

Enable logging output. 

Store a gif from the episode frames. 

Get the given attribute from a potentially wrapped environment. 

Given a pytree of grads, return a dict that contains the quantiles of the magnitudes of each individual component. 

Given a pytree, return a dict that contains the quantiles of the magnitudes of each individual component. 

Generate a single transition from the environment. 

Check if a potentially wrapped environment has a given attribute. 

Given a numpy array, return its corresponding integer index array. 

Check whether an object is a policy. 

Check whether an object is a 

Check whether an object is a dynamics model. 

Check whether an object is a stochastic function approximator. 

Check whether an object is a dynamics model. 

Check whether an object is a 

This helper uses a slightly looser definition of scalar compared to 

An alternative of 

Load an object from a file that was created by 

Load an object from a bytestring that was created by 

Create a Gym environment for a DeepMind Control suite task. 

Merge dicts into a single dict. 

Print 

Generate pretty 

Generate 

Generate 

Embed the given quantile fractions \(\tau\) in an n dimensional space using cosine basis functions. 

Recursively reload a module (in order of dependence). 

Run a single episode with env.render() calls with each time step. 

Safely sample from a gymnasiumstyle space. 

Take a single instance and turn it into a batch of size 1. 

Apply 

Synchronize shared params. 

Flatten and concatenate all leaves into a single flat ndarray. 

Apply a batched function on a single instance, which effectively does the inverse of what 
Object Reference¶
 class coax.utils.OrnsteinUhlenbeckNoise(mu=0.0, sigma=1.0, theta=0.15, min_value=None, max_value=None, random_seed=None)[source]¶
Add OrnsteinUhlenbeck noise to continuous actions.
\[A_t\ \mapsto\ \widetilde{A}_t = A_t + X_t\]As a side effect, the OrnsteinUhlenbeck noise \(X_t\) is updated with every function call:
\[X_t\ =\ X_{t1}  \theta\,\left(X_{t1}  \mu\right) + \sigma\,\varepsilon\]where \(\varepsilon\) is white noise, i.e. \(\varepsilon\sim\mathcal{N}(0,\mathbb{I})\).
The authors of the DDPG paper chose to use OrnsteinUhlenbeck noise “[…] in order to explore well in physical environments that have momentum.”
 Parameters:
mu (float or ndarray, optional) – The mean \(\mu\) towards which the OrnsteinUhlenbeck process should revert; must be broadcastable with the input actions.
sigma (positive float or ndarray, optional) – The spread of the noise \(\sigma>0\) of the OrnsteinUhlenbeck process; must be broadcastable with the input actions.
theta (positive float or ndarray, optional) – The (elementwise) dissipation rate \(\theta>0\) of the OrnsteinUhlenbeck process; must be broadcastable with the input actions.
min_value (float or ndarray, optional) – The lower bound used for clipping the output action; must be broadcastable with the input actions.
max_value (float or ndarray, optional) – The upper bound used for clipping the output action; must be broadcastable with the input actions.
random_seed (int, optional) – Sets the random state to get reproducible results.
 class coax.utils.StepwiseLinearFunction(*steps)[source]¶
Stepwise linear function. The function remains flat outside of the regions defined by
steps
. Parameters:
*steps (sequence of tuples (int, float)) – Each step
(timestep, value)
fixes the output value attimestep
to the providedvalue
.
Example
Here’s an example of the exploration schedule in a DQN agent:
pi = coax.EpsilonGreedy(q, epsilon=1.0) epsilon = StepwiseLinearFunction((0, 1.0), (1000000, 0.1), (2000000, 0.01)) for _ in range(num_episodes): pi.epsilon = epsilon(T) # T is a global step counter ...
Notice that the function is flat outside the interpolation range provided by
steps
.
 class coax.utils.SegmentTree(capacity, reducer, init_value)[source]¶
A segment tree data structure that allows for batched updating and batched partialrange (segment) reductions.
 Parameters:
capacity (positive int) – Number of values to accommodate.
reducer (function) – The reducer function:
(float, float) > float
.init_value (float) – The unit element relative to the reducer function. Some typical examples are: 0 if reducer is
add
, 1 formultiply
, \(\infty\) formaximum
, \(\infty\) forminimum
.
Warning
The
values
attribute and squarebracket lookups (tree[level, index]
) return references of the underlying storage array. Therefore, make sure that downstream code doesn’t update these values inplace, which would corrupt the segment tree structure. partial_reduce(start=0, stop=None)[source]¶
Reduce values over a partial range of indices. This is an efficient, batched implementation of
reduce(reducer, values[state:stop], init_value)
.
 set_values(idx, values)[source]¶
Set or update the
values
. Parameters:
idx (1d array of ints) – The indices of the values to be updated. If you wish to update all values use ellipses instead, e.g.
tree.set_values(..., values)
.values (1d array of floats) – The new values.
 property height¶
The height of the tree \(h\sim\log(\text{capacity})\).
 property root_value¶
The aggregated value, equivalent to
reduce(reducer, values, init_value)
.
 property values¶
The values stored at the leaves of the tree.
 class coax.utils.SumTree(capacity, random_seed=None)[source]¶
A sumtree data structure that allows for batched updating and batched weighted sampling.
Both update and sampling operations have a time complexity of \(\mathcal{O}(\log N)\) and a memory footprint of \(\mathcal{O}(N)\), where \(N\) is the length of the underlying
values
. Parameters:
capacity (positive int) – Number of values to accommodate.
reducer (function) – The reducer function:
(float, float) > float
.init_value (float) – The unit element relative to the reducer function. Some typical examples are: 0 if reducer is
operator.add()
, 1 foroperator.mul()
, \(\infty\) formax()
, \(\infty\) formin()
.
 inverse_cdf(u)[source]¶
Inverse of the cumulative distribution function (CDF) of the categorical distribution \(\text{Cat}(p)\), where \(p\) are the normalized values \(p_i=\)
values[i] / sum(values)
.This function provides the machinery for the
sample
method. Parameters:
u (float or 1d array of floats) – One of more numbers \(u\in[0,1]\). These are typically sampled from \(\text{Unif([0, 1])}\).
 Returns:
idx (array of ints) – The indices associated with \(u\), shape: (n,)
Warning
This method presumes (but doesn’t check) that all
values
stored in the tree are nonnegative.
 partial_reduce(start=0, stop=None)¶
Reduce values over a partial range of indices. This is an efficient, batched implementation of
reduce(reducer, values[state:stop], init_value)
.
 sample(n)[source]¶
Sample array indices using weighted sampling, where the sample weights are proprotional to the values stored in
values
. Parameters:
n (positive int) – The number of samples to return.
 Returns:
idx (array of ints) – The sampled indices, shape: (n,)
Warning
This method presumes (but doesn’t check) that all
values
stored in the tree are nonnegative.
 set_values(idx, values)¶
Set or update the
values
. Parameters:
idx (1d array of ints) – The indices of the values to be updated. If you wish to update all values use ellipses instead, e.g.
tree.set_values(..., values)
.values (1d array of floats) – The new values.
 property height¶
The height of the tree \(h\sim\log(\text{capacity})\).
 property root_value¶
The aggregated value, equivalent to
reduce(reducer, values, init_value)
.
 property values¶
The values stored at the leaves of the tree.
 class coax.utils.MinTree(capacity)[source]¶
A mintree data structure, which is a
SegmentTree
whose reducer isminimum
. Parameters:
capacity (positive int) – Number of values to accommodate.
 partial_reduce(start=0, stop=None)¶
Reduce values over a partial range of indices. This is an efficient, batched implementation of
reduce(reducer, values[state:stop], init_value)
.
 set_values(idx, values)¶
Set or update the
values
. Parameters:
idx (1d array of ints) – The indices of the values to be updated. If you wish to update all values use ellipses instead, e.g.
tree.set_values(..., values)
.values (1d array of floats) – The new values.
 property height¶
The height of the tree \(h\sim\log(\text{capacity})\).
 property root_value¶
The aggregated value, equivalent to
reduce(reducer, values, init_value)
.
 property values¶
The values stored at the leaves of the tree.
 class coax.utils.MaxTree(capacity)[source]¶
A maxtree data structure, which is a
SegmentTree
whose reducer ismaximum
. Parameters:
capacity (positive int) – Number of values to accommodate.
 partial_reduce(start=0, stop=None)¶
Reduce values over a partial range of indices. This is an efficient, batched implementation of
reduce(reducer, values[state:stop], init_value)
.
 set_values(idx, values)¶
Set or update the
values
. Parameters:
idx (1d array of ints) – The indices of the values to be updated. If you wish to update all values use ellipses instead, e.g.
tree.set_values(..., values)
.values (1d array of floats) – The new values.
 property height¶
The height of the tree \(h\sim\log(\text{capacity})\).
 property root_value¶
The aggregated value, equivalent to
reduce(reducer, values, init_value)
.
 property values¶
The values stored at the leaves of the tree.
 coax.utils.argmax(rng, arr, axis=1)[source]¶
This is a little hack to ensure that argmax breaks ties randomly, which is something that
numpy.argmax()
doesn’t do. Parameters:
rng (jax.random.PRNGKey) – A pseudorandom number generator key.
arr (array_like) – Input array.
axis (int, optional) – By default, the index is into the flattened array, otherwise along the specified axis.
 Returns:
index_array (ndarray of ints) – Array of indices into the array. It has the same shape as a.shape with the dimension along axis removed.
 coax.utils.argmin(rng, arr, axis=1)[source]¶
This is a little hack to ensure that argmin breaks ties randomly, which is something that
numpy.argmin()
doesn’t do.Note: random tie breaking is only done for 1d arrays; for multidimensional inputs, we fall back to the numpy version.
 Parameters:
rng (jax.random.PRNGKey) – A pseudorandom number generator key.
arr (array_like) – Input array.
axis (int, optional) – By default, the index is into the flattened array, otherwise along the specified axis.
 Returns:
index_array (ndarray of ints) – Array of indices into the array. It has the same shape as a.shape with the dimension along axis removed.
 coax.utils.batch_to_single(pytree, index=0)[source]¶
Extract a single instance from a pytree of array batches.
This just does an
leaf[index]
on all leaf nodes of the pytree. Parameters:
pytree_batch (pytree with ndarray leaves) – A pytree representing a batch.
 Returns:
pytree_single (pytree with ndarray leaves) – A pytree representing e.g. a single state observation.
 coax.utils.check_array(arr, ndim=None, ndim_min=None, ndim_max=None, dtype=None, shape=None, axis_size=None, axis=None, except_np=False)[source]¶
This helper function is mostly for internal use. It is used to check a few common properties of a numpy array.
 Raises:
TypeError – If one of the checks fails.
 coax.utils.check_preprocessors(space, *preprocessors, num_samples=20, random_seed=None)[source]¶
Check whether two preprocessors are the same.
 Parameters:
space (gymnasium.Space) – The domain of the prepocessors.
*preprocessors – Preprocessor functions, which are functions with input signature:
func(rng: PRNGKey, x: Element[space]) > Any
.num_samples (positive int) – The number of samples in which to run checks.
 Returns:
match (bool) – Whether the preprocessors match.
 coax.utils.clipped_logit(x, epsilon=1e15)[source]¶
A safe implementation of the logit function \(x\mapsto\log(x/(1x))\). It clips the arguments of the log function from below so as to avoid evaluating it at 0:
\[\text{logit}_\epsilon(x)\ =\ \log(\max(\epsilon, x))  \log(\max(\epsilon, 1  x))\] Parameters:
x (ndarray) – Input numpy array whose entries lie on the unit interval, \(x_i\in [0, 1]\).
epsilon (float, optional) – The small number with which to clip the arguments of the logarithm from below.
 Returns:
z (ndarray, dtype: float, shape: same as input) – The output logits whose entries lie on the real line, \(z_i\in\mathbb{R}\).
 coax.utils.default_preprocessor(space)[source]¶
The default preprocessor for a given space.
 Parameters:
space (gymnasium.Space) – The domain of the prepocessor.
 Returns:
preprocessor (Callable[PRGNKey, Element[space], Any]) – The preprocessor function. See
NormalDist.preprocess_variate
for an example.
 coax.utils.diff_transform(X, dtype='float32')[source]¶
A helper function that implements discrete differentiation for stacked state observations. See
diff_transform_matrix()
for a detailed description.M = diff_transform_matrix(num_frames=X.shape[1]) X_transformed = np.dot(X, M)
 Parameters:
X (ndarray) – An array whose shape is such that the last axis is the framestack axis, i.e.
X.shape[1] == num_frames
. Returns:
X_transformed (ndarray) – The shape is the same as the input shape, but the last axis are mixed to represent position, velocity, acceleration, etc.
 coax.utils.diff_transform_matrix(num_frames, dtype='float32')[source]¶
A helper function that implements discrete differentiation for stacked state observations.
Let’s say we have a feature vector \(X\) consisting of four stacked frames, i.e. the shape would be:
[batch_size, height, width, 4]
.The corresponding difftransform matrix with
num_frames=4
is a \(4\times 4\) matrix given by:\[\begin{split}M_\text{diff}^{(4)}\ =\ \begin{pmatrix} 1 & 0 & 0 & 0 \\ 3 & 1 & 0 & 0 \\ 3 & 2 & 1 & 0 \\ 1 & 1 & 1 & 1 \end{pmatrix}\end{split}\]such that the difftransformed feature vector is readily computed as:
\[X_\text{diff}\ =\ X\, M_\text{diff}^{(4)}\]The difftransformation preserves the shape, but it reorganizes the frames in such a way that they look more like canonical variables. You can think of \(X_\text{diff}\) as the stacked variables \(x\), \(\dot{x}\), \(\ddot{x}\), etc. (in reverse order). These represent the position, velocity, acceleration, etc. of pixels in a single frame.
 Parameters:
num_frames (positive int) – The number of stacked frames in the original \(X\).
dtype (dtype, optional) – The output data type.
 Returns:
M (2dTensor, shape: [num_frames, num_frames]) – A square matrix that is intended to be multiplied from the left, e.g.
X_diff = K.dot(X_orig, M)
, where we assume that the frames are stacked inaxis=1
ofX_orig
, in chronological order.
 coax.utils.docstring(obj)[source]¶
A simple decorator that sets the
__doc__
attribute toobj.__doc__
on the decorated object, see example below. Parameters:
obj (object) – The objects whose docstring you wish to copy onto the wrapped object.
Examples
>>> def f(x): ... """Some docstring""" ... return x * x ... >>> def g(x): ... return 13  x ... >>> g.__doc__ = f.__doc__
This can abbreviated by:
>>> @docstring(f) ... def g(x): ... return 13  x ...
 coax.utils.double_relu(arr)[source]¶
A doubleReLU, whose output is the concatenated result of
relu(arr)
andrelu(arr)
.This activation function has the advantage that no signal is lost between layers.
 Parameters:
arr (ndarray) – The input array, e.g. activations.
 Returns:
doubled_arr – The output array, e.g. input for next layer.
Examples
>>> import coax >>> import jax.numpy as jnp >>> x = jnp.array([[11, 8], ... [ 17, 5], ... [13, 7], ... [ 19, 3]]) ... >>> coax.utils.double_relu(x) DeviceArray([[11, 8, 0, 0], [ 0, 0, 17, 5], [13, 0, 0, 7], [ 0, 3, 19, 0]], dtype=int32)
There are two things we may observe from the above example. The first is that all components from the original array are passed on as output. The second thing is that half of the output components (along axis=1) are masked out, which means that the doubling of array size doesn’t result in doubling the amount of “activation” passed on to the next layer. It merely allows for the neural net to learn conditional branches in its internal logic.
 coax.utils.dump(obj, filepath)[source]¶
Save an object to disk.
Warning
References between objects are only preserved if they are stored as part of a single object, for example:
# b has a reference to a a = [13] b = {'a': a} # references preserved dump((a, b), 'ab.pkl.lz4') a_new, b_new = load('ab.pkl.lz4') b_new['a'].append(7) print(b_new) # {'a': [13, 7]} print(a_new) # [13, 7] < updated # references not preserved dump(a, 'a.pkl.lz4') dump(b, 'b.pkl.lz4') a_new = load('a.pkl.lz4') b_new = load('b.pkl.lz4') b_new['a'].append(7) print(b_new) # {'a': [13, 7]} print(a_new) # [13] < not updated!!
Therefore, the safest way to create checkpoints is to store the entire state as a single object like a dict or a tuple.
 coax.utils.dumps(obj)[source]¶
Serialize an object to an lz4compressed pickle bytestring.
 Parameters:
obj (object) – Any python object.
 Returns:
s (bytes) – An lz4compressed pickle bytestring.
Warning
References between objects are only preserved if they are stored as part of a single object, for example:
# b has a reference to a a = [13] b = {'a': a} # references preserved s = dumps((a, b)) a_new, b_new = loads(s) b_new['a'].append(7) print(b_new) # {'a': [13, 7]} print(a_new) # [13, 7] < updated # references not preserved s_a = dumps(a) s_b = dumps(b) a_new = loads(s_a) b_new = loads(s_b) b_new['a'].append(7) print(b_new) # {'a': [13, 7]} print(a_new) # [13] < not updated!!
Therefore, the safest way to create checkpoints is to store the entire state as a single object like a dict or a tuple.
 coax.utils.enable_logging(name=None, level=20, output_filepath=None, output_level=None)[source]¶
Enable logging output.
This executes the following two lines of code:
import logging logging.basicConfig(level=logging.INFO)
 Parameters:
name (str, optional) – Name of the process that is logging. This can be set to whatever you like.
level (int, optional) – Logging level for the default
StreamHandler
. The default setting islevel=logging.INFO
(which is 20). If you’d like to see more verbose logging messages you might setlevel=logging.DEBUG
.output_filepath (str, optional) –
If provided, a
FileHandler
will be added to the root logger via:file_handler = logging.FileHandler(output_filepath) logging.getLogger('').addHandler(file_handler)
output_level (int, optional) – Logging level for the
FileHandler
. If left unspecified, this defaults tolevel
, i.e. the same level as the defaultStreamHandler
.
 coax.utils.generate_gif(env, filepath, policy=None, resize_to=None, duration=50, max_episode_steps=None)[source]¶
Store a gif from the episode frames.
 Parameters:
env (gymnasium environment) – The environment to record from.
filepath (str) – Location of the output gif file.
policy (callable, optional) – A policy objects that is used to pick actions:
a = policy(s)
. If left unspecified, we’ll just take random actions instead, i.e.a = env.action_space.sample()
.resize_to (tuple of ints, optional) – The size of the output frames,
(width, height)
. Notice the ordering: first width, then height. This is the convention PIL uses.duration (float, optional) – Time between frames in the animated gif, in milliseconds.
max_episode_steps (int, optional) – The maximum number of step in the episode. If left unspecified, we’ll attempt to get the value from
env.spec.max_episode_steps
and if that fails we default to 10000.
 coax.utils.get_env_attr(env, attr, default='__ERROR__', max_depth=100)[source]¶
Get the given attribute from a potentially wrapped environment.
Note that the wrapped envs are traversed from the outside in. Once the attribute is found, the search stops. This means that an inner wrapped env may carry the same (possibly conflicting) attribute. This situation is not resolved by this function.
 Parameters:
env (gymnasium environment) – A potentially wrapped environment.
attr (str) – The attribute name.
max_depth (positive int, optional) – The maximum depth of wrappers to traverse.
 coax.utils.get_grads_diagnostics(grads, key_prefix='', keep_tree_structure=False)[source]¶
Given a pytree of grads, return a dict that contains the quantiles of the magnitudes of each individual component.
This is meant to be a highlevel diagnostic. It first extracts the leaves of the pytree, then flattens each leaf and then it computes the elementwise magnitude. Then, it concatenates all magnitudes into one long flat array. The quantiles are computed on this array.
 Parameters:
grads (a pytree with ndarray leaves) – The gradients of some loss function with respect to the model parameters (weights).
key_prefix (str, optional) – The prefix to add the output dict keys.
keep_tree_structure (bool, optional) – Whether to keep the tree structure, i.e. to compute the grads diagnostics for each individual leaf. If
False
(default), we only compute the global grads diagnostics.
 Returns:
grads_diagnotics (dict<str, float>) – A dict with structure
{name: score}
.
 coax.utils.get_magnitude_quantiles(pytree, key_prefix='')[source]¶
Given a pytree, return a dict that contains the quantiles of the magnitudes of each individual component.
This is meant to be a highlevel diagnostic. It first extracts the leaves of the pytree, then flattens each leaf and then it computes the elementwise magnitude. Then, it concatenates all magnitudes into one long flat array. The quantiles are computed on this array.
 Parameters:
pytree (a pytree with ndarray leaves) – A typical example is a pytree of model params (weights) or gradients with respect to such model params.
key_prefix (str, optional) – The prefix to add the output dict keys.
 Returns:
magnitude_quantiles (dict) – A dict with keys:
['min', 'p25', 'p50', 'p75', 'max']
. The values of the dict are nonnegative floats that represent the magnitude quantiles.
 coax.utils.get_transition_batch(env, batch_size=1, gamma=0.9, random_seed=None)[source]¶
Generate a single transition from the environment.
This basically does a single step on the environment and then closes it.
 Parameters:
env (gymnasium environment) – A gymnasiumstyle environment.
batch_size (positive int, optional) – The desired batch size of the sample.
random_seed (int, optional) – In order to generate the transition, we do some random sampling from the provided spaces. This random_seed set the seed for the pseudorandom number generators.
 Returns:
transition_batch (TransitionBatch) – A batch of transitions.
 coax.utils.has_env_attr(env, attr, max_depth=100)[source]¶
Check if a potentially wrapped environment has a given attribute.
 Parameters:
env (gymnasium environment) – A potentially wrapped environment.
attr (str) – The attribute name.
max_depth (positive int, optional) – The maximum depth of wrappers to traverse.
 coax.utils.idx(arr, axis=0)[source]¶
Given a numpy array, return its corresponding integer index array.
 Parameters:
arr (array) – Input array.
axis (int, optional) – The axis along which we’d like to get an index.
 Returns:
index (1d array, shape: arr.shape[axis]) – An index array [0, 1, 2, …].
 coax.utils.is_policy(obj)[source]¶
Check whether an object is a policy.
 Parameters:
obj – Object to check.
 Returns:
bool – Whether
obj
is a policy.
 coax.utils.is_qfunction(obj)[source]¶
Check whether an object is a
stateaction value function
, or Qfunction. Parameters:
obj – Object to check.
 Returns:
bool – Whether
obj
is a Qfunction and (optionally) whether it is of modeltype 1 or 2.
 coax.utils.is_reward_function(obj)[source]¶
Check whether an object is a dynamics model.
 Parameters:
obj – Object to check.
 Returns:
bool – Whether
obj
is a dynamics function.
 coax.utils.is_stochastic(obj)[source]¶
Check whether an object is a stochastic function approximator.
 Parameters:
obj – Object to check.
 Returns:
bool – Whether
obj
is a stochastic function approximator.
 coax.utils.is_transition_model(obj)[source]¶
Check whether an object is a dynamics model.
 Parameters:
obj – Object to check.
 Returns:
bool – Whether
obj
is a dynamics function.
 coax.utils.is_vfunction(obj)[source]¶
Check whether an object is a
state value function
, or Vfunction. Parameters:
obj – Object to check.
 Returns:
bool – Whether
obj
is a Vfunction.
 coax.utils.isscalar(num)[source]¶
This helper uses a slightly looser definition of scalar compared to
numpy.isscalar()
(andjax.numpy.isscalar()
) in that it also considers singleitem arrays to be scalars as well. Parameters:
num (number or ndarray) – Input array.
 Returns:
isscalar (bool) – Whether the input number is either a number or a singleitem array.
 coax.utils.jit(func, static_argnums=(), donate_argnums=())[source]¶
An alternative of
jax.jit()
that returns a picklable JITcompiled function.Note that
jax.jit()
produces nonpicklable functions, because the JIT compilation depends on thedevice
andbackend
. In order to facilitate serialization, this function does not for the user to specifydevice
orbackend
. Insteadjax.jit()
is called with the default:jax.jit(..., device=None, backend=None)
.Check out the original
jax.jit()
docs for a more detailed description of the arguments.
 coax.utils.load(filepath)[source]¶
Load an object from a file that was created by
dump(obj, filepath)
. Parameters:
filepath (str) – File to load.
 coax.utils.loads(s)[source]¶
Load an object from a bytestring that was created by
dumps(obj)
. Parameters:
s (str) – An lz4compressed pickle bytestring.
 coax.utils.make_dmc(domain, task, seed=0, max_episode_steps=1000, height=84, width=84, camera_id=0)[source]¶
Create a Gym environment for a DeepMind Control suite task.
 Parameters:
 Returns:
env (gymnasium.Env) – Gym environment.
 coax.utils.merge_dicts(*dicts)[source]¶
Merge dicts into a single dict.
WARNING: duplicate keys are not resolved.
 Parameters:
*dicts (*dict) – Multiple dictionaries.
 Returns:
merged (dict) – A single dictionary.
 coax.utils.pretty_print(obj)[source]¶
Print
pretty_repr(obj)
. Parameters:
obj (object) – Any object.
 coax.utils.quantiles(batch_size, num_quantiles=200)[source]¶
Generate
batch_size
quantile fractions that split the interval \([0, 1]\) intonum_quantiles
equally spaced fractions.
 coax.utils.quantiles_uniform(rng, batch_size, num_quantiles=32)[source]¶
Generate
batch_size
quantile fractions that split the interval \([0, 1]\) intonum_quantiles
uniformly distributed fractions. Parameters:
 Returns:
quantile_fractions (ndarray) – Array of quantile fractions.
 coax.utils.quantile_cos_embedding(quantile_fractions, n=64)[source]¶
Embed the given quantile fractions \(\tau\) in an n dimensional space using cosine basis functions.
\[\phi(\tau) = \cos(\tau i \pi) \qquad 0 \leq i \lt n\] Parameters:
quantile_fractions (ndarray) – Array of quantile fractions \(\tau\) to be embedded.
n (int) – The dimensionality of the embedding. By default 64.
 Returns:
quantile_embs (ndarray) – Array of quantile embeddings with shape (quantile_fractions.shape[0], n).
 coax.utils.reload_recursive(module, reload_external_modules=False)[source]¶
Recursively reload a module (in order of dependence).
 coax.utils.render_episode(env, policy=None, step_delay_ms=0)[source]¶
Run a single episode with env.render() calls with each time step.
 Parameters:
env (gymnasium environment) – A gymnasium environment.
policy (callable, optional) – A policy objects that is used to pick actions:
a = policy(s)
. If left unspecified, we’ll just take random actions instead, i.e.a = env.action_space.sample()
.step_delay_ms (nonnegative float) – The number of milliseconds to wait between consecutive timesteps. This can be used to slow down the rendering.
 coax.utils.safe_sample(space, seed=None)[source]¶
Safely sample from a gymnasiumstyle space.
 Parameters:
space (gymnasium.Space) – A gymnasiumstyle space.
seed (int, optional) – The seed for the pseudorandom number generator.
 Returns:
sample – An single sample from of the given
space
.
 coax.utils.single_to_batch(pytree)[source]¶
Take a single instance and turn it into a batch of size 1.
This just does an
np.expand_dims(leaf, axis=0)
on all leaf nodes of the pytree. Parameters:
pytree_single (pytree with ndarray leaves) – A pytree representing e.g. a single state observation.
 Returns:
pytree_batch (pytree with ndarray leaves) – A pytree representing a batch with
batch_size=1
.
 coax.utils.stack_trees(*trees)[source]¶
Apply
jnp.stack
to the leaves of a pytree. Parameters:
trees (sequence of pytrees with ndarray leaves) – A typical example are pytrees containing the parameters and function states of a model that should be used in a function which is vectorized by jax.vmap. The trees have to have the same pytree structure.
 Returns:
pytree (pytree with ndarray leaves) – A tuple of pytrees.
Synchronize shared params. See the A2C stub for example usage.
 Parameters:
*params (multiple hk.Params objects) – The parameter dicts that contain shared parameters.
weights (list of positive floats) – The relative weights to use for averaging the shared params.
 Returns:
params (tuple of hk.Params objects) – Same as input
*params
but with synchronized shared params.
 coax.utils.tree_ravel(pytree)[source]¶
Flatten and concatenate all leaves into a single flat ndarray.
 Parameters:
pytree (a pytree with ndarray leaves) – A typical example is a pytree of model parameters (weights) or gradients with respect to such model params.
 Returns:
arr (ndarray with ndim=1) – A single flat array.
 coax.utils.unvectorize(f, in_axes=0, out_axes=0)[source]¶
Apply a batched function on a single instance, which effectively does the inverse of what
jax.vmap()
does. Parameters:
f (callable) – A batched function.
in_axes (int or tuple of ints, optional) – Specify the batch axes of the inputs of the function
f
. If left unpsecified, this defaults to axis 0 for all inputs.out_axis (int, optional) – Specify the batch axes of the outputs of the function
f
. These axes will be dropped byjnp.squeeze
, i.e. dropped. If left unpsecified, this defaults to axis 0 for all outputs.
 Returns:
f_single (callable) – The unvectorized version of
f
.
Examples
Haiku uses a batchoriented design (although some components may be batchagnostic). To create a function that acts on a single instance, we can use
unvectorize()
as follows:import jax.numpy as jnp import haiku as hk import coax def f(x_batch): return hk.Linear(11)(x_batch) rngs = hk.PRNGSequence(42) x_batch = jnp.zeros(shape=(3, 5)) # batch of 3 instances x_single = jnp.zeros(shape=(5,)) # single instance init, f_batch = hk.transform(f) params = init(next(rngs), x_batch) y_batch = f_batch(params, next(rngs), x_batch) assert y_batch.shape == (3, 11) f_single = coax.unvectorize(f_batch, in_axes=(None, None, 0), out_axes=0) y_single = f_single(params, next(rngs), x_single) assert y_single.shape == (11,)
Alternatively, and perhaps more conveniently, we can unvectorize the function before doing the Haiku transform:
init, f_single = hk.transform(coax.unvectorize(f)) params = init(next(rngs), x_single) y_single = f_single(params, next(rngs), x_single) assert y_single.shape == (11,)