partial_persistent_union_find

ソースコード

from titan_pylib.data_structures.union_find.partial_persistent_union_find import PartialPersistentUnionFind

view on github

展開済みコード

  1# from titan_pylib.data_structures.union_find.partial_persistent_union_find import PartialPersistentUnionFind
  2# from titan_pylib.data_structures.array.partial_persistent_array import (
  3#     PartialPersistentArray,
  4# )
  5from typing import Iterable, TypeVar, Generic
  6
  7T = TypeVar("T")
  8
  9
 10class PartialPersistentArray(Generic[T]):
 11    """部分永続配列です。
 12    最新版の更新と、過去へのアクセスが可能です。
 13    """
 14
 15    def __init__(self, a: Iterable[T]):
 16        """``a`` から部分永続配列を構築します。
 17        初期配列のバージョンは ``0`` です。
 18        :math:`O(n)` です。
 19        """
 20        self.a: list[list[T]] = [[e] for e in a]
 21        self.t: list[list[int]] = [[0] for _ in range(len(self.a))]
 22        self.last_time: int = 0
 23
 24    def set(self, k: int, v: T, t: int) -> None:
 25        """位置 ``k`` を ``v`` に更新します。
 26        :math:`O(1)` です。
 27
 28        Args:
 29          k (int): インデックスです。
 30          v (T): 更新する値です。
 31          t (int): 新たな配列のバージョンです。
 32        """
 33        assert t >= self.last_time
 34        assert t > self.t[k][-1]
 35        self.a[k].append(v)
 36        self.t[k].append(t)
 37        self.last_time = t
 38
 39    def get(self, k: int, t: int = -1) -> T:
 40        """位置 ``k`` 、バージョン ``t`` の要素を返します。
 41        :math:`O(\\log{n})` です。
 42
 43        Args:
 44          k (int): インデックスです。
 45          t (int, optional): バージョンです。デフォルトは最新バージョンです。
 46        """
 47        if t == -1 or t >= self.t[k][-1]:
 48            return self.a[k][-1]
 49        tk = self.t[k]
 50        ok, ng = 0, len(tk)
 51        while ng - ok > 1:
 52            mid = (ok + ng) // 2
 53            if tk[mid] <= t:
 54                ok = mid
 55            else:
 56                ng = mid
 57        return self.a[k][ok]
 58
 59    def tolist(self, t: int) -> list[T]:
 60        """バージョン ``t`` の配列を返します。
 61        :math:`O(n\\log{n})` です。
 62        """
 63        return [self.get(i, t) for i in range(len(self))]
 64
 65    def show(self, t: int) -> None:
 66        print(f"Time: {t}", end="")
 67        print([self.get(i, t) for i in range(len(self))])
 68
 69    def show_all(self) -> None:
 70        """すべてのバージョンの配列を表示します。"""
 71        for i in range(self.last_time):
 72            self.show(i)
 73
 74    def __len__(self):
 75        return len(self.a)
 76
 77
 78class PartialPersistentUnionFind:
 79
 80    def __init__(self, n: int):
 81        self._n: int = n
 82        self._parents: PartialPersistentArray[int] = PartialPersistentArray([-1] * n)
 83        self._last_time: int = 0
 84
 85    def root(self, x: int, t: int = -1) -> int:
 86        assert t == -1 or t <= self._last_time
 87        while True:
 88            p = self._parents.get(x, t)
 89            if p < 0:
 90                return x
 91            x = p
 92
 93    def unite(self, x: int, y: int, t: int) -> bool:
 94        assert t == -1 or t >= self._last_time
 95        self._last_time = t
 96        x = self.root(x, t)
 97        y = self.root(y, t)
 98        if x == y:
 99            return False
100        if self._parents.get(x, t) > self._parents.get(y, t):
101            x, y = y, x
102        self._parents.set(x, self._parents.get(x, t) + self._parents.get(y, t), t)
103        self._parents.set(y, x, t)
104        return True
105
106    def size(self, x: int, t: int = -1) -> int:
107        assert t == -1 or t <= self._last_time
108        return -self._parents.get(self.root(x, t), t)
109
110    def same(self, x: int, y: int, t: int = -1) -> bool:
111        assert t == -1 or t <= self._last_time
112        return self.root(x, t) == self.root(y, t)

仕様

class PartialPersistentUnionFind(n: int)[source]

Bases: object

root(x: int, t: int = -1) int[source]
same(x: int, y: int, t: int = -1) bool[source]
size(x: int, t: int = -1) int[source]
unite(x: int, y: int, t: int) bool[source]