red_black_tree_multiset¶
ソースコード¶
from titan_pylib.data_structures.red_black_tree.red_black_tree_multiset import RedBlackTreeMultiset
展開済みコード¶
1# from titan_pylib.data_structures.red_black_tree.red_black_tree_multiset import RedBlackTreeMultiset
2# from titan_pylib.my_class.supports_less_than import SupportsLessThan
3from typing import Protocol
4
5
6class SupportsLessThan(Protocol):
7
8 def __lt__(self, other) -> bool: ...
9# from titan_pylib.my_class.ordered_multiset_interface import OrderedMultisetInterface
10# from titan_pylib.my_class.supports_less_than import SupportsLessThan
11from abc import ABC, abstractmethod
12from typing import Iterable, Optional, Iterator, TypeVar, Generic
13
14T = TypeVar("T", bound=SupportsLessThan)
15
16
17class OrderedMultisetInterface(ABC, Generic[T]):
18
19 @abstractmethod
20 def __init__(self, a: Iterable[T]) -> None:
21 raise NotImplementedError
22
23 @abstractmethod
24 def add(self, key: T, cnt: int) -> None:
25 raise NotImplementedError
26
27 @abstractmethod
28 def discard(self, key: T, cnt: int) -> bool:
29 raise NotImplementedError
30
31 @abstractmethod
32 def discard_all(self, key: T) -> bool:
33 raise NotImplementedError
34
35 @abstractmethod
36 def count(self, key: T) -> int:
37 raise NotImplementedError
38
39 @abstractmethod
40 def remove(self, key: T, cnt: int) -> None:
41 raise NotImplementedError
42
43 @abstractmethod
44 def le(self, key: T) -> Optional[T]:
45 raise NotImplementedError
46
47 @abstractmethod
48 def lt(self, key: T) -> Optional[T]:
49 raise NotImplementedError
50
51 @abstractmethod
52 def ge(self, key: T) -> Optional[T]:
53 raise NotImplementedError
54
55 @abstractmethod
56 def gt(self, key: T) -> Optional[T]:
57 raise NotImplementedError
58
59 @abstractmethod
60 def get_max(self) -> Optional[T]:
61 raise NotImplementedError
62
63 @abstractmethod
64 def get_min(self) -> Optional[T]:
65 raise NotImplementedError
66
67 @abstractmethod
68 def pop_max(self) -> T:
69 raise NotImplementedError
70
71 @abstractmethod
72 def pop_min(self) -> T:
73 raise NotImplementedError
74
75 @abstractmethod
76 def clear(self) -> None:
77 raise NotImplementedError
78
79 @abstractmethod
80 def tolist(self) -> list[T]:
81 raise NotImplementedError
82
83 @abstractmethod
84 def __iter__(self) -> Iterator:
85 raise NotImplementedError
86
87 @abstractmethod
88 def __next__(self) -> T:
89 raise NotImplementedError
90
91 @abstractmethod
92 def __contains__(self, key: T) -> bool:
93 raise NotImplementedError
94
95 @abstractmethod
96 def __len__(self) -> int:
97 raise NotImplementedError
98
99 @abstractmethod
100 def __bool__(self) -> bool:
101 raise NotImplementedError
102
103 @abstractmethod
104 def __str__(self) -> str:
105 raise NotImplementedError
106
107 @abstractmethod
108 def __repr__(self) -> str:
109 raise NotImplementedError
110# from titan_pylib.data_structures.bst_base.bst_multiset_node_base import (
111# BSTMultisetNodeBase,
112# )
113from typing import TypeVar, Generic, Optional
114
115T = TypeVar("T")
116Node = TypeVar("Node")
117# protcolで、key,val,left,right を規定
118
119
120class BSTMultisetNodeBase(Generic[T, Node]):
121
122 @staticmethod
123 def count(node, key: T) -> int:
124 while node is not None:
125 if node.key == key:
126 return node.val
127 node = node.left if key < node.key else node.right
128 return 0
129
130 @staticmethod
131 def get_min(node: Node) -> Optional[T]:
132 if not node:
133 return None
134 while node.left:
135 node = node.left
136 return node.key
137
138 @staticmethod
139 def get_max(node: Node) -> Optional[T]:
140 if not node:
141 return None
142 while node.right:
143 node = node.right
144 return node.key
145
146 @staticmethod
147 def contains(node: Node, key: T) -> bool:
148 while node:
149 if key == node.key:
150 return True
151 node = node.left if key < node.key else node.right
152 return False
153
154 @staticmethod
155 def tolist(node: Node) -> list[T]:
156 stack = []
157 a = []
158 while stack or node:
159 if node:
160 stack.append(node)
161 node = node.left
162 else:
163 node = stack.pop()
164 for _ in range(node.val):
165 a.append(node.key)
166 node = node.right
167 return a
168
169 @staticmethod
170 def tolist_items(node: Node) -> list[tuple[T, int]]:
171 stack = []
172 a = []
173 while stack or node:
174 if node:
175 stack.append(node)
176 node = node.left
177 else:
178 node = stack.pop()
179 a.append((node.key, node.val))
180 node = node.right
181 return a
182
183 @staticmethod
184 def _rle(a: list[T]) -> tuple[list[T], list[int]]:
185 keys, vals = [], []
186 keys.append(a[0])
187 vals.append(1)
188 for i, elm in enumerate(a):
189 if i == 0:
190 continue
191 if elm == keys[-1]:
192 vals[-1] += 1
193 continue
194 keys.append(elm)
195 vals.append(1)
196 return keys, vals
197
198 @staticmethod
199 def le(node: Node, key: T) -> Optional[T]:
200 res = None
201 while node is not None:
202 if key == node.key:
203 res = key
204 break
205 if key < node.key:
206 node = node.left
207 else:
208 res = node.key
209 node = node.right
210 return res
211
212 @staticmethod
213 def lt(node: Node, key: T) -> Optional[T]:
214 res = None
215 while node is not None:
216 if key <= node.key:
217 node = node.left
218 else:
219 res = node.key
220 node = node.right
221 return res
222
223 @staticmethod
224 def ge(node: Node, key: T) -> Optional[T]:
225 res = None
226 while node is not None:
227 if key == node.key:
228 res = key
229 break
230 if key < node.key:
231 res = node.key
232 node = node.left
233 else:
234 node = node.right
235 return res
236
237 @staticmethod
238 def gt(node: Node, key: T) -> Optional[T]:
239 res = None
240 while node is not None:
241 if key < node.key:
242 res = node.key
243 node = node.left
244 else:
245 node = node.right
246 return res
247
248 @staticmethod
249 def index(node: Node, key: T) -> int:
250 k = 0
251 while node:
252 if key == node.key:
253 if node.left:
254 k += node.left.valsize
255 break
256 if key < node.key:
257 node = node.left
258 else:
259 k += node.val if node.left is None else node.left.valsize + node.val
260 node = node.right
261 return k
262
263 @staticmethod
264 def index_right(node: Node, key: T) -> int:
265 k = 0
266 while node:
267 if key == node.key:
268 k += node.val if node.left is None else node.left.valsize + node.val
269 break
270 if key < node.key:
271 node = node.left
272 else:
273 k += node.val if node.left is None else node.left.valsize + node.val
274 node = node.right
275 return k
276from typing import Iterable, Optional, TypeVar, Generic
277
278T = TypeVar("T", bound=SupportsLessThan)
279
280
281class RedBlackTreeMultiset(OrderedMultisetInterface, Generic[T]):
282
283 class Node:
284
285 def __init__(self, key: T, cnt: int = 1):
286 self.key: T = key
287 self.left = RedBlackTreeMultiset._NIL
288 self.right = RedBlackTreeMultiset._NIL
289 self.par = RedBlackTreeMultiset._NIL
290 self.col: int = 0
291 self.cnt: int = cnt
292
293 @property
294 def count(self) -> int:
295 return self.cnt
296
297 def _min(self) -> "RedBlackTreeMultiset.Node":
298 now = self
299 while now.left:
300 now = now.left
301 return now
302
303 def _max(self) -> "RedBlackTreeMultiset.Node":
304 now = self
305 while now.right:
306 now = now.right
307 return now
308
309 def _next(self) -> Optional["RedBlackTreeMultiset.Node"]:
310 now = self
311 pre = RedBlackTreeMultiset._NIL
312 flag = now.right is pre
313 while now.right is pre:
314 pre, now = now, now.par
315 if not now:
316 return None
317 return now if flag and pre is now.left else now.right._min()
318
319 def _prev(self) -> Optional["RedBlackTreeMultiset.Node"]:
320 now, pre = self, RedBlackTreeMultiset._NIL
321 flag = now.left is pre
322 while now.left is pre:
323 pre, now = now, now.par
324 if not now:
325 return None
326 return now if flag and pre is now.right else now.left._max()
327
328 def __add__(self, other: int) -> Optional["RedBlackTreeMultiset.Node"]:
329 res = self
330 for _ in range(other):
331 res = res._next()
332 return res
333
334 def __sub__(self, other: int) -> Optional["RedBlackTreeMultiset.Node"]:
335 res = self
336 for _ in range(other):
337 res = res._prev()
338 return res
339
340 __iadd__ = __add__
341 __isub__ = __sub__
342
343 def __str__(self):
344 if (not self.left) and (not self.right):
345 return f"(key,col,par.key):{self.key, self.col, self.par.key}\n"
346 return f"(key,col,par.key):{self.key, self.col, self.par.key},\n left:{self.left},\n right:{self.right}\n"
347
348 class _NILNode:
349
350 key = None
351 left = None
352 right = None
353 par = None
354 col = 0
355 cnt = 0
356
357 def _min(self):
358 return None
359
360 def _max(self):
361 return None
362
363 def __bool__(self):
364 return False
365
366 def __str__(self):
367 return "_NIL"
368
369 _NIL = _NILNode()
370
371 def __init__(self, a: Iterable[T] = []):
372 self.node = RedBlackTreeMultiset._NIL
373 self.size = 0
374 self.min_node = None
375 self.max_node = None
376 if not isinstance(a, list):
377 a = list(a)
378 if a:
379 self._build(a)
380
381 def _build(self, a: list[T]) -> None:
382 def sort(l: int, r: int, d: int):
383 mid = (l + r) >> 1
384 node = Node(x[mid], y[mid])
385 node.col = int((not flag and d & 1) or (flag and d > 1 and not d & 1))
386 if l != mid:
387 node.left = sort(l, mid, d + 1)
388 node.left.par = node
389 if mid + 1 != r:
390 node.right = sort(mid + 1, r, d + 1)
391 node.right.par = node
392 return node
393
394 if not all(a[i] <= a[i + 1] for i in range(len(a) - 1)):
395 a = sorted(a)
396 Node = RedBlackTreeMultiset.Node
397 x, y = BSTMultisetNodeBase[T, RedBlackTreeMultiset.Node]._rle(a)
398 flag = len(x).bit_length() & 1
399 self.node = sort(0, len(x), 0)
400 self.min_node = self.node._min()
401 self.max_node = self.node._max()
402 self.size = len(a)
403
404 def _rotate_left(self, node: Node) -> None:
405 u = node.right
406 p = node.par
407 node.right = u.left
408 if u.left:
409 u.left.par = node
410 u.par = p
411 if not p:
412 self.node = u
413 elif node is p.left:
414 p.left = u
415 else:
416 p.right = u
417 u.left = node
418 node.par = u
419
420 def _rotate_right(self, node: Node) -> None:
421 u = node.left
422 p = node.par
423 node.left = u.right
424 if u.right:
425 u.right.par = node
426 u.par = p
427 if not p:
428 self.node = u
429 elif node is p.right:
430 p.right = u
431 else:
432 p.left = u
433 u.right = node
434 node.par = u
435
436 def _transplant(self, u: Node, v: Node) -> None:
437 if not u.par:
438 self.node = v
439 elif u is u.par.left:
440 u.par.left = v
441 else:
442 u.par.right = v
443 v.par = u.par
444
445 def _get_min(self, node: Node) -> Node:
446 while node.left:
447 node = node.left
448 return node
449
450 def _get_max(self, node: Node) -> Node:
451 while node.right:
452 node = node.right
453 return node
454
455 def add(self, key: T, cnt: int = 1) -> None:
456 if cnt <= 0:
457 return
458 self.size += cnt
459 if not self.node:
460 node = RedBlackTreeMultiset.Node(key, cnt)
461 self.node = node
462 self.min_node = node
463 self.max_node = node
464 return
465 pnode = RedBlackTreeMultiset._NIL
466 node = self.node
467 while node:
468 pnode = node
469 if key == node.key:
470 node.cnt += cnt
471 return
472 node = node.left if key < node.key else node.right
473 z = RedBlackTreeMultiset.Node(key, cnt)
474 if key < self.min_node.key:
475 self.min_node = z
476 if key > self.max_node.key:
477 self.max_node = z
478 z.par = pnode
479 if not pnode:
480 self.node = z
481 elif key < pnode.key:
482 pnode.left = z
483 else:
484 pnode.right = z
485 z.col = 1
486 while z.par.col == 1:
487 g = z.par.par
488 if z.par is g.left:
489 y = g.right
490 if y.col:
491 z.par.col = 0
492 y.col = 0
493 g.col = 1
494 z = g
495 else:
496 if z is z.par.right:
497 z = z.par
498 self._rotate_left(z)
499 z.par.col = 0
500 g.col = 1
501 self._rotate_right(g)
502 break
503 else:
504 y = g.left
505 if y.col:
506 z.par.col = 0
507 y.col = 0
508 g.col = 1
509 z = g
510 else:
511 if z is z.par.left:
512 z = z.par
513 self._rotate_right(z)
514 z.par.col = 0
515 g.col = 1
516 self._rotate_left(g)
517 break
518 self.node.col = 0
519
520 def discard_iter(self, node: Node) -> None:
521 assert isinstance(node, RedBlackTreeMultiset.Node)
522 self.size -= node.cnt
523 if node.key == self.min_node.key:
524 self.min_node = node._next()
525 if node.key == self.max_node.key:
526 self.max_node = node._prev()
527 y = node
528 y_col = y.col
529 if not node.left:
530 x = node.right
531 self._transplant(node, node.right)
532 elif not node.right:
533 x = node.left
534 self._transplant(node, node.left)
535 else:
536 y = self._get_min(node.right)
537 y_col = y.col
538 x = y.right
539 if y.par is node:
540 x.par = y
541 else:
542 self._transplant(y, y.right)
543 y.right = node.right
544 y.right.par = y
545 self._transplant(node, y)
546 y.left = node.left
547 y.left.par = y
548 y.col = node.col
549 if y_col:
550 return
551 while x is not self.node and not x.col:
552 if x is x.par.left:
553 y = x.par
554 w = y.right
555 if w.col:
556 w.col = 0
557 y.col = 1
558 self._rotate_left(y)
559 w = y.right
560 if not (w.left.col or w.right.col):
561 w.col = 1
562 x = y
563 else:
564 if not w.right.col:
565 w.left.col = 0
566 w.col = 1
567 self._rotate_right(w)
568 w = y.right
569 w.col = y.col
570 y.col = 0
571 w.right.col = 0
572 self._rotate_left(x.par)
573 x = self.node
574 else:
575 y = x.par
576 w = y.left
577 if w.col:
578 w.col = 0
579 y.col = 1
580 self._rotate_right(y)
581 w = y.left
582 if not (w.right.col or w.left.col):
583 w.col = 1
584 x = y
585 else:
586 if not w.left.col:
587 w.right.col = 0
588 w.col = 1
589 self._rotate_left(w)
590 w = y.left
591 w.col = y.col
592 y.col = 0
593 w.left.col = 0
594 self._rotate_right(y)
595 x = self.node
596 x.col = 0
597 return True
598
599 def discard(self, key: T, cnt: int = 1) -> bool:
600 assert cnt >= 0
601 node = self.node
602 while node:
603 if key == node.key:
604 break
605 node = node.left if key < node.key else node.right
606 else:
607 return False
608 if node.cnt > cnt:
609 node.cnt -= cnt
610 self.size -= cnt
611 return True
612 self.discard_iter(node)
613 return True
614
615 def discard_all(self, key: T) -> bool:
616 node = self.node
617 while node:
618 if key == node.key:
619 break
620 node = node.left if key < node.key else node.right
621 else:
622 return False
623 self.discard_iter(node)
624 return True
625
626 def remove(self, key: T, cnt: int = 1) -> None:
627 c = self.count(key)
628 if c < cnt:
629 raise KeyError(key)
630 self.discard(key, cnt)
631
632 def count(self, key: T) -> int:
633 node = self.node
634 while node:
635 if key == node.key:
636 break
637 node = node.left if key < node.key else node.right
638 return node.cnt if node else 0
639
640 def get_max(self) -> Optional[T]:
641 if self.max_node is None:
642 return
643 return self.max_node.key
644
645 def get_min(self) -> Optional[T]:
646 if self.min_node is None:
647 return
648 return self.min_node.key
649
650 def get_max_iter(self) -> Optional[Node]:
651 return self.max_node
652
653 def get_min_iter(self) -> Optional[Node]:
654 return self.min_node
655
656 def le(self, key: T) -> Optional[T]:
657 res, node = None, self.node
658 while node:
659 if key == node.key:
660 res = node
661 break
662 elif key < node.key:
663 node = node.left
664 else:
665 res = node
666 node = node.right
667 return res.key if res else None
668
669 def lt(self, key: T) -> Optional[T]:
670 res, node = None, self.node
671 while node:
672 if key <= node.key:
673 node = node.left
674 else:
675 res = node
676 node = node.right
677 return res.key if res else None
678
679 def ge(self, key: T) -> Optional[T]:
680 res, node = None, self.node
681 while node:
682 if key == node.key:
683 res = node
684 break
685 elif key < node.key:
686 res = node
687 node = node.left
688 else:
689 node = node.right
690 return res.key if res else None
691
692 def gt(self, key: T) -> Optional[T]:
693 res, node = None, self.node
694 while node:
695 if key < node.key:
696 res = node
697 node = node.left
698 else:
699 node = node.right
700 return res.key if res else None
701
702 def le_iter(self, key: T) -> Optional[Node]:
703 res, node = None, self.node
704 while node:
705 if key == node.key:
706 res = node
707 break
708 elif key < node.key:
709 node = node.left
710 else:
711 res = node
712 node = node.right
713 return res
714
715 def lt_iter(self, key: T) -> Optional[Node]:
716 res, node = None, self.node
717 while node:
718 if key <= node.key:
719 node = node.left
720 else:
721 res = node
722 node = node.right
723 return res
724
725 def ge_iter(self, key: T) -> Optional[Node]:
726 res, node = None, self.node
727 while node:
728 if key == node.key:
729 res = node
730 break
731 elif key < node.key:
732 res = node
733 node = node.left
734 else:
735 node = node.right
736 return res
737
738 def gt_iter(self, key: T) -> Optional[Node]:
739 res, node = None, self.node
740 while node:
741 if key < node.key:
742 res = node
743 node = node.left
744 else:
745 node = node.right
746 return res
747
748 def find(self, key: T) -> Optional[Node]:
749 node = self.node
750 while node:
751 if key == node.key:
752 return node
753 node = node.left if key < node.key else node.right
754 return None
755
756 def tolist(self) -> list[T]:
757 node = self.node
758 stack = []
759 res = []
760 while stack or node:
761 if node:
762 stack.append(node)
763 node = node.left
764 else:
765 node = stack.pop()
766 for _ in range(node.cnt):
767 res.append(node.key)
768 node = node.right
769 return res
770
771 def pop_max(self) -> T:
772 assert self.node, "IndexError: pop_max() from empty RedBlackTreeMultiset"
773 node = self.max_node
774 if node.cnt == 1:
775 self.discard_iter(node)
776 else:
777 node.cnt -= 1
778 self.size -= 1
779 return node.key
780
781 def pop_min(self) -> T:
782 assert self.node, "IndexError: pop_min() from empty RedBlackTreeMultiset"
783 node = self.min_node
784 if node.cnt == 1:
785 self.discard_iter(node)
786 else:
787 node.cnt -= 1
788 self.size -= 1
789 return node.key
790
791 def clear(self) -> None:
792 self.node = RedBlackTreeMultiset._NIL
793 self.size = 0
794 self.min_node = None
795 self.max_node = None
796
797 def __iter__(self):
798 self.it = self.min_node
799 self.cnt = 0
800 return self
801
802 def __next__(self):
803 if self.cnt >= self.it.cnt:
804 self.it += 1
805 self.cnt = 0
806 self.cnt += 1
807 if self.it is None:
808 raise StopIteration
809 return self.it.key
810
811 def __bool__(self):
812 return self.node is not RedBlackTreeMultiset._NIL
813
814 def __contains__(self, key: T):
815 node = self.node
816 while node:
817 if key == node.key:
818 return True
819 node = node.left if key < node.key else node.right
820 return False
821
822 def __len__(self):
823 return self.size
824
825 def __str__(self):
826 return "{" + ", ".join(map(str, self.tolist())) + "}"
827
828 def __repr__(self):
829 return "RedBlackTreeMultiset(" + str(self.tolist()) + ")"
仕様¶
- class RedBlackTreeMultiset(a: Iterable[T] = [])[source]¶
Bases:
OrderedMultisetInterface
,Generic
[T
]