Source code for titan_pylib.data_structures.avl_tree.avl_tree_multiset

  1from titan_pylib.my_class.ordered_multiset_interface import OrderedMultisetInterface
  2from titan_pylib.my_class.supports_less_than import SupportsLessThan
  3from titan_pylib.data_structures.bst_base.bst_multiset_array_base import (
  4    BSTMultisetArrayBase,
  5)
  6from typing import Generic, Iterable, Iterator, TypeVar, Optional
  7from array import array
  8
  9T = TypeVar("T", bound=SupportsLessThan)
 10
 11
[docs] 12class AVLTreeMultiset(OrderedMultisetInterface, Generic[T]): 13 """ 14 多重集合としての AVL 木です。 15 配列を用いてノードを表現しています。 16 size を持ちます。 17 """ 18 19 def __init__(self, a: Iterable[T] = []): 20 self.root = 0 21 self.key = [0] 22 self.val = [0] 23 self.valsize = [0] 24 self.size = array("I", bytes(4)) 25 self.left = array("I", bytes(4)) 26 self.right = array("I", bytes(4)) 27 self.balance = array("b", bytes(1)) 28 self.end = 1 29 if not isinstance(a, list): 30 a = list(a) 31 if a: 32 self._build(a) 33 34 def _make_node(self, key: T, val: int) -> int: 35 end = self.end 36 if end >= len(self.key): 37 self.key.append(key) 38 self.val.append(val) 39 self.valsize.append(val) 40 self.size.append(1) 41 self.left.append(0) 42 self.right.append(0) 43 self.balance.append(0) 44 else: 45 self.key[end] = key 46 self.val[end] = val 47 self.valsize[end] = val 48 self.end += 1 49 return end 50
[docs] 51 def reserve(self, n: int) -> None: 52 a = [0] * n 53 self.key += a 54 self.val += a 55 self.valsize += a 56 a = array("I", bytes(4 * n)) 57 self.left += a 58 self.right += a 59 self.size += array("I", [1] * n) 60 self.balance += array("b", bytes(n))
61 62 def _build(self, a: list[T]) -> None: 63 left, right, size, valsize, balance = ( 64 self.left, 65 self.right, 66 self.size, 67 self.valsize, 68 self.balance, 69 ) 70 71 def sort(l: int, r: int) -> tuple[int, int]: 72 mid = (l + r) >> 1 73 node = mid 74 hl, hr = 0, 0 75 if l != mid: 76 left[node], hl = sort(l, mid) 77 size[node] += size[left[node]] 78 valsize[node] += valsize[left[node]] 79 if mid + 1 != r: 80 right[node], hr = sort(mid + 1, r) 81 size[node] += size[right[node]] 82 valsize[node] += valsize[right[node]] 83 balance[node] = hl - hr 84 return node, max(hl, hr) + 1 85 86 if not all(a[i] <= a[i + 1] for i in range(len(a) - 1)): 87 a = sorted(a) 88 if not a: 89 return 90 x, y = BSTMultisetArrayBase[AVLTreeMultiset, T]._rle(a) 91 n = len(x) 92 end = self.end 93 self.end += n 94 self.reserve(n) 95 self.key[end : end + n] = x 96 self.val[end : end + n] = y 97 self.valsize[end : end + n] = y 98 self.root = sort(end, n + end)[0] 99 100 def _rotate_L(self, node: int) -> int: 101 left, right, size, valsize, balance = ( 102 self.left, 103 self.right, 104 self.size, 105 self.valsize, 106 self.balance, 107 ) 108 u = left[node] 109 size[u] = size[node] 110 valsize[u] = valsize[node] 111 if left[u] == 0: 112 size[node] -= 1 113 valsize[node] -= self.val[u] 114 else: 115 size[node] -= size[left[u]] + 1 116 valsize[node] -= valsize[left[u]] + self.val[u] 117 left[node] = right[u] 118 right[u] = node 119 if balance[u] == 1: 120 balance[u] = 0 121 balance[node] = 0 122 else: 123 balance[u] = -1 124 balance[node] = 1 125 return u 126 127 def _rotate_R(self, node: int) -> int: 128 left, right, size, valsize, balance = ( 129 self.left, 130 self.right, 131 self.size, 132 self.valsize, 133 self.balance, 134 ) 135 u = right[node] 136 size[u] = size[node] 137 valsize[u] = valsize[node] 138 if right[u] == 0: 139 size[node] -= 1 140 valsize[node] -= self.val[u] 141 else: 142 size[node] -= size[right[u]] + 1 143 valsize[node] -= valsize[right[u]] + self.val[u] 144 right[node] = left[u] 145 left[u] = node 146 if balance[u] == -1: 147 balance[u] = 0 148 balance[node] = 0 149 else: 150 balance[u] = 1 151 balance[node] = -1 152 return u 153 154 def _update_balance(self, node: int) -> None: 155 left, right, balance = self.left, self.right, self.balance 156 if balance[node] == 1: 157 balance[right[node]] = -1 158 balance[left[node]] = 0 159 elif balance[node] == -1: 160 balance[right[node]] = 0 161 balance[left[node]] = 1 162 else: 163 balance[right[node]] = 0 164 balance[left[node]] = 0 165 balance[node] = 0 166 167 def _rotate_LR(self, node: int) -> int: 168 left, right, size, valsize = self.left, self.right, self.size, self.valsize 169 B = left[node] 170 E = right[B] 171 size[E] = size[node] 172 valsize[E] = valsize[node] 173 if right[E] == 0: 174 size[node] -= size[B] 175 valsize[node] -= valsize[B] 176 size[B] -= 1 177 valsize[B] -= self.val[E] 178 else: 179 size[node] -= size[B] - size[right[E]] 180 valsize[node] -= valsize[B] - valsize[right[E]] 181 size[B] -= size[right[E]] + 1 182 valsize[B] -= valsize[right[E]] + self.val[E] 183 right[B] = left[E] 184 left[E] = B 185 left[node] = right[E] 186 right[E] = node 187 self._update_balance(E) 188 return E 189 190 def _rotate_RL(self, node: int) -> int: 191 left, right, size, valsize = self.left, self.right, self.size, self.valsize 192 C = right[node] 193 D = left[C] 194 size[D] = size[node] 195 valsize[D] = valsize[node] 196 if left[D] == 0: 197 size[node] -= size[C] 198 valsize[node] -= valsize[C] 199 size[C] -= 1 200 valsize[C] -= self.val[D] 201 else: 202 size[node] -= size[C] - size[left[D]] 203 valsize[node] -= valsize[C] - valsize[left[D]] 204 size[C] -= size[left[D]] + 1 205 valsize[C] -= valsize[left[D]] + self.val[D] 206 left[C] = right[D] 207 right[D] = C 208 right[node] = left[D] 209 left[D] = node 210 self._update_balance(D) 211 return D 212 213 def _kth_elm(self, k: int) -> tuple[T, int]: 214 return BSTMultisetArrayBase[AVLTreeMultiset, T]._kth_elm(self, k) 215 216 def _kth_elm_tree(self, k: int) -> tuple[T, int]: 217 left, right, vals, size = self.left, self.right, self.val, self.size 218 if k < 0: 219 k += self.len_elm() 220 assert 0 <= k < self.len_elm() 221 node = self.root 222 while True: 223 t = size[left[node]] 224 if t == k: 225 return self.key[node], vals[node] 226 if t > k: 227 node = left[node] 228 else: 229 node = right[node] 230 k -= t + 1 231 232 def _discard(self, node: int, path: list[int], di: int) -> bool: 233 left, right, keys, vals = self.left, self.right, self.key, self.val 234 balance, size, valsize = self.balance, self.size, self.valsize 235 fdi = 0 236 if left[node] and right[node]: 237 path.append(node) 238 di <<= 1 239 di |= 1 240 lmax = left[node] 241 while right[lmax]: 242 path.append(lmax) 243 di <<= 1 244 fdi <<= 1 245 fdi |= 1 246 lmax = right[lmax] 247 lmax_val = vals[lmax] 248 keys[node] = keys[lmax] 249 vals[node] = lmax_val 250 node = lmax 251 cnode = right[node] if left[node] == 0 else left[node] 252 if path: 253 if di & 1: 254 left[path[-1]] = cnode 255 else: 256 right[path[-1]] = cnode 257 else: 258 self.root = cnode 259 return True 260 while path: 261 new_node = 0 262 pnode = path.pop() 263 balance[pnode] -= 1 if di & 1 else -1 264 size[pnode] -= 1 265 valsize[pnode] -= lmax_val if fdi & 1 else 1 266 di >>= 1 267 fdi >>= 1 268 if balance[pnode] == 2: 269 new_node = ( 270 self._rotate_LR(pnode) 271 if balance[left[pnode]] < 0 272 else self._rotate_L(pnode) 273 ) 274 elif balance[pnode] == -2: 275 new_node = ( 276 self._rotate_RL(pnode) 277 if balance[right[pnode]] > 0 278 else self._rotate_R(pnode) 279 ) 280 elif balance[pnode] != 0: 281 break 282 if new_node: 283 if not path: 284 self.root = new_node 285 return 286 if di & 1: 287 left[path[-1]] = new_node 288 else: 289 right[path[-1]] = new_node 290 if balance[new_node] != 0: 291 break 292 while path: 293 pnode = path.pop() 294 size[pnode] -= 1 295 valsize[pnode] -= lmax_val if fdi & 1 else 1 296 fdi >>= 1 297 return True 298
[docs] 299 def discard(self, key: T, val: int = 1) -> bool: 300 keys, vals, left, right, valsize = ( 301 self.key, 302 self.val, 303 self.left, 304 self.right, 305 self.valsize, 306 ) 307 path = [] 308 di = 0 309 node = self.root 310 while node: 311 if key == keys[node]: 312 break 313 path.append(node) 314 di <<= 1 315 if key < keys[node]: 316 di |= 1 317 node = left[node] 318 else: 319 node = right[node] 320 else: 321 return False 322 if val > vals[node]: 323 val = vals[node] - 1 324 vals[node] -= val 325 valsize[node] -= val 326 for p in path: 327 valsize[p] -= val 328 if vals[node] == 1: 329 self._discard(node, path, di) 330 else: 331 vals[node] -= val 332 valsize[node] -= val 333 for p in path: 334 valsize[p] -= val 335 return True
336
[docs] 337 def discard_all(self, key: T) -> None: 338 self.discard(key, self.count(key))
339
[docs] 340 def remove(self, key: T, val: int = 1) -> None: 341 if self.discard(key, val): 342 return 343 raise KeyError(key)
344
[docs] 345 def add(self, key: T, val: int = 1) -> None: 346 if self.root == 0: 347 self.root = self._make_node(key, val) 348 return 349 left, right, keys, valsize = self.left, self.right, self.key, self.valsize 350 size, balance = self.size, self.balance 351 node = self.root 352 di = 0 353 path = [] 354 while node: 355 if key == keys[node]: 356 self.val[node] += val 357 valsize[node] += val 358 for p in path: 359 valsize[p] += val 360 return 361 path.append(node) 362 di <<= 1 363 if key < keys[node]: 364 di |= 1 365 node = left[node] 366 else: 367 node = right[node] 368 if di & 1: 369 left[path[-1]] = self._make_node(key, val) 370 else: 371 right[path[-1]] = self._make_node(key, val) 372 new_node = 0 373 while path: 374 node = path.pop() 375 size[node] += 1 376 valsize[node] += val 377 balance[node] += 1 if di & 1 else -1 378 di >>= 1 379 if balance[node] == 0: 380 break 381 if balance[node] == 2: 382 new_node = ( 383 self._rotate_LR(node) 384 if balance[left[node]] < 0 385 else self._rotate_L(node) 386 ) 387 break 388 elif balance[node] == -2: 389 new_node = ( 390 self._rotate_RL(node) 391 if balance[right[node]] > 0 392 else self._rotate_R(node) 393 ) 394 break 395 if new_node: 396 if path: 397 if di & 1: 398 left[path[-1]] = new_node 399 else: 400 right[path[-1]] = new_node 401 else: 402 self.root = new_node 403 for p in path: 404 size[p] += 1 405 valsize[p] += val
406
[docs] 407 def count(self, key: T) -> int: 408 return BSTMultisetArrayBase[AVLTreeMultiset, T].count(self, key)
409
[docs] 410 def le(self, key: T) -> Optional[T]: 411 return BSTMultisetArrayBase[AVLTreeMultiset, T].le(self, key)
412
[docs] 413 def lt(self, key: T) -> Optional[T]: 414 return BSTMultisetArrayBase[AVLTreeMultiset, T].lt(self, key)
415
[docs] 416 def ge(self, key: T) -> Optional[T]: 417 return BSTMultisetArrayBase[AVLTreeMultiset, T].ge(self, key)
418
[docs] 419 def gt(self, key: T) -> Optional[T]: 420 return BSTMultisetArrayBase[AVLTreeMultiset, T].gt(self, key)
421
[docs] 422 def index(self, key: T) -> int: 423 return BSTMultisetArrayBase[AVLTreeMultiset, T].index(self, key)
424
[docs] 425 def index_right(self, key: T) -> int: 426 return BSTMultisetArrayBase[AVLTreeMultiset, T].index_right(self, key)
427
[docs] 428 def index_keys(self, key: T) -> int: 429 keys, left, right, vals, size = ( 430 self.key, 431 self.left, 432 self.right, 433 self.val, 434 self.size, 435 ) 436 k = 0 437 node = self.root 438 while node: 439 if key == keys[node]: 440 if left[node]: 441 k += size[left[node]] 442 break 443 if key < keys[node]: 444 node = left[node] 445 else: 446 k += size[left[node]] + vals[node] 447 node = right[node] 448 return k
449
[docs] 450 def index_right_keys(self, key: T) -> int: 451 keys, left, right, vals, size = ( 452 self.key, 453 self.left, 454 self.right, 455 self.val, 456 self.size, 457 ) 458 k = 0 459 node = self.root 460 while node: 461 if key == keys[node]: 462 k += size[left[node]] + vals[node] 463 break 464 if key < keys[node]: 465 node = left[node] 466 else: 467 k += size[left[node]] + vals[node] 468 node = right[node] 469 return k
470
[docs] 471 def get_min(self) -> Optional[T]: 472 if self.root == 0: 473 return 474 left = self.left 475 node = self.root 476 while left[node]: 477 node = left[node] 478 return self.key[node]
479
[docs] 480 def get_max(self) -> Optional[T]: 481 if self.root == 0: 482 return 483 right = self.right 484 node = self.root 485 while right[node]: 486 node = right[node] 487 return self.key[node]
488
[docs] 489 def pop(self, k: int = -1) -> T: 490 left, right, vals, valsize = self.left, self.right, self.val, self.valsize 491 keys = self.key 492 node = self.root 493 if k < 0: 494 k += valsize[node] 495 path = [] 496 if k == valsize[node] - 1: 497 while right[node]: 498 path.append(node) 499 node = right[node] 500 x = keys[node] 501 if vals[node] == 1: 502 self._discard(node, path, 0) 503 else: 504 vals[node] -= 1 505 valsize[node] -= 1 506 for p in path: 507 valsize[p] -= 1 508 return x 509 di = 0 510 while True: 511 t = vals[node] + valsize[left[node]] 512 if t - vals[node] <= k < t: 513 x = keys[node] 514 break 515 path.append(node) 516 di <<= 1 517 if t > k: 518 di |= 1 519 node = left[node] 520 else: 521 node = right[node] 522 k -= t 523 if vals[node] == 1: 524 self._discard(node, path, di) 525 else: 526 vals[node] -= 1 527 valsize[node] -= 1 528 for p in path: 529 valsize[p] -= 1 530 return x
531
[docs] 532 def pop_max(self) -> T: 533 assert self 534 return self.pop()
535
[docs] 536 def pop_min(self) -> T: 537 assert self 538 return self.pop(0)
539
[docs] 540 def items(self) -> Iterator[tuple[T, int]]: 541 for i in range(self.len_elm()): 542 yield self._kth_elm_tree(i)
543
[docs] 544 def keys(self) -> Iterator[T]: 545 for i in range(self.len_elm()): 546 yield self._kth_elm_tree(i)[0]
547
[docs] 548 def values(self) -> Iterator[int]: 549 for i in range(self.len_elm()): 550 yield self._kth_elm_tree(i)[1]
551
[docs] 552 def len_elm(self) -> int: 553 return self.size[self.root]
554
[docs] 555 def show(self) -> None: 556 print( 557 "{" + ", ".join(map(lambda x: f"{x[0]}: {x[1]}", self.tolist_items())) + "}" 558 )
559
[docs] 560 def clear(self) -> None: 561 self.root = 0
562
[docs] 563 def get_elm(self, k: int) -> T: 564 return self._kth_elm_tree(k)[0]
565
[docs] 566 def tolist(self) -> list[T]: 567 return BSTMultisetArrayBase[AVLTreeMultiset, T].tolist(self)
568
[docs] 569 def tolist_items(self) -> list[tuple[T, int]]: 570 left, right, keys, vals = self.left, self.right, self.key, self.val 571 node = self.root 572 stack, a = [], [] 573 while stack or node: 574 if node: 575 stack.append(node) 576 node = left[node] 577 else: 578 node = stack.pop() 579 a.append((keys[node], vals[node])) 580 node = right[node] 581 return a
582 583 def __getitem__(self, k: int) -> T: 584 return self._kth_elm(k)[0] 585 586 def __contains__(self, key: T): 587 return BSTMultisetArrayBase[AVLTreeMultiset, T].contains(self, key) 588 589 def __iter__(self): 590 self.__iter = 0 591 return self 592 593 def __next__(self): 594 if self.__iter == len(self): 595 raise StopIteration 596 res = self._kth_elm(self.__iter) 597 self.__iter += 1 598 return res 599 600 def __reversed__(self): 601 for i in range(len(self)): 602 yield self._kth_elm(-i - 1)[0] 603 604 def __len__(self): 605 return self.valsize[self.root] 606 607 def __bool__(self): 608 return self.root != 0 609 610 def __str__(self): 611 return "{" + ", ".join(map(str, self.tolist())) + "}" 612 613 def __repr__(self): 614 return f"{self.__class__.__name__}({self.tolist()})"