1from titan_pylib.my_class.supports_less_than import SupportsLessThan
2import sys
3from typing import Iterator, Optional, Generic, Iterable, TypeVar
4
5T = TypeVar("T", bound=SupportsLessThan)
6
7
[docs]
8class SplayTreeMultiset(Generic[T]):
9
[docs]
10 class Node:
11
12 def __init__(self, key: T, val: int):
13 self.key: T = key
14 self.size: int = 1
15 self.val: int = val
16 self.valsize: int = val
17 self.left: Optional["SplayTreeMultiset.Node"] = None
18 self.right: Optional["SplayTreeMultiset.Node"] = None
19
20 def __str__(self):
21 if self.left is None and self.right is None:
22 return f"key:{self.key, self.size, self.val, self.valsize}\n"
23 return f"key:{self.key, self.size, self.val, self.valsize},\n left:{self.left},\n right:{self.right}\n"
24
25 def __init__(self, a: Iterable[T] = []) -> None:
26 self.root: Optional["SplayTreeMultiset.Node"] = None
27 if a:
28 self._build(a)
29
30 def _build(self, a: Iterable[T]) -> None:
31 Node = SplayTreeMultiset.Node
32
33 def sort(l: int, r: int) -> SplayTreeMultiset.Node:
34 mid = (l + r) >> 1
35 node = Node(key[mid], val[mid])
36 if l != mid:
37 node.left = sort(l, mid)
38 if mid + 1 != r:
39 node.right = sort(mid + 1, r)
40 self._update(node)
41 return node
42
43 key, val = self._rle(sorted(a))
44 if len(key) == 0:
45 return
46 self.root = sort(0, len(key))
47
48 def _rle(self, a: list[T]) -> tuple[list[T], list[int]]:
49 x = []
50 y = []
51 x.append(a[0])
52 y.append(1)
53 for i, e in enumerate(a):
54 if i == 0:
55 continue
56 if e == x[-1]:
57 y[-1] += 1
58 continue
59 x.append(e)
60 y.append(1)
61 return x, y
62
63 def _update(self, node: Node) -> None:
64 if node.left is None:
65 if node.right is None:
66 node.size = 1
67 node.valsize = node.val
68 else:
69 node.size = 1 + node.right.size
70 node.valsize = node.val + node.right.valsize
71 else:
72 if node.right is None:
73 node.size = 1 + node.left.size
74 node.valsize = node.val + node.left.valsize
75 else:
76 node.size = 1 + node.left.size + node.right.size
77 node.valsize = node.val + node.left.valsize + node.right.valsize
78
79 def _splay(self, path: list[Node], d: int) -> Node:
80 for _ in range(len(path) >> 1):
81 node = path.pop()
82 pnode = path.pop()
83 if d & 1 == d >> 1 & 1:
84 if d & 1:
85 tmp = node.left
86 node.left = tmp.right
87 tmp.right = node
88 pnode.left = node.right
89 node.right = pnode
90 else:
91 tmp = node.right
92 node.right = tmp.left
93 tmp.left = node
94 pnode.right = node.left
95 node.left = pnode
96 else:
97 if d & 1:
98 tmp = node.left
99 node.left = tmp.right
100 pnode.right = tmp.left
101 tmp.right = node
102 tmp.left = pnode
103 else:
104 tmp = node.right
105 node.right = tmp.left
106 pnode.left = tmp.right
107 tmp.left = node
108 tmp.right = pnode
109 self._update(pnode)
110 self._update(node)
111 self._update(tmp)
112 if not path:
113 return tmp
114 d >>= 2
115 if d & 1:
116 path[-1].left = tmp
117 else:
118 path[-1].right = tmp
119 gnode = path[0]
120 if d & 1:
121 node = gnode.left
122 gnode.left = node.right
123 node.right = gnode
124 else:
125 node = gnode.right
126 gnode.right = node.left
127 node.left = gnode
128 self._update(gnode)
129 self._update(node)
130 return node
131
132 def _set_search_splay(self, key: T) -> None:
133 node = self.root
134 if node is None or node.key == key:
135 return
136 path = []
137 d = 0
138 while True:
139 if node.key == key:
140 break
141 if key < node.key:
142 if node.left is None:
143 break
144 path.append(node)
145 d <<= 1
146 d |= 1
147 node = node.left
148 else:
149 if node.right is None:
150 break
151 path.append(node)
152 d <<= 1
153 node = node.right
154 if path:
155 self.root = self._splay(path, d)
156
157 def _set_kth_elm_splay(self, k: int) -> None:
158 if k < 0:
159 k += self.__len__()
160 d = 0
161 node = self.root
162 path = []
163 while True:
164 t = node.val if node.left is None else node.val + node.left.valsize
165 if t - node.val <= k < t:
166 if path:
167 self.root = self._splay(path, d)
168 break
169 elif t > k:
170 path.append(node)
171 d <<= 1
172 d |= 1
173 node = node.left
174 else:
175 path.append(node)
176 d <<= 1
177 node = node.right
178 k -= t
179
180 def _set_kth_elm_tree_splay(self, k: int) -> None:
181 if k < 0:
182 k += self.len_elm()
183 assert 0 <= k < self.len_elm()
184 d = 0
185 node = self.root
186 path = []
187 while True:
188 t = 0 if node.left is None else node.left.size
189 if t == k:
190 if path:
191 self.root = self._splay(path, d)
192 return
193 elif t > k:
194 path.append(node)
195 d <<= 1
196 d |= 1
197 node = node.left
198 else:
199 path.append(node)
200 d <<= 1
201 node = node.right
202 k -= t + 1
203
204 def _get_min_splay(self, node: Node) -> Node:
205 if node is None or node.left is None:
206 return node
207 path = []
208 while node.left is not None:
209 path.append(node)
210 node = node.left
211 return self._splay(path, (1 << len(path)) - 1)
212
213 def _get_max_splay(self, node: Node) -> Node:
214 if node is None or node.right is None:
215 return node
216 path = []
217 while node.right is not None:
218 path.append(node)
219 node = node.right
220 return self._splay(path, 0)
221
[docs]
222 def add(self, key: T, val: int = 1) -> None:
223 if self.root is None:
224 self.root = SplayTreeMultiset.Node(key, val)
225 return
226 self._set_search_splay(key)
227 if self.root.key == key:
228 self.root.val += val
229 self._update(self.root)
230 return
231 node = SplayTreeMultiset.Node(key, val)
232 if key < self.root.key:
233 node.left = self.root.left
234 node.right = self.root
235 self.root.left = None
236 self._update(node.right)
237 else:
238 node.left = self.root
239 node.right = self.root.right
240 self.root.right = None
241 self._update(node.left)
242 self._update(node)
243 self.root = node
244 return
245
[docs]
246 def discard(self, key: T, val: int = 1) -> bool:
247 if self.root is None:
248 return False
249 self._set_search_splay(key)
250 if self.root.key != key:
251 return False
252 if self.root.val > val:
253 self.root.val -= val
254 self._update(self.root)
255 return True
256 if self.root.left is None:
257 self.root = self.root.right
258 elif self.root.right is None:
259 self.root = self.root.left
260 else:
261 node = self._get_min_splay(self.root.right)
262 node.left = self.root.left
263 self._update(node)
264 self.root = node
265 return True
266
[docs]
267 def discard_all(self, key: T) -> bool:
268 return self.discard(key, self.count(key))
269
[docs]
270 def count(self, key: T) -> int:
271 if self.root is None:
272 return 0
273 self._set_search_splay(key)
274 return self.root.val if self.root.key == key else 0
275
[docs]
276 def le(self, key: T) -> Optional[T]:
277 node = self.root
278 if node is None:
279 return None
280 path = []
281 d = 0
282 res = None
283 while True:
284 if node.key == key:
285 res = key
286 break
287 elif key < node.key:
288 if node.left is None:
289 break
290 path.append(node)
291 d <<= 1
292 d |= 1
293 node = node.left
294 else:
295 res = node.key
296 if node.right is None:
297 break
298 path.append(node)
299 d <<= 1
300 node = node.right
301 if path:
302 self.root = self._splay(path, d)
303 return res
304
[docs]
305 def lt(self, key: T) -> Optional[T]:
306 node = self.root
307 path = []
308 d = 0
309 res = None
310 while node is not None:
311 if key <= node.key:
312 path.append(node)
313 d <<= 1
314 d |= 1
315 node = node.left
316 else:
317 path.append(node)
318 d <<= 1
319 res = node.key
320 node = node.right
321 else:
322 if path:
323 path.pop()
324 d >>= 1
325 if path:
326 self.root = self._splay(path, d)
327 return res
328
[docs]
329 def ge(self, key: T) -> Optional[T]:
330 node = self.root
331 if node is None:
332 return None
333 path = []
334 d = 0
335 res = None
336 while True:
337 if node.key == key:
338 res = node.key
339 break
340 elif key < node.key:
341 res = node.key
342 if node.left is None:
343 break
344 path.append(node)
345 d <<= 1
346 d |= 1
347 node = node.left
348 else:
349 if node.right is None:
350 break
351 path.append(node)
352 d <<= 1
353 node = node.right
354 if path:
355 self.root = self._splay(path, d)
356 return res
357
[docs]
358 def gt(self, key: T) -> Optional[T]:
359 node = self.root
360 path = []
361 d = 0
362 res = None
363 while node is not None:
364 if key < node.key:
365 path.append(node)
366 d <<= 1
367 d |= 1
368 res = node.key
369 node = node.left
370 else:
371 path.append(node)
372 d <<= 1
373 node = node.right
374 else:
375 if path:
376 path.pop()
377 d >>= 1
378 if path:
379 self.root = self._splay(path, d)
380 return res
381
[docs]
382 def index(self, key: T) -> int:
383 if self.root is None:
384 return 0
385 self._set_search_splay(key)
386 res = 0 if self.root.left is None else self.root.left.valsize
387 if self.root.key < key:
388 res += self.root.val
389 return res
390
[docs]
391 def index_right(self, key: T) -> int:
392 if self.root is None:
393 return 0
394 self._set_search_splay(key)
395 res = 0 if self.root.left is None else self.root.left.valsize
396 if self.root.key <= key:
397 res += self.root.val
398 return res
399
[docs]
400 def index_keys(self, key: T) -> int:
401 if self.root is None:
402 return 0
403 self._set_search_splay(key)
404 res = 0 if self.root.left is None else self.root.left.size
405 if self.root.key < key:
406 res += 1
407 return res
408
[docs]
409 def index_right_keys(self, key: T) -> int:
410 if self.root is None:
411 return 0
412 self._set_search_splay(key)
413 res = 0 if self.root.left is None else self.root.left.size
414 if self.root.key <= key:
415 res += 1
416 return res
417
[docs]
418 def pop(self, k: int = -1) -> T:
419 self._set_kth_elm_splay(k)
420 res = self.root.key
421 self.discard(res)
422 return res
423
[docs]
424 def pop_max(self) -> T:
425 return self.pop()
426
[docs]
427 def pop_min(self) -> T:
428 return self.pop(0)
429
[docs]
430 def tolist(self) -> list[T]:
431 a = []
432 if self.root is None:
433 return a
434 if sys.getrecursionlimit() < self.len_elm():
435 sys.setrecursionlimit(self.len_elm() + 1)
436
437 def rec(node):
438 if node.left is not None:
439 rec(node.left)
440 for _ in range(node.val):
441 a.append(node.key)
442 if node.right is not None:
443 rec(node.right)
444
445 rec(self.root)
446 return a
447
[docs]
448 def tolist_items(self) -> list[tuple[T, int]]:
449 a = []
450 if self.root is None:
451 return a
452 if sys.getrecursionlimit() < self.len_elm():
453 sys.setrecursionlimit(self.len_elm() + 1)
454
455 def rec(node):
456 if node.left is not None:
457 rec(node.left)
458 a.append((node.key, node.val))
459 if node.right is not None:
460 rec(node.right)
461
462 rec(self.root)
463 return a
464
[docs]
465 def get_elm(self, k: int) -> T:
466 assert -self.len_elm() <= k < self.len_elm()
467 self._set_kth_elm_tree_splay(k)
468 return self.root.key
469
[docs]
470 def items(self) -> Iterator[tuple[T, int]]:
471 for i in range(self.len_elm()):
472 self._set_kth_elm_tree_splay(i)
473 yield self.root.key, self.root.val
474
[docs]
475 def keys(self) -> Iterator[T]:
476 for i in range(self.len_elm()):
477 self._set_kth_elm_tree_splay(i)
478 yield self.root.key
479
[docs]
480 def values(self) -> Iterator[int]:
481 for i in range(self.len_elm()):
482 self._set_kth_elm_tree_splay(i)
483 yield self.root.val
484
[docs]
485 def len_elm(self) -> int:
486 return 0 if self.root is None else self.root.size
487
[docs]
488 def show(self) -> None:
489 print(
490 "{" + ", ".join(map(lambda x: f"{x[0]}: {x[1]}", self.tolist_items())) + "}"
491 )
492
[docs]
493 def clear(self) -> None:
494 self.root = None
495
496 def __iter__(self):
497 self.__iter = 0
498 return self
499
500 def __next__(self):
501 if self.__iter == self.__len__():
502 raise StopIteration
503 res = self.__getitem__(self.__iter)
504 self.__iter += 1
505 return res
506
507 def __reversed__(self):
508 for i in range(self.__len__()):
509 yield self.__getitem__(-i - 1)
510
511 def __contains__(self, key: T) -> bool:
512 self._set_search_splay(key)
513 return self.root is not None and self.root.key == key
514
515 def __getitem__(self, k: int) -> T:
516 self._set_kth_elm_splay(k)
517 return self.root.key
518
519 def __len__(self):
520 return 0 if self.root is None else self.root.valsize
521
522 def __bool__(self):
523 return self.root is not None
524
525 def __str__(self):
526 return "{" + ", ".join(map(str, self.tolist())) + "}"
527
528 def __repr__(self):
529 return f"SplayTreeMultiset({self.tolist()})"