Toggle Light / Dark / Auto color theme
Toggle table of contents sidebar
Source code for titan_pylib.data_structures.union_find.weighted_union_find
1 from typing import Optional
2 from collections import defaultdict
3
4
[docs]
5 class WeightedUnionFind :
6
7 def __init__ ( self , n : int ):
8 self . _n : int = n
9 self . _group_numbers : int = n
10 self . _parents : list [ int ] = [ - 1 ] * n
11 self . _weight : list [ int ] = [ 0 ] * n
12
[docs]
13 def root ( self , x : int ) -> int :
14 path = [ x ]
15 while self . _parents [ x ] >= 0 :
16 x = self . _parents [ x ]
17 path . append ( x )
18 a = path . pop ()
19 while path :
20 x = path . pop ()
21 self . _weight [ x ] += self . _weight [ self . _parents [ x ]]
22 self . _parents [ x ] = a
23 return a
24
[docs]
25 def unite ( self , x : int , y : int , w : int ) -> Optional [ int ]:
26 """Untie x and y, weight[y] = weight[x] + w. / O(α(N))"""
27 rx = self . root ( x )
28 ry = self . root ( y )
29 if rx == ry :
30 return rx if self . diff ( x , y ) == w else None
31 w += self . _weight [ x ] - self . _weight [ y ]
32 self . _group_numbers -= 1
33 if self . _parents [ rx ] > self . _parents [ ry ]:
34 rx , ry = ry , rx
35 w = - w
36 self . _parents [ rx ] += self . _parents [ ry ]
37 self . _parents [ ry ] = rx
38 self . _weight [ ry ] = w
39 return rx
40
[docs]
41 def size ( self , x : int ) -> int :
42 return - self . _parents [ self . root ( x )]
43
[docs]
44 def same ( self , x : int , y : int ) -> bool :
45 return self . root ( x ) == self . root ( y )
46
[docs]
47 def members ( self , x : int ) -> list [ int ]:
48 x = self . root ( x )
49 return [ i for i in range ( self . _n ) if self . root ( i ) == x ]
50
[docs]
51 def all_roots ( self ) -> list [ int ]:
52 return [ i for i , x in enumerate ( self . _parents ) if x < 0 ]
53
[docs]
54 def group_count ( self ) -> int :
55 return self . _group_numbers
56
[docs]
57 def all_group_members ( self ) -> defaultdict :
58 group_members = defaultdict ( list )
59 for member in range ( self . _n ):
60 group_members [ self . root ( member )] . append ( member )
61 return group_members
62
[docs]
63 def clear ( self ) -> None :
64 self . _group_numbers = self . _n
65 for i in range ( self . _n ):
66 # self._G[i].clear()
67 self . _parents [ i ] = - 1
68
[docs]
69 def diff ( self , x : int , y : int ) -> Optional [ int ]:
70 """weight[y] - weight[x]"""
71 if not self . same ( x , y ):
72 return None
73 return self . _weight [ y ] - self . _weight [ x ]
74
75 def __str__ ( self ) -> str :
76 return (
77 "<WeightedUnionFind> [ \n "
78 + " \n " . join ( f " { k } : { v } " for k , v in self . all_group_members () . items ())
79 + " \n ]"
80 )