1 from titan_pylib.my_class.ordered_set_interface import OrderedSetInterface
2 from titan_pylib.my_class.supports_less_than import SupportsLessThan
3 from titan_pylib.data_structures.bst_base.bst_set_node_base import BSTSetNodeBase
4 from typing import Generic , Iterable , TypeVar , Optional
5
6 T = TypeVar ( "T" , bound = SupportsLessThan )
7
8
[docs]
9 class TreapSet ( OrderedSetInterface , Generic [ T ]):
10 """treap です。
11
12 乱数を使用して平衡を保っています。Hackされることなんてあるんですかね。今のところ集合と多重集合しかないです。
13 """
14
[docs]
15 class Random :
16
17 _x , _y , _z , _w = 123456789 , 362436069 , 521288629 , 88675123
18
[docs]
19 @classmethod
20 def random ( cls ) -> int :
21 t = ( cls . _x ^ (( cls . _x << 11 ) & 0xFFFFFFFF )) & 0xFFFFFFFF
22 cls . _x , cls . _y , cls . _z = cls . _y , cls . _z , cls . _w
23 cls . _w = ( cls . _w ^ ( cls . _w >> 19 )) ^ (
24 t ^ (( t >> 8 )) & 0xFFFFFFFF
25 ) & 0xFFFFFFFF
26 return cls . _w
27
[docs]
28 class Node :
29
30 def __init__ ( self , key : T , priority : int = - 1 ):
31 self . key : T = key
32 self . left : Optional [ "TreapSet.Node" ] = None
33 self . right : Optional [ "TreapSet.Node" ] = None
34 self . priority : int = (
35 TreapSet . Random . random () if priority == - 1 else priority
36 )
37
38 def __str__ ( self ):
39 if self . left is None and self . right is None :
40 return f "key: { self . key , self . priority } \n "
41 return f "key: { self . key , self . priority } , \n left: { self . left } , \n right: { self . right } \n "
42
43 def __init__ ( self , a : Iterable [ T ] = []):
44 self . root : Optional [ "TreapSet.Node" ] = None
45 self . _len : int = 0
46 if not isinstance ( a , list ):
47 a = list ( a )
48 if a :
49 self . _build ( a )
50
51 def _build ( self , a : list [ T ]) -> None :
52 Node = TreapSet . Node
53
54 def rec ( l : int , r : int ) -> TreapSet . Node :
55 mid = ( l + r ) >> 1
56 node = Node ( a [ mid ], rand [ mid ])
57 if l != mid :
58 node . left = rec ( l , mid )
59 if mid + 1 != r :
60 node . right = rec ( mid + 1 , r )
61 return node
62
63 a = BSTSetNodeBase [ T , TreapSet . Node ] . sort_unique ( a )
64 self . _len = len ( a )
65 rand = sorted ( TreapSet . Random . random () for _ in range ( self . _len ))
66 self . root = rec ( 0 , self . _len )
67
68 def _rotate_L ( self , node : Node ) -> Node :
69 u = node . left
70 node . left = u . right
71 u . right = node
72 return u
73
74 def _rotate_R ( self , node : Node ) -> Node :
75 u = node . right
76 node . right = u . left
77 u . left = node
78 return u
79
[docs]
80 def add ( self , key : T ) -> bool :
81 if not self . root :
82 self . root = TreapSet . Node ( key )
83 self . _len = 1
84 return True
85 node = self . root
86 path = []
87 di = 0
88 while node :
89 if key == node . key :
90 return False
91 path . append ( node )
92 if key < node . key :
93 di <<= 1
94 di |= 1
95 node = node . left
96 else :
97 di <<= 1
98 node = node . right
99 if di & 1 :
100 path [ - 1 ] . left = TreapSet . Node ( key )
101 else :
102 path [ - 1 ] . right = TreapSet . Node ( key )
103 while path :
104 new_node = None
105 node = path . pop ()
106 if di & 1 :
107 if node . left . priority < node . priority :
108 new_node = self . _rotate_L ( node )
109 else :
110 if node . right . priority < node . priority :
111 new_node = self . _rotate_R ( node )
112 di >>= 1
113 if new_node :
114 if path :
115 if di & 1 :
116 path [ - 1 ] . left = new_node
117 else :
118 path [ - 1 ] . right = new_node
119 else :
120 self . root = new_node
121 self . _len += 1
122 return True
123
[docs]
124 def discard ( self , key : T ) -> bool :
125 node = self . root
126 pnode = None
127 while node :
128 if key == node . key :
129 break
130 pnode = node
131 node = node . left if key < node . key else node . right
132 else :
133 return False
134 self . _len -= 1
135 while node . left and node . right :
136 if node . left . priority < node . right . priority :
137 if not pnode :
138 pnode = self . _rotate_L ( node )
139 self . root = pnode
140 continue
141 new_node = self . _rotate_L ( node )
142 if node . key < pnode . key :
143 pnode . left = new_node
144 else :
145 pnode . right = new_node
146 else :
147 if not pnode :
148 pnode = self . _rotate_R ( node )
149 self . root = pnode
150 continue
151 new_node = self . _rotate_R ( node )
152 if node . key < pnode . key :
153 pnode . left = new_node
154 else :
155 pnode . right = new_node
156 pnode = new_node
157 if not pnode :
158 if node . left is None :
159 self . root = node . right
160 else :
161 self . root = node . left
162 return True
163 if node . left is None :
164 if node . key < pnode . key :
165 pnode . left = node . right
166 else :
167 pnode . right = node . right
168 else :
169 if node . key < pnode . key :
170 pnode . left = node . left
171 else :
172 pnode . right = node . left
173 return True
174
[docs]
175 def remove ( self , key : T ) -> None :
176 if self . discard ( key ):
177 return
178 raise KeyError ( key )
179
[docs]
180 def le ( self , key : T ) -> Optional [ T ]:
181 return BSTSetNodeBase [ T , TreapSet . Node ] . le ( self . root , key )
182
[docs]
183 def lt ( self , key : T ) -> Optional [ T ]:
184 return BSTSetNodeBase [ T , TreapSet . Node ] . lt ( self . root , key )
185
[docs]
186 def ge ( self , key : T ) -> Optional [ T ]:
187 return BSTSetNodeBase [ T , TreapSet . Node ] . ge ( self . root , key )
188
[docs]
189 def gt ( self , key : T ) -> Optional [ T ]:
190 return BSTSetNodeBase [ T , TreapSet . Node ] . gt ( self . root , key )
191
[docs]
192 def get_min ( self ) -> Optional [ T ]:
193 return BSTSetNodeBase [ T , TreapSet . Node ] . get_min ( self . root )
194
[docs]
195 def get_max ( self ) -> Optional [ T ]:
196 return BSTSetNodeBase [ T , TreapSet . Node ] . get_max ( self . root )
197
[docs]
198 def pop_min ( self ) -> T :
199 assert self . root , f "IndexError: pop_min() from Empty { self . __class__ . __name__ } ."
200 node = self . root
201 pnode = None
202 while node . left :
203 pnode = node
204 node = node . left
205 self . _len -= 1
206 res = node . key
207 if not pnode :
208 self . root = self . root . right
209 else :
210 pnode . left = node . right
211 return res
212
[docs]
213 def pop_max ( self ) -> T :
214 assert self . root , f "IndexError: pop_max() from Empty { self . __class__ . __name__ } ."
215 node = self . root
216 pnode = None
217 while node . right :
218 pnode = node
219 node = node . right
220 self . _len -= 1
221 res = node . key
222 if not pnode :
223 self . root = self . root . left
224 else :
225 pnode . right = node . left
226 return res
227
[docs]
228 def clear ( self ) -> None :
229 self . root = None
230
[docs]
231 def tolist ( self ) -> list [ T ]:
232 return BSTSetNodeBase [ T , TreapSet . Node ] . tolist ( self . root )
233
234 def __iter__ ( self ):
235 self . _it = self . get_min ()
236 return self
237
238 def __next__ ( self ):
239 if self . _it is None :
240 raise StopIteration
241 res = self . _it
242 self . _it = self . gt ( self . _it )
243 return res
244
245 def __contains__ ( self , key : T ):
246 return BSTSetNodeBase [ T , TreapSet . Node ] . contains ( self . root , key )
247
248 def __len__ ( self ):
249 return self . _len
250
251 def __bool__ ( self ):
252 return self . _len > 0
253
254 def __str__ ( self ):
255 return "{" + ", " . join ( map ( str , self . tolist ())) + "}"
256
257 def __repr__ ( self ):
258 return f " { self . __class__ . __name__ } ( { self . tolist () } )"
Copy to clipboard