Source code for titan_pylib.data_structures.union_find.union_find_heavy

  1from collections import defaultdict
  2
  3
[docs] 4class UnionFindHeavy: 5 6 def __init__(self, n: int) -> None: 7 self._n: int = n 8 self._group_numbers: int = n 9 self._parents: list[int] = [-1] * n # defaultdict(lambda: -1) 10 # self._roots = set(range(n)) 11 self._edges: list[int] = [0] * n 12 self._G: list[list[int]] = [[] for _ in range(n)] 13
[docs] 14 def root(self, x: int) -> int: 15 assert 0 <= x < self._n, f"{self.__class__.__name__}.root(x) IndexError, x={x}" 16 a = x 17 while self._parents[a] >= 0: 18 a = self._parents[a] 19 # return a # not compressing path. 20 while self._parents[x] >= 0: 21 y = x 22 x = self._parents[x] 23 self._parents[y] = a 24 return a
25
[docs] 26 def unite(self, x: int, y: int) -> bool: 27 assert ( 28 0 <= x < self._n and 0 <= y < self._n 29 ), f"IndexError: {self.__class__.__name__}.unite({x}, {y})" 30 x = self.root(x) 31 y = self.root(y) 32 self._edges[x] += 1 33 self._edges[y] += 1 34 if x == y: 35 return False 36 self._G[x].append(y) 37 self._G[y].append(x) 38 self._group_numbers -= 1 39 if self._parents[x] > self._parents[y]: 40 x, y = y, x 41 self._parents[x] += self._parents[y] 42 self._parents[y] = x 43 # self._roots.discard(y) 44 return True
45
[docs] 46 def get_edges(self, x: int) -> int: 47 return self._edges[self.root(x)]
48 49 # x -> y
[docs] 50 def unite_right(self, x: int, y: int) -> int: 51 assert ( 52 0 <= x < self._n and 0 <= y < self._n 53 ), f"IndexError: {self.__class__.__name__}.unite_right(x: int, y: int), x={x}, y={y}" 54 x = self.root(x) 55 y = self.root(y) 56 if x == y: 57 return x 58 self._G[x].append(y) 59 self._G[y].append(x) 60 self._group_numbers -= 1 61 self._parents[y] += self._parents[x] 62 self._parents[x] = y 63 # self._roots.discard(y) 64 return y
65 66 # x <- y
[docs] 67 def unite_left(self, x: int, y: int) -> int: 68 assert ( 69 0 <= x < self._n and 0 <= y < self._n 70 ), f"IndexError: {self.__class__.__name__}.unite_left(x: int, y: int), x={x}, y={y}" 71 x = self.root(x) 72 y = self.root(y) 73 if x == y: 74 return x 75 self._G[x].append(y) 76 self._G[y].append(x) 77 self._group_numbers -= 1 78 self._parents[x] += self._parents[y] 79 self._parents[y] = x 80 # self._roots.discard(y) 81 return x
82
[docs] 83 def size(self, x: int) -> int: 84 assert ( 85 0 <= x < self._n 86 ), f"IndexError: {self.__class__.__name__}.size(x: int), x={x}" 87 return -self._parents[self.root(x)]
88
[docs] 89 def same(self, x: int, y: int) -> bool: 90 assert ( 91 0 <= x < self._n and 0 <= y < self._n 92 ), f"IndexError: {self.__class__.__name__}.same(x: int, y: int), x={x}, y={y}" 93 return self.root(x) == self.root(y)
94
[docs] 95 def members(self, x: int) -> set[int]: 96 assert ( 97 0 <= x < self._n 98 ), f"IndexError: {self.__class__.__name__}.members(x: int), x={x}" 99 seen = set([x]) 100 todo = [x] 101 while todo: 102 v = todo.pop() 103 for x in self._G[v]: 104 if x in seen: 105 continue 106 todo.append(x) 107 seen.add(x) 108 return seen
109
[docs] 110 def all_roots(self) -> list[int]: 111 """Return all roots. / O(1)""" 112 # return self._roots 113 return [i for i, x in enumerate(self._parents) if x < 0]
114
[docs] 115 def group_count(self) -> int: 116 return self._group_numbers
117
[docs] 118 def all_group_members(self) -> defaultdict: 119 group_members = defaultdict(list) 120 for member in range(self._n): 121 group_members[self.root(member)].append(member) 122 return group_members
123
[docs] 124 def clear(self) -> None: 125 self._group_numbers = self._n 126 for i in range(self._n): 127 self._parents[i] = -1 128 self._G[i].clear()
129 130 def __str__(self) -> str: 131 return ( 132 f"<{self.__class__.__name__}> [\n" 133 + "\n".join(f" {k}: {v}" for k, v in self.all_group_members().items()) 134 + "\n]" 135 )