weighted_union_find

ソースコード

from titan_pylib.data_structures.union_find.weighted_union_find import WeightedUnionFind

view on github

展開済みコード

 1# from titan_pylib.data_structures.union_find.weighted_union_find import WeightedUnionFind
 2from typing import Optional
 3from collections import defaultdict
 4
 5
 6class WeightedUnionFind:
 7
 8    def __init__(self, n: int):
 9        self._n: int = n
10        self._group_numbers: int = n
11        self._parents: list[int] = [-1] * n
12        self._weight: list[int] = [0] * n
13
14    def root(self, x: int) -> int:
15        path = [x]
16        while self._parents[x] >= 0:
17            x = self._parents[x]
18            path.append(x)
19        a = path.pop()
20        while path:
21            x = path.pop()
22            self._weight[x] += self._weight[self._parents[x]]
23            self._parents[x] = a
24        return a
25
26    def unite(self, x: int, y: int, w: int) -> Optional[int]:
27        """Untie x and y, weight[y] = weight[x] + w. / O(α(N))"""
28        rx = self.root(x)
29        ry = self.root(y)
30        if rx == ry:
31            return rx if self.diff(x, y) == w else None
32        w += self._weight[x] - self._weight[y]
33        self._group_numbers -= 1
34        if self._parents[rx] > self._parents[ry]:
35            rx, ry = ry, rx
36            w = -w
37        self._parents[rx] += self._parents[ry]
38        self._parents[ry] = rx
39        self._weight[ry] = w
40        return rx
41
42    def size(self, x: int) -> int:
43        return -self._parents[self.root(x)]
44
45    def same(self, x: int, y: int) -> bool:
46        return self.root(x) == self.root(y)
47
48    def members(self, x: int) -> list[int]:
49        x = self.root(x)
50        return [i for i in range(self._n) if self.root(i) == x]
51
52    def all_roots(self) -> list[int]:
53        return [i for i, x in enumerate(self._parents) if x < 0]
54
55    def group_count(self) -> int:
56        return self._group_numbers
57
58    def all_group_members(self) -> defaultdict:
59        group_members = defaultdict(list)
60        for member in range(self._n):
61            group_members[self.root(member)].append(member)
62        return group_members
63
64    def clear(self) -> None:
65        self._group_numbers = self._n
66        for i in range(self._n):
67            # self._G[i].clear()
68            self._parents[i] = -1
69
70    def diff(self, x: int, y: int) -> Optional[int]:
71        """weight[y] - weight[x]"""
72        if not self.same(x, y):
73            return None
74        return self._weight[y] - self._weight[x]
75
76    def __str__(self) -> str:
77        return (
78            "<WeightedUnionFind> [\n"
79            + "\n".join(f"  {k}: {v}" for k, v in self.all_group_members().items())
80            + "\n]"
81        )

仕様

class WeightedUnionFind(n: int)[source]

Bases: object

all_group_members() defaultdict[source]
all_roots() list[int][source]
clear() None[source]
diff(x: int, y: int) int | None[source]

weight[y] - weight[x]

group_count() int[source]
members(x: int) list[int][source]
root(x: int) int[source]
same(x: int, y: int) bool[source]
size(x: int) int[source]
unite(x: int, y: int, w: int) int | None[source]

Untie x and y, weight[y] = weight[x] + w. / O(α(N))