Source code for titan_pylib.data_structures.avl_tree.avl_tree_set3

  1from titan_pylib.my_class.ordered_set_interface import OrderedSetInterface
  2from titan_pylib.my_class.supports_less_than import SupportsLessThan
  3from typing import Generic, Iterable, TypeVar, Optional, Sequence
  4
  5T = TypeVar("T", bound=SupportsLessThan)
  6
  7
[docs] 8class AVLTreeSet3(OrderedSetInterface, Generic[T]): 9 """ 10 集合としての AVL木 です。 11 size を持ちます。 12 ``class Node()`` を用いています。 13 """ 14
[docs] 15 class Node: 16 17 def __init__(self, key: T): 18 self.key: T = key 19 self.size: int = 1 20 self.left: Optional["AVLTreeSet3.Node"] = None 21 self.right: Optional["AVLTreeSet3.Node"] = None 22 self.balance: int = 0 23 24 def __str__(self): 25 if self.left is None and self.right is None: 26 return f"key:{self.key, self.size}\n" 27 return ( 28 f"key:{self.key, self.size},\n left:{self.left},\n right:{self.right}\n" 29 )
30 31 def __init__(self, a: Iterable[T] = []) -> None: 32 self.node = None 33 if not isinstance(a, Sequence): 34 a = list(a) 35 if a: 36 self._build(a) 37 38 def _build(self, a: Sequence[T]) -> None: 39 Node = AVLTreeSet3.Node 40 41 def rec(l: int, r: int) -> tuple[AVLTreeSet3.Node, int]: 42 mid = (l + r) >> 1 43 node = Node(a[mid]) 44 hl, hr = 0, 0 45 if l != mid: 46 node.left, hl = rec(l, mid) 47 node.size += node.left.size 48 if mid + 1 != r: 49 node.right, hr = rec(mid + 1, r) 50 node.size += node.right.size 51 node.balance = hl - hr 52 return node, max(hl, hr) + 1 53 54 if not all(a[i] < a[i + 1] for i in range(len(a) - 1)): 55 a = sorted(set(a)) 56 self.node = rec(0, len(a))[0] 57 58 def _rotate_L(self, node: Node) -> Node: 59 u = node.left 60 u.size = node.size 61 node.size -= 1 if u.left is None else u.left.size + 1 62 node.left = u.right 63 u.right = node 64 if u.balance == 1: 65 u.balance = 0 66 node.balance = 0 67 else: 68 u.balance = -1 69 node.balance = 1 70 return u 71 72 def _rotate_R(self, node: Node) -> Node: 73 u = node.right 74 u.size = node.size 75 node.size -= 1 if u.right is None else u.right.size + 1 76 node.right = u.left 77 u.left = node 78 if u.balance == -1: 79 u.balance = 0 80 node.balance = 0 81 else: 82 u.balance = 1 83 node.balance = -1 84 return u 85 86 def _update_balance(self, node: Node) -> None: 87 if node.balance == 1: 88 node.right.balance = -1 89 node.left.balance = 0 90 elif node.balance == -1: 91 node.right.balance = 0 92 node.left.balance = 1 93 else: 94 node.right.balance = 0 95 node.left.balance = 0 96 node.balance = 0 97 98 def _rotate_LR(self, node: Node) -> Node: 99 B = node.left 100 E = B.right 101 E.size = node.size 102 if E.right is None: 103 node.size -= B.size 104 B.size -= 1 105 else: 106 node.size -= B.size - E.right.size 107 B.size -= E.right.size + 1 108 B.right = E.left 109 E.left = B 110 node.left = E.right 111 E.right = node 112 self._update_balance(E) 113 return E 114 115 def _rotate_RL(self, node: Node) -> Node: 116 C = node.right 117 D = C.left 118 D.size = node.size 119 if D.left is None: 120 node.size -= C.size 121 C.size -= 1 122 else: 123 node.size -= C.size - D.left.size 124 C.size -= D.left.size + 1 125 C.left = D.right 126 D.right = C 127 node.right = D.left 128 D.left = node 129 self._update_balance(D) 130 return D 131 132 def _kth_elm(self, k: int) -> T: 133 if k < 0: 134 k += self.node.size 135 node = self.node 136 while True: 137 t = 0 if node.left is None else node.left.size 138 if t == k: 139 return node.key 140 elif t < k: 141 k -= t + 1 142 node = node.right 143 else: 144 node = node.left 145
[docs] 146 def add(self, key: T) -> bool: 147 if self.node is None: 148 self.node = AVLTreeSet3.Node(key) 149 return True 150 pnode = self.node 151 path = [] 152 di = 0 153 while pnode is not None: 154 if key == pnode.key: 155 return False 156 elif key < pnode.key: 157 path.append(pnode) 158 di <<= 1 159 di |= 1 160 pnode = pnode.left 161 else: 162 path.append(pnode) 163 di <<= 1 164 pnode = pnode.right 165 if di & 1: 166 path[-1].left = AVLTreeSet3.Node(key) 167 else: 168 path[-1].right = AVLTreeSet3.Node(key) 169 new_node = None 170 while path: 171 pnode = path.pop() 172 pnode.size += 1 173 pnode.balance += 1 if di & 1 else -1 174 di >>= 1 175 if pnode.balance == 0: 176 break 177 if pnode.balance == 2: 178 new_node = ( 179 self._rotate_LR(pnode) 180 if pnode.left.balance == -1 181 else self._rotate_L(pnode) 182 ) 183 break 184 elif pnode.balance == -2: 185 new_node = ( 186 self._rotate_RL(pnode) 187 if pnode.right.balance == 1 188 else self._rotate_R(pnode) 189 ) 190 break 191 if new_node is not None: 192 if path: 193 gnode = path.pop() 194 gnode.size += 1 195 if di & 1: 196 gnode.left = new_node 197 else: 198 gnode.right = new_node 199 else: 200 self.node = new_node 201 for p in path: 202 p.size += 1 203 return True
204
[docs] 205 def discard(self, key: T) -> bool: 206 di = 0 207 path = [] 208 node = self.node 209 while node: 210 if key == node.key: 211 break 212 elif key < node.key: 213 path.append(node) 214 di <<= 1 215 di |= 1 216 node = node.left 217 else: 218 path.append(node) 219 di <<= 1 220 node = node.right 221 else: 222 return False 223 if node.left and node.right: 224 path.append(node) 225 di <<= 1 226 di |= 1 227 lmax = node.left 228 while lmax.right: 229 path.append(lmax) 230 di <<= 1 231 lmax = lmax.right 232 node.key = lmax.key 233 node = lmax 234 cnode = node.right if node.left is None else node.left 235 if path: 236 if di & 1: 237 path[-1].left = cnode 238 else: 239 path[-1].right = cnode 240 else: 241 self.node = cnode 242 return True 243 while path: 244 new_node = None 245 pnode = path.pop() 246 pnode.balance -= 1 if di & 1 else -1 247 di >>= 1 248 pnode.size -= 1 249 if pnode.balance == 2: 250 new_node = ( 251 self._rotate_LR(pnode) 252 if pnode.left.balance == -1 253 else self._rotate_L(pnode) 254 ) 255 elif pnode.balance == -2: 256 new_node = ( 257 self._rotate_RL(pnode) 258 if pnode.right.balance == 1 259 else self._rotate_R(pnode) 260 ) 261 elif pnode.balance != 0: 262 break 263 if new_node: 264 if not path: 265 self.node = new_node 266 return True 267 if di & 1: 268 path[-1].left = new_node 269 else: 270 path[-1].right = new_node 271 if new_node.balance != 0: 272 break 273 for p in path: 274 p.size -= 1 275 return True
276
[docs] 277 def remove(self, key: T) -> None: 278 if self.discard(key): 279 return 280 raise KeyError(key)
281
[docs] 282 def le(self, key: T) -> Optional[T]: 283 res = None 284 node = self.node 285 while node is not None: 286 if key == node.key: 287 res = key 288 break 289 elif key < node.key: 290 node = node.left 291 else: 292 res = node.key 293 node = node.right 294 return res
295
[docs] 296 def lt(self, key: T) -> Optional[T]: 297 res = None 298 node = self.node 299 while node is not None: 300 if key <= node.key: 301 node = node.left 302 else: 303 res = node.key 304 node = node.right 305 return res
306
[docs] 307 def ge(self, key: T) -> Optional[T]: 308 res = None 309 node = self.node 310 while node is not None: 311 if key == node.key: 312 res = key 313 break 314 elif key < node.key: 315 res = node.key 316 node = node.left 317 else: 318 node = node.right 319 return res
320
[docs] 321 def gt(self, key: T) -> Optional[T]: 322 res = None 323 node = self.node 324 while node is not None: 325 if key < node.key: 326 res = node.key 327 node = node.left 328 else: 329 node = node.right 330 return res
331
[docs] 332 def index(self, key: T) -> int: 333 k = 0 334 node = self.node 335 while node is not None: 336 if key == node.key: 337 k += 0 if node.left is None else node.left.size 338 break 339 elif key < node.key: 340 node = node.left 341 else: 342 k += 1 if node.left is None else node.left.size + 1 343 node = node.right 344 return k
345
[docs] 346 def index_right(self, key: T) -> int: 347 k = 0 348 node = self.node 349 while node is not None: 350 if key == node.key: 351 k += 1 if node.left is None else node.left.size + 1 352 break 353 elif key < node.key: 354 node = node.left 355 else: 356 k += 1 if node.left is None else node.left.size + 1 357 node = node.right 358 return k
359
[docs] 360 def pop(self, k: int = -1) -> T: 361 assert ( 362 self.node is not None 363 ), f"IndexError: {self.__class__.__name__}.pop({k}), pop({k}) from Empty {self.__class__.__name__}" 364 x = self._kth_elm(k) 365 self.discard(x) 366 return x
367
[docs] 368 def pop_max(self) -> T: 369 assert ( 370 self.node is not None 371 ), f"IndexError: {self.__class__.__name__}.pop_max(), pop_max from Empty {self.__class__.__name__}" 372 return self.pop()
373
[docs] 374 def pop_min(self) -> T: 375 assert ( 376 self.node is not None 377 ), f"IndexError: {self.__class__.__name__}.pop_min(), pop_min from Empty {self.__class__.__name__}" 378 return self.pop(0)
379
[docs] 380 def get_max(self) -> Optional[T]: 381 if self.node is None: 382 return 383 return self._kth_elm(-1)
384
[docs] 385 def get_min(self) -> Optional[T]: 386 if self.node is None: 387 return 388 return self._kth_elm(0)
389
[docs] 390 def clear(self) -> None: 391 self.node = None
392
[docs] 393 def tolist(self) -> list[T]: 394 a = [] 395 if self.node is None: 396 return a 397 398 def rec(node): 399 if node.left is not None: 400 rec(node.left) 401 a.append(node.key) 402 if node.right is not None: 403 rec(node.right) 404 405 rec(self.node) 406 return a
407 408 def __contains__(self, key: T) -> bool: 409 node = self.node 410 while node is not None: 411 if key == node.key: 412 return True 413 elif key < node.key: 414 node = node.left 415 else: 416 node = node.right 417 return False 418 419 def __getitem__(self, k: int) -> T: 420 assert ( 421 -len(self) <= k < len(self) 422 ), f"IndexError: {self.__class__.__name__}.__getitem__({k}), len={len(self)}" 423 return self._kth_elm(k) 424 425 def __iter__(self): 426 self.__iter = 0 427 return self 428 429 def __next__(self): 430 if self.__iter == self.__len__(): 431 raise StopIteration 432 res = self.__getitem__(self.__iter) 433 self.__iter += 1 434 return res 435 436 def __reversed__(self): 437 for i in range(self.__len__()): 438 yield self.__getitem__(-i - 1) 439 440 def __len__(self): 441 return 0 if self.node is None else self.node.size 442 443 def __bool__(self): 444 return self.node is not None 445 446 def __str__(self): 447 return "{" + ", ".join(map(str, self.tolist())) + "}" 448 449 def __repr__(self): 450 return f"AVLTreeSet3({str(self)})"