Source code for titan_pylib.data_structures.segment_tree.segment_tree_RSQ

  1from titan_pylib.data_structures.segment_tree.segment_tree_interface import (
  2    SegmentTreeInterface,
  3)
  4from titan_pylib.my_class.supports_add import SupportsAdd
  5from typing import Generic, Iterable, TypeVar, Union
  6
  7T = TypeVar("T", bound=SupportsAdd)
  8
  9
[docs] 10class SegmentTreeRSQ(SegmentTreeInterface, Generic[T]): 11 """RSQ セグ木です。""" 12 13 def __init__(self, _n_or_a: Union[int, Iterable[T]], e: T = 0) -> None: 14 self._e = e 15 if isinstance(_n_or_a, int): 16 self._n = _n_or_a 17 self._log = (self._n - 1).bit_length() 18 self._size = 1 << self._log 19 self._data = [self._e] * (self._size << 1) 20 else: 21 _n_or_a = list(_n_or_a) 22 self._n = len(_n_or_a) 23 self._log = (self._n - 1).bit_length() 24 self._size = 1 << self._log 25 _data = [self._e] * (self._size << 1) 26 _data[self._size : self._size + self._n] = _n_or_a 27 for i in range(self._size - 1, 0, -1): 28 _data[i] = _data[i << 1] + _data[i << 1 | 1] 29 self._data = _data 30
[docs] 31 def set(self, k: int, v: T) -> None: 32 assert ( 33 -self._n <= k < self._n 34 ), f"IndexError: {self.__class__.__name__}.set({k}: int, {v}: T), n={self._n}" 35 if k < 0: 36 k += self._n 37 k += self._size 38 self._data[k] = v 39 for _ in range(self._log): 40 k >>= 1 41 self._data[k] = self._data[k << 1] + self._data[k << 1 | 1]
42
[docs] 43 def get(self, k: int) -> T: 44 assert ( 45 -self._n <= k < self._n 46 ), f"IndexError: {self.__class__.__name__}.get({k}: int), n={self._n}" 47 if k < 0: 48 k += self._n 49 return self._data[k + self._size]
50
[docs] 51 def prod(self, l: int, r: int): 52 assert ( 53 0 <= l <= r <= self._n 54 ), f"IndexError: {self.__class__.__name__}.prod({l}: int, {r}: int)" 55 l += self._size 56 r += self._size 57 res = self._e 58 while l < r: 59 if l & 1: 60 res += self._data[l] 61 l += 1 62 if r & 1: 63 res += self._data[r ^ 1] 64 l >>= 1 65 r >>= 1 66 return res
67
[docs] 68 def all_prod(self): 69 return self._data[1]
70
[docs] 71 def max_right(self, l: int, f=lambda lr: lr): 72 assert ( 73 0 <= l <= self._n 74 ), f"IndexError: {self.__class__.__name__}.max_right({l}, f) index out of range" 75 assert f( 76 self._e 77 ), f"{self.__class__.__name__}.max_right({l}, f), f({self._e}) must be true." 78 if l == self._n: 79 return self._n 80 l += self._size 81 s = self._e 82 while True: 83 while l & 1 == 0: 84 l >>= 1 85 if not f(s + self._data[l]): 86 while l < self._size: 87 l <<= 1 88 if f(s + self._data[l]): 89 s += self._data[l] 90 l += 1 91 return l - self._size 92 s += self._data[l] 93 l += 1 94 if l & -l == l: 95 break 96 return self._n
97
[docs] 98 def min_left(self, r: int, f=lambda lr: lr): 99 assert ( 100 0 <= r <= self._n 101 ), f"IndexError: {self.__class__.__name__}.min_left({r}, f) index out of range" 102 assert f( 103 self._e 104 ), f"{self.__class__.__name__}.min_left({r}, f), f({self._e}) must be true." 105 if r == 0: 106 return 0 107 r += self._size 108 s = self._e 109 while True: 110 r -= 1 111 while r > 1 and r & 1: 112 r >>= 1 113 if not f(self._data[r] + s): 114 while r < self._size: 115 r = r << 1 | 1 116 if f(self._data[r] + s): 117 s += self._data[r] 118 r -= 1 119 return r + 1 - self._size 120 s += self._data[r] 121 if r & -r == r: 122 break 123 return 0
124
[docs] 125 def tolist(self) -> list[T]: 126 return [self.get(i) for i in range(self._n)]
127
[docs] 128 def show(self) -> None: 129 print( 130 f"<{self.__class__.__name__}> [\n" 131 + "\n".join( 132 [ 133 " " 134 + " ".join( 135 map(str, [self._data[(1 << i) + j] for j in range(1 << i)]) 136 ) 137 for i in range(self._log + 1) 138 ] 139 ) 140 + "\n]" 141 )
142 143 def __getitem__(self, k: int) -> T: 144 assert ( 145 -self._n <= k < self._n 146 ), f"IndexError: {self.__class__.__name__}.__getitem__({k}: int), n={self._n}" 147 return self.get(k) 148 149 def __setitem__(self, k: int, v: T): 150 assert ( 151 -self._n <= k < self._n 152 ), f"IndexError: {self.__class__.__name__}.__setitem__{k}: int, {v}: T), n={self._n}" 153 self.set(k, v) 154 155 def __str__(self): 156 return "[" + ", ".join(map(str, (self.get(i) for i in range(self._n)))) + "]" 157 158 def __repr__(self): 159 return f"{self.__class__.__name__}({self})"