treap_multiset

ソースコード

from titan_pylib.data_structures.treap.treap_multiset import TreapMultiset

view on github

展開済みコード

  1# from titan_pylib.data_structures.treap.treap_multiset import TreapMultiset
  2# from titan_pylib.my_class.ordered_multiset_interface import OrderedMultisetInterface
  3# from titan_pylib.my_class.supports_less_than import SupportsLessThan
  4from typing import Protocol
  5
  6
  7class SupportsLessThan(Protocol):
  8
  9    def __lt__(self, other) -> bool: ...
 10from abc import ABC, abstractmethod
 11from typing import Iterable, Optional, Iterator, TypeVar, Generic
 12
 13T = TypeVar("T", bound=SupportsLessThan)
 14
 15
 16class OrderedMultisetInterface(ABC, Generic[T]):
 17
 18    @abstractmethod
 19    def __init__(self, a: Iterable[T]) -> None:
 20        raise NotImplementedError
 21
 22    @abstractmethod
 23    def add(self, key: T, cnt: int) -> None:
 24        raise NotImplementedError
 25
 26    @abstractmethod
 27    def discard(self, key: T, cnt: int) -> bool:
 28        raise NotImplementedError
 29
 30    @abstractmethod
 31    def discard_all(self, key: T) -> bool:
 32        raise NotImplementedError
 33
 34    @abstractmethod
 35    def count(self, key: T) -> int:
 36        raise NotImplementedError
 37
 38    @abstractmethod
 39    def remove(self, key: T, cnt: int) -> None:
 40        raise NotImplementedError
 41
 42    @abstractmethod
 43    def le(self, key: T) -> Optional[T]:
 44        raise NotImplementedError
 45
 46    @abstractmethod
 47    def lt(self, key: T) -> Optional[T]:
 48        raise NotImplementedError
 49
 50    @abstractmethod
 51    def ge(self, key: T) -> Optional[T]:
 52        raise NotImplementedError
 53
 54    @abstractmethod
 55    def gt(self, key: T) -> Optional[T]:
 56        raise NotImplementedError
 57
 58    @abstractmethod
 59    def get_max(self) -> Optional[T]:
 60        raise NotImplementedError
 61
 62    @abstractmethod
 63    def get_min(self) -> Optional[T]:
 64        raise NotImplementedError
 65
 66    @abstractmethod
 67    def pop_max(self) -> T:
 68        raise NotImplementedError
 69
 70    @abstractmethod
 71    def pop_min(self) -> T:
 72        raise NotImplementedError
 73
 74    @abstractmethod
 75    def clear(self) -> None:
 76        raise NotImplementedError
 77
 78    @abstractmethod
 79    def tolist(self) -> list[T]:
 80        raise NotImplementedError
 81
 82    @abstractmethod
 83    def __iter__(self) -> Iterator:
 84        raise NotImplementedError
 85
 86    @abstractmethod
 87    def __next__(self) -> T:
 88        raise NotImplementedError
 89
 90    @abstractmethod
 91    def __contains__(self, key: T) -> bool:
 92        raise NotImplementedError
 93
 94    @abstractmethod
 95    def __len__(self) -> int:
 96        raise NotImplementedError
 97
 98    @abstractmethod
 99    def __bool__(self) -> bool:
