1# from titan_pylib.data_structures.set.sorted_set import SortedSet
2# https://github.com/tatyam-prime/SortedSet/blob/main/SortedSet.py
3import math
4from bisect import bisect_left, bisect_right
5from typing import Generic, Iterable, Iterator, TypeVar, Optional
6
7T = TypeVar("T")
8
9
10class SortedSet(Generic[T]):
11 BUCKET_RATIO = 16
12 SPLIT_RATIO = 24
13
14 def __init__(self, a: Iterable[T] = []) -> None:
15 "Make a new SortedSet from iterable. / O(N) if sorted and unique / O(N log N)"
16 a = list(a)
17 n = len(a)
18 if any(a[i] > a[i + 1] for i in range(n - 1)):
19 a.sort()
20 if any(a[i] >= a[i + 1] for i in range(n - 1)):
21 a, b = [], a
22 for x in b:
23 if not a or a[-1] != x:
24 a.append(x)
25 n = self.size = len(a)
26 num_bucket = int(math.ceil(math.sqrt(n / self.BUCKET_RATIO)))
27 self.a = [
28 a[n * i // num_bucket : n * (i + 1) // num_bucket]
29 for i in range(num_bucket)
30 ]
31
32 def __iter__(self) -> Iterator[T]:
33 for i in self.a:
34 for j in i:
35 yield j
36
37 def __reversed__(self) -> Iterator[T]:
38 for i in reversed(self.a):
39 for j in reversed(i):
40 yield j
41
42 def __eq__(self, other) -> bool:
43 return list(self) == list(other)
44
45 def __len__(self) -> int:
46 return self.size
47
48 def __repr__(self) -> str:
49 return "SortedSet" + str(self.a)
50
51 def __str__(self) -> str:
52 s = str(list(self))
53 return "{" + s[1 : len(s) - 1] + "}"
54
55 def _position(self, x: T) -> tuple[list[T], int, int]:
56 "return the bucket, index of the bucket and position in which x should be. self must not be empty."
57 for i, a in enumerate(self.a):
58 if x <= a[-1]:
59 break
60 return (a, i, bisect_left(a, x))
61
62 def __contains__(self, x: T) -> bool:
63 if self.size == 0:
64 return False
65 a, _, i = self._position(x)
66 return i != len(a) and a[i] == x
67
68 def add(self, x: T) -> bool:
69 "Add an element and return True if added. / O(√N)"
70 if self.size == 0:
71 self.a = [[x]]
72 self.size = 1
73 return True
74 a, b, i = self._position(x)
75 if i != len(a) and a[i] == x:
76 return False
77 a.insert(i, x)
78 self.size += 1
79 if len(a) > len(self.a) * self.SPLIT_RATIO:
80 mid = len(a) >> 1
81 self.a[b : b + 1] = [a[:mid], a[mid:]]
82 return True
83
84 def _pop(self, a: list[T], b: int, i: int) -> T:
85 ans = a.pop(i)
86 self.size -= 1
87 if not a:
88 del self.a[b]
89 return ans
90
91 def discard(self, x: T) -> bool:
92 "Remove an element and return True if removed. / O(√N)"
93 if self.size == 0:
94 return False
95 a, b, i = self._position(x)
96 if i == len(a) or a[i] != x:
97 return False
98 self._pop(a, b, i)
99 return True
100
101 def lt(self, x: T) -> Optional[T]:
102 "Find the largest element < x, or None if it doesn't exist."
103 for a in reversed(self.a):
104 if a[0] < x:
105 return a[bisect_left(a, x) - 1]
106
107 def le(self, x: T) -> Optional[T]:
108 "Find the largest element <= x, or None if it doesn't exist."
109 for a in reversed(self.a):
110 if a[0] <= x:
111 return a[bisect_right(a, x) - 1]
112
113 def gt(self, x: T) -> Optional[T]:
114 "Find the smallest element > x, or None if it doesn't exist."
115 for a in self.a:
116 if a[-1] > x:
117 return a[bisect_right(a, x)]
118
119 def ge(self, x: T) -> Optional[T]:
120 "Find the smallest element >= x, or None if it doesn't exist."
121 for a in self.a:
122 if a[-1] >= x:
123 return a[bisect_left(a, x)]
124
125 def __getitem__(self, i: int) -> T:
126 "Return the i-th element."
127 if i < 0:
128 for a in reversed(self.a):
129 i += len(a)
130 if i >= 0:
131 return a[i]
132 else:
133 for a in self.a:
134 if i < len(a):
135 return a[i]
136 i -= len(a)
137 raise IndexError
138
139 def pop(self, i: int = -1) -> T:
140 "Pop and return the i-th element."
141 if i < 0:
142 for b, a in enumerate(reversed(self.a)):
143 i += len(a)
144 if i >= 0:
145 return self._pop(a, ~b, i)
146 else:
147 for b, a in enumerate(self.a):
148 if i < len(a):
149 return self._pop(a, b, i)
150 i -= len(a)
151 raise IndexError
152
153 def index(self, x: T) -> int:
154 "Count the number of elements < x."
155 ans = 0
156 for a in self.a:
157 if a[-1] >= x:
158 return ans + bisect_left(a, x)
159 ans += len(a)
160 return ans
161
162 def index_right(self, x: T) -> int:
163 "Count the number of elements <= x."
164 ans = 0
165 for a in self.a:
166 if a[-1] > x:
167 return ans + bisect_right(a, x)
168 ans += len(a)
169 return ans