treap_set

ソースコード

from titan_pylib.data_structures.treap.treap_set import TreapSet

view on github

展開済みコード

  1# from titan_pylib.data_structures.treap.treap_set import TreapSet
  2# from titan_pylib.my_class.ordered_set_interface import OrderedSetInterface
  3# from titan_pylib.my_class.supports_less_than import SupportsLessThan
  4from typing import Protocol
  5
  6
  7class SupportsLessThan(Protocol):
  8
  9    def __lt__(self, other) -> bool: ...
 10from abc import ABC, abstractmethod
 11from typing import Iterable, Optional, Iterator, TypeVar, Generic
 12
 13T = TypeVar("T", bound=SupportsLessThan)
 14
 15
 16class OrderedSetInterface(ABC, Generic[T]):
 17
 18    @abstractmethod
 19    def __init__(self, a: Iterable[T]) -> None:
 20        raise NotImplementedError
 21
 22    @abstractmethod
 23    def add(self, key: T) -> bool:
 24        raise NotImplementedError
 25
 26    @abstractmethod
 27    def discard(self, key: T) -> bool:
 28        raise NotImplementedError
 29
 30    @abstractmethod
 31    def remove(self, key: T) -> None:
 32        raise NotImplementedError
 33
 34    @abstractmethod
 35    def le(self, key: T) -> Optional[T]:
 36        raise NotImplementedError
 37
 38    @abstractmethod
 39    def lt(self, key: T) -> Optional[T]:
 40        raise NotImplementedError
 41
 42    @abstractmethod
 43    def ge(self, key: T) -> Optional[T]:
 44        raise NotImplementedError
 45
 46    @abstractmethod
 47    def gt(self, key: T) -> Optional[T]:
 48        raise NotImplementedError
 49
 50    @abstractmethod
 51    def get_max(self) -> Optional[T]:
 52        raise NotImplementedError
 53
 54    @abstractmethod
 55    def get_min(self) -> Optional[T]:
 56        raise NotImplementedError
 57
 58    @abstractmethod
 59    def pop_max(self) -> T:
 60        raise NotImplementedError
 61
 62    @abstractmethod
 63    def pop_min(self) -> T:
 64        raise NotImplementedError
 65
 66    @abstractmethod
 67    def clear(self) -> None:
 68        raise NotImplementedError
 69
 70    @abstractmethod
 71    def tolist(self) -> list[T]:
 72        raise NotImplementedError
 73
 74    @abstractmethod
 75    def __iter__(self) -> Iterator:
 76        raise NotImplementedError
 77
 78    @abstractmethod
 79    def __next__(self) -> T:
 80        raise NotImplementedError
 81
 82    @abstractmethod
 83    def __contains__(self, key: T) -> bool:
 84        raise NotImplementedError
 85
 86    @abstractmethod
 87    def __len__(self) -> int:
 88        raise NotImplementedError
 89
 90    @abstractmethod
 91    def __bool__(self) -> bool:
 92        raise NotImplementedError
 93
 94    @abstractmethod
 95    def __str__(self) -> str:
 96        raise NotImplementedError
 97
 98    @abstractmethod
 99    def __repr__(self) -> str:
