Source code for titan_pylib.data_structures.union_find.undoable_union_find
1from collections import defaultdict
2
3
[docs]
4class UndoableUnionFind:
5
6 def __init__(self, n: int) -> None:
7 """``n`` 個の要素からなる ``UndoableUnionFind`` を構築します。
8 :math:`O(n)` です。
9 """
10 self._n: int = n
11 self._parents: list[int] = [-1] * n
12 self._group_count: int = n
13 self._history: list[tuple[int, int]] = []
14
[docs]
15 def root(self, x: int) -> int:
16 """要素 ``x`` を含む集合の代表元を返します。
17 :math:`O(\\log{n})` です。
18 """
19 while self._parents[x] >= 0:
20 x = self._parents[x]
21 return x
22
[docs]
23 def unite(self, x: int, y: int) -> bool:
24 """要素 ``x`` を含む集合と要素 ``y`` を含む集合を併合します。
25 :math:`O(\\log{n})` です。
26
27 Returns:
28 bool: もともと同じ集合であれば ``False``、そうでなければ ``True`` を返します。
29 """
30 x = self.root(x)
31 y = self.root(y)
32 if x == y:
33 self._history.append((-1, -1))
34 return False
35 if self._parents[x] > self._parents[y]:
36 x, y = y, x
37 self._group_count -= 1
38 self._history.append((x, self._parents[x]))
39 self._history.append((y, self._parents[y]))
40 self._parents[x] += self._parents[y]
41 self._parents[y] = x
42 return True
43
[docs]
44 def undo(self) -> None:
45 """直前の ``unite`` クエリを戻します。
46 :math:`O(\\log{n})` です。
47 """
48 assert (
49 self._history
50 ), f"Error: {self.__class__.__name__}.undo() with non history."
51 y, py = self._history.pop()
52 if y == -1:
53 return
54 self._group_count += 1
55 x, px = self._history.pop()
56 self._parents[y] = py
57 self._parents[x] = px
58
[docs]
59 def size(self, x: int) -> int:
60 """要素 ``x`` を含む集合の要素数を返します。
61 :math:`O(\\log{n})` です。
62 """
63 return -self._parents[self.root(x)]
64
[docs]
65 def same(self, x: int, y: int) -> bool:
66 """
67 要素 ``x`` と ``y`` が同じ集合に属するなら ``True`` を、
68 そうでないなら ``False`` を返します。
69 :math:`O(\\log{n})` です。
70 """
71 return self.root(x) == self.root(y)
72
[docs]
73 def all_roots(self) -> list[int]:
74 """全ての集合の代表元からなるリストを返します。
75 :math:`O(n)` です。
76
77 Returns:
78 list[int]: 昇順であることが保証されます。
79 """
80 return [i for i, x in enumerate(self._parents) if x < 0]
81
[docs]
82 def all_group_members(self) -> defaultdict:
83 """
84 `key` に代表元、 `value` に `key` を代表元とする集合のリストをもつ `defaultdict` を返します。
85 :math:`O(n\\log{n})` です。
86 """
87 group_members = defaultdict(list)
88 for member in range(self._n):
89 group_members[self.root(member)].append(member)
90 return group_members
91
[docs]
92 def group_count(self) -> int:
93 """集合の総数を返します。
94 :math:`O(1)` です。
95 """
96 return self._group_count
97
[docs]
98 def clear(self) -> None:
99 """集合の連結状態をなくします(初期状態に戻します)。
100 :math:`O(n)` です。
101 """
102 for i in range(self._n):
103 self._parents[i] = -1
104
[docs]
105 def __str__(self) -> str:
106 """よしなにします。
107
108 :math:`O(n\\log{n})` です。
109 """
110 return (
111 f"<{self.__class__.__name__}> [\n"
112 + "\n".join(f" {k}: {v}" for k, v in self.all_group_members().items())
113 + "\n]"
114 )