Source code for coax.utils._misc

import os
import time
import logging
from importlib import reload, import_module
from types import ModuleType

import jax.numpy as jnp
import numpy as onp
import pandas as pd
import lz4.frame
import cloudpickle as pickle
from PIL import Image


__all__ = (
    'docstring',
    'enable_logging',
    'dump',
    'dumps',
    'load',
    'loads',
    'generate_gif',
    'get_env_attr',
    'getattr_safe',
    'has_env_attr',
    'is_policy',
    'is_qfunction',
    'is_reward_function',
    'is_stochastic',
    'is_transition_model',
    'is_vfunction',
    'pretty_repr',
    'pretty_print',
    'reload_recursive',
    'render_episode',
)


[docs]def docstring(obj): r''' A simple decorator that sets the ``__doc__`` attribute to ``obj.__doc__`` on the decorated object, see example below. Parameters ---------- obj : object The objects whose docstring you wish to copy onto the wrapped object. Examples -------- >>> def f(x): ... """Some docstring""" ... return x * x ... >>> def g(x): ... return 13 - x ... >>> g.__doc__ = f.__doc__ This can abbreviated by: >>> @docstring(f) ... def g(x): ... return 13 - x ... ''' def decorator(func): func.__doc__ = obj.__doc__ return func return decorator
[docs]def enable_logging(name=None, level=logging.INFO, output_filepath=None, output_level=None): r""" Enable logging output. This executes the following two lines of code: .. code:: python import logging logging.basicConfig(level=logging.INFO) Parameters ---------- name : str, optional Name of the process that is logging. This can be set to whatever you like. level : int, optional Logging level for the default :py:class:`StreamHandler <logging.StreamHandler>`. The default setting is ``level=logging.INFO`` (which is 20). If you'd like to see more verbose logging messages you might set ``level=logging.DEBUG``. output_filepath : str, optional If provided, a :py:class:`FileHandler <logging.FileHandler>` will be added to the root logger via: .. code:: python file_handler = logging.FileHandler(output_filepath) logging.getLogger('').addHandler(file_handler) output_level : int, optional Logging level for the :py:class:`FileHandler <logging.FileHandler>`. If left unspecified, this defaults to ``level``, i.e. the same level as the default :py:class:`StreamHandler <logging.StreamHandler>`. """ if name is None: fmt = '[%(name)s|%(levelname)s] %(message)s' else: fmt = f'[{name}|%(name)s|%(levelname)s] %(message)s' logging.basicConfig(level=level, format=fmt) if output_filepath is not None: os.makedirs(os.path.dirname(output_filepath) or '.', exist_ok=True) fh = logging.FileHandler(output_filepath) fh.setLevel(level if output_level is None else output_level) logging.getLogger('').addHandler(fh)
[docs]def dump(obj, filepath): r""" Save an object to disk. Parameters ---------- obj : object Any python object. filepath : str Where to store the instance. Warning ------- References between objects are only preserved if they are stored as part of a single object, for example: .. code:: python # b has a reference to a a = [13] b = {'a': a} # references preserved dump((a, b), 'ab.pkl.lz4') a_new, b_new = load('ab.pkl.lz4') b_new['a'].append(7) print(b_new) # {'a': [13, 7]} print(a_new) # [13, 7] <-- updated # references not preserved dump(a, 'a.pkl.lz4') dump(b, 'b.pkl.lz4') a_new = load('a.pkl.lz4') b_new = load('b.pkl.lz4') b_new['a'].append(7) print(b_new) # {'a': [13, 7]} print(a_new) # [13] <-- not updated!! Therefore, the safest way to create checkpoints is to store the entire state as a single object like a dict or a tuple. """ dirpath = os.path.dirname(filepath) if dirpath: os.makedirs(dirpath, exist_ok=True) with lz4.frame.open(filepath, 'wb') as f: f.write(pickle.dumps(obj))
[docs]def dumps(obj): r""" Serialize an object to an lz4-compressed pickle byte-string. Parameters ---------- obj : object Any python object. Returns ------- s : bytes An lz4-compressed pickle byte-string. Warning ------- References between objects are only preserved if they are stored as part of a single object, for example: .. code:: python # b has a reference to a a = [13] b = {'a': a} # references preserved s = dumps((a, b)) a_new, b_new = loads(s) b_new['a'].append(7) print(b_new) # {'a': [13, 7]} print(a_new) # [13, 7] <-- updated # references not preserved s_a = dumps(a) s_b = dumps(b) a_new = loads(s_a) b_new = loads(s_b) b_new['a'].append(7) print(b_new) # {'a': [13, 7]} print(a_new) # [13] <-- not updated!! Therefore, the safest way to create checkpoints is to store the entire state as a single object like a dict or a tuple. """ return lz4.frame.compress(pickle.dumps(obj))
[docs]def load(filepath): r""" Load an object from a file that was created by :func:`dump(obj, filepath) <dump>`. Parameters ---------- filepath : str File to load. """ with lz4.frame.open(filepath, 'rb') as f: return pickle.loads(f.read())
[docs]def loads(s): r""" Load an object from a byte-string that was created by :func:`dumps(obj) <dumps>`. Parameters ---------- s : str An lz4-compressed pickle byte-string. """ return pickle.loads(lz4.frame.decompress(s))
def _reload(module, reload_all, reloaded, logger): if isinstance(module, ModuleType): module_name = module.__name__ elif isinstance(module, str): module_name, module = module, import_module(module) else: raise TypeError( "'module' must be either a module or str; " f"got: {module.__class__.__name__}") for attr_name in dir(module): attr = getattr(module, attr_name) check = ( # is it a module? isinstance(attr, ModuleType) # has it already been reloaded? and attr.__name__ not in reloaded # is it a proper submodule? (or just reload all) and (reload_all or attr.__name__.startswith(module_name)) ) if check: _reload(attr, reload_all, reloaded, logger) logger.debug(f"reloading module: {module_name}") reload(module) reloaded.add(module_name)
[docs]def reload_recursive(module, reload_external_modules=False): """ Recursively reload a module (in order of dependence). Parameters ---------- module : ModuleType or str The module to reload. reload_external_modules : bool, optional Whether to reload all referenced modules, including external ones which aren't submodules of ``module``. """ logger = logging.getLogger('coax.utils.reload_recursive') _reload(module, reload_external_modules, set(), logger)
[docs]def render_episode(env, policy=None, step_delay_ms=0): r""" Run a single episode with env.render() calls with each time step. Parameters ---------- env : gymnasium environment A gymnasium environment. policy : callable, optional A policy objects that is used to pick actions: ``a = policy(s)``. If left unspecified, we'll just take random actions instead, i.e. ``a = env.action_space.sample()``. step_delay_ms : non-negative float The number of milliseconds to wait between consecutive timesteps. This can be used to slow down the rendering. """ from ..wrappers import TrainMonitor if isinstance(env, TrainMonitor): env = env.env # unwrap to strip off TrainMonitor s = env.reset() env.render() for t in range(int(1e9)): a = env.action_space.sample() if policy is None else policy(s) s_next, r, done, info = env.step(a) env.render() time.sleep(step_delay_ms / 1e3) if done: break s = s_next time.sleep(5 * step_delay_ms / 1e3)
[docs]def has_env_attr(env, attr, max_depth=100): r""" Check if a potentially wrapped environment has a given attribute. Parameters ---------- env : gymnasium environment A potentially wrapped environment. attr : str The attribute name. max_depth : positive int, optional The maximum depth of wrappers to traverse. """ e = env for i in range(max_depth): if hasattr(e, attr): return True if not hasattr(e, 'env'): break e = e.env return False
[docs]def get_env_attr(env, attr, default='__ERROR__', max_depth=100): r""" Get the given attribute from a potentially wrapped environment. Note that the wrapped envs are traversed from the outside in. Once the attribute is found, the search stops. This means that an inner wrapped env may carry the same (possibly conflicting) attribute. This situation is *not* resolved by this function. Parameters ---------- env : gymnasium environment A potentially wrapped environment. attr : str The attribute name. max_depth : positive int, optional The maximum depth of wrappers to traverse. """ e = env for i in range(max_depth): if hasattr(e, attr): return getattr(e, attr) if not hasattr(e, 'env'): break e = e.env if default == '__ERROR__': raise AttributeError("env is missing attribute: {}".format(attr)) return default
[docs]def generate_gif(env, filepath, policy=None, resize_to=None, duration=50, max_episode_steps=None): r""" Store a gif from the episode frames. Parameters ---------- env : gymnasium environment The environment to record from. filepath : str Location of the output gif file. policy : callable, optional A policy objects that is used to pick actions: ``a = policy(s)``. If left unspecified, we'll just take random actions instead, i.e. ``a = env.action_space.sample()``. resize_to : tuple of ints, optional The size of the output frames, ``(width, height)``. Notice the ordering: first **width**, then **height**. This is the convention PIL uses. duration : float, optional Time between frames in the animated gif, in milliseconds. max_episode_steps : int, optional The maximum number of step in the episode. If left unspecified, we'll attempt to get the value from ``env.spec.max_episode_steps`` and if that fails we default to 10000. """ logger = logging.getLogger('generate_gif') max_episode_steps = max_episode_steps \ or getattr(getattr(env, 'spec'), 'max_episode_steps', 10000) from ..wrappers import TrainMonitor if isinstance(env, TrainMonitor): env = env.env # unwrap to strip off TrainMonitor s, info = env.reset() # check if render_mode is set to 'rbg_array' if not (env.render_mode == 'rgb_array' or isinstance(env.render(), onp.ndarray)): raise RuntimeError("Cannot generate GIF if env.render_mode != 'rgb_array'.") # collect frames frames = [] for t in range(max_episode_steps): a = env.action_space.sample() if policy is None else policy(s) s_next, r, done, truncated, info = env.step(a) # store frame frame = env.render() frame = Image.fromarray(frame) frame = frame.convert('P', palette=Image.ADAPTIVE) if resize_to is not None: if not (isinstance(resize_to, tuple) and len(resize_to) == 2): raise TypeError( "expected a tuple of size 2, resize_to=(w, h)") frame = frame.resize(resize_to) frames.append(frame) if done or truncated: break s = s_next # store last frame frame = env.render() frame = Image.fromarray(frame) frame = frame.convert('P', palette=Image.ADAPTIVE) if resize_to is not None: frame = frame.resize(resize_to) frames.append(frame) # generate gif os.makedirs(os.path.dirname(filepath) or '.', exist_ok=True) frames[0].save( fp=filepath, format='GIF', append_images=frames[1:], save_all=True, duration=duration, loop=0) logger.info("recorded episode to: {}".format(filepath))
[docs]def is_transition_model(obj): r""" Check whether an object is a dynamics model. Parameters ---------- obj Object to check. Returns ------- bool Whether ``obj`` is a dynamics function. """ # import at runtime to avoid circular dependence from .._core.transition_model import TransitionModel from .._core.stochastic_transition_model import StochasticTransitionModel return isinstance(obj, (TransitionModel, StochasticTransitionModel))
[docs]def is_reward_function(obj): r""" Check whether an object is a dynamics model. Parameters ---------- obj Object to check. Returns ------- bool Whether ``obj`` is a dynamics function. """ # import at runtime to avoid circular dependence from .._core.reward_function import RewardFunction from .._core.stochastic_reward_function import StochasticRewardFunction return isinstance(obj, (RewardFunction, StochasticRewardFunction))
[docs]def is_vfunction(obj): r""" Check whether an object is a :class:`state value function <coax.V>`, or V-function. Parameters ---------- obj Object to check. Returns ------- bool Whether ``obj`` is a V-function. """ # import at runtime to avoid circular dependence from .._core.v import V from .._core.stochastic_v import StochasticV return isinstance(obj, (V, StochasticV))
[docs]def is_qfunction(obj): r""" Check whether an object is a :class:`state-action value function <coax.Q>`, or Q-function. Parameters ---------- obj Object to check. Returns ------- bool Whether ``obj`` is a Q-function and (optionally) whether it is of modeltype 1 or 2. """ # import at runtime to avoid circular dependence from .._core.q import Q from .._core.stochastic_q import StochasticQ from .._core.successor_state_q import SuccessorStateQ return isinstance(obj, (Q, StochasticQ, SuccessorStateQ))
[docs]def is_stochastic(obj): r""" Check whether an object is a stochastic function approximator. Parameters ---------- obj Object to check. Returns ------- bool Whether ``obj`` is a stochastic function approximator. """ # import at runtime to avoid circular dependence from .._core.policy import Policy from .._core.stochastic_v import StochasticV from .._core.stochastic_q import StochasticQ from .._core.stochastic_reward_function import StochasticRewardFunction from .._core.stochastic_transition_model import StochasticTransitionModel return isinstance(obj, ( Policy, StochasticV, StochasticQ, StochasticRewardFunction, StochasticTransitionModel))
[docs]def is_policy(obj): r""" Check whether an object is a :doc:`policy <policies>`. Parameters ---------- obj Object to check. Returns ------- bool Whether ``obj`` is a policy. """ # import at runtime to avoid circular dependence from .._core.policy import Policy from .._core.value_based_policy import EpsilonGreedy, BoltzmannPolicy return isinstance(obj, (Policy, EpsilonGreedy, BoltzmannPolicy))
[docs]def pretty_repr(o, d=0): r""" Generate pretty :func:`repr` (string representions). Parameters ---------- o : object Any object. d : int, optional The depth of the recursion. This is used to determine the indentation level in recursive calls, so we typically keep this 0. Returns ------- pretty_repr : str A nicely formatted string representation of :code:`object`. """ i = " " # indentation string if isinstance(o, (jnp.ndarray, onp.ndarray, pd.Index)): try: summary = f", min={onp.min(o):.3g}, median={onp.median(o):.3g}, max={onp.max(o):.3g}" except Exception: summary = "" return f"array(shape={o.shape}, dtype={str(o.dtype)}{summary:s})" if isinstance(o, (pd.Series, pd.DataFrame)): sep = ',\n' + i * (d + 1) items = zip(('index', 'data'), (o.index, o.values)) body = sep + sep.join(f"{k}={pretty_repr(v, d + 1)}" for k, v in items) return f"{type(o).__name__}({body})" if hasattr(o, '_asdict'): sep = '\n' + i * (d + 1) body = sep + sep.join(f"{k}={pretty_repr(v, d + 1)}" for k, v in o._asdict().items()) return f"{type(o).__name__}({body})" if isinstance(o, tuple): sep = ',\n' + i * (d + 1) body = '\n' + i * (d + 1) + sep.join(f"{pretty_repr(v, d + 1)}" for v in o) return f"({body})" if isinstance(o, list): sep = ',\n' + i * (d + 1) body = '\n' + i * (d + 1) + sep.join(f"{pretty_repr(v, d + 1)}" for v in o) return f"[{body}]" if hasattr(o, 'items'): sep = ',\n' + i * (d + 1) body = '\n' + i * (d + 1) + sep.join( f"{repr(k)}: {pretty_repr(v, d + 1)}" for k, v in o.items()) return f"{{{body}}}" return repr(o)
[docs]def pretty_print(obj): r""" Print :func:`pretty_repr(obj) <coax.utils.pretty_repr>`. Parameters ---------- obj : object Any object. """ print(pretty_repr(obj))
def getattr_safe(obj, name, default=None): """ A safe implementation of :func:`getattr <python3:getattr>`. If an attr exists, but calling getattr raises an error, this implementation will silence the error and return the ``default`` value. Parameter --------- obj : object Any object. name : str The name of the attribute. default : object, optional The default value to return if getattr fails. Returns ------- attr : object The attribute ``obj.name`` or ``default``. """ attr = default try: attr = getattr(obj, name, default) except Exception: pass return attr