Source code for coax.experience_replay._simple

import random

import jax
import numpy as onp

from ..reward_tracing import TransitionBatch
from ._base import BaseReplayBuffer

__all__ = (

[docs]class SimpleReplayBuffer(BaseReplayBuffer): r""" A simple ring buffer for experience replay. Parameters ---------- capacity : positive int The capacity of the experience replay buffer. random_seed : int, optional To get reproducible results. """ def __init__(self, capacity, random_seed=None): self._capacity = int(capacity) random.seed(random_seed) self._random_state = random.getstate() self.clear() # sets self._storage @property def capacity(self): return self._capacity
[docs] def add(self, transition_batch): r""" Add a transition to the experience replay buffer. Parameters ---------- transition_batch : TransitionBatch A :class:`TransitionBatch <coax.reward_tracing.TransitionBatch>` object. """ if not isinstance(transition_batch, TransitionBatch): raise TypeError( f"transition_batch must be a TransitionBatch, got: {type(transition_batch)}") transition_batch.idx = onp.arange(self._index, self._index + transition_batch.batch_size) self._index += transition_batch.batch_size self._storage.extend(transition_batch.to_singles()) while len(self) > self.capacity: self._storage.pop(0)
[docs] def sample(self, batch_size=32): r""" Get a batch of transitions to be used for bootstrapped updates. Parameters ---------- batch_size : positive int, optional The desired batch size of the sample. Returns ------- transitions : TransitionBatch A :class:`TransitionBatch <coax.reward_tracing.TransitionBatch>` object. """ # sandwich sample in between setstate/getstate in case global random state was tampered with random.setstate(self._random_state) transitions = random.sample(self._storage, batch_size) self._random_state = random.getstate() return jax.tree_map(lambda *leaves: onp.concatenate(leaves, axis=0), *transitions)
[docs] def clear(self): r""" Clear the experience replay buffer. """ self._storage = list() self._index = 0
def __len__(self): return len(self._storage) def __bool__(self): return bool(len(self)) def __iter__(self): return iter(self._storage)