segment_tree¶
ソースコード¶
from titan_pylib.data_structures.segment_tree.segment_tree import SegmentTree
展開済みコード¶
1# from titan_pylib.data_structures.segment_tree.segment_tree import SegmentTree
2# from titan_pylib.data_structures.segment_tree.segment_tree_interface import (
3# SegmentTreeInterface,
4# )
5from abc import ABC, abstractmethod
6from typing import TypeVar, Generic, Union, Iterable, Callable
7
8T = TypeVar("T")
9
10
11class SegmentTreeInterface(ABC, Generic[T]):
12
13 @abstractmethod
14 def __init__(self, n_or_a: Union[int, Iterable[T]], op: Callable[[T, T], T], e: T):
15 raise NotImplementedError
16
17 @abstractmethod
18 def set(self, k: int, v: T) -> None:
19 raise NotImplementedError
20
21 @abstractmethod
22 def get(self, k: int) -> T:
23 raise NotImplementedError
24
25 @abstractmethod
26 def prod(self, l: int, r: int) -> T:
27 raise NotImplementedError
28
29 @abstractmethod
30 def all_prod(self) -> T:
31 raise NotImplementedError
32
33 @abstractmethod
34 def max_right(self, l: int, f: Callable[[T], bool]) -> int:
35 raise NotImplementedError
36
37 @abstractmethod
38 def min_left(self, r: int, f: Callable[[T], bool]) -> int:
39 raise NotImplementedError
40
41 @abstractmethod
42 def tolist(self) -> list[T]:
43 raise NotImplementedError
44
45 @abstractmethod
46 def __getitem__(self, k: int) -> T:
47 raise NotImplementedError
48
49 @abstractmethod
50 def __setitem__(self, k: int, v: T) -> None:
51 raise NotImplementedError
52
53 @abstractmethod
54 def __str__(self):
55 raise NotImplementedError
56
57 @abstractmethod
58 def __repr__(self):
59 raise NotImplementedError
60from typing import Generic, Iterable, TypeVar, Callable, Union
61
62T = TypeVar("T")
63
64
65class SegmentTree(SegmentTreeInterface, Generic[T]):
66 """セグ木です。非再帰です。"""
67
68 def __init__(
69 self, n_or_a: Union[int, Iterable[T]], op: Callable[[T, T], T], e: T
70 ) -> None:
71 """``SegmentTree`` を構築します。
72 :math:`O(n)` です。
73
74 Args:
75 n_or_a (Union[int, Iterable[T]]): ``n: int`` のとき、 ``e`` を初期値として長さ ``n`` の ``SegmentTree`` を構築します。
76 ``a: Iterable[T]`` のとき、 ``a`` から ``SegmentTree`` を構築します。
77 op (Callable[[T, T], T]): 2項演算の関数です。
78 e (T): 単位元です。
79 """
80 self._op = op
81 self._e = e
82 if isinstance(n_or_a, int):
83 self._n = n_or_a
84 self._log = (self._n - 1).bit_length()
85 self._size = 1 << self._log
86 self._data = [e] * (self._size << 1)
87 else:
88 n_or_a = list(n_or_a)
89 self._n = len(n_or_a)
90 self._log = (self._n - 1).bit_length()
91 self._size = 1 << self._log
92 _data = [e] * (self._size << 1)
93 _data[self._size : self._size + self._n] = n_or_a
94 for i in range(self._size - 1, 0, -1):
95 _data[i] = op(_data[i << 1], _data[i << 1 | 1])
96 self._data = _data
97
98 def set(self, k: int, v: T) -> None:
99 """一点更新です。
100 :math:`O(\\log{n})` です。
101
102 Args:
103 k (int): 更新するインデックスです。
104 v (T): 更新する値です。
105
106 制約:
107 :math:`-n \\leq n \\leq k < n`
108 """
109 assert (
110 -self._n <= k < self._n
111 ), f"IndexError: {self.__class__.__name__}.set({k}, {v}), n={self._n}"
112 if k < 0:
113 k += self._n
114 k += self._size
115 self._data[k] = v
116 for _ in range(self._log):
117 k >>= 1
118 self._data[k] = self._op(self._data[k << 1], self._data[k << 1 | 1])
119
120 def get(self, k: int) -> T:
121 """一点取得です。
122 :math:`O(1)` です。
123
124 Args:
125 k (int): インデックスです。
126
127 制約:
128 :math:`-n \\leq n \\leq k < n`
129 """
130 assert (
131 -self._n <= k < self._n
132 ), f"IndexError: {self.__class__.__name__}.get({k}), n={self._n}"
133 if k < 0:
134 k += self._n
135 return self._data[k + self._size]
136
137 def prod(self, l: int, r: int) -> T:
138 """区間 ``[l, r)`` の総積を返します。
139 :math:`O(\\log{n})` です。
140
141 Args:
142 l (int): インデックスです。
143 r (int): インデックスです。
144
145 制約:
146 :math:`0 \\leq l \\leq r \\leq n`
147 """
148 assert (
149 0 <= l <= r <= self._n
150 ), f"IndexError: {self.__class__.__name__}.prod({l}, {r})"
151 l += self._size
152 r += self._size
153 lres = self._e
154 rres = self._e
155 while l < r:
156 if l & 1:
157 lres = self._op(lres, self._data[l])
158 l += 1
159 if r & 1:
160 rres = self._op(self._data[r ^ 1], rres)
161 l >>= 1
162 r >>= 1
163 return self._op(lres, rres)
164
165 def all_prod(self) -> T:
166 """区間 ``[0, n)`` の総積を返します。
167 :math:`O(1)` です。
168 """
169 return self._data[1]
170
171 def max_right(self, l: int, f: Callable[[T], bool]) -> int:
172 """Find the largest index R s.t. f([l, R)) == True. / O(\\log{n})"""
173 assert (
174 0 <= l <= self._n
175 ), f"IndexError: {self.__class__.__name__}.max_right({l}, f) index out of range"
176 # assert f(self._e), \
177 # f'{self.__class__.__name__}.max_right({l}, f), f({self._e}) must be true.'
178 if l == self._n:
179 return self._n
180 l += self._size
181 s = self._e
182 while True:
183 while l & 1 == 0:
184 l >>= 1
185 if not f(self._op(s, self._data[l])):
186 while l < self._size:
187 l <<= 1
188 if f(self._op(s, self._data[l])):
189 s = self._op(s, self._data[l])
190 l |= 1
191 return l - self._size
192 s = self._op(s, self._data[l])
193 l += 1
194 if l & -l == l:
195 break
196 return self._n
197
198 def min_left(self, r: int, f: Callable[[T], bool]) -> int:
199 """Find the smallest index L s.t. f([L, r)) == True. / O(\\log{n})"""
200 assert (
201 0 <= r <= self._n
202 ), f"IndexError: {self.__class__.__name__}.min_left({r}, f) index out of range"
203 # assert f(self._e), \
204 # f'{self.__class__.__name__}.min_left({r}, f), f({self._e}) must be true.'
205 if r == 0:
206 return 0
207 r += self._size
208 s = self._e
209 while True:
210 r -= 1
211 while r > 1 and r & 1:
212 r >>= 1
213 if not f(self._op(self._data[r], s)):
214 while r < self._size:
215 r = r << 1 | 1
216 if f(self._op(self._data[r], s)):
217 s = self._op(self._data[r], s)
218 r ^= 1
219 return r + 1 - self._size
220 s = self._op(self._data[r], s)
221 if r & -r == r:
222 break
223 return 0
224
225 def tolist(self) -> list[T]:
226 """リストにして返します。
227 :math:`O(n)` です。
228 """
229 return [self.get(i) for i in range(self._n)]
230
231 def show(self) -> None:
232 """デバッグ用のメソッドです。"""
233 print(
234 f"<{self.__class__.__name__}> [\n"
235 + "\n".join(
236 [
237 " "
238 + " ".join(
239 map(str, [self._data[(1 << i) + j] for j in range(1 << i)])
240 )
241 for i in range(self._log + 1)
242 ]
243 )
244 + "\n]"
245 )
246
247 def __getitem__(self, k: int) -> T:
248 assert (
249 -self._n <= k < self._n
250 ), f"IndexError: {self.__class__.__name__}.__getitem__({k}), n={self._n}"
251 return self.get(k)
252
253 def __setitem__(self, k: int, v: T):
254 assert (
255 -self._n <= k < self._n
256 ), f"IndexError: {self.__class__.__name__}.__setitem__{k}, {v}), n={self._n}"
257 self.set(k, v)
258
259 def __len__(self) -> int:
260 return self._n
261
262 def __str__(self) -> str:
263 return str(self.tolist())
264
265 def __repr__(self) -> str:
266 return f"{self.__class__.__name__}({self})"
仕様¶
- class SegmentTree(n_or_a: int | Iterable[T], op: Callable[[T, T], T], e: T)[source]¶
Bases:
SegmentTreeInterface
,Generic
[T
]セグ木です。非再帰です。
- max_right(l: int, f: Callable[[T], bool]) int [source]¶
Find the largest index R s.t. f([l, R)) == True. / O(log{n})
- min_left(r: int, f: Callable[[T], bool]) int [source]¶
Find the smallest index L s.t. f([L, r)) == True. / O(log{n})
- prod(l: int, r: int) T [source]¶
区間
[l, r)
の総積を返します。 です。- Parameters:
l (int) – インデックスです。
r (int) – インデックスです。
- 制約: