persistent_union_find

ソースコード

from titan_pylib.data_structures.union_find.persistent_union_find import PersistentUnionFind

view on github

展開済みコード

  1# from titan_pylib.data_structures.union_find.persistent_union_find import PersistentUnionFind
  2# from titan_pylib.data_structures.array.persistent_array import PersistentArray
  3from typing import Iterable, TypeVar, Generic, Optional
  4
  5T = TypeVar("T")
  6
  7
  8class PersistentArray(Generic[T]):
  9
 10    class _Node:
 11
 12        def __init__(self, key: T):
 13            self.key: T = key
 14            self.left: Optional[PersistentArray._Node] = None
 15            self.right: Optional[PersistentArray._Node] = None
 16
 17        def copy(self) -> "PersistentArray._Node":
 18            node = PersistentArray._Node(self.key)
 19            node.left = self.left
 20            node.right = self.right
 21            return node
 22
 23    def __init__(
 24        self, a: Iterable[T] = [], _root: Optional["PersistentArray._Node"] = None
 25    ) -> None:
 26        self.root = self._build(a) if _root is None else _root
 27
 28    def _build(self, a: Iterable[T]) -> Optional["PersistentArray._Node"]:
 29        pool = [PersistentArray._Node(e) for e in a]
 30        self.n = len(pool)
 31        if not pool:
 32            return None
 33        n = len(pool)
 34        for i in range(1, n + 1):
 35            if 2 * i - 1 < n:
 36                pool[i - 1].left = pool[2 * i - 1]
 37            if 2 * i < n:
 38                pool[i - 1].right = pool[2 * i]
 39        return pool[0]
 40
 41    def _new(self, root: Optional["PersistentArray._Node"]) -> "PersistentArray[T]":
 42        res = PersistentArray(_root=root)
 43        res.n = self.n
 44        return res
 45
 46    def set(self, k: int, v: T) -> "PersistentArray[T]":
 47        assert 0 <= k < len(self), f"IndexError: {self.__class__.__name__}.set({k})"
 48        assert self.root
 49        node = self.root
 50        new_node = node.copy()
 51        res = self._new(new_node)
 52        k += 1
 53        b = k.bit_length()
 54        for i in range(b - 2, -1, -1):
 55            if k >> i & 1:
 56                node = node.right
 57                new_node.right = node.copy()
 58                new_node = new_node.right
 59            else:
 60                node = node.left
 61                new_node.left = node.copy()
 62                new_node = new_node.left
 63        new_node.key = v
 64        return res
 65
 66    def get(self, k: int) -> T:
 67        assert 0 <= k < len(self), f"IndexError: {self.__class__.__name__}.get({k})"
 68        node = self.root
 69        k += 1
 70        b = k.bit_length()
 71        for i in range(b - 2, -1, -1):
 72            if k >> i & 1:
 73                node = node.right
 74            else:
 75                node = node.left
 76        return node.key
 77
 78    __getitem__ = get
 79
 80    def copy(self) -> "PersistentArray[T]":
 81        return self._new(None if self.root is None else self.root.copy())
 82
 83    def tolist(self) -> list[T]:
 84        node = self.root
 85        a: list[T] = []
 86        if not node:
 87            return a
 88        q = [node]
 89        for node in q:
 90            a.append(node.key)
 91            if node.left:
 92                q.append(node.left)
 93            if node.right:
 94                q.append(node.right)
 95        return a
 96
 97    def __len__(self):
 98        return self.n
 99
100    def __str__(self):
101        return str(self.tolist())
102
103    def __repr__(self):
104        return f"{self.__class__.__name__}({self})"
105from typing import Optional
106
107
108class PersistentUnionFind:
109
110    def __init__(self, n: int, _parents: Optional[PersistentArray[int]] = None) -> None:
111        """``n`` 個の要素からなる ``PersistentUnionFind`` を構築します。
112        :math:`O(n)` です。
113        """
114        self._n: int = n
115        self._parents: PersistentArray[int] = (
116            PersistentArray([-1] * n) if _parents is None else _parents
117        )
118
119    def _new(self, _parents: PersistentArray[int]) -> "PersistentUnionFind":
120        return PersistentUnionFind(self._n, _parents)
121
122    def copy(self) -> "PersistentUnionFind":
123        """コピーします。
124        :math:`O(1)` です。
125        """
126        return self._new(self._parents.copy())
127
128    def root(self, x: int) -> int:
129        """要素 ``x`` を含む集合の代表元を返します。
130        :math:`O(\\log^2{n})` です。
131        """
132        _parents = self._parents
133        while True:
134            p = _parents.get(x)
135            if p < 0:
136                return x
137            x = p
138
139    def unite(self, x: int, y: int, update: bool = True) -> "PersistentUnionFind":
140        """要素 ``x`` を含む集合と要素 ``y`` を含む集合を併合します。
141        :math:`O(\\log^2{n})` です。
142
143        Args:
144          x (int): 集合の要素です。
145          y (int): 集合の要素です。
146          update (bool, optional): 併合後を新しいインスタンスにするなら ``True`` です。
147
148        Returns:
149          PersistentUnionFind: 併合後の uf です。
150        """
151        x = self.root(x)
152        y = self.root(y)
153        res_parents = self._parents.copy() if update else self._parents
154        if x == y:
155            return self._new(res_parents)
156        px, py = res_parents.get(x), res_parents.get(y)
157        if px > py:
158            x, y = y, x
159        res_parents = res_parents.set(x, px + py)
160        res_parents = res_parents.set(y, x)
161        return self._new(res_parents)
162
163    def size(self, x: int) -> int:
164        """要素 ``x`` を含む集合の要素数を返します。
165        :math:`O(\\log^2{n})` です。
166        """
167        return -self._parents.get(self.root(x))
168
169    def same(self, x: int, y: int) -> bool:
170        """
171        要素 ``x`` と ``y`` が同じ集合に属するなら ``True`` を、
172        そうでないなら ``False`` を返します。
173        :math:`O(\\log^2{n})` です。
174        """
175        return self.root(x) == self.root(y)

仕様

class PersistentUnionFind(n: int, _parents: PersistentArray[int] | None = None)[source]

Bases: object

copy() PersistentUnionFind[source]

コピーします。 \(O(1)\) です。

root(x: int) int[source]

要素 x を含む集合の代表元を返します。 \(O(\log^2{n})\) です。

same(x: int, y: int) bool[source]

要素 xy が同じ集合に属するなら True を、 そうでないなら False を返します。 \(O(\log^2{n})\) です。

size(x: int) int[source]

要素 x を含む集合の要素数を返します。 \(O(\log^2{n})\) です。

unite(x: int, y: int, update: bool = True) PersistentUnionFind[source]

要素 x を含む集合と要素 y を含む集合を併合します。 \(O(\log^2{n})\) です。

Parameters:
  • x (int) – 集合の要素です。

  • y (int) – 集合の要素です。

  • update (bool, optional) – 併合後を新しいインスタンスにするなら True です。

Returns:

併合後の uf です。

Return type:

PersistentUnionFind