Source code for titan_pylib.data_structures.avl_tree.avl_tree_multiset3

  1from titan_pylib.my_class.ordered_multiset_interface import OrderedMultisetInterface
  2from titan_pylib.my_class.supports_less_than import SupportsLessThan
  3from typing import Generic, Iterable, Iterator, TypeVar, Optional
  4
  5T = TypeVar("T", bound=SupportsLessThan)
  6
  7
[docs] 8class AVLTreeMultiset3(OrderedMultisetInterface, Generic[T]): 9 """ 10 多重集合としての AVL 木です。 11 ``class Node()`` を用いています。 12 """ 13
[docs] 14 class Node: 15 16 def __init__(self, key: T, val: int): 17 self.key: T = key 18 self.val: int = val 19 self.valsize: int = val 20 self.size: int = 1 21 self.left: Optional["AVLTreeMultiset3.Node"] = None 22 self.right: Optional["AVLTreeMultiset3.Node"] = None 23 self.balance: int = 0 24 25 def __str__(self): 26 if self.left is None and self.right is None: 27 return f"key:{self.key, self.val, self.size, self.valsize}\n" 28 return f"key:{self.key, self.val, self.size, self.valsize},\n left:{self.left},\n right:{self.right}\n"
29 30 def __init__(self, a: Iterable[T] = []): 31 self.node: Optional["AVLTreeMultiset3.Node"] = None 32 if a: 33 self._build(a) 34 35 def _rle(self, L: list[T]) -> tuple[list[T], list[int]]: 36 x, y = [L[0]], [1] 37 for i, a in enumerate(L): 38 if i == 0: 39 continue 40 if a == x[-1]: 41 y[-1] += 1 42 continue 43 x.append(a) 44 y.append(1) 45 return x, y 46 47 def _build(self, a: Iterable[T]) -> None: 48 Node = AVLTreeMultiset3.Node 49 50 def sort(l: int, r: int) -> tuple[Node, int]: 51 mid = (l + r) >> 1 52 node = Node(x[mid], y[mid]) 53 h = 0 54 if l != mid: 55 left, hl = sort(l, mid) 56 node.left = left 57 node.size += left.size 58 node.valsize += left.valsize 59 node.balance = hl 60 h = hl 61 if mid + 1 != r: 62 right, hr = sort(mid + 1, r) 63 node.right = right 64 node.size += right.size 65 node.valsize += right.valsize 66 node.balance -= hr 67 if hr > h: 68 h = hr 69 return node, h + 1 70 71 a = sorted(a) 72 if not a: 73 return 74 x, y = self._rle(a) 75 self.node = sort(0, len(x))[0] 76 77 def _rotate_L(self, node: Node) -> Node: 78 u = node.left 79 u.size = node.size 80 u.valsize = node.valsize 81 if u.left is None: 82 node.size -= 1 83 node.valsize -= u.val 84 else: 85 node.size -= u.left.size + 1 86 node.valsize -= u.left.valsize + u.val 87 node.left = u.right 88 u.right = node 89 if u.balance == 1: 90 u.balance = 0 91 node.balance = 0 92 else: 93 u.balance = -1 94 node.balance = 1 95 return u 96 97 def _rotate_R(self, node: Node) -> Node: 98 u = node.right 99 u.size = node.size 100 u.valsize = node.valsize 101 if u.right is None: 102 node.size -= 1 103 node.valsize -= u.val 104 else: 105 node.size -= u.right.size + 1 106 node.valsize -= u.right.valsize + u.val 107 node.right = u.left 108 u.left = node 109 if u.balance == -1: 110 u.balance = 0 111 node.balance = 0 112 else: 113 u.balance = 1 114 node.balance = -1 115 return u 116 117 def _update_balance(self, node: Node) -> None: 118 if node.balance == 1: 119 node.right.balance = -1 120 node.left.balance = 0 121 elif node.balance == -1: 122 node.right.balance = 0 123 node.left.balance = 1 124 else: 125 node.right.balance = 0 126 node.left.balance = 0 127 node.balance = 0 128 129 def _rotate_LR(self, node: Node) -> Node: 130 B = node.left 131 E = B.right 132 E.size = node.size 133 E.valsize = node.valsize 134 if E.right is None: 135 node.size -= B.size 136 node.valsize -= B.valsize 137 B.size -= 1 138 B.valsize -= E.val 139 else: 140 node.size -= B.size - E.right.size 141 node.valsize -= B.valsize - E.right.valsize 142 B.size -= E.right.size + 1 143 B.valsize -= E.right.valsize + E.val 144 B.right = E.left 145 E.left = B 146 node.left = E.right 147 E.right = node 148 self._update_balance(E) 149 return E 150 151 def _rotate_RL(self, node: Node) -> Node: 152 C = node.right 153 D = C.left 154 D.size = node.size 155 D.valsize = node.valsize 156 if D.left is None: 157 node.size -= C.size 158 node.valsize -= C.valsize 159 C.size -= 1 160 C.valsize -= D.val 161 else: 162 node.size -= C.size - D.left.size 163 node.valsize -= C.valsize - D.left.valsize 164 C.size -= D.left.size + 1 165 C.valsize -= D.left.valsize + D.val 166 C.left = D.right 167 D.right = C 168 node.right = D.left 169 D.left = node 170 self._update_balance(D) 171 return D 172 173 def _kth_elm(self, k: int) -> tuple[T, int]: 174 if k < 0: 175 k += len(self) 176 node = self.node 177 while True: 178 t = node.val if node.left is None else node.val + node.left.valsize 179 if t - node.val <= k < t: 180 return node.key, node.val 181 elif t > k: 182 node = node.left 183 else: 184 node = node.right 185 k -= t 186 187 def _kth_elm_tree(self, k: int) -> tuple[T, int]: 188 if k < 0: 189 k += self.len_elm() 190 assert 0 <= k < self.len_elm() 191 node = self.node 192 while True: 193 t = 0 if node.left is None else node.left.size 194 if t == k: 195 return node.key, node.val 196 elif t > k: 197 node = node.left 198 else: 199 node = node.right 200 k -= t + 1 201 202 def _discard(self, node: Node, path: list[Node], di: int) -> bool: 203 fdi = 0 204 if node.left is not None and node.right is not None: 205 path.append(node) 206 di <<= 1 207 di |= 1 208 lmax = node.left 209 while lmax.right is not None: 210 path.append(lmax) 211 di <<= 1 212 fdi <<= 1 213 fdi |= 1 214 lmax = lmax.right 215 lmax_val = lmax.val 216 node.key = lmax.key 217 node.val = lmax_val 218 node = lmax 219 cnode = node.right if node.left is None else node.left 220 if path: 221 if di & 1: 222 path[-1].left = cnode 223 else: 224 path[-1].right = cnode 225 else: 226 self.node = cnode 227 return True 228 while path: 229 new_node = None 230 pnode = path.pop() 231 pnode.balance -= 1 if di & 1 else -1 232 pnode.size -= 1 233 pnode.valsize -= lmax_val if fdi & 1 else 1 234 di >>= 1 235 fdi >>= 1 236 if pnode.balance == 2: 237 new_node = ( 238 self._rotate_LR(pnode) 239 if pnode.left.balance < 0 240 else self._rotate_L(pnode) 241 ) 242 elif pnode.balance == -2: 243 new_node = ( 244 self._rotate_RL(pnode) 245 if pnode.right.balance > 0 246 else self._rotate_R(pnode) 247 ) 248 elif pnode.balance != 0: 249 break 250 if new_node is not None: 251 if not path: 252 self.node = new_node 253 return 254 if di & 1: 255 path[-1].left = new_node 256 else: 257 path[-1].right = new_node 258 if new_node.balance != 0: 259 break 260 while path: 261 pnode = path.pop() 262 pnode.size -= 1 263 pnode.valsize -= lmax_val if fdi & 1 else 1 264 fdi >>= 1 265 return True 266
[docs] 267 def discard(self, key: T, val: int = 1) -> bool: 268 path = [] 269 di = 0 270 node = self.node 271 while node is not None: 272 if key == node.key: 273 break 274 elif key < node.key: 275 path.append(node) 276 di <<= 1 277 di |= 1 278 node = node.left 279 else: 280 path.append(node) 281 di <<= 1 282 node = node.right 283 else: 284 return False 285 if val > node.val: 286 val = node.val - 1 287 node.val -= val 288 node.valsize -= val 289 for p in path: 290 p.valsize -= val 291 if node.val == 1: 292 self._discard(node, path, di) 293 else: 294 node.val -= val 295 node.valsize -= val 296 for p in path: 297 p.valsize -= val 298 return True
299
[docs] 300 def discard_all(self, key: T) -> None: 301 self.discard(key, self.count(key))
302
[docs] 303 def remove(self, key: T, val: int = 1) -> None: 304 if self.discard(key, val): 305 return 306 raise KeyError(key)
307
[docs] 308 def add(self, key: T, val: int = 1) -> None: 309 if self.node is None: 310 self.node = AVLTreeMultiset3.Node(key, val) 311 return 312 pnode = self.node 313 di = 0 314 path = [] 315 while pnode is not None: 316 if key == pnode.key: 317 pnode.val += val 318 pnode.valsize += val 319 for p in path: 320 p.valsize += val 321 return 322 elif key < pnode.key: 323 path.append(pnode) 324 di <<= 1 325 di |= 1 326 pnode = pnode.left 327 else: 328 path.append(pnode) 329 di <<= 1 330 pnode = pnode.right 331 if di & 1: 332 path[-1].left = AVLTreeMultiset3.Node(key, val) 333 else: 334 path[-1].right = AVLTreeMultiset3.Node(key, val) 335 new_node = None 336 while path: 337 pnode = path.pop() 338 pnode.size += 1 339 pnode.valsize += val 340 pnode.balance += 1 if di & 1 else -1 341 di >>= 1 342 if pnode.balance == 0: 343 break 344 if pnode.balance == 2: 345 new_node = ( 346 self._rotate_LR(pnode) 347 if pnode.left.balance < 0 348 else self._rotate_L(pnode) 349 ) 350 break 351 elif pnode.balance == -2: 352 new_node = ( 353 self._rotate_RL(pnode) 354 if pnode.right.balance > 0 355 else self._rotate_R(pnode) 356 ) 357 break 358 if new_node is not None: 359 if path: 360 if di & 1: 361 path[-1].left = new_node 362 else: 363 path[-1].right = new_node 364 else: 365 self.node = new_node 366 for p in path: 367 p.size += 1 368 p.valsize += val
369
[docs] 370 def count(self, key: T) -> int: 371 node = self.node 372 while node is not None: 373 if node.key == key: 374 return node.val 375 elif key < node.key: 376 node = node.left 377 else: 378 node = node.right 379 return 0
380
[docs] 381 def le(self, key: T) -> Optional[T]: 382 res = None 383 node = self.node 384 while node is not None: 385 if key == node.key: 386 res = key 387 break 388 elif key < node.key: 389 node = node.left 390 else: 391 res = node.key 392 node = node.right 393 return res
394
[docs] 395 def lt(self, key: T) -> Optional[T]: 396 res = None 397 node = self.node 398 while node is not None: 399 if key <= node.key: 400 node = node.left 401 else: 402 res = node.key 403 node = node.right 404 return res
405
[docs] 406 def ge(self, key: T) -> Optional[T]: 407 res = None 408 node = self.node 409 while node is not None: 410 if key == node.key: 411 res = key 412 break 413 elif key < node.key: 414 res = node.key 415 node = node.left 416 else: 417 node = node.right 418 return res
419
[docs] 420 def gt(self, key: T) -> Optional[T]: 421 res = None 422 node = self.node 423 while node is not None: 424 if key < node.key: 425 res = node.key 426 node = node.left 427 else: 428 node = node.right 429 return res
430
[docs] 431 def index(self, key: T) -> int: 432 k = 0 433 node = self.node 434 while node is not None: 435 if key == node.key: 436 if node.left is not None: 437 k += node.left.valsize 438 break 439 elif key < node.key: 440 node = node.left 441 else: 442 k += node.val if node.left is None else node.left.valsize + node.val 443 node = node.right 444 return k
445
[docs] 446 def index_right(self, key: T) -> int: 447 k = 0 448 node = self.node 449 while node is not None: 450 if key == node.key: 451 k += node.val if node.left is None else node.left.valsize + node.val 452 break 453 elif key < node.key: 454 node = node.left 455 else: 456 k += node.val if node.left is None else node.left.valsize + node.val 457 node = node.right 458 return k
459
[docs] 460 def index_keys(self, key: T) -> int: 461 k = 0 462 node = self.node 463 while node: 464 if key == node.key: 465 if node.left is not None: 466 k += node.left.size 467 break 468 elif key < node.key: 469 node = node.left 470 else: 471 k += node.val if node.left is None else node.left.size + node.val 472 node = node.right 473 return k
474
[docs] 475 def index_right_keys(self, key: T) -> int: 476 k = 0 477 node = self.node 478 while node: 479 if key == node.key: 480 k += node.val if node.left is None else node.left.size + node.val 481 break 482 elif key < node.key: 483 node = node.left 484 else: 485 k += node.val if node.left is None else node.left.size + node.val 486 node = node.right 487 return k
488
[docs] 489 def get_min(self) -> Optional[T]: 490 if self.node is None: 491 return 492 node = self.node 493 while node.left is not None: 494 node = node.left 495 return node.key
496
[docs] 497 def get_max(self) -> Optional[T]: 498 if self.node is None: 499 return 500 node = self.node 501 while node.right is not None: 502 node = node.right 503 return node.key
504
[docs] 505 def pop(self, k: int = -1) -> T: 506 if k < 0: 507 k += self.node.valsize 508 node = self.node 509 path = [] 510 if k == self.node.valsize - 1: 511 while node.right is not None: 512 path.append(node) 513 node = node.right 514 x = node.key 515 if node.val == 1: 516 self._discard(node, path, 0) 517 else: 518 node.val -= 1 519 node.valsize -= 1 520 for p in path: 521 p.valsize -= 1 522 return x 523 di = 0 524 while True: 525 t = node.val if node.left is None else node.val + node.left.valsize 526 if t - node.val <= k < t: 527 x = node.key 528 break 529 elif t > k: 530 path.append(node) 531 di <<= 1 532 di |= 1 533 node = node.left 534 else: 535 path.append(node) 536 di <<= 1 537 node = node.right 538 k -= t 539 if node.val == 1: 540 self._discard(node, path, di) 541 else: 542 node.val -= 1 543 node.valsize -= 1 544 for p in path: 545 p.valsize -= 1 546 return x
547
[docs] 548 def pop_max(self) -> T: 549 assert self 550 return self.pop()
551
[docs] 552 def pop_min(self) -> T: 553 node = self.node 554 path = [] 555 while node.left is not None: 556 path.append(node) 557 node = node.left 558 x = node.key 559 if node.val == 1: 560 self._discard(node, path, (1 << len(path)) - 1) 561 else: 562 node.val -= 1 563 node.valsize -= 1 564 for p in path: 565 p.valsize -= 1 566 return x
567
[docs] 568 def items(self) -> Iterator[tuple[T, int]]: 569 for i in range(self.len_elm()): 570 yield self._kth_elm_tree(i)
571
[docs] 572 def keys(self) -> Iterator[T]: 573 for i in range(self.len_elm()): 574 yield self._kth_elm_tree(i)[0]
575
[docs] 576 def values(self) -> Iterator[int]: 577 for i in range(self.len_elm()): 578 yield self._kth_elm_tree(i)[1]
579
[docs] 580 def len_elm(self) -> int: 581 return 0 if self.node is None else self.node.size
582
[docs] 583 def show(self) -> None: 584 print( 585 "{" + ", ".join(map(lambda x: f"{x[0]}: {x[1]}", self.tolist_items())) + "}" 586 )
587
[docs] 588 def clear(self) -> None: 589 self.node = None
590
[docs] 591 def get_elm(self, k: int) -> T: 592 return self._kth_elm_tree(k)[0]
593
[docs] 594 def tolist(self) -> list[T]: 595 a = [] 596 if self.node is None: 597 return a 598 599 def rec(node): 600 if node.left is not None: 601 rec(node.left) 602 a.extend([node.key] * node.val) 603 if node.right is not None: 604 rec(node.right) 605 606 rec(self.node) 607 return a
608
[docs] 609 def tolist_items(self) -> list[tuple[T, int]]: 610 a = [] 611 if self.node is None: 612 return a 613 614 def rec(node): 615 if node.left is not None: 616 rec(node.left) 617 a.append((node.key, node.val)) 618 if node.right is not None: 619 rec(node.right) 620 621 rec(self.node) 622 return a
623 624 def __getitem__(self, k: int): 625 return self._kth_elm(k)[0] 626 627 def __contains__(self, key: T): 628 node = self.node 629 while node: 630 if node.key == key: 631 return True 632 node = node.left if key < node.key else node.right 633 return False 634 635 def __iter__(self): 636 self.__iter = 0 637 return self 638 639 def __next__(self): 640 if self.__iter == len(self): 641 raise StopIteration 642 res = self._kth_elm(self.__iter) 643 self.__iter += 1 644 return res 645 646 def __reversed__(self): 647 for i in range(len(self)): 648 yield self._kth_elm(-i - 1)[0] 649 650 def __len__(self): 651 return 0 if self.node is None else self.node.valsize 652 653 def __bool__(self): 654 return self.node is not None 655 656 def __str__(self): 657 return "{" + ", ".join(map(str, self.tolist())) + "}" 658 659 def __repr__(self): 660 return f"AVLTreeMultiset3({self.tolist()})"