1# from titan_pylib.data_structures.segment_tree.range_set_range_composite import RangeSetRangeComposite
2# from titan_pylib.data_structures.segment_tree.segment_tree import SegmentTree
3# from titan_pylib.data_structures.segment_tree.segment_tree_interface import (
4# SegmentTreeInterface,
5# )
6from abc import ABC, abstractmethod
7from typing import TypeVar, Generic, Union, Iterable, Callable
8
9T = TypeVar("T")
10
11
12class SegmentTreeInterface(ABC, Generic[T]):
13
14 @abstractmethod
15 def __init__(self, n_or_a: Union[int, Iterable[T]], op: Callable[[T, T], T], e: T):
16 raise NotImplementedError
17
18 @abstractmethod
19 def set(self, k: int, v: T) -> None:
20 raise NotImplementedError
21
22 @abstractmethod
23 def get(self, k: int) -> T:
24 raise NotImplementedError
25
26 @abstractmethod
27 def prod(self, l: int, r: int) -> T:
28 raise NotImplementedError
29
30 @abstractmethod
31 def all_prod(self) -> T:
32 raise NotImplementedError
33
34 @abstractmethod
35 def max_right(self, l: int, f: Callable[[T], bool]) -> int:
36 raise NotImplementedError
37
38 @abstractmethod
39 def min_left(self, r: int, f: Callable[[T], bool]) -> int:
40 raise NotImplementedError
41
42 @abstractmethod
43 def tolist(self) -> list[T]:
44 raise NotImplementedError
45
46 @abstractmethod
47 def __getitem__(self, k: int) -> T:
48 raise NotImplementedError
49
50 @abstractmethod
51 def __setitem__(self, k: int, v: T) -> None:
52 raise NotImplementedError
53
54 @abstractmethod
55 def __str__(self):
56 raise NotImplementedError
57
58 @abstractmethod
59 def __repr__(self):
60 raise NotImplementedError
61from typing import Generic, Iterable, TypeVar, Callable, Union
62
63T = TypeVar("T")
64
65
66class SegmentTree(SegmentTreeInterface, Generic[T]):
67 """セグ木です。非再帰です。"""
68
69 def __init__(
70 self, n_or_a: Union[int, Iterable[T]], op: Callable[[T, T], T], e: T
71 ) -> None:
72 """``SegmentTree`` を構築します。
73 :math:`O(n)` です。
74
75 Args:
76 n_or_a (Union[int, Iterable[T]]): ``n: int`` のとき、 ``e`` を初期値として長さ ``n`` の ``SegmentTree`` を構築します。
77 ``a: Iterable[T]`` のとき、 ``a`` から ``SegmentTree`` を構築します。
78 op (Callable[[T, T], T]): 2項演算の関数です。
79 e (T): 単位元です。
80 """
81 self._op = op
82 self._e = e
83 if isinstance(n_or_a, int):
84 self._n = n_or_a
85 self._log = (self._n - 1).bit_length()
86 self._size = 1 << self._log
87 self._data = [e] * (self._size << 1)
88 else:
89 n_or_a = list(n_or_a)
90 self._n = len(n_or_a)
91 self._log = (self._n - 1).bit_length()
92 self._size = 1 << self._log
93 _data = [e] * (self._size << 1)
94 _data[self._size : self._size + self._n] = n_or_a
95 for i in range(self._size - 1, 0, -1):
96 _data[i] = op(_data[i << 1], _data[i << 1 | 1])
97 self._data = _data
98
99 def set(self, k: int, v: T) -> None:
100 """一点更新です。
101 :math:`O(\\log{n})` です。
102
103 Args:
104 k (int): 更新するインデックスです。
105 v (T): 更新する値です。
106
107 制約:
108 :math:`-n \\leq n \\leq k < n`
109 """
110 assert (
111 -self._n <= k < self._n
112 ), f"IndexError: {self.__class__.__name__}.set({k}, {v}), n={self._n}"
113 if k < 0:
114 k += self._n
115 k += self._size
116 self._data[k] = v
117 for _ in range(self._log):
118 k >>= 1
119 self._data[k] = self._op(self._data[k << 1], self._data[k << 1 | 1])
120
121 def get(self, k: int) -> T:
122 """一点取得です。
123 :math:`O(1)` です。
124
125 Args:
126 k (int): インデックスです。
127
128 制約:
129 :math:`-n \\leq n \\leq k < n`
130 """
131 assert (
132 -self._n <= k < self._n
133 ), f"IndexError: {self.__class__.__name__}.get({k}), n={self._n}"
134 if k < 0:
135 k += self._n
136 return self._data[k + self._size]
137
138 def prod(self, l: int, r: int) -> T:
139 """区間 ``[l, r)`` の総積を返します。
140 :math:`O(\\log{n})` です。
141
142 Args:
143 l (int): インデックスです。
144 r (int): インデックスです。
145
146 制約:
147 :math:`0 \\leq l \\leq r \\leq n`
148 """
149 assert (
150 0 <= l <= r <= self._n
151 ), f"IndexError: {self.__class__.__name__}.prod({l}, {r})"
152 l += self._size
153 r += self._size
154 lres = self._e
155 rres = self._e
156 while l < r:
157 if l & 1:
158 lres = self._op(lres, self._data[l])
159 l += 1
160 if r & 1:
161 rres = self._op(self._data[r ^ 1], rres)
162 l >>= 1
163 r >>= 1
164 return self._op(lres, rres)
165
166 def all_prod(self) -> T:
167 """区間 ``[0, n)`` の総積を返します。
168 :math:`O(1)` です。
169 """
170 return self._data[1]
171
172 def max_right(self, l: int, f: Callable[[T], bool]) -> int:
173 """Find the largest index R s.t. f([l, R)) == True. / O(\\log{n})"""
174 assert (
175 0 <= l <= self._n
176 ), f"IndexError: {self.__class__.__name__}.max_right({l}, f) index out of range"
177 # assert f(self._e), \
178 # f'{self.__class__.__name__}.max_right({l}, f), f({self._e}) must be true.'
179 if l == self._n:
180 return self._n
181 l += self._size
182 s = self._e
183 while True:
184 while l & 1 == 0:
185 l >>= 1
186 if not f(self._op(s, self._data[l])):
187 while l < self._size:
188 l <<= 1
189 if f(self._op(s, self._data[l])):
190 s = self._op(s, self._data[l])
191 l |= 1
192 return l - self._size
193 s = self._op(s, self._data[l])
194 l += 1
195 if l & -l == l:
196 break
197 return self._n
198
199 def min_left(self, r: int, f: Callable[[T], bool]) -> int:
200 """Find the smallest index L s.t. f([L, r)) == True. / O(\\log{n})"""
201 assert (
202 0 <= r <= self._n
203 ), f"IndexError: {self.__class__.__name__}.min_left({r}, f) index out of range"
204 # assert f(self._e), \
205 # f'{self.__class__.__name__}.min_left({r}, f), f({self._e}) must be true.'
206 if r == 0:
207 return 0
208 r += self._size
209 s = self._e
210 while True:
211 r -= 1
212 while r > 1 and r & 1:
213 r >>= 1
214 if not f(self._op(self._data[r], s)):
215 while r < self._size:
216 r = r << 1 | 1
217 if f(self._op(self._data[r], s)):
218 s = self._op(self._data[r], s)
219 r ^= 1
220 return r + 1 - self._size
221 s = self._op(self._data[r], s)
222 if r & -r == r:
223 break
224 return 0
225
226 def tolist(self) -> list[T]:
227 """リストにして返します。
228 :math:`O(n)` です。
229 """
230 return [self.get(i) for i in range(self._n)]
231
232 def show(self) -> None:
233 """デバッグ用のメソッドです。"""
234 print(
235 f"<{self.__class__.__name__}> [\n"
236 + "\n".join(
237 [
238 " "
239 + " ".join(
240 map(str, [self._data[(1 << i) + j] for j in range(1 << i)])
241 )
242 for i in range(self._log + 1)
243 ]
244 )
245 + "\n]"
246 )
247
248 def __getitem__(self, k: int) -> T:
249 assert (
250 -self._n <= k < self._n
251 ), f"IndexError: {self.__class__.__name__}.__getitem__({k}), n={self._n}"
252 return self.get(k)
253
254 def __setitem__(self, k: int, v: T):
255 assert (
256 -self._n <= k < self._n
257 ), f"IndexError: {self.__class__.__name__}.__setitem__{k}, {v}), n={self._n}"
258 self.set(k, v)
259
260 def __len__(self) -> int:
261 return self._n
262
263 def __str__(self) -> str:
264 return str(self.tolist())
265
266 def __repr__(self) -> str:
267 return f"{self.__class__.__name__}({self})"
268# from titan_pylib.data_structures.set.wordsize_tree_set import WordsizeTreeSet
269from array import array
270from typing import Iterable, Optional
271
272
273class WordsizeTreeSet:
274 """``[0, u)`` の整数集合を管理する32分木です。
275 空間 :math:`O(u)` であることに注意してください。
276 """
277
278 def __init__(self, u: int, a: Iterable[int] = []) -> None:
279 """:math:`O(u)` です。"""
280 assert u >= 0
281 u += 1 # 念のため
282 self.u = u
283 data = []
284 len_ = 0
285 if a:
286 u >>= 5
287 A = array("I", bytes(4 * (u + 1)))
288 for a_ in a:
289 assert (
290 0 <= a_ < self.u
291 ), f"ValueError: {self.__class__.__name__}.__init__, {a_}, u={u}"
292 if A[a_ >> 5] >> (a_ & 31) & 1 == 0:
293 len_ += 1
294 A[a_ >> 5] |= 1 << (a_ & 31)
295 data.append(A)
296 while u:
297 a = array("I", bytes(4 * ((u >> 5) + 1)))
298 for i in range(u + 1):
299 if A[i]:
300 a[i >> 5] |= 1 << (i & 31)
301 data.append(a)
302 A = a
303 u >>= 5
304 else:
305 while u:
306 u >>= 5
307 data.append(array("I", bytes(4 * (u + 1))))
308 self.data: list[array[int]] = data
309 self.len: int = len_
310 self.len_data: int = len(data)
311
312 def add(self, v: int) -> bool:
313 """整数 ``v`` を個追加します。
314 :math:`O(\\log{u})` です。
315 """
316 assert (
317 0 <= v < self.u
318 ), f"ValueError: {self.__class__.__name__}.add({v}), u={self.u}"
319 if self.data[0][v >> 5] >> (v & 31) & 1:
320 return False
321 self.len += 1
322 for a in self.data:
323 a[v >> 5] |= 1 << (v & 31)
324 v >>= 5
325 return True
326
327 def discard(self, v: int) -> bool:
328 """整数 ``v`` を削除します。
329 :math:`O(\\log{u})` です。
330 """
331 assert (
332 0 <= v < self.u
333 ), f"ValueError: {self.__class__.__name__}.discard({v}), u={self.u}"
334 if self.data[0][v >> 5] >> (v & 31) & 1 == 0:
335 return False
336 self.len -= 1
337 for a in self.data:
338 a[v >> 5] &= ~(1 << (v & 31))
339 v >>= 5
340 if a[v]:
341 break
342 return True
343
344 def remove(self, v: int) -> None:
345 """整数 ``v`` を削除します。
346 :math:`O(\\log{u})` です。
347
348 Note: ``v`` が存在しないとき、例外を投げます。
349 """
350 assert (
351 0 <= v < self.u
352 ), f"ValueError: {self.__class__.__name__}.remove({v}), u={self.u}"
353 assert self.discard(v), f"ValueError: {v} not in self."
354
355 def ge(self, v: int) -> Optional[int]:
356 """``v`` 以上で最小の要素を返します。存在しないとき、 ``None``を返します。
357 :math:`O(\\log{u})` です。
358 """
359 assert (
360 0 <= v < self.u
361 ), f"ValueError: {self.__class__.__name__}.ge({v}), u={self.u}"
362 data = self.data
363 d = 0
364 while True:
365 if d >= self.len_data or v >> 5 >= len(data[d]):
366 return None
367 m = data[d][v >> 5] & ((~0) << (v & 31))
368 if m == 0:
369 d += 1
370 v = (v >> 5) + 1
371 else:
372 v = (v >> 5 << 5) + (m & -m).bit_length() - 1
373 if d == 0:
374 break
375 v <<= 5
376 d -= 1
377 return v
378
379 def gt(self, v: int) -> Optional[int]:
380 """``v`` より大きい値で最小の要素を返します。存在しないとき、 ``None``を返します。
381 :math:`O(\\log{u})` です。
382 """
383 assert (
384 0 <= v < self.u
385 ), f"ValueError: {self.__class__.__name__}.gt({v}), u={self.u}"
386 if v + 1 == self.u:
387 return
388 return self.ge(v + 1)
389
390 def le(self, v: int) -> Optional[int]:
391 """``v`` 以下で最大の要素を返します。存在しないとき、 ``None``を返します。
392 :math:`O(\\log{u})` です。
393 """
394 assert (
395 0 <= v < self.u
396 ), f"ValueError: {self.__class__.__name__}.le({v}), u={self.u}"
397 data = self.data
398 d = 0
399 while True:
400 if v < 0 or d >= self.len_data:
401 return None
402 m = data[d][v >> 5] & ~((~1) << (v & 31))
403 if m == 0:
404 d += 1
405 v = (v >> 5) - 1
406 else:
407 v = (v >> 5 << 5) + m.bit_length() - 1
408 if d == 0:
409 break
410 v <<= 5
411 v += 31
412 d -= 1
413 return v
414
415 def lt(self, v: int) -> Optional[int]:
416 """``v`` より小さい値で最大の要素を返します。存在しないとき、 ``None``を返します。
417 :math:`O(\\log{u})` です。
418 """
419 assert (
420 0 <= v < self.u
421 ), f"ValueError: {self.__class__.__name__}.lt({v}), u={self.u}"
422 if v - 1 == 0:
423 return
424 return self.le(v - 1)
425
426 def get_min(self) -> Optional[int]:
427 """`最小値を返します。存在しないとき、 ``None``を返します。
428 :math:`O(\\log{u})` です。
429 """
430 return self.ge(0)
431
432 def get_max(self) -> Optional[int]:
433 """最大値を返します。存在しないとき、 ``None``を返します。
434 :math:`O(\\log{u})` です。
435 """
436 return self.le(self.u - 1)
437
438 def pop_min(self) -> int:
439 """最小値を削除して返します。
440 :math:`O(\\log{u})` です。
441 """
442 v = self.get_min()
443 assert (
444 v is not None
445 ), f"IndexError: pop_min() from empty {self.__class__.__name__}."
446 self.discard(v)
447 return v
448
449 def pop_max(self) -> int:
450 """最大値を削除して返します。
451 :math:`O(\\log{u})` です。
452 """
453 v = self.get_max()
454 assert (
455 v is not None
456 ), f"IndexError: pop_max() from empty {self.__class__.__name__}."
457 self.discard(v)
458 return v
459
460 def clear(self) -> None:
461 """集合を空にします。
462 :math:`O(n\\log{u})` です。
463 """
464 for e in self:
465 self.discard(e)
466 self.len = 0
467
468 def tolist(self) -> list[int]:
469 """リストにして返します。
470 :math:`O(n\\log{u})` です。
471 """
472 return [x for x in self]
473
474 def __bool__(self):
475 return self.len > 0
476
477 def __len__(self):
478 return self.len
479
480 def __contains__(self, v: int):
481 assert (
482 0 <= v < self.u
483 ), f"ValueError: {v} in {self.__class__.__name__}, u={self.u}"
484 return self.data[0][v >> 5] >> (v & 31) & 1 == 1
485
486 def __iter__(self):
487 self._val = self.ge(0)
488 return self
489
490 def __next__(self):
491 if self._val is None:
492 raise StopIteration
493 pre = self._val
494 self._val = self.gt(pre)
495 return pre
496
497 def __str__(self):
498 return "{" + ", ".join(map(str, self)) + "}"
499
500 def __repr__(self):
501 return f"{self.__class__.__name__}({self.u}, {self})"
502from typing import Union, Callable, TypeVar, Generic, Iterable
503
504T = TypeVar("T")
505
506
507class RangeSetRangeComposite(Generic[T]):
508 """区間更新+区間積です。"""
509
510 def __init__(
511 self,
512 n_or_a: Union[int, Iterable[T]],
513 op: Callable[[T, T], T],
514 pow_: Callable[[T, int], T],
515 e: T,
516 ) -> None:
517 """
518 :math:`O(nlogn)` です。
519
520 Args:
521 n_or_a (Union[int, Iterable[T]]): n or a
522 op (Callable[[T, T], T]): 2項演算です。
523 pow_ (Callable[[T, int], T]): 累乗演算です。
524 e (T): 単位元です。
525 """
526 self.op = op
527 self.pow = pow_
528 self.e = e
529 a = [e] * n_or_a if isinstance(n_or_a, int) else list(n_or_a)
530 a.append(e)
531 self.seg = SegmentTree(a, op, e)
532 self.n = len(self.seg)
533 self.indx = WordsizeTreeSet(self.n + 1, range(self.n + 1))
534 self.val = a
535 self.beki = [1] * self.n
536
537 def prod(self, l: int, r: int) -> T:
538 """区間 ``[l, r)`` の総積を返します。
539 :math:`O(logn)` です。
540 ``op`` を :math:`O(logn)` 回、 ``pow_`` を :math:`O(1)` 回呼び出します。
541 """
542 ll = self.indx.ge(l)
543 rr = self.indx.le(r)
544 ans = self.e
545 if ll != l:
546 l0 = self.indx.le(l)
547 beki = self.beki[l0] - (l - l0) if l0 + self.beki[l0] <= r else r - l
548 ans = self.pow(self.val[l0], beki)
549 if ll < rr:
550 ans = self.op(ans, self.seg.prod(ll, rr))
551 if rr != r and l <= rr:
552 ans = self.op(ans, self.pow(self.val[rr], r - rr))
553 return ans
554
555 def apply(self, l: int, r: int, f: T) -> None:
556 """区間 ``[l, r)`` を ``f`` に更新します。
557 :math:`O(logn)` です。
558 ``op`` を :math:`O(logn)` 回、 ``pow_`` を :math:`O(1)` 回呼び出します。
559 """
560 indx, val, beki, seg = self.indx, self.val, self.beki, self.seg
561
562 l0 = indx.le(l)
563 r0 = indx.le(r)
564 if l != l0:
565 seg[l0] = self.pow(val[l0], l - l0)
566 if r != r0:
567 beki[r] = beki[r0] - (r - r0)
568 indx.add(r)
569 val[r] = val[r0]
570 seg[r] = self.pow(val[r], beki[r])
571 if l != l0:
572 beki[l0] = l - l0
573
574 i = indx.gt(l)
575 while i < r:
576 seg[i] = self.e
577 indx.discard(i)
578 i = indx.gt(i)
579 val[l] = f
580 indx.add(l)
581 beki[l] = r - l
582 seg[l] = self.pow(f, beki[l])