Wrappers

coax.wrappers.TrainMonitor

Environment wrapper for monitoring the training process.

coax.wrappers.FrameStacking

Wrapper that does frame stacking (see DQN paper).

coax.wrappers.BoxActionsToReals

This wrapper decompactifies a Box action space to the reals.

coax.wrappers.BoxActionsToDiscrete

This wrapper splits a Box action space into bins.

coax.wrappers.MetaPolicyEnv

Wrap a gymnasium-style environment such that it may be used by a meta-policy, i.e. a bandit that selects a policy (an arm), which is then used to sample a lower-level action and fed the original environment.


Gymnasium provides a nice modular interface to extend existing using environment wrappers. Here we list some wrappers that are used throughout the coax package.

The most notable wrapper that you’ll probably want to use is coax.wrappers.TrainMonitor. It wraps the environment in a way that we can view our training logs easily. It uses both the standard logging module as well as tensorboard through the tensorboardX package.

Object Reference

class coax.wrappers.TrainMonitor(env, tensorboard_dir=None, tensorboard_write_all=False, log_all_metrics=False, smoothing=10, **logger_kwargs)[source]

Environment wrapper for monitoring the training process.

This wrapper logs some diagnostics at the end of each episode and it also gives us some handy attributes (listed below).

Parameters:
  • env (gymnasium environment) – A gymnasium environment.

  • tensorboard_dir (str, optional) –

    If provided, TrainMonitor will log all diagnostics to be viewed in tensorboard. To view these, point tensorboard to the same dir:

    $ tensorboard --logdir {tensorboard_dir}
    

  • tensorboard_write_all (bool, optional) – You may record your training metrics using the record_metrics method. Setting the tensorboard_write_all specifies whether to pass the metrics on to tensorboard immediately (True) or to wait and average them across the episode (False). The default setting (False) prevents tensorboard from being fluided by logs.

  • log_all_metrics (bool, optional) – Whether to log all metrics. If log_all_metrics=False, only a reduced set of metrics are logged.

  • smoothing (positive int, optional) –

    The number of observations for smoothing the metrics. We use the following smooth update rule:

    \[\begin{split}n\ &\leftarrow\ \min(\text{smoothing}, n + 1) \\ x_\text{avg}\ &\leftarrow\ x_\text{avg} + \frac{x_\text{obs} - x_\text{avg}}{n}\end{split}\]

  • **logger_kwargs – Keyword arguments to pass on to coax.utils.enable_logging().

Variables:
  • T (positive int) – Global step counter. This is not reset by env.reset(), use env.reset_global() instead.

  • ep (positive int) – Global episode counter. This is not reset by env.reset(), use env.reset_global() instead.

  • t (positive int) – Step counter within an episode.

  • G (float) – The return, i.e. amount of reward accumulated from the start of the current episode.

  • avg_G (float) – The average return G, averaged over the past 100 episodes.

  • dt_ms (float) – The average wall time of a single step, in milliseconds.

classmethod class_name() str

Returns the class name of the wrapper.

close()

Closes the wrapper and env.

get_counters()[source]

Get the current state of all internal counters.

Returns:

counter (dict) – The dict that contains the counters.

get_metrics()[source]

Return the current state of the metrics.

Returns:

metrics (dict) – A dict of metrics, of type {name <str>: value <float>}.

load_counters(filepath)[source]

Restore the state of all internal counters.

Parameters:

filepath (str) – The checkpoint file path.

record_metrics(metrics)[source]

Record metrics during the training process.

These are used to print more diagnostics.

Parameters:

metrics (dict) – A dict of metrics, of type {name <str>: value <float>}.

render() RenderFrame | list[RenderFrame] | None

Uses the render() of the env that can be overwritten to change the returned data.

reset()[source]

Uses the reset() of the env that can be overwritten to change the returned data.

reset_global()[source]

Reset the global counters, not just the episodic ones.

save_counters(filepath)[source]

Store the current state of all internal counters.

Parameters:

filepath (str) – The checkpoint file path.

set_counters(counters)[source]

Restore the state of all internal counters.

Parameters:

counter (dict) – The dict that contains the counters.

step(a)[source]

Uses the step() of the env that can be overwritten to change the returned data.

property action_space: spaces.Space[ActType] | spaces.Space[WrapperActType]

Return the Env action_space unless overwritten then the wrapper action_space is used.