100        raise NotImplementedError
101
102    @abstractmethod
103    def __str__(self) -> str:
104        raise NotImplementedError
105
106    @abstractmethod
107    def __repr__(self) -> str:
108        raise NotImplementedError
109# from titan_pylib.my_class.supports_less_than import SupportsLessThan
110# from titan_pylib.data_structures.bst_base.bst_multiset_node_base import (
111#     BSTMultisetNodeBase,
112# )
113from typing import TypeVar, Generic, Optional
114
115T = TypeVar("T")
116Node = TypeVar("Node")
117# protcolで、key,val,left,right を規定
118
119
120class BSTMultisetNodeBase(Generic[T, Node]):
121
122    @staticmethod
123    def count(node, key: T) -> int:
124        while node is not None:
125            if node.key == key:
126                return node.val
127            node = node.left if key < node.key else node.right
128        return 0
129
130    @staticmethod
131    def get_min(node: Node) -> Optional[T]:
132        if not node:
133            return None
134        while node.left:
135            node = node.left
136        return node.key
137
138    @staticmethod
139    def get_max(node: Node) -> Optional[T]:
140        if not node:
141            return None
142        while node.right:
143            node = node.right
144        return node.key
145
146    @staticmethod
147    def contains(node: Node, key: T) -> bool:
148        while node:
149            if key == node.key:
150                return True
151            node = node.left if key < node.key else node.right
152        return False
153
154    @staticmethod
155    def tolist(node: Node) -> list[T]:
156        stack = []
157        a = []
158        while stack or node:
159            if node:
160                stack.append(node)
161                node = node.left
162            else:
163                node = stack.pop()
164                for _ in range(node.val):
165                    a.append(node.key)
166                node = node.right
167        return a
168
169    @staticmethod
170    def tolist_items(node: Node) -> list[tuple[T, int]]:
171        stack = []
172        a = []
173        while stack or node:
174            if node:
175                stack.append(node)
176                node = node.left
177            else:
178                node = stack.pop()
179                a.append((node.key, node.val))
180                node = node.right
181        return a
182
183    @staticmethod
184    def _rle(a: list[T]) -> tuple[list[T], list[int]]:
185        keys, vals = [], []
186        keys.append(a[0])
187        vals.append(1)
188        for i, elm in enumerate(a):
189            if i == 0:
190                continue
191            if elm == keys[-1]:
192                vals[-1] += 1
193                continue
194            keys.append(elm)
195            vals.append(1)
196        return keys, vals
197
198    @staticmethod
199    def le(node: Node, key: T) -> Optional[T]:
200        res = None
201        while node is not None:
202            if key == node.key:
203                res = key
204                break
205            if key < node.key:
206                node = node.left
207            else:
208                res = node.key
209                node = node.right
210        return res
211
212    @staticmethod
213    def lt(node: Node, key: T) -> Optional[T]:
214        res = None
215        while node is not None:
216            if key <= node.key:
217                node = node.left
218            else:
219                res = node.key
220                node = node.right
221        return res
222
223    @staticmethod
224    def ge(node: Node, key: T) -> Optional[T]:
225        res = None
226        while node is not None:
227            if key == node.key:
228                res = key
229                break
230            if key < node.key:
231                res = node.key
232                node = node.left
233            else:
234                node = node.right
235        return res
236
237    @staticmethod
238    def gt(node: Node, key: T) -> Optional[T]:
239        res = None
240        while node is not None:
241            if key < node.key:
242                res = node.key
243                node = node.left
244            else:
245                node = node.right
246        return res
247
248    @staticmethod
249    def index(node: Node, key: T) -> int:
250        k = 0
251        while node:
252            if key == node.key:
253                if node.left:
254                    k += node.left.valsize
255                break
256            if key < node.key:
257                node = node.left
258            else:
259                k += node.val if node.left is None else node.left.valsize + node.val
260                node = node.right
261        return k
262
263    @staticmethod
264    def index_right(node: Node, key: T) -> int:
265        k = 0
266        while node:
267            if key == node.key:
268                k += node.val if node.left is None else node.left.valsize + node.val
269                break
270            if key < node.key:
271                node = node.left
272            else:
273                k += node.val if node.left is None else node.left.valsize + node.val
274                node = node.right
275        return k
276from typing import Generic, Iterable, TypeVar, Optional, Sequence
277
278T = TypeVar("T", bound=SupportsLessThan)
279
280
281class TreapMultiset(OrderedMultisetInterface, Generic[T]):
282
283    class Random:
284
285        _x, _y, _z, _w = 123456789, 362436069, 521288629, 88675123
286
287        @classmethod
288        def random(cls) -> int:
289            t = cls._x ^ (cls._x << 11) & 0xFFFFFFFF
290            cls._x, cls._y, cls._z = cls._y, cls._z, cls._w
291            cls._w = (cls._w ^ (cls._w >> 19)) ^ (t ^ (t >> 8)) & 0xFFFFFFFF
292            return cls._w
293
294    class Node:
295
296        def __init__(self, key: T, val: int = 1, priority: int = -1):
297            self.key: T = key
298            self.val: int = val
299            self.left: Optional["TreapMultiset.Node"] = None
300            self.right: Optional["TreapMultiset.Node"] = None
301            self.priority: int = (
302                TreapMultiset.Random.random() if priority == -1 else priority
303            )
304
305        def __str__(self):
306            if self.left is None and self.right is None:
307                return f"key:{self.key, self.priority}\n"
308            return f"key:{self.key, self.priority},\n left:{self.left},\n right:{self.right}\n"
309
310    def __init__(self, a: Iterable[T] = []):
311        self.root: Optional["TreapMultiset.Node"] = None
312        self._len: int = 0
313        self._len_elm: int = 0
314        if not isinstance(a, Sequence):
315            a = list(a)
316        if a:
317            self._build(a)
318
319    def _build(self, a: Iterable[T]) -> None:
320        Node = TreapMultiset.Node
321
322        def sort(l: int, r: int) -> TreapMultiset.Node:
323            mid = (l + r) >> 1
324            node = Node(key[mid], val[mid], rand[mid])
325            if l != mid:
326                node.left = sort(l, mid)
327            if mid + 1 != r:
328                node.right = sort(mid + 1, r)
329            return node
330
331        a = sorted(a)
332        key, val = BSTMultisetNodeBase[T, TreapMultiset.Node]._rle(a)
333        self._len = len(a)
334        self._len_elm = len(key)
335        rand = sorted(TreapMultiset.Random.random() for _ in range(self._len_elm))
336        self.root = sort(0, len(key))
337
338    def _rotate_L(self, node: Node) -> Node:
339        u = node.left
340        node.left = u.right
341        u.right = node
342        return u
343
344    def _rotate_R(self, node: Node) -> Node:
345        u = node.right
346        node.right = u.left
347        u.left = node
348        return u
349
350    def add(self, key: T, val: int = 1) -> None:
351        self._len += val
352        if self.root is None:
353            self.root = TreapMultiset.Node(key, val)
354            self._len_elm += 1
355            return
356        node = self.root
357        path = []
358        di = 0
359        while node is not None:
360            if key == node.key:
361                node.val += val
362                return
363            path.append(node)
364            if key < node.key:
365                di <<= 1
366                di |= 1
367                node = node.left
368            else:
369                di <<= 1
370                node = node.right
371        self._len_elm += 1
372        if di & 1:
373            path[-1].left = TreapMultiset.Node(key, val)
374        else:
375            path[-1].right = TreapMultiset.Node(key, val)
376        while path:
377            new_node = None
378            node = path.pop()
379            if di & 1:
380                if node.left.priority < node.priority:
381                    new_node = self._rotate_L(node)
382            else:
383                if node.right.priority < node.priority:
384                    new_node = self._rotate_R(node)
385            di >>= 1
386            if new_node is not None:
387                if path:
388                    if di & 1:
389                        path[-1].left = new_node
390                    else:
391                        path[-1].right = new_node
392                else:
393                    self.root = new_node
394        self._len += 1
395
396    def discard(self, key: T, val: int = 1) -> bool:
397        node = self.root
398        pnode = None
399        while node is not None:
400            if key == node.key:
401                break
402            pnode = node
403            node = node.left if key < node.key else node.right
404        else:
405            return False
406        self._len -= min(val, node.val)
407        if node.val > val:
408            node.val -= val
409            return True
410        self._len_elm -= 1
411        while node.left is not None and node.right is not None:
412            if node.left.priority < node.right.priority:
413                if pnode is None:
414                    pnode = self._rotate_L(node)
415                    self.root = pnode
416                    continue
417                new_node = self._rotate_L(node)
418                if node.key < pnode.key:
419                    pnode.left = new_node
420                else:
421                    pnode.right = new_node
422            else:
423                if pnode is None:
424                    pnode = self._rotate_R(node)
425                    self.root = pnode
426                    continue
427                new_node = self._rotate_R(node)
428                if node.key < pnode.key:
429                    pnode.left = new_node
430                else:
431                    pnode.right = new_node
432            pnode = new_node
433        if pnode is None:
434            if node.left is None:
435                self.root = node.right
436            else:
437                self.root = node.left
438            return True
439        if node.left is None:
440            if node.key < pnode.key:
441                pnode.left = node.right
442            else:
443                pnode.right = node.right
444        else:
445            if node.key < pnode.key:
446                pnode.left = node.left
447            else:
448                pnode.right = node.left
449        return True
450
451    def discard_all(self, key: T) -> bool:
452        return self.discard(key, self.count(key))
453
454    def remove(self, key: T, val: int = 1) -> None:
455        if self.discard(key, val):
456            return
457        raise KeyError(key)
458
459    def count(self, key: T) -> int:
460        return BSTMultisetNodeBase[T, TreapMultiset.Node].count(self.root)
461
462    def le(self, key: T) -> Optional[T]:
463        return BSTMultisetNodeBase[T, TreapMultiset.Node].le(self.root, key)
464
465    def lt(self, key: T) -> Optional[T]:
466        return BSTMultisetNodeBase[T, TreapMultiset.Node].lt(self.root, key)
467
468    def ge(self, key: T) -> Optional[T]:
469        return BSTMultisetNodeBase[T, TreapMultiset.Node].ge(self.root, key)
470
471    def gt(self, key: T) -> Optional[T]:
472        return BSTMultisetNodeBase[T, TreapMultiset.Node].gt(self.root, key)
473
474    def len_elm(self) -> int:
475        return self._len_elm
476
477    def show(self) -> None:
478        print(
479            "{" + ", ".join(map(lambda x: f"{x[0]}: {x[1]}", self.tolist_items())) + "}"
480        )
481
482    def tolist(self) -> list[T]:
483        return BSTMultisetNodeBase[T, TreapMultiset.Node].tolist(self.root)
484
485    def tolist_items(self) -> list[tuple[T, int]]:
486        return BSTMultisetNodeBase[T, TreapMultiset.Node].tolist_items(self.root)
487
488    def get_min(self) -> Optional[T]:
489        return BSTMultisetNodeBase[T, TreapMultiset.Node][
490            T, TreapMultiset.Node
491        ].get_min(self.root)
492
493    def get_max(self) -> Optional[T]:
494        return BSTMultisetNodeBase[T, TreapMultiset.Node].get_max(self.root)
495
496    def pop_min(self) -> T:
497        assert self
498        self._len -= 1
499        node = self.root
500        pnode = None
501        while node.left is not None:
502            pnode = node
503            node = node.left
504        if node.val > 1:
505            node.val -= 1
506            return node.key
507        self._len_elm -= 1
508        res = node.key
509        if pnode is None:
510            self.root = self.root.right
511        else:
512            pnode.left = node.right
513        return res
514
515    def pop_max(self) -> T:
516        assert self, "IndexError"
517        self._len -= 1
518        node = self.root
519        pnode = None
520        while node.right is not None:
521            pnode = node
522            node = node.right
523        if node.val > 1:
524            node.val -= 1
525            return node.key
526        self._len_elm -= 1
527        res = node.key
528        if pnode is None:
529            self.root = self.root.left
530        else:
531            pnode.right = node.left
532        return res
533
534    def clear(self) -> None:
535        self.root = None
536
537    def __iter__(self):
538        self._it = self.get_min()
539        self._cnt = 1
540        return self
541
542    def __next__(self):
543        if self._it is None:
544            raise StopIteration
545        res = self._it
546        if self._cnt == self.count(self._it):
547            self._it = self.gt(self._it)
548            self._cnt = 1
549        else:
550            self._cnt += 1
551        return res
552
553    def __contains__(self, key: T):
554        return BSTMultisetNodeBase[T, TreapMultiset.Node].contains(self.root, key)
555
556    def __bool__(self):
557        return self.root is not None
558
559    def __len__(self):
560        return self._len
561
562    def __str__(self):
563        return "{" + ", ".join(map(str, self.tolist())) + "}"
564
565    def __repr__(self):
566        return f"TreapMultiset({self.tolist()})"

仕様

class TreapMultiset(a: Iterable[T] = [])[source]

Bases: OrderedMultisetInterface, Generic[T]

class Node(key: T, val: int = 1, priority: int = -1)[source]

Bases: object

class Random[source]

Bases: object

classmethod random() int[source]
add(key: T, val: int = 1) None[source]
clear() None[source]
count(key: T) int[source]
discard(key: T, val: int = 1) bool[source]
discard_all(key: T) bool[source]
ge(key: T) T | None[source]
get_max() T | None[source]
get_min() T | None[source]
gt(key: T) T | None[source]
le(key: T) T | None[source]
len_elm() int[source]
lt(key: T) T | None[source]
pop_max() T[source]
pop_min() T[source]
remove(key: T, val: int = 1) None[source]
show() None[source]
tolist() list[T][source]
tolist_items() list[tuple[T, int]][source]