Source code for titan_pylib.data_structures.union_find.union_find_heavy
1from collections import defaultdict
2
3
[docs]
4class UnionFindHeavy:
5
6 def __init__(self, n: int) -> None:
7 self._n: int = n
8 self._group_numbers: int = n
9 self._parents: list[int] = [-1] * n # defaultdict(lambda: -1)
10 # self._roots = set(range(n))
11 self._edges: list[int] = [0] * n
12 self._G: list[list[int]] = [[] for _ in range(n)]
13
[docs]
14 def root(self, x: int) -> int:
15 assert 0 <= x < self._n, f"{self.__class__.__name__}.root(x) IndexError, x={x}"
16 a = x
17 while self._parents[a] >= 0:
18 a = self._parents[a]
19 # return a # not compressing path.
20 while self._parents[x] >= 0:
21 y = x
22 x = self._parents[x]
23 self._parents[y] = a
24 return a
25
[docs]
26 def unite(self, x: int, y: int) -> bool:
27 assert (
28 0 <= x < self._n and 0 <= y < self._n
29 ), f"IndexError: {self.__class__.__name__}.unite({x}, {y})"
30 x = self.root(x)
31 y = self.root(y)
32 self._edges[x] += 1
33 self._edges[y] += 1
34 if x == y:
35 return False
36 self._G[x].append(y)
37 self._G[y].append(x)
38 self._group_numbers -= 1
39 if self._parents[x] > self._parents[y]:
40 x, y = y, x
41 self._parents[x] += self._parents[y]
42 self._parents[y] = x
43 # self._roots.discard(y)
44 return True
45
[docs]
46 def get_edges(self, x: int) -> int:
47 return self._edges[self.root(x)]
48
49 # x -> y
[docs]
50 def unite_right(self, x: int, y: int) -> int:
51 assert (
52 0 <= x < self._n and 0 <= y < self._n
53 ), f"IndexError: {self.__class__.__name__}.unite_right(x: int, y: int), x={x}, y={y}"
54 x = self.root(x)
55 y = self.root(y)
56 if x == y:
57 return x
58 self._G[x].append(y)
59 self._G[y].append(x)
60 self._group_numbers -= 1
61 self._parents[y] += self._parents[x]
62 self._parents[x] = y
63 # self._roots.discard(y)
64 return y
65
66 # x <- y
[docs]
67 def unite_left(self, x: int, y: int) -> int:
68 assert (
69 0 <= x < self._n and 0 <= y < self._n
70 ), f"IndexError: {self.__class__.__name__}.unite_left(x: int, y: int), x={x}, y={y}"
71 x = self.root(x)
72 y = self.root(y)
73 if x == y:
74 return x
75 self._G[x].append(y)
76 self._G[y].append(x)
77 self._group_numbers -= 1
78 self._parents[x] += self._parents[y]
79 self._parents[y] = x
80 # self._roots.discard(y)
81 return x
82
[docs]
83 def size(self, x: int) -> int:
84 assert (
85 0 <= x < self._n
86 ), f"IndexError: {self.__class__.__name__}.size(x: int), x={x}"
87 return -self._parents[self.root(x)]
88
[docs]
89 def same(self, x: int, y: int) -> bool:
90 assert (
91 0 <= x < self._n and 0 <= y < self._n
92 ), f"IndexError: {self.__class__.__name__}.same(x: int, y: int), x={x}, y={y}"
93 return self.root(x) == self.root(y)
94
[docs]
95 def members(self, x: int) -> set[int]:
96 assert (
97 0 <= x < self._n
98 ), f"IndexError: {self.__class__.__name__}.members(x: int), x={x}"
99 seen = set([x])
100 todo = [x]
101 while todo:
102 v = todo.pop()
103 for x in self._G[v]:
104 if x in seen:
105 continue
106 todo.append(x)
107 seen.add(x)
108 return seen
109
[docs]
110 def all_roots(self) -> list[int]:
111 """Return all roots. / O(1)"""
112 # return self._roots
113 return [i for i, x in enumerate(self._parents) if x < 0]
114
[docs]
115 def group_count(self) -> int:
116 return self._group_numbers
117
[docs]
118 def all_group_members(self) -> defaultdict:
119 group_members = defaultdict(list)
120 for member in range(self._n):
121 group_members[self.root(member)].append(member)
122 return group_members
123
[docs]
124 def clear(self) -> None:
125 self._group_numbers = self._n
126 for i in range(self._n):
127 self._parents[i] = -1
128 self._G[i].clear()
129
130 def __str__(self) -> str:
131 return (
132 f"<{self.__class__.__name__}> [\n"
133 + "\n".join(f" {k}: {v}" for k, v in self.all_group_members().items())
134 + "\n]"
135 )