1# from titan_pylib.data_structures.set.sorted_multiset import SortedMultiset
2# https://github.com/tatyam-prime/SortedSet/blob/main/SortedMultiset.py
3import math
4from bisect import bisect_left, bisect_right
5from typing import Generic, Iterable, Iterator, TypeVar, Optional
6
7T = TypeVar("T")
8
9
10class SortedMultiset(Generic[T]):
11 BUCKET_RATIO = 16
12 SPLIT_RATIO = 24
13
14 def __init__(self, a: Iterable[T] = []) -> None:
15 "Make a new SortedMultiset from iterable. / O(N) if sorted / O(N log N)"
16 a = list(a)
17 n = self.size = len(a)
18 if any(a[i] > a[i + 1] for i in range(n - 1)):
19 a.sort()
20 num_bucket = int(math.ceil(math.sqrt(n / self.BUCKET_RATIO)))
21 self.a = [
22 a[n * i // num_bucket : n * (i + 1) // num_bucket]
23 for i in range(num_bucket)
24 ]
25
26 def __iter__(self) -> Iterator[T]:
27 for i in self.a:
28 for j in i:
29 yield j
30
31 def __reversed__(self) -> Iterator[T]:
32 for i in reversed(self.a):
33 for j in reversed(i):
34 yield j
35
36 def __eq__(self, other) -> bool:
37 return list(self) == list(other)
38
39 def __len__(self) -> int:
40 return self.size
41
42 def __repr__(self) -> str:
43 return "SortedMultiset" + str(self.a)
44
45 def __str__(self) -> str:
46 s = str(list(self))
47 return "{" + s[1 : len(s) - 1] + "}"
48
49 def _position(self, x: T) -> tuple[list[T], int, int]:
50 "return the bucket, index of the bucket and position in which x should be. self must not be empty."
51 for i, a in enumerate(self.a):
52 if x <= a[-1]:
53 break
54 return (a, i, bisect_left(a, x))
55
56 def __contains__(self, x: T) -> bool:
57 if self.size == 0:
58 return False
59 a, _, i = self._position(x)
60 return i != len(a) and a[i] == x
61
62 def count(self, x: T) -> int:
63 "Count the number of x."
64 return self.index_right(x) - self.index(x)
65
66 def add(self, x: T) -> None:
67 "Add an element. / O(√N)"
68 if self.size == 0:
69 self.a = [[x]]
70 self.size = 1
71 return
72 a, b, i = self._position(x)
73 a.insert(i, x)
74 self.size += 1
75 if len(a) > len(self.a) * self.SPLIT_RATIO:
76 mid = len(a) >> 1
77 self.a[b : b + 1] = [a[:mid], a[mid:]]
78
79 def _pop(self, a: list[T], b: int, i: int) -> T:
80 ans = a.pop(i)
81 self.size -= 1
82 if not a:
83 del self.a[b]
84 return ans
85
86 def discard(self, x: T) -> bool:
87 "Remove an element and return True if removed. / O(√N)"
88 if self.size == 0:
89 return False
90 a, b, i = self._position(x)
91 if i == len(a) or a[i] != x:
92 return False
93 self._pop(a, b, i)
94 return True
95
96 def lt(self, x: T) -> Optional[T]:
97 "Find the largest element < x, or None if it doesn't exist."
98 for a in reversed(self.a):
99 if a[0] < x:
100 return a[bisect_left(a, x) - 1]
101
102 def le(self, x: T) -> Optional[T]:
103 "Find the largest element <= x, or None if it doesn't exist."
104 for a in reversed(self.a):
105 if a[0] <= x:
106 return a[bisect_right(a, x) - 1]
107
108 def gt(self, x: T) -> Optional[T]:
109 "Find the smallest element > x, or None if it doesn't exist."
110 for a in self.a:
111 if a[-1] > x:
112 return a[bisect_right(a, x)]
113
114 def ge(self, x: T) -> Optional[T]:
115 "Find the smallest element >= x, or None if it doesn't exist."
116 for a in self.a:
117 if a[-1] >= x:
118 return a[bisect_left(a, x)]
119
120 def __getitem__(self, i: int) -> T:
121 "Return the i-th element."
122 if i < 0:
123 for a in reversed(self.a):
124 i += len(a)
125 if i >= 0:
126 return a[i]
127 else:
128 for a in self.a:
129 if i < len(a):
130 return a[i]
131 i -= len(a)
132 raise IndexError
133
134 def pop(self, i: int = -1) -> T:
135 "Pop and return the i-th element."
136 if i < 0:
137 for b, a in enumerate(reversed(self.a)):
138 i += len(a)
139 if i >= 0:
140 return self._pop(a, ~b, i)
141 else:
142 for b, a in enumerate(self.a):
143 if i < len(a):
144 return self._pop(a, b, i)
145 i -= len(a)
146 raise IndexError
147
148 def index(self, x: T) -> int:
149 "Count the number of elements < x."
150 ans = 0
151 for a in self.a:
152 if a[-1] >= x:
153 return ans + bisect_left(a, x)
154 ans += len(a)
155 return ans
156
157 def index_right(self, x: T) -> int:
158 "Count the number of elements <= x."
159 ans = 0
160 for a in self.a:
161 if a[-1] > x:
162 return ans + bisect_right(a, x)
163 ans += len(a)
164 return ans