Source code for coax.wrappers._meta_policy

import inspect

import gymnasium


[docs]class MetaPolicyEnv(gymnasium.Wrapper): r""" Wrap a gymnasium-style environment such that it may be used by a meta-policy, i.e. a bandit that selects a policy (an *arm*), which is then used to sample a lower-level action and fed the original environment. In other words, the actions that the :attr:`step` method expects are *meta-actions*, selecting different *arms*. The lower-level actions (and their log-propensities) that are sampled internally are stored in the ``info`` dict, returned by the :attr:`step` method. Parameters ---------- env : gymnasium-style environment The original environment to be wrapped into a meta-policy env. \*arms : functions Callable objects that take a state observation :math:`s` and return an action :math:`a` (and optionally, log-propensity :math:`\log\pi(a|s)`). See for example :attr:`coax.Policy.__call__` or :attr:`coax.Policy.mode`. """ def __init__(self, env, *arms): super().__init__(env) self.arms = arms self.action_space = gymnasium.spaces.Discrete(len(arms)) self._s = None
[docs] def reset(self): self._s, info = self.env.reset() return self._s, info
[docs] def step(self, a_meta): assert self.action_space.contains(a_meta), "a_meta is invalid" assert self._s is not None, "please call env.reset() first" pi = self.arms[a_meta] if 'return_logp' in inspect.getargspec(pi).args: a, logp = pi(self._s, return_logp=True) else: a, logp = pi(self._s), 0. self._s, r, done, truncated, info = self.env.step(a) info = info or {} info.update({'a': a, 'logp': logp}) return self._s, r, done, truncated, info