scapegoat_tree_set

ソースコード

from titan_pylib.data_structures.scapegoat_tree.scapegoat_tree_set import ScapegoatTreeSet

view on github

展開済みコード

  1# from titan_pylib.data_structures.scapegoat_tree.scapegoat_tree_set import ScapegoatTreeSet
  2# from titan_pylib.my_class.ordered_set_interface import OrderedSetInterface
  3# from titan_pylib.my_class.supports_less_than import SupportsLessThan
  4from typing import Protocol
  5
  6
  7class SupportsLessThan(Protocol):
  8
  9    def __lt__(self, other) -> bool: ...
 10from abc import ABC, abstractmethod
 11from typing import Iterable, Optional, Iterator, TypeVar, Generic
 12
 13T = TypeVar("T", bound=SupportsLessThan)
 14
 15
 16class OrderedSetInterface(ABC, Generic[T]):
 17
 18    @abstractmethod
 19    def __init__(self, a: Iterable[T]) -> None:
 20        raise NotImplementedError
 21
 22    @abstractmethod
 23    def add(self, key: T) -> bool:
 24        raise NotImplementedError
 25
 26    @abstractmethod
 27    def discard(self, key: T) -> bool:
 28        raise NotImplementedError
 29
 30    @abstractmethod
 31    def remove(self, key: T) -> None:
 32        raise NotImplementedError
 33
 34    @abstractmethod
 35    def le(self, key: T) -> Optional[T]:
 36        raise NotImplementedError
 37
 38    @abstractmethod
 39    def lt(self, key: T) -> Optional[T]:
 40        raise NotImplementedError
 41
 42    @abstractmethod
 43    def ge(self, key: T) -> Optional[T]:
 44        raise NotImplementedError
 45
 46    @abstractmethod
 47    def gt(self, key: T) -> Optional[T]:
 48        raise NotImplementedError
 49
 50    @abstractmethod
 51    def get_max(self) -> Optional[T]:
 52        raise NotImplementedError
 53
 54    @abstractmethod
 55    def get_min(self) -> Optional[T]:
 56        raise NotImplementedError
 57
 58    @abstractmethod
 59    def pop_max(self) -> T:
 60        raise NotImplementedError
 61
 62    @abstractmethod
 63    def pop_min(self) -> T:
 64        raise NotImplementedError
 65
 66    @abstractmethod
 67    def clear(self) -> None:
 68        raise NotImplementedError
 69
 70    @abstractmethod
 71    def tolist(self) -> list[T]:
 72        raise NotImplementedError
 73
 74    @abstractmethod
 75    def __iter__(self) -> Iterator:
 76        raise NotImplementedError
 77
 78    @abstractmethod
 79    def __next__(self) -> T:
 80        raise NotImplementedError
 81
 82    @abstractmethod
 83    def __contains__(self, key: T) -> bool:
 84        raise NotImplementedError
 85
 86    @abstractmethod
 87    def __len__(self) -> int:
 88        raise NotImplementedError
 89
 90    @abstractmethod
 91    def __bool__(self) -> bool:
 92        raise NotImplementedError
 93
 94    @abstractmethod
 95    def __str__(self) -> str:
 96        raise NotImplementedError
 97
 98    @abstractmethod
 99    def __repr__(self) -> str:
