undoable_union_find

ソースコード

from titan_pylib.data_structures.union_find.undoable_union_find import UndoableUnionFind

view on github

展開済みコード

  1# from titan_pylib.data_structures.union_find.undoable_union_find import UndoableUnionFind
  2from collections import defaultdict
  3
  4
  5class UndoableUnionFind:
  6
  7    def __init__(self, n: int) -> None:
  8        """``n`` 個の要素からなる ``UndoableUnionFind`` を構築します。
  9        :math:`O(n)` です。
 10        """
 11        self._n: int = n
 12        self._parents: list[int] = [-1] * n
 13        self._group_count: int = n
 14        self._history: list[tuple[int, int]] = []
 15
 16    def root(self, x: int) -> int:
 17        """要素 ``x`` を含む集合の代表元を返します。
 18        :math:`O(\\log{n})` です。
 19        """
 20        while self._parents[x] >= 0:
 21            x = self._parents[x]
 22        return x
 23
 24    def unite(self, x: int, y: int) -> bool:
 25        """要素 ``x`` を含む集合と要素 ``y`` を含む集合を併合します。
 26        :math:`O(\\log{n})` です。
 27
 28        Returns:
 29          bool: もともと同じ集合であれば ``False``、そうでなければ ``True`` を返します。
 30        """
 31        x = self.root(x)
 32        y = self.root(y)
 33        if x == y:
 34            self._history.append((-1, -1))
 35            return False
 36        if self._parents[x] > self._parents[y]:
 37            x, y = y, x
 38        self._group_count -= 1
 39        self._history.append((x, self._parents[x]))
 40        self._history.append((y, self._parents[y]))
 41        self._parents[x] += self._parents[y]
 42        self._parents[y] = x
 43        return True
 44
 45    def undo(self) -> None:
 46        """直前の ``unite`` クエリを戻します。
 47        :math:`O(\\log{n})` です。
 48        """
 49        assert (
 50            self._history
 51        ), f"Error: {self.__class__.__name__}.undo() with non history."
 52        y, py = self._history.pop()
 53        if y == -1:
 54            return
 55        self._group_count += 1
 56        x, px = self._history.pop()
 57        self._parents[y] = py
 58        self._parents[x] = px
 59
 60    def size(self, x: int) -> int:
 61        """要素 ``x`` を含む集合の要素数を返します。
 62        :math:`O(\\log{n})` です。
 63        """
 64        return -self._parents[self.root(x)]
 65
 66    def same(self, x: int, y: int) -> bool:
 67        """
 68        要素 ``x`` と ``y`` が同じ集合に属するなら ``True`` を、
 69        そうでないなら ``False`` を返します。
 70        :math:`O(\\log{n})` です。
 71        """
 72        return self.root(x) == self.root(y)
 73
 74    def all_roots(self) -> list[int]:
 75        """全ての集合の代表元からなるリストを返します。
 76        :math:`O(n)` です。
 77
 78        Returns:
 79            list[int]: 昇順であることが保証されます。
 80        """
 81        return [i for i, x in enumerate(self._parents) if x < 0]
 82
 83    def all_group_members(self) -> defaultdict:
 84        """
 85        `key` に代表元、 `value` に `key` を代表元とする集合のリストをもつ `defaultdict` を返します。
 86        :math:`O(n\\log{n})` です。
 87        """
 88        group_members = defaultdict(list)
 89        for member in range(self._n):
 90            group_members[self.root(member)].append(member)
 91        return group_members
 92
 93    def group_count(self) -> int:
 94        """集合の総数を返します。
 95        :math:`O(1)` です。
 96        """
 97        return self._group_count
 98
 99    def clear(self) -> None:
100        """集合の連結状態をなくします(初期状態に戻します)。
101        :math:`O(n)` です。
102        """
103        for i in range(self._n):
104            self._parents[i] = -1
105
106    def __str__(self) -> str:
107        """よしなにします。
108
109        :math:`O(n\\log{n})` です。
110        """
111        return (
112            f"<{self.__class__.__name__}> [\n"
113            + "\n".join(f"  {k}: {v}" for k, v in self.all_group_members().items())
114            + "\n]"
115        )

仕様

class UndoableUnionFind(n: int)[source]

Bases: object

__str__() str[source]

よしなにします。

\(O(n\log{n})\) です。

all_group_members() defaultdict[source]

key に代表元、 valuekey を代表元とする集合のリストをもつ defaultdict を返します。 \(O(n\log{n})\) です。

all_roots() list[int][source]

全ての集合の代表元からなるリストを返します。 \(O(n)\) です。

Returns:

昇順であることが保証されます。

Return type:

list[int]

clear() None[source]

集合の連結状態をなくします(初期状態に戻します)。 \(O(n)\) です。

group_count() int[source]

集合の総数を返します。 \(O(1)\) です。

root(x: int) int[source]

要素 x を含む集合の代表元を返します。 \(O(\log{n})\) です。

same(x: int, y: int) bool[source]

要素 xy が同じ集合に属するなら True を、 そうでないなら False を返します。 \(O(\log{n})\) です。

size(x: int) int[source]

要素 x を含む集合の要素数を返します。 \(O(\log{n})\) です。

undo() None[source]

直前の unite クエリを戻します。 \(O(\log{n})\) です。

unite(x: int, y: int) bool[source]

要素 x を含む集合と要素 y を含む集合を併合します。 \(O(\log{n})\) です。

Returns:

もともと同じ集合であれば False、そうでなければ True を返します。

Return type:

bool