Source code for titan_pylib.data_structures.splay_tree.splay_tree_multiset

  1from titan_pylib.my_class.supports_less_than import SupportsLessThan
  2import sys
  3from typing import Iterator, Optional, Generic, Iterable, TypeVar
  4
  5T = TypeVar("T", bound=SupportsLessThan)
  6
  7
[docs] 8class SplayTreeMultiset(Generic[T]): 9
[docs] 10 class Node: 11 12 def __init__(self, key: T, val: int): 13 self.key: T = key 14 self.size: int = 1 15 self.val: int = val 16 self.valsize: int = val 17 self.left: Optional["SplayTreeMultiset.Node"] = None 18 self.right: Optional["SplayTreeMultiset.Node"] = None 19 20 def __str__(self): 21 if self.left is None and self.right is None: 22 return f"key:{self.key, self.size, self.val, self.valsize}\n" 23 return f"key:{self.key, self.size, self.val, self.valsize},\n left:{self.left},\n right:{self.right}\n"
24 25 def __init__(self, a: Iterable[T] = []) -> None: 26 self.root: Optional["SplayTreeMultiset.Node"] = None 27 if a: 28 self._build(a) 29 30 def _build(self, a: Iterable[T]) -> None: 31 Node = SplayTreeMultiset.Node 32 33 def sort(l: int, r: int) -> SplayTreeMultiset.Node: 34 mid = (l + r) >> 1 35 node = Node(key[mid], val[mid]) 36 if l != mid: 37 node.left = sort(l, mid) 38 if mid + 1 != r: 39 node.right = sort(mid + 1, r) 40 self._update(node) 41 return node 42 43 key, val = self._rle(sorted(a)) 44 if len(key) == 0: 45 return 46 self.root = sort(0, len(key)) 47 48 def _rle(self, a: list[T]) -> tuple[list[T], list[int]]: 49 x = [] 50 y = [] 51 x.append(a[0]) 52 y.append(1) 53 for i, e in enumerate(a): 54 if i == 0: 55 continue 56 if e == x[-1]: 57 y[-1] += 1 58 continue 59 x.append(e) 60 y.append(1) 61 return x, y 62 63 def _update(self, node: Node) -> None: 64 if node.left is None: 65 if node.right is None: 66 node.size = 1 67 node.valsize = node.val 68 else: 69 node.size = 1 + node.right.size 70 node.valsize = node.val + node.right.valsize 71 else: 72 if node.right is None: 73 node.size = 1 + node.left.size 74 node.valsize = node.val + node.left.valsize 75 else: 76 node.size = 1 + node.left.size + node.right.size 77 node.valsize = node.val + node.left.valsize + node.right.valsize 78 79 def _splay(self, path: list[Node], d: int) -> Node: 80 for _ in range(len(path) >> 1): 81 node = path.pop() 82 pnode = path.pop() 83 if d & 1 == d >> 1 & 1: 84 if d & 1: 85 tmp = node.left 86 node.left = tmp.right 87 tmp.right = node 88 pnode.left = node.right 89 node.right = pnode 90 else: 91 tmp = node.right 92 node.right = tmp.left 93 tmp.left = node 94 pnode.right = node.left 95 node.left = pnode 96 else: 97 if d & 1: 98 tmp = node.left 99 node.left = tmp.right 100 pnode.right = tmp.left 101 tmp.right = node 102 tmp.left = pnode 103 else: 104 tmp = node.right 105 node.right = tmp.left 106 pnode.left = tmp.right 107 tmp.left = node 108 tmp.right = pnode 109 self._update(pnode) 110 self._update(node) 111 self._update(tmp) 112 if not path: 113 return tmp 114 d >>= 2 115 if d & 1: 116 path[-1].left = tmp 117 else: 118 path[-1].right = tmp 119 gnode = path[0] 120 if d & 1: 121 node = gnode.left 122 gnode.left = node.right 123 node.right = gnode 124 else: 125 node = gnode.right 126 gnode.right = node.left 127 node.left = gnode 128 self._update(gnode) 129 self._update(node) 130 return node 131 132 def _set_search_splay(self, key: T) -> None: 133 node = self.root 134 if node is None or node.key == key: 135 return 136 path = [] 137 d = 0 138 while True: 139 if node.key == key: 140 break 141 if key < node.key: 142 if node.left is None: 143 break 144 path.append(node) 145 d <<= 1 146 d |= 1 147 node = node.left 148 else: 149 if node.right is None: 150 break 151 path.append(node) 152 d <<= 1 153 node = node.right 154 if path: 155 self.root = self._splay(path, d) 156 157 def _set_kth_elm_splay(self, k: int) -> None: 158 if k < 0: 159 k += self.__len__() 160 d = 0 161 node = self.root 162 path = [] 163 while True: 164 t = node.val if node.left is None else node.val + node.left.valsize 165 if t - node.val <= k < t: 166 if path: 167 self.root = self._splay(path, d) 168 break 169 elif t > k: 170 path.append(node) 171 d <<= 1 172 d |= 1 173 node = node.left 174 else: 175 path.append(node) 176 d <<= 1 177 node = node.right 178 k -= t 179 180 def _set_kth_elm_tree_splay(self, k: int) -> None: 181 if k < 0: 182 k += self.len_elm() 183 assert 0 <= k < self.len_elm() 184 d = 0 185 node = self.root 186 path = [] 187 while True: 188 t = 0 if node.left is None else node.left.size 189 if t == k: 190 if path: 191 self.root = self._splay(path, d) 192 return 193 elif t > k: 194 path.append(node) 195 d <<= 1 196 d |= 1 197 node = node.left 198 else: 199 path.append(node) 200 d <<= 1 201 node = node.right 202 k -= t + 1 203 204 def _get_min_splay(self, node: Node) -> Node: 205 if node is None or node.left is None: 206 return node 207 path = [] 208 while node.left is not None: 209 path.append(node) 210 node = node.left 211 return self._splay(path, (1 << len(path)) - 1) 212 213 def _get_max_splay(self, node: Node) -> Node: 214 if node is None or node.right is None: 215 return node 216 path = [] 217 while node.right is not None: 218 path.append(node) 219 node = node.right 220 return self._splay(path, 0) 221
[docs] 222 def add(self, key: T, val: int = 1) -> None: 223 if self.root is None: 224 self.root = SplayTreeMultiset.Node(key, val) 225 return 226 self._set_search_splay(key) 227 if self.root.key == key: 228 self.root.val += val 229 self._update(self.root) 230 return 231 node = SplayTreeMultiset.Node(key, val) 232 if key < self.root.key: 233 node.left = self.root.left 234 node.right = self.root 235 self.root.left = None 236 self._update(node.right) 237 else: 238 node.left = self.root 239 node.right = self.root.right 240 self.root.right = None 241 self._update(node.left) 242 self._update(node) 243 self.root = node 244 return
245
[docs] 246 def discard(self, key: T, val: int = 1) -> bool: 247 if self.root is None: 248 return False 249 self._set_search_splay(key) 250 if self.root.key != key: 251 return False 252 if self.root.val > val: 253 self.root.val -= val 254 self._update(self.root) 255 return True 256 if self.root.left is None: 257 self.root = self.root.right 258 elif self.root.right is None: 259 self.root = self.root.left 260 else: 261 node = self._get_min_splay(self.root.right) 262 node.left = self.root.left 263 self._update(node) 264 self.root = node 265 return True
266
[docs] 267 def discard_all(self, key: T) -> bool: 268 return self.discard(key, self.count(key))
269
[docs] 270 def count(self, key: T) -> int: 271 if self.root is None: 272 return 0 273 self._set_search_splay(key) 274 return self.root.val if self.root.key == key else 0
275
[docs] 276 def le(self, key: T) -> Optional[T]: 277 node = self.root 278 if node is None: 279 return None 280 path = [] 281 d = 0 282 res = None 283 while True: 284 if node.key == key: 285 res = key 286 break 287 elif key < node.key: 288 if node.left is None: 289 break 290 path.append(node) 291 d <<= 1 292 d |= 1 293 node = node.left 294 else: 295 res = node.key 296 if node.right is None: 297 break 298 path.append(node) 299 d <<= 1 300 node = node.right 301 if path: 302 self.root = self._splay(path, d) 303 return res
304
[docs] 305 def lt(self, key: T) -> Optional[T]: 306 node = self.root 307 path = [] 308 d = 0 309 res = None 310 while node is not None: 311 if key <= node.key: 312 path.append(node) 313 d <<= 1 314 d |= 1 315 node = node.left 316 else: 317 path.append(node) 318 d <<= 1 319 res = node.key 320 node = node.right 321 else: 322 if path: 323 path.pop() 324 d >>= 1 325 if path: 326 self.root = self._splay(path, d) 327 return res
328
[docs] 329 def ge(self, key: T) -> Optional[T]: 330 node = self.root 331 if node is None: 332 return None 333 path = [] 334 d = 0 335 res = None 336 while True: 337 if node.key == key: 338 res = node.key 339 break 340 elif key < node.key: 341 res = node.key 342 if node.left is None: 343 break 344 path.append(node) 345 d <<= 1 346 d |= 1 347 node = node.left 348 else: 349 if node.right is None: 350 break 351 path.append(node) 352 d <<= 1 353 node = node.right 354 if path: 355 self.root = self._splay(path, d) 356 return res
357
[docs] 358 def gt(self, key: T) -> Optional[T]: 359 node = self.root 360 path = [] 361 d = 0 362 res = None 363 while node is not None: 364 if key < node.key: 365 path.append(node) 366 d <<= 1 367 d |= 1 368 res = node.key 369 node = node.left 370 else: 371 path.append(node) 372 d <<= 1 373 node = node.right 374 else: 375 if path: 376 path.pop() 377 d >>= 1 378 if path: 379 self.root = self._splay(path, d) 380 return res
381
[docs] 382 def index(self, key: T) -> int: 383 if self.root is None: 384 return 0 385 self._set_search_splay(key) 386 res = 0 if self.root.left is None else self.root.left.valsize 387 if self.root.key < key: 388 res += self.root.val 389 return res
390
[docs] 391 def index_right(self, key: T) -> int: 392 if self.root is None: 393 return 0 394 self._set_search_splay(key) 395 res = 0 if self.root.left is None else self.root.left.valsize 396 if self.root.key <= key: 397 res += self.root.val 398 return res
399
[docs] 400 def index_keys(self, key: T) -> int: 401 if self.root is None: 402 return 0 403 self._set_search_splay(key) 404 res = 0 if self.root.left is None else self.root.left.size 405 if self.root.key < key: 406 res += 1 407 return res
408
[docs] 409 def index_right_keys(self, key: T) -> int: 410 if self.root is None: 411 return 0 412 self._set_search_splay(key) 413 res = 0 if self.root.left is None else self.root.left.size 414 if self.root.key <= key: 415 res += 1 416 return res
417
[docs] 418 def pop(self, k: int = -1) -> T: 419 self._set_kth_elm_splay(k) 420 res = self.root.key 421 self.discard(res) 422 return res
423
[docs] 424 def pop_max(self) -> T: 425 return self.pop()
426
[docs] 427 def pop_min(self) -> T: 428 return self.pop(0)
429
[docs] 430 def tolist(self) -> list[T]: 431 a = [] 432 if self.root is None: 433 return a 434 if sys.getrecursionlimit() < self.len_elm(): 435 sys.setrecursionlimit(self.len_elm() + 1) 436 437 def rec(node): 438 if node.left is not None: 439 rec(node.left) 440 for _ in range(node.val): 441 a.append(node.key) 442 if node.right is not None: 443 rec(node.right) 444 445 rec(self.root) 446 return a
447
[docs] 448 def tolist_items(self) -> list[tuple[T, int]]: 449 a = [] 450 if self.root is None: 451 return a 452 if sys.getrecursionlimit() < self.len_elm(): 453 sys.setrecursionlimit(self.len_elm() + 1) 454 455 def rec(node): 456 if node.left is not None: 457 rec(node.left) 458 a.append((node.key, node.val)) 459 if node.right is not None: 460 rec(node.right) 461 462 rec(self.root) 463 return a
464
[docs] 465 def get_elm(self, k: int) -> T: 466 assert -self.len_elm() <= k < self.len_elm() 467 self._set_kth_elm_tree_splay(k) 468 return self.root.key
469
[docs] 470 def items(self) -> Iterator[tuple[T, int]]: 471 for i in range(self.len_elm()): 472 self._set_kth_elm_tree_splay(i) 473 yield self.root.key, self.root.val
474
[docs] 475 def keys(self) -> Iterator[T]: 476 for i in range(self.len_elm()): 477 self._set_kth_elm_tree_splay(i) 478 yield self.root.key
479
[docs] 480 def values(self) -> Iterator[int]: 481 for i in range(self.len_elm()): 482 self._set_kth_elm_tree_splay(i) 483 yield self.root.val
484
[docs] 485 def len_elm(self) -> int: 486 return 0 if self.root is None else self.root.size
487
[docs] 488 def show(self) -> None: 489 print( 490 "{" + ", ".join(map(lambda x: f"{x[0]}: {x[1]}", self.tolist_items())) + "}" 491 )
492
[docs] 493 def clear(self) -> None: 494 self.root = None
495 496 def __iter__(self): 497 self.__iter = 0 498 return self 499 500 def __next__(self): 501 if self.__iter == self.__len__(): 502 raise StopIteration 503 res = self.__getitem__(self.__iter) 504 self.__iter += 1 505 return res 506 507 def __reversed__(self): 508 for i in range(self.__len__()): 509 yield self.__getitem__(-i - 1) 510 511 def __contains__(self, key: T) -> bool: 512 self._set_search_splay(key) 513 return self.root is not None and self.root.key == key 514 515 def __getitem__(self, k: int) -> T: 516 self._set_kth_elm_splay(k) 517 return self.root.key 518 519 def __len__(self): 520 return 0 if self.root is None else self.root.valsize 521 522 def __bool__(self): 523 return self.root is not None 524 525 def __str__(self): 526 return "{" + ", ".join(map(str, self.tolist())) + "}" 527 528 def __repr__(self): 529 return f"SplayTreeMultiset({self.tolist()})"