Source code for titan_pylib.data_structures.segment_tree.range_set_range_composite
1from titan_pylib.data_structures.segment_tree.segment_tree import SegmentTree
2from titan_pylib.data_structures.set.wordsize_tree_set import WordsizeTreeSet
3from typing import Union, Callable, TypeVar, Generic, Iterable
4
5T = TypeVar("T")
6
7
[docs]
8class RangeSetRangeComposite(Generic[T]):
9 """区間更新+区間積です。"""
10
11 def __init__(
12 self,
13 n_or_a: Union[int, Iterable[T]],
14 op: Callable[[T, T], T],
15 pow_: Callable[[T, int], T],
16 e: T,
17 ) -> None:
18 """
19 :math:`O(nlogn)` です。
20
21 Args:
22 n_or_a (Union[int, Iterable[T]]): n or a
23 op (Callable[[T, T], T]): 2項演算です。
24 pow_ (Callable[[T, int], T]): 累乗演算です。
25 e (T): 単位元です。
26 """
27 self.op = op
28 self.pow = pow_
29 self.e = e
30 a = [e] * n_or_a if isinstance(n_or_a, int) else list(n_or_a)
31 a.append(e)
32 self.seg = SegmentTree(a, op, e)
33 self.n = len(self.seg)
34 self.indx = WordsizeTreeSet(self.n + 1, range(self.n + 1))
35 self.val = a
36 self.beki = [1] * self.n
37
[docs]
38 def prod(self, l: int, r: int) -> T:
39 """区間 ``[l, r)`` の総積を返します。
40 :math:`O(logn)` です。
41 ``op`` を :math:`O(logn)` 回、 ``pow_`` を :math:`O(1)` 回呼び出します。
42 """
43 ll = self.indx.ge(l)
44 rr = self.indx.le(r)
45 ans = self.e
46 if ll != l:
47 l0 = self.indx.le(l)
48 beki = self.beki[l0] - (l - l0) if l0 + self.beki[l0] <= r else r - l
49 ans = self.pow(self.val[l0], beki)
50 if ll < rr:
51 ans = self.op(ans, self.seg.prod(ll, rr))
52 if rr != r and l <= rr:
53 ans = self.op(ans, self.pow(self.val[rr], r - rr))
54 return ans
55
[docs]
56 def apply(self, l: int, r: int, f: T) -> None:
57 """区間 ``[l, r)`` を ``f`` に更新します。
58 :math:`O(logn)` です。
59 ``op`` を :math:`O(logn)` 回、 ``pow_`` を :math:`O(1)` 回呼び出します。
60 """
61 indx, val, beki, seg = self.indx, self.val, self.beki, self.seg
62
63 l0 = indx.le(l)
64 r0 = indx.le(r)
65 if l != l0:
66 seg[l0] = self.pow(val[l0], l - l0)
67 if r != r0:
68 beki[r] = beki[r0] - (r - r0)
69 indx.add(r)
70 val[r] = val[r0]
71 seg[r] = self.pow(val[r], beki[r])
72 if l != l0:
73 beki[l0] = l - l0
74
75 i = indx.gt(l)
76 while i < r:
77 seg[i] = self.e
78 indx.discard(i)
79 i = indx.gt(i)
80 val[l] = f
81 indx.add(l)
82 beki[l] = r - l
83 seg[l] = self.pow(f, beki[l])