Source code for titan_pylib.data_structures.union_find.union_find

  1from collections import defaultdict
  2
  3
[docs] 4class UnionFind: 5 6 def __init__(self, n: int) -> None: 7 """``n`` 個の要素からなる ``UnionFind`` を構築します。 8 :math:`O(n)` です。 9 """ 10 self._n: int = n 11 self._group_numbers: int = n 12 self._parents: list[int] = [-1] * n 13
[docs] 14 def root(self, x: int) -> int: 15 """要素 ``x`` を含む集合の代表元を返します。 16 :math:`O(\\alpha(n))` です。 17 """ 18 a = x 19 while self._parents[a] >= 0: 20 a = self._parents[a] 21 while self._parents[x] >= 0: 22 y = x 23 x = self._parents[x] 24 self._parents[y] = a 25 return a
26
[docs] 27 def unite(self, x: int, y: int) -> bool: 28 """要素 ``x`` を含む集合と要素 ``y`` を含む集合を併合します。 29 :math:`O(\\alpha(n))` です。 30 31 Returns: 32 bool: もともと同じ集合であれば ``False``、そうでなければ ``True`` を返します。 33 """ 34 x = self.root(x) 35 y = self.root(y) 36 if x == y: 37 return False 38 self._group_numbers -= 1 39 if self._parents[x] > self._parents[y]: 40 x, y = y, x 41 self._parents[x] += self._parents[y] 42 self._parents[y] = x 43 return True
44
[docs] 45 def unite_right(self, x: int, y: int) -> int: 46 # x -> y 47 x = self.root(x) 48 y = self.root(y) 49 if x == y: 50 return x 51 self._group_numbers -= 1 52 self._parents[y] += self._parents[x] 53 self._parents[x] = y 54 return y
55
[docs] 56 def unite_left(self, x: int, y: int) -> int: 57 # x <- y 58 x = self.root(x) 59 y = self.root(y) 60 if x == y: 61 return x 62 self._group_numbers -= 1 63 self._parents[x] += self._parents[y] 64 self._parents[y] = x 65 return x
66
[docs] 67 def size(self, x: int) -> int: 68 """要素 ``x`` を含む集合の要素数を返します。 69 :math:`O(\\alpha(n))` です。 70 """ 71 return -self._parents[self.root(x)]
72
[docs] 73 def same(self, x: int, y: int) -> bool: 74 """ 75 要素 ``x`` と ``y`` が同じ集合に属するなら ``True`` を、 76 そうでないなら ``False`` を返します。 77 :math:`O(\\alpha(n))` です。 78 """ 79 return self.root(x) == self.root(y)
80
[docs] 81 def members(self, x: int) -> list[int]: 82 """要素 ``x`` を含む集合を返します。""" 83 x = self.root(x) 84 return [i for i in range(self._n) if self.root(i) == x]
85
[docs] 86 def all_roots(self) -> list[int]: 87 """全ての集合の代表元からなるリストを返します。 88 :math:`O(n)` です。 89 90 Returns: 91 list[int]: 昇順であることが保証されます。 92 """ 93 return [i for i, x in enumerate(self._parents) if x < 0]
94
[docs] 95 def group_count(self) -> int: 96 """集合の総数を返します。 97 :math:`O(1)` です。 98 """ 99 return self._group_numbers
100
[docs] 101 def all_group_members(self) -> defaultdict: 102 """ 103 key に代表元、 value に key を代表元とする集合のリストをもつ defaultdict を返します。 104 :math:`O(n\\alpha(n))` です。 105 """ 106 group_members = defaultdict(list) 107 for member in range(self._n): 108 group_members[self.root(member)].append(member) 109 return group_members
110
[docs] 111 def clear(self) -> None: 112 """集合の連結状態をなくします(初期状態に戻します)。 113 :math:`O(n)` です。 114 """ 115 self._group_numbers = self._n 116 for i in range(self._n): 117 self._parents[i] = -1
118
[docs] 119 def __str__(self) -> str: 120 """よしなにします。 121 :math:`O(n\\alpha(n))` です。 122 """ 123 return ( 124 f"<{self.__class__.__name__}> [\n" 125 + "\n".join(f" {k}: {v}" for k, v in self.all_group_members().items()) 126 + "\n]" 127 )