100        raise NotImplementedError
101# from titan_pylib.my_class.supports_less_than import SupportsLessThan
102# from titan_pylib.data_structures.bst_base.bst_set_node_base import BSTSetNodeBase
103from typing import TypeVar, Generic, Optional
104
105T = TypeVar("T")
106Node = TypeVar("Node")
107# protcolで、key,left,right を規定
108
109
110class BSTSetNodeBase(Generic[T, Node]):
111
112    @staticmethod
113    def sort_unique(a: list[T]) -> list[T]:
114        if not all(a[i] < a[i + 1] for i in range(len(a) - 1)):
115            a = sorted(a)
116            new_a = [a[0]]
117            for elm in a:
118                if new_a[-1] == elm:
119                    continue
120                new_a.append(elm)
121            a = new_a
122        return a
123
124    @staticmethod
125    def contains(node: Node, key: T) -> bool:
126        while node:
127            if key == node.key:
128                return True
129            node = node.left if key < node.key else node.right
130        return False
131
132    @staticmethod
133    def get_min(node: Node) -> Optional[T]:
134        if not node:
135            return None
136        while node.left:
137            node = node.left
138        return node.key
139
140    @staticmethod
141    def get_max(node: Node) -> Optional[T]:
142        if not node:
143            return None
144        while node.right:
145            node = node.right
146        return node.key
147
148    @staticmethod
149    def le(node: Node, key: T) -> Optional[T]:
150        res = None
151        while node is not None:
152            if key == node.key:
153                res = key
154                break
155            if key < node.key:
156                node = node.left
157            else:
158                res = node.key
159                node = node.right
160        return res
161
162    @staticmethod
163    def lt(node: Node, key: T) -> Optional[T]:
164        res = None
165        while node is not None:
166            if key <= node.key:
167                node = node.left
168            else:
169                res = node.key
170                node = node.right
171        return res
172
173    @staticmethod
174    def ge(node: Node, key: T) -> Optional[T]:
175        res = None
176        while node is not None:
177            if key == node.key:
178                res = key
179                break
180            if key < node.key:
181                res = node.key
182                node = node.left
183            else:
184                node = node.right
185        return res
186
187    @staticmethod
188    def gt(node: Node, key: T) -> Optional[T]:
189        res = None
190        while node is not None:
191            if key < node.key:
192                res = node.key
193                node = node.left
194            else:
195                node = node.right
196        return res
197
198    @staticmethod
199    def index(node: Node, key: T) -> int:
200        k = 0
201        while node is not None:
202            if key == node.key:
203                if node.left is not None:
204                    k += node.left.size
205                break
206            if key < node.key:
207                node = node.left
208            else:
209                k += 1 if node.left is None else node.left.size + 1
210                node = node.right
211        return k
212
213    @staticmethod
214    def index_right(node: Node, key: T) -> int:
215        k = 0
216        while node is not None:
217            if key == node.key:
218                k += 1 if node.left is None else node.left.size + 1
219                break
220            if key < node.key:
221                node = node.left
222            else:
223                k += 1 if node.left is None else node.left.size + 1
224                node = node.right
225        return k
226
227    @staticmethod
228    def tolist(node: Node) -> list[T]:
229        stack = []
230        res = []
231        while stack or node:
232            if node:
233                stack.append(node)
234                node = node.left
235            else:
236                node = stack.pop()
237                res.append(node.key)
238                node = node.right
239        return res
240
241    @staticmethod
242    def kth_elm(node: Node, k: int, _len: int) -> T:
243        if k < 0:
244            k += _len
245        while True:
246            t = 0 if node.left is None else node.left.size
247            if t == k:
248                return node.key
249            if t > k:
250                node = node.left
251            else:
252                node = node.right
253                k -= t + 1
254from typing import Generic, Iterable, TypeVar, Optional
255
256T = TypeVar("T", bound=SupportsLessThan)
257
258
259class TreapSet(OrderedSetInterface, Generic[T]):
260    """treap です。
261
262    乱数を使用して平衡を保っています。Hackされることなんてあるんですかね。今のところ集合と多重集合しかないです。
263    """
264
265    class Random:
266
267        _x, _y, _z, _w = 123456789, 362436069, 521288629, 88675123
268
269        @classmethod
270        def random(cls) -> int:
271            t = (cls._x ^ ((cls._x << 11) & 0xFFFFFFFF)) & 0xFFFFFFFF
272            cls._x, cls._y, cls._z = cls._y, cls._z, cls._w
273            cls._w = (cls._w ^ (cls._w >> 19)) ^ (
274                t ^ ((t >> 8)) & 0xFFFFFFFF
275            ) & 0xFFFFFFFF
276            return cls._w
277
278    class Node:
279
280        def __init__(self, key: T, priority: int = -1):
281            self.key: T = key
282            self.left: Optional["TreapSet.Node"] = None
283            self.right: Optional["TreapSet.Node"] = None
284            self.priority: int = (
285                TreapSet.Random.random() if priority == -1 else priority
286            )
287
288        def __str__(self):
289            if self.left is None and self.right is None:
290                return f"key:{self.key, self.priority}\n"
291            return f"key:{self.key, self.priority},\n left:{self.left},\n right:{self.right}\n"
292
293    def __init__(self, a: Iterable[T] = []):
294        self.root: Optional["TreapSet.Node"] = None
295        self._len: int = 0
296        if not isinstance(a, list):
297            a = list(a)
298        if a:
299            self._build(a)
300
301    def _build(self, a: list[T]) -> None:
302        Node = TreapSet.Node
303
304        def rec(l: int, r: int) -> TreapSet.Node:
305            mid = (l + r) >> 1
306            node = Node(a[mid], rand[mid])
307            if l != mid:
308                node.left = rec(l, mid)
309            if mid + 1 != r:
310                node.right = rec(mid + 1, r)
311            return node
312
313        a = BSTSetNodeBase[T, TreapSet.Node].sort_unique(a)
314        self._len = len(a)
315        rand = sorted(TreapSet.Random.random() for _ in range(self._len))
316        self.root = rec(0, self._len)
317
318    def _rotate_L(self, node: Node) -> Node:
319        u = node.left
320        node.left = u.right
321        u.right = node
322        return u
323
324    def _rotate_R(self, node: Node) -> Node:
325        u = node.right
326        node.right = u.left
327        u.left = node
328        return u
329
330    def add(self, key: T) -> bool:
331        if not self.root:
332            self.root = TreapSet.Node(key)
333            self._len = 1
334            return True
335        node = self.root
336        path = []
337        di = 0
338        while node:
339            if key == node.key:
340                return False
341            path.append(node)
342            if key < node.key:
343                di <<= 1
344                di |= 1
345                node = node.left
346            else:
347                di <<= 1
348                node = node.right
349        if di & 1:
350            path[-1].left = TreapSet.Node(key)
351        else:
352            path[-1].right = TreapSet.Node(key)
353        while path:
354            new_node = None
355            node = path.pop()
356            if di & 1:
357                if node.left.priority < node.priority:
358                    new_node = self._rotate_L(node)
359            else:
360                if node.right.priority < node.priority:
361                    new_node = self._rotate_R(node)
362            di >>= 1
363            if new_node:
364                if path:
365                    if di & 1:
366                        path[-1].left = new_node
367                    else:
368                        path[-1].right = new_node
369                else:
370                    self.root = new_node
371        self._len += 1
372        return True
373
374    def discard(self, key: T) -> bool:
375        node = self.root
376        pnode = None
377        while node:
378            if key == node.key:
379                break
380            pnode = node
381            node = node.left if key < node.key else node.right
382        else:
383            return False
384        self._len -= 1
385        while node.left and node.right:
386            if node.left.priority < node.right.priority:
387                if not pnode:
388                    pnode = self._rotate_L(node)
389                    self.root = pnode
390                    continue
391                new_node = self._rotate_L(node)
392                if node.key < pnode.key:
393                    pnode.left = new_node
394                else:
395                    pnode.right = new_node
396            else:
397                if not pnode:
398                    pnode = self._rotate_R(node)
399                    self.root = pnode
400                    continue
401                new_node = self._rotate_R(node)
402                if node.key < pnode.key:
403                    pnode.left = new_node
404                else:
405                    pnode.right = new_node
406            pnode = new_node
407        if not pnode:
408            if node.left is None:
409                self.root = node.right
410            else:
411                self.root = node.left
412            return True
413        if node.left is None:
414            if node.key < pnode.key:
415                pnode.left = node.right
416            else:
417                pnode.right = node.right
418        else:
419            if node.key < pnode.key:
420                pnode.left = node.left
421            else:
422                pnode.right = node.left
423        return True
424
425    def remove(self, key: T) -> None:
426        if self.discard(key):
427            return
428        raise KeyError(key)
429
430    def le(self, key: T) -> Optional[T]:
431        return BSTSetNodeBase[T, TreapSet.Node].le(self.root, key)
432
433    def lt(self, key: T) -> Optional[T]:
434        return BSTSetNodeBase[T, TreapSet.Node].lt(self.root, key)
435
436    def ge(self, key: T) -> Optional[T]:
437        return BSTSetNodeBase[T, TreapSet.Node].ge(self.root, key)
438
439    def gt(self, key: T) -> Optional[T]:
440        return BSTSetNodeBase[T, TreapSet.Node].gt(self.root, key)
441
442    def get_min(self) -> Optional[T]:
443        return BSTSetNodeBase[T, TreapSet.Node].get_min(self.root)
444
445    def get_max(self) -> Optional[T]:
446        return BSTSetNodeBase[T, TreapSet.Node].get_max(self.root)
447
448    def pop_min(self) -> T:
449        assert self.root, f"IndexError: pop_min() from Empty {self.__class__.__name__}."
450        node = self.root
451        pnode = None
452        while node.left:
453            pnode = node
454            node = node.left
455        self._len -= 1
456        res = node.key
457        if not pnode:
458            self.root = self.root.right
459        else:
460            pnode.left = node.right
461        return res
462
463    def pop_max(self) -> T:
464        assert self.root, f"IndexError: pop_max() from Empty {self.__class__.__name__}."
465        node = self.root
466        pnode = None
467        while node.right:
468            pnode = node
469            node = node.right
470        self._len -= 1
471        res = node.key
472        if not pnode:
473            self.root = self.root.left
474        else:
475            pnode.right = node.left
476        return res
477
478    def clear(self) -> None:
479        self.root = None
480
481    def tolist(self) -> list[T]:
482        return BSTSetNodeBase[T, TreapSet.Node].tolist(self.root)
483
484    def __iter__(self):
485        self._it = self.get_min()
486        return self
487
488    def __next__(self):
489        if self._it is None:
490            raise StopIteration
491        res = self._it
492        self._it = self.gt(self._it)
493        return res
494
495    def __contains__(self, key: T):
496        return BSTSetNodeBase[T, TreapSet.Node].contains(self.root, key)
497
498    def __len__(self):
499        return self._len
500
501    def __bool__(self):
502        return self._len > 0
503
504    def __str__(self):
505        return "{" + ", ".join(map(str, self.tolist())) + "}"
506
507    def __repr__(self):
508        return f"{self.__class__.__name__}({self.tolist()})"

仕様

class TreapSet(a: Iterable[T] = [])[source]

Bases: OrderedSetInterface, Generic[T]

treap です。

乱数を使用して平衡を保っています。Hackされることなんてあるんですかね。今のところ集合と多重集合しかないです。

class Node(key: T, priority: int = -1)[source]

Bases: object

class Random[source]

Bases: object

classmethod random() int[source]
add(key: T) bool[source]
clear() None[source]
discard(key: T) bool[source]
ge(key: T) T | None[source]
get_max() T | None[source]
get_min() T | None[source]
gt(key: T) T | None[source]
le(key: T) T | None[source]
lt(key: T) T | None[source]
pop_max() T[source]
pop_min() T[source]
remove(key: T) None[source]
tolist() list[T][source]