scapegoat_tree_multiset

ソースコード

from titan_pylib.data_structures.scapegoat_tree.scapegoat_tree_multiset import ScapegoatTreeMultiset

view on github

展開済みコード

  1# from titan_pylib.data_structures.scapegoat_tree.scapegoat_tree_multiset import ScapegoatTreeMultiset
  2# from titan_pylib.my_class.ordered_multiset_interface import OrderedMultisetInterface
  3# from titan_pylib.my_class.supports_less_than import SupportsLessThan
  4from typing import Protocol
  5
  6
  7class SupportsLessThan(Protocol):
  8
  9    def __lt__(self, other) -> bool: ...
 10from abc import ABC, abstractmethod
 11from typing import Iterable, Optional, Iterator, TypeVar, Generic
 12
 13T = TypeVar("T", bound=SupportsLessThan)
 14
 15
 16class OrderedMultisetInterface(ABC, Generic[T]):
 17
 18    @abstractmethod
 19    def __init__(self, a: Iterable[T]) -> None:
 20        raise NotImplementedError
 21
 22    @abstractmethod
 23    def add(self, key: T, cnt: int) -> None:
 24        raise NotImplementedError
 25
 26    @abstractmethod
 27    def discard(self, key: T, cnt: int) -> bool:
 28        raise NotImplementedError
 29
 30    @abstractmethod
 31    def discard_all(self, key: T) -> bool:
 32        raise NotImplementedError
 33
 34    @abstractmethod
 35    def count(self, key: T) -> int:
 36        raise NotImplementedError
 37
 38    @abstractmethod
 39    def remove(self, key: T, cnt: int) -> None:
 40        raise NotImplementedError
 41
 42    @abstractmethod
 43    def le(self, key: T) -> Optional[T]:
 44        raise NotImplementedError
 45
 46    @abstractmethod
 47    def lt(self, key: T) -> Optional[T]:
 48        raise NotImplementedError
 49
 50    @abstractmethod
 51    def ge(self, key: T) -> Optional[T]:
 52        raise NotImplementedError
 53
 54    @abstractmethod
 55    def gt(self, key: T) -> Optional[T]:
 56        raise NotImplementedError
 57
 58    @abstractmethod
 59    def get_max(self) -> Optional[T]:
 60        raise NotImplementedError
 61
 62    @abstractmethod
 63    def get_min(self) -> Optional[T]:
 64        raise NotImplementedError
 65
 66    @abstractmethod
 67    def pop_max(self) -> T:
 68        raise NotImplementedError
 69
 70    @abstractmethod
 71    def pop_min(self) -> T:
 72        raise NotImplementedError
 73
 74    @abstractmethod
 75    def clear(self) -> None:
 76        raise NotImplementedError
 77
 78    @abstractmethod
 79    def tolist(self) -> list[T]:
 80        raise NotImplementedError
 81
 82    @abstractmethod
 83    def __iter__(self) -> Iterator:
 84        raise NotImplementedError
 85
 86    @abstractmethod
 87    def __next__(self) -> T:
 88        raise NotImplementedError
 89
 90    @abstractmethod
 91    def __contains__(self, key: T) -> bool:
 92        raise NotImplementedError
 93
 94    @abstractmethod
 95    def __len__(self) -> int:
 96        raise NotImplementedError
 97
 98    @abstractmethod
 99    def __bool__(self) -> bool:
