hld_lazy_segment_tree¶
ソースコード¶
from titan_pylib.graph.hld.hld_lazy_segment_tree import HLDLazySegmentTree
展開済みコード¶
1# from titan_pylib.graph.hld.hld_lazy_segment_tree import HLDLazySegmentTree
2# from titan_pylib.data_structures.segment_tree.lazy_segment_tree import LazySegmentTree
3from typing import Union, Callable, TypeVar, Generic, Iterable
4
5T = TypeVar("T")
6F = TypeVar("F")
7
8
9class LazySegmentTree(Generic[T, F]):
10 """遅延セグ木です。"""
11
12 def __init__(
13 self,
14 n_or_a: Union[int, Iterable[T]],
15 op: Callable[[T, T], T],
16 mapping: Callable[[F, T], T],
17 composition: Callable[[F, F], F],
18 e: T,
19 id: F,
20 ) -> None:
21 self.op: Callable[[T, T], T] = op
22 self.mapping: Callable[[F, T], T] = mapping
23 self.composition: Callable[[F, F], F] = composition
24 self.e: T = e
25 self.id: F = id
26 if isinstance(n_or_a, int):
27 self.n = n_or_a
28 self.log = (self.n - 1).bit_length()
29 self.size = 1 << self.log
30 self.data = [e] * (self.size << 1)
31 else:
32 a = list(n_or_a)
33 self.n = len(a)
34 self.log = (self.n - 1).bit_length()
35 self.size = 1 << self.log
36 data = [e] * (self.size << 1)
37 data[self.size : self.size + self.n] = a
38 for i in range(self.size - 1, 0, -1):
39 data[i] = op(data[i << 1], data[i << 1 | 1])
40 self.data = data
41 self.lazy = [id] * self.size
42
43 def _update(self, k: int) -> None:
44 self.data[k] = self.op(self.data[k << 1], self.data[k << 1 | 1])
45
46 def _all_apply(self, k: int, f: F) -> None:
47 self.data[k] = self.mapping(f, self.data[k])
48 if k >= self.size:
49 return
50 self.lazy[k] = self.composition(f, self.lazy[k])
51
52 def _propagate(self, k: int) -> None:
53 if self.lazy[k] == self.id:
54 return
55 self._all_apply(k << 1, self.lazy[k])
56 self._all_apply(k << 1 | 1, self.lazy[k])
57 self.lazy[k] = self.id
58
59 def apply_point(self, k: int, f: F) -> None:
60 k += self.size
61 for i in range(self.log, 0, -1):
62 self._propagate(k >> i)
63 self.data[k] = self.mapping(f, self.data[k])
64 for i in range(1, self.log + 1):
65 self._update(k >> i)
66
67 def _upper_propagate(self, l: int, r: int) -> None:
68 for i in range(self.log, 0, -1):
69 if l >> i << i != l:
70 self._propagate(l >> i)
71 if (r >> i << i != r) and (l >> i != (r - 1) >> i or l >> i << i == l):
72 self._propagate((r - 1) >> i)
73
74 def apply(self, l: int, r: int, f: F) -> None:
75 assert (
76 0 <= l <= r <= self.n
77 ), f"IndexError: {self.__class__.__name__}.apply({l}, {r}, {f}), n={self.n}"
78 if l == r:
79 return
80 if f == self.id:
81 return
82 l += self.size
83 r += self.size
84 self._upper_propagate(l, r)
85 l2, r2 = l, r
86 while l < r:
87 if l & 1:
88 self._all_apply(l, f)
89 l += 1
90 if r & 1:
91 self._all_apply(r ^ 1, f)
92 l >>= 1
93 r >>= 1
94 ll, rr = l2, r2 - 1
95 for i in range(1, self.log + 1):
96 ll >>= 1
97 rr >>= 1
98 if ll << i != l2:
99 self._update(ll)
100 if (ll << i == l2 or ll != rr) and (r2 >> i << i != r2):
101 self._update(rr)
102
103 def all_apply(self, f: F) -> None:
104 self.lazy[1] = self.composition(f, self.lazy[1])
105
106 def prod(self, l: int, r: int) -> T:
107 assert (
108 0 <= l <= r <= self.n
109 ), f"IndexError: {self.__class__.__name__}.prod({l}, {r}), n={self.n}"
110 if l == r:
111 return self.e
112 l += self.size
113 r += self.size
114 self._upper_propagate(l, r)
115 lres = self.e
116 rres = self.e
117 while l < r:
118 if l & 1:
119 lres = self.op(lres, self.data[l])
120 l += 1
121 if r & 1:
122 rres = self.op(self.data[r ^ 1], rres)
123 l >>= 1
124 r >>= 1
125 return self.op(lres, rres)
126
127 def all_prod(self) -> T:
128 return self.data[1]
129
130 def all_propagate(self) -> None:
131 for i in range(self.size):
132 self._propagate(i)
133
134 def tolist(self) -> list[T]:
135 self.all_propagate()
136 return self.data[self.size : self.size + self.n]
137
138 def max_right(self, l, f) -> int:
139 assert 0 <= l <= self.n
140 # assert f(self.e)
141 if l == self.size:
142 return self.n
143 l += self.size
144 for i in range(self.log, 0, -1):
145 self._propagate(l >> i)
146 s = self.e
147 while True:
148 while l & 1 == 0:
149 l >>= 1
150 if not f(self.op(s, self.data[l])):
151 while l < self.size:
152 self._propagate(l)
153 l <<= 1
154 if f(self.op(s, self.data[l])):
155 s = self.op(s, self.data[l])
156 l |= 1
157 return l - self.size
158 s = self.op(s, self.data[l])
159 l += 1
160 if l & -l == l:
161 break
162 return self.n
163
164 def min_left(self, r: int, f) -> int:
165 assert 0 <= r <= self.n
166 # assert f(self.e)
167 if r == 0:
168 return 0
169 r += self.size
170 for i in range(self.log, 0, -1):
171 self._propagate((r - 1) >> i)
172 s = self.e
173 while True:
174 r -= 1
175 while r > 1 and r & 1:
176 r >>= 1
177 if not f(self.op(self.data[r], s)):
178 while r < self.size:
179 self._propagate(r)
180 r = r << 1 | 1
181 if f(self.op(self.data[r], s)):
182 s = self.op(self.data[r], s)
183 r ^= 1
184 return r + 1 - self.size
185 s = self.op(self.data[r], s)
186 if r & -r == r:
187 break
188 return 0
189
190 def __getitem__(self, k: int) -> T:
191 assert (
192 -self.n <= k < self.n
193 ), f"IndexError: {self.__class__.__name__}[{k}], n={self.n}"
194 if k < 0:
195 k += self.n
196 k += self.size
197 for i in range(self.log, 0, -1):
198 self._propagate(k >> i)
199 return self.data[k]
200
201 def __setitem__(self, k: int, v: T):
202 assert (
203 -self.n <= k < self.n
204 ), f"IndexError: {self.__class__.__name__}[{k}] = {v}, n={self.n}"
205 if k < 0:
206 k += self.n
207 k += self.size
208 for i in range(self.log, 0, -1):
209 self._propagate(k >> i)
210 self.data[k] = v
211 for i in range(1, self.log + 1):
212 self._update(k >> i)
213
214 def __str__(self) -> str:
215 return str(self.tolist())
216
217 def __repr__(self):
218 return f"{self.__class__.__name__}({self})"
219# from titan_pylib.graph.hld.hld import HLD
220from typing import Any, Iterator
221
222
223class HLD:
224
225 def __init__(self, G: list[list[int]], root: int):
226 """``root`` を根とする木 ``G`` を HLD します。
227 :math:`O(n)` です。
228
229 Args:
230 G (list[list[int]]): 木を表す隣接リストです。
231 root (int): 根です。
232 """
233 n = len(G)
234 self.n: int = n
235 self.G: list[list[int]] = G
236 self.size: list[int] = [1] * n
237 self.par: list[int] = [-1] * n
238 self.dep: list[int] = [-1] * n
239 self.nodein: list[int] = [0] * n
240 self.nodeout: list[int] = [0] * n
241 self.head: list[int] = [0] * n
242 self.hld: list[int] = [0] * n
243 self._dfs(root)
244
245 def _dfs(self, root: int) -> None:
246 dep, par, size, G = self.dep, self.par, self.size, self.G
247 dep[root] = 0
248 stack = [~root, root]
249 while stack:
250 v = stack.pop()
251 if v >= 0:
252 dep_nxt = dep[v] + 1
253 for x in G[v]:
254 if dep[x] != -1:
255 continue
256 dep[x] = dep_nxt
257 stack.append(~x)
258 stack.append(x)
259 else:
260 v = ~v
261 G_v, dep_v = G[v], dep[v]
262 for i, x in enumerate(G_v):
263 if dep[x] < dep_v:
264 par[v] = x
265 continue
266 size[v] += size[x]
267 if size[x] > size[G_v[0]]:
268 G_v[0], G_v[i] = G_v[i], G_v[0]
269
270 head, nodein, nodeout, hld = self.head, self.nodein, self.nodeout, self.hld
271 curtime = 0
272 stack = [~root, root]
273 while stack:
274 v = stack.pop()
275 if v >= 0:
276 if par[v] == -1:
277 head[v] = v
278 nodein[v] = curtime
279 hld[curtime] = v
280 curtime += 1
281 if not G[v]:
282 continue
283 G_v0 = G[v][0]
284 for x in reversed(G[v]):
285 if x == par[v]:
286 continue
287 head[x] = head[v] if x == G_v0 else x
288 stack.append(~x)
289 stack.append(x)
290 else:
291 nodeout[~v] = curtime
292
293 def build_list(self, a: list[Any]) -> list[Any]:
294 """``hld配列`` を基にインデックスを振りなおします。非破壊的です。
295 :math:`O(n)` です。
296
297 Args:
298 a (list[Any]): 元の配列です。
299
300 Returns:
301 list[Any]: 振りなおし後の配列です。
302 """
303 return [a[e] for e in self.hld]
304
305 def for_each_vertex_path(self, u: int, v: int) -> Iterator[tuple[int, int]]:
306 """``u-v`` パスに対応する区間のインデックスを返します。
307 :math:`O(\\log{n})` です。
308 """
309 head, nodein, dep, par = self.head, self.nodein, self.dep, self.par
310 while head[u] != head[v]:
311 if dep[head[u]] < dep[head[v]]:
312 u, v = v, u
313 yield nodein[head[u]], nodein[u] + 1
314 u = par[head[u]]
315 if dep[u] < dep[v]:
316 u, v = v, u
317 yield nodein[v], nodein[u] + 1
318
319 def for_each_vertex_subtree(self, v: int) -> Iterator[tuple[int, int]]:
320 """頂点 ``v`` の部分木に対応する区間のインデックスを返します。
321 :math:`O(1)` です。
322 """
323 yield self.nodein[v], self.nodeout[v]
324
325 def path_kth_elm(self, s: int, t: int, k: int) -> int:
326 """``s`` から ``t`` に向かって ``k`` 個進んだ頂点のインデックスを返します。
327 存在しないときは ``-1`` を返します。
328 :math:`O(\\log{n})` です。
329 """
330 head, dep, par = self.head, self.dep, self.par
331 lca = self.lca(s, t)
332 d = dep[s] + dep[t] - 2 * dep[lca]
333 if d < k:
334 return -1
335 if dep[s] - dep[lca] < k:
336 s = t
337 k = d - k
338 hs = head[s]
339 while dep[s] - dep[hs] < k:
340 k -= dep[s] - dep[hs] + 1
341 s = par[hs]
342 hs = head[s]
343 return self.hld[self.nodein[s] - k]
344
345 def lca(self, u: int, v: int) -> int:
346 """``u``, ``v`` の LCA を返します。
347 :math:`O(\\log{n})` です。
348 """
349 nodein, head, par = self.nodein, self.head, self.par
350 while True:
351 if nodein[u] > nodein[v]:
352 u, v = v, u
353 if head[u] == head[v]:
354 return u
355 v = par[head[v]]
356
357 def dist(self, u: int, v: int) -> int:
358 return self.dep[u] + self.dep[v] - 2 * self.dep[self.lca(u, v)]
359
360 def is_on_path(self, u: int, v: int, a: int) -> bool:
361 """Return True if (a is on path(u - v)) else False. / O(logN)"""
362 return self.dist(u, a) + self.dist(a, v) == self.dist(u, v)
363from typing import Union, Iterable, Callable, TypeVar, Generic
364
365T = TypeVar("T")
366F = TypeVar("F")
367
368
369class HLDLazySegmentTree(Generic[T, F]):
370 """遅延セグ木搭載HLDです。
371
372 Note:
373 **非可換に対応してます。**
374 """
375
376 def __init__(
377 self,
378 hld: HLD,
379 n_or_a: Union[int, Iterable[T]],
380 op: Callable[[T, T], T],
381 mapping: Callable[[F, T], T],
382 composition: Callable[[F, F], F],
383 e: T,
384 id: F,
385 ) -> None:
386 self.hld: HLD = hld
387 a = (
388 [e] * n_or_a
389 if isinstance(n_or_a, int)
390 else self.hld.build_list(list(n_or_a))
391 )
392 self.seg: LazySegmentTree[T, F] = LazySegmentTree(
393 a, op, mapping, composition, e, id
394 )
395 self.rseg: LazySegmentTree[T, F] = LazySegmentTree(
396 a[::-1], op, mapping, composition, e, id
397 )
398 self.op: Callable[[T, T], T] = op
399 self.e: T = e
400
401 def path_prod(self, u: int, v: int) -> T:
402 """頂点 ``u`` から頂点 ``v`` へのパスの集約値を返します。
403 :math:`O(\\log^2{n})` です。
404
405 Args:
406 u (int): パスの **始点** です。
407 v (int): パスの **終点** です。
408
409 Returns:
410 T: 求める集約値です。
411 """
412 head, nodein, dep, par, n = (
413 self.hld.head,
414 self.hld.nodein,
415 self.hld.dep,
416 self.hld.par,
417 self.hld.n,
418 )
419 lres, rres = self.e, self.e
420 seg, rseg = self.seg, self.rseg
421 while head[u] != head[v]:
422 if dep[head[u]] > dep[head[v]]:
423 lres = self.op(lres, rseg.prod(n - nodein[u] - 1, n - nodein[head[u]]))
424 u = par[head[u]]
425 else:
426 rres = self.op(seg.prod(nodein[head[v]], nodein[v] + 1), rres)
427 v = par[head[v]]
428 if dep[u] > dep[v]:
429 lres = self.op(lres, rseg.prod(n - nodein[u] - 1, n - nodein[v]))
430 else:
431 lres = self.op(lres, seg.prod(nodein[u], nodein[v] + 1))
432 return self.op(lres, rres)
433
434 def path_apply(self, u: int, v: int, f: F) -> None:
435 """頂点 ``u`` から頂点 ``v`` へのパスに作用させます。
436 :math:`O(\\log^2{n})` です。
437
438 Args:
439 u (int): パスの **始点** です。
440 v (int): パスの **終点** です。
441 f (F): 作用素です。
442 """
443 head, nodein, dep, par = (
444 self.hld.head,
445 self.hld.nodein,
446 self.hld.dep,
447 self.hld.par,
448 )
449 while head[u] != head[v]:
450 if dep[head[u]] < dep[head[v]]:
451 u, v = v, u
452 self.seg.apply(nodein[head[u]], nodein[u] + 1, f)
453 self.rseg.apply(
454 self.hld.n - (nodein[u] + 1 - 1) - 1,
455 self.hld.n - nodein[head[u]] - 1 + 1,
456 f,
457 )
458 u = par[head[u]]
459 if dep[u] < dep[v]:
460 u, v = v, u
461 self.seg.apply(nodein[v], nodein[u] + 1, f)
462 self.rseg.apply(
463 self.hld.n - (nodein[u] + 1 - 1) - 1, self.hld.n - nodein[v] - 1 + 1, f
464 )
465
466 def get(self, k: int) -> T:
467 """頂点の値を返します。
468 :math:`O(\\log{n})` です。
469
470 Args:
471 k (int): 頂点のインデックスです。
472
473 Returns:
474 T: 頂点の値です。
475 """
476 return self.seg[self.hld.nodein[k]]
477
478 def set(self, k: int, v: T) -> None:
479 """頂点の値を更新します。
480 :math:`O(\\log{n})` です。
481
482 Args:
483 k (int): 頂点のインデックスです。
484 v (T): 更新する値です。
485 """
486 self.seg[self.hld.nodein[k]] = v
487 self.rseg[self.hld.n - self.hld.nodein[k] - 1] = v
488
489 __getitem__ = get
490 __setitem__ = set
491
492 def subtree_prod(self, v: int) -> T:
493 """部分木の集約値を返します。
494 :math:`O(\\log{n})` です。
495
496 Args:
497 v (int): 根とする頂点です。
498
499 Returns:
500 T: 求める集約値です。
501 """
502 return self.seg.prod(self.hld.nodein[v], self.hld.nodeout[v])
503
504 def subtree_apply(self, v: int, f: F) -> None:
505 """部分木に作用させます。
506 :math:`O(\\log{n})` です。
507
508 Args:
509 v (int): 根とする頂点です。
510 f (F): 作用素です。
511 """
512 self.seg.apply(self.hld.nodein[v], self.hld.nodeout[v], f)
513 self.rseg.apply(
514 self.hld.n - self.hld.nodeout[v] - 1 - 1,
515 self.hld.n - self.hld.nodein[v] - 1 + 1,
516 f,
517 )
仕様¶
- class HLDLazySegmentTree(hld: HLD, n_or_a: int | Iterable[T], op: Callable[[T, T], T], mapping: Callable[[F, T], T], composition: Callable[[F, F], F], e: T, id: F)[source]¶
Bases:
Generic
[T
,F
]遅延セグ木搭載HLDです。
Note
非可換に対応してます。
- __getitem__(k: int) T ¶
頂点の値を返します。 \(O(\log{n})\) です。
- Parameters:
k (int) – 頂点のインデックスです。
- Returns:
頂点の値です。
- Return type:
T
- __setitem__(k: int, v: T) None ¶
頂点の値を更新します。 \(O(\log{n})\) です。
- Parameters:
k (int) – 頂点のインデックスです。
v (T) – 更新する値です。
- get(k: int) T [source]¶
頂点の値を返します。 \(O(\log{n})\) です。
- Parameters:
k (int) – 頂点のインデックスです。
- Returns:
頂点の値です。
- Return type:
T
- path_apply(u: int, v: int, f: F) None [source]¶
頂点
u
から頂点v
へのパスに作用させます。 \(O(\log^2{n})\) です。- Parameters:
u (int) – パスの 始点 です。
v (int) – パスの 終点 です。
f (F) – 作用素です。
- path_prod(u: int, v: int) T [source]¶
頂点
u
から頂点v
へのパスの集約値を返します。 \(O(\log^2{n})\) です。- Parameters:
u (int) – パスの 始点 です。
v (int) – パスの 終点 です。
- Returns:
求める集約値です。
- Return type:
T
- set(k: int, v: T) None [source]¶
頂点の値を更新します。 \(O(\log{n})\) です。
- Parameters:
k (int) – 頂点のインデックスです。
v (T) – 更新する値です。