Source code for titan_pylib.data_structures.avl_tree.avl_tree_set

  1from titan_pylib.my_class.supports_less_than import SupportsLessThan
  2from titan_pylib.my_class.ordered_set_interface import OrderedSetInterface
  3from array import array
  4from typing import Generic, Iterable, TypeVar, Optional
  5
  6T = TypeVar("T", bound=SupportsLessThan)
  7
  8
[docs] 9class AVLTreeSet(OrderedSetInterface, Generic[T]): 10 """AVLTreeSet 11 集合としての AVL 木です。 12 配列を用いてノードを表現しています。 13 size を持ちます。 14 """ 15 16 def __init__(self, a: Iterable[T] = []) -> None: 17 self.root = 0 18 self.key = [0] 19 self.size = array("I", bytes(4)) 20 self.left = array("I", bytes(4)) 21 self.right = array("I", bytes(4)) 22 self.balance = array("b", bytes(1)) 23 self.end = 1 24 if not isinstance(a, list): 25 a = list(a) 26 if a: 27 self._build(a) 28
[docs] 29 def reserve(self, n: int) -> None: 30 self.key += [0] * n 31 a = array("I", bytes(4 * n)) 32 self.left += a 33 self.right += a 34 self.size += array("I", [1] * n) 35 self.balance += array("b", bytes(n))
36 37 def _build(self, a: list[T]) -> None: 38 left, right, size, balance = self.left, self.right, self.size, self.balance 39 40 def sort(l: int, r: int) -> tuple[int, int]: 41 mid = (l + r) >> 1 42 node = mid 43 hl, hr = 0, 0 44 if l != mid: 45 left[node], hl = sort(l, mid) 46 size[node] += size[left[node]] 47 if mid + 1 != r: 48 right[node], hr = sort(mid + 1, r) 49 size[node] += size[right[node]] 50 balance[node] = hl - hr 51 return node, max(hl, hr) + 1 52 53 n = len(a) 54 if n == 0: 55 return 56 if not all(a[i] < a[i + 1] for i in range(n - 1)): 57 b = sorted(a) 58 a = [b[0]] 59 for i in range(1, n): 60 if b[i] != a[-1]: 61 a.append(b[i]) 62 n = len(a) 63 end = self.end 64 self.end += n 65 self.reserve(n) 66 self.key[end : end + n] = a 67 self.root = sort(end, n + end)[0] 68 69 def _rotate_L(self, node: int) -> int: 70 left, right, size, balance = self.left, self.right, self.size, self.balance 71 u = left[node] 72 size[u] = size[node] 73 size[node] -= size[left[u]] + 1 74 left[node] = right[u] 75 right[u] = node 76 if balance[u] == 1: 77 balance[u] = 0 78 balance[node] = 0 79 else: 80 balance[u] = -1 81 balance[node] = 1 82 return u 83 84 def _rotate_R(self, node: int) -> int: 85 left, right, size, balance = self.left, self.right, self.size, self.balance 86 u = right[node] 87 size[u] = size[node] 88 size[node] -= size[right[u]] + 1 89 right[node] = left[u] 90 left[u] = node 91 if balance[u] == -1: 92 balance[u] = 0 93 balance[node] = 0 94 else: 95 balance[u] = 1 96 balance[node] = -1 97 return u 98 99 def _update_balance(self, node: int) -> None: 100 balance = self.balance 101 if balance[node] == 1: 102 balance[self.right[node]] = -1 103 balance[self.left[node]] = 0 104 elif balance[node] == -1: 105 balance[self.right[node]] = 0 106 balance[self.left[node]] = 1 107 else: 108 balance[self.right[node]] = 0 109 balance[self.left[node]] = 0 110 balance[node] = 0 111 112 def _rotate_LR(self, node: int) -> int: 113 left, right, size = self.left, self.right, self.size 114 B = left[node] 115 E = right[B] 116 size[E] = size[node] 117 size[node] -= size[B] - size[right[E]] 118 size[B] -= size[right[E]] + 1 119 right[B] = left[E] 120 left[E] = B 121 left[node] = right[E] 122 right[E] = node 123 self._update_balance(E) 124 return E 125 126 def _rotate_RL(self, node: int) -> int: 127 left, right, size = self.left, self.right, self.size 128 C = right[node] 129 D = left[C] 130 size[D] = size[node] 131 size[node] -= size[C] - size[left[D]] 132 size[C] -= size[left[D]] + 1 133 left[C] = right[D] 134 right[D] = C 135 right[node] = left[D] 136 left[D] = node 137 self._update_balance(D) 138 return D 139 140 def _make_node(self, key: T) -> int: 141 end = self.end 142 if end >= len(self.key): 143 self.key.append(key) 144 self.size.append(1) 145 self.left.append(0) 146 self.right.append(0) 147 self.balance.append(0) 148 else: 149 self.key[end] = key 150 self.end += 1 151 return end 152
[docs] 153 def add(self, key: T) -> bool: 154 if self.root == 0: 155 self.root = self._make_node(key) 156 return True 157 left, right, size, balance, keys = ( 158 self.left, 159 self.right, 160 self.size, 161 self.balance, 162 self.key, 163 ) 164 node = self.root 165 path = [] 166 di = 0 167 while node: 168 if key == keys[node]: 169 return False 170 di <<= 1 171 path.append(node) 172 if key < keys[node]: 173 di |= 1 174 node = left[node] 175 else: 176 node = right[node] 177 if di & 1: 178 left[path[-1]] = self._make_node(key) 179 else: 180 right[path[-1]] = self._make_node(key) 181 new_node = 0 182 while path: 183 node = path.pop() 184 size[node] += 1 185 balance[node] += 1 if di & 1 else -1 186 di >>= 1 187 if balance[node] == 0: 188 break 189 if balance[node] == 2: 190 new_node = ( 191 self._rotate_LR(node) 192 if balance[left[node]] == -1 193 else self._rotate_L(node) 194 ) 195 break 196 elif balance[node] == -2: 197 new_node = ( 198 self._rotate_RL(node) 199 if balance[right[node]] == 1 200 else self._rotate_R(node) 201 ) 202 break 203 if new_node: 204 if path: 205 node = path.pop() 206 size[node] += 1 207 if di & 1: 208 left[node] = new_node 209 else: 210 right[node] = new_node 211 else: 212 self.root = new_node 213 for p in path: 214 size[p] += 1 215 return True
216
[docs] 217 def remove(self, key: T) -> bool: 218 if self.discard(key): 219 return True 220 raise KeyError(key)
221
[docs] 222 def discard(self, key: T) -> bool: 223 left, right, size, balance, keys = ( 224 self.left, 225 self.right, 226 self.size, 227 self.balance, 228 self.key, 229 ) 230 di = 0 231 path = [] 232 node = self.root 233 while node: 234 if key == keys[node]: 235 break 236 path.append(node) 237 di <<= 1 238 if key < keys[node]: 239 di |= 1 240 node = left[node] 241 else: 242 node = right[node] 243 else: 244 return False 245 if left[node] and right[node]: 246 path.append(node) 247 di <<= 1 248 di |= 1 249 lmax = left[node] 250 while right[lmax]: 251 path.append(lmax) 252 di <<= 1 253 lmax = right[lmax] 254 keys[node] = keys[lmax] 255 node = lmax 256 cnode = right[node] if left[node] == 0 else left[node] 257 if path: 258 if di & 1: 259 left[path[-1]] = cnode 260 else: 261 right[path[-1]] = cnode 262 else: 263 self.root = cnode 264 return True 265 while path: 266 new_node = 0 267 node = path.pop() 268 balance[node] -= 1 if di & 1 else -1 269 di >>= 1 270 size[node] -= 1 271 if balance[node] == 2: 272 new_node = ( 273 self._rotate_LR(node) 274 if balance[left[node]] == -1 275 else self._rotate_L(node) 276 ) 277 elif balance[node] == -2: 278 new_node = ( 279 self._rotate_RL(node) 280 if balance[right[node]] == 1 281 else self._rotate_R(node) 282 ) 283 elif balance[node]: 284 break 285 if new_node: 286 if not path: 287 self.root = new_node 288 return True 289 if di & 1: 290 left[path[-1]] = new_node 291 else: 292 right[path[-1]] = new_node 293 if balance[new_node]: 294 break 295 for p in path: 296 size[p] -= 1 297 return True
298
[docs] 299 def le(self, key: T) -> Optional[T]: 300 keys, left, right = self.key, self.left, self.right 301 res = None 302 node = self.root 303 while node: 304 if key == keys[node]: 305 return keys[node] 306 if key < keys[node]: 307 node = left[node] 308 else: 309 res = keys[node] 310 node = right[node] 311 return res
312
[docs] 313 def lt(self, key: T) -> Optional[T]: 314 keys, left, right = self.key, self.left, self.right 315 res = None 316 node = self.root 317 while node: 318 if key <= keys[node]: 319 node = left[node] 320 else: 321 res = keys[node] 322 node = right[node] 323 return res
324
[docs] 325 def ge(self, key: T) -> Optional[T]: 326 keys, left, right = self.key, self.left, self.right 327 res = None 328 node = self.root 329 while node: 330 if key == keys[node]: 331 return keys[node] 332 if key < keys[node]: 333 res = keys[node] 334 node = left[node] 335 else: 336 node = right[node] 337 return res
338
[docs] 339 def gt(self, key: T) -> Optional[T]: 340 keys, left, right = self.key, self.left, self.right 341 res = None 342 node = self.root 343 while node: 344 if key < keys[node]: 345 res = keys[node] 346 node = left[node] 347 else: 348 node = right[node] 349 return res
350
[docs] 351 def index(self, key: T) -> int: 352 keys, left, right, size = self.key, self.left, self.right, self.size 353 k = 0 354 node = self.root 355 while node: 356 if key == keys[node]: 357 k += size[left[node]] 358 break 359 if key < keys[node]: 360 node = left[node] 361 else: 362 k += size[left[node]] + 1 363 node = right[node] 364 return k
365
[docs] 366 def index_right(self, key: T) -> int: 367 keys, left, right, size = self.key, self.left, self.right, self.size 368 k, node = 0, self.root 369 while node: 370 if key == keys[node]: 371 k += size[left[node]] + 1 372 break 373 if key < keys[node]: 374 node = left[node] 375 else: 376 k += size[left[node]] + 1 377 node = right[node] 378 return k
379
[docs] 380 def get_max(self) -> Optional[T]: 381 if not self: 382 return 383 return self[len(self) - 1]
384
[docs] 385 def get_min(self) -> Optional[T]: 386 if not self: 387 return 388 return self[0]
389
[docs] 390 def pop(self, k: int = -1) -> T: 391 left, right, size, key, balance = ( 392 self.left, 393 self.right, 394 self.size, 395 self.key, 396 self.balance, 397 ) 398 if k < 0: 399 k += size[self.root] 400 assert 0 <= k and k < size[self.root], "IndexError" 401 path = [] 402 di = 0 403 node = self.root 404 while True: 405 t = size[left[node]] 406 if t == k: 407 res = key[node] 408 break 409 path.append(node) 410 di <<= 1 411 if t < k: 412 k -= t + 1 413 node = right[node] 414 else: 415 di |= 1 416 node = left[node] 417 if left[node] and right[node]: 418 path.append(node) 419 di <<= 1 420 di |= 1 421 lmax = left[node] 422 while right[lmax]: 423 path.append(lmax) 424 di <<= 1 425 lmax = right[lmax] 426 key[node] = key[lmax] 427 node = lmax 428 cnode = right[node] if left[node] == 0 else left[node] 429 if path: 430 if di & 1: 431 left[path[-1]] = cnode 432 else: 433 right[path[-1]] = cnode 434 else: 435 self.root = cnode 436 return res 437 while path: 438 new_node = 0 439 node = path.pop() 440 balance[node] -= 1 if di & 1 else -1 441 di >>= 1 442 size[node] -= 1 443 if balance[node] == 2: 444 new_node = ( 445 self._rotate_LR(node) 446 if balance[left[node]] == -1 447 else self._rotate_L(node) 448 ) 449 elif balance[node] == -2: 450 new_node = ( 451 self._rotate_RL(node) 452 if balance[right[node]] == 1 453 else self._rotate_R(node) 454 ) 455 elif balance[node]: 456 break 457 if new_node: 458 if not path: 459 self.root = new_node 460 return res 461 if di & 1: 462 left[path[-1]] = new_node 463 else: 464 right[path[-1]] = new_node 465 if balance[new_node]: 466 break 467 for p in path: 468 size[p] -= 1 469 return res
470
[docs] 471 def pop_max(self) -> T: 472 return self.pop()
473
[docs] 474 def pop_min(self) -> T: 475 return self.pop(0)
476
[docs] 477 def clear(self) -> None: 478 self.root = 0
479
[docs] 480 def tolist(self) -> list[T]: 481 left, right, keys = self.left, self.right, self.key 482 node = self.root 483 stack, a = [], [] 484 while stack or node: 485 if node: 486 stack.append(node) 487 node = left[node] 488 else: 489 node = stack.pop() 490 a.append(keys[node]) 491 node = right[node] 492 return a
493
[docs] 494 def ave_height(self): 495 if not self.root: 496 return 0 497 left, right, size, keys = self.left, self.right, self.size, self.key 498 ans = 0 499 500 def dfs(node, dep): 501 nonlocal ans 502 ans += dep 503 if left[node]: 504 dfs(left[node], dep + 1) 505 if right[node]: 506 dfs(right[node], dep + 1) 507 508 dfs(self.root, 1) 509 ans /= len(self) 510 return ans
511 512 def _get_height(self) -> int: 513 """作業用デバック関数 514 size,key,balanceをチェックして、正しければ高さを表示する 515 """ 516 if self.root == 0: 517 return 0 518 519 # _size, height 520 def dfs(node) -> tuple[int, int]: 521 h = 0 522 s = 1 523 if self.left[node]: 524 ls, lh = dfs(self.left[node]) 525 s += ls 526 h = max(h, lh) 527 if self.right[node]: 528 rs, rh = dfs(self.right[node]) 529 s += rs 530 h = max(h, rh) 531 assert self.size[node] == s 532 return s, h + 1 533 534 _, h = dfs(self.root) 535 return h 536 537 def __contains__(self, key: T) -> bool: 538 keys, left, right = self.key, self.left, self.right 539 node = self.root 540 while node: 541 if key == keys[node]: 542 return True 543 node = left[node] if key < keys[node] else right[node] 544 return False 545 546 def __getitem__(self, k: int) -> T: 547 left, right, size, key = self.left, self.right, self.size, self.key 548 if k < 0: 549 k += size[self.root] 550 assert ( 551 0 <= k and k < size[self.root] 552 ), f"IndexError: AVLTreeSet[{k}], len={len(self)}" 553 node = self.root 554 while True: 555 t = size[left[node]] 556 if t == k: 557 return key[node] 558 if t < k: 559 k -= t + 1 560 node = right[node] 561 else: 562 node = left[node] 563 564 def __iter__(self): 565 self.__iter = 0 566 return self 567 568 def __next__(self): 569 if self.__iter == self.__len__(): 570 raise StopIteration 571 res = self[self.__iter] 572 self.__iter += 1 573 return res 574 575 def __reversed__(self): 576 for i in range(self.__len__()): 577 yield self[-i - 1] 578 579 def __len__(self): 580 return self.size[self.root] 581 582 def __bool__(self): 583 return self.root != 0 584 585 def __str__(self): 586 return "{" + ", ".join(map(str, self.tolist())) + "}" 587 588 def __repr__(self): 589 return f"AVLTreeSet({self})"