1import sys
2from typing import Iterator, Optional, Generic, Iterable, TypeVar
3
4T = TypeVar("T")
5
6
[docs]
7class SplayTreeMultisetSum(Generic[T]):
8
[docs]
9 class Node:
10
11 def __init__(self, key, cnt: int) -> None:
12 self.data = key * cnt
13 self.key = key
14 self.size = 1
15 self.cnt = cnt
16 self.cntsize = cnt
17 self.left = None
18 self.right = None
19
20 def __str__(self) -> str:
21 if self.left is None and self.right is None:
22 return f"key:{self.key, self.size, self.cnt, self.cntsize}\n"
23 return f"key:{self.key, self.size, self.cnt, self.cntsize},\n left:{self.left},\n right:{self.right}\n"
24
25 __repr__ = __str__
26
27 def __init__(self, e: T, a: Iterable[T] = [], _node=None) -> None:
28 self.node = _node
29 self.e = e
30 if a:
31 self._build(a)
32
33 def _build(self, a: Iterable[T]) -> None:
34 Node = SplayTreeMultisetSum.Node
35
36 def sort(l: int, r: int) -> SplayTreeMultisetSum.Node:
37 mid = (l + r) >> 1
38 node = Node(key[mid], cnt[mid])
39 if l != mid:
40 node.left = sort(l, mid)
41 if mid + 1 != r:
42 node.right = sort(mid + 1, r)
43 self._update(node)
44 return node
45
46 key, cnt = self._rle(sorted(a))
47 self.node = sort(0, len(key))
48
49 def _rle(self, a: list[T]) -> tuple[list[T], list[int]]:
50 x = []
51 y = []
52 x.append(a[0])
53 y.append(1)
54 for i, e in enumerate(a):
55 if i == 0:
56 continue
57 if e == x[-1]:
58 y[-1] += 1
59 continue
60 x.append(e)
61 y.append(1)
62 return x, y
63
64 def _update(self, node: Node) -> None:
65 if node.left is None:
66 if node.right is None:
67 node.size = 1
68 node.cntsize = node.cnt
69 node.data = node.key * node.cnt
70 else:
71 node.size = 1 + node.right.size
72 node.cntsize = node.cnt + node.right.cntsize
73 node.data = node.key * node.cnt + node.right.data
74 else:
75 if node.right is None:
76 node.size = 1 + node.left.size
77 node.cntsize = node.cnt + node.left.cntsize
78 node.data = node.key * node.cnt + node.left.data
79 else:
80 node.size = 1 + node.left.size + node.right.size
81 node.cntsize = node.cnt + node.left.cntsize + node.right.cntsize
82 node.data = node.key * node.cnt + node.left.data + node.right.data
83
84 def _splay(self, path: list[Node], d: int) -> Node:
85 for _ in range(len(path) >> 1):
86 node = path.pop()
87 pnode = path.pop()
88 if d & 1 == d >> 1 & 1:
89 if d & 1:
90 tmp = node.left
91 node.left = tmp.right
92 tmp.right = node
93 pnode.left = node.right
94 node.right = pnode
95 else:
96 tmp = node.right
97 node.right = tmp.left
98 tmp.left = node
99 pnode.right = node.left
100 node.left = pnode
101 else:
102 if d & 1:
103 tmp = node.left
104 node.left = tmp.right
105 pnode.right = tmp.left
106 tmp.right = node
107 tmp.left = pnode
108 else:
109 tmp = node.right
110 node.right = tmp.left
111 pnode.left = tmp.right
112 tmp.left = node
113 tmp.right = pnode
114 self._update(pnode)
115 self._update(node)
116 self._update(tmp)
117 if not path:
118 return tmp
119 d >>= 2
120 if d & 1:
121 path[-1].left = tmp
122 else:
123 path[-1].right = tmp
124 gnode = path[0]
125 if d & 1:
126 node = gnode.left
127 gnode.left = node.right
128 node.right = gnode
129 else:
130 node = gnode.right
131 gnode.right = node.left
132 node.left = gnode
133 self._update(gnode)
134 self._update(node)
135 return node
136
137 def _set_search_splay(self, key: T) -> None:
138 node = self.node
139 if node is None or node.key == key:
140 return
141 path = []
142 d = 0
143 while True:
144 if node.key == key:
145 break
146 elif key < node.key:
147 if node.left is None:
148 break
149 path.append(node)
150 d <<= 1
151 d |= 1
152 node = node.left
153 else:
154 if node.right is None:
155 break
156 path.append(node)
157 d <<= 1
158 node = node.right
159 if path:
160 self.node = self._splay(path, d)
161
162 def _set_kth_elm_splay(self, k: int) -> None:
163 if k < 0:
164 k += self.__len__()
165 d = 0
166 node = self.node
167 path = []
168 while True:
169 t = node.cnt if node.left is None else node.cnt + node.left.cntsize
170 if t - node.cnt <= k < t:
171 if path:
172 self.node = self._splay(path, d)
173 break
174 elif t > k:
175 path.append(node)
176 d <<= 1
177 d |= 1
178 node = node.left
179 else:
180 path.append(node)
181 d <<= 1
182 node = node.right
183 k -= t
184
185 def _set_kth_elm_tree_splay(self, k: int) -> None:
186 if k < 0:
187 k += self.len_elm()
188 assert 0 <= k < self.len_elm()
189 d = 0
190 node = self.node
191 path = []
192 while True:
193 t = 0 if node.left is None else node.left.size
194 if t == k:
195 if path:
196 self.node = self._splay(path, d)
197 return
198 elif t > k:
199 path.append(node)
200 d <<= 1
201 d |= 1
202 node = node.left
203 else:
204 path.append(node)
205 d <<= 1
206 node = node.right
207 k -= t + 1
208
209 def _get_min_splay(self, node: Node) -> Node:
210 if node is None or node.left is None:
211 return node
212 path = []
213 while node.left is not None:
214 path.append(node)
215 node = node.left
216 return self._splay(path, (1 << len(path)) - 1)
217
218 def _get_max_splay(self, node: Node) -> Node:
219 if node is None or node.right is None:
220 return node
221 path = []
222 while node.right is not None:
223 path.append(node)
224 node = node.right
225 return self._splay(path, 0)
226
[docs]
227 def add(self, key: T, cnt: int = 1) -> None:
228 if self.node is None:
229 self.node = SplayTreeMultisetSum.Node(key, cnt)
230 return
231 self._set_search_splay(key)
232 if self.node.key == key:
233 self.node.cnt += cnt
234 self._update(self.node)
235 return
236 node = SplayTreeMultisetSum.Node(key, cnt)
237 if key < self.node.key:
238 node.left = self.node.left
239 node.right = self.node
240 self.node.left = None
241 self._update(node.right)
242 else:
243 node.left = self.node
244 node.right = self.node.right
245 self.node.right = None
246 self._update(node.left)
247 self._update(node)
248 self.node = node
249 return
250
[docs]
251 def discard(self, key: T, cnt: int = 1) -> bool:
252 if self.node is None:
253 return False
254 self._set_search_splay(key)
255 if self.node.key != key:
256 return False
257 if self.node.cnt > cnt:
258 self.node.cnt -= cnt
259 self._update(self.node)
260 return True
261 if self.node.left is None:
262 self.node = self.node.right
263 elif self.node.right is None:
264 self.node = self.node.left
265 else:
266 node = self._get_min_splay(self.node.right)
267 node.left = self.node.left
268 self._update(node)
269 self.node = node
270 return True
271
[docs]
272 def discard_all(self, key: T) -> bool:
273 return self.discard(key, self.count(key))
274
[docs]
275 def count(self, key: T) -> int:
276 if self.node is None:
277 return 0
278 self._set_search_splay(key)
279 return self.node.cnt if self.node.key == key else 0
280
[docs]
281 def le(self, key: T) -> Optional[T]:
282 node = self.node
283 if node is None:
284 return None
285 path = []
286 d = 0
287 res = None
288 while True:
289 if node.key == key:
290 res = key
291 break
292 elif key < node.key:
293 if node.left is None:
294 break
295 path.append(node)
296 d <<= 1
297 d |= 1
298 node = node.left
299 else:
300 res = node.key
301 if node.right is None:
302 break
303 path.append(node)
304 d <<= 1
305 node = node.right
306 if path:
307 self.node = self._splay(path, d)
308 return res
309
[docs]
310 def lt(self, key: T) -> Optional[T]:
311 node = self.node
312 path = []
313 d = 0
314 res = None
315 while node is not None:
316 if key <= node.key:
317 path.append(node)
318 d <<= 1
319 d |= 1
320 node = node.left
321 else:
322 path.append(node)
323 d <<= 1
324 res = node.key
325 node = node.right
326 else:
327 if path:
328 path.pop()
329 d >>= 1
330 if path:
331 self.node = self._splay(path, d)
332 return res
333
[docs]
334 def ge(self, key: T) -> Optional[T]:
335 node = self.node
336 if node is None:
337 return None
338 path = []
339 d = 0
340 res = None
341 while True:
342 if node.key == key:
343 res = node.key
344 break
345 elif key < node.key:
346 res = node.key
347 if node.left is None:
348 break
349 path.append(node)
350 d <<= 1
351 d |= 1
352 node = node.left
353 else:
354 if node.right is None:
355 break
356 path.append(node)
357 d <<= 1
358 node = node.right
359 if path:
360 self.node = self._splay(path, d)
361 return res
362
[docs]
363 def gt(self, key: T) -> Optional[T]:
364 node = self.node
365 path = []
366 d = 0
367 res = None
368 while node is not None:
369 if key < node.key:
370 path.append(node)
371 d <<= 1
372 d |= 1
373 res = node.key
374 node = node.left
375 else:
376 path.append(node)
377 d <<= 1
378 node = node.right
379 else:
380 if path:
381 path.pop()
382 d >>= 1
383 if path:
384 self.node = self._splay(path, d)
385 return res
386
[docs]
387 def index(self, key: T) -> int:
388 if self.node is None:
389 return 0
390 self._set_search_splay(key)
391 res = 0 if self.node.left is None else self.node.left.cntsize
392 if self.node.key < key:
393 res += self.node.cnt
394 return res
395
[docs]
396 def index_right(self, key: T) -> int:
397 if self.node is None:
398 return 0
399 self._set_search_splay(key)
400 res = 0 if self.node.left is None else self.node.left.cntsize
401 if self.node.key <= key:
402 res += self.node.cnt
403 return res
404
[docs]
405 def index_keys(self, key: T) -> int:
406 if self.node is None:
407 return 0
408 self._set_search_splay(key)
409 res = 0 if self.node.left is None else self.node.left.size
410 if self.node.key < key:
411 res += 1
412 return res
413
[docs]
414 def index_right_keys(self, key: T) -> int:
415 if self.node is None:
416 return 0
417 self._set_search_splay(key)
418 res = 0 if self.node.left is None else self.node.left.size
419 if self.node.key <= key:
420 res += 1
421 return res
422
[docs]
423 def pop(self, k: int = -1) -> T:
424 self._set_kth_elm_splay(k)
425 res = self.node.key
426 self.discard(res)
427 return res
428
[docs]
429 def pop_max(self) -> T:
430 return self.pop()
431
[docs]
432 def pop_min(self) -> T:
433 return self.pop(0)
434
[docs]
435 def tolist(self) -> list[T]:
436 a = []
437 if self.node is None:
438 return a
439 if sys.getrecursionlimit() < self.len_elm():
440 sys.setrecursionlimit(self.len_elm() + 1)
441
442 def rec(node):
443 if node.left is not None:
444 rec(node.left)
445 for _ in range(node.cnt):
446 a.append(node.key)
447 if node.right is not None:
448 rec(node.right)
449
450 rec(self.node)
451 return a
452
[docs]
453 def tolist_items(self) -> list[tuple[T, int]]:
454 a = []
455 if self.node is None:
456 return a
457 if sys.getrecursionlimit() < self.len_elm():
458 sys.setrecursionlimit(self.len_elm() + 1)
459
460 def rec(node):
461 if node.left is not None:
462 rec(node.left)
463 a.append((node.key, node.cnt))
464 if node.right is not None:
465 rec(node.right)
466
467 rec(self.node)
468 return a
469
[docs]
470 def get_elm(self, k: int) -> T:
471 assert -self.len_elm() <= k < self.len_elm()
472 self._set_kth_elm_tree_splay(k)
473 return self.node.key
474
[docs]
475 def items(self) -> Iterator[tuple[T, int]]:
476 for i in range(self.len_elm()):
477 self._set_kth_elm_tree_splay(i)
478 yield self.node.key, self.node.cnt
479
[docs]
480 def keys(self) -> Iterator[T]:
481 for i in range(self.len_elm()):
482 self._set_kth_elm_tree_splay(i)
483 yield self.node.key
484
[docs]
485 def cntues(self) -> Iterator[int]:
486 for i in range(self.len_elm()):
487 self._set_kth_elm_tree_splay(i)
488 yield self.node.cnt
489
[docs]
490 def len_elm(self) -> int:
491 return 0 if self.node is None else self.node.size
492
[docs]
493 def show(self) -> None:
494 print(
495 "{" + ", ".join(map(lambda x: f"{x[0]}: {x[1]}", self.tolist_items())) + "}"
496 )
497
[docs]
498 def clear(self) -> None:
499 self.node = None
500
[docs]
501 def get_sum(self) -> T:
502 if self.node is None:
503 return self.e
504 return self.node.data
505
[docs]
506 def merge(self, other: "SplayTreeMultisetSum") -> None:
507 if self.node is None:
508 self.node = other.node
509 return
510 if other.node is None:
511 return
512 self.node = self._get_max_splay(self.node)
513 other.node = other._get_min_splay(other.node)
514 assert self.node.key <= other.node.key
515 if self.node.key == other.node.key:
516 self.node.cnt += other.node.cnt
517 self._update(self.node)
518 other.discard(other.node.key, other.node.cnt)
519 self.node.right = other.node
520 self._update(self.node)
521
[docs]
522 def split(self, k: int) -> tuple["SplayTreeMultisetSum", "SplayTreeMultisetSum"]:
523 if self.node is None:
524 return SplayTreeMultisetSum(self.e), self
525 if k >= len(self):
526 return self, SplayTreeMultisetSum(self.e)
527 self._set_kth_elm_splay(k)
528 if self.node.left is None:
529 left = SplayTreeMultisetSum(self.e)
530 else:
531 left = SplayTreeMultisetSum(self.e, _node=self.node.left)
532 if len(left) != k:
533 assert k - len(left) > 0
534 self.node.cnt -= k - len(left)
535 self._update(self.node)
536 root = SplayTreeMultisetSum.Node(self.node.key, k - len(left))
537 root.left = left.node
538 left = SplayTreeMultisetSum(self.e, _node=root)
539 left._update(left.node)
540 if self.node:
541 self.node.left = None
542 self._update(self.node)
543 return left, self
544
545 def __iter__(self):
546 self.__iter = 0
547 return self
548
549 def __next__(self):
550 if self.__iter == self.__len__():
551 raise StopIteration
552 res = self.__getitem__(self.__iter)
553 self.__iter += 1
554 return res
555
556 def __reversed__(self):
557 for i in range(self.__len__()):
558 yield self.__getitem__(-i - 1)
559
560 def __contains__(self, key: T) -> bool:
561 self._set_search_splay(key)
562 return self.node is not None and self.node.key == key
563
564 def __getitem__(self, k: int) -> T:
565 self._set_kth_elm_splay(k)
566 return self.node.key
567
568 def __len__(self):
569 return 0 if self.node is None else self.node.cntsize
570
571 def __bool__(self):
572 return self.node is not None
573
574 def __str__(self):
575 return "{" + ", ".join(map(str, self.tolist())) + "}"
576
577 def __repr__(self):
578 return "SplayTreeMultisetSum(" + str(self) + ")"