Source code for titan_pylib.data_structures.wbt.wbt_set

  1from titan_pylib.data_structures.wbt._wbt_set_node import _WBTSetNode
  2from typing import Generic, TypeVar, Optional, Iterable, Iterator
  3
  4T = TypeVar("T")
  5
  6
[docs] 7class WBTSet(Generic[T]): 8 """重み平衡木で実装された順序付き集合""" 9 10 __slots__ = "_root", "_min", "_max" 11 12 def __init__(self, a: Iterable[T] = []) -> None: 13 """イテラブル ``a`` から ``WBTSet`` を構築します。 14 15 Args: 16 a (Iterable[T], optional): 構築元のイテラブルです。 17 18 計算量: 19 20 ソート済みなら :math:`O(n)` 、そうでないなら :math:`O(n \\log{n})` 21 """ 22 self._root: Optional[_WBTSetNode[T]] = None 23 self._min: Optional[_WBTSetNode[T]] = None 24 self._max: Optional[_WBTSetNode[T]] = None 25 self.__build(a) 26 27 def __build(self, a: Iterable[T]) -> None: 28 """再帰的に構築する関数""" 29 30 def build( 31 l: int, r: int, pnode: Optional[_WBTSetNode[T]] = None 32 ) -> _WBTSetNode[T]: 33 if l == r: 34 return None 35 mid = (l + r) // 2 36 node = _WBTSetNode(a[mid]) 37 node._left = build(l, mid, node) 38 node._right = build(mid + 1, r, node) 39 node._par = pnode 40 node._update() 41 return node 42 43 a = list(a) 44 if not a: 45 return 46 if not all(a[i] < a[i + 1] for i in range(len(a) - 1)): 47 a.sort() 48 new_a = [a[0]] 49 for elm in a: 50 if new_a[-1] == elm: 51 continue 52 new_a.append(elm) 53 a = new_a 54 self._root = build(0, len(a)) 55 self._max = self._root._max() 56 self._min = self._root._min() 57
[docs] 58 def add(self, key: T) -> bool: 59 """既に ``key`` が存在していれば何もせず ``False`` を返し、 60 存在していれば ``key`` を 1 つ追加して ``True`` を返します。 61 62 Args: 63 key (T): 追加するキーです。 64 65 Returns: 66 bool: ``key`` を追加したら ``True`` 、そうでなければ ``False`` を返します。 67 68 計算量: 69 :math:`O(\\log{n})` 70 """ 71 if not self._root: 72 self._root = _WBTSetNode(key) 73 self._max = self._root 74 self._min = self._root 75 return True 76 pnode = None 77 node = self._root 78 while node: 79 if key == node._key: 80 return False 81 pnode = node 82 node = node._left if key < node._key else node._right 83 if key < pnode._key: 84 pnode._left = _WBTSetNode(key) 85 if key < self._min._key: 86 self._min = pnode._left 87 pnode._left._par = pnode 88 else: 89 pnode._right = _WBTSetNode(key) 90 if key > self._max._key: 91 self._max = pnode._right 92 pnode._right._par = pnode 93 self._root = pnode._rebalance() 94 return True
95
[docs] 96 def find_key(self, key: T) -> Optional[_WBTSetNode[T]]: 97 """``key`` が存在すれば ``key`` を指すノードを返します。 98 そうでなければ ``None`` を返します。 99 100 Args: 101 key (T): 102 103 Returns: 104 Optional[_WBTSetNode[T]]: 105 106 計算量: 107 :math:`O(\\log{n})` 108 """ 109 node = self._root 110 while node: 111 if key == node._key: 112 return node 113 node = node._left if key < node._key else node._right 114 return None
115
[docs] 116 def find_order(self, k: int) -> _WBTSetNode[T]: 117 """昇順 ``k`` 番目のノードを返します。 118 119 Args: 120 k (int): 121 122 Returns: 123 _WBTSetNode[T]: 124 125 計算量: 126 :math:`O(\\log{n})` 127 128 制約: 129 :math:`-n \\leq k \\le n` 130 """ 131 if k < 0: 132 k += len(self) 133 node = self._root 134 while True: 135 t = node._left._size if node._left else 0 136 if t == k: 137 return node 138 if t < k: 139 k -= t + 1 140 node = node._right 141 else: 142 node = node._left
143
[docs] 144 def count(self, key: T) -> int: 145 return 1 if self.find_key(key) is not None else 0
146
[docs] 147 def remove_iter(self, node: _WBTSetNode[T]) -> None: 148 """``node`` を削除します。 149 150 Args: 151 node (_WBTSetNode[T]): 152 153 計算量: 154 :math:`O(\\log{n})` 155 """ 156 if node is self._min: 157 self._min = self._min._next() 158 if node is self._max: 159 self._max = self._max._prev() 160 delnode = node 161 pnode, mnode = node._par, None 162 if node._left and node._right: 163 pnode, mnode = node, node._left 164 while mnode._right: 165 pnode, mnode = mnode, mnode._right 166 node = mnode 167 cnode = node._right if not node._left else node._left 168 if cnode: 169 cnode._par = pnode 170 if pnode: 171 if pnode._left is node: 172 pnode._left = cnode 173 else: 174 pnode._right = cnode 175 self._root = pnode._rebalance() 176 else: 177 self._root = cnode 178 if mnode: 179 if self._root is delnode: 180 self._root = mnode 181 mnode._copy_from(delnode) 182 del delnode
183
[docs] 184 def remove(self, key: T) -> None: 185 """``key`` を削除します。 186 187 Args: 188 key (T): 削除する ``key`` です。 189 190 計算量: 191 :math:`O(\\log{n})` 192 193 Note: 194 ``key`` が存在しない場合、 ``AssertionError`` を出します。 195 """ 196 node = self.find_key(key) 197 assert node, f"KeyError: {key} is not exist." 198 self.remove_iter(node)
199
[docs] 200 def discard(self, key: T) -> bool: 201 """``key`` が存在すれば削除して ``True`` を返します。 202 存在しなければなにもせず ``False`` を返します。 203 204 Args: 205 key (T): 削除する ``key`` です。 206 207 Returns: 208 bool: ``key`` が存在したかどうか 209 210 計算量: 211 :math:`O(\\log{n})` 212 """ 213 node = self.find_key(key) 214 if node is None: 215 return False 216 self.remove_iter(node) 217 return True
218
[docs] 219 def pop(self, k: int = -1) -> T: 220 """``k`` 番目の値を削除して返します。 221 引数指定がない場合は最大の値を削除して返します。 222 223 Args: 224 k (int, optional): 削除するインデックスです。 225 226 Returns: 227 T: ``k`` 番目の値です。 228 229 計算量: 230 :math:`O(\\log{n})` 231 """ 232 node = self.find_order(k) 233 key = node._key 234 self.remove_iter(node) 235 return key
236
[docs] 237 def le_iter(self, key: T) -> Optional[_WBTSetNode[T]]: 238 """``key`` 以下で最大のノードを返します。存在しないときは ``None`` を返します。 239 240 計算量: 241 :math:`O(\\log{n})` 242 """ 243 res = None 244 node = self._root 245 while node: 246 if key == node._key: 247 res = node 248 break 249 if key < node._key: 250 node = node._left 251 else: 252 res = node 253 node = node._right 254 return res
255
[docs] 256 def lt_iter(self, key: T) -> Optional[_WBTSetNode[T]]: 257 """``key`` より小さい値で最大のノードを返します。存在しないときは ``None`` を返します。 258 259 計算量: 260 :math:`O(\\log{n})` 261 """ 262 res = None 263 node = self._root 264 while node: 265 if key <= node._key: 266 node = node._left 267 else: 268 res = node 269 node = node._right 270 return res
271
[docs] 272 def ge_iter(self, key: T) -> Optional[_WBTSetNode[T]]: 273 """``key`` 以上で最小のノードを返します。存在しないときは ``None`` を返します。 274 275 計算量: 276 :math:`O(\\log{n})` 277 """ 278 res = None 279 node = self._root 280 while node: 281 if key == node._key: 282 res = node 283 break 284 if key < node._key: 285 res = node 286 node = node._left 287 else: 288 node = node._right 289 return res
290
[docs] 291 def gt_iter(self, key: T) -> Optional[_WBTSetNode[T]]: 292 """``key`` より大きい値で最小のノードを返します。存在しないときは ``None`` を返します。 293 294 計算量: 295 :math:`O(\\log{n})` 296 """ 297 res = None 298 node = self._root 299 while node: 300 if key < node._key: 301 res = node 302 node = node._left 303 else: 304 node = node._right 305 return res
306
[docs] 307 def le(self, key: T) -> Optional[T]: 308 """``key`` 以下で最大の要素を返します。存在しないときは ``None`` を返します。 309 310 計算量: 311 :math:`O(\\log{n})` 312 """ 313 res = None 314 node = self._root 315 while node: 316 if key == node._key: 317 res = key 318 break 319 if key < node._key: 320 node = node._left 321 else: 322 res = node._key 323 node = node._right 324 return res
325
[docs] 326 def lt(self, key: T) -> Optional[T]: 327 """``key`` より小さい値で最大の要素を返します。存在しないときは ``None`` を返します。 328 329 計算量: 330 :math:`O(\\log{n})` 331 """ 332 res = None 333 node = self._root 334 while node: 335 if key <= node._key: 336 node = node._left 337 else: 338 res = node._key 339 node = node._right 340 return res
341
[docs] 342 def ge(self, key: T) -> Optional[T]: 343 """``key`` 以上で最小の要素を返します。存在しないときは ``None`` を返します。 344 345 計算量: 346 :math:`O(\\log{n})` 347 """ 348 res = None 349 node = self._root 350 while node: 351 if key == node._key: 352 res = key 353 break 354 if key < node._key: 355 res = node._key 356 node = node._left 357 else: 358 node = node._right 359 return res
360
[docs] 361 def gt(self, key: T) -> Optional[T]: 362 """``key`` より大きい値で最小の要素を返します。存在しないときは ``None`` を返します。 363 364 計算量: 365 :math:`O(\\log{n})` 366 """ 367 res = None 368 node = self._root 369 while node: 370 if key < node._key: 371 res = node._key 372 node = node._left 373 else: 374 node = node._right 375 return res
376
[docs] 377 def index(self, key: T) -> int: 378 """``key`` より小さい値を個数を返します。 379 380 Args: 381 key (T): 382 383 Returns: 384 int: 385 386 計算量: 387 :math:`O(\\log{n})` 388 """ 389 k = 0 390 node = self._root 391 while node: 392 if key == node._key: 393 k += node._left._size if node._left else 0 394 break 395 if key < node._key: 396 node = node._left 397 else: 398 k += node._left._size + 1 if node._left else 1 399 node = node._right 400 return k
401
[docs] 402 def index_right(self, key: T) -> int: 403 """``key`` 以下の値を個数を返します。 404 405 Args: 406 key (T): 407 408 Returns: 409 int: 410 411 計算量: 412 :math:`O(\\log{n})` 413 """ 414 k = 0 415 node = self._root 416 while node: 417 if key == node._key: 418 k += node._left._size + 1 if node._left else 1 419 break 420 if key < node._key: 421 node = node._left 422 else: 423 k += node._left._size + 1 if node._left else 1 424 node = node._right 425 return k
426
[docs] 427 def get_min(self) -> T: 428 """最小の要素を返します。 429 430 Returns: 431 T: 432 433 計算量: 434 :math:`O(1)` 435 436 制約: 437 :math:`0 < n` 438 """ 439 assert self._min 440 return self._min._key
441
[docs] 442 def get_max(self) -> T: 443 """最大の要素を返します。 444 445 Returns: 446 T: 447 448 計算量: 449 :math:`O(1)` 450 451 制約: 452 :math:`0 < n` 453 """ 454 assert self._max 455 return self._max._key
456
[docs] 457 def pop_min(self) -> T: 458 """最小の要素を削除して返します。 459 460 Returns: 461 T: 462 463 計算量: 464 :math:`O(\\log{n})` 465 466 制約: 467 :math:`0 < n` 468 """ 469 assert self._min 470 key = self._min._key 471 self.remove_iter(self._min) 472 return key
473
[docs] 474 def pop_max(self) -> T: 475 """最大の要素を削除して返します。 476 477 Returns: 478 T: 479 480 計算量: 481 :math:`O(\\log{n})` 482 483 制約: 484 :math:`0 < n` 485 """ 486 assert self._max 487 key = self._max._key 488 self.remove_iter(self._max) 489 return key
490 491 def _check(self) -> int: 492 """作業用デバック関数 493 size,key,balanceをチェックして、正しければ高さを表示する 494 """ 495 if self._root is None: 496 # print("ok. 0 (empty)") 497 return 0 498 499 # _size, height 500 def dfs(node: _WBTSetNode[T]) -> tuple[int, int]: 501 h = 0 502 s = 1 503 if node._left: 504 assert node._key > node._left._key 505 ls, lh = dfs(node._left) 506 s += ls 507 h = max(h, lh) 508 if node._right: 509 assert node._key < node._right._key 510 rs, rh = dfs(node._right) 511 s += rs 512 h = max(h, rh) 513 assert node._size == s 514 node._balance_check() 515 return s, h + 1 516 517 _, h = dfs(self._root) 518 # print(f"ok. {h}") 519 return h 520
[docs] 521 def __contains__(self, key: T) -> bool: 522 """``key`` が存在すれば ``True`` 、そうでなければ ``False`` を返します。 523 524 Args: 525 key (T): 526 527 Returns: 528 bool: 529 530 計算量: 531 :math:`O(\\log{n})` 532 """ 533 return self.find_key(key) is not None
534
[docs] 535 def __getitem__(self, k: int) -> T: 536 """昇順 ``k`` 番目の値を返します。 537 538 Args: 539 k (int): 540 541 Returns: 542 T: 543 544 計算量: 545 k = 0 または k = n-1 の場合: :math:`O(1)` 546 そうでない場合: :math:`O(\\log{n})` 547 """ 548 assert ( 549 -len(self) <= k < len(self) 550 ), f"IndexError: {self.__class__.__name__}[{k}], len={len(self)}" 551 if k < 0: 552 k += len(self) 553 if k == 0: 554 return self.get_min() 555 if k == len(self) - 1: 556 return self.get_max() 557 return self.find_order(k)._key
558 559 # def __delitem__(self, k: int) -> None: 560 # self.remove_iter(self.find_order(k)) 561
[docs] 562 def __len__(self) -> int: 563 """要素数を返します。 564 565 Returns: 566 int: 567 568 計算量: 569 :math:`O(1)` 570 """ 571 return self._root._size if self._root else 0
572
[docs] 573 def __iter__(self) -> Iterator[T]: 574 """昇順に値を返します。 575 576 Yields: 577 Iterator[T]: 578 579 計算量: 580 全体で :math:`O(n)` 581 """ 582 stack: list[_WBTSetNode[T]] = [] 583 node = self._root 584 while stack or node: 585 if node: 586 stack.append(node) 587 node = node._left 588 else: 589 node = stack.pop() 590 yield node._key 591 node = node._right
592
[docs] 593 def __reversed__(self) -> Iterator[T]: 594 """降順に値を返します。 595 596 Yields: 597 Iterator[T]: 598 599 計算量: 600 全体で :math:`O(n)` 601 """ 602 stack: list[_WBTSetNode[T]] = [] 603 node = self._root 604 while stack or node: 605 if node: 606 stack.append(node) 607 node = node._right 608 else: 609 node = stack.pop() 610 yield node._key 611 node = node._left
612 613 def __str__(self) -> str: 614 return "{" + ", ".join(map(str, self)) + "}" 615 616 def __repr__(self) -> str: 617 return f"{self.__class__.__name__}(" + "{" + ", ".join(map(str, self)) + "})"