1# from titan_pylib.data_structures.segment_tree.lazy_segment_tree import LazySegmentTree
2from typing import Union, Callable, TypeVar, Generic, Iterable
3
4T = TypeVar("T")
5F = TypeVar("F")
6
7
8class LazySegmentTree(Generic[T, F]):
9 """遅延セグ木です。"""
10
11 def __init__(
12 self,
13 n_or_a: Union[int, Iterable[T]],
14 op: Callable[[T, T], T],
15 mapping: Callable[[F, T], T],
16 composition: Callable[[F, F], F],
17 e: T,
18 id: F,
19 ) -> None:
20 self.op: Callable[[T, T], T] = op
21 self.mapping: Callable[[F, T], T] = mapping
22 self.composition: Callable[[F, F], F] = composition
23 self.e: T = e
24 self.id: F = id
25 if isinstance(n_or_a, int):
26 self.n = n_or_a
27 self.log = (self.n - 1).bit_length()
28 self.size = 1 << self.log
29 self.data = [e] * (self.size << 1)
30 else:
31 a = list(n_or_a)
32 self.n = len(a)
33 self.log = (self.n - 1).bit_length()
34 self.size = 1 << self.log
35 data = [e] * (self.size << 1)
36 data[self.size : self.size + self.n] = a
37 for i in range(self.size - 1, 0, -1):
38 data[i] = op(data[i << 1], data[i << 1 | 1])
39 self.data = data
40 self.lazy = [id] * self.size
41
42 def _update(self, k: int) -> None:
43 self.data[k] = self.op(self.data[k << 1], self.data[k << 1 | 1])
44
45 def _all_apply(self, k: int, f: F) -> None:
46 self.data[k] = self.mapping(f, self.data[k])
47 if k >= self.size:
48 return
49 self.lazy[k] = self.composition(f, self.lazy[k])
50
51 def _propagate(self, k: int) -> None:
52 if self.lazy[k] == self.id:
53 return
54 self._all_apply(k << 1, self.lazy[k])
55 self._all_apply(k << 1 | 1, self.lazy[k])
56 self.lazy[k] = self.id
57
58 def apply_point(self, k: int, f: F) -> None:
59 k += self.size
60 for i in range(self.log, 0, -1):
61 self._propagate(k >> i)
62 self.data[k] = self.mapping(f, self.data[k])
63 for i in range(1, self.log + 1):
64 self._update(k >> i)
65
66 def _upper_propagate(self, l: int, r: int) -> None:
67 for i in range(self.log, 0, -1):
68 if l >> i << i != l:
69 self._propagate(l >> i)
70 if (r >> i << i != r) and (l >> i != (r - 1) >> i or l >> i << i == l):
71 self._propagate((r - 1) >> i)
72
73 def apply(self, l: int, r: int, f: F) -> None:
74 assert (
75 0 <= l <= r <= self.n
76 ), f"IndexError: {self.__class__.__name__}.apply({l}, {r}, {f}), n={self.n}"
77 if l == r:
78 return
79 if f == self.id:
80 return
81 l += self.size
82 r += self.size
83 self._upper_propagate(l, r)
84 l2, r2 = l, r
85 while l < r:
86 if l & 1:
87 self._all_apply(l, f)
88 l += 1
89 if r & 1:
90 self._all_apply(r ^ 1, f)
91 l >>= 1
92 r >>= 1
93 ll, rr = l2, r2 - 1
94 for i in range(1, self.log + 1):
95 ll >>= 1
96 rr >>= 1
97 if ll << i != l2:
98 self._update(ll)
99 if (ll << i == l2 or ll != rr) and (r2 >> i << i != r2):
100 self._update(rr)
101
102 def all_apply(self, f: F) -> None:
103 self.lazy[1] = self.composition(f, self.lazy[1])
104
105 def prod(self, l: int, r: int) -> T:
106 assert (
107 0 <= l <= r <= self.n
108 ), f"IndexError: {self.__class__.__name__}.prod({l}, {r}), n={self.n}"
109 if l == r:
110 return self.e
111 l += self.size
112 r += self.size
113 self._upper_propagate(l, r)
114 lres = self.e
115 rres = self.e
116 while l < r:
117 if l & 1:
118 lres = self.op(lres, self.data[l])
119 l += 1
120 if r & 1:
121 rres = self.op(self.data[r ^ 1], rres)
122 l >>= 1
123 r >>= 1
124 return self.op(lres, rres)
125
126 def all_prod(self) -> T:
127 return self.data[1]
128
129 def all_propagate(self) -> None:
130 for i in range(self.size):
131 self._propagate(i)
132
133 def tolist(self) -> list[T]:
134 self.all_propagate()
135 return self.data[self.size : self.size + self.n]
136
137 def max_right(self, l, f) -> int:
138 assert 0 <= l <= self.n
139 # assert f(self.e)
140 if l == self.size:
141 return self.n
142 l += self.size
143 for i in range(self.log, 0, -1):
144 self._propagate(l >> i)
145 s = self.e
146 while True:
147 while l & 1 == 0:
148 l >>= 1
149 if not f(self.op(s, self.data[l])):
150 while l < self.size:
151 self._propagate(l)
152 l <<= 1
153 if f(self.op(s, self.data[l])):
154 s = self.op(s, self.data[l])
155 l |= 1
156 return l - self.size
157 s = self.op(s, self.data[l])
158 l += 1
159 if l & -l == l:
160 break
161 return self.n
162
163 def min_left(self, r: int, f) -> int:
164 assert 0 <= r <= self.n
165 # assert f(self.e)
166 if r == 0:
167 return 0
168 r += self.size
169 for i in range(self.log, 0, -1):
170 self._propagate((r - 1) >> i)
171 s = self.e
172 while True:
173 r -= 1
174 while r > 1 and r & 1:
175 r >>= 1
176 if not f(self.op(self.data[r], s)):
177 while r < self.size:
178 self._propagate(r)
179 r = r << 1 | 1
180 if f(self.op(self.data[r], s)):
181 s = self.op(self.data[r], s)
182 r ^= 1
183 return r + 1 - self.size
184 s = self.op(self.data[r], s)
185 if r & -r == r:
186 break
187 return 0
188
189 def __getitem__(self, k: int) -> T:
190 assert (
191 -self.n <= k < self.n
192 ), f"IndexError: {self.__class__.__name__}[{k}], n={self.n}"
193 if k < 0:
194 k += self.n
195 k += self.size
196 for i in range(self.log, 0, -1):
197 self._propagate(k >> i)
198 return self.data[k]
199
200 def __setitem__(self, k: int, v: T):
201 assert (
202 -self.n <= k < self.n
203 ), f"IndexError: {self.__class__.__name__}[{k}] = {v}, n={self.n}"
204 if k < 0:
205 k += self.n
206 k += self.size
207 for i in range(self.log, 0, -1):
208 self._propagate(k >> i)
209 self.data[k] = v
210 for i in range(1, self.log + 1):
211 self._update(k >> i)
212
213 def __str__(self) -> str:
214 return str(self.tolist())
215
216 def __repr__(self):
217 return f"{self.__class__.__name__}({self})"