weighted_union_find¶
ソースコード¶
from titan_pylib.data_structures.union_find.weighted_union_find import WeightedUnionFind
展開済みコード¶
1# from titan_pylib.data_structures.union_find.weighted_union_find import WeightedUnionFind
2from typing import Optional
3from collections import defaultdict
4
5
6class WeightedUnionFind:
7
8 def __init__(self, n: int):
9 self._n: int = n
10 self._group_numbers: int = n
11 self._parents: list[int] = [-1] * n
12 self._weight: list[int] = [0] * n
13
14 def root(self, x: int) -> int:
15 path = [x]
16 while self._parents[x] >= 0:
17 x = self._parents[x]
18 path.append(x)
19 a = path.pop()
20 while path:
21 x = path.pop()
22 self._weight[x] += self._weight[self._parents[x]]
23 self._parents[x] = a
24 return a
25
26 def unite(self, x: int, y: int, w: int) -> Optional[int]:
27 """Untie x and y, weight[y] = weight[x] + w. / O(α(N))"""
28 rx = self.root(x)
29 ry = self.root(y)
30 if rx == ry:
31 return rx if self.diff(x, y) == w else None
32 w += self._weight[x] - self._weight[y]
33 self._group_numbers -= 1
34 if self._parents[rx] > self._parents[ry]:
35 rx, ry = ry, rx
36 w = -w
37 self._parents[rx] += self._parents[ry]
38 self._parents[ry] = rx
39 self._weight[ry] = w
40 return rx
41
42 def size(self, x: int) -> int:
43 return -self._parents[self.root(x)]
44
45 def same(self, x: int, y: int) -> bool:
46 return self.root(x) == self.root(y)
47
48 def members(self, x: int) -> list[int]:
49 x = self.root(x)
50 return [i for i in range(self._n) if self.root(i) == x]
51
52 def all_roots(self) -> list[int]:
53 return [i for i, x in enumerate(self._parents) if x < 0]
54
55 def group_count(self) -> int:
56 return self._group_numbers
57
58 def all_group_members(self) -> defaultdict:
59 group_members = defaultdict(list)
60 for member in range(self._n):
61 group_members[self.root(member)].append(member)
62 return group_members
63
64 def clear(self) -> None:
65 self._group_numbers = self._n
66 for i in range(self._n):
67 # self._G[i].clear()
68 self._parents[i] = -1
69
70 def diff(self, x: int, y: int) -> Optional[int]:
71 """weight[y] - weight[x]"""
72 if not self.same(x, y):
73 return None
74 return self._weight[y] - self._weight[x]
75
76 def __str__(self) -> str:
77 return (
78 "<WeightedUnionFind> [\n"
79 + "\n".join(f" {k}: {v}" for k, v in self.all_group_members().items())
80 + "\n]"
81 )