Source code for titan_pylib.data_structures.union_find.weighted_union_find

 1from typing import Optional
 2from collections import defaultdict
 3
 4
[docs] 5class WeightedUnionFind: 6 7 def __init__(self, n: int): 8 self._n: int = n 9 self._group_numbers: int = n 10 self._parents: list[int] = [-1] * n 11 self._weight: list[int] = [0] * n 12
[docs] 13 def root(self, x: int) -> int: 14 path = [x] 15 while self._parents[x] >= 0: 16 x = self._parents[x] 17 path.append(x) 18 a = path.pop() 19 while path: 20 x = path.pop() 21 self._weight[x] += self._weight[self._parents[x]] 22 self._parents[x] = a 23 return a
24
[docs] 25 def unite(self, x: int, y: int, w: int) -> Optional[int]: 26 """Untie x and y, weight[y] = weight[x] + w. / O(α(N))""" 27 rx = self.root(x) 28 ry = self.root(y) 29 if rx == ry: 30 return rx if self.diff(x, y) == w else None 31 w += self._weight[x] - self._weight[y] 32 self._group_numbers -= 1 33 if self._parents[rx] > self._parents[ry]: 34 rx, ry = ry, rx 35 w = -w 36 self._parents[rx] += self._parents[ry] 37 self._parents[ry] = rx 38 self._weight[ry] = w 39 return rx
40
[docs] 41 def size(self, x: int) -> int: 42 return -self._parents[self.root(x)]
43
[docs] 44 def same(self, x: int, y: int) -> bool: 45 return self.root(x) == self.root(y)
46
[docs] 47 def members(self, x: int) -> list[int]: 48 x = self.root(x) 49 return [i for i in range(self._n) if self.root(i) == x]
50
[docs] 51 def all_roots(self) -> list[int]: 52 return [i for i, x in enumerate(self._parents) if x < 0]
53
[docs] 54 def group_count(self) -> int: 55 return self._group_numbers
56
[docs] 57 def all_group_members(self) -> defaultdict: 58 group_members = defaultdict(list) 59 for member in range(self._n): 60 group_members[self.root(member)].append(member) 61 return group_members
62
[docs] 63 def clear(self) -> None: 64 self._group_numbers = self._n 65 for i in range(self._n): 66 # self._G[i].clear() 67 self._parents[i] = -1
68
[docs] 69 def diff(self, x: int, y: int) -> Optional[int]: 70 """weight[y] - weight[x]""" 71 if not self.same(x, y): 72 return None 73 return self._weight[y] - self._weight[x]
74 75 def __str__(self) -> str: 76 return ( 77 "<WeightedUnionFind> [\n" 78 + "\n".join(f" {k}: {v}" for k, v in self.all_group_members().items()) 79 + "\n]" 80 )