1# from titan_pylib.data_structures.avl_tree.avl_tree_multiset2 import AVLTreeMultiset2
2# from titan_pylib.my_class.supports_less_than import SupportsLessThan
3from typing import Protocol
4
5
6class SupportsLessThan(Protocol):
7
8 def __lt__(self, other) -> bool: ...
9# from titan_pylib.data_structures.bst_base.bst_multiset_array_base import (
10# BSTMultisetArrayBase,
11# )
12from typing import TypeVar, Generic, Optional
13
14T = TypeVar("T")
15BST = TypeVar("BST")
16# protcolで、key,val,left,right を規定
17
18
19class BSTMultisetArrayBase(Generic[BST, T]):
20
21 @staticmethod
22 def _rle(a: list[T]) -> tuple[list[T], list[int]]:
23 keys, vals = [a[0]], [1]
24 for i, elm in enumerate(a):
25 if i == 0:
26 continue
27 if elm == keys[-1]:
28 vals[-1] += 1
29 continue
30 keys.append(elm)
31 vals.append(1)
32 return keys, vals
33
34 @staticmethod
35 def count(bst: BST, key: T) -> int:
36 keys, left, right = bst.key, bst.left, bst.right
37 node = bst.root
38 while node:
39 if keys[node] == key:
40 return bst.val[node]
41 node = left[node] if key < keys[node] else right[node]
42 return 0
43
44 @staticmethod
45 def le(bst: BST, key: T) -> Optional[T]:
46 keys, left, right = bst.key, bst.left, bst.right
47 res = None
48 node = bst.root
49 while node:
50 if key == keys[node]:
51 res = key
52 break
53 if key < keys[node]:
54 node = left[node]
55 else:
56 res = keys[node]
57 node = right[node]
58 return res
59
60 @staticmethod
61 def lt(bst: BST, key: T) -> Optional[T]:
62 keys, left, right = bst.key, bst.left, bst.right
63 res = None
64 node = bst.root
65 while node:
66 if key <= keys[node]:
67 node = left[node]
68 else:
69 res = keys[node]
70 node = right[node]
71 return res
72
73 @staticmethod
74 def ge(bst: BST, key: T) -> Optional[T]:
75 keys, left, right = bst.key, bst.left, bst.right
76 res = None
77 node = bst.root
78 while node:
79 if key == keys[node]:
80 res = key
81 break
82 if key < keys[node]:
83 res = keys[node]
84 node = left[node]
85 else:
86 node = right[node]
87 return res
88
89 @staticmethod
90 def gt(bst: BST, key: T) -> Optional[T]:
91 keys, left, right = bst.key, bst.left, bst.right
92 res = None
93 node = bst.root
94 while node:
95 if key < keys[node]:
96 res = keys[node]
97 node = left[node]
98 else:
99 node = right[node]
100 return res
101
102 @staticmethod
103 def index(bst: BST, key: T) -> int:
104 keys, left, right, vals, valsize = (
105 bst.key,
106 bst.left,
107 bst.right,
108 bst.val,
109 bst.valsize,
110 )
111 k = 0
112 node = bst.root
113 while node:
114 if key == keys[node]:
115 if left[node]:
116 k += valsize[left[node]]
117 break
118 if key < keys[node]:
119 node = left[node]
120 else:
121 k += valsize[left[node]] + vals[node]
122 node = right[node]
123 return k
124
125 @staticmethod
126 def index_right(bst: BST, key: T) -> int:
127 keys, left, right, vals, valsize = (
128 bst.key,
129 bst.left,
130 bst.right,
131 bst.val,
132 bst.valsize,
133 )
134 k = 0
135 node = bst.root
136 while node:
137 if key == keys[node]:
138 k += valsize[left[node]] + vals[node]
139 break
140 if key < keys[node]:
141 node = left[node]
142 else:
143 k += valsize[left[node]] + vals[node]
144 node = right[node]
145 return k
146
147 @staticmethod
148 def _kth_elm(bst: BST, k: int) -> tuple[T, int]:
149 left, right, vals, valsize = bst.left, bst.right, bst.val, bst.valsize
150 if k < 0:
151 k += len(bst)
152 node = bst.root
153 while True:
154 t = vals[node] + valsize[left[node]]
155 if t - vals[node] <= k < t:
156 return bst.key[node], vals[node]
157 if t > k:
158 node = left[node]
159 else:
160 node = right[node]
161 k -= t
162
163 @staticmethod
164 def contains(bst: BST, key: T) -> bool:
165 left, right, keys = bst.left, bst.right, bst.key
166 node = bst.root
167 while node:
168 if keys[node] == key:
169 return True
170 node = left[node] if key < keys[node] else right[node]
171 return False
172
173 @staticmethod
174 def tolist(bst: BST) -> list[T]:
175 left, right, keys, vals = bst.left, bst.right, bst.key, bst.val
176 node = bst.root
177 stack, a = [], []
178 while stack or node:
179 if node:
180 stack.append(node)
181 node = left[node]
182 else:
183 node = stack.pop()
184 key = keys[node]
185 for _ in range(vals[node]):
186 a.append(key)
187 node = right[node]
188 return a
189from typing import Generic, Iterable, TypeVar, Optional
190from array import array
191
192T = TypeVar("T", bound=SupportsLessThan)
193
194
195class AVLTreeMultiset2(Generic[T]):
196 """
197 多重集合としての AVL 木です。
198 配列を用いてノードを表現しています。
199 size を持たないので軽めです。
200 """
201
202 def __init__(self, a: Iterable[T] = []):
203 self.root = 0
204 self._len = 0
205 self.key = [0]
206 self.val = [0]
207 self.left = array("I", bytes(4))
208 self.right = array("I", bytes(4))
209 self.balance = array("b", bytes(1))
210 self.end = 1
211 if not isinstance(a, list):
212 a = list(a)
213 if a:
214 self._build(a)
215
216 def _make_node(self, key: T, val: int) -> int:
217 end = self.end
218 if end >= len(self.key):
219 self.key.append(key)
220 self.val.append(val)
221 self.left.append(0)
222 self.right.append(0)
223 self.balance.append(0)
224 else:
225 self.key[end] = key
226 self.val[end] = val
227 self.end += 1
228 return end
229
230 def reserve(self, n: int) -> None:
231 a = [0] * n
232 self.key += a
233 self.val += a
234 a = array("I", bytes(4 * n))
235 self.left += a
236 self.right += a
237 self.balance += array("b", bytes(n))
238
239 def _build(self, a: list[T]) -> None:
240 left, right, balance = self.left, self.right, self.balance
241
242 def sort(l: int, r: int) -> tuple[int, int]:
243 mid = (l + r) >> 1
244 node = mid
245 hl, hr = 0, 0
246 if l != mid:
247 left[node], hl = sort(l, mid)
248 if mid + 1 != r:
249 right[node], hr = sort(mid + 1, r)
250 balance[node] = hl - hr
251 return node, max(hl, hr) + 1
252
253 self._len = len(a)
254 if not all(a[i] < a[i + 1] for i in range(len(a) - 1)):
255 a = sorted(a)
256 x, y = BSTMultisetArrayBase[AVLTreeMultiset2, T]._rle(a)
257 n = len(x)
258 end = self.end
259 self.end += n
260 self.reserve(n)
261 self.key[end : end + n] = x
262 self.val[end : end + n] = y
263 self.root = sort(end, n + end)[0]
264
265 def _rotate_L(self, node: int) -> int:
266 left, right, balance = self.left, self.right, self.balance
267 u = left[node]
268 left[node] = right[u]
269 right[u] = node
270 if balance[u] == 1:
271 balance[u] = 0
272 balance[node] = 0
273 else:
274 balance[u] = -1
275 balance[node] = 1
276 return u
277
278 def _rotate_R(self, node: int) -> int:
279 left, right, balance = self.left, self.right, self.balance
280 u = right[node]
281 right[node] = left[u]
282 left[u] = node
283 if balance[u] == -1:
284 balance[u] = 0
285 balance[node] = 0
286 else:
287 balance[u] = 1
288 balance[node] = -1
289 return u
290
291 def _update_balance(self, node: int) -> None:
292 left, right, balance = self.left, self.right, self.balance
293 if balance[node] == 1:
294 balance[right[node]] = -1
295 balance[left[node]] = 0
296 elif balance[node] == -1:
297 balance[right[node]] = 0
298 balance[left[node]] = 1
299 else:
300 balance[right[node]] = 0
301 balance[left[node]] = 0
302 balance[node] = 0
303
304 def _rotate_LR(self, node: int) -> int:
305 left, right = self.left, self.right
306 B = left[node]
307 E = right[B]
308 right[B] = left[E]
309 left[E] = B
310 left[node] = right[E]
311 right[E] = node
312 self._update_balance(E)
313 return E
314
315 def _rotate_RL(self, node: int) -> int:
316 left, right = self.left, self.right
317 C = right[node]
318 D = left[C]
319 left[C] = right[D]
320 right[D] = C
321 right[node] = left[D]
322 left[D] = node
323 self._update_balance(D)
324 return D
325
326 def _discard(self, node: int, path: list[int], di: int) -> bool:
327 left, right, keys, vals, balance = (
328 self.left,
329 self.right,
330 self.key,
331 self.val,
332 self.balance,
333 )
334 if left[node] and right[node]:
335 path.append(node)
336 di <<= 1
337 di |= 1
338 lmax = left[node]
339 while right[lmax]:
340 path.append(lmax)
341 di <<= 1
342 lmax = right[lmax]
343 lmax_val = vals[lmax]
344 keys[node] = keys[lmax]
345 vals[node] = lmax_val
346 node = lmax
347 cnode = right[node] if left[node] == 0 else left[node]
348 if path:
349 if di & 1:
350 left[path[-1]] = cnode
351 else:
352 right[path[-1]] = cnode
353 else:
354 self.root = cnode
355 return True
356 while path:
357 new_node = 0
358 pnode = path.pop()
359 balance[pnode] -= 1 if di & 1 else -1
360 di >>= 1
361 if balance[pnode] == 2:
362 new_node = (
363 self._rotate_LR(pnode)
364 if balance[left[pnode]] < 0
365 else self._rotate_L(pnode)
366 )
367 elif balance[pnode] == -2:
368 new_node = (
369 self._rotate_RL(pnode)
370 if balance[right[pnode]] > 0
371 else self._rotate_R(pnode)
372 )
373 elif balance[pnode] != 0:
374 break
375 if new_node:
376 if not path:
377 self.root = new_node
378 return
379 if di & 1:
380 left[path[-1]] = new_node
381 else:
382 right[path[-1]] = new_node
383 if balance[new_node] != 0:
384 break
385 return True
386
387 def discard(self, key: T, val: int = 1) -> bool:
388 keys, vals, left, right = self.key, self.val, self.left, self.right
389 path = []
390 di = 0
391 node = self.root
392 while node:
393 if key == keys[node]:
394 break
395 path.append(node)
396 di <<= 1
397 if key < keys[node]:
398 di |= 1
399 node = left[node]
400 else:
401 node = right[node]
402 else:
403 return False
404 self._len -= min(val, vals[node])
405 if val > vals[node]:
406 val = vals[node] - 1
407 vals[node] -= val
408 if vals[node] == 1:
409 self._discard(node, path, di)
410 else:
411 vals[node] -= val
412 return True
413
414 def discard_all(self, key: T) -> None:
415 self.discard(key, self.count(key))
416
417 def remove(self, key: T, val: int = 1) -> None:
418 if self.discard(key, val):
419 return
420 raise KeyError(key)
421
422 def add(self, key: T, val: int = 1) -> None:
423 self._len += val
424 if self.root == 0:
425 self.root = self._make_node(key, val)
426 return
427 left, right, keys, balance = self.left, self.right, self.key, self.balance
428 node = self.root
429 di = 0
430 path = []
431 while node:
432 if key == keys[node]:
433 self.val[node] += val
434 return
435 path.append(node)
436 di <<= 1
437 if key < keys[node]:
438 di |= 1
439 node = left[node]
440 else:
441 node = right[node]
442 if di & 1:
443 left[path[-1]] = self._make_node(key, val)
444 else:
445 right[path[-1]] = self._make_node(key, val)
446 new_node = 0
447 while path:
448 node = path.pop()
449 balance[node] += 1 if di & 1 else -1
450 di >>= 1
451 if balance[node] == 0:
452 break
453 if balance[node] == 2:
454 new_node = (
455 self._rotate_LR(node)
456 if balance[left[node]] < 0
457 else self._rotate_L(node)
458 )
459 break
460 elif balance[node] == -2:
461 new_node = (
462 self._rotate_RL(node)
463 if balance[right[node]] > 0
464 else self._rotate_R(node)
465 )
466 break
467 if new_node:
468 if path:
469 if di & 1:
470 left[path[-1]] = new_node
471 else:
472 right[path[-1]] = new_node
473 else:
474 self.root = new_node
475
476 def count(self, key: T) -> int:
477 return BSTMultisetArrayBase[AVLTreeMultiset2, T].count(self, key)
478
479 def le(self, key: T) -> Optional[T]:
480 return BSTMultisetArrayBase[AVLTreeMultiset2, T].le(self, key)
481
482 def lt(self, key: T) -> Optional[T]:
483 return BSTMultisetArrayBase[AVLTreeMultiset2, T].lt(self, key)
484
485 def ge(self, key: T) -> Optional[T]:
486 return BSTMultisetArrayBase[AVLTreeMultiset2, T].ge(self, key)
487
488 def gt(self, key: T) -> Optional[T]:
489 return BSTMultisetArrayBase[AVLTreeMultiset2, T].gt(self, key)
490
491 def get_min(self) -> Optional[T]:
492 if self.root == 0:
493 return
494 left = self.left
495 node = self.root
496 while left[node]:
497 node = left[node]
498 return self.key[node]
499
500 def get_max(self) -> Optional[T]:
501 if self.root == 0:
502 return
503 right = self.right
504 node = self.root
505 while right[node]:
506 node = right[node]
507 return self.key[node]
508
509 def pop_min(self) -> T:
510 left, vals, keys = self.left, self.val, self.key
511 self._len -= 1
512 node = self.root
513 path = []
514 while left[node]:
515 path.append(node)
516 node = left[node]
517 x = keys[node]
518 if vals[node] == 1:
519 self._discard(node, path, (1 << len(path)) - 1)
520 else:
521 vals[node] -= 1
522 return x
523
524 def pop_max(self) -> T:
525 right, vals, keys = self.right, self.val, self.key
526 self._len -= 1
527 node = self.root
528 path = []
529 while right[node]:
530 path.append(node)
531 node = right[node]
532 x = keys[node]
533 if vals[node] == 1:
534 self._discard(node, path, 0)
535 else:
536 vals[node] -= 1
537 return x
538
539 def clear(self) -> None:
540 self.root = 0
541
542 def tolist(self) -> list[T]:
543 return BSTMultisetArrayBase[AVLTreeMultiset2, T].tolist(self)
544
545 def tolist_items(self) -> list[tuple[T, int]]:
546 left, right, keys, vals = self.left, self.right, self.key, self.val
547 node = self.root
548 stack: list[int] = []
549 a: list[tuple[T, int]] = []
550 while stack or node:
551 if node:
552 stack.append(node)
553 node = left[node]
554 else:
555 node = stack.pop()
556 a.append((keys[node], vals[node]))
557 node = right[node]
558 return a
559
560 def __contains__(self, key: T):
561 return BSTMultisetArrayBase[AVLTreeMultiset2, T].contains(self, key)
562
563 def __len__(self):
564 return self._len
565
566 def __bool__(self):
567 return self.root != 0
568
569 def __str__(self):
570 return "{" + ", ".join(map(str, self.tolist())) + "}"
571
572 def __repr__(self):
573 return f"{self.__class__.__name__}({self.tolist()})"