from functools import partial
import jax
import jax.numpy as jnp
import numpy as onp
from .._base.mixins import CopyMixin
from ..utils import pretty_repr
__all__ = (
'TransitionBatch',
)
[docs]class TransitionBatch(CopyMixin):
r"""
A container object for a batch of MDP transitions.
Parameters
----------
S : pytree with ndarray leaves
A batch of state observations :math:`S_t`.
A : ndarray
A batch of actions :math:`A_t`.
logP : ndarray
A batch of log-propensities :math:`\log\pi(A_t|S_t)`.
Rn : ndarray
A batch of partial (:math:`\gamma`-discounted) returns. For instance,
in :math:`n`-step bootstrapping these are given by:
.. math::
R^{(n)}_t\ &=\ \sum_{k=0}^{n-1}\gamma^kR_{t+k} \\
In other words, it's the part of the :math:`n`-step return *without*
the bootstrapping term.
In : ndarray
A batch of bootstrap factors. For instance, in :math:`n`-step
bootstrapping these are given by :math:`I^{(n)}_t=\gamma^n` when
bootstrapping and :math:`I^{(n)}_t=0` otherwise. Bootstrap factors are
used in constructing the :math:`n`-step bootstrapped target:
.. math::
G^{(n)}_t\ =\ R^{(n)}_t + I^{(n)}_t\,Q(S_{t+1}, A_{t+1})
S_next : pytree with ndarray leaves
A batch of next-state observations :math:`S_{t+n}`. This is typically
used to contruct the TD target in :math:`n`-step bootstrapping.
A_next : ndarray, optional
A batch of next-actions :math:`A_{t+n}`. This is typically used to
contruct the TD target in :math:`n`-step bootstrapping when using SARSA
updates.
logP_next : ndarray, optional
A batch of log-propensities :math:`\log\pi(A_{t+n}|S_{t+n})`.
W : ndarray, optional
A batch of importance weights associated with the sampling procedure that generated each
transition. For example, we need these values when we sample transitions from a
:class:`PrioritizedReplayBuffer <coax.experience_replay.PrioritizedReplayBuffer>`.
"""
__slots__ = ('S', 'A', 'logP', 'Rn', 'In', 'S_next',
'A_next', 'logP_next', 'W', 'idx', 'extra_info')
def __init__(self, S, A, logP, Rn, In, S_next, A_next=None, logP_next=None, W=None, idx=None,
extra_info=None):
self.S = S
self.A = A
self.logP = logP
self.Rn = Rn
self.In = In
self.S_next = S_next
self.A_next = A_next
self.logP_next = logP_next
self.W = onp.ones_like(Rn) if W is None else W
self.idx = onp.arange(Rn.shape[0], dtype='int32') if idx is None else idx
self.extra_info = extra_info
[docs] @classmethod
def from_single(
cls, s, a, logp, r, done, gamma,
s_next=None, a_next=None, logp_next=None, w=1, idx=None, extra_info=None):
r"""
Create a TransitionBatch (with batch_size=1) from a single transition.
Attributes
----------
s : state observation
A single state observation :math:`S_t`.
a : action
A single action :math:`A_t`.
logp : non-positive float
The log-propensity :math:`\log\pi(A_t|S_t)`.
r : float or array of floats
A single reward :math:`R_t`.
done : bool
Whether the episode has finished.
info : dict or None
Some additional info about the current time step.
s_next : state observation
A single next-state observation :math:`S_{t+1}`.
a_next : action
A single next-action :math:`A_{t+1}`.
logp_next : non-positive float
The log-propensity :math:`\log\pi(A_{t+1}|S_{t+1})`.
w : positive float, optional
The importance weight associated with the sampling procedure that generated this
transition.
idx : int, optional
The identifier of this particular transition.
"""
# check types
array = (int, float, onp.ndarray, jnp.ndarray)
if not (isinstance(logp, array) and onp.all(logp <= 0)):
raise TypeError(f"logp must be non-positive float(s), got: {logp}")
if not isinstance(r, array):
raise TypeError(f"r must be a scalar or an array, got: {r}")
if not isinstance(done, bool):
raise TypeError(f"done must be a bool, got: {done}")
if not (isinstance(gamma, (float, int)) and 0 <= gamma <= 1):
raise TypeError(f"gamma must be a float in the unit interval [0, 1], got: {gamma}")
if not (logp_next is None or (isinstance(logp_next, array) and onp.all(logp_next <= 0))):
raise TypeError(f"logp_next must be None or non-positive float(s), got: {logp_next}")
if not (isinstance(w, (float, int)) and w > 0):
raise TypeError(f"w must be a positive float, got: {w}")
return cls(
S=_single_to_batch(s),
A=_single_to_batch(a),
logP=_single_to_batch(logp),
Rn=_single_to_batch(r),
In=_single_to_batch(float(gamma) * (1. - bool(done))),
S_next=_single_to_batch(s_next) if s_next is not None else None,
A_next=_single_to_batch(a_next) if a_next is not None else None,
logP_next=_single_to_batch(logp_next) if logp_next is not None else None,
W=_single_to_batch(float(w)),
idx=_single_to_batch(idx) if idx is not None else None,
extra_info=_single_to_batch(extra_info) if extra_info is not None else None
)
@property
def batch_size(self):
return onp.shape(self.Rn)[0]
[docs] def to_singles(self):
r"""
Get an iterator of single transitions.
Returns
-------
transition_batches : iterator of TransitionBatch
An iterator of :class:`TransitionBatch <coax.reward_tracing.TransitionBatch>` objects
with ``batch_size=1``.
**Note:** The iterator walks through the individual transitions *in reverse order*.
"""
if self.batch_size == 1:
yield self
return # break out of generator
def lookup(i, pytree):
s = slice(i, i + 1) # ndim-preserving lookup
return jax.tree_map(lambda leaf: leaf[s], pytree)
for i in range(self.batch_size):
yield TransitionBatch(*map(partial(lookup, i), self))
def items(self):
for k in self.__slots__:
yield k, getattr(self, k)
def _asdict(self):
return dict(self.items())
def __repr__(self):
return pretty_repr(self)
def __iter__(self):
return (getattr(self, a) for a in self.__slots__)
def __getitem__(self, int_or_slice):
return tuple(self).__getitem__(int_or_slice)
def __eq__(self, other):
return (type(self) is type(other)) and all(
onp.allclose(a, b) if isinstance(a, (onp.ndarray, jnp.ndarray))
else (a is b if a is None else a == b)
for a, b in zip(jax.tree_util.tree_leaves(self), jax.tree_util.tree_leaves(other)))
def _single_to_batch(pytree):
# notice that we're pulling eveyrthing out of jax.numpy and into ordinary numpy land
return jax.tree_map(lambda arr: onp.expand_dims(arr, axis=0), pytree)
jax.tree_util.register_pytree_node(
TransitionBatch,
lambda tn: (tuple(tn), None),
lambda treedef, leaves: TransitionBatch(*leaves))