1# from titan_pylib.data_structures.splay_tree.splay_tree_multiset import SplayTreeMultiset
2# from titan_pylib.my_class.supports_less_than import SupportsLessThan
3from typing import Protocol
4
5
6class SupportsLessThan(Protocol):
7
8 def __lt__(self, other) -> bool: ...
9import sys
10from typing import Iterator, Optional, Generic, Iterable, TypeVar
11
12T = TypeVar("T", bound=SupportsLessThan)
13
14
15class SplayTreeMultiset(Generic[T]):
16
17 class Node:
18
19 def __init__(self, key: T, val: int):
20 self.key: T = key
21 self.size: int = 1
22 self.val: int = val
23 self.valsize: int = val
24 self.left: Optional["SplayTreeMultiset.Node"] = None
25 self.right: Optional["SplayTreeMultiset.Node"] = None
26
27 def __str__(self):
28 if self.left is None and self.right is None:
29 return f"key:{self.key, self.size, self.val, self.valsize}\n"
30 return f"key:{self.key, self.size, self.val, self.valsize},\n left:{self.left},\n right:{self.right}\n"
31
32 def __init__(self, a: Iterable[T] = []) -> None:
33 self.root: Optional["SplayTreeMultiset.Node"] = None
34 if a:
35 self._build(a)
36
37 def _build(self, a: Iterable[T]) -> None:
38 Node = SplayTreeMultiset.Node
39
40 def sort(l: int, r: int) -> SplayTreeMultiset.Node:
41 mid = (l + r) >> 1
42 node = Node(key[mid], val[mid])
43 if l != mid:
44 node.left = sort(l, mid)
45 if mid + 1 != r:
46 node.right = sort(mid + 1, r)
47 self._update(node)
48 return node
49
50 key, val = self._rle(sorted(a))
51 if len(key) == 0:
52 return
53 self.root = sort(0, len(key))
54
55 def _rle(self, a: list[T]) -> tuple[list[T], list[int]]:
56 x = []
57 y = []
58 x.append(a[0])
59 y.append(1)
60 for i, e in enumerate(a):
61 if i == 0:
62 continue
63 if e == x[-1]:
64 y[-1] += 1
65 continue
66 x.append(e)
67 y.append(1)
68 return x, y
69
70 def _update(self, node: Node) -> None:
71 if node.left is None:
72 if node.right is None:
73 node.size = 1
74 node.valsize = node.val
75 else:
76 node.size = 1 + node.right.size
77 node.valsize = node.val + node.right.valsize
78 else:
79 if node.right is None:
80 node.size = 1 + node.left.size
81 node.valsize = node.val + node.left.valsize
82 else:
83 node.size = 1 + node.left.size + node.right.size
84 node.valsize = node.val + node.left.valsize + node.right.valsize
85
86 def _splay(self, path: list[Node], d: int) -> Node:
87 for _ in range(len(path) >> 1):
88 node = path.pop()
89 pnode = path.pop()
90 if d & 1 == d >> 1 & 1:
91 if d & 1:
92 tmp = node.left
93 node.left = tmp.right
94 tmp.right = node
95 pnode.left = node.right
96 node.right = pnode
97 else:
98 tmp = node.right
99 node.right = tmp.left
100 tmp.left = node
101 pnode.right = node.left
102 node.left = pnode
103 else:
104 if d & 1:
105 tmp = node.left
106 node.left = tmp.right
107 pnode.right = tmp.left
108 tmp.right = node
109 tmp.left = pnode
110 else:
111 tmp = node.right
112 node.right = tmp.left
113 pnode.left = tmp.right
114 tmp.left = node
115 tmp.right = pnode
116 self._update(pnode)
117 self._update(node)
118 self._update(tmp)
119 if not path:
120 return tmp
121 d >>= 2
122 if d & 1:
123 path[-1].left = tmp
124 else:
125 path[-1].right = tmp
126 gnode = path[0]
127 if d & 1:
128 node = gnode.left
129 gnode.left = node.right
130 node.right = gnode
131 else:
132 node = gnode.right
133 gnode.right = node.left
134 node.left = gnode
135 self._update(gnode)
136 self._update(node)
137 return node
138
139 def _set_search_splay(self, key: T) -> None:
140 node = self.root
141 if node is None or node.key == key:
142 return
143 path = []
144 d = 0
145 while True:
146 if node.key == key:
147 break
148 if key < node.key:
149 if node.left is None:
150 break
151 path.append(node)
152 d <<= 1
153 d |= 1
154 node = node.left
155 else:
156 if node.right is None:
157 break
158 path.append(node)
159 d <<= 1
160 node = node.right
161 if path:
162 self.root = self._splay(path, d)
163
164 def _set_kth_elm_splay(self, k: int) -> None:
165 if k < 0:
166 k += self.__len__()
167 d = 0
168 node = self.root
169 path = []
170 while True:
171 t = node.val if node.left is None else node.val + node.left.valsize
172 if t - node.val <= k < t:
173 if path:
174 self.root = self._splay(path, d)
175 break
176 elif t > k:
177 path.append(node)
178 d <<= 1
179 d |= 1
180 node = node.left
181 else:
182 path.append(node)
183 d <<= 1
184 node = node.right
185 k -= t
186
187 def _set_kth_elm_tree_splay(self, k: int) -> None:
188 if k < 0:
189 k += self.len_elm()
190 assert 0 <= k < self.len_elm()
191 d = 0
192 node = self.root
193 path = []
194 while True:
195 t = 0 if node.left is None else node.left.size
196 if t == k:
197 if path:
198 self.root = self._splay(path, d)
199 return
200 elif t > k:
201 path.append(node)
202 d <<= 1
203 d |= 1
204 node = node.left
205 else:
206 path.append(node)
207 d <<= 1
208 node = node.right
209 k -= t + 1
210
211 def _get_min_splay(self, node: Node) -> Node:
212 if node is None or node.left is None:
213 return node
214 path = []
215 while node.left is not None:
216 path.append(node)
217 node = node.left
218 return self._splay(path, (1 << len(path)) - 1)
219
220 def _get_max_splay(self, node: Node) -> Node:
221 if node is None or node.right is None:
222 return node
223 path = []
224 while node.right is not None:
225 path.append(node)
226 node = node.right
227 return self._splay(path, 0)
228
229 def add(self, key: T, val: int = 1) -> None:
230 if self.root is None:
231 self.root = SplayTreeMultiset.Node(key, val)
232 return
233 self._set_search_splay(key)
234 if self.root.key == key:
235 self.root.val += val
236 self._update(self.root)
237 return
238 node = SplayTreeMultiset.Node(key, val)
239 if key < self.root.key:
240 node.left = self.root.left
241 node.right = self.root
242 self.root.left = None
243 self._update(node.right)
244 else:
245 node.left = self.root
246 node.right = self.root.right
247 self.root.right = None
248 self._update(node.left)
249 self._update(node)
250 self.root = node
251 return
252
253 def discard(self, key: T, val: int = 1) -> bool:
254 if self.root is None:
255 return False
256 self._set_search_splay(key)
257 if self.root.key != key:
258 return False
259 if self.root.val > val:
260 self.root.val -= val
261 self._update(self.root)
262 return True
263 if self.root.left is None:
264 self.root = self.root.right
265 elif self.root.right is None:
266 self.root = self.root.left
267 else:
268 node = self._get_min_splay(self.root.right)
269 node.left = self.root.left
270 self._update(node)
271 self.root = node
272 return True
273
274 def discard_all(self, key: T) -> bool:
275 return self.discard(key, self.count(key))
276
277 def count(self, key: T) -> int:
278 if self.root is None:
279 return 0
280 self._set_search_splay(key)
281 return self.root.val if self.root.key == key else 0
282
283 def le(self, key: T) -> Optional[T]:
284 node = self.root
285 if node is None:
286 return None
287 path = []
288 d = 0
289 res = None
290 while True:
291 if node.key == key:
292 res = key
293 break
294 elif key < node.key:
295 if node.left is None:
296 break
297 path.append(node)
298 d <<= 1
299 d |= 1
300 node = node.left
301 else:
302 res = node.key
303 if node.right is None:
304 break
305 path.append(node)
306 d <<= 1
307 node = node.right
308 if path:
309 self.root = self._splay(path, d)
310 return res
311
312 def lt(self, key: T) -> Optional[T]:
313 node = self.root
314 path = []
315 d = 0
316 res = None
317 while node is not None:
318 if key <= node.key:
319 path.append(node)
320 d <<= 1
321 d |= 1
322 node = node.left
323 else:
324 path.append(node)
325 d <<= 1
326 res = node.key
327 node = node.right
328 else:
329 if path:
330 path.pop()
331 d >>= 1
332 if path:
333 self.root = self._splay(path, d)
334 return res
335
336 def ge(self, key: T) -> Optional[T]:
337 node = self.root
338 if node is None:
339 return None
340 path = []
341 d = 0
342 res = None
343 while True:
344 if node.key == key:
345 res = node.key
346 break
347 elif key < node.key:
348 res = node.key
349 if node.left is None:
350 break
351 path.append(node)
352 d <<= 1
353 d |= 1
354 node = node.left
355 else:
356 if node.right is None:
357 break
358 path.append(node)
359 d <<= 1
360 node = node.right
361 if path:
362 self.root = self._splay(path, d)
363 return res
364
365 def gt(self, key: T) -> Optional[T]:
366 node = self.root
367 path = []
368 d = 0
369 res = None
370 while node is not None:
371 if key < node.key:
372 path.append(node)
373 d <<= 1
374 d |= 1
375 res = node.key
376 node = node.left
377 else:
378 path.append(node)
379 d <<= 1
380 node = node.right
381 else:
382 if path:
383 path.pop()
384 d >>= 1
385 if path:
386 self.root = self._splay(path, d)
387 return res
388
389 def index(self, key: T) -> int:
390 if self.root is None:
391 return 0
392 self._set_search_splay(key)
393 res = 0 if self.root.left is None else self.root.left.valsize
394 if self.root.key < key:
395 res += self.root.val
396 return res
397
398 def index_right(self, key: T) -> int:
399 if self.root is None:
400 return 0
401 self._set_search_splay(key)
402 res = 0 if self.root.left is None else self.root.left.valsize
403 if self.root.key <= key:
404 res += self.root.val
405 return res
406
407 def index_keys(self, key: T) -> int:
408 if self.root is None:
409 return 0
410 self._set_search_splay(key)
411 res = 0 if self.root.left is None else self.root.left.size
412 if self.root.key < key:
413 res += 1
414 return res
415
416 def index_right_keys(self, key: T) -> int:
417 if self.root is None:
418 return 0
419 self._set_search_splay(key)
420 res = 0 if self.root.left is None else self.root.left.size
421 if self.root.key <= key:
422 res += 1
423 return res
424
425 def pop(self, k: int = -1) -> T:
426 self._set_kth_elm_splay(k)
427 res = self.root.key
428 self.discard(res)
429 return res
430
431 def pop_max(self) -> T:
432 return self.pop()
433
434 def pop_min(self) -> T:
435 return self.pop(0)
436
437 def tolist(self) -> list[T]:
438 a = []
439 if self.root is None:
440 return a
441 if sys.getrecursionlimit() < self.len_elm():
442 sys.setrecursionlimit(self.len_elm() + 1)
443
444 def rec(node):
445 if node.left is not None:
446 rec(node.left)
447 for _ in range(node.val):
448 a.append(node.key)
449 if node.right is not None:
450 rec(node.right)
451
452 rec(self.root)
453 return a
454
455 def tolist_items(self) -> list[tuple[T, int]]:
456 a = []
457 if self.root is None:
458 return a
459 if sys.getrecursionlimit() < self.len_elm():
460 sys.setrecursionlimit(self.len_elm() + 1)
461
462 def rec(node):
463 if node.left is not None:
464 rec(node.left)
465 a.append((node.key, node.val))
466 if node.right is not None:
467 rec(node.right)
468
469 rec(self.root)
470 return a
471
472 def get_elm(self, k: int) -> T:
473 assert -self.len_elm() <= k < self.len_elm()
474 self._set_kth_elm_tree_splay(k)
475 return self.root.key
476
477 def items(self) -> Iterator[tuple[T, int]]:
478 for i in range(self.len_elm()):
479 self._set_kth_elm_tree_splay(i)
480 yield self.root.key, self.root.val
481
482 def keys(self) -> Iterator[T]:
483 for i in range(self.len_elm()):
484 self._set_kth_elm_tree_splay(i)
485 yield self.root.key
486
487 def values(self) -> Iterator[int]:
488 for i in range(self.len_elm()):
489 self._set_kth_elm_tree_splay(i)
490 yield self.root.val
491
492 def len_elm(self) -> int:
493 return 0 if self.root is None else self.root.size
494
495 def show(self) -> None:
496 print(
497 "{" + ", ".join(map(lambda x: f"{x[0]}: {x[1]}", self.tolist_items())) + "}"
498 )
499
500 def clear(self) -> None:
501 self.root = None
502
503 def __iter__(self):
504 self.__iter = 0
505 return self
506
507 def __next__(self):
508 if self.__iter == self.__len__():
509 raise StopIteration
510 res = self.__getitem__(self.__iter)
511 self.__iter += 1
512 return res
513
514 def __reversed__(self):
515 for i in range(self.__len__()):
516 yield self.__getitem__(-i - 1)
517
518 def __contains__(self, key: T) -> bool:
519 self._set_search_splay(key)
520 return self.root is not None and self.root.key == key
521
522 def __getitem__(self, k: int) -> T:
523 self._set_kth_elm_splay(k)
524 return self.root.key
525
526 def __len__(self):
527 return 0 if self.root is None else self.root.valsize
528
529 def __bool__(self):
530 return self.root is not None
531
532 def __str__(self):
533 return "{" + ", ".join(map(str, self.tolist())) + "}"
534
535 def __repr__(self):
536 return f"SplayTreeMultiset({self.tolist()})"