union_find_heavy

ソースコード

from titan_pylib.data_structures.union_find.union_find_heavy import UnionFindHeavy

view on github

展開済みコード

  1# from titan_pylib.data_structures.union_find.union_find_heavy import UnionFindHeavy
  2from collections import defaultdict
  3
  4
  5class UnionFindHeavy:
  6
  7    def __init__(self, n: int) -> None:
  8        self._n: int = n
  9        self._group_numbers: int = n
 10        self._parents: list[int] = [-1] * n  # defaultdict(lambda: -1)
 11        # self._roots = set(range(n))
 12        self._edges: list[int] = [0] * n
 13        self._G: list[list[int]] = [[] for _ in range(n)]
 14
 15    def root(self, x: int) -> int:
 16        assert 0 <= x < self._n, f"{self.__class__.__name__}.root(x) IndexError, x={x}"
 17        a = x
 18        while self._parents[a] >= 0:
 19            a = self._parents[a]
 20        # return a  # not compressing path.
 21        while self._parents[x] >= 0:
 22            y = x
 23            x = self._parents[x]
 24            self._parents[y] = a
 25        return a
 26
 27    def unite(self, x: int, y: int) -> bool:
 28        assert (
 29            0 <= x < self._n and 0 <= y < self._n
 30        ), f"IndexError: {self.__class__.__name__}.unite({x}, {y})"
 31        x = self.root(x)
 32        y = self.root(y)
 33        self._edges[x] += 1
 34        self._edges[y] += 1
 35        if x == y:
 36            return False
 37        self._G[x].append(y)
 38        self._G[y].append(x)
 39        self._group_numbers -= 1
 40        if self._parents[x] > self._parents[y]:
 41            x, y = y, x
 42        self._parents[x] += self._parents[y]
 43        self._parents[y] = x
 44        # self._roots.discard(y)
 45        return True
 46
 47    def get_edges(self, x: int) -> int:
 48        return self._edges[self.root(x)]
 49
 50    # x -> y
 51    def unite_right(self, x: int, y: int) -> int:
 52        assert (
 53            0 <= x < self._n and 0 <= y < self._n
 54        ), f"IndexError: {self.__class__.__name__}.unite_right(x: int, y: int), x={x}, y={y}"
 55        x = self.root(x)
 56        y = self.root(y)
 57        if x == y:
 58            return x
 59        self._G[x].append(y)
 60        self._G[y].append(x)
 61        self._group_numbers -= 1
 62        self._parents[y] += self._parents[x]
 63        self._parents[x] = y
 64        # self._roots.discard(y)
 65        return y
 66
 67    # x <- y
 68    def unite_left(self, x: int, y: int) -> int:
 69        assert (
 70            0 <= x < self._n and 0 <= y < self._n
 71        ), f"IndexError: {self.__class__.__name__}.unite_left(x: int, y: int), x={x}, y={y}"
 72        x = self.root(x)
 73        y = self.root(y)
 74        if x == y:
 75            return x
 76        self._G[x].append(y)
 77        self._G[y].append(x)
 78        self._group_numbers -= 1
 79        self._parents[x] += self._parents[y]
 80        self._parents[y] = x
 81        # self._roots.discard(y)
 82        return x
 83
 84    def size(self, x: int) -> int:
 85        assert (
 86            0 <= x < self._n
 87        ), f"IndexError: {self.__class__.__name__}.size(x: int), x={x}"
 88        return -self._parents[self.root(x)]
 89
 90    def same(self, x: int, y: int) -> bool:
 91        assert (
 92            0 <= x < self._n and 0 <= y < self._n
 93        ), f"IndexError: {self.__class__.__name__}.same(x: int, y: int), x={x}, y={y}"
 94        return self.root(x) == self.root(y)
 95
 96    def members(self, x: int) -> set[int]:
 97        assert (
 98            0 <= x < self._n
 99        ), f"IndexError: {self.__class__.__name__}.members(x: int), x={x}"
100        seen = set([x])
101        todo = [x]
102        while todo:
103            v = todo.pop()
104            for x in self._G[v]:
105                if x in seen:
106                    continue
107                todo.append(x)
108                seen.add(x)
109        return seen
110
111    def all_roots(self) -> list[int]:
112        """Return all roots. / O(1)"""
113        # return self._roots
114        return [i for i, x in enumerate(self._parents) if x < 0]
115
116    def group_count(self) -> int:
117        return self._group_numbers
118
119    def all_group_members(self) -> defaultdict:
120        group_members = defaultdict(list)
121        for member in range(self._n):
122            group_members[self.root(member)].append(member)
123        return group_members
124
125    def clear(self) -> None:
126        self._group_numbers = self._n
127        for i in range(self._n):
128            self._parents[i] = -1
129            self._G[i].clear()
130
131    def __str__(self) -> str:
132        return (
133            f"<{self.__class__.__name__}> [\n"
134            + "\n".join(f"  {k}: {v}" for k, v in self.all_group_members().items())
135            + "\n]"
136        )

仕様

class UnionFindHeavy(n: int)[source]

Bases: object

all_group_members() defaultdict[source]
all_roots() list[int][source]

Return all roots. / O(1)

clear() None[source]
get_edges(x: int) int[source]
group_count() int[source]
members(x: int) set[int][source]
root(x: int) int[source]
same(x: int, y: int) bool[source]
size(x: int) int[source]
unite(x: int, y: int) bool[source]
unite_left(x: int, y: int) int[source]
unite_right(x: int, y: int) int[source]