1from titan_pylib.my_class.ordered_multiset_interface import OrderedMultisetInterface
2from titan_pylib.my_class.supports_less_than import SupportsLessThan
3from typing import Generic, Iterable, Iterator, TypeVar, Optional
4
5T = TypeVar("T", bound=SupportsLessThan)
6
7
[docs]
8class AVLTreeMultiset3(OrderedMultisetInterface, Generic[T]):
9 """
10 多重集合としての AVL 木です。
11 ``class Node()`` を用いています。
12 """
13
[docs]
14 class Node:
15
16 def __init__(self, key: T, val: int):
17 self.key: T = key
18 self.val: int = val
19 self.valsize: int = val
20 self.size: int = 1
21 self.left: Optional["AVLTreeMultiset3.Node"] = None
22 self.right: Optional["AVLTreeMultiset3.Node"] = None
23 self.balance: int = 0
24
25 def __str__(self):
26 if self.left is None and self.right is None:
27 return f"key:{self.key, self.val, self.size, self.valsize}\n"
28 return f"key:{self.key, self.val, self.size, self.valsize},\n left:{self.left},\n right:{self.right}\n"
29
30 def __init__(self, a: Iterable[T] = []):
31 self.node: Optional["AVLTreeMultiset3.Node"] = None
32 if a:
33 self._build(a)
34
35 def _rle(self, L: list[T]) -> tuple[list[T], list[int]]:
36 x, y = [L[0]], [1]
37 for i, a in enumerate(L):
38 if i == 0:
39 continue
40 if a == x[-1]:
41 y[-1] += 1
42 continue
43 x.append(a)
44 y.append(1)
45 return x, y
46
47 def _build(self, a: Iterable[T]) -> None:
48 Node = AVLTreeMultiset3.Node
49
50 def sort(l: int, r: int) -> tuple[Node, int]:
51 mid = (l + r) >> 1
52 node = Node(x[mid], y[mid])
53 h = 0
54 if l != mid:
55 left, hl = sort(l, mid)
56 node.left = left
57 node.size += left.size
58 node.valsize += left.valsize
59 node.balance = hl
60 h = hl
61 if mid + 1 != r:
62 right, hr = sort(mid + 1, r)
63 node.right = right
64 node.size += right.size
65 node.valsize += right.valsize
66 node.balance -= hr
67 if hr > h:
68 h = hr
69 return node, h + 1
70
71 a = sorted(a)
72 if not a:
73 return
74 x, y = self._rle(a)
75 self.node = sort(0, len(x))[0]
76
77 def _rotate_L(self, node: Node) -> Node:
78 u = node.left
79 u.size = node.size
80 u.valsize = node.valsize
81 if u.left is None:
82 node.size -= 1
83 node.valsize -= u.val
84 else:
85 node.size -= u.left.size + 1
86 node.valsize -= u.left.valsize + u.val
87 node.left = u.right
88 u.right = node
89 if u.balance == 1:
90 u.balance = 0
91 node.balance = 0
92 else:
93 u.balance = -1
94 node.balance = 1
95 return u
96
97 def _rotate_R(self, node: Node) -> Node:
98 u = node.right
99 u.size = node.size
100 u.valsize = node.valsize
101 if u.right is None:
102 node.size -= 1
103 node.valsize -= u.val
104 else:
105 node.size -= u.right.size + 1
106 node.valsize -= u.right.valsize + u.val
107 node.right = u.left
108 u.left = node
109 if u.balance == -1:
110 u.balance = 0
111 node.balance = 0
112 else:
113 u.balance = 1
114 node.balance = -1
115 return u
116
117 def _update_balance(self, node: Node) -> None:
118 if node.balance == 1:
119 node.right.balance = -1
120 node.left.balance = 0
121 elif node.balance == -1:
122 node.right.balance = 0
123 node.left.balance = 1
124 else:
125 node.right.balance = 0
126 node.left.balance = 0
127 node.balance = 0
128
129 def _rotate_LR(self, node: Node) -> Node:
130 B = node.left
131 E = B.right
132 E.size = node.size
133 E.valsize = node.valsize
134 if E.right is None:
135 node.size -= B.size
136 node.valsize -= B.valsize
137 B.size -= 1
138 B.valsize -= E.val
139 else:
140 node.size -= B.size - E.right.size
141 node.valsize -= B.valsize - E.right.valsize
142 B.size -= E.right.size + 1
143 B.valsize -= E.right.valsize + E.val
144 B.right = E.left
145 E.left = B
146 node.left = E.right
147 E.right = node
148 self._update_balance(E)
149 return E
150
151 def _rotate_RL(self, node: Node) -> Node:
152 C = node.right
153 D = C.left
154 D.size = node.size
155 D.valsize = node.valsize
156 if D.left is None:
157 node.size -= C.size
158 node.valsize -= C.valsize
159 C.size -= 1
160 C.valsize -= D.val
161 else:
162 node.size -= C.size - D.left.size
163 node.valsize -= C.valsize - D.left.valsize
164 C.size -= D.left.size + 1
165 C.valsize -= D.left.valsize + D.val
166 C.left = D.right
167 D.right = C
168 node.right = D.left
169 D.left = node
170 self._update_balance(D)
171 return D
172
173 def _kth_elm(self, k: int) -> tuple[T, int]:
174 if k < 0:
175 k += len(self)
176 node = self.node
177 while True:
178 t = node.val if node.left is None else node.val + node.left.valsize
179 if t - node.val <= k < t:
180 return node.key, node.val
181 elif t > k:
182 node = node.left
183 else:
184 node = node.right
185 k -= t
186
187 def _kth_elm_tree(self, k: int) -> tuple[T, int]:
188 if k < 0:
189 k += self.len_elm()
190 assert 0 <= k < self.len_elm()
191 node = self.node
192 while True:
193 t = 0 if node.left is None else node.left.size
194 if t == k:
195 return node.key, node.val
196 elif t > k:
197 node = node.left
198 else:
199 node = node.right
200 k -= t + 1
201
202 def _discard(self, node: Node, path: list[Node], di: int) -> bool:
203 fdi = 0
204 if node.left is not None and node.right is not None:
205 path.append(node)
206 di <<= 1
207 di |= 1
208 lmax = node.left
209 while lmax.right is not None:
210 path.append(lmax)
211 di <<= 1
212 fdi <<= 1
213 fdi |= 1
214 lmax = lmax.right
215 lmax_val = lmax.val
216 node.key = lmax.key
217 node.val = lmax_val
218 node = lmax
219 cnode = node.right if node.left is None else node.left
220 if path:
221 if di & 1:
222 path[-1].left = cnode
223 else:
224 path[-1].right = cnode
225 else:
226 self.node = cnode
227 return True
228 while path:
229 new_node = None
230 pnode = path.pop()
231 pnode.balance -= 1 if di & 1 else -1
232 pnode.size -= 1
233 pnode.valsize -= lmax_val if fdi & 1 else 1
234 di >>= 1
235 fdi >>= 1
236 if pnode.balance == 2:
237 new_node = (
238 self._rotate_LR(pnode)
239 if pnode.left.balance < 0
240 else self._rotate_L(pnode)
241 )
242 elif pnode.balance == -2:
243 new_node = (
244 self._rotate_RL(pnode)
245 if pnode.right.balance > 0
246 else self._rotate_R(pnode)
247 )
248 elif pnode.balance != 0:
249 break
250 if new_node is not None:
251 if not path:
252 self.node = new_node
253 return
254 if di & 1:
255 path[-1].left = new_node
256 else:
257 path[-1].right = new_node
258 if new_node.balance != 0:
259 break
260 while path:
261 pnode = path.pop()
262 pnode.size -= 1
263 pnode.valsize -= lmax_val if fdi & 1 else 1
264 fdi >>= 1
265 return True
266
[docs]
267 def discard(self, key: T, val: int = 1) -> bool:
268 path = []
269 di = 0
270 node = self.node
271 while node is not None:
272 if key == node.key:
273 break
274 elif key < node.key:
275 path.append(node)
276 di <<= 1
277 di |= 1
278 node = node.left
279 else:
280 path.append(node)
281 di <<= 1
282 node = node.right
283 else:
284 return False
285 if val > node.val:
286 val = node.val - 1
287 node.val -= val
288 node.valsize -= val
289 for p in path:
290 p.valsize -= val
291 if node.val == 1:
292 self._discard(node, path, di)
293 else:
294 node.val -= val
295 node.valsize -= val
296 for p in path:
297 p.valsize -= val
298 return True
299
[docs]
300 def discard_all(self, key: T) -> None:
301 self.discard(key, self.count(key))
302
[docs]
303 def remove(self, key: T, val: int = 1) -> None:
304 if self.discard(key, val):
305 return
306 raise KeyError(key)
307
[docs]
308 def add(self, key: T, val: int = 1) -> None:
309 if self.node is None:
310 self.node = AVLTreeMultiset3.Node(key, val)
311 return
312 pnode = self.node
313 di = 0
314 path = []
315 while pnode is not None:
316 if key == pnode.key:
317 pnode.val += val
318 pnode.valsize += val
319 for p in path:
320 p.valsize += val
321 return
322 elif key < pnode.key:
323 path.append(pnode)
324 di <<= 1
325 di |= 1
326 pnode = pnode.left
327 else:
328 path.append(pnode)
329 di <<= 1
330 pnode = pnode.right
331 if di & 1:
332 path[-1].left = AVLTreeMultiset3.Node(key, val)
333 else:
334 path[-1].right = AVLTreeMultiset3.Node(key, val)
335 new_node = None
336 while path:
337 pnode = path.pop()
338 pnode.size += 1
339 pnode.valsize += val
340 pnode.balance += 1 if di & 1 else -1
341 di >>= 1
342 if pnode.balance == 0:
343 break
344 if pnode.balance == 2:
345 new_node = (
346 self._rotate_LR(pnode)
347 if pnode.left.balance < 0
348 else self._rotate_L(pnode)
349 )
350 break
351 elif pnode.balance == -2:
352 new_node = (
353 self._rotate_RL(pnode)
354 if pnode.right.balance > 0
355 else self._rotate_R(pnode)
356 )
357 break
358 if new_node is not None:
359 if path:
360 if di & 1:
361 path[-1].left = new_node
362 else:
363 path[-1].right = new_node
364 else:
365 self.node = new_node
366 for p in path:
367 p.size += 1
368 p.valsize += val
369
[docs]
370 def count(self, key: T) -> int:
371 node = self.node
372 while node is not None:
373 if node.key == key:
374 return node.val
375 elif key < node.key:
376 node = node.left
377 else:
378 node = node.right
379 return 0
380
[docs]
381 def le(self, key: T) -> Optional[T]:
382 res = None
383 node = self.node
384 while node is not None:
385 if key == node.key:
386 res = key
387 break
388 elif key < node.key:
389 node = node.left
390 else:
391 res = node.key
392 node = node.right
393 return res
394
[docs]
395 def lt(self, key: T) -> Optional[T]:
396 res = None
397 node = self.node
398 while node is not None:
399 if key <= node.key:
400 node = node.left
401 else:
402 res = node.key
403 node = node.right
404 return res
405
[docs]
406 def ge(self, key: T) -> Optional[T]:
407 res = None
408 node = self.node
409 while node is not None:
410 if key == node.key:
411 res = key
412 break
413 elif key < node.key:
414 res = node.key
415 node = node.left
416 else:
417 node = node.right
418 return res
419
[docs]
420 def gt(self, key: T) -> Optional[T]:
421 res = None
422 node = self.node
423 while node is not None:
424 if key < node.key:
425 res = node.key
426 node = node.left
427 else:
428 node = node.right
429 return res
430
[docs]
431 def index(self, key: T) -> int:
432 k = 0
433 node = self.node
434 while node is not None:
435 if key == node.key:
436 if node.left is not None:
437 k += node.left.valsize
438 break
439 elif key < node.key:
440 node = node.left
441 else:
442 k += node.val if node.left is None else node.left.valsize + node.val
443 node = node.right
444 return k
445
[docs]
446 def index_right(self, key: T) -> int:
447 k = 0
448 node = self.node
449 while node is not None:
450 if key == node.key:
451 k += node.val if node.left is None else node.left.valsize + node.val
452 break
453 elif key < node.key:
454 node = node.left
455 else:
456 k += node.val if node.left is None else node.left.valsize + node.val
457 node = node.right
458 return k
459
[docs]
460 def index_keys(self, key: T) -> int:
461 k = 0
462 node = self.node
463 while node:
464 if key == node.key:
465 if node.left is not None:
466 k += node.left.size
467 break
468 elif key < node.key:
469 node = node.left
470 else:
471 k += node.val if node.left is None else node.left.size + node.val
472 node = node.right
473 return k
474
[docs]
475 def index_right_keys(self, key: T) -> int:
476 k = 0
477 node = self.node
478 while node:
479 if key == node.key:
480 k += node.val if node.left is None else node.left.size + node.val
481 break
482 elif key < node.key:
483 node = node.left
484 else:
485 k += node.val if node.left is None else node.left.size + node.val
486 node = node.right
487 return k
488
[docs]
489 def get_min(self) -> Optional[T]:
490 if self.node is None:
491 return
492 node = self.node
493 while node.left is not None:
494 node = node.left
495 return node.key
496
[docs]
497 def get_max(self) -> Optional[T]:
498 if self.node is None:
499 return
500 node = self.node
501 while node.right is not None:
502 node = node.right
503 return node.key
504
[docs]
505 def pop(self, k: int = -1) -> T:
506 if k < 0:
507 k += self.node.valsize
508 node = self.node
509 path = []
510 if k == self.node.valsize - 1:
511 while node.right is not None:
512 path.append(node)
513 node = node.right
514 x = node.key
515 if node.val == 1:
516 self._discard(node, path, 0)
517 else:
518 node.val -= 1
519 node.valsize -= 1
520 for p in path:
521 p.valsize -= 1
522 return x
523 di = 0
524 while True:
525 t = node.val if node.left is None else node.val + node.left.valsize
526 if t - node.val <= k < t:
527 x = node.key
528 break
529 elif t > k:
530 path.append(node)
531 di <<= 1
532 di |= 1
533 node = node.left
534 else:
535 path.append(node)
536 di <<= 1
537 node = node.right
538 k -= t
539 if node.val == 1:
540 self._discard(node, path, di)
541 else:
542 node.val -= 1
543 node.valsize -= 1
544 for p in path:
545 p.valsize -= 1
546 return x
547
[docs]
548 def pop_max(self) -> T:
549 assert self
550 return self.pop()
551
[docs]
552 def pop_min(self) -> T:
553 node = self.node
554 path = []
555 while node.left is not None:
556 path.append(node)
557 node = node.left
558 x = node.key
559 if node.val == 1:
560 self._discard(node, path, (1 << len(path)) - 1)
561 else:
562 node.val -= 1
563 node.valsize -= 1
564 for p in path:
565 p.valsize -= 1
566 return x
567
[docs]
568 def items(self) -> Iterator[tuple[T, int]]:
569 for i in range(self.len_elm()):
570 yield self._kth_elm_tree(i)
571
[docs]
572 def keys(self) -> Iterator[T]:
573 for i in range(self.len_elm()):
574 yield self._kth_elm_tree(i)[0]
575
[docs]
576 def values(self) -> Iterator[int]:
577 for i in range(self.len_elm()):
578 yield self._kth_elm_tree(i)[1]
579
[docs]
580 def len_elm(self) -> int:
581 return 0 if self.node is None else self.node.size
582
[docs]
583 def show(self) -> None:
584 print(
585 "{" + ", ".join(map(lambda x: f"{x[0]}: {x[1]}", self.tolist_items())) + "}"
586 )
587
[docs]
588 def clear(self) -> None:
589 self.node = None
590
[docs]
591 def get_elm(self, k: int) -> T:
592 return self._kth_elm_tree(k)[0]
593
[docs]
594 def tolist(self) -> list[T]:
595 a = []
596 if self.node is None:
597 return a
598
599 def rec(node):
600 if node.left is not None:
601 rec(node.left)
602 a.extend([node.key] * node.val)
603 if node.right is not None:
604 rec(node.right)
605
606 rec(self.node)
607 return a
608
[docs]
609 def tolist_items(self) -> list[tuple[T, int]]:
610 a = []
611 if self.node is None:
612 return a
613
614 def rec(node):
615 if node.left is not None:
616 rec(node.left)
617 a.append((node.key, node.val))
618 if node.right is not None:
619 rec(node.right)
620
621 rec(self.node)
622 return a
623
624 def __getitem__(self, k: int):
625 return self._kth_elm(k)[0]
626
627 def __contains__(self, key: T):
628 node = self.node
629 while node:
630 if node.key == key:
631 return True
632 node = node.left if key < node.key else node.right
633 return False
634
635 def __iter__(self):
636 self.__iter = 0
637 return self
638
639 def __next__(self):
640 if self.__iter == len(self):
641 raise StopIteration
642 res = self._kth_elm(self.__iter)
643 self.__iter += 1
644 return res
645
646 def __reversed__(self):
647 for i in range(len(self)):
648 yield self._kth_elm(-i - 1)[0]
649
650 def __len__(self):
651 return 0 if self.node is None else self.node.valsize
652
653 def __bool__(self):
654 return self.node is not None
655
656 def __str__(self):
657 return "{" + ", ".join(map(str, self.tolist())) + "}"
658
659 def __repr__(self):
660 return f"AVLTreeMultiset3({self.tolist()})"