import numpy as onp
__all__ = (
'SegmentTree',
'SumTree',
'MinTree',
'MaxTree',
)
[docs]class SegmentTree:
r"""
A `segment tree <https://en.wikipedia.org/wiki/Segment_tree>`_ data structure that allows
for batched updating and batched partial-range (segment) reductions.
Parameters
----------
capacity : positive int
Number of values to accommodate.
reducer : function
The reducer function: :code:`(float, float) -> float`.
init_value : float
The unit element relative to the reducer function. Some typical examples are: 0 if reducer
is :data:`add <numpy.add>`, 1 for :data:`multiply <numpy.multiply>`, :math:`-\infty` for
:data:`maximum <numpy.maximum>`, :math:`\infty` for :data:`minimum <numpy.minimum>`.
Warning
-------
The :attr:`values` attribute and square-bracket lookups (:code:`tree[level, index]`) return
references of the underlying storage array. Therefore, make sure that downstream code doesn't
update these values in-place, which would corrupt the segment tree structure.
"""
def __init__(self, capacity, reducer, init_value):
self.capacity = capacity
self.reducer = reducer
self.init_value = float(init_value)
self._height = int(onp.ceil(onp.log2(capacity))) + 1 # the +1 is for the values themselves
self._arr = onp.full(shape=(2 ** self.height - 1), fill_value=self.init_value)
@property
def height(self):
r""" The height of the tree :math:`h\sim\log(\text{capacity})`. """
return self._height
@property
def root_value(self):
r"""
The aggregated value, equivalent to
:func:`reduce(reducer, values, init_value) <functools.reduce>`.
"""
return self._arr[0]
@property
def values(self):
r""" The values stored at the leaves of the tree. """
start = 2 ** (self.height - 1) - 1
stop = start + self.capacity
return self._arr[start:stop]
def __getitem__(self, lookup):
if isinstance(lookup, int):
level_offset, level_size = self._check_level_lookup(lookup)
return self._arr[level_offset:(level_offset + level_size)]
if isinstance(lookup, tuple) and len(lookup) == 1:
level, = lookup
return self[level]
if isinstance(lookup, tuple) and len(lookup) == 2:
level, index = lookup
return self[level][index]
raise IndexError(
"tree lookup must be of the form: tree[level] or tree[level, index], "
"where 'level' is an int and 'index' is a 1d array lookup")
[docs] def set_values(self, idx, values):
r"""
Set or update the :attr:`values`.
Parameters
----------
idx : 1d array of ints
The indices of the values to be updated. If you wish to update all values use ellipses
instead, e.g. :code:`tree.set_values(..., values)`.
values : 1d array of floats
The new values.
"""
idx, level_offset, level_size = self._check_idx(idx)
# update leaf-node values
self._arr[level_offset + (idx % level_size)] = values
for level in range(self.height - 2, -1, -1):
idx = onp.unique(idx // 2)
left_child = level_offset + 2 * idx
right_child = left_child + 1
level_offset = 2 ** level - 1
parent = level_offset + idx
self._arr[parent] = self.reducer(self._arr[left_child], self._arr[right_child])
[docs] def partial_reduce(self, start=0, stop=None):
r"""
Reduce values over a partial range of indices. This is an efficient, batched implementation
of :func:`reduce(reducer, values[state:stop], init_value) <functools.reduce>`.
Parameters
----------
start : int or array of ints
The lower bound of the range (inclusive).
stop : int or array of ints, optional
The lower bound of the range (exclusive). If left unspecified, this defaults to
:attr:`height`.
Returns
-------
value : float
The result of the partial reduction.
"""
# NOTE: This is an iterative implementation, which is a lot uglier than a recursive one.
# The reason why we use an iterative approach is that it's easier for batch-processing.
# i and j are 1d arrays (indices for self._arr)
i, j = self._check_start_stop_to_i_j(start, stop)
# trivial case
done = (i == j)
if done.all():
return self._arr[i]
# left/right accumulators (mask one of them to avoid over-counting if i == j)
a, b = self._arr[i], onp.where(done, self.init_value, self._arr[j])
# number of nodes in higher levels
level_offset = 2 ** (self.height - 1) - 1
# we start from the leaves and work up towards the root
for level in range(self.height - 2, -1, -1):
# get parent indices
level_offset_parent = 2 ** level - 1
i_parent = (i - level_offset) // 2 + level_offset_parent
j_parent = (j - level_offset) // 2 + level_offset_parent
# stop when we have a shared parent (possibly the root node, but not necessarily)
done |= (i_parent == j_parent)
if done.all():
return self.reducer(a, b)
# only accumulate right-child value if 'i' was a left child of 'i_parent'
a = onp.where((i % 2 == 1) & ~done, self.reducer(a, self._arr[i + 1]), a)
# only accumulate left-child value if 'j' was a right child of 'j_parent'
b = onp.where((j % 2 == 0) & ~done, self.reducer(b, self._arr[j - 1]), b)
# prepare for next loop
i, j, level_offset = i_parent, j_parent, level_offset_parent
assert False, 'this point should not be reached'
def __repr__(self):
s = ""
for level in range(self.height):
s += f"\n level={level} : {repr(self[level])}"
return f"{type(self).__name__}({s})"
def _check_level_lookup(self, level):
if not isinstance(level, int):
raise IndexError(f"level lookup must be an int, got: {type(level)}")
if not (-self.height <= level < self.height):
raise IndexError(f"level index {level} is out of bounds; tree height: {self.height}")
level %= self.height
level_offset = 2 ** level - 1
level_size = min(2 ** level, self.capacity)
return level_offset, level_size
def _check_level(self, level):
if level < -self.height or level >= self.height:
raise IndexError(f"tree level index {level} out of range; tree height: {self.height}")
return level % self.height
def _check_idx(self, idx):
""" some boiler plate to turn any compatible idx into a 1d integer array """
level_offset, level_size = self._check_level_lookup(self.height - 1)
if isinstance(idx, int):
idx = onp.asarray([idx], dtype='int32')
if idx is None or idx is Ellipsis:
idx = onp.arange(level_size, dtype='int32')
elif isinstance(idx, list) and all(isinstance(x, int) for x in idx):
idx = onp.asarray(idx, dtype='int32')
elif (isinstance(idx, onp.ndarray)
and onp.issubdtype(idx.dtype, onp.integer)
and idx.ndim <= 1):
idx = idx.reshape(-1)
else:
raise IndexError("idx must be an int or a 1d integer array")
if not onp.all((idx < level_size) & (idx >= -level_size)):
raise IndexError("one of more entries in idx are out or range")
return idx % level_size, level_offset, level_size
def _check_start_stop_to_i_j(self, start, stop):
""" some boiler plate to turn (start, stop) into left/right index arrays (i, j) """
start_orig, stop_orig = start, stop
# convert 'start' index to 1d array
if isinstance(start, int):
start = onp.array([start])
if not (isinstance(start, onp.ndarray)
and start.ndim == 1
and onp.issubdtype(start.dtype, onp.integer)):
raise TypeError("'start' must be an int or a 1d integer array")
# convert 'stop' index to 1d array
if stop is None:
stop = onp.full_like(start, self.capacity)
if isinstance(stop, int):
stop = onp.full_like(start, stop)
if not (isinstance(stop, onp.ndarray)
and stop.ndim == 1
and onp.issubdtype(stop.dtype, onp.integer)):
raise TypeError("'stop' must be an int or a 1d integer array")
# ensure that 'start' is the same size as 'stop'
if start.size == 1 and stop.size > 1:
start = onp.full_like(stop, start[0])
# check compatible shapes
if start.shape != stop.shape:
raise ValueError(
f"shapes must be equal, got: start.shape: {start.shape}, stop.shape: {stop.shape}")
# convert to (i, j), where j is the *inclusive* version of 'stop' (which is exclusive)
level_offset, level_size = self._check_level_lookup(self.height - 1)
i = level_offset + start % level_size
j = level_offset + (stop - 1) % level_size
# check consistency of ranges
if not onp.all((i >= level_offset) & (j < level_offset + level_size) & (i <= j)):
raise IndexError(
f"inconsistent ranges detected from (start, stop) = ({start_orig}, {stop_orig})")
return i, j
[docs]class SumTree(SegmentTree):
r"""
A sum-tree data structure that allows for batched updating and batched weighted sampling.
Both update and sampling operations have a time complexity of :math:`\mathcal{O}(\log N)` and a
memory footprint of :math:`\mathcal{O}(N)`, where :math:`N` is the length of the underlying
:attr:`values`.
Parameters
----------
capacity : positive int
Number of values to accommodate.
reducer : function
The reducer function: :code:`(float, float) -> float`.
init_value : float
The unit element relative to the reducer function. Some typical examples are: 0 if
reducer is :func:`operator.add`, 1 for :func:`operator.mul`, :math:`-\infty` for
:func:`max`, :math:`\infty` for :func:`min`.
"""
def __init__(self, capacity, random_seed=None):
super().__init__(capacity=capacity, reducer=onp.add, init_value=0)
self.random_seed = random_seed
@property
def random_seed(self):
return self._random_seed
@random_seed.setter
def random_seed(self, new_random_seed):
self._rnd = onp.random.RandomState(new_random_seed)
self._random_seed = new_random_seed
[docs] def sample(self, n):
r"""
Sample array indices using weighted sampling, where the sample weights are proprotional to
the values stored in :attr:`values`.
Parameters
----------
n : positive int
The number of samples to return.
Returns
-------
idx : array of ints
The sampled indices, shape: (n,)
Warning
-------
This method presumes (but doesn't check) that all :attr:`values` stored in the tree are
non-negative.
"""
if not (isinstance(n, int) and n > 0):
raise TypeError("n must be a positive integer")
return self.inverse_cdf(self._rnd.rand(n))
[docs] def inverse_cdf(self, u):
r"""
Inverse of the cumulative distribution function (CDF) of the categorical distribution
:math:`\text{Cat}(p)`, where :math:`p` are the normalized values :math:`p_i=`
:attr:`values[i] / sum(values) <values>`.
This function provides the machinery for the :attr:`sample` method.
Parameters
----------
u : float or 1d array of floats
One of more numbers :math:`u\in[0,1]`. These are typically sampled from
:math:`\text{Unif([0, 1])}`.
Returns
-------
idx : array of ints
The indices associated with :math:`u`, shape: (n,)
Warning
-------
This method presumes (but doesn't check) that all :attr:`values` stored in the tree are
non-negative.
"""
# NOTE: This is an iterative implementation, which is a lot uglier than a recursive one.
# The reason why we use an iterative approach is that it's easier for batch-processing.
if self.root_value <= 0:
raise RuntimeError("the root_value must be positive")
# init (will be updated in loop)
u, isscalar = self._check_u(u)
values = u * self.root_value
idx = onp.zeros_like(values, dtype='int32') # this is ultimately what we'll returned
level_offset_parent = 0 # number of nodes in levels above parent
# iterate down, from the root to the leaves
for level in range(1, self.height):
# get child indices
level_offset = 2 ** level - 1
left_child_idx = (idx - level_offset_parent) * 2 + level_offset
right_child_idx = left_child_idx + 1
# update (idx, values, level_offset_parent)
left_child_values = self._arr[left_child_idx]
pick_left_child = left_child_values > values
idx = onp.where(pick_left_child, left_child_idx, right_child_idx)
values = onp.where(pick_left_child, values, values - left_child_values)
level_offset_parent = level_offset
idx = idx - level_offset_parent
return idx[0] if isscalar else idx
def _check_u(self, u):
""" some boilerplate to check validity of 'u' array """
isscalar = False
if isinstance(u, (float, int)):
u = onp.array([u], dtype='float')
isscalar = True
if isinstance(u, list) and all(isinstance(x, (float, int)) for x in u):
u = onp.asarray(u, dtype='float')
if not (isinstance(u, onp.ndarray)
and u.ndim == 1 and onp.issubdtype(u.dtype, onp.floating)):
raise TypeError("'u' must be a float or a 1d array of floats")
if onp.any(u > 1) or onp.any(u < 0):
raise ValueError("all values in 'u' must lie in the unit interval [0, 1]")
return u, isscalar
[docs]class MinTree(SegmentTree):
r"""
A min-tree data structure, which is a :class:`SegmentTree` whose reducer is :data:`minimum
<numpy.minimum>`.
Parameters
----------
capacity : positive int
Number of values to accommodate.
"""
def __init__(self, capacity):
super().__init__(capacity=capacity, reducer=onp.minimum, init_value=float('inf'))
[docs]class MaxTree(SegmentTree):
r"""
A max-tree data structure, which is a :class:`SegmentTree` whose reducer is :data:`maximum
<numpy.maximum>`.
Parameters
----------
capacity : positive int
Number of values to accommodate.
"""
def __init__(self, capacity):
super().__init__(capacity=capacity, reducer=onp.maximum, init_value=-float('inf'))