1from titan_pylib.my_class.ordered_set_interface import OrderedSetInterface
2from titan_pylib.my_class.supports_less_than import SupportsLessThan
3from array import array
4from typing import Generic, Iterable, TypeVar, Optional
5
6T = TypeVar("T", bound=SupportsLessThan)
7
8
[docs]
9class SplayTreeSet(OrderedSetInterface, Generic[T]):
10
11 def __init__(self, a: Iterable[T] = [], e: T = 0):
12 self.keys: list[T] = [e]
13 self.size = array("I", bytes(4))
14 self.child = array("I", bytes(8))
15 self.end = 1
16 self.node = 0
17 self.e = e
18 if not isinstance(a, list):
19 a = list(a)
20 if a:
21 self._build(a)
22
23 def _build(self, a: list[T]) -> None:
24 def sort(l: int, r: int) -> int:
25 mid = (l + r) >> 1
26 if l != mid:
27 child[mid << 1] = sort(l, mid)
28 if mid + 1 != r:
29 child[mid << 1 | 1] = sort(mid + 1, r)
30 size[mid] = 1 + size[child[mid << 1]] + size[child[mid << 1 | 1]]
31 return mid
32
33 if not all(a[i] < a[i + 1] for i in range(len(a) - 1)):
34 a = sorted(set(a))
35 n = len(a)
36 key, child, size = self.keys, self.child, self.size
37 self.reserve(n - len(key) + 2)
38 self.end += n
39 key[1 : n + 1] = a
40 self.node = sort(1, n + 1)
41
42 def _make_node(self, key: T) -> int:
43 end = self.end
44 if end >= len(self.keys):
45 self.keys.append(key)
46 self.size.append(1)
47 self.child.append(0)
48 self.child.append(0)
49 else:
50 self.keys[end] = key
51 self.end += 1
52 return end
53
54 def _update(self, node: int) -> None:
55 self.size[node] = (
56 1 + self.size[self.child[node << 1]] + self.size[self.child[node << 1 | 1]]
57 )
58
59 def _update_triple(self, x: int, y: int, z: int) -> None:
60 size, child = self.size, self.child
61 size[z] = size[x]
62 size[x] = 1 + size[child[x << 1]] + size[child[x << 1 | 1]]
63 size[y] = 1 + size[child[y << 1]] + size[child[y << 1 | 1]]
64
65 def _splay(self, path: list[int], d: int) -> int:
66 child = self.child
67 g = d & 1
68 while len(path) > 1:
69 node = path.pop()
70 pnode = path.pop()
71 f = d >> 1 & 1
72 tmp = child[node << 1 | g ^ 1]
73 child[node << 1 | g ^ 1] = child[tmp << 1 | g]
74 child[tmp << 1 | f ^ (g != f)] = node
75 child[pnode << 1 | f ^ 1] = child[(node if g == f else tmp) << 1 | f]
76 child[(node if g == f else tmp) << 1 | f] = pnode
77 self._update_triple(pnode, node, tmp)
78 if not path:
79 return tmp
80 d >>= 2
81 g = d & 1
82 child[path[-1] << 1 | g ^ 1] = tmp
83 pnode = path[0]
84 node = child[pnode << 1 | g ^ 1]
85 child[pnode << 1 | g ^ 1] = child[node << 1 | g]
86 child[node << 1 | g] = pnode
87 self._update(pnode)
88 self._update(node)
89 return node
90
91 def _set_search_splay(self, key: T) -> None:
92 node = self.node
93 keys, child = self.keys, self.child
94 if (not node) or keys[node] == key:
95 return
96 path = []
97 d = 0
98 while keys[node] != key:
99 nxt = child[node << 1 | (key > keys[node])]
100 if not nxt:
101 break
102 path.append(node)
103 d, node = d << 1 | (key < keys[node]), nxt
104 if path:
105 self.node = self._splay(path, d)
106
107 def _set_kth_elm_splay(self, k: int) -> None:
108 size, child = self.size, self.child
109 node = self.node
110 if k < 0:
111 k += size[node]
112 d = 0
113 path = []
114 while True:
115 t = size[child[node << 1]]
116 if t == k:
117 if path:
118 self.node = self._splay(path, d)
119 return
120 path.append(node)
121 d, node = d << 1 | (t > k), child[node << 1 | (t < k)]
122 if t < k:
123 k -= t + 1
124
125 def _get_min_splay(self, node: int) -> int:
126 if not node:
127 return 0
128 child = self.child
129 if not child[node << 1]:
130 return node
131 path = []
132 while child[node << 1]:
133 path.append(node)
134 node = child[node << 1]
135 return self._splay(path, (1 << len(path)) - 1)
136
137 def _get_max_splay(self, node: int) -> int:
138 if not node:
139 return 0
140 child = self.child
141 if not child[node << 1 | 1]:
142 return node
143 path = []
144 while child[node << 1 | 1]:
145 path.append(node)
146 node = child[node << 1 | 1]
147 return self._splay(path, 0)
148
[docs]
149 def reserve(self, n: int) -> None:
150 assert n >= 0, f"ValueError: SplayTreeSet.reserve({n})"
151 self.keys += [self.e] * n
152 self.size += array("I", [1] * n)
153 self.child += array("I", bytes(8 * n))
154
[docs]
155 def add(self, key: T) -> bool:
156 keys, child = self.keys, self.child
157 self._set_search_splay(key)
158 if keys[self.node] == key:
159 return False
160 node = self._make_node(key)
161 if not self.node:
162 self.node = node
163 return True
164 f = key > keys[self.node]
165 child[node << 1 | f ^ 1] = self.node
166 child[node << 1 | f] = child[self.node << 1 | f]
167 child[self.node << 1 | f] = 0
168 self._update(child[node << 1 | f ^ 1])
169 self._update(node)
170 self.node = node
171 return True
172
[docs]
173 def discard(self, key: T) -> bool:
174 if not self.node:
175 return False
176 self._set_search_splay(key)
177 if self.keys[self.node] != key:
178 return False
179 child = self.child
180 if not child[self.node << 1]:
181 self.node = child[self.node << 1 | 1]
182 elif not child[self.node << 1 | 1]:
183 self.node = child[self.node << 1]
184 else:
185 node = self._get_min_splay(child[self.node << 1 | 1])
186 child[node << 1] = child[self.node << 1]
187 self._update(child[node << 1])
188 self.node = node
189 return True
190
[docs]
191 def remove(self, key: T) -> None:
192 if self.discard(key):
193 return
194 raise KeyError(key)
195
[docs]
196 def le(self, key: T) -> Optional[T]:
197 node = self.node
198 keys, child = self.keys, self.child
199 path = []
200 d = 0
201 res = None
202 while node:
203 if keys[node] == key:
204 res = key
205 break
206 path.append(node)
207 d = d << 1 | (key < keys[node])
208 if key > keys[node]:
209 res = keys[node]
210 node = child[node << 1 | (key > keys[node])]
211 else:
212 if path:
213 path.pop()
214 d >>= 1
215 if path:
216 self.node = self._splay(path, d)
217 return res
218
[docs]
219 def lt(self, key: T) -> Optional[T]:
220 node = self.node
221 keys, child = self.keys, self.child
222 path = []
223 d = 0
224 res = None
225 while node:
226 path.append(node)
227 d = d << 1 | (key <= keys[node])
228 if key > keys[node]:
229 res = keys[node]
230 node = child[node << 1 | (key > keys[node])]
231 else:
232 if path:
233 path.pop()
234 d >>= 1
235 if path:
236 self.node = self._splay(path, d)
237 return res
238
[docs]
239 def ge(self, key: T) -> Optional[T]:
240 node = self.node
241 keys, child = self.keys, self.child
242 path = []
243 d = 0
244 res = None
245 while node:
246 if keys[node] == key:
247 res = keys[node]
248 break
249 path.append(node)
250 d = d << 1 | (key < keys[node])
251 if key < keys[node]:
252 res = keys[node]
253 node = child[node << 1 | (key >= keys[node])]
254 else:
255 if path:
256 path.pop()
257 d >>= 1
258 if path:
259 self.node = self._splay(path, d)
260 return res
261
[docs]
262 def gt(self, key: T) -> Optional[T]:
263 node = self.node
264 keys, child = self.keys, self.child
265 path = []
266 d = 0
267 res = None
268 while node:
269 path.append(node)
270 d = d << 1 | (key < keys[node])
271 if key < keys[node]:
272 res = keys[node]
273 node = child[node << 1 | (key >= keys[node])]
274 else:
275 if path:
276 path.pop()
277 d >>= 1
278 if path:
279 self.node = self._splay(path, d)
280 return res
281
[docs]
282 def index(self, key: T) -> int:
283 if not self.node:
284 return 0
285 self._set_search_splay(key)
286 res = self.size[self.child[self.node << 1]]
287 if self.keys[self.node] < key:
288 res += 1
289 return res
290
[docs]
291 def index_right(self, key: T) -> int:
292 if not self.node:
293 return 0
294 self._set_search_splay(key)
295 res = self.size[self.child[self.node << 1]]
296 if self.keys[self.node] <= key:
297 res += 1
298 return res
299
[docs]
300 def get_min(self) -> T:
301 return self.__getitem__(0)
302
[docs]
303 def get_max(self) -> T:
304 return self.__getitem__(-1)
305
[docs]
306 def pop(self, k: int = -1) -> T:
307 assert self.node, f"IndexError: SplayTreeSet.pop({k})"
308 keys, child = self.keys, self.child
309 if k == -1:
310 node = self._get_max_splay(self.node)
311 self.node = child[node << 1]
312 return keys[node]
313 self._set_kth_elm_splay(k)
314 res = keys[self.node]
315 if not [self.node << 1]:
316 self.node = child[self.node << 1 | 1]
317 elif not child[self.node << 1 | 1]:
318 self.node = child[self.node << 1]
319 else:
320 node = self._get_min_splay(child[self.node << 1 | 1])
321 child[node << 1] = child[self.node << 1]
322 self._update(node)
323 self.node = node
324 return res
325
[docs]
326 def pop_max(self) -> T:
327 return self.pop()
328
[docs]
329 def pop_min(self) -> T:
330 assert self.node, "IndexError: SplayTreeSet.pop_min()"
331 node = self._get_min_splay(self.node)
332 self.node = self.child[node << 1 | 1]
333 return self.keys[node]
334
[docs]
335 def clear(self) -> None:
336 self.node = 0
337
[docs]
338 def tolist(self) -> list[T]:
339 node = self.node
340 child, keys = self.child, self.keys
341 stack, res = [], []
342 while stack or node:
343 if node:
344 stack.append(node)
345 node = child[node << 1]
346 else:
347 node = stack.pop()
348 res.append(keys[node])
349 node = child[node << 1 | 1]
350 return res
351
352 def __iter__(self):
353 self.__iter = 0
354 return self
355
356 def __next__(self):
357 if self.__iter == self.__len__():
358 raise StopIteration
359 self.__iter += 1
360 return self.__getitem__(self.__iter - 1)
361
362 def __reversed__(self):
363 for i in range(self.__len__()):
364 yield self.__getitem__(-i - 1)
365
366 def __contains__(self, key: T):
367 self._set_search_splay(key)
368 return self.keys[self.node] == key
369
370 def __getitem__(self, k: int) -> T:
371 self._set_kth_elm_splay(k)
372 return self.keys[self.node]
373
374 def __len__(self):
375 return self.size[self.node]
376
377 def __bool__(self):
378 return self.node != 0
379
380 def __str__(self):
381 return "{" + ", ".join(map(str, self.tolist())) + "}"
382
383 def __repr__(self):
384 return f"SplayTreeSet({self})"