persistent_union_find¶
ソースコード¶
from titan_pylib.data_structures.union_find.persistent_union_find import PersistentUnionFind
展開済みコード¶
1# from titan_pylib.data_structures.union_find.persistent_union_find import PersistentUnionFind
2# from titan_pylib.data_structures.array.persistent_array import PersistentArray
3from typing import Iterable, TypeVar, Generic, Optional
4
5T = TypeVar("T")
6
7
8class PersistentArray(Generic[T]):
9
10 class _Node:
11
12 def __init__(self, key: T):
13 self.key: T = key
14 self.left: Optional[PersistentArray._Node] = None
15 self.right: Optional[PersistentArray._Node] = None
16
17 def copy(self) -> "PersistentArray._Node":
18 node = PersistentArray._Node(self.key)
19 node.left = self.left
20 node.right = self.right
21 return node
22
23 def __init__(
24 self, a: Iterable[T] = [], _root: Optional["PersistentArray._Node"] = None
25 ) -> None:
26 self.root = self._build(a) if _root is None else _root
27
28 def _build(self, a: Iterable[T]) -> Optional["PersistentArray._Node"]:
29 pool = [PersistentArray._Node(e) for e in a]
30 self.n = len(pool)
31 if not pool:
32 return None
33 n = len(pool)
34 for i in range(1, n + 1):
35 if 2 * i - 1 < n:
36 pool[i - 1].left = pool[2 * i - 1]
37 if 2 * i < n:
38 pool[i - 1].right = pool[2 * i]
39 return pool[0]
40
41 def _new(self, root: Optional["PersistentArray._Node"]) -> "PersistentArray[T]":
42 res = PersistentArray(_root=root)
43 res.n = self.n
44 return res
45
46 def set(self, k: int, v: T) -> "PersistentArray[T]":
47 assert 0 <= k < len(self), f"IndexError: {self.__class__.__name__}.set({k})"
48 assert self.root
49 node = self.root
50 new_node = node.copy()
51 res = self._new(new_node)
52 k += 1
53 b = k.bit_length()
54 for i in range(b - 2, -1, -1):
55 if k >> i & 1:
56 node = node.right
57 new_node.right = node.copy()
58 new_node = new_node.right
59 else:
60 node = node.left
61 new_node.left = node.copy()
62 new_node = new_node.left
63 new_node.key = v
64 return res
65
66 def get(self, k: int) -> T:
67 assert 0 <= k < len(self), f"IndexError: {self.__class__.__name__}.get({k})"
68 node = self.root
69 k += 1
70 b = k.bit_length()
71 for i in range(b - 2, -1, -1):
72 if k >> i & 1:
73 node = node.right
74 else:
75 node = node.left
76 return node.key
77
78 __getitem__ = get
79
80 def copy(self) -> "PersistentArray[T]":
81 return self._new(None if self.root is None else self.root.copy())
82
83 def tolist(self) -> list[T]:
84 node = self.root
85 a: list[T] = []
86 if not node:
87 return a
88 q = [node]
89 for node in q:
90 a.append(node.key)
91 if node.left:
92 q.append(node.left)
93 if node.right:
94 q.append(node.right)
95 return a
96
97 def __len__(self):
98 return self.n
99
100 def __str__(self):
101 return str(self.tolist())
102
103 def __repr__(self):
104 return f"{self.__class__.__name__}({self})"
105from typing import Optional
106
107
108class PersistentUnionFind:
109
110 def __init__(self, n: int, _parents: Optional[PersistentArray[int]] = None) -> None:
111 """``n`` 個の要素からなる ``PersistentUnionFind`` を構築します。
112 :math:`O(n)` です。
113 """
114 self._n: int = n
115 self._parents: PersistentArray[int] = (
116 PersistentArray([-1] * n) if _parents is None else _parents
117 )
118
119 def _new(self, _parents: PersistentArray[int]) -> "PersistentUnionFind":
120 return PersistentUnionFind(self._n, _parents)
121
122 def copy(self) -> "PersistentUnionFind":
123 """コピーします。
124 :math:`O(1)` です。
125 """
126 return self._new(self._parents.copy())
127
128 def root(self, x: int) -> int:
129 """要素 ``x`` を含む集合の代表元を返します。
130 :math:`O(\\log^2{n})` です。
131 """
132 _parents = self._parents
133 while True:
134 p = _parents.get(x)
135 if p < 0:
136 return x
137 x = p
138
139 def unite(self, x: int, y: int, update: bool = True) -> "PersistentUnionFind":
140 """要素 ``x`` を含む集合と要素 ``y`` を含む集合を併合します。
141 :math:`O(\\log^2{n})` です。
142
143 Args:
144 x (int): 集合の要素です。
145 y (int): 集合の要素です。
146 update (bool, optional): 併合後を新しいインスタンスにするなら ``True`` です。
147
148 Returns:
149 PersistentUnionFind: 併合後の uf です。
150 """
151 x = self.root(x)
152 y = self.root(y)
153 res_parents = self._parents.copy() if update else self._parents
154 if x == y:
155 return self._new(res_parents)
156 px, py = res_parents.get(x), res_parents.get(y)
157 if px > py:
158 x, y = y, x
159 res_parents = res_parents.set(x, px + py)
160 res_parents = res_parents.set(y, x)
161 return self._new(res_parents)
162
163 def size(self, x: int) -> int:
164 """要素 ``x`` を含む集合の要素数を返します。
165 :math:`O(\\log^2{n})` です。
166 """
167 return -self._parents.get(self.root(x))
168
169 def same(self, x: int, y: int) -> bool:
170 """
171 要素 ``x`` と ``y`` が同じ集合に属するなら ``True`` を、
172 そうでないなら ``False`` を返します。
173 :math:`O(\\log^2{n})` です。
174 """
175 return self.root(x) == self.root(y)
仕様¶
- class PersistentUnionFind(n: int, _parents: PersistentArray[int] | None = None)[source]¶
Bases:
object
- copy() PersistentUnionFind [source]¶
コピーします。 \(O(1)\) です。
- same(x: int, y: int) bool [source]¶
要素
x
とy
が同じ集合に属するならTrue
を、 そうでないならFalse
を返します。 \(O(\log^2{n})\) です。
- unite(x: int, y: int, update: bool = True) PersistentUnionFind [source]¶
要素
x
を含む集合と要素y
を含む集合を併合します。 \(O(\log^2{n})\) です。- Parameters:
x (int) – 集合の要素です。
y (int) – 集合の要素です。
update (bool, optional) – 併合後を新しいインスタンスにするなら
True
です。
- Returns:
併合後の uf です。
- Return type: