1from titan_pylib.my_class.ordered_multiset_interface import OrderedMultisetInterface
2from titan_pylib.my_class.supports_less_than import SupportsLessThan
3from titan_pylib.data_structures.bst_base.bst_multiset_array_base import (
4 BSTMultisetArrayBase,
5)
6from typing import Generic, Iterable, Iterator, TypeVar, Optional
7from array import array
8
9T = TypeVar("T", bound=SupportsLessThan)
10
11
[docs]
12class AVLTreeMultiset(OrderedMultisetInterface, Generic[T]):
13 """
14 多重集合としての AVL 木です。
15 配列を用いてノードを表現しています。
16 size を持ちます。
17 """
18
19 def __init__(self, a: Iterable[T] = []):
20 self.root = 0
21 self.key = [0]
22 self.val = [0]
23 self.valsize = [0]
24 self.size = array("I", bytes(4))
25 self.left = array("I", bytes(4))
26 self.right = array("I", bytes(4))
27 self.balance = array("b", bytes(1))
28 self.end = 1
29 if not isinstance(a, list):
30 a = list(a)
31 if a:
32 self._build(a)
33
34 def _make_node(self, key: T, val: int) -> int:
35 end = self.end
36 if end >= len(self.key):
37 self.key.append(key)
38 self.val.append(val)
39 self.valsize.append(val)
40 self.size.append(1)
41 self.left.append(0)
42 self.right.append(0)
43 self.balance.append(0)
44 else:
45 self.key[end] = key
46 self.val[end] = val
47 self.valsize[end] = val
48 self.end += 1
49 return end
50
[docs]
51 def reserve(self, n: int) -> None:
52 a = [0] * n
53 self.key += a
54 self.val += a
55 self.valsize += a
56 a = array("I", bytes(4 * n))
57 self.left += a
58 self.right += a
59 self.size += array("I", [1] * n)
60 self.balance += array("b", bytes(n))
61
62 def _build(self, a: list[T]) -> None:
63 left, right, size, valsize, balance = (
64 self.left,
65 self.right,
66 self.size,
67 self.valsize,
68 self.balance,
69 )
70
71 def sort(l: int, r: int) -> tuple[int, int]:
72 mid = (l + r) >> 1
73 node = mid
74 hl, hr = 0, 0
75 if l != mid:
76 left[node], hl = sort(l, mid)
77 size[node] += size[left[node]]
78 valsize[node] += valsize[left[node]]
79 if mid + 1 != r:
80 right[node], hr = sort(mid + 1, r)
81 size[node] += size[right[node]]
82 valsize[node] += valsize[right[node]]
83 balance[node] = hl - hr
84 return node, max(hl, hr) + 1
85
86 if not all(a[i] <= a[i + 1] for i in range(len(a) - 1)):
87 a = sorted(a)
88 if not a:
89 return
90 x, y = BSTMultisetArrayBase[AVLTreeMultiset, T]._rle(a)
91 n = len(x)
92 end = self.end
93 self.end += n
94 self.reserve(n)
95 self.key[end : end + n] = x
96 self.val[end : end + n] = y
97 self.valsize[end : end + n] = y
98 self.root = sort(end, n + end)[0]
99
100 def _rotate_L(self, node: int) -> int:
101 left, right, size, valsize, balance = (
102 self.left,
103 self.right,
104 self.size,
105 self.valsize,
106 self.balance,
107 )
108 u = left[node]
109 size[u] = size[node]
110 valsize[u] = valsize[node]
111 if left[u] == 0:
112 size[node] -= 1
113 valsize[node] -= self.val[u]
114 else:
115 size[node] -= size[left[u]] + 1
116 valsize[node] -= valsize[left[u]] + self.val[u]
117 left[node] = right[u]
118 right[u] = node
119 if balance[u] == 1:
120 balance[u] = 0
121 balance[node] = 0
122 else:
123 balance[u] = -1
124 balance[node] = 1
125 return u
126
127 def _rotate_R(self, node: int) -> int:
128 left, right, size, valsize, balance = (
129 self.left,
130 self.right,
131 self.size,
132 self.valsize,
133 self.balance,
134 )
135 u = right[node]
136 size[u] = size[node]
137 valsize[u] = valsize[node]
138 if right[u] == 0:
139 size[node] -= 1
140 valsize[node] -= self.val[u]
141 else:
142 size[node] -= size[right[u]] + 1
143 valsize[node] -= valsize[right[u]] + self.val[u]
144 right[node] = left[u]
145 left[u] = node
146 if balance[u] == -1:
147 balance[u] = 0
148 balance[node] = 0
149 else:
150 balance[u] = 1
151 balance[node] = -1
152 return u
153
154 def _update_balance(self, node: int) -> None:
155 left, right, balance = self.left, self.right, self.balance
156 if balance[node] == 1:
157 balance[right[node]] = -1
158 balance[left[node]] = 0
159 elif balance[node] == -1:
160 balance[right[node]] = 0
161 balance[left[node]] = 1
162 else:
163 balance[right[node]] = 0
164 balance[left[node]] = 0
165 balance[node] = 0
166
167 def _rotate_LR(self, node: int) -> int:
168 left, right, size, valsize = self.left, self.right, self.size, self.valsize
169 B = left[node]
170 E = right[B]
171 size[E] = size[node]
172 valsize[E] = valsize[node]
173 if right[E] == 0:
174 size[node] -= size[B]
175 valsize[node] -= valsize[B]
176 size[B] -= 1
177 valsize[B] -= self.val[E]
178 else:
179 size[node] -= size[B] - size[right[E]]
180 valsize[node] -= valsize[B] - valsize[right[E]]
181 size[B] -= size[right[E]] + 1
182 valsize[B] -= valsize[right[E]] + self.val[E]
183 right[B] = left[E]
184 left[E] = B
185 left[node] = right[E]
186 right[E] = node
187 self._update_balance(E)
188 return E
189
190 def _rotate_RL(self, node: int) -> int:
191 left, right, size, valsize = self.left, self.right, self.size, self.valsize
192 C = right[node]
193 D = left[C]
194 size[D] = size[node]
195 valsize[D] = valsize[node]
196 if left[D] == 0:
197 size[node] -= size[C]
198 valsize[node] -= valsize[C]
199 size[C] -= 1
200 valsize[C] -= self.val[D]
201 else:
202 size[node] -= size[C] - size[left[D]]
203 valsize[node] -= valsize[C] - valsize[left[D]]
204 size[C] -= size[left[D]] + 1
205 valsize[C] -= valsize[left[D]] + self.val[D]
206 left[C] = right[D]
207 right[D] = C
208 right[node] = left[D]
209 left[D] = node
210 self._update_balance(D)
211 return D
212
213 def _kth_elm(self, k: int) -> tuple[T, int]:
214 return BSTMultisetArrayBase[AVLTreeMultiset, T]._kth_elm(self, k)
215
216 def _kth_elm_tree(self, k: int) -> tuple[T, int]:
217 left, right, vals, size = self.left, self.right, self.val, self.size
218 if k < 0:
219 k += self.len_elm()
220 assert 0 <= k < self.len_elm()
221 node = self.root
222 while True:
223 t = size[left[node]]
224 if t == k:
225 return self.key[node], vals[node]
226 if t > k:
227 node = left[node]
228 else:
229 node = right[node]
230 k -= t + 1
231
232 def _discard(self, node: int, path: list[int], di: int) -> bool:
233 left, right, keys, vals = self.left, self.right, self.key, self.val
234 balance, size, valsize = self.balance, self.size, self.valsize
235 fdi = 0
236 if left[node] and right[node]:
237 path.append(node)
238 di <<= 1
239 di |= 1
240 lmax = left[node]
241 while right[lmax]:
242 path.append(lmax)
243 di <<= 1
244 fdi <<= 1
245 fdi |= 1
246 lmax = right[lmax]
247 lmax_val = vals[lmax]
248 keys[node] = keys[lmax]
249 vals[node] = lmax_val
250 node = lmax
251 cnode = right[node] if left[node] == 0 else left[node]
252 if path:
253 if di & 1:
254 left[path[-1]] = cnode
255 else:
256 right[path[-1]] = cnode
257 else:
258 self.root = cnode
259 return True
260 while path:
261 new_node = 0
262 pnode = path.pop()
263 balance[pnode] -= 1 if di & 1 else -1
264 size[pnode] -= 1
265 valsize[pnode] -= lmax_val if fdi & 1 else 1
266 di >>= 1
267 fdi >>= 1
268 if balance[pnode] == 2:
269 new_node = (
270 self._rotate_LR(pnode)
271 if balance[left[pnode]] < 0
272 else self._rotate_L(pnode)
273 )
274 elif balance[pnode] == -2:
275 new_node = (
276 self._rotate_RL(pnode)
277 if balance[right[pnode]] > 0
278 else self._rotate_R(pnode)
279 )
280 elif balance[pnode] != 0:
281 break
282 if new_node:
283 if not path:
284 self.root = new_node
285 return
286 if di & 1:
287 left[path[-1]] = new_node
288 else:
289 right[path[-1]] = new_node
290 if balance[new_node] != 0:
291 break
292 while path:
293 pnode = path.pop()
294 size[pnode] -= 1
295 valsize[pnode] -= lmax_val if fdi & 1 else 1
296 fdi >>= 1
297 return True
298
[docs]
299 def discard(self, key: T, val: int = 1) -> bool:
300 keys, vals, left, right, valsize = (
301 self.key,
302 self.val,
303 self.left,
304 self.right,
305 self.valsize,
306 )
307 path = []
308 di = 0
309 node = self.root
310 while node:
311 if key == keys[node]:
312 break
313 path.append(node)
314 di <<= 1
315 if key < keys[node]:
316 di |= 1
317 node = left[node]
318 else:
319 node = right[node]
320 else:
321 return False
322 if val > vals[node]:
323 val = vals[node] - 1
324 vals[node] -= val
325 valsize[node] -= val
326 for p in path:
327 valsize[p] -= val
328 if vals[node] == 1:
329 self._discard(node, path, di)
330 else:
331 vals[node] -= val
332 valsize[node] -= val
333 for p in path:
334 valsize[p] -= val
335 return True
336
[docs]
337 def discard_all(self, key: T) -> None:
338 self.discard(key, self.count(key))
339
[docs]
340 def remove(self, key: T, val: int = 1) -> None:
341 if self.discard(key, val):
342 return
343 raise KeyError(key)
344
[docs]
345 def add(self, key: T, val: int = 1) -> None:
346 if self.root == 0:
347 self.root = self._make_node(key, val)
348 return
349 left, right, keys, valsize = self.left, self.right, self.key, self.valsize
350 size, balance = self.size, self.balance
351 node = self.root
352 di = 0
353 path = []
354 while node:
355 if key == keys[node]:
356 self.val[node] += val
357 valsize[node] += val
358 for p in path:
359 valsize[p] += val
360 return
361 path.append(node)
362 di <<= 1
363 if key < keys[node]:
364 di |= 1
365 node = left[node]
366 else:
367 node = right[node]
368 if di & 1:
369 left[path[-1]] = self._make_node(key, val)
370 else:
371 right[path[-1]] = self._make_node(key, val)
372 new_node = 0
373 while path:
374 node = path.pop()
375 size[node] += 1
376 valsize[node] += val
377 balance[node] += 1 if di & 1 else -1
378 di >>= 1
379 if balance[node] == 0:
380 break
381 if balance[node] == 2:
382 new_node = (
383 self._rotate_LR(node)
384 if balance[left[node]] < 0
385 else self._rotate_L(node)
386 )
387 break
388 elif balance[node] == -2:
389 new_node = (
390 self._rotate_RL(node)
391 if balance[right[node]] > 0
392 else self._rotate_R(node)
393 )
394 break
395 if new_node:
396 if path:
397 if di & 1:
398 left[path[-1]] = new_node
399 else:
400 right[path[-1]] = new_node
401 else:
402 self.root = new_node
403 for p in path:
404 size[p] += 1
405 valsize[p] += val
406
[docs]
407 def count(self, key: T) -> int:
408 return BSTMultisetArrayBase[AVLTreeMultiset, T].count(self, key)
409
[docs]
410 def le(self, key: T) -> Optional[T]:
411 return BSTMultisetArrayBase[AVLTreeMultiset, T].le(self, key)
412
[docs]
413 def lt(self, key: T) -> Optional[T]:
414 return BSTMultisetArrayBase[AVLTreeMultiset, T].lt(self, key)
415
[docs]
416 def ge(self, key: T) -> Optional[T]:
417 return BSTMultisetArrayBase[AVLTreeMultiset, T].ge(self, key)
418
[docs]
419 def gt(self, key: T) -> Optional[T]:
420 return BSTMultisetArrayBase[AVLTreeMultiset, T].gt(self, key)
421
[docs]
422 def index(self, key: T) -> int:
423 return BSTMultisetArrayBase[AVLTreeMultiset, T].index(self, key)
424
[docs]
425 def index_right(self, key: T) -> int:
426 return BSTMultisetArrayBase[AVLTreeMultiset, T].index_right(self, key)
427
[docs]
428 def index_keys(self, key: T) -> int:
429 keys, left, right, vals, size = (
430 self.key,
431 self.left,
432 self.right,
433 self.val,
434 self.size,
435 )
436 k = 0
437 node = self.root
438 while node:
439 if key == keys[node]:
440 if left[node]:
441 k += size[left[node]]
442 break
443 if key < keys[node]:
444 node = left[node]
445 else:
446 k += size[left[node]] + vals[node]
447 node = right[node]
448 return k
449
[docs]
450 def index_right_keys(self, key: T) -> int:
451 keys, left, right, vals, size = (
452 self.key,
453 self.left,
454 self.right,
455 self.val,
456 self.size,
457 )
458 k = 0
459 node = self.root
460 while node:
461 if key == keys[node]:
462 k += size[left[node]] + vals[node]
463 break
464 if key < keys[node]:
465 node = left[node]
466 else:
467 k += size[left[node]] + vals[node]
468 node = right[node]
469 return k
470
[docs]
471 def get_min(self) -> Optional[T]:
472 if self.root == 0:
473 return
474 left = self.left
475 node = self.root
476 while left[node]:
477 node = left[node]
478 return self.key[node]
479
[docs]
480 def get_max(self) -> Optional[T]:
481 if self.root == 0:
482 return
483 right = self.right
484 node = self.root
485 while right[node]:
486 node = right[node]
487 return self.key[node]
488
[docs]
489 def pop(self, k: int = -1) -> T:
490 left, right, vals, valsize = self.left, self.right, self.val, self.valsize
491 keys = self.key
492 node = self.root
493 if k < 0:
494 k += valsize[node]
495 path = []
496 if k == valsize[node] - 1:
497 while right[node]:
498 path.append(node)
499 node = right[node]
500 x = keys[node]
501 if vals[node] == 1:
502 self._discard(node, path, 0)
503 else:
504 vals[node] -= 1
505 valsize[node] -= 1
506 for p in path:
507 valsize[p] -= 1
508 return x
509 di = 0
510 while True:
511 t = vals[node] + valsize[left[node]]
512 if t - vals[node] <= k < t:
513 x = keys[node]
514 break
515 path.append(node)
516 di <<= 1
517 if t > k:
518 di |= 1
519 node = left[node]
520 else:
521 node = right[node]
522 k -= t
523 if vals[node] == 1:
524 self._discard(node, path, di)
525 else:
526 vals[node] -= 1
527 valsize[node] -= 1
528 for p in path:
529 valsize[p] -= 1
530 return x
531
[docs]
532 def pop_max(self) -> T:
533 assert self
534 return self.pop()
535
[docs]
536 def pop_min(self) -> T:
537 assert self
538 return self.pop(0)
539
[docs]
540 def items(self) -> Iterator[tuple[T, int]]:
541 for i in range(self.len_elm()):
542 yield self._kth_elm_tree(i)
543
[docs]
544 def keys(self) -> Iterator[T]:
545 for i in range(self.len_elm()):
546 yield self._kth_elm_tree(i)[0]
547
[docs]
548 def values(self) -> Iterator[int]:
549 for i in range(self.len_elm()):
550 yield self._kth_elm_tree(i)[1]
551
[docs]
552 def len_elm(self) -> int:
553 return self.size[self.root]
554
[docs]
555 def show(self) -> None:
556 print(
557 "{" + ", ".join(map(lambda x: f"{x[0]}: {x[1]}", self.tolist_items())) + "}"
558 )
559
[docs]
560 def clear(self) -> None:
561 self.root = 0
562
[docs]
563 def get_elm(self, k: int) -> T:
564 return self._kth_elm_tree(k)[0]
565
[docs]
566 def tolist(self) -> list[T]:
567 return BSTMultisetArrayBase[AVLTreeMultiset, T].tolist(self)
568
[docs]
569 def tolist_items(self) -> list[tuple[T, int]]:
570 left, right, keys, vals = self.left, self.right, self.key, self.val
571 node = self.root
572 stack, a = [], []
573 while stack or node:
574 if node:
575 stack.append(node)
576 node = left[node]
577 else:
578 node = stack.pop()
579 a.append((keys[node], vals[node]))
580 node = right[node]
581 return a
582
583 def __getitem__(self, k: int) -> T:
584 return self._kth_elm(k)[0]
585
586 def __contains__(self, key: T):
587 return BSTMultisetArrayBase[AVLTreeMultiset, T].contains(self, key)
588
589 def __iter__(self):
590 self.__iter = 0
591 return self
592
593 def __next__(self):
594 if self.__iter == len(self):
595 raise StopIteration
596 res = self._kth_elm(self.__iter)
597 self.__iter += 1
598 return res
599
600 def __reversed__(self):
601 for i in range(len(self)):
602 yield self._kth_elm(-i - 1)[0]
603
604 def __len__(self):
605 return self.valsize[self.root]
606
607 def __bool__(self):
608 return self.root != 0
609
610 def __str__(self):
611 return "{" + ", ".join(map(str, self.tolist())) + "}"
612
613 def __repr__(self):
614 return f"{self.__class__.__name__}({self.tolist()})"