scapegoat_tree_multiset¶
ソースコード¶
from titan_pylib.data_structures.scapegoat_tree.scapegoat_tree_multiset import ScapegoatTreeMultiset
展開済みコード¶
1# from titan_pylib.data_structures.scapegoat_tree.scapegoat_tree_multiset import ScapegoatTreeMultiset
2# from titan_pylib.my_class.ordered_multiset_interface import OrderedMultisetInterface
3# from titan_pylib.my_class.supports_less_than import SupportsLessThan
4from typing import Protocol
5
6
7class SupportsLessThan(Protocol):
8
9 def __lt__(self, other) -> bool: ...
10from abc import ABC, abstractmethod
11from typing import Iterable, Optional, Iterator, TypeVar, Generic
12
13T = TypeVar("T", bound=SupportsLessThan)
14
15
16class OrderedMultisetInterface(ABC, Generic[T]):
17
18 @abstractmethod
19 def __init__(self, a: Iterable[T]) -> None:
20 raise NotImplementedError
21
22 @abstractmethod
23 def add(self, key: T, cnt: int) -> None:
24 raise NotImplementedError
25
26 @abstractmethod
27 def discard(self, key: T, cnt: int) -> bool:
28 raise NotImplementedError
29
30 @abstractmethod
31 def discard_all(self, key: T) -> bool:
32 raise NotImplementedError
33
34 @abstractmethod
35 def count(self, key: T) -> int:
36 raise NotImplementedError
37
38 @abstractmethod
39 def remove(self, key: T, cnt: int) -> None:
40 raise NotImplementedError
41
42 @abstractmethod
43 def le(self, key: T) -> Optional[T]:
44 raise NotImplementedError
45
46 @abstractmethod
47 def lt(self, key: T) -> Optional[T]:
48 raise NotImplementedError
49
50 @abstractmethod
51 def ge(self, key: T) -> Optional[T]:
52 raise NotImplementedError
53
54 @abstractmethod
55 def gt(self, key: T) -> Optional[T]:
56 raise NotImplementedError
57
58 @abstractmethod
59 def get_max(self) -> Optional[T]:
60 raise NotImplementedError
61
62 @abstractmethod
63 def get_min(self) -> Optional[T]:
64 raise NotImplementedError
65
66 @abstractmethod
67 def pop_max(self) -> T:
68 raise NotImplementedError
69
70 @abstractmethod
71 def pop_min(self) -> T:
72 raise NotImplementedError
73
74 @abstractmethod
75 def clear(self) -> None:
76 raise NotImplementedError
77
78 @abstractmethod
79 def tolist(self) -> list[T]:
80 raise NotImplementedError
81
82 @abstractmethod
83 def __iter__(self) -> Iterator:
84 raise NotImplementedError
85
86 @abstractmethod
87 def __next__(self) -> T:
88 raise NotImplementedError
89
90 @abstractmethod
91 def __contains__(self, key: T) -> bool:
92 raise NotImplementedError
93
94 @abstractmethod
95 def __len__(self) -> int:
96 raise NotImplementedError
97
98 @abstractmethod
99 def __bool__(self) -> bool:
100 raise NotImplementedError
101
102 @abstractmethod
103 def __str__(self) -> str:
104 raise NotImplementedError
105
106 @abstractmethod
107 def __repr__(self) -> str:
108 raise NotImplementedError
109# from titan_pylib.my_class.supports_less_than import SupportsLessThan
110# from titan_pylib.data_structures.bst_base.bst_multiset_node_base import (
111# BSTMultisetNodeBase,
112# )
113from typing import TypeVar, Generic, Optional
114
115T = TypeVar("T")
116Node = TypeVar("Node")
117# protcolで、key,val,left,right を規定
118
119
120class BSTMultisetNodeBase(Generic[T, Node]):
121
122 @staticmethod
123 def count(node, key: T) -> int:
124 while node is not None:
125 if node.key == key:
126 return node.val
127 node = node.left if key < node.key else node.right
128 return 0
129
130 @staticmethod
131 def get_min(node: Node) -> Optional[T]:
132 if not node:
133 return None
134 while node.left:
135 node = node.left
136 return node.key
137
138 @staticmethod
139 def get_max(node: Node) -> Optional[T]:
140 if not node:
141 return None
142 while node.right:
143 node = node.right
144 return node.key
145
146 @staticmethod
147 def contains(node: Node, key: T) -> bool:
148 while node:
149 if key == node.key:
150 return True
151 node = node.left if key < node.key else node.right
152 return False
153
154 @staticmethod
155 def tolist(node: Node) -> list[T]:
156 stack = []
157 a = []
158 while stack or node:
159 if node:
160 stack.append(node)
161 node = node.left
162 else:
163 node = stack.pop()
164 for _ in range(node.val):
165 a.append(node.key)
166 node = node.right
167 return a
168
169 @staticmethod
170 def tolist_items(node: Node) -> list[tuple[T, int]]:
171 stack = []
172 a = []
173 while stack or node:
174 if node:
175 stack.append(node)
176 node = node.left
177 else:
178 node = stack.pop()
179 a.append((node.key, node.val))
180 node = node.right
181 return a
182
183 @staticmethod
184 def _rle(a: list[T]) -> tuple[list[T], list[int]]:
185 keys, vals = [], []
186 keys.append(a[0])
187 vals.append(1)
188 for i, elm in enumerate(a):
189 if i == 0:
190 continue
191 if elm == keys[-1]:
192 vals[-1] += 1
193 continue
194 keys.append(elm)
195 vals.append(1)
196 return keys, vals
197
198 @staticmethod
199 def le(node: Node, key: T) -> Optional[T]:
200 res = None
201 while node is not None:
202 if key == node.key:
203 res = key
204 break
205 if key < node.key:
206 node = node.left
207 else:
208 res = node.key
209 node = node.right
210 return res
211
212 @staticmethod
213 def lt(node: Node, key: T) -> Optional[T]:
214 res = None
215 while node is not None:
216 if key <= node.key:
217 node = node.left
218 else:
219 res = node.key
220 node = node.right
221 return res
222
223 @staticmethod
224 def ge(node: Node, key: T) -> Optional[T]:
225 res = None
226 while node is not None:
227 if key == node.key:
228 res = key
229 break
230 if key < node.key:
231 res = node.key
232 node = node.left
233 else:
234 node = node.right
235 return res
236
237 @staticmethod
238 def gt(node: Node, key: T) -> Optional[T]:
239 res = None
240 while node is not None:
241 if key < node.key:
242 res = node.key
243 node = node.left
244 else:
245 node = node.right
246 return res
247
248 @staticmethod
249 def index(node: Node, key: T) -> int:
250 k = 0
251 while node:
252 if key == node.key:
253 if node.left:
254 k += node.left.valsize
255 break
256 if key < node.key:
257 node = node.left
258 else:
259 k += node.val if node.left is None else node.left.valsize + node.val
260 node = node.right
261 return k
262
263 @staticmethod
264 def index_right(node: Node, key: T) -> int:
265 k = 0
266 while node:
267 if key == node.key:
268 k += node.val if node.left is None else node.left.valsize + node.val
269 break
270 if key < node.key:
271 node = node.left
272 else:
273 k += node.val if node.left is None else node.left.valsize + node.val
274 node = node.right
275 return k
276import math
277from typing import Final, TypeVar, Generic, Iterable, Optional, Iterator
278
279T = TypeVar("T", bound=SupportsLessThan)
280
281
282class ScapegoatTreeMultiset(OrderedMultisetInterface, Generic[T]):
283
284 ALPHA: Final[float] = 0.75
285 BETA: Final[float] = math.log2(1 / ALPHA)
286
287 class Node:
288
289 def __init__(self, key: T, val: int):
290 self.key: T = key
291 self.val: int = val
292 self.size: int = 1
293 self.valsize: int = val
294 self.left: Optional[ScapegoatTreeMultiset.Node] = None
295 self.right: Optional[ScapegoatTreeMultiset.Node] = None
296
297 def __str__(self):
298 if self.left is None and self.right is None:
299 return f"key:{self.key, self.val, self.size, self.valsize}\n"
300 return f"key:{self.key, self.val, self.size, self.valsize},\n left:{self.left},\n right:{self.right}\n"
301
302 def __init__(self, a: Iterable[T] = []):
303 self.root = None
304 if not isinstance(a, list):
305 a = list(a)
306 self._build(a)
307
308 def _build(self, a: list[T]) -> None:
309 Node = ScapegoatTreeMultiset.Node
310
311 def rec(l: int, r: int) -> ScapegoatTreeMultiset.Node:
312 mid = (l + r) >> 1
313 node = Node(x[mid], y[mid])
314 if l != mid:
315 node.left = rec(l, mid)
316 node.size += node.left.size
317 node.valsize += node.left.valsize
318 if mid + 1 != r:
319 node.right = rec(mid + 1, r)
320 node.size += node.right.size
321 node.valsize += node.right.valsize
322 return node
323
324 if not all(a[i] <= a[i + 1] for i in range(len(a) - 1)):
325 a = sorted(a)
326 if not a:
327 return
328 x, y = BSTMultisetNodeBase[T, ScapegoatTreeMultiset.Node]._rle(a)
329 self.root = rec(0, len(x))
330
331 def _rebuild(self, node: Node) -> Node:
332 def rec(l: int, r: int) -> ScapegoatTreeMultiset.Node:
333 mid = (l + r) >> 1
334 node = a[mid]
335 node.size = 1
336 node.valsize = node.val
337 if l != mid:
338 node.left = rec(l, mid)
339 node.size += node.left.size
340 node.valsize += node.left.valsize
341 else:
342 node.left = None
343 if mid + 1 != r:
344 node.right = rec(mid + 1, r)
345 node.size += node.right.size
346 node.valsize += node.right.valsize
347 else:
348 node.right = None
349 return node
350
351 a = []
352 stack = []
353 while stack or node:
354 if node:
355 stack.append(node)
356 node = node.left
357 else:
358 node = stack.pop()
359 a.append(node)
360 node = node.right
361 return rec(0, len(a))
362
363 def _kth_elm(self, k: int) -> tuple[T, int]:
364 if k < 0:
365 k += len(self)
366 node = self.root
367 while node:
368 t = (node.val + node.left.valsize) if node.left else node.val
369 if t - node.val <= k and k < t:
370 return node.key, node.val
371 elif t > k:
372 node = node.left
373 else:
374 node = node.right
375 k -= t
376
377 def _kth_elm_tree(self, k: int) -> tuple[T, int]:
378 if k < 0:
379 k += self.len_elm()
380 node = self.root
381 while node:
382 t = node.left.size if node.left else 0
383 if t == k:
384 return node.key, node.val
385 if t > k:
386 node = node.left
387 else:
388 node = node.right
389 k -= t + 1
390 assert False, "IndexError"
391
392 def add(self, key: T, val: int = 1) -> None:
393 if val <= 0:
394 return
395 if not self.root:
396 self.root = ScapegoatTreeMultiset.Node(key, val)
397 return
398 node = self.root
399 path = []
400 while node:
401 path.append(node)
402 if key == node.key:
403 node.val += val
404 for p in path:
405 p.valsize += val
406 return
407 node = node.left if key < node.key else node.right
408 if key < path[-1].key:
409 path[-1].left = ScapegoatTreeMultiset.Node(key, val)
410 else:
411 path[-1].right = ScapegoatTreeMultiset.Node(key, val)
412 if len(path) * ScapegoatTreeMultiset.BETA > math.log(self.len_elm()):
413 node_size = 1
414 while path:
415 pnode = path.pop()
416 pnode_size = pnode.size + 1
417 if ScapegoatTreeMultiset.ALPHA * pnode_size < node_size:
418 break
419 node_size = pnode_size
420 new_node = self._rebuild(pnode)
421 if not path:
422 self.root = new_node
423 return
424 if new_node.key < path[-1].key:
425 path[-1].left = new_node
426 else:
427 path[-1].right = new_node
428 for p in path:
429 p.size += 1
430 p.valsize += val
431
432 def _discard(self, key: T) -> bool:
433 path = []
434 node = self.root
435 di, cnt = 1, 0
436 while node:
437 if key == node.key:
438 break
439 path.append(node)
440 di = key < node.key
441 node = node.left if di else node.right
442 if node.left and node.right:
443 path.append(node)
444 lmax = node.left
445 di = 0 if lmax.right else 1
446 while lmax.right:
447 cnt += 1
448 path.append(lmax)
449 lmax = lmax.right
450 lmax_val = lmax.val
451 node.key = lmax.key
452 node.val = lmax_val
453 node = lmax
454 cnode = node.left if node.left else node.right
455 if path:
456 if di == 1:
457 path[-1].left = cnode
458 else:
459 path[-1].right = cnode
460 else:
461 self.root = cnode
462 return True
463 for _ in range(cnt):
464 p = path.pop()
465 p.size -= 1
466 p.valsize -= lmax_val
467 for p in path:
468 p.size -= 1
469 p.valsize -= 1
470 return True
471
472 def discard(self, key: T, val=1) -> bool:
473 if val <= 0:
474 return True
475 path = []
476 node = self.root
477 while node:
478 path.append(node)
479 if key == node.key:
480 break
481 node = node.left if key < node.key else node.right
482 else:
483 return False
484 if val > node.val:
485 val = node.val - 1
486 if val > 0:
487 node.val -= val
488 while path:
489 path.pop().valsize -= val
490 if node.val == 1:
491 self._discard(key)
492 else:
493 node.val -= val
494 while path:
495 path.pop().valsize -= val
496 return True
497
498 def remove(self, key: T, val: int = 1) -> None:
499 c = self.count(key)
500 if c > val:
501 raise KeyError(key)
502 self.discard(key, val)
503
504 def count(self, key: T) -> int:
505 node = self.root
506 while node:
507 if key == node.key:
508 return node.val
509 node = node.left if key < node.key else node.right
510 return 0
511
512 def discard_all(self, key: T) -> bool:
513 return self.discard(key, self.count(key))
514
515 def le(self, key: T) -> Optional[T]:
516 return BSTMultisetNodeBase[T, ScapegoatTreeMultiset.Node].le(self.root, key)
517
518 def lt(self, key: T) -> Optional[T]:
519 return BSTMultisetNodeBase[T, ScapegoatTreeMultiset.Node].lt(self.root, key)
520
521 def ge(self, key: T) -> Optional[T]:
522 return BSTMultisetNodeBase[T, ScapegoatTreeMultiset.Node].ge(self.root, key)
523
524 def gt(self, key: T) -> Optional[T]:
525 return BSTMultisetNodeBase[T, ScapegoatTreeMultiset.Node].gt(self.root, key)
526
527 def index(self, key: T) -> int:
528 return BSTMultisetNodeBase[T, ScapegoatTreeMultiset.Node].index(self.root, key)
529
530 def index_right(self, key: T) -> int:
531 return BSTMultisetNodeBase[T, ScapegoatTreeMultiset.Node].index_right(
532 self.root, key
533 )
534
535 def index_keys(self, key: T) -> int:
536 k = 0
537 node = self.root
538 while node:
539 if key == node.key:
540 if node.left:
541 k += node.left.size
542 break
543 elif key < node.key:
544 node = node.left
545 else:
546 k += node.val if node.left is None else node.left.size + node.val
547 node = node.right
548 return k
549
550 def index_right_keys(self, key: T) -> int:
551 k = 0
552 node = self.root
553 while node:
554 if key == node.key:
555 k += node.val if node.left is None else node.left.size + node.val
556 break
557 if key < node.key:
558 node = node.left
559 else:
560 k += node.val if node.left is None else node.left.size + node.val
561 node = node.right
562 return k
563
564 def pop(self, k: int = -1) -> T:
565 if k < 0:
566 k += self.root.valsize
567 x = self[k]
568 self.discard(x)
569 return x
570
571 def pop_min(self) -> T:
572 return self.pop(0)
573
574 def pop_max(self) -> T:
575 return self.pop(-1)
576
577 def items(self) -> Iterator[tuple[T, int]]:
578 for i in range(self.len_elm()):
579 yield self._kth_elm_tree(i)
580
581 def keys(self) -> Iterator[T]:
582 for i in range(self.len_elm()):
583 yield self._kth_elm_tree(i)[0]
584
585 def values(self) -> Iterator[int]:
586 for i in range(self.len_elm()):
587 yield self._kth_elm_tree(i)[1]
588
589 def show(self) -> None:
590 print(
591 "{" + ", ".join(map(lambda x: f"{x[0]}: {x[1]}", self.tolist_items())) + "}"
592 )
593
594 def get_elm(self, k: int) -> T:
595 assert (
596 -self.len_elm() <= k < self.len_elm()
597 ), f"IndexError: {self.__class__.__name__}.get_elm({k}), len_elm=({self.len_elm()})"
598 return self._kth_elm_tree(k)[0]
599
600 def len_elm(self) -> int:
601 return self.root.size if self.root else 0
602
603 def tolist(self) -> list[T]:
604 return BSTMultisetNodeBase[T, ScapegoatTreeMultiset.Node].tolist(self.root)
605
606 def tolist_items(self) -> list[tuple[T, int]]:
607 return BSTMultisetNodeBase[T, ScapegoatTreeMultiset.Node].tolist_items(
608 self.root
609 )
610
611 def clear(self) -> None:
612 self.root = None
613
614 def get_max(self) -> T:
615 return self._kth_elm_tree(-1)[0]
616
617 def get_min(self) -> T:
618 return self._kth_elm_tree(0)[0]
619
620 def __contains__(self, key: T):
621 return BSTMultisetNodeBase[T, ScapegoatTreeMultiset.Node].contains(
622 self.root, key
623 )
624
625 def __getitem__(self, k: int) -> T:
626 assert (
627 -len(self) <= k < len(self)
628 ), f"IndexError: {self.__class__.__name__}[{k}], len={len(self)}"
629 return self._kth_elm(k)[0]
630
631 def __iter__(self):
632 self.__iter = 0
633 return self
634
635 def __next__(self):
636 if self.__iter == len(self):
637 raise StopIteration
638 res = self._kth_elm(self.__iter)[0]
639 self.__iter += 1
640 return res
641
642 def __reversed__(self):
643 for i in range(len(self)):
644 yield self._kth_elm(-i - 1)[0]
645
646 def __len__(self):
647 return self.root.valsize if self.root else 0
648
649 def __bool__(self):
650 return self.root is not None
651
652 def __str__(self):
653 return "{" + ", ".join(map(str, self.tolist())) + "}"
654
655 def __repr__(self):
656 return f"{self.__class__.__name__}({self.tolist})"
仕様¶
- class ScapegoatTreeMultiset(a: Iterable[T] = [])[source]¶
Bases:
OrderedMultisetInterface
,Generic
[T
]- ALPHA: Final[float] = 0.75¶
- BETA: Final[float] = 0.41503749927884376¶