100        raise NotImplementedError
101
102    @abstractmethod
103    def __str__(self) -> str:
104        raise NotImplementedError
105
106    @abstractmethod
107    def __repr__(self) -> str:
108        raise NotImplementedError
109# from titan_pylib.my_class.supports_less_than import SupportsLessThan
110# from titan_pylib.data_structures.bst_base.bst_multiset_node_base import (
111#     BSTMultisetNodeBase,
112# )
113from typing import TypeVar, Generic, Optional
114
115T = TypeVar("T")
116Node = TypeVar("Node")
117# protcolで、key,val,left,right を規定
118
119
120class BSTMultisetNodeBase(Generic[T, Node]):
121
122    @staticmethod
123    def count(node, key: T) -> int:
124        while node is not None:
125            if node.key == key:
126                return node.val
127            node = node.left if key < node.key else node.right
128        return 0
129
130    @staticmethod
131    def get_min(node: Node) -> Optional[T]:
132        if not node:
133            return None
134        while node.left:
135            node = node.left
136        return node.key
137
138    @staticmethod
139    def get_max(node: Node) -> Optional[T]:
140        if not node:
141            return None
142        while node.right:
143            node = node.right
144        return node.key
145
146    @staticmethod
147    def contains(node: Node, key: T) -> bool:
148        while node:
149            if key == node.key:
150                return True
151            node = node.left if key < node.key else node.right
152        return False
153
154    @staticmethod
155    def tolist(node: Node) -> list[T]:
156        stack = []
157        a = []
158        while stack or node:
159            if node:
160                stack.append(node)
161                node = node.left
162            else:
163                node = stack.pop()
164                for _ in range(node.val):
165                    a.append(node.key)
166                node = node.right
167        return a
168
169    @staticmethod
170    def tolist_items(node: Node) -> list[tuple[T, int]]:
171        stack = []
172        a = []
173        while stack or node:
174            if node:
175                stack.append(node)
176                node = node.left
177            else:
178                node = stack.pop()
179                a.append((node.key, node.val))
180                node = node.right
181        return a
182
183    @staticmethod
184    def _rle(a: list[T]) -> tuple[list[T], list[int]]:
185        keys, vals = [], []
186        keys.append(a[0])
187        vals.append(1)
188        for i, elm in enumerate(a):
189            if i == 0:
190                continue
191            if elm == keys[-1]:
192                vals[-1] += 1
193                continue
194            keys.append(elm)
195            vals.append(1)
196        return keys, vals
197
198    @staticmethod
199    def le(node: Node, key: T) -> Optional[T]:
200        res = None
201        while node is not None:
202            if key == node.key:
203                res = key
204                break
205            if key < node.key:
206                node = node.left
207            else:
208                res = node.key
209                node = node.right
210        return res
211
212    @staticmethod
213    def lt(node: Node, key: T) -> Optional[T]:
214        res = None
215        while node is not None:
216            if key <= node.key:
217                node = node.left
218            else:
219                res = node.key
220                node = node.right
221        return res
222
223    @staticmethod
224    def ge(node: Node, key: T) -> Optional[T]:
225        res = None
226        while node is not None:
227            if key == node.key:
228                res = key
229                break
230            if key < node.key:
231                res = node.key
232                node = node.left
233            else:
234                node = node.right
235        return res
236
237    @staticmethod
238    def gt(node: Node, key: T) -> Optional[T]:
239        res = None
240        while node is not None:
241            if key < node.key:
242                res = node.key
243                node = node.left
244            else:
245                node = node.right
246        return res
247
248    @staticmethod
249    def index(node: Node, key: T) -> int:
250        k = 0
251        while node:
252            if key == node.key:
253                if node.left:
254                    k += node.left.valsize
255                break
256            if key < node.key:
257                node = node.left
258            else:
259                k += node.val if node.left is None else node.left.valsize + node.val
260                node = node.right
261        return k
262
263    @staticmethod
264    def index_right(node: Node, key: T) -> int:
265        k = 0
266        while node:
267            if key == node.key:
268                k += node.val if node.left is None else node.left.valsize + node.val
269                break
270            if key < node.key:
271                node = node.left
272            else:
273                k += node.val if node.left is None else node.left.valsize + node.val
274                node = node.right
275        return k
276import math
277from typing import Final, TypeVar, Generic, Iterable, Optional, Iterator
278
279T = TypeVar("T", bound=SupportsLessThan)
280
281
282class ScapegoatTreeMultiset(OrderedMultisetInterface, Generic[T]):
283
284    ALPHA: Final[float] = 0.75
285    BETA: Final[float] = math.log2(1 / ALPHA)
286
287    class Node:
288
289        def __init__(self, key: T, val: int):
290            self.key: T = key
291            self.val: int = val
292            self.size: int = 1
293            self.valsize: int = val
294            self.left: Optional[ScapegoatTreeMultiset.Node] = None
295            self.right: Optional[ScapegoatTreeMultiset.Node] = None
296
297        def __str__(self):
298            if self.left is None and self.right is None:
299                return f"key:{self.key, self.val, self.size, self.valsize}\n"
300            return f"key:{self.key, self.val, self.size, self.valsize},\n left:{self.left},\n right:{self.right}\n"
301
302    def __init__(self, a: Iterable[T] = []):
303        self.root = None
304        if not isinstance(a, list):
305            a = list(a)
306        self._build(a)
307
308    def _build(self, a: list[T]) -> None:
309        Node = ScapegoatTreeMultiset.Node
310
311        def rec(l: int, r: int) -> ScapegoatTreeMultiset.Node:
312            mid = (l + r) >> 1
313            node = Node(x[mid], y[mid])
314            if l != mid:
315                node.left = rec(l, mid)
316                node.size += node.left.size
317                node.valsize += node.left.valsize
318            if mid + 1 != r:
319                node.right = rec(mid + 1, r)
320                node.size += node.right.size
321                node.valsize += node.right.valsize
322            return node
323
324        if not all(a[i] <= a[i + 1] for i in range(len(a) - 1)):
325            a = sorted(a)
326        if not a:
327            return
328        x, y = BSTMultisetNodeBase[T, ScapegoatTreeMultiset.Node]._rle(a)
329        self.root = rec(0, len(x))
330
331    def _rebuild(self, node: Node) -> Node:
332        def rec(l: int, r: int) -> ScapegoatTreeMultiset.Node:
333            mid = (l + r) >> 1
334            node = a[mid]
335            node.size = 1
336            node.valsize = node.val
337            if l != mid:
338                node.left = rec(l, mid)
339                node.size += node.left.size
340                node.valsize += node.left.valsize
341            else:
342                node.left = None
343            if mid + 1 != r:
344                node.right = rec(mid + 1, r)
345                node.size += node.right.size
346                node.valsize += node.right.valsize
347            else:
348                node.right = None
349            return node
350
351        a = []
352        stack = []
353        while stack or node:
354            if node:
355                stack.append(node)
356                node = node.left
357            else:
358                node = stack.pop()
359                a.append(node)
360                node = node.right
361        return rec(0, len(a))
362
363    def _kth_elm(self, k: int) -> tuple[T, int]:
364        if k < 0:
365            k += len(self)
366        node = self.root
367        while node:
368            t = (node.val + node.left.valsize) if node.left else node.val
369            if t - node.val <= k and k < t:
370                return node.key, node.val
371            elif t > k:
372                node = node.left
373            else:
374                node = node.right
375                k -= t
376
377    def _kth_elm_tree(self, k: int) -> tuple[T, int]:
378        if k < 0:
379            k += self.len_elm()
380        node = self.root
381        while node:
382            t = node.left.size if node.left else 0
383            if t == k:
384                return node.key, node.val
385            if t > k:
386                node = node.left
387            else:
388                node = node.right
389                k -= t + 1
390        assert False, "IndexError"
391
392    def add(self, key: T, val: int = 1) -> None:
393        if val <= 0:
394            return
395        if not self.root:
396            self.root = ScapegoatTreeMultiset.Node(key, val)
397            return
398        node = self.root
399        path = []
400        while node:
401            path.append(node)
402            if key == node.key:
403                node.val += val
404                for p in path:
405                    p.valsize += val
406                return
407            node = node.left if key < node.key else node.right
408        if key < path[-1].key:
409            path[-1].left = ScapegoatTreeMultiset.Node(key, val)
410        else:
411            path[-1].right = ScapegoatTreeMultiset.Node(key, val)
412        if len(path) * ScapegoatTreeMultiset.BETA > math.log(self.len_elm()):
413            node_size = 1
414            while path:
415                pnode = path.pop()
416                pnode_size = pnode.size + 1
417                if ScapegoatTreeMultiset.ALPHA * pnode_size < node_size:
418                    break
419                node_size = pnode_size
420            new_node = self._rebuild(pnode)
421            if not path:
422                self.root = new_node
423                return
424            if new_node.key < path[-1].key:
425                path[-1].left = new_node
426            else:
427                path[-1].right = new_node
428        for p in path:
429            p.size += 1
430            p.valsize += val
431
432    def _discard(self, key: T) -> bool:
433        path = []
434        node = self.root
435        di, cnt = 1, 0
436        while node:
437            if key == node.key:
438                break
439            path.append(node)
440            di = key < node.key
441            node = node.left if di else node.right
442        if node.left and node.right:
443            path.append(node)
444            lmax = node.left
445            di = 0 if lmax.right else 1
446            while lmax.right:
447                cnt += 1
448                path.append(lmax)
449                lmax = lmax.right
450            lmax_val = lmax.val
451            node.key = lmax.key
452            node.val = lmax_val
453            node = lmax
454        cnode = node.left if node.left else node.right
455        if path:
456            if di == 1:
457                path[-1].left = cnode
458            else:
459                path[-1].right = cnode
460        else:
461            self.root = cnode
462            return True
463        for _ in range(cnt):
464            p = path.pop()
465            p.size -= 1
466            p.valsize -= lmax_val
467        for p in path:
468            p.size -= 1
469            p.valsize -= 1
470        return True
471
472    def discard(self, key: T, val=1) -> bool:
473        if val <= 0:
474            return True
475        path = []
476        node = self.root
477        while node:
478            path.append(node)
479            if key == node.key:
480                break
481            node = node.left if key < node.key else node.right
482        else:
483            return False
484        if val > node.val:
485            val = node.val - 1
486            if val > 0:
487                node.val -= val
488                while path:
489                    path.pop().valsize -= val
490        if node.val == 1:
491            self._discard(key)
492        else:
493            node.val -= val
494            while path:
495                path.pop().valsize -= val
496        return True
497
498    def remove(self, key: T, val: int = 1) -> None:
499        c = self.count(key)
500        if c > val:
501            raise KeyError(key)
502        self.discard(key, val)
503
504    def count(self, key: T) -> int:
505        node = self.root
506        while node:
507            if key == node.key:
508                return node.val
509            node = node.left if key < node.key else node.right
510        return 0
511
512    def discard_all(self, key: T) -> bool:
513        return self.discard(key, self.count(key))
514
515    def le(self, key: T) -> Optional[T]:
516        return BSTMultisetNodeBase[T, ScapegoatTreeMultiset.Node].le(self.root, key)
517
518    def lt(self, key: T) -> Optional[T]:
519        return BSTMultisetNodeBase[T, ScapegoatTreeMultiset.Node].lt(self.root, key)
520
521    def ge(self, key: T) -> Optional[T]:
522        return BSTMultisetNodeBase[T, ScapegoatTreeMultiset.Node].ge(self.root, key)
523
524    def gt(self, key: T) -> Optional[T]:
525        return BSTMultisetNodeBase[T, ScapegoatTreeMultiset.Node].gt(self.root, key)
526
527    def index(self, key: T) -> int:
528        return BSTMultisetNodeBase[T, ScapegoatTreeMultiset.Node].index(self.root, key)
529
530    def index_right(self, key: T) -> int:
531        return BSTMultisetNodeBase[T, ScapegoatTreeMultiset.Node].index_right(
532            self.root, key
533        )
534
535    def index_keys(self, key: T) -> int:
536        k = 0
537        node = self.root
538        while node:
539            if key == node.key:
540                if node.left:
541                    k += node.left.size
542                break
543            elif key < node.key:
544                node = node.left
545            else:
546                k += node.val if node.left is None else node.left.size + node.val
547                node = node.right
548        return k
549
550    def index_right_keys(self, key: T) -> int:
551        k = 0
552        node = self.root
553        while node:
554            if key == node.key:
555                k += node.val if node.left is None else node.left.size + node.val
556                break
557            if key < node.key:
558                node = node.left
559            else:
560                k += node.val if node.left is None else node.left.size + node.val
561                node = node.right
562        return k
563
564    def pop(self, k: int = -1) -> T:
565        if k < 0:
566            k += self.root.valsize
567        x = self[k]
568        self.discard(x)
569        return x
570
571    def pop_min(self) -> T:
572        return self.pop(0)
573
574    def pop_max(self) -> T:
575        return self.pop(-1)
576
577    def items(self) -> Iterator[tuple[T, int]]:
578        for i in range(self.len_elm()):
579            yield self._kth_elm_tree(i)
580
581    def keys(self) -> Iterator[T]:
582        for i in range(self.len_elm()):
583            yield self._kth_elm_tree(i)[0]
584
585    def values(self) -> Iterator[int]:
586        for i in range(self.len_elm()):
587            yield self._kth_elm_tree(i)[1]
588
589    def show(self) -> None:
590        print(
591            "{" + ", ".join(map(lambda x: f"{x[0]}: {x[1]}", self.tolist_items())) + "}"
592        )
593
594    def get_elm(self, k: int) -> T:
595        assert (
596            -self.len_elm() <= k < self.len_elm()
597        ), f"IndexError: {self.__class__.__name__}.get_elm({k}), len_elm=({self.len_elm()})"
598        return self._kth_elm_tree(k)[0]
599
600    def len_elm(self) -> int:
601        return self.root.size if self.root else 0
602
603    def tolist(self) -> list[T]:
604        return BSTMultisetNodeBase[T, ScapegoatTreeMultiset.Node].tolist(self.root)
605
606    def tolist_items(self) -> list[tuple[T, int]]:
607        return BSTMultisetNodeBase[T, ScapegoatTreeMultiset.Node].tolist_items(
608            self.root
609        )
610
611    def clear(self) -> None:
612        self.root = None
613
614    def get_max(self) -> T:
615        return self._kth_elm_tree(-1)[0]
616
617    def get_min(self) -> T:
618        return self._kth_elm_tree(0)[0]
619
620    def __contains__(self, key: T):
621        return BSTMultisetNodeBase[T, ScapegoatTreeMultiset.Node].contains(
622            self.root, key
623        )
624
625    def __getitem__(self, k: int) -> T:
626        assert (
627            -len(self) <= k < len(self)
628        ), f"IndexError: {self.__class__.__name__}[{k}], len={len(self)}"
629        return self._kth_elm(k)[0]
630
631    def __iter__(self):
632        self.__iter = 0
633        return self
634
635    def __next__(self):
636        if self.__iter == len(self):
637            raise StopIteration
638        res = self._kth_elm(self.__iter)[0]
639        self.__iter += 1
640        return res
641
642    def __reversed__(self):
643        for i in range(len(self)):
644            yield self._kth_elm(-i - 1)[0]
645
646    def __len__(self):
647        return self.root.valsize if self.root else 0
648
649    def __bool__(self):
650        return self.root is not None
651
652    def __str__(self):
653        return "{" + ", ".join(map(str, self.tolist())) + "}"
654
655    def __repr__(self):
656        return f"{self.__class__.__name__}({self.tolist})"

仕様

class ScapegoatTreeMultiset(a: Iterable[T] = [])[source]

Bases: OrderedMultisetInterface, Generic[T]

ALPHA: Final[float] = 0.75
BETA: Final[float] = 0.41503749927884376
class Node(key: T, val: int)[source]

Bases: object

add(key: T, val: int = 1) None[source]
clear() None[source]
count(key: T) int[source]
discard(key: T, val=1) bool[source]
discard_all(key: T) bool[source]
ge(key: T) T | None[source]
get_elm(k: int) T[source]
get_max() T[source]
get_min() T[source]
gt(key: T) T | None[source]
index(key: T) int[source]
index_keys(key: T) int[source]
index_right(key: T) int[source]
index_right_keys(key: T) int[source]
items() Iterator[tuple[T, int]][source]
keys() Iterator[T][source]
le(key: T) T | None[source]
len_elm() int[source]
lt(key: T) T | None[source]
pop(k: int = -1) T[source]
pop_max() T[source]
pop_min() T[source]
remove(key: T, val: int = 1) None[source]
show() None[source]
tolist() list[T][source]
tolist_items() list[tuple[T, int]][source]
values() Iterator[int][source]