# Source code for coax.reward_tracing._transition

from functools import partial

import jax
import jax.numpy as jnp
import numpy as onp

from .._base.mixins import CopyMixin
from ..utils import pretty_repr

__all__ = (
'TransitionBatch',
)

[docs]class TransitionBatch(CopyMixin):
r"""

A container object for a batch of MDP transitions.

Parameters
----------

S : pytree with ndarray leaves

A batch of state observations :math:S_t.

A : ndarray

A batch of actions :math:A_t.

logP : ndarray

A batch of log-propensities :math:\log\pi(A_t|S_t).

Rn : ndarray

A batch of partial (:math:\gamma-discounted) returns. For instance,
in :math:n-step bootstrapping these are given by:

.. math::

R^{(n)}_t\ &=\ \sum_{k=0}^{n-1}\gamma^kR_{t+k} \\

In other words, it's the part of the :math:n-step return *without*
the bootstrapping term.

In : ndarray

A batch of bootstrap factors. For instance, in :math:n-step
bootstrapping these are given by :math:I^{(n)}_t=\gamma^n when
bootstrapping and :math:I^{(n)}_t=0 otherwise. Bootstrap factors are
used in constructing the :math:n-step bootstrapped target:

.. math::

G^{(n)}_t\ =\ R^{(n)}_t + I^{(n)}_t\,Q(S_{t+1}, A_{t+1})

S_next : pytree with ndarray leaves

A batch of next-state observations :math:S_{t+n}. This is typically
used to contruct the TD target in :math:n-step bootstrapping.

A_next : ndarray, optional

A batch of next-actions :math:A_{t+n}. This is typically used to
contruct the TD target in :math:n-step bootstrapping when using SARSA

logP_next : ndarray, optional

A batch of log-propensities :math:\log\pi(A_{t+n}|S_{t+n}).

W : ndarray, optional

A batch of importance weights associated with the sampling procedure that generated each
transition. For example, we need these values when we sample transitions from a
:class:PrioritizedReplayBuffer <coax.experience_replay.PrioritizedReplayBuffer>.

"""
__slots__ = ('S', 'A', 'logP', 'Rn', 'In', 'S_next',
'A_next', 'logP_next', 'W', 'idx', 'extra_info')

def __init__(self, S, A, logP, Rn, In, S_next, A_next=None, logP_next=None, W=None, idx=None,
extra_info=None):

self.S = S
self.A = A
self.logP = logP
self.Rn = Rn
self.In = In
self.S_next = S_next
self.A_next = A_next
self.logP_next = logP_next
self.W = onp.ones_like(Rn) if W is None else W
self.idx = onp.arange(Rn.shape[0], dtype='int32') if idx is None else idx
self.extra_info = extra_info

[docs]    @classmethod
def from_single(
cls, s, a, logp, r, done, gamma,
s_next=None, a_next=None, logp_next=None, w=1, idx=None, extra_info=None):
r"""

Create a TransitionBatch (with batch_size=1) from a single transition.

Attributes
----------
s : state observation

A single state observation :math:S_t.

a : action

A single action :math:A_t.

logp : non-positive float

The log-propensity :math:\log\pi(A_t|S_t).

r : float or array of floats

A single reward :math:R_t.

done : bool

Whether the episode has finished.

info : dict or None

s_next : state observation

A single next-state observation :math:S_{t+1}.

a_next : action

A single next-action :math:A_{t+1}.

logp_next : non-positive float

The log-propensity :math:\log\pi(A_{t+1}|S_{t+1}).

w : positive float, optional

The importance weight associated with the sampling procedure that generated this
transition.

idx : int, optional

The identifier of this particular transition.

"""

# check types
array = (int, float, onp.ndarray, jnp.ndarray)
if not (isinstance(logp, array) and onp.all(logp <= 0)):
raise TypeError(f"logp must be non-positive float(s), got: {logp}")
if not isinstance(r, array):
raise TypeError(f"r must be a scalar or an array, got: {r}")
if not isinstance(done, bool):
raise TypeError(f"done must be a bool, got: {done}")
if not (isinstance(gamma, (float, int)) and 0 <= gamma <= 1):
raise TypeError(f"gamma must be a float in the unit interval [0, 1], got: {gamma}")
if not (logp_next is None or (isinstance(logp_next, array) and onp.all(logp_next <= 0))):
raise TypeError(f"logp_next must be None or non-positive float(s), got: {logp_next}")
if not (isinstance(w, (float, int)) and w > 0):
raise TypeError(f"w must be a positive float, got: {w}")

return cls(
S=_single_to_batch(s),
A=_single_to_batch(a),
logP=_single_to_batch(logp),
Rn=_single_to_batch(r),
In=_single_to_batch(float(gamma) * (1. - bool(done))),
S_next=_single_to_batch(s_next) if s_next is not None else None,
A_next=_single_to_batch(a_next) if a_next is not None else None,
logP_next=_single_to_batch(logp_next) if logp_next is not None else None,
W=_single_to_batch(float(w)),
idx=_single_to_batch(idx) if idx is not None else None,
extra_info=_single_to_batch(extra_info) if extra_info is not None else None
)

@property
def batch_size(self):
return onp.shape(self.Rn)[0]

[docs]    def to_singles(self):
r"""

Get an iterator of single transitions.

Returns
-------
transition_batches : iterator of TransitionBatch

An iterator of :class:TransitionBatch <coax.reward_tracing.TransitionBatch> objects
with batch_size=1.

**Note:** The iterator walks through the individual transitions *in reverse order*.

"""
if self.batch_size == 1:
yield self
return  # break out of generator

def lookup(i, pytree):
s = slice(i, i + 1)  # ndim-preserving lookup
return jax.tree_map(lambda leaf: leaf[s], pytree)

for i in range(self.batch_size):
yield TransitionBatch(*map(partial(lookup, i), self))

def items(self):
for k in self.__slots__:
yield k, getattr(self, k)

def _asdict(self):
return dict(self.items())

def __repr__(self):
return pretty_repr(self)

def __iter__(self):
return (getattr(self, a) for a in self.__slots__)

def __getitem__(self, int_or_slice):
return tuple(self).__getitem__(int_or_slice)

def __eq__(self, other):
return (type(self) is type(other)) and all(
onp.allclose(a, b) if isinstance(a, (onp.ndarray, jnp.ndarray))
else (a is b if a is None else a == b)
for a, b in zip(jax.tree_util.tree_leaves(self), jax.tree_util.tree_leaves(other)))

def _single_to_batch(pytree):
# notice that we're pulling eveyrthing out of jax.numpy and into ordinary numpy land
return jax.tree_map(lambda arr: onp.expand_dims(arr, axis=0), pytree)

jax.tree_util.register_pytree_node(
TransitionBatch,
lambda tn: (tuple(tn), None),
lambda treedef, leaves: TransitionBatch(*leaves))