Source code for titan_pylib.data_structures.avl_tree.avl_tree_set2

  1from titan_pylib.my_class.supports_less_than import SupportsLessThan
  2from titan_pylib.my_class.ordered_set_interface import OrderedSetInterface
  3from array import array
  4from typing import Generic, Iterable, TypeVar, Optional
  5
  6T = TypeVar("T", bound=SupportsLessThan)
  7
  8
[docs] 9class AVLTreeSet2(OrderedSetInterface, Generic[T]): 10 """AVLTreeSet2 11 集合としての AVL 木です。 12 配列を用いてノードを表現しています。 13 size を持たないので軽めです。 14 """ 15 16 def __init__(self, a: Iterable[T] = []) -> None: 17 self.root = 0 18 self._len = 0 19 self.key = [0] 20 self.left = array("I", bytes(4)) 21 self.right = array("I", bytes(4)) 22 self.balance = array("b", bytes(1)) 23 self.end = 1 24 if not isinstance(a, list): 25 a = list(a) 26 if a: 27 self._build(a) 28
[docs] 29 def reserve(self, n: int) -> None: 30 self.key += [0] * n 31 a = array("I", bytes(4 * n)) 32 self.left += a 33 self.right += a 34 self.balance += array("b", bytes(n))
35 36 def _build(self, a: list[T]) -> None: 37 left, right, balance = self.left, self.right, self.balance 38 39 def sort(l: int, r: int) -> tuple[int, int]: 40 mid = (l + r) >> 1 41 node = mid 42 hl, hr = 0, 0 43 if l != mid: 44 left[node], hl = sort(l, mid) 45 if mid + 1 != r: 46 right[node], hr = sort(mid + 1, r) 47 balance[node] = hl - hr 48 return node, max(hl, hr) + 1 49 50 n = len(a) 51 if n == 0: 52 return 53 if not all(a[i] < a[i + 1] for i in range(n - 1)): 54 b = sorted(a) 55 a = [b[0]] 56 for i in range(1, n): 57 if b[i] != a[-1]: 58 a.append(b[i]) 59 n = len(a) 60 self._len = n 61 end = self.end 62 self.end += n 63 self.reserve(n) 64 self.key[end : end + n] = a 65 self.root = sort(end, n + end)[0] 66 67 def _rotate_L(self, node: int) -> int: 68 left, right, balance = self.left, self.right, self.balance 69 u = left[node] 70 left[node] = right[u] 71 right[u] = node 72 if balance[u] == 1: 73 balance[u] = 0 74 balance[node] = 0 75 else: 76 balance[u] = -1 77 balance[node] = 1 78 return u 79 80 def _rotate_R(self, node: int) -> int: 81 left, right, balance = self.left, self.right, self.balance 82 u = right[node] 83 right[node] = left[u] 84 left[u] = node 85 if balance[u] == -1: 86 balance[u] = 0 87 balance[node] = 0 88 else: 89 balance[u] = 1 90 balance[node] = -1 91 return u 92 93 def _update_balance(self, node: int) -> None: 94 balance = self.balance 95 if balance[node] == 1: 96 balance[self.right[node]] = -1 97 balance[self.left[node]] = 0 98 elif balance[node] == -1: 99 balance[self.right[node]] = 0 100 balance[self.left[node]] = 1 101 else: 102 balance[self.right[node]] = 0 103 balance[self.left[node]] = 0 104 balance[node] = 0 105 106 def _rotate_LR(self, node: int) -> int: 107 left, right = self.left, self.right 108 B = left[node] 109 E = right[B] 110 right[B] = left[E] 111 left[E] = B 112 left[node] = right[E] 113 right[E] = node 114 self._update_balance(E) 115 return E 116 117 def _rotate_RL(self, node: int) -> int: 118 left, right = self.left, self.right 119 C = right[node] 120 D = left[C] 121 left[C] = right[D] 122 right[D] = C 123 right[node] = left[D] 124 left[D] = node 125 self._update_balance(D) 126 return D 127 128 def _make_node(self, key: T) -> int: 129 end = self.end 130 if end >= len(self.key): 131 self.key.append(key) 132 self.left.append(0) 133 self.right.append(0) 134 self.balance.append(0) 135 else: 136 self.key[end] = key 137 self.end += 1 138 return end 139
[docs] 140 def add(self, key: T) -> bool: 141 if self.root == 0: 142 self.root = self._make_node(key) 143 self._len = 1 144 return True 145 left, right, balance, keys = self.left, self.right, self.balance, self.key 146 node = self.root 147 path = [] 148 di = 0 149 while node: 150 if key == keys[node]: 151 return False 152 di <<= 1 153 path.append(node) 154 if key < keys[node]: 155 di |= 1 156 node = left[node] 157 else: 158 node = right[node] 159 self._len += 1 160 if di & 1: 161 left[path[-1]] = self._make_node(key) 162 else: 163 right[path[-1]] = self._make_node(key) 164 new_node = 0 165 while path: 166 pnode = path.pop() 167 balance[pnode] += 1 if di & 1 else -1 168 di >>= 1 169 if balance[pnode] == 0: 170 break 171 if balance[pnode] == 2: 172 new_node = ( 173 self._rotate_LR(pnode) 174 if balance[left[pnode]] == -1 175 else self._rotate_L(pnode) 176 ) 177 break 178 elif balance[pnode] == -2: 179 new_node = ( 180 self._rotate_RL(pnode) 181 if balance[right[pnode]] == 1 182 else self._rotate_R(pnode) 183 ) 184 break 185 if new_node: 186 if path: 187 gnode = path.pop() 188 if di & 1: 189 left[gnode] = new_node 190 else: 191 right[gnode] = new_node 192 else: 193 self.root = new_node 194 return True
195
[docs] 196 def remove(self, key: T) -> bool: 197 if self.discard(key): 198 return True 199 raise KeyError(key)
200
[docs] 201 def discard(self, key: T) -> bool: 202 left, right, balance, keys = self.left, self.right, self.balance, self.key 203 di = 0 204 path = [] 205 node = self.root 206 while node: 207 if key == keys[node]: 208 break 209 path.append(node) 210 di <<= 1 211 if key < keys[node]: 212 di |= 1 213 node = left[node] 214 else: 215 node = right[node] 216 else: 217 return False 218 self._len -= 1 219 if left[node] and right[node]: 220 path.append(node) 221 di <<= 1 222 di |= 1 223 lmax = left[node] 224 while right[lmax]: 225 path.append(lmax) 226 di <<= 1 227 lmax = right[lmax] 228 keys[node] = keys[lmax] 229 node = lmax 230 cnode = right[node] if left[node] == 0 else left[node] 231 if path: 232 if di & 1: 233 left[path[-1]] = cnode 234 else: 235 right[path[-1]] = cnode 236 else: 237 self.root = cnode 238 return True 239 while path: 240 new_node = 0 241 pnode = path.pop() 242 balance[pnode] -= 1 if di & 1 else -1 243 di >>= 1 244 if balance[pnode] == 2: 245 new_node = ( 246 self._rotate_LR(pnode) 247 if balance[left[pnode]] == -1 248 else self._rotate_L(pnode) 249 ) 250 elif balance[pnode] == -2: 251 new_node = ( 252 self._rotate_RL(pnode) 253 if balance[right[pnode]] == 1 254 else self._rotate_R(pnode) 255 ) 256 elif balance[pnode]: 257 break 258 if new_node: 259 if not path: 260 self.root = new_node 261 return True 262 if di & 1: 263 left[path[-1]] = new_node 264 else: 265 right[path[-1]] = new_node 266 if balance[new_node]: 267 break 268 return True
269
[docs] 270 def le(self, key: T) -> Optional[T]: 271 keys, left, right = self.key, self.left, self.right 272 res = None 273 node = self.root 274 while node: 275 if key == keys[node]: 276 res = key 277 break 278 if key < keys[node]: 279 node = left[node] 280 else: 281 res = keys[node] 282 node = right[node] 283 return res
284
[docs] 285 def lt(self, key: T) -> Optional[T]: 286 keys, left, right = self.key, self.left, self.right 287 res = None 288 node = self.root 289 while node: 290 if key <= keys[node]: 291 node = left[node] 292 else: 293 res = keys[node] 294 node = right[node] 295 return res
296
[docs] 297 def ge(self, key: T) -> Optional[T]: 298 keys, left, right = self.key, self.left, self.right 299 res = None 300 node = self.root 301 while node: 302 if key == keys[node]: 303 res = key 304 break 305 if key < keys[node]: 306 res = keys[node] 307 node = left[node] 308 else: 309 node = right[node] 310 return res
311
[docs] 312 def gt(self, key: T) -> Optional[T]: 313 keys, left, right = self.key, self.left, self.right 314 res = None 315 node = self.root 316 while node: 317 if key < keys[node]: 318 res = keys[node] 319 node = left[node] 320 else: 321 node = right[node] 322 return res
323
[docs] 324 def get_max(self) -> Optional[T]: 325 if not self: 326 return 327 right = self.right 328 node = self.root 329 while right[node]: 330 node = right[node] 331 return self.key[node]
332
[docs] 333 def get_min(self) -> Optional[T]: 334 if not self: 335 return 336 left = self.left 337 node = self.root 338 while left[node]: 339 node = left[node] 340 return self.key[node]
341
[docs] 342 def pop_min(self) -> T: 343 self._len -= 1 344 left, right, balance, keys = self.left, self.right, self.balance, self.key 345 path = [] 346 node = self.root 347 while left[node]: 348 path.append(node) 349 node = left[node] 350 res = keys[node] 351 cnode = right[node] 352 if path: 353 left[path[-1]] = cnode 354 else: 355 self.root = cnode 356 return res 357 while path: 358 new_node = 0 359 pnode = path.pop() 360 balance[pnode] -= 1 361 if balance[pnode] == 2: 362 new_node = ( 363 self._rotate_LR(pnode) 364 if balance[left[pnode]] == -1 365 else self._rotate_L(pnode) 366 ) 367 elif balance[pnode] == -2: 368 new_node = ( 369 self._rotate_RL(pnode) 370 if balance[right[pnode]] == 1 371 else self._rotate_R(pnode) 372 ) 373 elif balance[pnode]: 374 break 375 if new_node: 376 if not path: 377 self.root = new_node 378 return res 379 left[path[-1]] = new_node 380 if balance[new_node]: 381 break 382 return res
383
[docs] 384 def pop_max(self) -> T: 385 self._len -= 1 386 left, right, balance, keys = self.left, self.right, self.balance, self.key 387 path = [] 388 node = self.root 389 while right[node]: 390 path.append(node) 391 node = right[node] 392 res = keys[node] 393 cnode = right[node] if left[node] == 0 else left[node] 394 if path: 395 right[path[-1]] = cnode 396 else: 397 self.root = cnode 398 return res 399 while path: 400 new_node = 0 401 pnode = path.pop() 402 balance[pnode] += 1 403 if balance[pnode] == 2: 404 new_node = ( 405 self._rotate_LR(pnode) 406 if balance[left[pnode]] == -1 407 else self._rotate_L(pnode) 408 ) 409 elif balance[pnode] == -2: 410 new_node = ( 411 self._rotate_RL(pnode) 412 if balance[right[pnode]] == 1 413 else self._rotate_R(pnode) 414 ) 415 elif balance[pnode]: 416 break 417 if new_node: 418 if not path: 419 self.root = new_node 420 return res 421 right[path[-1]] = new_node 422 if balance[new_node]: 423 break 424 return res
425
[docs] 426 def clear(self) -> None: 427 self.root = 0
428
[docs] 429 def tolist(self) -> list[T]: 430 left, right, keys = self.left, self.right, self.key 431 node = self.root 432 stack, a = [], [] 433 while stack or node: 434 if node: 435 stack.append(node) 436 node = left[node] 437 else: 438 node = stack.pop() 439 a.append(keys[node]) 440 node = right[node] 441 return a
442 443 def __contains__(self, key: T) -> bool: 444 keys, left, right = self.key, self.left, self.right 445 node = self.root 446 while node: 447 if key == keys[node]: 448 return True 449 node = left[node] if key < keys[node] else right[node] 450 return False 451 452 def __iter__(self): 453 self.it = self.get_min() 454 return self 455 456 def __next__(self): 457 if self.it is None: 458 raise StopIteration 459 res = self.it 460 self.it = self.gt(res) 461 return res 462 463 def __len__(self): 464 return self._len 465 466 def __bool__(self): 467 return self.root != 0 468 469 def __str__(self): 470 return "{" + ", ".join(map(str, self.tolist())) + "}" 471 472 def __repr__(self): 473 return f"AVLTreeSet2({self})"