1from typing import Union, Callable, TypeVar, Generic, Iterable
2
3T = TypeVar("T")
4F = TypeVar("F")
5
6
[docs]
7class LazySegmentTree(Generic[T, F]):
8 """遅延セグ木です。"""
9
10 def __init__(
11 self,
12 n_or_a: Union[int, Iterable[T]],
13 op: Callable[[T, T], T],
14 mapping: Callable[[F, T], T],
15 composition: Callable[[F, F], F],
16 e: T,
17 id: F,
18 ) -> None:
19 self.op: Callable[[T, T], T] = op
20 self.mapping: Callable[[F, T], T] = mapping
21 self.composition: Callable[[F, F], F] = composition
22 self.e: T = e
23 self.id: F = id
24 if isinstance(n_or_a, int):
25 self.n = n_or_a
26 self.log = (self.n - 1).bit_length()
27 self.size = 1 << self.log
28 self.data = [e] * (self.size << 1)
29 else:
30 a = list(n_or_a)
31 self.n = len(a)
32 self.log = (self.n - 1).bit_length()
33 self.size = 1 << self.log
34 data = [e] * (self.size << 1)
35 data[self.size : self.size + self.n] = a
36 for i in range(self.size - 1, 0, -1):
37 data[i] = op(data[i << 1], data[i << 1 | 1])
38 self.data = data
39 self.lazy = [id] * self.size
40
41 def _update(self, k: int) -> None:
42 self.data[k] = self.op(self.data[k << 1], self.data[k << 1 | 1])
43
44 def _all_apply(self, k: int, f: F) -> None:
45 self.data[k] = self.mapping(f, self.data[k])
46 if k >= self.size:
47 return
48 self.lazy[k] = self.composition(f, self.lazy[k])
49
50 def _propagate(self, k: int) -> None:
51 if self.lazy[k] == self.id:
52 return
53 self._all_apply(k << 1, self.lazy[k])
54 self._all_apply(k << 1 | 1, self.lazy[k])
55 self.lazy[k] = self.id
56
[docs]
57 def apply_point(self, k: int, f: F) -> None:
58 k += self.size
59 for i in range(self.log, 0, -1):
60 self._propagate(k >> i)
61 self.data[k] = self.mapping(f, self.data[k])
62 for i in range(1, self.log + 1):
63 self._update(k >> i)
64
65 def _upper_propagate(self, l: int, r: int) -> None:
66 for i in range(self.log, 0, -1):
67 if l >> i << i != l:
68 self._propagate(l >> i)
69 if (r >> i << i != r) and (l >> i != (r - 1) >> i or l >> i << i == l):
70 self._propagate((r - 1) >> i)
71
[docs]
72 def apply(self, l: int, r: int, f: F) -> None:
73 assert (
74 0 <= l <= r <= self.n
75 ), f"IndexError: {self.__class__.__name__}.apply({l}, {r}, {f}), n={self.n}"
76 if l == r:
77 return
78 if f == self.id:
79 return
80 l += self.size
81 r += self.size
82 self._upper_propagate(l, r)
83 l2, r2 = l, r
84 while l < r:
85 if l & 1:
86 self._all_apply(l, f)
87 l += 1
88 if r & 1:
89 self._all_apply(r ^ 1, f)
90 l >>= 1
91 r >>= 1
92 ll, rr = l2, r2 - 1
93 for i in range(1, self.log + 1):
94 ll >>= 1
95 rr >>= 1
96 if ll << i != l2:
97 self._update(ll)
98 if (ll << i == l2 or ll != rr) and (r2 >> i << i != r2):
99 self._update(rr)
100
[docs]
101 def all_apply(self, f: F) -> None:
102 self.lazy[1] = self.composition(f, self.lazy[1])
103
[docs]
104 def prod(self, l: int, r: int) -> T:
105 assert (
106 0 <= l <= r <= self.n
107 ), f"IndexError: {self.__class__.__name__}.prod({l}, {r}), n={self.n}"
108 if l == r:
109 return self.e
110 l += self.size
111 r += self.size
112 self._upper_propagate(l, r)
113 lres = self.e
114 rres = self.e
115 while l < r:
116 if l & 1:
117 lres = self.op(lres, self.data[l])
118 l += 1
119 if r & 1:
120 rres = self.op(self.data[r ^ 1], rres)
121 l >>= 1
122 r >>= 1
123 return self.op(lres, rres)
124
[docs]
125 def all_prod(self) -> T:
126 return self.data[1]
127
[docs]
128 def all_propagate(self) -> None:
129 for i in range(self.size):
130 self._propagate(i)
131
[docs]
132 def tolist(self) -> list[T]:
133 self.all_propagate()
134 return self.data[self.size : self.size + self.n]
135
[docs]
136 def max_right(self, l, f) -> int:
137 assert 0 <= l <= self.n
138 # assert f(self.e)
139 if l == self.size:
140 return self.n
141 l += self.size
142 for i in range(self.log, 0, -1):
143 self._propagate(l >> i)
144 s = self.e
145 while True:
146 while l & 1 == 0:
147 l >>= 1
148 if not f(self.op(s, self.data[l])):
149 while l < self.size:
150 self._propagate(l)
151 l <<= 1
152 if f(self.op(s, self.data[l])):
153 s = self.op(s, self.data[l])
154 l |= 1
155 return l - self.size
156 s = self.op(s, self.data[l])
157 l += 1
158 if l & -l == l:
159 break
160 return self.n
161
[docs]
162 def min_left(self, r: int, f) -> int:
163 assert 0 <= r <= self.n
164 # assert f(self.e)
165 if r == 0:
166 return 0
167 r += self.size
168 for i in range(self.log, 0, -1):
169 self._propagate((r - 1) >> i)
170 s = self.e
171 while True:
172 r -= 1
173 while r > 1 and r & 1:
174 r >>= 1
175 if not f(self.op(self.data[r], s)):
176 while r < self.size:
177 self._propagate(r)
178 r = r << 1 | 1
179 if f(self.op(self.data[r], s)):
180 s = self.op(self.data[r], s)
181 r ^= 1
182 return r + 1 - self.size
183 s = self.op(self.data[r], s)
184 if r & -r == r:
185 break
186 return 0
187
188 def __getitem__(self, k: int) -> T:
189 assert (
190 -self.n <= k < self.n
191 ), f"IndexError: {self.__class__.__name__}[{k}], n={self.n}"
192 if k < 0:
193 k += self.n
194 k += self.size
195 for i in range(self.log, 0, -1):
196 self._propagate(k >> i)
197 return self.data[k]
198
199 def __setitem__(self, k: int, v: T):
200 assert (
201 -self.n <= k < self.n
202 ), f"IndexError: {self.__class__.__name__}[{k}] = {v}, n={self.n}"
203 if k < 0:
204 k += self.n
205 k += self.size
206 for i in range(self.log, 0, -1):
207 self._propagate(k >> i)
208 self.data[k] = v
209 for i in range(1, self.log + 1):
210 self._update(k >> i)
211
212 def __str__(self) -> str:
213 return str(self.tolist())
214
215 def __repr__(self):
216 return f"{self.__class__.__name__}({self})"