Source code for coax.wrappers._train_monitor

import os
import re
import datetime
import logging
import time
from collections import deque
from typing import Mapping

import numpy as np
import lz4.frame
import cloudpickle as pickle
from gymnasium import Wrapper
from gymnasium.spaces import Discrete
from tensorboardX import SummaryWriter

from .._base.mixins import LoggerMixin
from ..utils import enable_logging


__all__ = (
    'TrainMonitor',
)


class StreamingSample:
    def __init__(self, maxlen, random_seed=None):
        self._deque = deque(maxlen=maxlen)
        self._count = 0
        self._rnd = np.random.RandomState(random_seed)

    def reset(self):
        self._deque = deque(maxlen=self.maxlen)
        self._count = 0

    def append(self, obj):
        self._count += 1
        if len(self) < self.maxlen:
            self._deque.append(obj)
        elif self._rnd.rand() < self.maxlen / self._count:
            i = self._rnd.randint(self.maxlen)
            self._deque[i] = obj

    @property
    def values(self):
        return list(self._deque)  # shallow copy

    @property
    def maxlen(self):
        return self._deque.maxlen

    def __len__(self):
        return len(self._deque)

    def __bool__(self):
        return bool(self._deque)


[docs]class TrainMonitor(Wrapper, LoggerMixin): r""" Environment wrapper for monitoring the training process. This wrapper logs some diagnostics at the end of each episode and it also gives us some handy attributes (listed below). Parameters ---------- env : gymnasium environment A gymnasium environment. tensorboard_dir : str, optional If provided, TrainMonitor will log all diagnostics to be viewed in tensorboard. To view these, point tensorboard to the same dir: .. code:: bash $ tensorboard --logdir {tensorboard_dir} tensorboard_write_all : bool, optional You may record your training metrics using the :attr:`record_metrics` method. Setting the ``tensorboard_write_all`` specifies whether to pass the metrics on to tensorboard immediately (``True``) or to wait and average them across the episode (``False``). The default setting (``False``) prevents tensorboard from being fluided by logs. log_all_metrics : bool, optional Whether to log all metrics. If ``log_all_metrics=False``, only a reduced set of metrics are logged. smoothing : positive int, optional The number of observations for smoothing the metrics. We use the following smooth update rule: .. math:: n\ &\leftarrow\ \min(\text{smoothing}, n + 1) \\ x_\text{avg}\ &\leftarrow\ x_\text{avg} + \frac{x_\text{obs} - x_\text{avg}}{n} \*\*logger_kwargs Keyword arguments to pass on to :func:`coax.utils.enable_logging`. Attributes ---------- T : positive int Global step counter. This is not reset by ``env.reset()``, use ``env.reset_global()`` instead. ep : positive int Global episode counter. This is not reset by ``env.reset()``, use ``env.reset_global()`` instead. t : positive int Step counter within an episode. G : float The return, i.e. amount of reward accumulated from the start of the current episode. avg_G : float The average return G, averaged over the past 100 episodes. dt_ms : float The average wall time of a single step, in milliseconds. """ _COUNTER_ATTRS = ( 'T', 'ep', 't', 'G', 'avg_G', '_n_avg_G', '_ep_starttime', '_ep_metrics', '_ep_actions', '_tensorboard_dir', '_period') def __init__( self, env, tensorboard_dir=None, tensorboard_write_all=False, log_all_metrics=False, smoothing=10, **logger_kwargs): super().__init__(env) self.log_all_metrics = log_all_metrics self.tensorboard_write_all = tensorboard_write_all self.smoothing = float(smoothing) self.reset_global() enable_logging(**logger_kwargs) self.logger.setLevel(logger_kwargs.get('level', logging.INFO)) self._init_tensorboard(tensorboard_dir)
[docs] def reset_global(self): r""" Reset the global counters, not just the episodic ones. """ self.T = 0 self.ep = 0 self.t = 0 self.G = 0.0 self.avg_G = 0.0 self._n_avg_G = 0.0 self._ep_starttime = time.time() self._ep_metrics = {} self._ep_actions = StreamingSample(maxlen=1000) self._period = {'T': {}, 'ep': {}}
[docs] def reset(self): # write logs from previous episode: if self.ep: self._write_episode_logs() # increment global counters: self.T += 1 self.ep += 1 # reset episodic counters: self.t = 0 self.G = 0.0 self._ep_starttime = time.time() self._ep_metrics = {} self._ep_actions.reset() return self.env.reset()
@property def dt_ms(self): if self.t <= 0: return np.nan return 1000 * (time.time() - self._ep_starttime) / self.t @property def avg_r(self): if self.t <= 0: return np.nan return self.G / self.t
[docs] def step(self, a): self._ep_actions.append(a) s_next, r, done, truncated, info = self.env.step(a) if info is None: info = {} info['monitor'] = {'T': self.T, 'ep': self.ep} self.t += 1 self.T += 1 self.G += r if done or truncated: if self._n_avg_G < self.smoothing: self._n_avg_G += 1. self.avg_G += (self.G - self.avg_G) / self._n_avg_G return s_next, r, done, truncated, info
[docs] def record_metrics(self, metrics): r""" Record metrics during the training process. These are used to print more diagnostics. Parameters ---------- metrics : dict A dict of metrics, of type ``{name <str>: value <float>}``. """ if not isinstance(metrics, Mapping): raise TypeError("metrics must be a Mapping") # write metrics to tensoboard if self.tensorboard is not None and self.tensorboard_write_all: for name, metric in metrics.items(): self.tensorboard.add_scalar( str(name), float(metric), global_step=self.T) # compute episode averages for k, v in metrics.items(): if k not in self._ep_metrics: self._ep_metrics[k] = v, 1. else: x, n = self._ep_metrics[k] self._ep_metrics[k] = x + v, n + 1
[docs] def get_metrics(self): r""" Return the current state of the metrics. Returns ------- metrics : dict A dict of metrics, of type ``{name <str>: value <float>}``. """ return {k: float(x) / n for k, (x, n) in self._ep_metrics.items()}
def period(self, name, T_period=None, ep_period=None): if T_period is not None: T_period = int(T_period) assert T_period > 0 if name not in self._period['T']: self._period['T'][name] = 1 if self.T >= self._period['T'][name] * T_period: self._period['T'][name] += 1 return True or self.period(name, None, ep_period) return self.period(name, None, ep_period) if ep_period is not None: ep_period = int(ep_period) assert ep_period > 0 if name not in self._period['ep']: self._period['ep'][name] = 1 if self.ep >= self._period['ep'][name] * ep_period: self._period['ep'][name] += 1 return True return False @property def tensorboard(self): if not hasattr(self, '_tensorboard'): assert self._tensorboard_dir is not None self._tensorboard = SummaryWriter(self._tensorboard_dir) return self._tensorboard def _init_tensorboard(self, tensorboard_dir): if tensorboard_dir is None: self._tensorboard_dir = None self._tensorboard = None return # append timestamp to disambiguate instances if not re.match(r'.*/\d{8}_\d{6}$', tensorboard_dir): tensorboard_dir = os.path.join( tensorboard_dir, datetime.datetime.now().strftime('%Y%m%d_%H%M%S')) # only set/update if necessary if tensorboard_dir != getattr(self, '_tensorboard_dir', None): self._tensorboard_dir = tensorboard_dir if hasattr(self, '_tensorboard'): del self._tensorboard def _write_episode_logs(self): metrics = ( f'{k:s}: {float(x) / n:.3g}' for k, (x, n) in self._ep_metrics.items() if ( self.log_all_metrics or str(k).endswith('/loss') or str(k).endswith('/entropy') or str(k).endswith('/kl_div') or str(k).startswith('throughput/') ) ) self.logger.info( ',\t'.join(( f'ep: {self.ep:d}', f'T: {self.T:,d}', f'G: {self.G:.3g}', f'avg_r: {self.avg_r:.3g}', f'avg_G: {self.avg_G:.3g}', f't: {self.t:d}', f'dt: {self.dt_ms:.3f}ms', *metrics))) if self.tensorboard is not None: metrics = { 'episode/episode': self.ep, 'episode/avg_reward': self.avg_r, 'episode/return': self.G, 'episode/steps': self.t, 'episode/avg_step_duration_ms': self.dt_ms} for name, metric in metrics.items(): self.tensorboard.add_scalar( str(name), float(metric), global_step=self.T) if self._ep_actions: if isinstance(self.action_space, Discrete): bins = np.arange(self.action_space.n + 1) else: bins = 'auto' # see also: np.histogram_bin_edges.__doc__ self.tensorboard.add_histogram( tag='actions', values=self._ep_actions.values, global_step=self.T, bins=bins) if self._ep_metrics and not self.tensorboard_write_all: for k, (x, n) in self._ep_metrics.items(): self.tensorboard.add_scalar(str(k), float(x) / n, global_step=self.T) self.tensorboard.flush() def __getstate__(self): state = self.__dict__.copy() # shallow copy if '_tensorboard' in state: del state['_tensorboard'] # remove reference to non-pickleable attr return state def __setstate__(self, state): self.__dict__.update(state) self._init_tensorboard(state['_tensorboard_dir'])
[docs] def get_counters(self): r""" Get the current state of all internal counters. Returns ------- counter : dict The dict that contains the counters. """ return {k: getattr(self, k) for k in self._COUNTER_ATTRS}
[docs] def set_counters(self, counters): r""" Restore the state of all internal counters. Parameters ---------- counter : dict The dict that contains the counters. """ if not (isinstance(counters, dict) and set(counters) == set(self._COUNTER_ATTRS)): raise TypeError(f"invalid counters dict: {counters}") self.__setstate__(counters)
[docs] def save_counters(self, filepath): r""" Store the current state of all internal counters. Parameters ---------- filepath : str The checkpoint file path. """ counters = self.get_counters() os.makedirs(os.path.dirname(filepath) or '.', exist_ok=True) with lz4.frame.open(filepath, 'wb') as f: f.write(pickle.dumps(counters))
[docs] def load_counters(self, filepath): r""" Restore the state of all internal counters. Parameters ---------- filepath : str The checkpoint file path. """ with lz4.frame.open(filepath, 'rb') as f: counters = pickle.loads(f.read()) self.set_counters(counters)