undoable_union_find¶
ソースコード¶
from titan_pylib.data_structures.union_find.undoable_union_find import UndoableUnionFind
展開済みコード¶
1# from titan_pylib.data_structures.union_find.undoable_union_find import UndoableUnionFind
2from collections import defaultdict
3
4
5class UndoableUnionFind:
6
7 def __init__(self, n: int) -> None:
8 """``n`` 個の要素からなる ``UndoableUnionFind`` を構築します。
9 :math:`O(n)` です。
10 """
11 self._n: int = n
12 self._parents: list[int] = [-1] * n
13 self._group_count: int = n
14 self._history: list[tuple[int, int]] = []
15
16 def root(self, x: int) -> int:
17 """要素 ``x`` を含む集合の代表元を返します。
18 :math:`O(\\log{n})` です。
19 """
20 while self._parents[x] >= 0:
21 x = self._parents[x]
22 return x
23
24 def unite(self, x: int, y: int) -> bool:
25 """要素 ``x`` を含む集合と要素 ``y`` を含む集合を併合します。
26 :math:`O(\\log{n})` です。
27
28 Returns:
29 bool: もともと同じ集合であれば ``False``、そうでなければ ``True`` を返します。
30 """
31 x = self.root(x)
32 y = self.root(y)
33 if x == y:
34 self._history.append((-1, -1))
35 return False
36 if self._parents[x] > self._parents[y]:
37 x, y = y, x
38 self._group_count -= 1
39 self._history.append((x, self._parents[x]))
40 self._history.append((y, self._parents[y]))
41 self._parents[x] += self._parents[y]
42 self._parents[y] = x
43 return True
44
45 def undo(self) -> None:
46 """直前の ``unite`` クエリを戻します。
47 :math:`O(\\log{n})` です。
48 """
49 assert (
50 self._history
51 ), f"Error: {self.__class__.__name__}.undo() with non history."
52 y, py = self._history.pop()
53 if y == -1:
54 return
55 self._group_count += 1
56 x, px = self._history.pop()
57 self._parents[y] = py
58 self._parents[x] = px
59
60 def size(self, x: int) -> int:
61 """要素 ``x`` を含む集合の要素数を返します。
62 :math:`O(\\log{n})` です。
63 """
64 return -self._parents[self.root(x)]
65
66 def same(self, x: int, y: int) -> bool:
67 """
68 要素 ``x`` と ``y`` が同じ集合に属するなら ``True`` を、
69 そうでないなら ``False`` を返します。
70 :math:`O(\\log{n})` です。
71 """
72 return self.root(x) == self.root(y)
73
74 def all_roots(self) -> list[int]:
75 """全ての集合の代表元からなるリストを返します。
76 :math:`O(n)` です。
77
78 Returns:
79 list[int]: 昇順であることが保証されます。
80 """
81 return [i for i, x in enumerate(self._parents) if x < 0]
82
83 def all_group_members(self) -> defaultdict:
84 """
85 `key` に代表元、 `value` に `key` を代表元とする集合のリストをもつ `defaultdict` を返します。
86 :math:`O(n\\log{n})` です。
87 """
88 group_members = defaultdict(list)
89 for member in range(self._n):
90 group_members[self.root(member)].append(member)
91 return group_members
92
93 def group_count(self) -> int:
94 """集合の総数を返します。
95 :math:`O(1)` です。
96 """
97 return self._group_count
98
99 def clear(self) -> None:
100 """集合の連結状態をなくします(初期状態に戻します)。
101 :math:`O(n)` です。
102 """
103 for i in range(self._n):
104 self._parents[i] = -1
105
106 def __str__(self) -> str:
107 """よしなにします。
108
109 :math:`O(n\\log{n})` です。
110 """
111 return (
112 f"<{self.__class__.__name__}> [\n"
113 + "\n".join(f" {k}: {v}" for k, v in self.all_group_members().items())
114 + "\n]"
115 )
仕様¶
- class UndoableUnionFind(n: int)[source]¶
Bases:
object
- all_group_members() defaultdict [source]¶
key に代表元、 value に key を代表元とする集合のリストをもつ defaultdict を返します。 \(O(n\log{n})\) です。
- all_roots() list[int] [source]¶
全ての集合の代表元からなるリストを返します。 \(O(n)\) です。
- Returns:
昇順であることが保証されます。
- Return type:
list[int]