100        raise NotImplementedError
101# from titan_pylib.my_class.supports_less_than import SupportsLessThan
102# from titan_pylib.data_structures.bst_base.bst_set_node_base import BSTSetNodeBase
103from typing import TypeVar, Generic, Optional
104
105T = TypeVar("T")
106Node = TypeVar("Node")
107# protcolで、key,left,right を規定
108
109
110class BSTSetNodeBase(Generic[T, Node]):
111
112    @staticmethod
113    def sort_unique(a: list[T]) -> list[T]:
114        if not all(a[i] < a[i + 1] for i in range(len(a) - 1)):
115            a = sorted(a)
116            new_a = [a[0]]
117            for elm in a:
118                if new_a[-1] == elm:
119                    continue
120                new_a.append(elm)
121            a = new_a
122        return a
123
124    @staticmethod
125    def contains(node: Node, key: T) -> bool:
126        while node:
127            if key == node.key:
128                return True
129            node = node.left if key < node.key else node.right
130        return False
131
132    @staticmethod
133    def get_min(node: Node) -> Optional[T]:
134        if not node:
135            return None
136        while node.left:
137            node = node.left
138        return node.key
139
140    @staticmethod
141    def get_max(node: Node) -> Optional[T]:
142        if not node:
143            return None
144        while node.right:
145            node = node.right
146        return node.key
147
148    @staticmethod
149    def le(node: Node, key: T) -> Optional[T]:
150        res = None
151        while node is not None:
152            if key == node.key:
153                res = key
154                break
155            if key < node.key:
156                node = node.left
157            else:
158                res = node.key
159                node = node.right
160        return res
161
162    @staticmethod
163    def lt(node: Node, key: T) -> Optional[T]:
164        res = None
165        while node is not None:
166            if key <= node.key:
167                node = node.left
168            else:
169                res = node.key
170                node = node.right
171        return res
172
173    @staticmethod
174    def ge(node: Node, key: T) -> Optional[T]:
175        res = None
176        while node is not None:
177            if key == node.key:
178                res = key
179                break
180            if key < node.key:
181                res = node.key
182                node = node.left
183            else:
184                node = node.right
185        return res
186
187    @staticmethod
188    def gt(node: Node, key: T) -> Optional[T]:
189        res = None
190        while node is not None:
191            if key < node.key:
192                res = node.key
193                node = node.left
194            else:
195                node = node.right
196        return res
197
198    @staticmethod
199    def index(node: Node, key: T) -> int:
200        k = 0
201        while node is not None:
202            if key == node.key:
203                if node.left is not None:
204                    k += node.left.size
205                break
206            if key < node.key:
207                node = node.left
208            else:
209                k += 1 if node.left is None else node.left.size + 1
210                node = node.right
211        return k
212
213    @staticmethod
214    def index_right(node: Node, key: T) -> int:
215        k = 0
216        while node is not None:
217            if key == node.key:
218                k += 1 if node.left is None else node.left.size + 1
219                break
220            if key < node.key:
221                node = node.left
222            else:
223                k += 1 if node.left is None else node.left.size + 1
224                node = node.right
225        return k
226
227    @staticmethod
228    def tolist(node: Node) -> list[T]:
229        stack = []
230        res = []
231        while stack or node:
232            if node:
233                stack.append(node)
234                node = node.left
235            else:
236                node = stack.pop()
237                res.append(node.key)
238                node = node.right
239        return res
240
241    @staticmethod
242    def kth_elm(node: Node, k: int, _len: int) -> T:
243        if k < 0:
244            k += _len
245        while True:
246            t = 0 if node.left is None else node.left.size
247            if t == k:
248                return node.key
249            if t > k:
250                node = node.left
251            else:
252                node = node.right
253                k -= t + 1
254import math
255from typing import Final, Iterator, TypeVar, Generic, Iterable, Optional
256
257T = TypeVar("T", bound=SupportsLessThan)
258
259
260class ScapegoatTreeSet(OrderedSetInterface, Generic[T]):
261
262    ALPHA: Final[float] = 0.75
263    BETA: Final[float] = math.log2(1 / ALPHA)
264
265    class Node:
266
267        def __init__(self, key: T):
268            self.key: T = key
269            self.left: Optional["ScapegoatTreeSet.Node"] = None
270            self.right: Optional["ScapegoatTreeSet.Node"] = None
271            self.size: int = 1
272
273        def __str__(self):
274            if self.left is None and self.right is None:
275                return f"key:{self.key, self.size}\n"
276            return (
277                f"key:{self.key, self.size},\n left:{self.left},\n right:{self.right}\n"
278            )
279
280    def __init__(self, a: Iterable[T] = []):
281        self.root: Optional["ScapegoatTreeSet.Node"] = None
282        if not isinstance(a, list):
283            a = list(a)
284        if a:
285            self._build(a)
286
287    def _build(self, a: list[T]) -> None:
288        Node = ScapegoatTreeSet.Node
289
290        def rec(l: int, r: int) -> ScapegoatTreeSet.Node:
291            mid = (l + r) >> 1
292            node = Node(a[mid])
293            if l != mid:
294                node.left = rec(l, mid)
295                node.size += node.left.size
296            if mid + 1 != r:
297                node.right = rec(mid + 1, r)
298                node.size += node.right.size
299            return node
300
301        a = BSTSetNodeBase[T, ScapegoatTreeSet.Node].sort_unique(a)
302        self.root = rec(0, len(a))
303
304    def _rebuild(self, node: Node) -> Node:
305        def rec(l: int, r: int) -> "ScapegoatTreeSet.Node":
306            mid = (l + r) >> 1
307            node = a[mid]
308            node.size = 1
309            if l != mid:
310                node.left = rec(l, mid)
311                node.size += node.left.size
312            else:
313                node.left = None
314            if mid + 1 != r:
315                node.right = rec(mid + 1, r)
316                node.size += node.right.size
317            else:
318                node.right = None
319            return node
320
321        a = []
322        stack = []
323        while stack or node:
324            if node:
325                stack.append(node)
326                node = node.left
327            else:
328                node = stack.pop()
329                a.append(node)
330                node = node.right
331        return rec(0, len(a))
332
333    def add(self, key: T) -> bool:
334        Node = ScapegoatTreeSet.Node
335        node = self.root
336        if node is None:
337            self.root = Node(key)
338            return True
339        path = []
340        while node:
341            path.append(node)
342            if key == node.key:
343                return False
344            node = node.left if key < node.key else node.right
345        if key < path[-1].key:
346            path[-1].left = Node(key)
347        else:
348            path[-1].right = Node(key)
349        if len(path) * ScapegoatTreeSet.BETA > math.log(self.root.size):
350            node_size = 1
351            while path:
352                pnode = path.pop()
353                pnode_size = pnode.size + 1
354                if ScapegoatTreeSet.ALPHA * pnode_size < node_size:
355                    break
356                node_size = pnode_size
357            new_node = self._rebuild(pnode)
358            if not path:
359                self.root = new_node
360                return True
361            if new_node.key < path[-1].key:
362                path[-1].left = new_node
363            else:
364                path[-1].right = new_node
365        for p in path:
366            p.size += 1
367        return True
368
369    def discard(self, key: T) -> bool:
370        d = 1
371        node = self.root
372        path = []
373        while node is not None:
374            if key == node.key:
375                break
376            path.append(node)
377            d = key < node.key
378            node = node.left if d else node.right
379        else:
380            return False
381        if node.left is not None and node.right is not None:
382            path.append(node)
383            lmax = node.left
384            d = 1 if lmax.right is None else 0
385            while lmax.right is not None:
386                path.append(lmax)
387                lmax = lmax.right
388            node.key = lmax.key
389            node = lmax
390        cnode = node.right if node.left is None else node.left
391        if path:
392            if d == 1:
393                path[-1].left = cnode
394            else:
395                path[-1].right = cnode
396        else:
397            self.root = cnode
398        for p in path:
399            p.size -= 1
400        return True
401
402    def remove(self, key: T) -> None:
403        if self.discard(key):
404            return
405        raise KeyError
406
407    def le(self, key: T) -> Optional[T]:
408        return BSTSetNodeBase[T, ScapegoatTreeSet.Node].le(self.root, key)
409
410    def lt(self, key: T) -> Optional[T]:
411        return BSTSetNodeBase[T, ScapegoatTreeSet.Node].lt(self.root, key)
412
413    def ge(self, key: T) -> Optional[T]:
414        return BSTSetNodeBase[T, ScapegoatTreeSet.Node].ge(self.root, key)
415
416    def gt(self, key: T) -> Optional[T]:
417        return BSTSetNodeBase[T, ScapegoatTreeSet.Node].gt(self.root, key)
418
419    def index(self, key: T) -> int:
420        return BSTSetNodeBase[T, ScapegoatTreeSet.Node].index(self.root, key)
421
422    def index_right(self, key: T) -> int:
423        return BSTSetNodeBase[T, ScapegoatTreeSet.Node].index_right(self.root, key)
424
425    def pop(self, k: int = -1) -> T:
426        if k < 0:
427            k += len(self)
428        d = 1
429        node = self.root
430        path = []
431        while True:
432            t = 0 if node.left is None else node.left.size
433            if t == k:
434                break
435            path.append(node)
436            if t < k:
437                node = node.right
438                k -= t + 1
439                d = 0
440            elif t > k:
441                d = 1
442                node = node.left
443        res = node.key
444        if node.left is not None and node.right is not None:
445            path.append(node)
446            lmax = node.left
447            d = 1 if lmax.right is None else 0
448            while lmax.right is not None:
449                path.append(lmax)
450                lmax = lmax.right
451            node.key = lmax.key
452            node = lmax
453        cnode = node.right if node.left is None else node.left
454        if path:
455            if d == 1:
456                path[-1].left = cnode
457            else:
458                path[-1].right = cnode
459        else:
460            self.root = cnode
461        for p in path:
462            p.size -= 1
463        return res
464
465    def pop_min(self) -> T:
466        return self.pop(0)
467
468    def pop_max(self) -> T:
469        return self.pop(-1)
470
471    def clear(self) -> None:
472        self.root = None
473
474    def tolist(self) -> list[T]:
475        return BSTSetNodeBase[T, ScapegoatTreeSet.Node].tolist(self.root)
476
477    def get_min(self) -> T:
478        return self[0]
479
480    def get_max(self) -> T:
481        return self[-1]
482
483    def __contains__(self, key: T):
484        node = self.root
485        while node is not None:
486            if key == node.key:
487                return True
488            node = node.left if key < node.key else node.right
489        return False
490
491    def __getitem__(self, k: int) -> T:
492        return BSTSetNodeBase[T, ScapegoatTreeSet.Node].kth_elm(self.root, k, len(self))
493
494    def __iter__(self) -> Iterator[T]:
495        self.__iter = 0
496        return self
497
498    def __next__(self) -> T:
499        if self.__iter == self.__len__():
500            raise StopIteration
501        res = self[self.__iter]
502        self.__iter += 1
503        return res
504
505    def __reversed__(self):
506        for i in range(self.__len__()):
507            yield self[-i - 1]
508
509    def __len__(self):
510        return 0 if self.root is None else self.root.size
511
512    def __bool__(self):
513        return self.root is not None
514
515    def __str__(self):
516        return "{" + ", ".join(map(str, self.tolist())) + "}"
517
518    def __repr__(self):
519        return f"{self.__class__.__name__}({self})"

仕様

class ScapegoatTreeSet(a: Iterable[T] = [])[source]

Bases: OrderedSetInterface, Generic[T]

ALPHA: Final[float] = 0.75
BETA: Final[float] = 0.41503749927884376
class Node(key: T)[source]

Bases: object

add(key: T) bool[source]
clear() None[source]
discard(key: T) bool[source]
ge(key: T) T | None[source]
get_max() T[source]
get_min() T[source]
gt(key: T) T | None[source]
index(key: T) int[source]
index_right(key: T) int[source]
le(key: T) T | None[source]
lt(key: T) T | None[source]
pop(k: int = -1) T[source]
pop_max() T[source]
pop_min() T[source]
remove(key: T) None[source]
tolist() list[T][source]