1# from titan_pylib.data_structures.union_find.partial_persistent_union_find import PartialPersistentUnionFind
2# from titan_pylib.data_structures.array.partial_persistent_array import (
3# PartialPersistentArray,
4# )
5from typing import Iterable, TypeVar, Generic
6
7T = TypeVar("T")
8
9
10class PartialPersistentArray(Generic[T]):
11 """部分永続配列です。
12 最新版の更新と、過去へのアクセスが可能です。
13 """
14
15 def __init__(self, a: Iterable[T]):
16 """``a`` から部分永続配列を構築します。
17 初期配列のバージョンは ``0`` です。
18 :math:`O(n)` です。
19 """
20 self.a: list[list[T]] = [[e] for e in a]
21 self.t: list[list[int]] = [[0] for _ in range(len(self.a))]
22 self.last_time: int = 0
23
24 def set(self, k: int, v: T, t: int) -> None:
25 """位置 ``k`` を ``v`` に更新します。
26 :math:`O(1)` です。
27
28 Args:
29 k (int): インデックスです。
30 v (T): 更新する値です。
31 t (int): 新たな配列のバージョンです。
32 """
33 assert t >= self.last_time
34 assert t > self.t[k][-1]
35 self.a[k].append(v)
36 self.t[k].append(t)
37 self.last_time = t
38
39 def get(self, k: int, t: int = -1) -> T:
40 """位置 ``k`` 、バージョン ``t`` の要素を返します。
41 :math:`O(\\log{n})` です。
42
43 Args:
44 k (int): インデックスです。
45 t (int, optional): バージョンです。デフォルトは最新バージョンです。
46 """
47 if t == -1 or t >= self.t[k][-1]:
48 return self.a[k][-1]
49 tk = self.t[k]
50 ok, ng = 0, len(tk)
51 while ng - ok > 1:
52 mid = (ok + ng) // 2
53 if tk[mid] <= t:
54 ok = mid
55 else:
56 ng = mid
57 return self.a[k][ok]
58
59 def tolist(self, t: int) -> list[T]:
60 """バージョン ``t`` の配列を返します。
61 :math:`O(n\\log{n})` です。
62 """
63 return [self.get(i, t) for i in range(len(self))]
64
65 def show(self, t: int) -> None:
66 print(f"Time: {t}", end="")
67 print([self.get(i, t) for i in range(len(self))])
68
69 def show_all(self) -> None:
70 """すべてのバージョンの配列を表示します。"""
71 for i in range(self.last_time):
72 self.show(i)
73
74 def __len__(self):
75 return len(self.a)
76
77
78class PartialPersistentUnionFind:
79
80 def __init__(self, n: int):
81 self._n: int = n
82 self._parents: PartialPersistentArray[int] = PartialPersistentArray([-1] * n)
83 self._last_time: int = 0
84
85 def root(self, x: int, t: int = -1) -> int:
86 assert t == -1 or t <= self._last_time
87 while True:
88 p = self._parents.get(x, t)
89 if p < 0:
90 return x
91 x = p
92
93 def unite(self, x: int, y: int, t: int) -> bool:
94 assert t == -1 or t >= self._last_time
95 self._last_time = t
96 x = self.root(x, t)
97 y = self.root(y, t)
98 if x == y:
99 return False
100 if self._parents.get(x, t) > self._parents.get(y, t):
101 x, y = y, x
102 self._parents.set(x, self._parents.get(x, t) + self._parents.get(y, t), t)
103 self._parents.set(y, x, t)
104 return True
105
106 def size(self, x: int, t: int = -1) -> int:
107 assert t == -1 or t <= self._last_time
108 return -self._parents.get(self.root(x, t), t)
109
110 def same(self, x: int, y: int, t: int = -1) -> bool:
111 assert t == -1 or t <= self._last_time
112 return self.root(x, t) == self.root(y, t)