hld_segment_tree¶
ソースコード¶
from titan_pylib.graph.hld.hld_segment_tree import HLDSegmentTree
展開済みコード¶
1# from titan_pylib.graph.hld.hld_segment_tree import HLDSegmentTree
2# from titan_pylib.data_structures.segment_tree.segment_tree import SegmentTree
3# from titan_pylib.data_structures.segment_tree.segment_tree_interface import (
4# SegmentTreeInterface,
5# )
6from abc import ABC, abstractmethod
7from typing import TypeVar, Generic, Union, Iterable, Callable
8
9T = TypeVar("T")
10
11
12class SegmentTreeInterface(ABC, Generic[T]):
13
14 @abstractmethod
15 def __init__(self, n_or_a: Union[int, Iterable[T]], op: Callable[[T, T], T], e: T):
16 raise NotImplementedError
17
18 @abstractmethod
19 def set(self, k: int, v: T) -> None:
20 raise NotImplementedError
21
22 @abstractmethod
23 def get(self, k: int) -> T:
24 raise NotImplementedError
25
26 @abstractmethod
27 def prod(self, l: int, r: int) -> T:
28 raise NotImplementedError
29
30 @abstractmethod
31 def all_prod(self) -> T:
32 raise NotImplementedError
33
34 @abstractmethod
35 def max_right(self, l: int, f: Callable[[T], bool]) -> int:
36 raise NotImplementedError
37
38 @abstractmethod
39 def min_left(self, r: int, f: Callable[[T], bool]) -> int:
40 raise NotImplementedError
41
42 @abstractmethod
43 def tolist(self) -> list[T]:
44 raise NotImplementedError
45
46 @abstractmethod
47 def __getitem__(self, k: int) -> T:
48 raise NotImplementedError
49
50 @abstractmethod
51 def __setitem__(self, k: int, v: T) -> None:
52 raise NotImplementedError
53
54 @abstractmethod
55 def __str__(self):
56 raise NotImplementedError
57
58 @abstractmethod
59 def __repr__(self):
60 raise NotImplementedError
61from typing import Generic, Iterable, TypeVar, Callable, Union
62
63T = TypeVar("T")
64
65
66class SegmentTree(SegmentTreeInterface, Generic[T]):
67 """セグ木です。非再帰です。"""
68
69 def __init__(
70 self, n_or_a: Union[int, Iterable[T]], op: Callable[[T, T], T], e: T
71 ) -> None:
72 """``SegmentTree`` を構築します。
73 :math:`O(n)` です。
74
75 Args:
76 n_or_a (Union[int, Iterable[T]]): ``n: int`` のとき、 ``e`` を初期値として長さ ``n`` の ``SegmentTree`` を構築します。
77 ``a: Iterable[T]`` のとき、 ``a`` から ``SegmentTree`` を構築します。
78 op (Callable[[T, T], T]): 2項演算の関数です。
79 e (T): 単位元です。
80 """
81 self._op = op
82 self._e = e
83 if isinstance(n_or_a, int):
84 self._n = n_or_a
85 self._log = (self._n - 1).bit_length()
86 self._size = 1 << self._log
87 self._data = [e] * (self._size << 1)
88 else:
89 n_or_a = list(n_or_a)
90 self._n = len(n_or_a)
91 self._log = (self._n - 1).bit_length()
92 self._size = 1 << self._log
93 _data = [e] * (self._size << 1)
94 _data[self._size : self._size + self._n] = n_or_a
95 for i in range(self._size - 1, 0, -1):
96 _data[i] = op(_data[i << 1], _data[i << 1 | 1])
97 self._data = _data
98
99 def set(self, k: int, v: T) -> None:
100 """一点更新です。
101 :math:`O(\\log{n})` です。
102
103 Args:
104 k (int): 更新するインデックスです。
105 v (T): 更新する値です。
106
107 制約:
108 :math:`-n \\leq n \\leq k < n`
109 """
110 assert (
111 -self._n <= k < self._n
112 ), f"IndexError: {self.__class__.__name__}.set({k}, {v}), n={self._n}"
113 if k < 0:
114 k += self._n
115 k += self._size
116 self._data[k] = v
117 for _ in range(self._log):
118 k >>= 1
119 self._data[k] = self._op(self._data[k << 1], self._data[k << 1 | 1])
120
121 def get(self, k: int) -> T:
122 """一点取得です。
123 :math:`O(1)` です。
124
125 Args:
126 k (int): インデックスです。
127
128 制約:
129 :math:`-n \\leq n \\leq k < n`
130 """
131 assert (
132 -self._n <= k < self._n
133 ), f"IndexError: {self.__class__.__name__}.get({k}), n={self._n}"
134 if k < 0:
135 k += self._n
136 return self._data[k + self._size]
137
138 def prod(self, l: int, r: int) -> T:
139 """区間 ``[l, r)`` の総積を返します。
140 :math:`O(\\log{n})` です。
141
142 Args:
143 l (int): インデックスです。
144 r (int): インデックスです。
145
146 制約:
147 :math:`0 \\leq l \\leq r \\leq n`
148 """
149 assert (
150 0 <= l <= r <= self._n
151 ), f"IndexError: {self.__class__.__name__}.prod({l}, {r})"
152 l += self._size
153 r += self._size
154 lres = self._e
155 rres = self._e
156 while l < r:
157 if l & 1:
158 lres = self._op(lres, self._data[l])
159 l += 1
160 if r & 1:
161 rres = self._op(self._data[r ^ 1], rres)
162 l >>= 1
163 r >>= 1
164 return self._op(lres, rres)
165
166 def all_prod(self) -> T:
167 """区間 ``[0, n)`` の総積を返します。
168 :math:`O(1)` です。
169 """
170 return self._data[1]
171
172 def max_right(self, l: int, f: Callable[[T], bool]) -> int:
173 """Find the largest index R s.t. f([l, R)) == True. / O(\\log{n})"""
174 assert (
175 0 <= l <= self._n
176 ), f"IndexError: {self.__class__.__name__}.max_right({l}, f) index out of range"
177 # assert f(self._e), \
178 # f'{self.__class__.__name__}.max_right({l}, f), f({self._e}) must be true.'
179 if l == self._n:
180 return self._n
181 l += self._size
182 s = self._e
183 while True:
184 while l & 1 == 0:
185 l >>= 1
186 if not f(self._op(s, self._data[l])):
187 while l < self._size:
188 l <<= 1
189 if f(self._op(s, self._data[l])):
190 s = self._op(s, self._data[l])
191 l |= 1
192 return l - self._size
193 s = self._op(s, self._data[l])
194 l += 1
195 if l & -l == l:
196 break
197 return self._n
198
199 def min_left(self, r: int, f: Callable[[T], bool]) -> int:
200 """Find the smallest index L s.t. f([L, r)) == True. / O(\\log{n})"""
201 assert (
202 0 <= r <= self._n
203 ), f"IndexError: {self.__class__.__name__}.min_left({r}, f) index out of range"
204 # assert f(self._e), \
205 # f'{self.__class__.__name__}.min_left({r}, f), f({self._e}) must be true.'
206 if r == 0:
207 return 0
208 r += self._size
209 s = self._e
210 while True:
211 r -= 1
212 while r > 1 and r & 1:
213 r >>= 1
214 if not f(self._op(self._data[r], s)):
215 while r < self._size:
216 r = r << 1 | 1
217 if f(self._op(self._data[r], s)):
218 s = self._op(self._data[r], s)
219 r ^= 1
220 return r + 1 - self._size
221 s = self._op(self._data[r], s)
222 if r & -r == r:
223 break
224 return 0
225
226 def tolist(self) -> list[T]:
227 """リストにして返します。
228 :math:`O(n)` です。
229 """
230 return [self.get(i) for i in range(self._n)]
231
232 def show(self) -> None:
233 """デバッグ用のメソッドです。"""
234 print(
235 f"<{self.__class__.__name__}> [\n"
236 + "\n".join(
237 [
238 " "
239 + " ".join(
240 map(str, [self._data[(1 << i) + j] for j in range(1 << i)])
241 )
242 for i in range(self._log + 1)
243 ]
244 )
245 + "\n]"
246 )
247
248 def __getitem__(self, k: int) -> T:
249 assert (
250 -self._n <= k < self._n
251 ), f"IndexError: {self.__class__.__name__}.__getitem__({k}), n={self._n}"
252 return self.get(k)
253
254 def __setitem__(self, k: int, v: T):
255 assert (
256 -self._n <= k < self._n
257 ), f"IndexError: {self.__class__.__name__}.__setitem__{k}, {v}), n={self._n}"
258 self.set(k, v)
259
260 def __len__(self) -> int:
261 return self._n
262
263 def __str__(self) -> str:
264 return str(self.tolist())
265
266 def __repr__(self) -> str:
267 return f"{self.__class__.__name__}({self})"
268# from titan_pylib.graph.hld.hld import HLD
269from typing import Any, Iterator
270
271
272class HLD:
273
274 def __init__(self, G: list[list[int]], root: int):
275 """``root`` を根とする木 ``G`` を HLD します。
276 :math:`O(n)` です。
277
278 Args:
279 G (list[list[int]]): 木を表す隣接リストです。
280 root (int): 根です。
281 """
282 n = len(G)
283 self.n: int = n
284 self.G: list[list[int]] = G
285 self.size: list[int] = [1] * n
286 self.par: list[int] = [-1] * n
287 self.dep: list[int] = [-1] * n
288 self.nodein: list[int] = [0] * n
289 self.nodeout: list[int] = [0] * n
290 self.head: list[int] = [0] * n
291 self.hld: list[int] = [0] * n
292 self._dfs(root)
293
294 def _dfs(self, root: int) -> None:
295 dep, par, size, G = self.dep, self.par, self.size, self.G
296 dep[root] = 0
297 stack = [~root, root]
298 while stack:
299 v = stack.pop()
300 if v >= 0:
301 dep_nxt = dep[v] + 1
302 for x in G[v]:
303 if dep[x] != -1:
304 continue
305 dep[x] = dep_nxt
306 stack.append(~x)
307 stack.append(x)
308 else:
309 v = ~v
310 G_v, dep_v = G[v], dep[v]
311 for i, x in enumerate(G_v):
312 if dep[x] < dep_v:
313 par[v] = x
314 continue
315 size[v] += size[x]
316 if size[x] > size[G_v[0]]:
317 G_v[0], G_v[i] = G_v[i], G_v[0]
318
319 head, nodein, nodeout, hld = self.head, self.nodein, self.nodeout, self.hld
320 curtime = 0
321 stack = [~root, root]
322 while stack:
323 v = stack.pop()
324 if v >= 0:
325 if par[v] == -1:
326 head[v] = v
327 nodein[v] = curtime
328 hld[curtime] = v
329 curtime += 1
330 if not G[v]:
331 continue
332 G_v0 = G[v][0]
333 for x in reversed(G[v]):
334 if x == par[v]:
335 continue
336 head[x] = head[v] if x == G_v0 else x
337 stack.append(~x)
338 stack.append(x)
339 else:
340 nodeout[~v] = curtime
341
342 def build_list(self, a: list[Any]) -> list[Any]:
343 """``hld配列`` を基にインデックスを振りなおします。非破壊的です。
344 :math:`O(n)` です。
345
346 Args:
347 a (list[Any]): 元の配列です。
348
349 Returns:
350 list[Any]: 振りなおし後の配列です。
351 """
352 return [a[e] for e in self.hld]
353
354 def for_each_vertex_path(self, u: int, v: int) -> Iterator[tuple[int, int]]:
355 """``u-v`` パスに対応する区間のインデックスを返します。
356 :math:`O(\\log{n})` です。
357 """
358 head, nodein, dep, par = self.head, self.nodein, self.dep, self.par
359 while head[u] != head[v]:
360 if dep[head[u]] < dep[head[v]]:
361 u, v = v, u
362 yield nodein[head[u]], nodein[u] + 1
363 u = par[head[u]]
364 if dep[u] < dep[v]:
365 u, v = v, u
366 yield nodein[v], nodein[u] + 1
367
368 def for_each_vertex_subtree(self, v: int) -> Iterator[tuple[int, int]]:
369 """頂点 ``v`` の部分木に対応する区間のインデックスを返します。
370 :math:`O(1)` です。
371 """
372 yield self.nodein[v], self.nodeout[v]
373
374 def path_kth_elm(self, s: int, t: int, k: int) -> int:
375 """``s`` から ``t`` に向かって ``k`` 個進んだ頂点のインデックスを返します。
376 存在しないときは ``-1`` を返します。
377 :math:`O(\\log{n})` です。
378 """
379 head, dep, par = self.head, self.dep, self.par
380 lca = self.lca(s, t)
381 d = dep[s] + dep[t] - 2 * dep[lca]
382 if d < k:
383 return -1
384 if dep[s] - dep[lca] < k:
385 s = t
386 k = d - k
387 hs = head[s]
388 while dep[s] - dep[hs] < k:
389 k -= dep[s] - dep[hs] + 1
390 s = par[hs]
391 hs = head[s]
392 return self.hld[self.nodein[s] - k]
393
394 def lca(self, u: int, v: int) -> int:
395 """``u``, ``v`` の LCA を返します。
396 :math:`O(\\log{n})` です。
397 """
398 nodein, head, par = self.nodein, self.head, self.par
399 while True:
400 if nodein[u] > nodein[v]:
401 u, v = v, u
402 if head[u] == head[v]:
403 return u
404 v = par[head[v]]
405
406 def dist(self, u: int, v: int) -> int:
407 return self.dep[u] + self.dep[v] - 2 * self.dep[self.lca(u, v)]
408
409 def is_on_path(self, u: int, v: int, a: int) -> bool:
410 """Return True if (a is on path(u - v)) else False. / O(logN)"""
411 return self.dist(u, a) + self.dist(a, v) == self.dist(u, v)
412from typing import Union, Iterable, Callable, TypeVar, Generic
413
414T = TypeVar("T")
415
416
417class HLDSegmentTree(Generic[T]):
418 """セグ木搭載HLDです。
419
420 Note:
421 **非可換に対応していません。**
422 """
423
424 def __init__(
425 self, hld: HLD, n_or_a: Union[int, Iterable[T]], op: Callable[[T, T], T], e: T
426 ) -> None:
427 self.hld: HLD = hld
428 n_or_a = (
429 n_or_a if isinstance(n_or_a, int) else self.hld.build_list(list(n_or_a))
430 )
431 self.seg: SegmentTree[T] = SegmentTree(n_or_a, op, e)
432 self.op: Callable[[T, T], T] = op
433 self.e: T = e
434
435 def path_prod(self, u: int, v: int) -> T:
436 """頂点 ``u`` から頂点 ``v`` へのパスの集約値を返します。
437 :math:`O(\\log^2{n})` です。
438
439 Note:
440 **非可換に対応していません。**
441
442 Args:
443 u (int): パスの端点です。
444 v (int): パスの端点です。
445
446 Returns:
447 T: 求める集約値です。
448 """
449 head, nodein, dep, par = (
450 self.hld.head,
451 self.hld.nodein,
452 self.hld.dep,
453 self.hld.par,
454 )
455 res = self.e
456 while head[u] != head[v]:
457 if dep[head[u]] < dep[head[v]]:
458 u, v = v, u
459 res = self.op(res, self.seg.prod(nodein[head[u]], nodein[u] + 1))
460 u = par[head[u]]
461 if dep[u] < dep[v]:
462 u, v = v, u
463 return self.op(res, self.seg.prod(nodein[v], nodein[u] + 1))
464
465 def get(self, k: int) -> T:
466 """頂点の値を返します。
467 :math:`O(\\log{n})` です。
468
469 Args:
470 k (int): 頂点のインデックスです。
471
472 Returns:
473 T: 頂点の値です。
474 """
475 return self.seg[self.hld.nodein[k]]
476
477 def set(self, k: int, v: T) -> None:
478 """頂点の値を更新します。
479 :math:`O(\\log{n})` です。
480
481 Args:
482 k (int): 頂点のインデックスです。
483 v (T): 更新する値です。
484 """
485 self.seg[self.hld.nodein[k]] = v
486
487 __getitem__ = get
488 __setitem__ = set
489
490 def subtree_prod(self, v: int) -> T:
491 """部分木の集約値を返します。
492 :math:`O(\\log{n})` です。
493
494 Args:
495 v (int): 根とする頂点です。
496
497 Returns:
498 T: 求める集約値です。
499 """
500 return self.seg.prod(self.hld.nodein[v], self.hld.nodeout[v])
仕様¶
- class HLDSegmentTree(hld: HLD, n_or_a: int | Iterable[T], op: Callable[[T, T], T], e: T)[source]¶
Bases:
Generic
[T
]セグ木搭載HLDです。
Note
非可換に対応していません。
- __getitem__(k: int) T ¶
頂点の値を返します。 \(O(\log{n})\) です。
- Parameters:
k (int) – 頂点のインデックスです。
- Returns:
頂点の値です。
- Return type:
T
- __setitem__(k: int, v: T) None ¶
頂点の値を更新します。 \(O(\log{n})\) です。
- Parameters:
k (int) – 頂点のインデックスです。
v (T) – 更新する値です。
- get(k: int) T [source]¶
頂点の値を返します。 \(O(\log{n})\) です。
- Parameters:
k (int) – 頂点のインデックスです。
- Returns:
頂点の値です。
- Return type:
T
- path_prod(u: int, v: int) T [source]¶
頂点
u
から頂点v
へのパスの集約値を返します。 \(O(\log^2{n})\) です。Note
非可換に対応していません。
- Parameters:
u (int) – パスの端点です。
v (int) – パスの端点です。
- Returns:
求める集約値です。
- Return type:
T