property metadata: dict[str, Any]

Returns the Env metadata.

property np_random: Generator

Returns the Env np_random attribute.

property observation_space: spaces.Space[ObsType] | spaces.Space[WrapperObsType]

Return the Env observation_space unless overwritten then the wrapper observation_space is used.

property render_mode: str | None

Returns the Env render_mode.

property reward_range: tuple[SupportsFloat, SupportsFloat]

Return the Env reward_range unless overwritten then the wrapper reward_range is used.

property spec: EnvSpec | None

Returns the Env spec attribute.

property unwrapped: Env[ObsType, ActType]

Returns the base environment of the wrapper.

This will be the bare gymnasium.Env environment, underneath all layers of wrappers.

class coax.wrappers.FrameStacking(env, num_frames)[source]

Wrapper that does frame stacking (see DQN paper).

This implementation is different from most implementations in that it doesn’t perform the stacking itself. Instead, it just returns a tuple of frames (untouched), which may be stacked downstream.

The benefit of this implementation is two-fold. First, it respects the gymnasium.spaces API, where each observation is truly an element of the observation space (this is not true of the gymnasium implementation, which uses a custom data class to maintain its minimal memory footprint). Second, this implementation is compatibility with the jax.tree_util module, which means that we can feed it into jit-compiled functions directly.

Example

import gymnasium
env = gymnasium.make('PongNoFrameskip-v0')
print(env.observation_space)  # Box(210, 160, 3)

env = FrameStacking(env, num_frames=2)
print(env.observation_space)  # Tuple((Box(210, 160, 3), Box(210, 160, 3)))
Parameters:
  • env (gymnasium-style environment) – The original environment to be wrapped.

  • num_frames (positive int) – Number of frames to stack.

classmethod class_name() str

Returns the class name of the wrapper.

close()

Closes the wrapper and env.

render() RenderFrame | list[RenderFrame] | None

Uses the render() of the env that can be overwritten to change the returned data.

reset(**kwargs)[source]

Uses the reset() of the env that can be overwritten to change the returned data.

step(action)[source]

Uses the step() of the env that can be overwritten to change the returned data.

property action_space: spaces.Space[ActType] | spaces.Space[WrapperActType]

Return the Env action_space unless overwritten then the wrapper action_space is used.

property metadata: dict[str, Any]

Returns the Env metadata.

property np_random: Generator

Returns the Env np_random attribute.

property observation_space: spaces.Space[ObsType] | spaces.Space[WrapperObsType]

Return the Env observation_space unless overwritten then the wrapper observation_space is used.

property render_mode: str | None

Returns the Env render_mode.

property reward_range: tuple[SupportsFloat, SupportsFloat]

Return the Env reward_range unless overwritten then the wrapper reward_range is used.

property spec: EnvSpec | None

Returns the Env spec attribute.

property unwrapped: Env[ObsType, ActType]

Returns the base environment of the wrapper.

This will be the bare gymnasium.Env environment, underneath all layers of wrappers.

class coax.wrappers.BoxActionsToReals(env)[source]

This wrapper decompactifies a Box action space to the reals. This is required in order to be able to use a Gaussian policy.

In practice, the wrapped environment expects the input action \(a_\text{real}\in\mathbb{R}^n\) and then it compactifies it back to a Box of the right size:

\[a_\text{box}\ =\ \text{low} + (\text{high}-\text{low})\times\text{sigmoid}(a_\text{real})\]

Technically, the transformed space is still a Box, but that’s only because we assume that the values lie between large but finite bounds, \(a_\text{real}\in[-10^{15}, 10^{15}]^n\).

classmethod class_name() str

Returns the class name of the wrapper.

close()

Closes the wrapper and env.

render() RenderFrame | list[RenderFrame] | None

Uses the render() of the env that can be overwritten to change the returned data.

reset(*, seed: int | None = None, options: dict[str, Any] | None = None) tuple[WrapperObsType, dict[str, Any]]

Uses the reset() of the env that can be overwritten to change the returned data.

step(a)[source]

Uses the step() of the env that can be overwritten to change the returned data.

property action_space: spaces.Space[ActType] | spaces.Space[WrapperActType]

Return the Env action_space unless overwritten then the wrapper action_space is used.

property metadata: dict[str, Any]

Returns the Env metadata.

property np_random: Generator

Returns the Env np_random attribute.

