Source code for titan_pylib.data_structures.avl_tree.wbt_set

  1from array import array
  2from typing import Generic, Iterable, TypeVar, Optional, Final
  3
  4T = TypeVar("T")
  5
  6DELTA: Final[int] = 3
  7GAMMA: Final[int] = 2
  8
  9
[docs] 10class WBTSet(Generic[T]): 11 12 def __init__(self, a: Iterable[T] = []) -> None: 13 self.root = 0 14 self.key = [0] 15 self.size = array("I", bytes(4)) 16 self.left = array("I", bytes(4)) 17 self.right = array("I", bytes(4)) 18 self.end = 1 19 if not isinstance(a, list): 20 a = list(a) 21 if a: 22 self._build(a) 23
[docs] 24 def reserve(self, n: int) -> None: 25 if n <= 0: 26 return 27 self.key += [0] * n 28 a = array("I", bytes(4 * n)) 29 self.left += a 30 self.right += a 31 self.size += array("I", [1] * n)
32 33 def _build(self, a: list[T]) -> None: 34 left, right, size = self.left, self.right, self.size 35 36 def sort(l: int, r: int) -> int: 37 mid = (l + r) >> 1 38 node = mid 39 if l != mid: 40 left[node] = sort(l, mid) 41 size[node] += size[left[node]] 42 if mid + 1 != r: 43 right[node] = sort(mid + 1, r) 44 size[node] += size[right[node]] 45 return node 46 47 n = len(a) 48 if n == 0: 49 return 50 if not all(a[i] < a[i + 1] for i in range(n - 1)): 51 b = sorted(a) 52 a = [b[0]] 53 for i in range(1, n): 54 if b[i] != a[-1]: 55 a.append(b[i]) 56 n = len(a) 57 end = self.end 58 self.end += n 59 self.reserve(n) 60 self.key[end : end + n] = a 61 self.root = sort(end, n + end) 62 63 def _rotate_left(self, node: int) -> int: 64 left, right, size = self.left, self.right, self.size 65 u = right[node] 66 size[u] = size[node] 67 size[node] -= size[right[u]] + 1 68 right[node] = left[u] 69 left[u] = node 70 return u 71 72 def _rotate_right(self, node: int) -> int: 73 left, right, size = self.left, self.right, self.size 74 u = left[node] 75 size[u] = size[node] 76 size[node] -= size[left[u]] + 1 77 left[node] = right[u] 78 right[u] = node 79 return u 80 81 def _make_node(self, key: T) -> int: 82 end = self.end 83 if end >= len(self.key): 84 self.key.append(key) 85 self.size.append(1) 86 self.left.append(0) 87 self.right.append(0) 88 else: 89 self.key[end] = key 90 self.end += 1 91 return end 92 93 def _weight_left(self, node: int) -> int: 94 return self.size[self.left[node]] + 1 95 96 def _weight_right(self, node: int) -> int: 97 return self.size[self.right[node]] + 1 98
[docs] 99 def ave_height(self): 100 if not self.root: 101 return 0 102 left, right, size, keys = self.left, self.right, self.size, self.key 103 ans = 0 104 105 def dfs(node, dep): 106 nonlocal ans 107 ans += dep 108 if left[node]: 109 dfs(left[node], dep + 1) 110 if right[node]: 111 dfs(right[node], dep + 1) 112 113 dfs(self.root, 1) 114 ans /= len(self) 115 return ans
116
[docs] 117 def debug(self, root): 118 left, right, size, keys = self.left, self.right, self.size, self.key 119 120 def dfs(node, indent): 121 if not node: 122 return 123 s = " " * indent 124 print(f"{s}key={keys[node]}, idx={node}") 125 if left[node]: 126 print(f"{s}left: {keys[left[node]]}, idx={left[node]}") 127 dfs(left[node], indent + 2) 128 if right[node]: 129 print(f"{s}righ: {keys[right[node]]}, idx={right[node]}") 130 dfs(right[node], indent + 2) 131 132 dfs(root, 0)
133
[docs] 134 def add(self, key: T) -> bool: 135 if self.root == 0: 136 self.root = self._make_node(key) 137 return True 138 left, right, size, keys = self.left, self.right, self.size, self.key 139 node = self.root 140 path = [] 141 di = 0 142 while node: 143 if key == keys[node]: 144 return False 145 path.append(node) 146 di <<= 1 147 if key < keys[node]: 148 di |= 1 149 node = left[node] 150 else: 151 node = right[node] 152 # self.debug(self.root) 153 if di & 1: 154 left[path[-1]] = self._make_node(key) 155 else: 156 right[path[-1]] = self._make_node(key) 157 while path: 158 node = path.pop() 159 size[node] += 1 160 di >>= 1 161 wl = self._weight_left(node) 162 wr = self._weight_right(node) 163 if wl * DELTA < wr: 164 # print("wl * DELTA < wr") 165 # self.debug(node) 166 if ( 167 self._weight_left(right[node]) 168 >= self._weight_right(right[node]) * GAMMA 169 ): 170 right[node] = self._rotate_right(right[node]) 171 node = self._rotate_left(node) 172 # self.debug(node) 173 # assert node 174 elif wr * DELTA < wl: 175 # print("wr * DELTA < wl") 176 if ( 177 self._weight_right(left[node]) 178 >= self._weight_left(left[node]) * GAMMA 179 ): 180 # print("left") 181 left[node] = self._rotate_left(left[node]) 182 # print("right") 183 node = self._rotate_right(node) 184 # assert node 185 if path: 186 if di & 1: 187 left[path[-1]] = node 188 else: 189 right[path[-1]] = node 190 else: 191 self.root = node 192 return True
193
[docs] 194 def remove(self, key: T) -> bool: 195 if self.discard(key): 196 return True 197 raise KeyError(key)
198
[docs] 199 def discard(self, key: T) -> bool: 200 left, right, size, keys = self.left, self.right, self.size, self.key 201 di = 0 202 path = [] 203 node = self.root 204 while node: 205 if key == keys[node]: 206 break 207 path.append(node) 208 di <<= 1 209 if key < keys[node]: 210 di |= 1 211 node = left[node] 212 else: 213 node = right[node] 214 else: 215 return False 216 if left[node] and right[node]: 217 path.append(node) 218 di <<= 1 219 di |= 1 220 lmax = left[node] 221 while right[lmax]: 222 path.append(lmax) 223 di <<= 1 224 lmax = right[lmax] 225 keys[node] = keys[lmax] 226 node = lmax 227 cnode = right[node] if left[node] == 0 else left[node] 228 if path: 229 if di & 1: 230 left[path[-1]] = cnode 231 else: 232 right[path[-1]] = cnode 233 else: 234 self.root = cnode 235 return True 236 while path: 237 node = path.pop() 238 size[node] -= 1 239 di >>= 1 240 wl = self._weight_left(node) 241 wr = self._weight_right(node) 242 if wl * DELTA < wr: 243 if ( 244 self._weight_left(right[node]) 245 >= self._weight_right(right[node]) * GAMMA 246 ): 247 right[node] = self._rotate_right(right[node]) 248 node = self._rotate_left(node) 249 elif wr * DELTA < wl: 250 if ( 251 self._weight_right(left[node]) 252 >= self._weight_left(left[node]) * GAMMA 253 ): 254 left[node] = self._rotate_left(left[node]) 255 node = self._rotate_right(node) 256 if path: 257 if di & 1: 258 left[path[-1]] = node 259 else: 260 right[path[-1]] = node 261 else: 262 self.root = node 263 return True
264
[docs] 265 def le(self, key: T) -> Optional[T]: 266 keys, left, right = self.key, self.left, self.right 267 res = None 268 node = self.root 269 while node: 270 if key == keys[node]: 271 return keys[node] 272 if key < keys[node]: 273 node = left[node] 274 else: 275 res = keys[node] 276 node = right[node] 277 return res
278
[docs] 279 def lt(self, key: T) -> Optional[T]: 280 keys, left, right = self.key, self.left, self.right 281 res = None 282 node = self.root 283 while node: 284 if key <= keys[node]: 285 node = left[node] 286 else: 287 res = keys[node] 288 node = right[node] 289 return res
290
[docs] 291 def ge(self, key: T) -> Optional[T]: 292 keys, left, right = self.key, self.left, self.right 293 res = None 294 node = self.root 295 while node: 296 if key == keys[node]: 297 return keys[node] 298 if key < keys[node]: 299 res = keys[node] 300 node = left[node] 301 else: 302 node = right[node] 303 return res
304
[docs] 305 def gt(self, key: T) -> Optional[T]: 306 keys, left, right = self.key, self.left, self.right 307 res = None 308 node = self.root 309 while node: 310 if key < keys[node]: 311 res = keys[node] 312 node = left[node] 313 else: 314 node = right[node] 315 return res
316
[docs] 317 def index(self, key: T) -> int: 318 keys, left, right, size = self.key, self.left, self.right, self.size 319 k = 0 320 node = self.root 321 while node: 322 if key == keys[node]: 323 k += size[left[node]] 324 break 325 if key < keys[node]: 326 node = left[node] 327 else: 328 k += size[left[node]] + 1 329 node = right[node] 330 return k
331
[docs] 332 def index_right(self, key: T) -> int: 333 keys, left, right, size = self.key, self.left, self.right, self.size 334 k, node = 0, self.root 335 while node: 336 if key == keys[node]: 337 k += size[left[node]] + 1 338 break 339 if key < keys[node]: 340 node = left[node] 341 else: 342 k += size[left[node]] + 1 343 node = right[node] 344 return k
345
[docs] 346 def get_max(self) -> Optional[T]: 347 if not self: 348 return 349 return self[len(self) - 1]
350
[docs] 351 def get_min(self) -> Optional[T]: 352 if not self: 353 return 354 return self[0]
355
[docs] 356 def pop(self, k: int = -1) -> T: 357 left, right, size, key = self.left, self.right, self.size, self.key 358 if k < 0: 359 k += size[self.root] 360 assert 0 <= k and k < size[self.root], "IndexError" 361 path = [] 362 di = 0 363 node = self.root 364 while True: 365 t = size[left[node]] 366 if t == k: 367 res = key[node] 368 break 369 path.append(node) 370 di <<= 1 371 if t < k: 372 k -= t + 1 373 node = right[node] 374 else: 375 di |= 1 376 node = left[node] 377 if left[node] and right[node]: 378 path.append(node) 379 di <<= 1 380 di |= 1 381 lmax = left[node] 382 while right[lmax]: 383 path.append(lmax) 384 di <<= 1 385 lmax = right[lmax] 386 key[node] = key[lmax] 387 node = lmax 388 cnode = right[node] if left[node] == 0 else left[node] 389 if path: 390 if di & 1: 391 left[path[-1]] = cnode 392 else: 393 right[path[-1]] = cnode 394 else: 395 self.root = cnode 396 return res 397 while path: 398 node = path.pop() 399 size[node] -= 1 400 di >>= 1 401 wl = self._weight_left(node) 402 wr = self._weight_right(node) 403 if wl * DELTA < wr: 404 if ( 405 self._weight_left(right[node]) 406 >= self._weight_right(right[node]) * GAMMA 407 ): 408 right[node] = self._rotate_right(right[node]) 409 node = self._rotate_left(node) 410 elif wr * DELTA < wl: 411 if ( 412 self._weight_right(left[node]) 413 >= self._weight_left(left[node]) * GAMMA 414 ): 415 left[node] = self._rotate_left(left[node]) 416 node = self._rotate_right(node) 417 if path: 418 if di & 1: 419 left[path[-1]] = node 420 else: 421 right[path[-1]] = node 422 else: 423 self.root = node 424 return res
425
[docs] 426 def pop_max(self) -> T: 427 return self.pop()
428
[docs] 429 def pop_min(self) -> T: 430 return self.pop(0)
431
[docs] 432 def clear(self) -> None: 433 self.root = 0
434
[docs] 435 def tolist(self) -> list[T]: 436 left, right, keys = self.left, self.right, self.key 437 node = self.root 438 stack, a = [], [] 439 while stack or node: 440 if node: 441 stack.append(node) 442 node = left[node] 443 else: 444 node = stack.pop() 445 a.append(keys[node]) 446 node = right[node] 447 return a
448
[docs] 449 def check(self) -> int: 450 """作業用デバック関数""" 451 if self.root == 0: 452 return 0 453 454 def _balance_check(node: int) -> None: 455 if node == 0: 456 return 457 if not self._weight_left(node) * DELTA >= self._weight_right(node): 458 print(self._weight_left(node), self._weight_right(node), flush=True) 459 assert False, f"self._weight_left() * DELTA >= self._weight_right()" 460 if not self._weight_right(node) * DELTA >= self._weight_left(node): 461 print(self._weight_left(node), self._weight_right(node), flush=True) 462 assert False, f"self._weight_right() * DELTA >= self._weight_left()" 463 464 keys = self.key 465 466 # _size, height 467 def dfs(node) -> tuple[int, int]: 468 _balance_check(node) 469 h = 0 470 s = 1 471 if self.left[node]: 472 assert keys[self.left[node]] < keys[node] 473 ls, lh = dfs(self.left[node]) 474 s += ls 475 h = max(h, lh) 476 if self.right[node]: 477 assert keys[node] < keys[self.right[node]] 478 rs, rh = dfs(self.right[node]) 479 s += rs 480 h = max(h, rh) 481 assert self.size[node] == s 482 return s, h + 1 483 484 _, h = dfs(self.root) 485 return h
486 487 def __contains__(self, key: T) -> bool: 488 keys, left, right = self.key, self.left, self.right 489 node = self.root 490 while node: 491 if key == keys[node]: 492 return True 493 node = left[node] if key < keys[node] else right[node] 494 return False 495 496 def __getitem__(self, k: int) -> T: 497 left, right, size, key = self.left, self.right, self.size, self.key 498 if k < 0: 499 k += size[self.root] 500 assert ( 501 0 <= k and k < size[self.root] 502 ), f"IndexError: WBTSet[{k}], len={len(self)}" 503 node = self.root 504 while True: 505 t = size[left[node]] 506 if t == k: 507 return key[node] 508 if t < k: 509 k -= t + 1 510 node = right[node] 511 else: 512 node = left[node] 513 514 def __iter__(self): 515 self.__iter = 0 516 return self 517 518 def __next__(self): 519 if self.__iter == self.__len__(): 520 raise StopIteration 521 res = self[self.__iter] 522 self.__iter += 1 523 return res 524 525 def __reversed__(self): 526 for i in range(self.__len__()): 527 yield self[-i - 1] 528 529 def __len__(self): 530 return self.size[self.root] 531 532 def __bool__(self): 533 return self.root != 0 534 535 def __str__(self): 536 return "{" + ", ".join(map(str, self.tolist())) + "}" 537 538 def __repr__(self): 539 return f"WBTSet({self})"