1 from titan_pylib.data_structures.wbt._wbt_multiset_node import _WBTMultisetNode
2 from typing import Generic , TypeVar , Optional , Iterable , Iterator
3
4 T = TypeVar ( "T" )
5
6
[docs]
7 class WBTMultiset ( Generic [ T ]):
8
9 __slots__ = "_root" , "_min" , "_max"
10
11 def __init__ ( self , a : Iterable [ T ] = []) -> None :
12 self . _root : Optional [ _WBTMultisetNode [ T ]] = None
13 self . _min : Optional [ _WBTMultisetNode [ T ]] = None
14 self . _max : Optional [ _WBTMultisetNode [ T ]] = None
15 self . __build ( a )
16
17 def __build ( self , a : Iterable [ T ]) -> None :
18 def build (
19 l : int , r : int , pnode : Optional [ _WBTMultisetNode [ T ]] = None
20 ) -> _WBTMultisetNode [ T ]:
21 if l == r :
22 return None
23 mid = ( l + r ) // 2
24 node = _WBTMultisetNode ( keys [ mid ], vals [ mid ])
25 node . _left = build ( l , mid , node )
26 node . _right = build ( mid + 1 , r , node )
27 node . _par = pnode
28 node . _update ()
29 return node
30
31 a = list ( a )
32 if not a :
33 return
34 if not all ( a [ i ] <= a [ i + 1 ] for i in range ( len ( a ) - 1 )):
35 a . sort ()
36 # RLE
37 keys , vals = [ a [ 0 ]], [ 1 ]
38 for i , elm in enumerate ( a ):
39 if i == 0 :
40 continue
41 if elm == keys [ - 1 ]:
42 vals [ - 1 ] += 1
43 continue
44 keys . append ( elm )
45 vals . append ( 1 )
46 self . _root = build ( 0 , len ( keys ))
47 self . _max = self . _root . _max ()
48 self . _min = self . _root . _min ()
49
[docs]
50 def add ( self , key : T , count : int = 1 ) -> None :
51 if not self . _root :
52 self . _root = _WBTMultisetNode ( key , count )
53 self . _max = self . _root
54 self . _min = self . _root
55 return
56 pnode = None
57 node = self . _root
58 while node :
59 node . _count_size += count
60 if key == node . _key :
61 node . _count += count
62 return
63 pnode = node
64 node = node . _left if key < node . _key else node . _right
65 if key < pnode . _key :
66 pnode . _left = _WBTMultisetNode ( key , count )
67 if key < self . _min . _key :
68 self . _min = pnode . _left
69 pnode . _left . _par = pnode
70 else :
71 pnode . _right = _WBTMultisetNode ( key , count )
72 if key > self . _max . _key :
73 self . _max = pnode . _right
74 pnode . _right . _par = pnode
75 self . _root = pnode . _rebalance ()
76
[docs]
77 def find_key ( self , key : T ) -> Optional [ _WBTMultisetNode [ T ]]:
78 node = self . _root
79 while node :
80 if key == node . _key :
81 return node
82 node = node . _left if key < node . _key else node . _right
83 return None
84
[docs]
85 def find_order ( self , k : int ) -> _WBTMultisetNode [ T ]:
86 node = self . _root
87 while True :
88 t = node . _left . _count_size + node . _count if node . _left else node . _count
89 if t - node . _count <= k < t :
90 return node
91 if t > k :
92 node = node . _left
93 else :
94 node = node . _right
95 k -= t
96
[docs]
97 def count ( self , key : T ) -> int :
98 node = self . find_key ( key )
99 return node . count if node is not None else 0
100
[docs]
101 def remove_iter ( self , node : _WBTMultisetNode [ T ]) -> None :
102 if node is self . _min :
103 self . _min = self . _min . _next ()
104 if node is self . _max :
105 self . _max = self . _max . _prev ()
106 delnode = node
107 pnode , mnode = node . _par , None
108 if node . _left and node . _right :
109 pnode , mnode = node , node . _left
110 while mnode . _right :
111 pnode , mnode = mnode , mnode . _right
112 node . _count = mnode . _count
113 node = mnode
114 cnode = node . _right if not node . _left else node . _left
115 if cnode :
116 cnode . _par = pnode
117 if pnode :
118 if pnode . _left is node :
119 pnode . _left = cnode
120 else :
121 pnode . _right = cnode
122 self . _root = pnode . _rebalance ()
123 else :
124 self . _root = cnode
125 if mnode :
126 if self . _root is delnode :
127 self . _root = mnode
128 mnode . _copy_from ( delnode )
129 del delnode
130
[docs]
131 def remove ( self , key : T , count : int = 1 ) -> None :
132 node = self . find_key ( key )
133 assert node , f "KeyError: { key } is not found."
134 if node . _count <= count :
135 self . remove_iter ( node )
136 else :
137 node . _count -= count
138 while node :
139 node . _count_size -= count
140 node = node . _par
141
[docs]
142 def discard ( self , key : T , count : int = 1 ) -> bool :
143 node = self . find_key ( key )
144 if node is None :
145 return False
146 if node . _count <= count :
147 self . remove_iter ( node )
148 else :
149 node . _count -= count
150 while node :
151 node . _count_size -= count
152 node = node . _par
153 return True
154
[docs]
155 def pop ( self , k : int = - 1 ) -> T :
156 node = self . find_order ( k )
157 key = node . _key
158 if node . _count == 0 :
159 self . remove_iter ( node )
160 else :
161 node . _count -= 1
162 while node :
163 node . _count_size -= 1
164 node = node . _par
165 return key
166
[docs]
167 def le_iter ( self , key : T ) -> Optional [ _WBTMultisetNode [ T ]]:
168 res = None
169 node = self . _root
170 while node :
171 if key == node . _key :
172 res = node
173 break
174 if key < node . _key :
175 node = node . _left
176 else :
177 res = node
178 node = node . _right
179 return res
180
[docs]
181 def lt_iter ( self , key : T ) -> Optional [ _WBTMultisetNode [ T ]]:
182 res = None
183 node = self . _root
184 while node :
185 if key <= node . _key :
186 node = node . _left
187 else :
188 res = node
189 node = node . _right
190 return res
191
[docs]
192 def ge_iter ( self , key : T ) -> Optional [ _WBTMultisetNode [ T ]]:
193 res = None
194 node = self . _root
195 while node :
196 if key == node . _key :
197 res = node
198 break
199 if key < node . _key :
200 res = node
201 node = node . _left
202 else :
203 node = node . _right
204 return res
205
[docs]
206 def gt_iter ( self , key : T ) -> Optional [ _WBTMultisetNode [ T ]]:
207 res = None
208 node = self . _root
209 while node :
210 if key < node . _key :
211 res = node
212 node = node . _left
213 else :
214 node = node . _right
215 return res
216
[docs]
217 def le ( self , key : T ) -> Optional [ T ]:
218 res = None
219 node = self . _root
220 while node :
221 if key == node . _key :
222 res = key
223 break
224 if key < node . _key :
225 node = node . _left
226 else :
227 res = node . _key
228 node = node . _right
229 return res
230
[docs]
231 def lt ( self , key : T ) -> Optional [ T ]:
232 res = None
233 node = self . _root
234 while node :
235 if key <= node . _key :
236 node = node . _left
237 else :
238 res = node . _key
239 node = node . _right
240 return res
241
[docs]
242 def ge ( self , key : T ) -> Optional [ T ]:
243 res = None
244 node = self . _root
245 while node :
246 if key == node . _key :
247 res = key
248 break
249 if key < node . _key :
250 res = node . _key
251 node = node . _left
252 else :
253 node = node . _right
254 return res
255
[docs]
256 def gt ( self , key : T ) -> Optional [ T ]:
257 res = None
258 node = self . _root
259 while node :
260 if key < node . _key :
261 res = node . _key
262 node = node . _left
263 else :
264 node = node . _right
265 return res
266
[docs]
267 def index ( self , key : T ) -> int :
268 k = 0
269 node = self . _root
270 while node :
271 if key == node . _key :
272 k += node . _left . _count_size if node . _left else 0
273 break
274 if key < node . _key :
275 node = node . _left
276 else :
277 k += node . _left . _count_size + node . _count if node . _left else node . _count
278 node = node . _right
279 return k
280
[docs]
281 def index_right ( self , key : T ) -> int :
282 k = 0
283 node = self . _root
284 while node :
285 if key == node . _key :
286 k += node . _left . _count_size + node . _count if node . _left else node . _count
287 break
288 if key < node . _key :
289 node = node . _left
290 else :
291 k += node . _left . _count_size + node . _count if node . _left else node . _count
292 node = node . _right
293 return k
294
[docs]
295 def tolist ( self ) -> list [ T ]:
296 return list ( self )
297
[docs]
298 def get_min ( self ) -> T :
299 assert self . _min
300 return self . _min . _key
301
[docs]
302 def get_max ( self ) -> T :
303 assert self . _max
304 return self . _max . _key
305
[docs]
306 def pop_min ( self ) -> T :
307 assert self . _min
308 key = self . _min . _key
309 self . _min . _count -= 1
310 if self . _min . _count == 0 :
311 self . remove_iter ( self . _min )
312 return key
313
[docs]
314 def pop_max ( self ) -> T :
315 assert self . _max
316 key = self . _max . _key
317 self . _max . _count -= 1
318 if self . _max . _count == 0 :
319 self . remove_iter ( self . _max )
320 return key
321
[docs]
322 def check ( self ) -> None :
323 if self . _root is None :
324 # print("ok. 0 (empty)")
325 return
326
327 # _size, count_size, height
328 def dfs ( node : _WBTMultisetNode [ T ]) -> tuple [ int , int , int ]:
329 h = 0
330 s = 1
331 cs = node . count
332 if node . _left :
333 assert node . _key > node . _left . _key
334 ls , lcs , lh = dfs ( node . _left )
335 s += ls
336 cs += lcs
337 h = max ( h , lh )
338 if node . _right :
339 assert node . _key < node . _right . _key
340 rs , rcs , rh = dfs ( node . _right )
341 s += rs
342 cs += rcs
343 h = max ( h , rh )
344 assert node . _size == s
345 assert node . _count_size == cs
346 node . _balance_check ()
347 return s , cs , h + 1
348
349 _ , _ , h = dfs ( self . _root )
350 # print(f"ok. {h}")
351
352 def __contains__ ( self , key : T ) -> bool :
353 return self . find_key ( key ) is not None
354
355 def __getitem__ ( self , k : int ) -> T :
356 assert (
357 - len ( self ) <= k < len ( self )
358 ), f "IndexError: { self . __class__ . __name__ } [ { k } ], len= { len ( self ) } "
359 if k < 0 :
360 k += len ( self )
361 if k == 0 :
362 return self . get_min ()
363 if k == len ( self ) - 1 :
364 return self . get_max ()
365 return self . find_order ( k ) . _key
366
367 def __delitem__ ( self , k : int ) -> None :
368 node = self . find_order ( k )
369 node . _count -= 1
370 if node . _count == 0 :
371 self . remove_iter ( node )
372
373 def __len__ ( self ) -> int :
374 return self . _root . _count_size if self . _root else 0
375
376 def __iter__ ( self ) -> Iterator [ T ]:
377 stack : list [ _WBTMultisetNode [ T ]] = []
378 node = self . _root
379 while stack or node :
380 if node :
381 stack . append ( node )
382 node = node . _left
383 else :
384 node = stack . pop ()
385 for _ in range ( node . _count ):
386 yield node . _key
387 node = node . _right
388
389 def __reversed__ ( self ) -> Iterator [ T ]:
390 stack : list [ _WBTMultisetNode [ T ]] = []
391 node = self . _root
392 while stack or node :
393 if node :
394 stack . append ( node )
395 node = node . _right
396 else :
397 node = stack . pop ()
398 for _ in range ( node . _count ):
399 yield node . _key
400 node = node . _left
401
402 def __str__ ( self ) -> str :
403 return "{" + ", " . join ( map ( str , self )) + "}"
404
405 def __repr__ ( self ) -> str :
406 return (
407 f " { self . __class__ . __name__ } ("
408 + "["
409 + ", " . join ( map ( str , self . tolist ()))
410 + "])"
411 )
Copy to clipboard