1# https://github.com/tatyam-prime/SortedSet/blob/main/SortedMultiset.py
2import math
3from bisect import bisect_left, bisect_right
4from typing import Generic, Iterable, Iterator, TypeVar, Optional
5
6T = TypeVar("T")
7
8
[docs]
9class SortedMultiset(Generic[T]):
10 BUCKET_RATIO = 16
11 SPLIT_RATIO = 24
12
13 def __init__(self, a: Iterable[T] = []) -> None:
14 "Make a new SortedMultiset from iterable. / O(N) if sorted / O(N log N)"
15 a = list(a)
16 n = self.size = len(a)
17 if any(a[i] > a[i + 1] for i in range(n - 1)):
18 a.sort()
19 num_bucket = int(math.ceil(math.sqrt(n / self.BUCKET_RATIO)))
20 self.a = [
21 a[n * i // num_bucket : n * (i + 1) // num_bucket]
22 for i in range(num_bucket)
23 ]
24
25 def __iter__(self) -> Iterator[T]:
26 for i in self.a:
27 for j in i:
28 yield j
29
30 def __reversed__(self) -> Iterator[T]:
31 for i in reversed(self.a):
32 for j in reversed(i):
33 yield j
34
35 def __eq__(self, other) -> bool:
36 return list(self) == list(other)
37
38 def __len__(self) -> int:
39 return self.size
40
41 def __repr__(self) -> str:
42 return "SortedMultiset" + str(self.a)
43
44 def __str__(self) -> str:
45 s = str(list(self))
46 return "{" + s[1 : len(s) - 1] + "}"
47
48 def _position(self, x: T) -> tuple[list[T], int, int]:
49 "return the bucket, index of the bucket and position in which x should be. self must not be empty."
50 for i, a in enumerate(self.a):
51 if x <= a[-1]:
52 break
53 return (a, i, bisect_left(a, x))
54
55 def __contains__(self, x: T) -> bool:
56 if self.size == 0:
57 return False
58 a, _, i = self._position(x)
59 return i != len(a) and a[i] == x
60
[docs]
61 def count(self, x: T) -> int:
62 "Count the number of x."
63 return self.index_right(x) - self.index(x)
64
[docs]
65 def add(self, x: T) -> None:
66 "Add an element. / O(√N)"
67 if self.size == 0:
68 self.a = [[x]]
69 self.size = 1
70 return
71 a, b, i = self._position(x)
72 a.insert(i, x)
73 self.size += 1
74 if len(a) > len(self.a) * self.SPLIT_RATIO:
75 mid = len(a) >> 1
76 self.a[b : b + 1] = [a[:mid], a[mid:]]
77
78 def _pop(self, a: list[T], b: int, i: int) -> T:
79 ans = a.pop(i)
80 self.size -= 1
81 if not a:
82 del self.a[b]
83 return ans
84
[docs]
85 def discard(self, x: T) -> bool:
86 "Remove an element and return True if removed. / O(√N)"
87 if self.size == 0:
88 return False
89 a, b, i = self._position(x)
90 if i == len(a) or a[i] != x:
91 return False
92 self._pop(a, b, i)
93 return True
94
[docs]
95 def lt(self, x: T) -> Optional[T]:
96 "Find the largest element < x, or None if it doesn't exist."
97 for a in reversed(self.a):
98 if a[0] < x:
99 return a[bisect_left(a, x) - 1]
100
[docs]
101 def le(self, x: T) -> Optional[T]:
102 "Find the largest element <= x, or None if it doesn't exist."
103 for a in reversed(self.a):
104 if a[0] <= x:
105 return a[bisect_right(a, x) - 1]
106
[docs]
107 def gt(self, x: T) -> Optional[T]:
108 "Find the smallest element > x, or None if it doesn't exist."
109 for a in self.a:
110 if a[-1] > x:
111 return a[bisect_right(a, x)]
112
[docs]
113 def ge(self, x: T) -> Optional[T]:
114 "Find the smallest element >= x, or None if it doesn't exist."
115 for a in self.a:
116 if a[-1] >= x:
117 return a[bisect_left(a, x)]
118
[docs]
119 def __getitem__(self, i: int) -> T:
120 "Return the i-th element."
121 if i < 0:
122 for a in reversed(self.a):
123 i += len(a)
124 if i >= 0:
125 return a[i]
126 else:
127 for a in self.a:
128 if i < len(a):
129 return a[i]
130 i -= len(a)
131 raise IndexError
132
[docs]
133 def pop(self, i: int = -1) -> T:
134 "Pop and return the i-th element."
135 if i < 0:
136 for b, a in enumerate(reversed(self.a)):
137 i += len(a)
138 if i >= 0:
139 return self._pop(a, ~b, i)
140 else:
141 for b, a in enumerate(self.a):
142 if i < len(a):
143 return self._pop(a, b, i)
144 i -= len(a)
145 raise IndexError
146
[docs]
147 def index(self, x: T) -> int:
148 "Count the number of elements < x."
149 ans = 0
150 for a in self.a:
151 if a[-1] >= x:
152 return ans + bisect_left(a, x)
153 ans += len(a)
154 return ans
155
[docs]
156 def index_right(self, x: T) -> int:
157 "Count the number of elements <= x."
158 ans = 0
159 for a in self.a:
160 if a[-1] > x:
161 return ans + bisect_right(a, x)
162 ans += len(a)
163 return ans