Source code for titan_pylib.data_structures.union_find.undoable_union_find

  1from collections import defaultdict
  2
  3
[docs] 4class UndoableUnionFind: 5 6 def __init__(self, n: int) -> None: 7 """``n`` 個の要素からなる ``UndoableUnionFind`` を構築します。 8 :math:`O(n)` です。 9 """ 10 self._n: int = n 11 self._parents: list[int] = [-1] * n 12 self._group_count: int = n 13 self._history: list[tuple[int, int]] = [] 14
[docs] 15 def root(self, x: int) -> int: 16 """要素 ``x`` を含む集合の代表元を返します。 17 :math:`O(\\log{n})` です。 18 """ 19 while self._parents[x] >= 0: 20 x = self._parents[x] 21 return x
22
[docs] 23 def unite(self, x: int, y: int) -> bool: 24 """要素 ``x`` を含む集合と要素 ``y`` を含む集合を併合します。 25 :math:`O(\\log{n})` です。 26 27 Returns: 28 bool: もともと同じ集合であれば ``False``、そうでなければ ``True`` を返します。 29 """ 30 x = self.root(x) 31 y = self.root(y) 32 if x == y: 33 self._history.append((-1, -1)) 34 return False 35 if self._parents[x] > self._parents[y]: 36 x, y = y, x 37 self._group_count -= 1 38 self._history.append((x, self._parents[x])) 39 self._history.append((y, self._parents[y])) 40 self._parents[x] += self._parents[y] 41 self._parents[y] = x 42 return True
43
[docs] 44 def undo(self) -> None: 45 """直前の ``unite`` クエリを戻します。 46 :math:`O(\\log{n})` です。 47 """ 48 assert ( 49 self._history 50 ), f"Error: {self.__class__.__name__}.undo() with non history." 51 y, py = self._history.pop() 52 if y == -1: 53 return 54 self._group_count += 1 55 x, px = self._history.pop() 56 self._parents[y] = py 57 self._parents[x] = px
58
[docs] 59 def size(self, x: int) -> int: 60 """要素 ``x`` を含む集合の要素数を返します。 61 :math:`O(\\log{n})` です。 62 """ 63 return -self._parents[self.root(x)]
64
[docs] 65 def same(self, x: int, y: int) -> bool: 66 """ 67 要素 ``x`` と ``y`` が同じ集合に属するなら ``True`` を、 68 そうでないなら ``False`` を返します。 69 :math:`O(\\log{n})` です。 70 """ 71 return self.root(x) == self.root(y)
72
[docs] 73 def all_roots(self) -> list[int]: 74 """全ての集合の代表元からなるリストを返します。 75 :math:`O(n)` です。 76 77 Returns: 78 list[int]: 昇順であることが保証されます。 79 """ 80 return [i for i, x in enumerate(self._parents) if x < 0]
81
[docs] 82 def all_group_members(self) -> defaultdict: 83 """ 84 `key` に代表元、 `value` に `key` を代表元とする集合のリストをもつ `defaultdict` を返します。 85 :math:`O(n\\log{n})` です。 86 """ 87 group_members = defaultdict(list) 88 for member in range(self._n): 89 group_members[self.root(member)].append(member) 90 return group_members
91
[docs] 92 def group_count(self) -> int: 93 """集合の総数を返します。 94 :math:`O(1)` です。 95 """ 96 return self._group_count
97
[docs] 98 def clear(self) -> None: 99 """集合の連結状態をなくします(初期状態に戻します)。 100 :math:`O(n)` です。 101 """ 102 for i in range(self._n): 103 self._parents[i] = -1
104
[docs] 105 def __str__(self) -> str: 106 """よしなにします。 107 108 :math:`O(n\\log{n})` です。 109 """ 110 return ( 111 f"<{self.__class__.__name__}> [\n" 112 + "\n".join(f" {k}: {v}" for k, v in self.all_group_members().items()) 113 + "\n]" 114 )