1# https://github.com/tatyam-prime/SortedSet/blob/main/SortedSet.py
2import math
3from bisect import bisect_left, bisect_right
4from typing import Generic, Iterable, Iterator, TypeVar, Optional
5
6T = TypeVar("T")
7
8
[docs]
9class SortedSet(Generic[T]):
10 BUCKET_RATIO = 16
11 SPLIT_RATIO = 24
12
13 def __init__(self, a: Iterable[T] = []) -> None:
14 "Make a new SortedSet from iterable. / O(N) if sorted and unique / O(N log N)"
15 a = list(a)
16 n = len(a)
17 if any(a[i] > a[i + 1] for i in range(n - 1)):
18 a.sort()
19 if any(a[i] >= a[i + 1] for i in range(n - 1)):
20 a, b = [], a
21 for x in b:
22 if not a or a[-1] != x:
23 a.append(x)
24 n = self.size = len(a)
25 num_bucket = int(math.ceil(math.sqrt(n / self.BUCKET_RATIO)))
26 self.a = [
27 a[n * i // num_bucket : n * (i + 1) // num_bucket]
28 for i in range(num_bucket)
29 ]
30
31 def __iter__(self) -> Iterator[T]:
32 for i in self.a:
33 for j in i:
34 yield j
35
36 def __reversed__(self) -> Iterator[T]:
37 for i in reversed(self.a):
38 for j in reversed(i):
39 yield j
40
41 def __eq__(self, other) -> bool:
42 return list(self) == list(other)
43
44 def __len__(self) -> int:
45 return self.size
46
47 def __repr__(self) -> str:
48 return "SortedSet" + str(self.a)
49
50 def __str__(self) -> str:
51 s = str(list(self))
52 return "{" + s[1 : len(s) - 1] + "}"
53
54 def _position(self, x: T) -> tuple[list[T], int, int]:
55 "return the bucket, index of the bucket and position in which x should be. self must not be empty."
56 for i, a in enumerate(self.a):
57 if x <= a[-1]:
58 break
59 return (a, i, bisect_left(a, x))
60
61 def __contains__(self, x: T) -> bool:
62 if self.size == 0:
63 return False
64 a, _, i = self._position(x)
65 return i != len(a) and a[i] == x
66
[docs]
67 def add(self, x: T) -> bool:
68 "Add an element and return True if added. / O(√N)"
69 if self.size == 0:
70 self.a = [[x]]
71 self.size = 1
72 return True
73 a, b, i = self._position(x)
74 if i != len(a) and a[i] == x:
75 return False
76 a.insert(i, x)
77 self.size += 1
78 if len(a) > len(self.a) * self.SPLIT_RATIO:
79 mid = len(a) >> 1
80 self.a[b : b + 1] = [a[:mid], a[mid:]]
81 return True
82
83 def _pop(self, a: list[T], b: int, i: int) -> T:
84 ans = a.pop(i)
85 self.size -= 1
86 if not a:
87 del self.a[b]
88 return ans
89
[docs]
90 def discard(self, x: T) -> bool:
91 "Remove an element and return True if removed. / O(√N)"
92 if self.size == 0:
93 return False
94 a, b, i = self._position(x)
95 if i == len(a) or a[i] != x:
96 return False
97 self._pop(a, b, i)
98 return True
99
[docs]
100 def lt(self, x: T) -> Optional[T]:
101 "Find the largest element < x, or None if it doesn't exist."
102 for a in reversed(self.a):
103 if a[0] < x:
104 return a[bisect_left(a, x) - 1]
105
[docs]
106 def le(self, x: T) -> Optional[T]:
107 "Find the largest element <= x, or None if it doesn't exist."
108 for a in reversed(self.a):
109 if a[0] <= x:
110 return a[bisect_right(a, x) - 1]
111
[docs]
112 def gt(self, x: T) -> Optional[T]:
113 "Find the smallest element > x, or None if it doesn't exist."
114 for a in self.a:
115 if a[-1] > x:
116 return a[bisect_right(a, x)]
117
[docs]
118 def ge(self, x: T) -> Optional[T]:
119 "Find the smallest element >= x, or None if it doesn't exist."
120 for a in self.a:
121 if a[-1] >= x:
122 return a[bisect_left(a, x)]
123
[docs]
124 def __getitem__(self, i: int) -> T:
125 "Return the i-th element."
126 if i < 0:
127 for a in reversed(self.a):
128 i += len(a)
129 if i >= 0:
130 return a[i]
131 else:
132 for a in self.a:
133 if i < len(a):
134 return a[i]
135 i -= len(a)
136 raise IndexError
137
[docs]
138 def pop(self, i: int = -1) -> T:
139 "Pop and return the i-th element."
140 if i < 0:
141 for b, a in enumerate(reversed(self.a)):
142 i += len(a)
143 if i >= 0:
144 return self._pop(a, ~b, i)
145 else:
146 for b, a in enumerate(self.a):
147 if i < len(a):
148 return self._pop(a, b, i)
149 i -= len(a)
150 raise IndexError
151
[docs]
152 def index(self, x: T) -> int:
153 "Count the number of elements < x."
154 ans = 0
155 for a in self.a:
156 if a[-1] >= x:
157 return ans + bisect_left(a, x)
158 ans += len(a)
159 return ans
160
[docs]
161 def index_right(self, x: T) -> int:
162 "Count the number of elements <= x."
163 ans = 0
164 for a in self.a:
165 if a[-1] > x:
166 return ans + bisect_right(a, x)
167 ans += len(a)
168 return ans