avl_tree_multiset3¶
ソースコード¶
from titan_pylib.data_structures.avl_tree.avl_tree_multiset3 import AVLTreeMultiset3
展開済みコード¶
1# from titan_pylib.data_structures.avl_tree.avl_tree_multiset3 import AVLTreeMultiset3
2# from titan_pylib.my_class.ordered_multiset_interface import OrderedMultisetInterface
3# from titan_pylib.my_class.supports_less_than import SupportsLessThan
4from typing import Protocol
5
6
7class SupportsLessThan(Protocol):
8
9 def __lt__(self, other) -> bool: ...
10from abc import ABC, abstractmethod
11from typing import Iterable, Optional, Iterator, TypeVar, Generic
12
13T = TypeVar("T", bound=SupportsLessThan)
14
15
16class OrderedMultisetInterface(ABC, Generic[T]):
17
18 @abstractmethod
19 def __init__(self, a: Iterable[T]) -> None:
20 raise NotImplementedError
21
22 @abstractmethod
23 def add(self, key: T, cnt: int) -> None:
24 raise NotImplementedError
25
26 @abstractmethod
27 def discard(self, key: T, cnt: int) -> bool:
28 raise NotImplementedError
29
30 @abstractmethod
31 def discard_all(self, key: T) -> bool:
32 raise NotImplementedError
33
34 @abstractmethod
35 def count(self, key: T) -> int:
36 raise NotImplementedError
37
38 @abstractmethod
39 def remove(self, key: T, cnt: int) -> None:
40 raise NotImplementedError
41
42 @abstractmethod
43 def le(self, key: T) -> Optional[T]:
44 raise NotImplementedError
45
46 @abstractmethod
47 def lt(self, key: T) -> Optional[T]:
48 raise NotImplementedError
49
50 @abstractmethod
51 def ge(self, key: T) -> Optional[T]:
52 raise NotImplementedError
53
54 @abstractmethod
55 def gt(self, key: T) -> Optional[T]:
56 raise NotImplementedError
57
58 @abstractmethod
59 def get_max(self) -> Optional[T]:
60 raise NotImplementedError
61
62 @abstractmethod
63 def get_min(self) -> Optional[T]:
64 raise NotImplementedError
65
66 @abstractmethod
67 def pop_max(self) -> T:
68 raise NotImplementedError
69
70 @abstractmethod
71 def pop_min(self) -> T:
72 raise NotImplementedError
73
74 @abstractmethod
75 def clear(self) -> None:
76 raise NotImplementedError
77
78 @abstractmethod
79 def tolist(self) -> list[T]:
80 raise NotImplementedError
81
82 @abstractmethod
83 def __iter__(self) -> Iterator:
84 raise NotImplementedError
85
86 @abstractmethod
87 def __next__(self) -> T:
88 raise NotImplementedError
89
90 @abstractmethod
91 def __contains__(self, key: T) -> bool:
92 raise NotImplementedError
93
94 @abstractmethod
95 def __len__(self) -> int:
96 raise NotImplementedError
97
98 @abstractmethod
99 def __bool__(self) -> bool:
100 raise NotImplementedError
101
102 @abstractmethod
103 def __str__(self) -> str:
104 raise NotImplementedError
105
106 @abstractmethod
107 def __repr__(self) -> str:
108 raise NotImplementedError
109# from titan_pylib.my_class.supports_less_than import SupportsLessThan
110from typing import Generic, Iterable, Iterator, TypeVar, Optional
111
112T = TypeVar("T", bound=SupportsLessThan)
113
114
115class AVLTreeMultiset3(OrderedMultisetInterface, Generic[T]):
116 """
117 多重集合としての AVL 木です。
118 ``class Node()`` を用いています。
119 """
120
121 class Node:
122
123 def __init__(self, key: T, val: int):
124 self.key: T = key
125 self.val: int = val
126 self.valsize: int = val
127 self.size: int = 1
128 self.left: Optional["AVLTreeMultiset3.Node"] = None
129 self.right: Optional["AVLTreeMultiset3.Node"] = None
130 self.balance: int = 0
131
132 def __str__(self):
133 if self.left is None and self.right is None:
134 return f"key:{self.key, self.val, self.size, self.valsize}\n"
135 return f"key:{self.key, self.val, self.size, self.valsize},\n left:{self.left},\n right:{self.right}\n"
136
137 def __init__(self, a: Iterable[T] = []):
138 self.node: Optional["AVLTreeMultiset3.Node"] = None
139 if a:
140 self._build(a)
141
142 def _rle(self, L: list[T]) -> tuple[list[T], list[int]]:
143 x, y = [L[0]], [1]
144 for i, a in enumerate(L):
145 if i == 0:
146 continue
147 if a == x[-1]:
148 y[-1] += 1
149 continue
150 x.append(a)
151 y.append(1)
152 return x, y
153
154 def _build(self, a: Iterable[T]) -> None:
155 Node = AVLTreeMultiset3.Node
156
157 def sort(l: int, r: int) -> tuple[Node, int]:
158 mid = (l + r) >> 1
159 node = Node(x[mid], y[mid])
160 h = 0
161 if l != mid:
162 left, hl = sort(l, mid)
163 node.left = left
164 node.size += left.size
165 node.valsize += left.valsize
166 node.balance = hl
167 h = hl
168 if mid + 1 != r:
169 right, hr = sort(mid + 1, r)
170 node.right = right
171 node.size += right.size
172 node.valsize += right.valsize
173 node.balance -= hr
174 if hr > h:
175 h = hr
176 return node, h + 1
177
178 a = sorted(a)
179 if not a:
180 return
181 x, y = self._rle(a)
182 self.node = sort(0, len(x))[0]
183
184 def _rotate_L(self, node: Node) -> Node:
185 u = node.left
186 u.size = node.size
187 u.valsize = node.valsize
188 if u.left is None:
189 node.size -= 1
190 node.valsize -= u.val
191 else:
192 node.size -= u.left.size + 1
193 node.valsize -= u.left.valsize + u.val
194 node.left = u.right
195 u.right = node
196 if u.balance == 1:
197 u.balance = 0
198 node.balance = 0
199 else:
200 u.balance = -1
201 node.balance = 1
202 return u
203
204 def _rotate_R(self, node: Node) -> Node:
205 u = node.right
206 u.size = node.size
207 u.valsize = node.valsize
208 if u.right is None:
209 node.size -= 1
210 node.valsize -= u.val
211 else:
212 node.size -= u.right.size + 1
213 node.valsize -= u.right.valsize + u.val
214 node.right = u.left
215 u.left = node
216 if u.balance == -1:
217 u.balance = 0
218 node.balance = 0
219 else:
220 u.balance = 1
221 node.balance = -1
222 return u
223
224 def _update_balance(self, node: Node) -> None:
225 if node.balance == 1:
226 node.right.balance = -1
227 node.left.balance = 0
228 elif node.balance == -1:
229 node.right.balance = 0
230 node.left.balance = 1
231 else:
232 node.right.balance = 0
233 node.left.balance = 0
234 node.balance = 0
235
236 def _rotate_LR(self, node: Node) -> Node:
237 B = node.left
238 E = B.right
239 E.size = node.size
240 E.valsize = node.valsize
241 if E.right is None:
242 node.size -= B.size
243 node.valsize -= B.valsize
244 B.size -= 1
245 B.valsize -= E.val
246 else:
247 node.size -= B.size - E.right.size
248 node.valsize -= B.valsize - E.right.valsize
249 B.size -= E.right.size + 1
250 B.valsize -= E.right.valsize + E.val
251 B.right = E.left
252 E.left = B
253 node.left = E.right
254 E.right = node
255 self._update_balance(E)
256 return E
257
258 def _rotate_RL(self, node: Node) -> Node:
259 C = node.right
260 D = C.left
261 D.size = node.size
262 D.valsize = node.valsize
263 if D.left is None:
264 node.size -= C.size
265 node.valsize -= C.valsize
266 C.size -= 1
267 C.valsize -= D.val
268 else:
269 node.size -= C.size - D.left.size
270 node.valsize -= C.valsize - D.left.valsize
271 C.size -= D.left.size + 1
272 C.valsize -= D.left.valsize + D.val
273 C.left = D.right
274 D.right = C
275 node.right = D.left
276 D.left = node
277 self._update_balance(D)
278 return D
279
280 def _kth_elm(self, k: int) -> tuple[T, int]:
281 if k < 0:
282 k += len(self)
283 node = self.node
284 while True:
285 t = node.val if node.left is None else node.val + node.left.valsize
286 if t - node.val <= k < t:
287 return node.key, node.val
288 elif t > k:
289 node = node.left
290 else:
291 node = node.right
292 k -= t
293
294 def _kth_elm_tree(self, k: int) -> tuple[T, int]:
295 if k < 0:
296 k += self.len_elm()
297 assert 0 <= k < self.len_elm()
298 node = self.node
299 while True:
300 t = 0 if node.left is None else node.left.size
301 if t == k:
302 return node.key, node.val
303 elif t > k:
304 node = node.left
305 else:
306 node = node.right
307 k -= t + 1
308
309 def _discard(self, node: Node, path: list[Node], di: int) -> bool:
310 fdi = 0
311 if node.left is not None and node.right is not None:
312 path.append(node)
313 di <<= 1
314 di |= 1
315 lmax = node.left
316 while lmax.right is not None:
317 path.append(lmax)
318 di <<= 1
319 fdi <<= 1
320 fdi |= 1
321 lmax = lmax.right
322 lmax_val = lmax.val
323 node.key = lmax.key
324 node.val = lmax_val
325 node = lmax
326 cnode = node.right if node.left is None else node.left
327 if path:
328 if di & 1:
329 path[-1].left = cnode
330 else:
331 path[-1].right = cnode
332 else:
333 self.node = cnode
334 return True
335 while path:
336 new_node = None
337 pnode = path.pop()
338 pnode.balance -= 1 if di & 1 else -1
339 pnode.size -= 1
340 pnode.valsize -= lmax_val if fdi & 1 else 1
341 di >>= 1
342 fdi >>= 1
343 if pnode.balance == 2:
344 new_node = (
345 self._rotate_LR(pnode)
346 if pnode.left.balance < 0
347 else self._rotate_L(pnode)
348 )
349 elif pnode.balance == -2:
350 new_node = (
351 self._rotate_RL(pnode)
352 if pnode.right.balance > 0
353 else self._rotate_R(pnode)
354 )
355 elif pnode.balance != 0:
356 break
357 if new_node is not None:
358 if not path:
359 self.node = new_node
360 return
361 if di & 1:
362 path[-1].left = new_node
363 else:
364 path[-1].right = new_node
365 if new_node.balance != 0:
366 break
367 while path:
368 pnode = path.pop()
369 pnode.size -= 1
370 pnode.valsize -= lmax_val if fdi & 1 else 1
371 fdi >>= 1
372 return True
373
374 def discard(self, key: T, val: int = 1) -> bool:
375 path = []
376 di = 0
377 node = self.node
378 while node is not None:
379 if key == node.key:
380 break
381 elif key < node.key:
382 path.append(node)
383 di <<= 1
384 di |= 1
385 node = node.left
386 else:
387 path.append(node)
388 di <<= 1
389 node = node.right
390 else:
391 return False
392 if val > node.val:
393 val = node.val - 1
394 node.val -= val
395 node.valsize -= val
396 for p in path:
397 p.valsize -= val
398 if node.val == 1:
399 self._discard(node, path, di)
400 else:
401 node.val -= val
402 node.valsize -= val
403 for p in path:
404 p.valsize -= val
405 return True
406
407 def discard_all(self, key: T) -> None:
408 self.discard(key, self.count(key))
409
410 def remove(self, key: T, val: int = 1) -> None:
411 if self.discard(key, val):
412 return
413 raise KeyError(key)
414
415 def add(self, key: T, val: int = 1) -> None:
416 if self.node is None:
417 self.node = AVLTreeMultiset3.Node(key, val)
418 return
419 pnode = self.node
420 di = 0
421 path = []
422 while pnode is not None:
423 if key == pnode.key:
424 pnode.val += val
425 pnode.valsize += val
426 for p in path:
427 p.valsize += val
428 return
429 elif key < pnode.key:
430 path.append(pnode)
431 di <<= 1
432 di |= 1
433 pnode = pnode.left
434 else:
435 path.append(pnode)
436 di <<= 1
437 pnode = pnode.right
438 if di & 1:
439 path[-1].left = AVLTreeMultiset3.Node(key, val)
440 else:
441 path[-1].right = AVLTreeMultiset3.Node(key, val)
442 new_node = None
443 while path:
444 pnode = path.pop()
445 pnode.size += 1
446 pnode.valsize += val
447 pnode.balance += 1 if di & 1 else -1
448 di >>= 1
449 if pnode.balance == 0:
450 break
451 if pnode.balance == 2:
452 new_node = (
453 self._rotate_LR(pnode)
454 if pnode.left.balance < 0
455 else self._rotate_L(pnode)
456 )
457 break
458 elif pnode.balance == -2:
459 new_node = (
460 self._rotate_RL(pnode)
461 if pnode.right.balance > 0
462 else self._rotate_R(pnode)
463 )
464 break
465 if new_node is not None:
466 if path:
467 if di & 1:
468 path[-1].left = new_node
469 else:
470 path[-1].right = new_node
471 else:
472 self.node = new_node
473 for p in path:
474 p.size += 1
475 p.valsize += val
476
477 def count(self, key: T) -> int:
478 node = self.node
479 while node is not None:
480 if node.key == key:
481 return node.val
482 elif key < node.key:
483 node = node.left
484 else:
485 node = node.right
486 return 0
487
488 def le(self, key: T) -> Optional[T]:
489 res = None
490 node = self.node
491 while node is not None:
492 if key == node.key:
493 res = key
494 break
495 elif key < node.key:
496 node = node.left
497 else:
498 res = node.key
499 node = node.right
500 return res
501
502 def lt(self, key: T) -> Optional[T]:
503 res = None
504 node = self.node
505 while node is not None:
506 if key <= node.key:
507 node = node.left
508 else:
509 res = node.key
510 node = node.right
511 return res
512
513 def ge(self, key: T) -> Optional[T]:
514 res = None
515 node = self.node
516 while node is not None:
517 if key == node.key:
518 res = key
519 break
520 elif key < node.key:
521 res = node.key
522 node = node.left
523 else:
524 node = node.right
525 return res
526
527 def gt(self, key: T) -> Optional[T]:
528 res = None
529 node = self.node
530 while node is not None:
531 if key < node.key:
532 res = node.key
533 node = node.left
534 else:
535 node = node.right
536 return res
537
538 def index(self, key: T) -> int:
539 k = 0
540 node = self.node
541 while node is not None:
542 if key == node.key:
543 if node.left is not None:
544 k += node.left.valsize
545 break
546 elif key < node.key:
547 node = node.left
548 else:
549 k += node.val if node.left is None else node.left.valsize + node.val
550 node = node.right
551 return k
552
553 def index_right(self, key: T) -> int:
554 k = 0
555 node = self.node
556 while node is not None:
557 if key == node.key:
558 k += node.val if node.left is None else node.left.valsize + node.val
559 break
560 elif key < node.key:
561 node = node.left
562 else:
563 k += node.val if node.left is None else node.left.valsize + node.val
564 node = node.right
565 return k
566
567 def index_keys(self, key: T) -> int:
568 k = 0
569 node = self.node
570 while node:
571 if key == node.key:
572 if node.left is not None:
573 k += node.left.size
574 break
575 elif key < node.key:
576 node = node.left
577 else:
578 k += node.val if node.left is None else node.left.size + node.val
579 node = node.right
580 return k
581
582 def index_right_keys(self, key: T) -> int:
583 k = 0
584 node = self.node
585 while node:
586 if key == node.key:
587 k += node.val if node.left is None else node.left.size + node.val
588 break
589 elif key < node.key:
590 node = node.left
591 else:
592 k += node.val if node.left is None else node.left.size + node.val
593 node = node.right
594 return k
595
596 def get_min(self) -> Optional[T]:
597 if self.node is None:
598 return
599 node = self.node
600 while node.left is not None:
601 node = node.left
602 return node.key
603
604 def get_max(self) -> Optional[T]:
605 if self.node is None:
606 return
607 node = self.node
608 while node.right is not None:
609 node = node.right
610 return node.key
611
612 def pop(self, k: int = -1) -> T:
613 if k < 0:
614 k += self.node.valsize
615 node = self.node
616 path = []
617 if k == self.node.valsize - 1:
618 while node.right is not None:
619 path.append(node)
620 node = node.right
621 x = node.key
622 if node.val == 1:
623 self._discard(node, path, 0)
624 else:
625 node.val -= 1
626 node.valsize -= 1
627 for p in path:
628 p.valsize -= 1
629 return x
630 di = 0
631 while True:
632 t = node.val if node.left is None else node.val + node.left.valsize
633 if t - node.val <= k < t:
634 x = node.key
635 break
636 elif t > k:
637 path.append(node)
638 di <<= 1
639 di |= 1
640 node = node.left
641 else:
642 path.append(node)
643 di <<= 1
644 node = node.right
645 k -= t
646 if node.val == 1:
647 self._discard(node, path, di)
648 else:
649 node.val -= 1
650 node.valsize -= 1
651 for p in path:
652 p.valsize -= 1
653 return x
654
655 def pop_max(self) -> T:
656 assert self
657 return self.pop()
658
659 def pop_min(self) -> T:
660 node = self.node
661 path = []
662 while node.left is not None:
663 path.append(node)
664 node = node.left
665 x = node.key
666 if node.val == 1:
667 self._discard(node, path, (1 << len(path)) - 1)
668 else:
669 node.val -= 1
670 node.valsize -= 1
671 for p in path:
672 p.valsize -= 1
673 return x
674
675 def items(self) -> Iterator[tuple[T, int]]:
676 for i in range(self.len_elm()):
677 yield self._kth_elm_tree(i)
678
679 def keys(self) -> Iterator[T]:
680 for i in range(self.len_elm()):
681 yield self._kth_elm_tree(i)[0]
682
683 def values(self) -> Iterator[int]:
684 for i in range(self.len_elm()):
685 yield self._kth_elm_tree(i)[1]
686
687 def len_elm(self) -> int:
688 return 0 if self.node is None else self.node.size
689
690 def show(self) -> None:
691 print(
692 "{" + ", ".join(map(lambda x: f"{x[0]}: {x[1]}", self.tolist_items())) + "}"
693 )
694
695 def clear(self) -> None:
696 self.node = None
697
698 def get_elm(self, k: int) -> T:
699 return self._kth_elm_tree(k)[0]
700
701 def tolist(self) -> list[T]:
702 a = []
703 if self.node is None:
704 return a
705
706 def rec(node):
707 if node.left is not None:
708 rec(node.left)
709 a.extend([node.key] * node.val)
710 if node.right is not None:
711 rec(node.right)
712
713 rec(self.node)
714 return a
715
716 def tolist_items(self) -> list[tuple[T, int]]:
717 a = []
718 if self.node is None:
719 return a
720
721 def rec(node):
722 if node.left is not None:
723 rec(node.left)
724 a.append((node.key, node.val))
725 if node.right is not None:
726 rec(node.right)
727
728 rec(self.node)
729 return a
730
731 def __getitem__(self, k: int):
732 return self._kth_elm(k)[0]
733
734 def __contains__(self, key: T):
735 node = self.node
736 while node:
737 if node.key == key:
738 return True
739 node = node.left if key < node.key else node.right
740 return False
741
742 def __iter__(self):
743 self.__iter = 0
744 return self
745
746 def __next__(self):
747 if self.__iter == len(self):
748 raise StopIteration
749 res = self._kth_elm(self.__iter)
750 self.__iter += 1
751 return res
752
753 def __reversed__(self):
754 for i in range(len(self)):
755 yield self._kth_elm(-i - 1)[0]
756
757 def __len__(self):
758 return 0 if self.node is None else self.node.valsize
759
760 def __bool__(self):
761 return self.node is not None
762
763 def __str__(self):
764 return "{" + ", ".join(map(str, self.tolist())) + "}"
765
766 def __repr__(self):
767 return f"AVLTreeMultiset3({self.tolist()})"
仕様¶
- class AVLTreeMultiset3(a: Iterable[T] = [])[source]¶
Bases:
OrderedMultisetInterface
,Generic
[T
]多重集合としての AVL 木です。
class Node()
を用いています。