1# from titan_pylib.data_structures.union_find.union_find_heavy import UnionFindHeavy
2from collections import defaultdict
3
4
5class UnionFindHeavy:
6
7 def __init__(self, n: int) -> None:
8 self._n: int = n
9 self._group_numbers: int = n
10 self._parents: list[int] = [-1] * n # defaultdict(lambda: -1)
11 # self._roots = set(range(n))
12 self._edges: list[int] = [0] * n
13 self._G: list[list[int]] = [[] for _ in range(n)]
14
15 def root(self, x: int) -> int:
16 assert 0 <= x < self._n, f"{self.__class__.__name__}.root(x) IndexError, x={x}"
17 a = x
18 while self._parents[a] >= 0:
19 a = self._parents[a]
20 # return a # not compressing path.
21 while self._parents[x] >= 0:
22 y = x
23 x = self._parents[x]
24 self._parents[y] = a
25 return a
26
27 def unite(self, x: int, y: int) -> bool:
28 assert (
29 0 <= x < self._n and 0 <= y < self._n
30 ), f"IndexError: {self.__class__.__name__}.unite({x}, {y})"
31 x = self.root(x)
32 y = self.root(y)
33 self._edges[x] += 1
34 self._edges[y] += 1
35 if x == y:
36 return False
37 self._G[x].append(y)
38 self._G[y].append(x)
39 self._group_numbers -= 1
40 if self._parents[x] > self._parents[y]:
41 x, y = y, x
42 self._parents[x] += self._parents[y]
43 self._parents[y] = x
44 # self._roots.discard(y)
45 return True
46
47 def get_edges(self, x: int) -> int:
48 return self._edges[self.root(x)]
49
50 # x -> y
51 def unite_right(self, x: int, y: int) -> int:
52 assert (
53 0 <= x < self._n and 0 <= y < self._n
54 ), f"IndexError: {self.__class__.__name__}.unite_right(x: int, y: int), x={x}, y={y}"
55 x = self.root(x)
56 y = self.root(y)
57 if x == y:
58 return x
59 self._G[x].append(y)
60 self._G[y].append(x)
61 self._group_numbers -= 1
62 self._parents[y] += self._parents[x]
63 self._parents[x] = y
64 # self._roots.discard(y)
65 return y
66
67 # x <- y
68 def unite_left(self, x: int, y: int) -> int:
69 assert (
70 0 <= x < self._n and 0 <= y < self._n
71 ), f"IndexError: {self.__class__.__name__}.unite_left(x: int, y: int), x={x}, y={y}"
72 x = self.root(x)
73 y = self.root(y)
74 if x == y:
75 return x
76 self._G[x].append(y)
77 self._G[y].append(x)
78 self._group_numbers -= 1
79 self._parents[x] += self._parents[y]
80 self._parents[y] = x
81 # self._roots.discard(y)
82 return x
83
84 def size(self, x: int) -> int:
85 assert (
86 0 <= x < self._n
87 ), f"IndexError: {self.__class__.__name__}.size(x: int), x={x}"
88 return -self._parents[self.root(x)]
89
90 def same(self, x: int, y: int) -> bool:
91 assert (
92 0 <= x < self._n and 0 <= y < self._n
93 ), f"IndexError: {self.__class__.__name__}.same(x: int, y: int), x={x}, y={y}"
94 return self.root(x) == self.root(y)
95
96 def members(self, x: int) -> set[int]:
97 assert (
98 0 <= x < self._n
99 ), f"IndexError: {self.__class__.__name__}.members(x: int), x={x}"
100 seen = set([x])
101 todo = [x]
102 while todo:
103 v = todo.pop()
104 for x in self._G[v]:
105 if x in seen:
106 continue
107 todo.append(x)
108 seen.add(x)
109 return seen
110
111 def all_roots(self) -> list[int]:
112 """Return all roots. / O(1)"""
113 # return self._roots
114 return [i for i, x in enumerate(self._parents) if x < 0]
115
116 def group_count(self) -> int:
117 return self._group_numbers
118
119 def all_group_members(self) -> defaultdict:
120 group_members = defaultdict(list)
121 for member in range(self._n):
122 group_members[self.root(member)].append(member)
123 return group_members
124
125 def clear(self) -> None:
126 self._group_numbers = self._n
127 for i in range(self._n):
128 self._parents[i] = -1
129 self._G[i].clear()
130
131 def __str__(self) -> str:
132 return (
133 f"<{self.__class__.__name__}> [\n"
134 + "\n".join(f" {k}: {v}" for k, v in self.all_group_members().items())
135 + "\n]"
136 )