property observation_space: spaces.Space[ObsType] | spaces.Space[WrapperObsType]

Return the Env observation_space unless overwritten then the wrapper observation_space is used.

property render_mode: str | None

Returns the Env render_mode.

property reward_range: tuple[SupportsFloat, SupportsFloat]

Return the Env reward_range unless overwritten then the wrapper reward_range is used.

property spec: EnvSpec | None

Returns the Env spec attribute.

property unwrapped: Env[ObsType, ActType]

Returns the base environment of the wrapper.

This will be the bare gymnasium.Env environment, underneath all layers of wrappers.

class coax.wrappers.BoxActionsToDiscrete(env, num_bins, random_seed=None)[source]

This wrapper splits a Box action space into bins. The resulting action space is either Discrete or MultiDiscrete, depending on the shape of the original action space.

Parameters:
  • num_bins (int or tuple of ints) – The number of bins to use. A multi-dimenionsional box requires a tuple of num_bins instead of a single integer.

  • random_seed (int, optional) – Sets the random state to get reproducible results.

classmethod class_name() str

Returns the class name of the wrapper.

close()

Closes the wrapper and env.

render() RenderFrame | list[RenderFrame] | None

Uses the render() of the env that can be overwritten to change the returned data.

reset(*, seed: int | None = None, options: dict[str, Any] | None = None) tuple[WrapperObsType, dict[str, Any]]

Uses the reset() of the env that can be overwritten to change the returned data.

step(a)[source]

Uses the step() of the env that can be overwritten to change the returned data.

property action_space: spaces.Space[ActType] | spaces.Space[WrapperActType]

Return the Env action_space unless overwritten then the wrapper action_space is used.

property metadata: dict[str, Any]

Returns the Env metadata.

property np_random: Generator

Returns the Env np_random attribute.

property observation_space: spaces.Space[ObsType] | spaces.Space[WrapperObsType]

Return the Env observation_space unless overwritten then the wrapper observation_space is used.

property render_mode: str | None

Returns the Env render_mode.

property reward_range: tuple[SupportsFloat, SupportsFloat]

Return the Env reward_range unless overwritten then the wrapper reward_range is used.

property spec: EnvSpec | None

Returns the Env spec attribute.

property unwrapped: Env[ObsType, ActType]

Returns the base environment of the wrapper.

This will be the bare gymnasium.Env environment, underneath all layers of wrappers.

class coax.wrappers.MetaPolicyEnv(env, *arms)[source]

Wrap a gymnasium-style environment such that it may be used by a meta-policy, i.e. a bandit that selects a policy (an arm), which is then used to sample a lower-level action and fed the original environment. In other words, the actions that the step method expects are meta-actions, selecting different arms. The lower-level actions (and their log-propensities) that are sampled internally are stored in the info dict, returned by the step method.

Parameters:
  • env (gymnasium-style environment) – The original environment to be wrapped into a meta-policy env.

  • *arms (functions) – Callable objects that take a state observation \(s\) and return an action \(a\) (and optionally, log-propensity \(\log\pi(a|s)\)). See for example coax.Policy.__call__ or coax.Policy.mode.

classmethod class_name() str

Returns the class name of the wrapper.

close()

Closes the wrapper and env.

render() RenderFrame | list[RenderFrame] | None

Uses the render() of the env that can be overwritten to change the returned data.

reset()[source]

Uses the reset() of the env that can be overwritten to change the returned data.

step(a_meta)[source]

Uses the step() of the env that can be overwritten to change the returned data.

property action_space: spaces.Space[ActType] | spaces.Space[WrapperActType]

Return the Env action_space unless overwritten then the wrapper action_space is used.

property metadata: dict[str, Any]

Returns the Env metadata.

property np_random: Generator

Returns the Env np_random attribute.

property observation_space: spaces.Space[ObsType] | spaces.Space[WrapperObsType]

Return the Env observation_space unless overwritten then the wrapper observation_space is used.

property render_mode: str | None

Returns the Env render_mode.

property reward_range: tuple[SupportsFloat, SupportsFloat]

Return the Env reward_range unless overwritten then the wrapper reward_range is used.

property spec: EnvSpec | None

Returns the Env spec attribute.

property unwrapped: Env[ObsType, ActType]

Returns the base environment of the wrapper.

This will be the bare gymnasium.Env environment, underneath all layers of wrappers.