Source code for titan_pylib.data_structures.segment_tree.segment_tree

  1from titan_pylib.data_structures.segment_tree.segment_tree_interface import (
  2    SegmentTreeInterface,
  3)
  4from typing import Generic, Iterable, TypeVar, Callable, Union
  5
  6T = TypeVar("T")
  7
  8
[docs] 9class SegmentTree(SegmentTreeInterface, Generic[T]): 10 """セグ木です。非再帰です。""" 11 12 def __init__( 13 self, n_or_a: Union[int, Iterable[T]], op: Callable[[T, T], T], e: T 14 ) -> None: 15 """``SegmentTree`` を構築します。 16 :math:`O(n)` です。 17 18 Args: 19 n_or_a (Union[int, Iterable[T]]): ``n: int`` のとき、 ``e`` を初期値として長さ ``n`` の ``SegmentTree`` を構築します。 20 ``a: Iterable[T]`` のとき、 ``a`` から ``SegmentTree`` を構築します。 21 op (Callable[[T, T], T]): 2項演算の関数です。 22 e (T): 単位元です。 23 """ 24 self._op = op 25 self._e = e 26 if isinstance(n_or_a, int): 27 self._n = n_or_a 28 self._log = (self._n - 1).bit_length() 29 self._size = 1 << self._log 30 self._data = [e] * (self._size << 1) 31 else: 32 n_or_a = list(n_or_a) 33 self._n = len(n_or_a) 34 self._log = (self._n - 1).bit_length() 35 self._size = 1 << self._log 36 _data = [e] * (self._size << 1) 37 _data[self._size : self._size + self._n] = n_or_a 38 for i in range(self._size - 1, 0, -1): 39 _data[i] = op(_data[i << 1], _data[i << 1 | 1]) 40 self._data = _data 41
[docs] 42 def set(self, k: int, v: T) -> None: 43 """一点更新です。 44 :math:`O(\\log{n})` です。 45 46 Args: 47 k (int): 更新するインデックスです。 48 v (T): 更新する値です。 49 50 制約: 51 :math:`-n \\leq n \\leq k < n` 52 """ 53 assert ( 54 -self._n <= k < self._n 55 ), f"IndexError: {self.__class__.__name__}.set({k}, {v}), n={self._n}" 56 if k < 0: 57 k += self._n 58 k += self._size 59 self._data[k] = v 60 for _ in range(self._log): 61 k >>= 1 62 self._data[k] = self._op(self._data[k << 1], self._data[k << 1 | 1])
63
[docs] 64 def get(self, k: int) -> T: 65 """一点取得です。 66 :math:`O(1)` です。 67 68 Args: 69 k (int): インデックスです。 70 71 制約: 72 :math:`-n \\leq n \\leq k < n` 73 """ 74 assert ( 75 -self._n <= k < self._n 76 ), f"IndexError: {self.__class__.__name__}.get({k}), n={self._n}" 77 if k < 0: 78 k += self._n 79 return self._data[k + self._size]
80
[docs] 81 def prod(self, l: int, r: int) -> T: 82 """区間 ``[l, r)`` の総積を返します。 83 :math:`O(\\log{n})` です。 84 85 Args: 86 l (int): インデックスです。 87 r (int): インデックスです。 88 89 制約: 90 :math:`0 \\leq l \\leq r \\leq n` 91 """ 92 assert ( 93 0 <= l <= r <= self._n 94 ), f"IndexError: {self.__class__.__name__}.prod({l}, {r})" 95 l += self._size 96 r += self._size 97 lres = self._e 98 rres = self._e 99 while l < r: 100 if l & 1: 101 lres = self._op(lres, self._data[l]) 102 l += 1 103 if r & 1: 104 rres = self._op(self._data[r ^ 1], rres) 105 l >>= 1 106 r >>= 1 107 return self._op(lres, rres)
108
[docs] 109 def all_prod(self) -> T: 110 """区間 ``[0, n)`` の総積を返します。 111 :math:`O(1)` です。 112 """ 113 return self._data[1]
114
[docs] 115 def max_right(self, l: int, f: Callable[[T], bool]) -> int: 116 """Find the largest index R s.t. f([l, R)) == True. / O(\\log{n})""" 117 assert ( 118 0 <= l <= self._n 119 ), f"IndexError: {self.__class__.__name__}.max_right({l}, f) index out of range" 120 # assert f(self._e), \ 121 # f'{self.__class__.__name__}.max_right({l}, f), f({self._e}) must be true.' 122 if l == self._n: 123 return self._n 124 l += self._size 125 s = self._e 126 while True: 127 while l & 1 == 0: 128 l >>= 1 129 if not f(self._op(s, self._data[l])): 130 while l < self._size: 131 l <<= 1 132 if f(self._op(s, self._data[l])): 133 s = self._op(s, self._data[l]) 134 l |= 1 135 return l - self._size 136 s = self._op(s, self._data[l]) 137 l += 1 138 if l & -l == l: 139 break 140 return self._n
141
[docs] 142 def min_left(self, r: int, f: Callable[[T], bool]) -> int: 143 """Find the smallest index L s.t. f([L, r)) == True. / O(\\log{n})""" 144 assert ( 145 0 <= r <= self._n 146 ), f"IndexError: {self.__class__.__name__}.min_left({r}, f) index out of range" 147 # assert f(self._e), \ 148 # f'{self.__class__.__name__}.min_left({r}, f), f({self._e}) must be true.' 149 if r == 0: 150 return 0 151 r += self._size 152 s = self._e 153 while True: 154 r -= 1 155 while r > 1 and r & 1: 156 r >>= 1 157 if not f(self._op(self._data[r], s)): 158 while r < self._size: 159 r = r << 1 | 1 160 if f(self._op(self._data[r], s)): 161 s = self._op(self._data[r], s) 162 r ^= 1 163 return r + 1 - self._size 164 s = self._op(self._data[r], s) 165 if r & -r == r: 166 break 167 return 0
168
[docs] 169 def tolist(self) -> list[T]: 170 """リストにして返します。 171 :math:`O(n)` です。 172 """ 173 return [self.get(i) for i in range(self._n)]
174
[docs] 175 def show(self) -> None: 176 """デバッグ用のメソッドです。""" 177 print( 178 f"<{self.__class__.__name__}> [\n" 179 + "\n".join( 180 [ 181 " " 182 + " ".join( 183 map(str, [self._data[(1 << i) + j] for j in range(1 << i)]) 184 ) 185 for i in range(self._log + 1) 186 ] 187 ) 188 + "\n]" 189 )
190 191 def __getitem__(self, k: int) -> T: 192 assert ( 193 -self._n <= k < self._n 194 ), f"IndexError: {self.__class__.__name__}.__getitem__({k}), n={self._n}" 195 return self.get(k) 196 197 def __setitem__(self, k: int, v: T): 198 assert ( 199 -self._n <= k < self._n 200 ), f"IndexError: {self.__class__.__name__}.__setitem__{k}, {v}), n={self._n}" 201 self.set(k, v) 202 203 def __len__(self) -> int: 204 return self._n 205 206 def __str__(self) -> str: 207 return str(self.tolist()) 208 209 def __repr__(self) -> str: 210 return f"{self.__class__.__name__}({self})"