1from titan_pylib.data_structures.segment_tree.segment_tree_interface import (
2 SegmentTreeInterface,
3)
4from typing import Generic, Iterable, TypeVar, Callable, Union
5
6T = TypeVar("T")
7
8
[docs]
9class SegmentTree(SegmentTreeInterface, Generic[T]):
10 """セグ木です。非再帰です。"""
11
12 def __init__(
13 self, n_or_a: Union[int, Iterable[T]], op: Callable[[T, T], T], e: T
14 ) -> None:
15 """``SegmentTree`` を構築します。
16 :math:`O(n)` です。
17
18 Args:
19 n_or_a (Union[int, Iterable[T]]): ``n: int`` のとき、 ``e`` を初期値として長さ ``n`` の ``SegmentTree`` を構築します。
20 ``a: Iterable[T]`` のとき、 ``a`` から ``SegmentTree`` を構築します。
21 op (Callable[[T, T], T]): 2項演算の関数です。
22 e (T): 単位元です。
23 """
24 self._op = op
25 self._e = e
26 if isinstance(n_or_a, int):
27 self._n = n_or_a
28 self._log = (self._n - 1).bit_length()
29 self._size = 1 << self._log
30 self._data = [e] * (self._size << 1)
31 else:
32 n_or_a = list(n_or_a)
33 self._n = len(n_or_a)
34 self._log = (self._n - 1).bit_length()
35 self._size = 1 << self._log
36 _data = [e] * (self._size << 1)
37 _data[self._size : self._size + self._n] = n_or_a
38 for i in range(self._size - 1, 0, -1):
39 _data[i] = op(_data[i << 1], _data[i << 1 | 1])
40 self._data = _data
41
[docs]
42 def set(self, k: int, v: T) -> None:
43 """一点更新です。
44 :math:`O(\\log{n})` です。
45
46 Args:
47 k (int): 更新するインデックスです。
48 v (T): 更新する値です。
49
50 制約:
51 :math:`-n \\leq n \\leq k < n`
52 """
53 assert (
54 -self._n <= k < self._n
55 ), f"IndexError: {self.__class__.__name__}.set({k}, {v}), n={self._n}"
56 if k < 0:
57 k += self._n
58 k += self._size
59 self._data[k] = v
60 for _ in range(self._log):
61 k >>= 1
62 self._data[k] = self._op(self._data[k << 1], self._data[k << 1 | 1])
63
[docs]
64 def get(self, k: int) -> T:
65 """一点取得です。
66 :math:`O(1)` です。
67
68 Args:
69 k (int): インデックスです。
70
71 制約:
72 :math:`-n \\leq n \\leq k < n`
73 """
74 assert (
75 -self._n <= k < self._n
76 ), f"IndexError: {self.__class__.__name__}.get({k}), n={self._n}"
77 if k < 0:
78 k += self._n
79 return self._data[k + self._size]
80
[docs]
81 def prod(self, l: int, r: int) -> T:
82 """区間 ``[l, r)`` の総積を返します。
83 :math:`O(\\log{n})` です。
84
85 Args:
86 l (int): インデックスです。
87 r (int): インデックスです。
88
89 制約:
90 :math:`0 \\leq l \\leq r \\leq n`
91 """
92 assert (
93 0 <= l <= r <= self._n
94 ), f"IndexError: {self.__class__.__name__}.prod({l}, {r})"
95 l += self._size
96 r += self._size
97 lres = self._e
98 rres = self._e
99 while l < r:
100 if l & 1:
101 lres = self._op(lres, self._data[l])
102 l += 1
103 if r & 1:
104 rres = self._op(self._data[r ^ 1], rres)
105 l >>= 1
106 r >>= 1
107 return self._op(lres, rres)
108
[docs]
109 def all_prod(self) -> T:
110 """区間 ``[0, n)`` の総積を返します。
111 :math:`O(1)` です。
112 """
113 return self._data[1]
114
[docs]
115 def max_right(self, l: int, f: Callable[[T], bool]) -> int:
116 """Find the largest index R s.t. f([l, R)) == True. / O(\\log{n})"""
117 assert (
118 0 <= l <= self._n
119 ), f"IndexError: {self.__class__.__name__}.max_right({l}, f) index out of range"
120 # assert f(self._e), \
121 # f'{self.__class__.__name__}.max_right({l}, f), f({self._e}) must be true.'
122 if l == self._n:
123 return self._n
124 l += self._size
125 s = self._e
126 while True:
127 while l & 1 == 0:
128 l >>= 1
129 if not f(self._op(s, self._data[l])):
130 while l < self._size:
131 l <<= 1
132 if f(self._op(s, self._data[l])):
133 s = self._op(s, self._data[l])
134 l |= 1
135 return l - self._size
136 s = self._op(s, self._data[l])
137 l += 1
138 if l & -l == l:
139 break
140 return self._n
141
[docs]
142 def min_left(self, r: int, f: Callable[[T], bool]) -> int:
143 """Find the smallest index L s.t. f([L, r)) == True. / O(\\log{n})"""
144 assert (
145 0 <= r <= self._n
146 ), f"IndexError: {self.__class__.__name__}.min_left({r}, f) index out of range"
147 # assert f(self._e), \
148 # f'{self.__class__.__name__}.min_left({r}, f), f({self._e}) must be true.'
149 if r == 0:
150 return 0
151 r += self._size
152 s = self._e
153 while True:
154 r -= 1
155 while r > 1 and r & 1:
156 r >>= 1
157 if not f(self._op(self._data[r], s)):
158 while r < self._size:
159 r = r << 1 | 1
160 if f(self._op(self._data[r], s)):
161 s = self._op(self._data[r], s)
162 r ^= 1
163 return r + 1 - self._size
164 s = self._op(self._data[r], s)
165 if r & -r == r:
166 break
167 return 0
168
[docs]
169 def tolist(self) -> list[T]:
170 """リストにして返します。
171 :math:`O(n)` です。
172 """
173 return [self.get(i) for i in range(self._n)]
174
[docs]
175 def show(self) -> None:
176 """デバッグ用のメソッドです。"""
177 print(
178 f"<{self.__class__.__name__}> [\n"
179 + "\n".join(
180 [
181 " "
182 + " ".join(
183 map(str, [self._data[(1 << i) + j] for j in range(1 << i)])
184 )
185 for i in range(self._log + 1)
186 ]
187 )
188 + "\n]"
189 )
190
191 def __getitem__(self, k: int) -> T:
192 assert (
193 -self._n <= k < self._n
194 ), f"IndexError: {self.__class__.__name__}.__getitem__({k}), n={self._n}"
195 return self.get(k)
196
197 def __setitem__(self, k: int, v: T):
198 assert (
199 -self._n <= k < self._n
200 ), f"IndexError: {self.__class__.__name__}.__setitem__{k}, {v}), n={self._n}"
201 self.set(k, v)
202
203 def __len__(self) -> int:
204 return self._n
205
206 def __str__(self) -> str:
207 return str(self.tolist())
208
209 def __repr__(self) -> str:
210 return f"{self.__class__.__name__}({self})"