Source code for titan_pylib.data_structures.splay_tree.splay_tree_multiset_sum

  1import sys
  2from typing import Iterator, Optional, Generic, Iterable, TypeVar
  3
  4T = TypeVar("T")
  5
  6
[docs] 7class SplayTreeMultisetSum(Generic[T]): 8
[docs] 9 class Node: 10 11 def __init__(self, key, cnt: int) -> None: 12 self.data = key * cnt 13 self.key = key 14 self.size = 1 15 self.cnt = cnt 16 self.cntsize = cnt 17 self.left = None 18 self.right = None 19 20 def __str__(self) -> str: 21 if self.left is None and self.right is None: 22 return f"key:{self.key, self.size, self.cnt, self.cntsize}\n" 23 return f"key:{self.key, self.size, self.cnt, self.cntsize},\n left:{self.left},\n right:{self.right}\n" 24 25 __repr__ = __str__
26 27 def __init__(self, e: T, a: Iterable[T] = [], _node=None) -> None: 28 self.node = _node 29 self.e = e 30 if a: 31 self._build(a) 32 33 def _build(self, a: Iterable[T]) -> None: 34 Node = SplayTreeMultisetSum.Node 35 36 def sort(l: int, r: int) -> SplayTreeMultisetSum.Node: 37 mid = (l + r) >> 1 38 node = Node(key[mid], cnt[mid]) 39 if l != mid: 40 node.left = sort(l, mid) 41 if mid + 1 != r: 42 node.right = sort(mid + 1, r) 43 self._update(node) 44 return node 45 46 key, cnt = self._rle(sorted(a)) 47 self.node = sort(0, len(key)) 48 49 def _rle(self, a: list[T]) -> tuple[list[T], list[int]]: 50 x = [] 51 y = [] 52 x.append(a[0]) 53 y.append(1) 54 for i, e in enumerate(a): 55 if i == 0: 56 continue 57 if e == x[-1]: 58 y[-1] += 1 59 continue 60 x.append(e) 61 y.append(1) 62 return x, y 63 64 def _update(self, node: Node) -> None: 65 if node.left is None: 66 if node.right is None: 67 node.size = 1 68 node.cntsize = node.cnt 69 node.data = node.key * node.cnt 70 else: 71 node.size = 1 + node.right.size 72 node.cntsize = node.cnt + node.right.cntsize 73 node.data = node.key * node.cnt + node.right.data 74 else: 75 if node.right is None: 76 node.size = 1 + node.left.size 77 node.cntsize = node.cnt + node.left.cntsize 78 node.data = node.key * node.cnt + node.left.data 79 else: 80 node.size = 1 + node.left.size + node.right.size 81 node.cntsize = node.cnt + node.left.cntsize + node.right.cntsize 82 node.data = node.key * node.cnt + node.left.data + node.right.data 83 84 def _splay(self, path: list[Node], d: int) -> Node: 85 for _ in range(len(path) >> 1): 86 node = path.pop() 87 pnode = path.pop() 88 if d & 1 == d >> 1 & 1: 89 if d & 1: 90 tmp = node.left 91 node.left = tmp.right 92 tmp.right = node 93 pnode.left = node.right 94 node.right = pnode 95 else: 96 tmp = node.right 97 node.right = tmp.left 98 tmp.left = node 99 pnode.right = node.left 100 node.left = pnode 101 else: 102 if d & 1: 103 tmp = node.left 104 node.left = tmp.right 105 pnode.right = tmp.left 106 tmp.right = node 107 tmp.left = pnode 108 else: 109 tmp = node.right 110 node.right = tmp.left 111 pnode.left = tmp.right 112 tmp.left = node 113 tmp.right = pnode 114 self._update(pnode) 115 self._update(node) 116 self._update(tmp) 117 if not path: 118 return tmp 119 d >>= 2 120 if d & 1: 121 path[-1].left = tmp 122 else: 123 path[-1].right = tmp 124 gnode = path[0] 125 if d & 1: 126 node = gnode.left 127 gnode.left = node.right 128 node.right = gnode 129 else: 130 node = gnode.right 131 gnode.right = node.left 132 node.left = gnode 133 self._update(gnode) 134 self._update(node) 135 return node 136 137 def _set_search_splay(self, key: T) -> None: 138 node = self.node 139 if node is None or node.key == key: 140 return 141 path = [] 142 d = 0 143 while True: 144 if node.key == key: 145 break 146 elif key < node.key: 147 if node.left is None: 148 break 149 path.append(node) 150 d <<= 1 151 d |= 1 152 node = node.left 153 else: 154 if node.right is None: 155 break 156 path.append(node) 157 d <<= 1 158 node = node.right 159 if path: 160 self.node = self._splay(path, d) 161 162 def _set_kth_elm_splay(self, k: int) -> None: 163 if k < 0: 164 k += self.__len__() 165 d = 0 166 node = self.node 167 path = [] 168 while True: 169 t = node.cnt if node.left is None else node.cnt + node.left.cntsize 170 if t - node.cnt <= k < t: 171 if path: 172 self.node = self._splay(path, d) 173 break 174 elif t > k: 175 path.append(node) 176 d <<= 1 177 d |= 1 178 node = node.left 179 else: 180 path.append(node) 181 d <<= 1 182 node = node.right 183 k -= t 184 185 def _set_kth_elm_tree_splay(self, k: int) -> None: 186 if k < 0: 187 k += self.len_elm() 188 assert 0 <= k < self.len_elm() 189 d = 0 190 node = self.node 191 path = [] 192 while True: 193 t = 0 if node.left is None else node.left.size 194 if t == k: 195 if path: 196 self.node = self._splay(path, d) 197 return 198 elif t > k: 199 path.append(node) 200 d <<= 1 201 d |= 1 202 node = node.left 203 else: 204 path.append(node) 205 d <<= 1 206 node = node.right 207 k -= t + 1 208 209 def _get_min_splay(self, node: Node) -> Node: 210 if node is None or node.left is None: 211 return node 212 path = [] 213 while node.left is not None: 214 path.append(node) 215 node = node.left 216 return self._splay(path, (1 << len(path)) - 1) 217 218 def _get_max_splay(self, node: Node) -> Node: 219 if node is None or node.right is None: 220 return node 221 path = [] 222 while node.right is not None: 223 path.append(node) 224 node = node.right 225 return self._splay(path, 0) 226
[docs] 227 def add(self, key: T, cnt: int = 1) -> None: 228 if self.node is None: 229 self.node = SplayTreeMultisetSum.Node(key, cnt) 230 return 231 self._set_search_splay(key) 232 if self.node.key == key: 233 self.node.cnt += cnt 234 self._update(self.node) 235 return 236 node = SplayTreeMultisetSum.Node(key, cnt) 237 if key < self.node.key: 238 node.left = self.node.left 239 node.right = self.node 240 self.node.left = None 241 self._update(node.right) 242 else: 243 node.left = self.node 244 node.right = self.node.right 245 self.node.right = None 246 self._update(node.left) 247 self._update(node) 248 self.node = node 249 return
250
[docs] 251 def discard(self, key: T, cnt: int = 1) -> bool: 252 if self.node is None: 253 return False 254 self._set_search_splay(key) 255 if self.node.key != key: 256 return False 257 if self.node.cnt > cnt: 258 self.node.cnt -= cnt 259 self._update(self.node) 260 return True 261 if self.node.left is None: 262 self.node = self.node.right 263 elif self.node.right is None: 264 self.node = self.node.left 265 else: 266 node = self._get_min_splay(self.node.right) 267 node.left = self.node.left 268 self._update(node) 269 self.node = node 270 return True
271
[docs] 272 def discard_all(self, key: T) -> bool: 273 return self.discard(key, self.count(key))
274
[docs] 275 def count(self, key: T) -> int: 276 if self.node is None: 277 return 0 278 self._set_search_splay(key) 279 return self.node.cnt if self.node.key == key else 0
280
[docs] 281 def le(self, key: T) -> Optional[T]: 282 node = self.node 283 if node is None: 284 return None 285 path = [] 286 d = 0 287 res = None 288 while True: 289 if node.key == key: 290 res = key 291 break 292 elif key < node.key: 293 if node.left is None: 294 break 295 path.append(node) 296 d <<= 1 297 d |= 1 298 node = node.left 299 else: 300 res = node.key 301 if node.right is None: 302 break 303 path.append(node) 304 d <<= 1 305 node = node.right 306 if path: 307 self.node = self._splay(path, d) 308 return res
309
[docs] 310 def lt(self, key: T) -> Optional[T]: 311 node = self.node 312 path = [] 313 d = 0 314 res = None 315 while node is not None: 316 if key <= node.key: 317 path.append(node) 318 d <<= 1 319 d |= 1 320 node = node.left 321 else: 322 path.append(node) 323 d <<= 1 324 res = node.key 325 node = node.right 326 else: 327 if path: 328 path.pop() 329 d >>= 1 330 if path: 331 self.node = self._splay(path, d) 332 return res
333
[docs] 334 def ge(self, key: T) -> Optional[T]: 335 node = self.node 336 if node is None: 337 return None 338 path = [] 339 d = 0 340 res = None 341 while True: 342 if node.key == key: 343 res = node.key 344 break 345 elif key < node.key: 346 res = node.key 347 if node.left is None: 348 break 349 path.append(node) 350 d <<= 1 351 d |= 1 352 node = node.left 353 else: 354 if node.right is None: 355 break 356 path.append(node) 357 d <<= 1 358 node = node.right 359 if path: 360 self.node = self._splay(path, d) 361 return res
362
[docs] 363 def gt(self, key: T) -> Optional[T]: 364 node = self.node 365 path = [] 366 d = 0 367 res = None 368 while node is not None: 369 if key < node.key: 370 path.append(node) 371 d <<= 1 372 d |= 1 373 res = node.key 374 node = node.left 375 else: 376 path.append(node) 377 d <<= 1 378 node = node.right 379 else: 380 if path: 381 path.pop() 382 d >>= 1 383 if path: 384 self.node = self._splay(path, d) 385 return res
386
[docs] 387 def index(self, key: T) -> int: 388 if self.node is None: 389 return 0 390 self._set_search_splay(key) 391 res = 0 if self.node.left is None else self.node.left.cntsize 392 if self.node.key < key: 393 res += self.node.cnt 394 return res
395
[docs] 396 def index_right(self, key: T) -> int: 397 if self.node is None: 398 return 0 399 self._set_search_splay(key) 400 res = 0 if self.node.left is None else self.node.left.cntsize 401 if self.node.key <= key: 402 res += self.node.cnt 403 return res
404
[docs] 405 def index_keys(self, key: T) -> int: 406 if self.node is None: 407 return 0 408 self._set_search_splay(key) 409 res = 0 if self.node.left is None else self.node.left.size 410 if self.node.key < key: 411 res += 1 412 return res
413
[docs] 414 def index_right_keys(self, key: T) -> int: 415 if self.node is None: 416 return 0 417 self._set_search_splay(key) 418 res = 0 if self.node.left is None else self.node.left.size 419 if self.node.key <= key: 420 res += 1 421 return res
422
[docs] 423 def pop(self, k: int = -1) -> T: 424 self._set_kth_elm_splay(k) 425 res = self.node.key 426 self.discard(res) 427 return res
428
[docs] 429 def pop_max(self) -> T: 430 return self.pop()
431
[docs] 432 def pop_min(self) -> T: 433 return self.pop(0)
434
[docs] 435 def tolist(self) -> list[T]: 436 a = [] 437 if self.node is None: 438 return a 439 if sys.getrecursionlimit() < self.len_elm(): 440 sys.setrecursionlimit(self.len_elm() + 1) 441 442 def rec(node): 443 if node.left is not None: 444 rec(node.left) 445 for _ in range(node.cnt): 446 a.append(node.key) 447 if node.right is not None: 448 rec(node.right) 449 450 rec(self.node) 451 return a
452
[docs] 453 def tolist_items(self) -> list[tuple[T, int]]: 454 a = [] 455 if self.node is None: 456 return a 457 if sys.getrecursionlimit() < self.len_elm(): 458 sys.setrecursionlimit(self.len_elm() + 1) 459 460 def rec(node): 461 if node.left is not None: 462 rec(node.left) 463 a.append((node.key, node.cnt)) 464 if node.right is not None: 465 rec(node.right) 466 467 rec(self.node) 468 return a
469
[docs] 470 def get_elm(self, k: int) -> T: 471 assert -self.len_elm() <= k < self.len_elm() 472 self._set_kth_elm_tree_splay(k) 473 return self.node.key
474
[docs] 475 def items(self) -> Iterator[tuple[T, int]]: 476 for i in range(self.len_elm()): 477 self._set_kth_elm_tree_splay(i) 478 yield self.node.key, self.node.cnt
479
[docs] 480 def keys(self) -> Iterator[T]: 481 for i in range(self.len_elm()): 482 self._set_kth_elm_tree_splay(i) 483 yield self.node.key
484
[docs] 485 def cntues(self) -> Iterator[int]: 486 for i in range(self.len_elm()): 487 self._set_kth_elm_tree_splay(i) 488 yield self.node.cnt
489
[docs] 490 def len_elm(self) -> int: 491 return 0 if self.node is None else self.node.size
492
[docs] 493 def show(self) -> None: 494 print( 495 "{" + ", ".join(map(lambda x: f"{x[0]}: {x[1]}", self.tolist_items())) + "}" 496 )
497
[docs] 498 def clear(self) -> None: 499 self.node = None
500
[docs] 501 def get_sum(self) -> T: 502 if self.node is None: 503 return self.e 504 return self.node.data
505
[docs] 506 def merge(self, other: "SplayTreeMultisetSum") -> None: 507 if self.node is None: 508 self.node = other.node 509 return 510 if other.node is None: 511 return 512 self.node = self._get_max_splay(self.node) 513 other.node = other._get_min_splay(other.node) 514 assert self.node.key <= other.node.key 515 if self.node.key == other.node.key: 516 self.node.cnt += other.node.cnt 517 self._update(self.node) 518 other.discard(other.node.key, other.node.cnt) 519 self.node.right = other.node 520 self._update(self.node)
521
[docs] 522 def split(self, k: int) -> tuple["SplayTreeMultisetSum", "SplayTreeMultisetSum"]: 523 if self.node is None: 524 return SplayTreeMultisetSum(self.e), self 525 if k >= len(self): 526 return self, SplayTreeMultisetSum(self.e) 527 self._set_kth_elm_splay(k) 528 if self.node.left is None: 529 left = SplayTreeMultisetSum(self.e) 530 else: 531 left = SplayTreeMultisetSum(self.e, _node=self.node.left) 532 if len(left) != k: 533 assert k - len(left) > 0 534 self.node.cnt -= k - len(left) 535 self._update(self.node) 536 root = SplayTreeMultisetSum.Node(self.node.key, k - len(left)) 537 root.left = left.node 538 left = SplayTreeMultisetSum(self.e, _node=root) 539 left._update(left.node) 540 if self.node: 541 self.node.left = None 542 self._update(self.node) 543 return left, self
544 545 def __iter__(self): 546 self.__iter = 0 547 return self 548 549 def __next__(self): 550 if self.__iter == self.__len__(): 551 raise StopIteration 552 res = self.__getitem__(self.__iter) 553 self.__iter += 1 554 return res 555 556 def __reversed__(self): 557 for i in range(self.__len__()): 558 yield self.__getitem__(-i - 1) 559 560 def __contains__(self, key: T) -> bool: 561 self._set_search_splay(key) 562 return self.node is not None and self.node.key == key 563 564 def __getitem__(self, k: int) -> T: 565 self._set_kth_elm_splay(k) 566 return self.node.key 567 568 def __len__(self): 569 return 0 if self.node is None else self.node.cntsize 570 571 def __bool__(self): 572 return self.node is not None 573 574 def __str__(self): 575 return "{" + ", ".join(map(str, self.tolist())) + "}" 576 577 def __repr__(self): 578 return "SplayTreeMultisetSum(" + str(self) + ")"