1# from titan_pylib.graph.euler_tour import EulerTour
2# from titan_pylib.data_structures.fenwick_tree.fenwick_tree import FenwickTree
3from typing import Union, Iterable, Optional
4
5
6class FenwickTree:
7 """FenwickTreeです。"""
8
9 def __init__(self, n_or_a: Union[Iterable[int], int]):
10 """構築します。
11 :math:`O(n)` です。
12
13 Args:
14 n_or_a (Union[Iterable[int], int]): `n_or_a` が `int` のとき、初期値 `0` 、長さ `n` で構築します。
15 `n_or_a` が `Iterable` のとき、初期値 `a` で構築します。
16 """
17 if isinstance(n_or_a, int):
18 self._size = n_or_a
19 self._tree = [0] * (self._size + 1)
20 else:
21 a = n_or_a if isinstance(n_or_a, list) else list(n_or_a)
22 _size = len(a)
23 _tree = [0] + a
24 for i in range(1, _size):
25 if i + (i & -i) <= _size:
26 _tree[i + (i & -i)] += _tree[i]
27 self._size = _size
28 self._tree = _tree
29 self._s = 1 << (self._size - 1).bit_length()
30
31 def pref(self, r: int) -> int:
32 """区間 ``[0, r)`` の総和を返します。
33 :math:`O(\\log{n})` です。
34 """
35 assert (
36 0 <= r <= self._size
37 ), f"IndexError: {self.__class__.__name__}.pref({r}), n={self._size}"
38 ret, _tree = 0, self._tree
39 while r > 0:
40 ret += _tree[r]
41 r &= r - 1
42 return ret
43
44 def suff(self, l: int) -> int:
45 """区間 ``[l, n)`` の総和を返します。
46 :math:`O(\\log{n})` です。
47 """
48 assert (
49 0 <= l < self._size
50 ), f"IndexError: {self.__class__.__name__}.suff({l}), n={self._size}"
51 return self.pref(self._size) - self.pref(l)
52
53 def sum(self, l: int, r: int) -> int:
54 """区間 ``[l, r)`` の総和を返します。
55 :math:`O(\\log{n})` です。
56 """
57 assert (
58 0 <= l <= r <= self._size
59 ), f"IndexError: {self.__class__.__name__}.sum({l}, {r}), n={self._size}"
60 _tree = self._tree
61 res = 0
62 while r > l:
63 res += _tree[r]
64 r &= r - 1
65 while l > r:
66 res -= _tree[l]
67 l &= l - 1
68 return res
69
70 prod = sum
71
72 def __getitem__(self, k: int) -> int:
73 """位置 ``k`` の要素を返します。
74 :math:`O(\\log{n})` です。
75 """
76 assert (
77 -self._size <= k < self._size
78 ), f"IndexError: {self.__class__.__name__}[{k}], n={self._size}"
79 if k < 0:
80 k += self._size
81 return self.sum(k, k + 1)
82
83 def add(self, k: int, x: int) -> None:
84 """``k`` 番目の値に ``x`` を加えます。
85 :math:`O(\\log{n})` です。
86 """
87 assert (
88 0 <= k < self._size
89 ), f"IndexError: {self.__class__.__name__}.add({k}, {x}), n={self._size}"
90 k += 1
91 _tree = self._tree
92 while k <= self._size:
93 _tree[k] += x
94 k += k & -k
95
96 def __setitem__(self, k: int, x: int):
97 """``k`` 番目の値を ``x`` に更新します。
98 :math:`O(\\log{n})` です。
99 """
100 assert (
101 -self._size <= k < self._size
102 ), f"IndexError: {self.__class__.__name__}[{k}] = {x}, n={self._size}"
103 if k < 0:
104 k += self._size
105 pre = self[k]
106 self.add(k, x - pre)
107
108 def bisect_left(self, w: int) -> Optional[int]:
109 i, s, _size, _tree = 0, self._s, self._size, self._tree
110 while s:
111 if i + s <= _size and _tree[i + s] < w:
112 w -= _tree[i + s]
113 i += s
114 s >>= 1
115 return i if w else None
116
117 def bisect_right(self, w: int) -> int:
118 i, s, _size, _tree = 0, self._s, self._size, self._tree
119 while s:
120 if i + s <= _size and _tree[i + s] <= w:
121 w -= _tree[i + s]
122 i += s
123 s >>= 1
124 return i
125
126 def _pop(self, k: int) -> int:
127 assert k >= 0
128 i, acc, s, _size, _tree = 0, 0, self._s, self._size, self._tree
129 while s:
130 if i + s <= _size:
131 if acc + _tree[i + s] <= k:
132 acc += _tree[i + s]
133 i += s
134 else:
135 _tree[i + s] -= 1
136 s >>= 1
137 return i
138
139 def tolist(self) -> list[int]:
140 """リストにして返します。
141 :math:`O(n)` です。
142 """
143 sub = [self.pref(i) for i in range(self._size + 1)]
144 return [sub[i + 1] - sub[i] for i in range(self._size)]
145
146 @staticmethod
147 def get_inversion_num(a: list[int], compress: bool = False) -> int:
148 inv = 0
149 if compress:
150 a_ = sorted(set(a))
151 z = {e: i for i, e in enumerate(a_)}
152 fw = FenwickTree(len(a_) + 1)
153 for i, e in enumerate(a):
154 inv += i - fw.pref(z[e] + 1)
155 fw.add(z[e], 1)
156 else:
157 fw = FenwickTree(len(a) + 1)
158 for i, e in enumerate(a):
159 inv += i - fw.pref(e + 1)
160 fw.add(e, 1)
161 return inv
162
163 def __str__(self):
164 return str(self.tolist())
165
166 def __repr__(self):
167 return f"{self.__class__.__name__}({self})"
168# from titan_pylib.data_structures.segment_tree.segment_tree_RmQ import SegmentTreeRmQ
169# from titan_pylib.data_structures.segment_tree.segment_tree_interface import (
170# SegmentTreeInterface,
171# )
172from abc import ABC, abstractmethod
173from typing import TypeVar, Generic, Union, Iterable, Callable
174
175T = TypeVar("T")
176
177
178class SegmentTreeInterface(ABC, Generic[T]):
179
180 @abstractmethod
181 def __init__(self, n_or_a: Union[int, Iterable[T]], op: Callable[[T, T], T], e: T):
182 raise NotImplementedError
183
184 @abstractmethod
185 def set(self, k: int, v: T) -> None:
186 raise NotImplementedError
187
188 @abstractmethod
189 def get(self, k: int) -> T:
190 raise NotImplementedError
191
192 @abstractmethod
193 def prod(self, l: int, r: int) -> T:
194 raise NotImplementedError
195
196 @abstractmethod
197 def all_prod(self) -> T:
198 raise NotImplementedError
199
200 @abstractmethod
201 def max_right(self, l: int, f: Callable[[T], bool]) -> int:
202 raise NotImplementedError
203
204 @abstractmethod
205 def min_left(self, r: int, f: Callable[[T], bool]) -> int:
206 raise NotImplementedError
207
208 @abstractmethod
209 def tolist(self) -> list[T]:
210 raise NotImplementedError
211
212 @abstractmethod
213 def __getitem__(self, k: int) -> T:
214 raise NotImplementedError
215
216 @abstractmethod
217 def __setitem__(self, k: int, v: T) -> None:
218 raise NotImplementedError
219
220 @abstractmethod
221 def __str__(self):
222 raise NotImplementedError
223
224 @abstractmethod
225 def __repr__(self):
226 raise NotImplementedError
227# from titan_pylib.my_class.supports_less_than import SupportsLessThan
228from typing import Protocol
229
230
231class SupportsLessThan(Protocol):
232
233 def __lt__(self, other) -> bool: ...
234from typing import Generic, Iterable, TypeVar, Union
235
236T = TypeVar("T", bound=SupportsLessThan)
237
238
239class SegmentTreeRmQ(SegmentTreeInterface, Generic[T]):
240 """RmQ セグ木です。"""
241
242 def __init__(self, _n_or_a: Union[int, Iterable[T]], e: T) -> None:
243 self._e = e
244 if isinstance(_n_or_a, int):
245 self._n = _n_or_a
246 self._log = (self._n - 1).bit_length()
247 self._size = 1 << self._log
248 self._data = [self._e] * (self._size << 1)
249 else:
250 _n_or_a = list(_n_or_a)
251 self._n = len(_n_or_a)
252 self._log = (self._n - 1).bit_length()
253 self._size = 1 << self._log
254 _data = [self._e] * (self._size << 1)
255 _data[self._size : self._size + self._n] = _n_or_a
256 for i in range(self._size - 1, 0, -1):
257 _data[i] = (
258 _data[i << 1]
259 if _data[i << 1] < _data[i << 1 | 1]
260 else _data[i << 1 | 1]
261 )
262 self._data = _data
263
264 def set(self, k: int, v: T) -> None:
265 if k < 0:
266 k += self._n
267 assert (
268 0 <= k < self._n
269 ), f"IndexError: {self.__class__.__name__}.set({k}: int, {v}: T), n={self._n}"
270 k += self._size
271 self._data[k] = v
272 for _ in range(self._log):
273 k >>= 1
274 self._data[k] = (
275 self._data[k << 1]
276 if self._data[k << 1] < self._data[k << 1 | 1]
277 else self._data[k << 1 | 1]
278 )
279
280 def get(self, k: int) -> T:
281 if k < 0:
282 k += self._n
283 assert (
284 0 <= k < self._n
285 ), f"IndexError: {self.__class__.__name__}.get({k}: int), n={self._n}"
286 return self._data[k + self._size]
287
288 def prod(self, l: int, r: int) -> T:
289 assert (
290 0 <= l <= r <= self._n
291 ), f"IndexError: {self.__class__.__name__}.prod({l}: int, {r}: int)"
292 l += self._size
293 r += self._size
294 res = self._e
295 while l < r:
296 if l & 1:
297 if res > self._data[l]:
298 res = self._data[l]
299 l += 1
300 if r & 1:
301 r ^= 1
302 if res > self._data[r]:
303 res = self._data[r]
304 l >>= 1
305 r >>= 1
306 return res
307
308 def all_prod(self) -> T:
309 return self._data[1]
310
311 def max_right(self, l: int, f=lambda lr: lr):
312 assert (
313 0 <= l <= self._n
314 ), f"IndexError: {self.__class__.__name__}.max_right({l}, f) index out of range"
315 assert f(
316 self._e
317 ), f"{self.__class__.__name__}.max_right({l}, f), f({self._e}) must be true."
318 if l == self._n:
319 return self._n
320 l += self._size
321 s = self._e
322 while True:
323 while l & 1 == 0:
324 l >>= 1
325 if not f(min(s, self._data[l])):
326 while l < self._size:
327 l <<= 1
328 if f(min(s, self._data[l])):
329 if s > self._data[l]:
330 s = self._data[l]
331 l += 1
332 return l - self._size
333 s = min(s, self._data[l])
334 l += 1
335 if l & -l == l:
336 break
337 return self._n
338
339 def min_left(self, r: int, f=lambda lr: lr):
340 assert (
341 0 <= r <= self._n
342 ), f"IndexError: {self.__class__.__name__}.min_left({r}, f) index out of range"
343 assert f(
344 self._e
345 ), f"{self.__class__.__name__}.min_left({r}, f), f({self._e}) must be true."
346 if r == 0:
347 return 0
348 r += self._size
349 s = self._e
350 while True:
351 r -= 1
352 while r > 1 and r & 1:
353 r >>= 1
354 if not f(min(self._data[r], s)):
355 while r < self._size:
356 r = r << 1 | 1
357 if f(min(self._data[r], s)):
358 if s > self._data[r]:
359 s = self._data[r]
360 r -= 1
361 return r + 1 - self._size
362 s = min(self._data[r], s)
363 if r & -r == r:
364 break
365 return 0
366
367 def tolist(self) -> list[T]:
368 return [self.get(i) for i in range(self._n)]
369
370 def show(self) -> None:
371 print(
372 f"<{self.__class__.__name__}> [\n"
373 + "\n".join(
374 [
375 " "
376 + " ".join(
377 map(str, [self._data[(1 << i) + j] for j in range(1 << i)])
378 )
379 for i in range(self._log + 1)
380 ]
381 )
382 + "\n]"
383 )
384
385 def __getitem__(self, k: int) -> T:
386 assert (
387 -self._n <= k < self._n
388 ), f"IndexError: {self.__class__.__name__}.__getitem__({k}: int), n={self._n}"
389 return self.get(k)
390
391 def __setitem__(self, k: int, v: T):
392 assert (
393 -self._n <= k < self._n
394 ), f"IndexError: {self.__class__.__name__}.__setitem__{k}: int, {v}: T), n={self._n}"
395 self.set(k, v)
396
397 def __str__(self):
398 return "[" + ", ".join(map(str, (self.get(i) for i in range(self._n)))) + "]"
399
400 def __repr__(self):
401 return f"{self.__class__.__name__}({self})"
402
403
404class EulerTour:
405
406 def __init__(
407 self, G: list[list[tuple[int, int]]], root: int, vertexcost: list[int] = []
408 ) -> None:
409 n = len(G)
410 if not vertexcost:
411 vertexcost = [0] * n
412
413 path = [0] * (2 * n)
414 vcost1 = [0] * (2 * n) # for vertex subtree
415 vcost2 = [0] * (2 * n) # for vertex path
416 ecost1 = [0] * (2 * n) # for edge subtree
417 ecost2 = [0] * (2 * n) # for edge path
418 nodein = [0] * n
419 nodeout = [0] * n
420 depth = [-1] * n
421
422 curtime = -1
423 depth[root] = 0
424 stack: list[tuple[int, int]] = [(~root, 0), (root, 0)]
425 while stack:
426 curtime += 1
427 v, ec = stack.pop()
428 if v >= 0:
429 nodein[v] = curtime
430 path[curtime] = v
431 ecost1[curtime] = ec
432 ecost2[curtime] = ec
433 vcost1[curtime] = vertexcost[v]
434 vcost2[curtime] = vertexcost[v]
435 if len(G[v]) == 1:
436 nodeout[v] = curtime + 1
437 for x, c in G[v]:
438 if depth[x] != -1:
439 continue
440 depth[x] = depth[v] + 1
441 stack.append((~v, c))
442 stack.append((x, c))
443 else:
444 v = ~v
445 path[curtime] = v
446 ecost1[curtime] = 0
447 ecost2[curtime] = -ec
448 vcost1[curtime] = 0
449 vcost2[curtime] = -vertexcost[v]
450 nodeout[v] = curtime
451
452 # ---------------------- #
453
454 self._n = n
455 self._depth = depth
456 self._nodein = nodein
457 self._nodeout = nodeout
458 self._vertexcost = vertexcost
459 self._path = path
460
461 self._vcost_subtree = FenwickTree(vcost1)
462 self._vcost_path = FenwickTree(vcost2)
463 self._ecost_subtree = FenwickTree(ecost1)
464 self._ecost_path = FenwickTree(ecost2)
465
466 bit = len(path).bit_length()
467 self.msk = (1 << bit) - 1
468 a: list[int] = [(depth[v] << bit) + i for i, v in enumerate(path)]
469 self._st: SegmentTreeRmQ[int] = SegmentTreeRmQ(a, e=max(a))
470
471 def lca(self, u: int, v: int) -> int:
472 if u == v:
473 return u
474 l = min(self._nodein[u], self._nodein[v])
475 r = max(self._nodeout[u], self._nodeout[v])
476 ind = self._st.prod(l, r) & self.msk
477 return self._path[ind]
478
479 def lca_mul(self, a: list[int]) -> int:
480 l, r = self._n + 1, -self._n - 1
481 for e in a:
482 l = min(l, self._nodein[e])
483 r = max(r, self._nodeout[e])
484 ind = self._st.prod(l, r) & self.msk
485 return self._path[ind]
486
487 def subtree_vcost(self, v: int) -> int:
488 l = self._nodein[v]
489 r = self._nodeout[v]
490 return self._vcost_subtree.prod(l, r)
491
492 def subtree_ecost(self, v: int) -> int:
493 l = self._nodein[v]
494 r = self._nodeout[v]
495 return self._ecost_subtree.prod(l + 1, r)
496
497 def _path_vcost(self, v: int) -> int:
498 """頂点 v を含む"""
499 return self._vcost_path.pref(self._nodein[v] + 1)
500
501 def _path_ecost(self, v: int) -> int:
502 """根から頂点 v までの辺"""
503 return self._ecost_path.pref(self._nodein[v] + 1)
504
505 def path_vcost(self, u: int, v: int) -> int:
506 a = self.lca(u, v)
507 return (
508 self._path_vcost(u)
509 + self._path_vcost(v)
510 - 2 * self._path_vcost(a)
511 + self._vertexcost[a]
512 )
513
514 def path_ecost(self, u: int, v: int) -> int:
515 return (
516 self._path_ecost(u)
517 + self._path_ecost(v)
518 - 2 * self._path_ecost(self.lca(u, v))
519 )
520
521 def add_vertex(self, v: int, w: int) -> None:
522 """Add w to vertex x. / O(logN)"""
523 l = self._nodein[v]
524 r = self._nodeout[v]
525 self._vcost_subtree.add(l, w)
526 self._vcost_path.add(l, w)
527 self._vcost_path.add(r, -w)
528 self._vertexcost[v] += w
529
530 def set_vertex(self, v: int, w: int) -> None:
531 """Set w to vertex v. / O(logN)"""
532 self.add_vertex(v, w - self._vertexcost[v])
533
534 def add_edge(self, u: int, v: int, w: int) -> None:
535 """Add w to edge([u - v]). / O(logN)"""
536 if self._depth[u] < self._depth[v]:
537 u, v = v, u
538 l = self._nodein[u]
539 r = self._nodeout[u]
540 self._ecost_subtree.add(l, w)
541 self._ecost_subtree.add(r + 1, -w)
542 self._ecost_path.add(l, w)
543 self._ecost_path.add(r + 1, -w)
544
545 def set_edge(self, u: int, v: int, w: int) -> None:
546 """Set w to edge([u - v]). / O(logN)"""
547 self.add_edge(u, v, w - self.path_ecost(u, v))