Source code for coax._core.worker

import time
import inspect
from abc import ABC, abstractmethod
from typing import Optional

import gymnasium
from jax.lib.xla_bridge import get_backend

from ..typing import Policy
from ..wrappers import TrainMonitor
from ..reward_tracing._base import BaseRewardTracer
from ..experience_replay._base import BaseReplayBuffer


__all__ = (
    'Worker',
)


class WorkerError(Exception):
    pass


[docs]class Worker(ABC): r""" The base class for defining workers as part of a distributed agent. Parameters ---------- env : gymnasium.Env | str | function Specifies the gymnasium-style environment by either passing the env itself (gymnasium.Env), its name (str), or a function that generates the environment. param_store : Worker, optional A distributed agent is presumed to have one worker that plays the role of a parameter store. To define the parameter-store worker itself, you must leave :code:`param_store=None`. For other worker roles, however, :code:`param_store` must be provided. pi : Policy, optional The behavior policy that is used by rollout workers to generate experience. tracer : RewardTracer, optional The reward tracer that is used by rollout workers. buffer : ReplayBuffer, optional The experience-replay buffer that is populated by rollout workers and sampled from by learners. buffer_warmup : int, optional The warmup period for the experience replay buffer, i.e. the minimal number of transitions that need to be stored in the replay buffer before we start sampling from it. name : str, optional A human-readable identifier of the worker. """ pi: Optional[Policy] = None tracer: Optional[BaseRewardTracer] = None buffer: Optional[BaseReplayBuffer] = None buffer_warmup: Optional[int] = None def __init__( self, env, param_store=None, pi=None, tracer=None, buffer=None, buffer_warmup=None, name=None): # import inline to avoid hard dependency on ray import ray import ray.actor self.__ray = ray self.env = _check_env(env, name) self.param_store = param_store self.pi = pi self.tracer = tracer self.buffer = buffer self.buffer_warmup = buffer_warmup self.name = name self.env.logger.info(f"JAX platform name: '{get_backend().platform}'")
[docs] @abstractmethod def get_state(self): r""" Get the internal state that is shared between workers. Returns ------- state : object The internal state. This will be consumed by :func:`set_state(state) <set_state>`. """ pass
[docs] @abstractmethod def set_state(self, state): r""" Set the internal state that is shared between workers. Parameters ---------- state : object The internal state, as returned by :func:`get_state`. """ pass
[docs] @abstractmethod def trace(self, s, a, r, done, logp=0.0, w=1.0): r""" This implements the reward-tracing step of a single, raw transition. Parameters ---------- s : state observation A single state observation. a : action A single action. r : float A single observed reward. done : bool Whether the episode has finished. logp : float, optional The log-propensity :math:`\log\pi(a|s)`. w : float, optional Sample weight associated with the given state-action pair. """ pass
[docs] @abstractmethod def learn(self, transition_batch): r""" Update the model parameters given a transition batch. """ pass
def rollout(self): assert self.pi is not None s, info = self.env.reset() for t in range(self.env.spec.max_episode_steps): a, logp = self.pi(s, return_logp=True) s_next, r, done, truncated, info = self.env.step(a) self.trace(s, a, r, done or truncated, logp) if done or truncated: break s = s_next def rollout_loop(self, max_total_steps, reward_threshold=None): reward_threshold = _check_reward_threshold(reward_threshold, self.env) T_global = self.pull_getattr('env.T') while T_global < max_total_steps and self.env.avg_G < reward_threshold: self.pull_state() self.rollout() metrics = self.pull_metrics() metrics['throughput/rollout_loop'] = 1000 / self.env.dt_ms metrics['episode/T_global'] = T_global = self.pull_getattr('env.T') + self.env.t self.push_setattr('env.T', T_global) # not exactly thread-safe, but that's okay self.env.record_metrics(metrics) def learn_loop(self, max_total_steps, batch_size=32): throughput = 0. while self.pull_getattr('env.T') < max_total_steps: t_start = time.time() self.pull_state() metrics = self.learn(self.buffer_sample(batch_size=batch_size)) metrics['throughput/learn_loop'] = throughput self.push_state() self.push_metrics(metrics) throughput = batch_size / (time.time() - t_start) def buffer_len(self): if self.param_store is None: assert self.buffer is not None len_ = len(self.buffer) elif isinstance(self.param_store, self.__ray.actor.ActorHandle): len_ = self.__ray.get(self.param_store.buffer_len.remote()) else: len_ = self.param_store.buffer_len() return len_ def buffer_add(self, transition_batch, Adv=None): if self.param_store is None: assert self.buffer is not None if 'Adv' in inspect.signature(self.buffer.add).parameters: # duck typing self.buffer.add(transition_batch, Adv=Adv) else: self.buffer.add(transition_batch) elif isinstance(self.param_store, self.__ray.actor.ActorHandle): self.__ray.get(self.param_store.buffer_add.remote(transition_batch, Adv)) else: self.param_store.buffer_add(transition_batch, Adv) def buffer_update(self, transition_batch_idx, Adv): if self.param_store is None: assert self.buffer is not None self.buffer.update(transition_batch_idx, Adv=Adv) elif isinstance(self.param_store, self.__ray.actor.ActorHandle): self.__ray.get(self.param_store.buffer_update.remote(transition_batch_idx, Adv)) else: self.param_store.buffer_update(transition_batch_idx, Adv) def buffer_sample(self, batch_size=32): buffer_warmup = max(self.buffer_warmup or 0, batch_size) wait_secs = 1 / 1024. buffer_len = self.buffer_len() while buffer_len < buffer_warmup: self.env.logger.debug( f"buffer insufficiently populated: {buffer_len}/{buffer_warmup}; " f"waiting for {wait_secs}s") time.sleep(wait_secs) wait_secs = min(30, wait_secs * 2) # wait at most 30s between tries buffer_len = self.buffer_len() if self.param_store is None: assert self.buffer is not None transition_batch = self.buffer.sample(batch_size=batch_size) elif isinstance(self.param_store, self.__ray.actor.ActorHandle): transition_batch = self.__ray.get( self.param_store.buffer_sample.remote(batch_size=batch_size)) else: transition_batch = self.param_store.buffer_sample(batch_size=batch_size) assert transition_batch is not None return transition_batch def pull_state(self): assert self.param_store is not None, "cannot call pull_state on param_store itself" if isinstance(self.param_store, self.__ray.actor.ActorHandle): self.set_state(self.__ray.get(self.param_store.get_state.remote())) else: self.set_state(self.param_store.get_state()) def push_state(self): assert self.param_store is not None, "cannot call push_state on param_store itself" if isinstance(self.param_store, self.__ray.actor.ActorHandle): self.__ray.get(self.param_store.set_state.remote(self.get_state())) else: self.param_store.set_state(self.get_state()) def pull_metrics(self): if self.param_store is None: metrics = self.env.get_metrics() elif isinstance(self.param_store, self.__ray.actor.ActorHandle): metrics = self.__ray.get(self.param_store.pull_metrics.remote()).copy() else: metrics = self.param_store.pull_metrics() return metrics def push_metrics(self, metrics): if self.param_store is None: self.env.record_metrics(metrics) elif isinstance(self.param_store, self.__ray.actor.ActorHandle): self.__ray.get(self.param_store.push_metrics.remote(metrics)) else: self.param_store.push_metrics(metrics) def pull_getattr(self, name, default_value=...): if self.param_store is None: value = _getattr_recursive(self, name, default_value) elif isinstance(self.param_store, self.__ray.actor.ActorHandle): value = self.__ray.get(self.param_store.pull_getattr.remote(name, default_value)) else: value = self.param_store.pull_getattr(name, default_value) return value def push_setattr(self, name, value): if self.param_store is None: _setattr_recursive(self, name, value) elif isinstance(self.param_store, self.__ray.actor.ActorHandle): self.__ray.get(self.param_store.push_setattr.remote(name, value)) else: self.param_store.push_setattr(name, value)
# -- some helper functions (boilerplate) --------------------------------------------------------- # def _check_env(env, name): if isinstance(env, gymnasium.Env): pass elif isinstance(env, str): env = gymnasium.make(env) elif hasattr(env, '__call__'): env = env() else: raise TypeError(f"env must be a gymnasium.Env, str or callable; got: {type(env)}") if getattr(getattr(env, 'spec', None), 'max_episode_steps', None) is None: raise ValueError( "env.spec.max_episode_steps not set; please register env with " "gymnasium.register('Foo-v0', entry_point='foo.Foo', max_episode_steps=...) " "or wrap your env with: env = gymnasium.wrappers.TimeLimit(env, max_episode_steps=...)") if not isinstance(env, TrainMonitor): env = TrainMonitor(env, name=name, log_all_metrics=True) return env def _check_reward_threshold(reward_threshold, env): if reward_threshold is None: reward_threshold = getattr(getattr(env, 'spec', None), 'reward_threshold', None) if reward_threshold is None: reward_threshold = float('inf') return reward_threshold def _getattr_recursive(obj, name, default=...): if '.' not in name: return getattr(obj, name) if default is Ellipsis else getattr(obj, name, default) name, subname = name.split('.', 1) return _getattr_recursive(getattr(obj, name), subname, default) def _setattr_recursive(obj, name, value): if '.' not in name: return setattr(obj, name, value) name, subname = name.split('.', 1) return _setattr_recursive(getattr(obj, name), subname, value)