1# from titan_pylib.data_structures.wbt.wbt_multiset import WBTMultiset
2# from titan_pylib.data_structures.wbt._wbt_multiset_node import _WBTMultisetNode
3# from titan_pylib.data_structures.wbt._wbt_node_base import _WBTNodeBase
4from typing import Generic, TypeVar, Optional, Final
5
6T = TypeVar("T")
7
8
9class _WBTNodeBase(Generic[T]):
10 """WBTノードのベースクラス
11 size, par, left, rightをもつ
12 """
13
14 __slots__ = "_size", "_par", "_left", "_right"
15 DELTA: Final[int] = 3
16 GAMMA: Final[int] = 2
17
18 def __init__(self) -> None:
19 self._size: int = 1
20 self._par: Optional[_WBTNodeBase[T]] = None
21 self._left: Optional[_WBTNodeBase[T]] = None
22 self._right: Optional[_WBTNodeBase[T]] = None
23
24 def _rebalance(self) -> "_WBTNodeBase[T]":
25 """根までを再構築する
26
27 Returns:
28 _WBTNodeBase[T]: 根ノード
29 """
30 node = self
31 while True:
32 node._update()
33 wl, wr = node._weight_left(), node._weight_right()
34 if wl * _WBTNodeBase.DELTA < wr:
35 if (
36 node._right._weight_left()
37 >= node._right._weight_right() * _WBTNodeBase.GAMMA
38 ):
39 node._right = node._right._rotate_right()
40 node = node._rotate_left()
41 elif wr * _WBTNodeBase.DELTA < wl:
42 if (
43 node._left._weight_right()
44 >= node._left._weight_left() * _WBTNodeBase.GAMMA
45 ):
46 node._left = node._left._rotate_left()
47 node = node._rotate_right()
48 if not node._par:
49 return node
50 node = node._par
51
52 def _copy_from(self, other: "_WBTNodeBase[T]") -> None:
53 self._size = other._size
54 if other._left:
55 other._left._par = self
56 if other._right:
57 other._right._par = self
58 if other._par:
59 if other._par._left is other:
60 other._par._left = self
61 else:
62 other._par._right = self
63 self._par = other._par
64 self._left = other._left
65 self._right = other._right
66
67 def _weight_left(self) -> int:
68 return self._left._size + 1 if self._left else 1
69
70 def _weight_right(self) -> int:
71 return self._right._size + 1 if self._right else 1
72
73 def _update(self) -> None:
74 self._size = (
75 1
76 + (self._left._size if self._left else 0)
77 + (self._right._size if self._right else 0)
78 )
79
80 def _rotate_right(self) -> "_WBTNodeBase[T]":
81 u = self._left
82 u._size = self._size
83 self._size -= u._left._size + 1 if u._left else 1
84 u._par = self._par
85 self._left = u._right
86 if u._right:
87 u._right._par = self
88 u._right = self
89 self._par = u
90 if u._par:
91 if u._par._left is self:
92 u._par._left = u
93 else:
94 u._par._right = u
95 return u
96
97 def _rotate_left(self) -> "_WBTNodeBase[T]":
98 u = self._right
99 u._size = self._size
100 self._size -= u._right._size + 1 if u._right else 1
101 u._par = self._par
102 self._right = u._left
103 if u._left:
104 u._left._par = self
105 u._left = self
106 self._par = u
107 if u._par:
108 if u._par._left is self:
109 u._par._left = u
110 else:
111 u._par._right = u
112 return u
113
114 def _balance_check(self) -> None:
115 if not self._weight_left() * _WBTNodeBase.DELTA >= self._weight_right():
116 print(self._weight_left(), self._weight_right(), flush=True)
117 print(self)
118 assert False, f"self._weight_left() * DELTA >= self._weight_right()"
119 if not self._weight_right() * _WBTNodeBase.DELTA >= self._weight_left():
120 print(self._weight_left(), self._weight_right(), flush=True)
121 print(self)
122 assert False, f"self._weight_right() * DELTA >= self._weight_left()"
123
124 def _min(self) -> "_WBTNodeBase[T]":
125 node = self
126 while node._left:
127 node = node._left
128 return node
129
130 def _max(self) -> "_WBTNodeBase[T]":
131 node = self
132 while node._right:
133 node = node._right
134 return node
135
136 def _next(self) -> Optional["_WBTNodeBase[T]"]:
137 if self._right:
138 return self._right._min()
139 now, pre = self, None
140 while now and now._right is pre:
141 now, pre = now._par, now
142 return now
143
144 def _prev(self) -> Optional["_WBTNodeBase[T]"]:
145 if self._left:
146 return self._left._max()
147 now, pre = self, None
148 while now and now._left is pre:
149 now, pre = now._par, now
150 return now
151
152 def __add__(self, other: int) -> Optional["_WBTNodeBase[T]"]:
153 node = self
154 for _ in range(other):
155 node = node._next()
156 return node
157
158 def __sub__(self, other: int) -> Optional["_WBTNodeBase[T]"]:
159 node = self
160 for _ in range(other):
161 node = node._prev()
162 return node
163
164 __iadd__ = __add__
165 __isub__ = __sub__
166
167 def __str__(self) -> str:
168 # if self._left is None and self._right is None:
169 # return f"key:{self._key, self._size}\n"
170 # return f"key:{self._key, self._size},\n _left:{self._left},\n _right:{self._right}\n"
171 return str(self._key)
172
173 __repr__ = __str__
174from typing import TypeVar, Optional
175
176T = TypeVar("T")
177
178
179class _WBTMultisetNode(_WBTNodeBase[T]):
180
181 __slots__ = "_key", "_count", "_count_size", "_size", "_par", "_left", "_right"
182
183 def __init__(self, key: T, count: int) -> None:
184 super().__init__()
185 self._key: T = key
186 self._count: int = count
187 self._count_size: int = count
188 self._par: Optional[_WBTMultisetNode[T]]
189 self._left: Optional[_WBTMultisetNode[T]]
190 self._right: Optional[_WBTMultisetNode[T]]
191
192 @property
193 def key(self) -> T:
194 return self._key
195
196 @property
197 def count(self) -> T:
198 return self._count
199
200 def _update(self) -> None:
201 super()._update()
202 self._count_size = (
203 self._count
204 + (self._left._count_size if self._left else 0)
205 + (self._right._count_size if self._right else 0)
206 )
207
208 def _rotate_right(self) -> "_WBTMultisetNode[T]":
209 u = self._left
210 u._size = self._size
211 u._count_size = self._count_size
212 self._size -= u._left._size + 1 if u._left else 1
213 self._count_size -= u._left._count_size + u._count if u._left else u._count
214 u._par = self._par
215 self._left = u._right
216 if u._right:
217 u._right._par = self
218 u._right = self
219 self._par = u
220 if u._par:
221 if u._par._left is self:
222 u._par._left = u
223 else:
224 u._par._right = u
225 return u
226
227 def _rotate_left(self) -> "_WBTMultisetNode[T]":
228 u = self._right
229 u._size = self._size
230 u._count_size = self._count_size
231 self._size -= u._right._size + 1 if u._right else 1
232 self._count_size -= u._right._count_size + u._count if u._right else u._count
233 u._par = self._par
234 self._right = u._left
235 if u._left:
236 u._left._par = self
237 u._left = self
238 self._par = u
239 if u._par:
240 if u._par._left is self:
241 u._par._left = u
242 else:
243 u._par._right = u
244 return u
245
246 def _copy_from(self, other: "_WBTMultisetNode[T]") -> None:
247 super()._copy_from(other)
248 self._count = other._count
249 self._count_size = other._count_size
250from typing import Generic, TypeVar, Optional, Iterable, Iterator
251
252T = TypeVar("T")
253
254
255class WBTMultiset(Generic[T]):
256
257 __slots__ = "_root", "_min", "_max"
258
259 def __init__(self, a: Iterable[T] = []) -> None:
260 self._root: Optional[_WBTMultisetNode[T]] = None
261 self._min: Optional[_WBTMultisetNode[T]] = None
262 self._max: Optional[_WBTMultisetNode[T]] = None
263 self.__build(a)
264
265 def __build(self, a: Iterable[T]) -> None:
266 def build(
267 l: int, r: int, pnode: Optional[_WBTMultisetNode[T]] = None
268 ) -> _WBTMultisetNode[T]:
269 if l == r:
270 return None
271 mid = (l + r) // 2
272 node = _WBTMultisetNode(keys[mid], vals[mid])
273 node._left = build(l, mid, node)
274 node._right = build(mid + 1, r, node)
275 node._par = pnode
276 node._update()
277 return node
278
279 a = list(a)
280 if not a:
281 return
282 if not all(a[i] <= a[i + 1] for i in range(len(a) - 1)):
283 a.sort()
284 # RLE
285 keys, vals = [a[0]], [1]
286 for i, elm in enumerate(a):
287 if i == 0:
288 continue
289 if elm == keys[-1]:
290 vals[-1] += 1
291 continue
292 keys.append(elm)
293 vals.append(1)
294 self._root = build(0, len(keys))
295 self._max = self._root._max()
296 self._min = self._root._min()
297
298 def add(self, key: T, count: int = 1) -> None:
299 if not self._root:
300 self._root = _WBTMultisetNode(key, count)
301 self._max = self._root
302 self._min = self._root
303 return
304 pnode = None
305 node = self._root
306 while node:
307 node._count_size += count
308 if key == node._key:
309 node._count += count
310 return
311 pnode = node
312 node = node._left if key < node._key else node._right
313 if key < pnode._key:
314 pnode._left = _WBTMultisetNode(key, count)
315 if key < self._min._key:
316 self._min = pnode._left
317 pnode._left._par = pnode
318 else:
319 pnode._right = _WBTMultisetNode(key, count)
320 if key > self._max._key:
321 self._max = pnode._right
322 pnode._right._par = pnode
323 self._root = pnode._rebalance()
324
325 def find_key(self, key: T) -> Optional[_WBTMultisetNode[T]]:
326 node = self._root
327 while node:
328 if key == node._key:
329 return node
330 node = node._left if key < node._key else node._right
331 return None
332
333 def find_order(self, k: int) -> _WBTMultisetNode[T]:
334 node = self._root
335 while True:
336 t = node._left._count_size + node._count if node._left else node._count
337 if t - node._count <= k < t:
338 return node
339 if t > k:
340 node = node._left
341 else:
342 node = node._right
343 k -= t
344
345 def count(self, key: T) -> int:
346 node = self.find_key(key)
347 return node.count if node is not None else 0
348
349 def remove_iter(self, node: _WBTMultisetNode[T]) -> None:
350 if node is self._min:
351 self._min = self._min._next()
352 if node is self._max:
353 self._max = self._max._prev()
354 delnode = node
355 pnode, mnode = node._par, None
356 if node._left and node._right:
357 pnode, mnode = node, node._left
358 while mnode._right:
359 pnode, mnode = mnode, mnode._right
360 node._count = mnode._count
361 node = mnode
362 cnode = node._right if not node._left else node._left
363 if cnode:
364 cnode._par = pnode
365 if pnode:
366 if pnode._left is node:
367 pnode._left = cnode
368 else:
369 pnode._right = cnode
370 self._root = pnode._rebalance()
371 else:
372 self._root = cnode
373 if mnode:
374 if self._root is delnode:
375 self._root = mnode
376 mnode._copy_from(delnode)
377 del delnode
378
379 def remove(self, key: T, count: int = 1) -> None:
380 node = self.find_key(key)
381 assert node, f"KeyError: {key} is not found."
382 if node._count <= count:
383 self.remove_iter(node)
384 else:
385 node._count -= count
386 while node:
387 node._count_size -= count
388 node = node._par
389
390 def discard(self, key: T, count: int = 1) -> bool:
391 node = self.find_key(key)
392 if node is None:
393 return False
394 if node._count <= count:
395 self.remove_iter(node)
396 else:
397 node._count -= count
398 while node:
399 node._count_size -= count
400 node = node._par
401 return True
402
403 def pop(self, k: int = -1) -> T:
404 node = self.find_order(k)
405 key = node._key
406 if node._count == 0:
407 self.remove_iter(node)
408 else:
409 node._count -= 1
410 while node:
411 node._count_size -= 1
412 node = node._par
413 return key
414
415 def le_iter(self, key: T) -> Optional[_WBTMultisetNode[T]]:
416 res = None
417 node = self._root
418 while node:
419 if key == node._key:
420 res = node
421 break
422 if key < node._key:
423 node = node._left
424 else:
425 res = node
426 node = node._right
427 return res
428
429 def lt_iter(self, key: T) -> Optional[_WBTMultisetNode[T]]:
430 res = None
431 node = self._root
432 while node:
433 if key <= node._key:
434 node = node._left
435 else:
436 res = node
437 node = node._right
438 return res
439
440 def ge_iter(self, key: T) -> Optional[_WBTMultisetNode[T]]:
441 res = None
442 node = self._root
443 while node:
444 if key == node._key:
445 res = node
446 break
447 if key < node._key:
448 res = node
449 node = node._left
450 else:
451 node = node._right
452 return res
453
454 def gt_iter(self, key: T) -> Optional[_WBTMultisetNode[T]]:
455 res = None
456 node = self._root
457 while node:
458 if key < node._key:
459 res = node
460 node = node._left
461 else:
462 node = node._right
463 return res
464
465 def le(self, key: T) -> Optional[T]:
466 res = None
467 node = self._root
468 while node:
469 if key == node._key:
470 res = key
471 break
472 if key < node._key:
473 node = node._left
474 else:
475 res = node._key
476 node = node._right
477 return res
478
479 def lt(self, key: T) -> Optional[T]:
480 res = None
481 node = self._root
482 while node:
483 if key <= node._key:
484 node = node._left
485 else:
486 res = node._key
487 node = node._right
488 return res
489
490 def ge(self, key: T) -> Optional[T]:
491 res = None
492 node = self._root
493 while node:
494 if key == node._key:
495 res = key
496 break
497 if key < node._key:
498 res = node._key
499 node = node._left
500 else:
501 node = node._right
502 return res
503
504 def gt(self, key: T) -> Optional[T]:
505 res = None
506 node = self._root
507 while node:
508 if key < node._key:
509 res = node._key
510 node = node._left
511 else:
512 node = node._right
513 return res
514
515 def index(self, key: T) -> int:
516 k = 0
517 node = self._root
518 while node:
519 if key == node._key:
520 k += node._left._count_size if node._left else 0
521 break
522 if key < node._key:
523 node = node._left
524 else:
525 k += node._left._count_size + node._count if node._left else node._count
526 node = node._right
527 return k
528
529 def index_right(self, key: T) -> int:
530 k = 0
531 node = self._root
532 while node:
533 if key == node._key:
534 k += node._left._count_size + node._count if node._left else node._count
535 break
536 if key < node._key:
537 node = node._left
538 else:
539 k += node._left._count_size + node._count if node._left else node._count
540 node = node._right
541 return k
542
543 def tolist(self) -> list[T]:
544 return list(self)
545
546 def get_min(self) -> T:
547 assert self._min
548 return self._min._key
549
550 def get_max(self) -> T:
551 assert self._max
552 return self._max._key
553
554 def pop_min(self) -> T:
555 assert self._min
556 key = self._min._key
557 self._min._count -= 1
558 if self._min._count == 0:
559 self.remove_iter(self._min)
560 return key
561
562 def pop_max(self) -> T:
563 assert self._max
564 key = self._max._key
565 self._max._count -= 1
566 if self._max._count == 0:
567 self.remove_iter(self._max)
568 return key
569
570 def check(self) -> None:
571 if self._root is None:
572 # print("ok. 0 (empty)")
573 return
574
575 # _size, count_size, height
576 def dfs(node: _WBTMultisetNode[T]) -> tuple[int, int, int]:
577 h = 0
578 s = 1
579 cs = node.count
580 if node._left:
581 assert node._key > node._left._key
582 ls, lcs, lh = dfs(node._left)
583 s += ls
584 cs += lcs
585 h = max(h, lh)
586 if node._right:
587 assert node._key < node._right._key
588 rs, rcs, rh = dfs(node._right)
589 s += rs
590 cs += rcs
591 h = max(h, rh)
592 assert node._size == s
593 assert node._count_size == cs
594 node._balance_check()
595 return s, cs, h + 1
596
597 _, _, h = dfs(self._root)
598 # print(f"ok. {h}")
599
600 def __contains__(self, key: T) -> bool:
601 return self.find_key(key) is not None
602
603 def __getitem__(self, k: int) -> T:
604 assert (
605 -len(self) <= k < len(self)
606 ), f"IndexError: {self.__class__.__name__}[{k}], len={len(self)}"
607 if k < 0:
608 k += len(self)
609 if k == 0:
610 return self.get_min()
611 if k == len(self) - 1:
612 return self.get_max()
613 return self.find_order(k)._key
614
615 def __delitem__(self, k: int) -> None:
616 node = self.find_order(k)
617 node._count -= 1
618 if node._count == 0:
619 self.remove_iter(node)
620
621 def __len__(self) -> int:
622 return self._root._count_size if self._root else 0
623
624 def __iter__(self) -> Iterator[T]:
625 stack: list[_WBTMultisetNode[T]] = []
626 node = self._root
627 while stack or node:
628 if node:
629 stack.append(node)
630 node = node._left
631 else:
632 node = stack.pop()
633 for _ in range(node._count):
634 yield node._key
635 node = node._right
636
637 def __reversed__(self) -> Iterator[T]:
638 stack: list[_WBTMultisetNode[T]] = []
639 node = self._root
640 while stack or node:
641 if node:
642 stack.append(node)
643 node = node._right
644 else:
645 node = stack.pop()
646 for _ in range(node._count):
647 yield node._key
648 node = node._left
649
650 def __str__(self) -> str:
651 return "{" + ", ".join(map(str, self)) + "}"
652
653 def __repr__(self) -> str:
654 return (
655 f"{self.__class__.__name__}("
656 + "["
657 + ", ".join(map(str, self.tolist()))
658 + "])"
659 )