Source code for titan_pylib.data_structures.set.sorted_multiset

  1# https://github.com/tatyam-prime/SortedSet/blob/main/SortedMultiset.py
  2import math
  3from bisect import bisect_left, bisect_right
  4from typing import Generic, Iterable, Iterator, TypeVar, Optional
  5
  6T = TypeVar("T")
  7
  8
[docs] 9class SortedMultiset(Generic[T]): 10 BUCKET_RATIO = 16 11 SPLIT_RATIO = 24 12 13 def __init__(self, a: Iterable[T] = []) -> None: 14 "Make a new SortedMultiset from iterable. / O(N) if sorted / O(N log N)" 15 a = list(a) 16 n = self.size = len(a) 17 if any(a[i] > a[i + 1] for i in range(n - 1)): 18 a.sort() 19 num_bucket = int(math.ceil(math.sqrt(n / self.BUCKET_RATIO))) 20 self.a = [ 21 a[n * i // num_bucket : n * (i + 1) // num_bucket] 22 for i in range(num_bucket) 23 ] 24 25 def __iter__(self) -> Iterator[T]: 26 for i in self.a: 27 for j in i: 28 yield j 29 30 def __reversed__(self) -> Iterator[T]: 31 for i in reversed(self.a): 32 for j in reversed(i): 33 yield j 34 35 def __eq__(self, other) -> bool: 36 return list(self) == list(other) 37 38 def __len__(self) -> int: 39 return self.size 40 41 def __repr__(self) -> str: 42 return "SortedMultiset" + str(self.a) 43 44 def __str__(self) -> str: 45 s = str(list(self)) 46 return "{" + s[1 : len(s) - 1] + "}" 47 48 def _position(self, x: T) -> tuple[list[T], int, int]: 49 "return the bucket, index of the bucket and position in which x should be. self must not be empty." 50 for i, a in enumerate(self.a): 51 if x <= a[-1]: 52 break 53 return (a, i, bisect_left(a, x)) 54 55 def __contains__(self, x: T) -> bool: 56 if self.size == 0: 57 return False 58 a, _, i = self._position(x) 59 return i != len(a) and a[i] == x 60
[docs] 61 def count(self, x: T) -> int: 62 "Count the number of x." 63 return self.index_right(x) - self.index(x)
64
[docs] 65 def add(self, x: T) -> None: 66 "Add an element. / O(√N)" 67 if self.size == 0: 68 self.a = [[x]] 69 self.size = 1 70 return 71 a, b, i = self._position(x) 72 a.insert(i, x) 73 self.size += 1 74 if len(a) > len(self.a) * self.SPLIT_RATIO: 75 mid = len(a) >> 1 76 self.a[b : b + 1] = [a[:mid], a[mid:]]
77 78 def _pop(self, a: list[T], b: int, i: int) -> T: 79 ans = a.pop(i) 80 self.size -= 1 81 if not a: 82 del self.a[b] 83 return ans 84
[docs] 85 def discard(self, x: T) -> bool: 86 "Remove an element and return True if removed. / O(√N)" 87 if self.size == 0: 88 return False 89 a, b, i = self._position(x) 90 if i == len(a) or a[i] != x: 91 return False 92 self._pop(a, b, i) 93 return True
94
[docs] 95 def lt(self, x: T) -> Optional[T]: 96 "Find the largest element < x, or None if it doesn't exist." 97 for a in reversed(self.a): 98 if a[0] < x: 99 return a[bisect_left(a, x) - 1]
100
[docs] 101 def le(self, x: T) -> Optional[T]: 102 "Find the largest element <= x, or None if it doesn't exist." 103 for a in reversed(self.a): 104 if a[0] <= x: 105 return a[bisect_right(a, x) - 1]
106
[docs] 107 def gt(self, x: T) -> Optional[T]: 108 "Find the smallest element > x, or None if it doesn't exist." 109 for a in self.a: 110 if a[-1] > x: 111 return a[bisect_right(a, x)]
112
[docs] 113 def ge(self, x: T) -> Optional[T]: 114 "Find the smallest element >= x, or None if it doesn't exist." 115 for a in self.a: 116 if a[-1] >= x: 117 return a[bisect_left(a, x)]
118
[docs] 119 def __getitem__(self, i: int) -> T: 120 "Return the i-th element." 121 if i < 0: 122 for a in reversed(self.a): 123 i += len(a) 124 if i >= 0: 125 return a[i] 126 else: 127 for a in self.a: 128 if i < len(a): 129 return a[i] 130 i -= len(a) 131 raise IndexError
132
[docs] 133 def pop(self, i: int = -1) -> T: 134 "Pop and return the i-th element." 135 if i < 0: 136 for b, a in enumerate(reversed(self.a)): 137 i += len(a) 138 if i >= 0: 139 return self._pop(a, ~b, i) 140 else: 141 for b, a in enumerate(self.a): 142 if i < len(a): 143 return self._pop(a, b, i) 144 i -= len(a) 145 raise IndexError
146
[docs] 147 def index(self, x: T) -> int: 148 "Count the number of elements < x." 149 ans = 0 150 for a in self.a: 151 if a[-1] >= x: 152 return ans + bisect_left(a, x) 153 ans += len(a) 154 return ans
155
[docs] 156 def index_right(self, x: T) -> int: 157 "Count the number of elements <= x." 158 ans = 0 159 for a in self.a: 160 if a[-1] > x: 161 return ans + bisect_right(a, x) 162 ans += len(a) 163 return ans