Source code for titan_pylib.data_structures.avl_tree.avl_tree_multiset2

  1from titan_pylib.my_class.supports_less_than import SupportsLessThan
  2from titan_pylib.data_structures.bst_base.bst_multiset_array_base import (
  3    BSTMultisetArrayBase,
  4)
  5from typing import Generic, Iterable, TypeVar, Optional
  6from array import array
  7
  8T = TypeVar("T", bound=SupportsLessThan)
  9
 10
[docs] 11class AVLTreeMultiset2(Generic[T]): 12 """ 13 多重集合としての AVL 木です。 14 配列を用いてノードを表現しています。 15 size を持たないので軽めです。 16 """ 17 18 def __init__(self, a: Iterable[T] = []): 19 self.root = 0 20 self._len = 0 21 self.key = [0] 22 self.val = [0] 23 self.left = array("I", bytes(4)) 24 self.right = array("I", bytes(4)) 25 self.balance = array("b", bytes(1)) 26 self.end = 1 27 if not isinstance(a, list): 28 a = list(a) 29 if a: 30 self._build(a) 31 32 def _make_node(self, key: T, val: int) -> int: 33 end = self.end 34 if end >= len(self.key): 35 self.key.append(key) 36 self.val.append(val) 37 self.left.append(0) 38 self.right.append(0) 39 self.balance.append(0) 40 else: 41 self.key[end] = key 42 self.val[end] = val 43 self.end += 1 44 return end 45
[docs] 46 def reserve(self, n: int) -> None: 47 a = [0] * n 48 self.key += a 49 self.val += a 50 a = array("I", bytes(4 * n)) 51 self.left += a 52 self.right += a 53 self.balance += array("b", bytes(n))
54 55 def _build(self, a: list[T]) -> None: 56 left, right, balance = self.left, self.right, self.balance 57 58 def sort(l: int, r: int) -> tuple[int, int]: 59 mid = (l + r) >> 1 60 node = mid 61 hl, hr = 0, 0 62 if l != mid: 63 left[node], hl = sort(l, mid) 64 if mid + 1 != r: 65 right[node], hr = sort(mid + 1, r) 66 balance[node] = hl - hr 67 return node, max(hl, hr) + 1 68 69 self._len = len(a) 70 if not all(a[i] < a[i + 1] for i in range(len(a) - 1)): 71 a = sorted(a) 72 x, y = BSTMultisetArrayBase[AVLTreeMultiset2, T]._rle(a) 73 n = len(x) 74 end = self.end 75 self.end += n 76 self.reserve(n) 77 self.key[end : end + n] = x 78 self.val[end : end + n] = y 79 self.root = sort(end, n + end)[0] 80 81 def _rotate_L(self, node: int) -> int: 82 left, right, balance = self.left, self.right, self.balance 83 u = left[node] 84 left[node] = right[u] 85 right[u] = node 86 if balance[u] == 1: 87 balance[u] = 0 88 balance[node] = 0 89 else: 90 balance[u] = -1 91 balance[node] = 1 92 return u 93 94 def _rotate_R(self, node: int) -> int: 95 left, right, balance = self.left, self.right, self.balance 96 u = right[node] 97 right[node] = left[u] 98 left[u] = node 99 if balance[u] == -1: 100 balance[u] = 0 101 balance[node] = 0 102 else: 103 balance[u] = 1 104 balance[node] = -1 105 return u 106 107 def _update_balance(self, node: int) -> None: 108 left, right, balance = self.left, self.right, self.balance 109 if balance[node] == 1: 110 balance[right[node]] = -1 111 balance[left[node]] = 0 112 elif balance[node] == -1: 113 balance[right[node]] = 0 114 balance[left[node]] = 1 115 else: 116 balance[right[node]] = 0 117 balance[left[node]] = 0 118 balance[node] = 0 119 120 def _rotate_LR(self, node: int) -> int: 121 left, right = self.left, self.right 122 B = left[node] 123 E = right[B] 124 right[B] = left[E] 125 left[E] = B 126 left[node] = right[E] 127 right[E] = node 128 self._update_balance(E) 129 return E 130 131 def _rotate_RL(self, node: int) -> int: 132 left, right = self.left, self.right 133 C = right[node] 134 D = left[C] 135 left[C] = right[D] 136 right[D] = C 137 right[node] = left[D] 138 left[D] = node 139 self._update_balance(D) 140 return D 141 142 def _discard(self, node: int, path: list[int], di: int) -> bool: 143 left, right, keys, vals, balance = ( 144 self.left, 145 self.right, 146 self.key, 147 self.val, 148 self.balance, 149 ) 150 if left[node] and right[node]: 151 path.append(node) 152 di <<= 1 153 di |= 1 154 lmax = left[node] 155 while right[lmax]: 156 path.append(lmax) 157 di <<= 1 158 lmax = right[lmax] 159 lmax_val = vals[lmax] 160 keys[node] = keys[lmax] 161 vals[node] = lmax_val 162 node = lmax 163 cnode = right[node] if left[node] == 0 else left[node] 164 if path: 165 if di & 1: 166 left[path[-1]] = cnode 167 else: 168 right[path[-1]] = cnode 169 else: 170 self.root = cnode 171 return True 172 while path: 173 new_node = 0 174 pnode = path.pop() 175 balance[pnode] -= 1 if di & 1 else -1 176 di >>= 1 177 if balance[pnode] == 2: 178 new_node = ( 179 self._rotate_LR(pnode) 180 if balance[left[pnode]] < 0 181 else self._rotate_L(pnode) 182 ) 183 elif balance[pnode] == -2: 184 new_node = ( 185 self._rotate_RL(pnode) 186 if balance[right[pnode]] > 0 187 else self._rotate_R(pnode) 188 ) 189 elif balance[pnode] != 0: 190 break 191 if new_node: 192 if not path: 193 self.root = new_node 194 return 195 if di & 1: 196 left[path[-1]] = new_node 197 else: 198 right[path[-1]] = new_node 199 if balance[new_node] != 0: 200 break 201 return True 202
[docs] 203 def discard(self, key: T, val: int = 1) -> bool: 204 keys, vals, left, right = self.key, self.val, self.left, self.right 205 path = [] 206 di = 0 207 node = self.root 208 while node: 209 if key == keys[node]: 210 break 211 path.append(node) 212 di <<= 1 213 if key < keys[node]: 214 di |= 1 215 node = left[node] 216 else: 217 node = right[node] 218 else: 219 return False 220 self._len -= min(val, vals[node]) 221 if val > vals[node]: 222 val = vals[node] - 1 223 vals[node] -= val 224 if vals[node] == 1: 225 self._discard(node, path, di) 226 else: 227 vals[node] -= val 228 return True
229
[docs] 230 def discard_all(self, key: T) -> None: 231 self.discard(key, self.count(key))
232
[docs] 233 def remove(self, key: T, val: int = 1) -> None: 234 if self.discard(key, val): 235 return 236 raise KeyError(key)
237
[docs] 238 def add(self, key: T, val: int = 1) -> None: 239 self._len += val 240 if self.root == 0: 241 self.root = self._make_node(key, val) 242 return 243 left, right, keys, balance = self.left, self.right, self.key, self.balance 244 node = self.root 245 di = 0 246 path = [] 247 while node: 248 if key == keys[node]: 249 self.val[node] += val 250 return 251 path.append(node) 252 di <<= 1 253 if key < keys[node]: 254 di |= 1 255 node = left[node] 256 else: 257 node = right[node] 258 if di & 1: 259 left[path[-1]] = self._make_node(key, val) 260 else: 261 right[path[-1]] = self._make_node(key, val) 262 new_node = 0 263 while path: 264 node = path.pop() 265 balance[node] += 1 if di & 1 else -1 266 di >>= 1 267 if balance[node] == 0: 268 break 269 if balance[node] == 2: 270 new_node = ( 271 self._rotate_LR(node) 272 if balance[left[node]] < 0 273 else self._rotate_L(node) 274 ) 275 break 276 elif balance[node] == -2: 277 new_node = ( 278 self._rotate_RL(node) 279 if balance[right[node]] > 0 280 else self._rotate_R(node) 281 ) 282 break 283 if new_node: 284 if path: 285 if di & 1: 286 left[path[-1]] = new_node 287 else: 288 right[path[-1]] = new_node 289 else: 290 self.root = new_node
291
[docs] 292 def count(self, key: T) -> int: 293 return BSTMultisetArrayBase[AVLTreeMultiset2, T].count(self, key)
294
[docs] 295 def le(self, key: T) -> Optional[T]: 296 return BSTMultisetArrayBase[AVLTreeMultiset2, T].le(self, key)
297
[docs] 298 def lt(self, key: T) -> Optional[T]: 299 return BSTMultisetArrayBase[AVLTreeMultiset2, T].lt(self, key)
300
[docs] 301 def ge(self, key: T) -> Optional[T]: 302 return BSTMultisetArrayBase[AVLTreeMultiset2, T].ge(self, key)
303
[docs] 304 def gt(self, key: T) -> Optional[T]: 305 return BSTMultisetArrayBase[AVLTreeMultiset2, T].gt(self, key)
306
[docs] 307 def get_min(self) -> Optional[T]: 308 if self.root == 0: 309 return 310 left = self.left 311 node = self.root 312 while left[node]: 313 node = left[node] 314 return self.key[node]
315
[docs] 316 def get_max(self) -> Optional[T]: 317 if self.root == 0: 318 return 319 right = self.right 320 node = self.root 321 while right[node]: 322 node = right[node] 323 return self.key[node]
324
[docs] 325 def pop_min(self) -> T: 326 left, vals, keys = self.left, self.val, self.key 327 self._len -= 1 328 node = self.root 329 path = [] 330 while left[node]: 331 path.append(node) 332 node = left[node] 333 x = keys[node] 334 if vals[node] == 1: 335 self._discard(node, path, (1 << len(path)) - 1) 336 else: 337 vals[node] -= 1 338 return x
339
[docs] 340 def pop_max(self) -> T: 341 right, vals, keys = self.right, self.val, self.key 342 self._len -= 1 343 node = self.root 344 path = [] 345 while right[node]: 346 path.append(node) 347 node = right[node] 348 x = keys[node] 349 if vals[node] == 1: 350 self._discard(node, path, 0) 351 else: 352 vals[node] -= 1 353 return x
354
[docs] 355 def clear(self) -> None: 356 self.root = 0
357
[docs] 358 def tolist(self) -> list[T]: 359 return BSTMultisetArrayBase[AVLTreeMultiset2, T].tolist(self)
360
[docs] 361 def tolist_items(self) -> list[tuple[T, int]]: 362 left, right, keys, vals = self.left, self.right, self.key, self.val 363 node = self.root 364 stack: list[int] = [] 365 a: list[tuple[T, int]] = [] 366 while stack or node: 367 if node: 368 stack.append(node) 369 node = left[node] 370 else: 371 node = stack.pop() 372 a.append((keys[node], vals[node])) 373 node = right[node] 374 return a
375 376 def __contains__(self, key: T): 377 return BSTMultisetArrayBase[AVLTreeMultiset2, T].contains(self, key) 378 379 def __len__(self): 380 return self._len 381 382 def __bool__(self): 383 return self.root != 0 384 385 def __str__(self): 386 return "{" + ", ".join(map(str, self.tolist())) + "}" 387 388 def __repr__(self): 389 return f"{self.__class__.__name__}({self.tolist()})"