Source code for titan_pylib.data_structures.scapegoat_tree.scapegoat_tree_multiset

  1from titan_pylib.my_class.ordered_multiset_interface import OrderedMultisetInterface
  2from titan_pylib.my_class.supports_less_than import SupportsLessThan
  3from titan_pylib.data_structures.bst_base.bst_multiset_node_base import (
  4    BSTMultisetNodeBase,
  5)
  6import math
  7from typing import Final, TypeVar, Generic, Iterable, Optional, Iterator
  8
  9T = TypeVar("T", bound=SupportsLessThan)
 10
 11
[docs] 12class ScapegoatTreeMultiset(OrderedMultisetInterface, Generic[T]): 13 14 ALPHA: Final[float] = 0.75 15 BETA: Final[float] = math.log2(1 / ALPHA) 16
[docs] 17 class Node: 18 19 def __init__(self, key: T, val: int): 20 self.key: T = key 21 self.val: int = val 22 self.size: int = 1 23 self.valsize: int = val 24 self.left: Optional[ScapegoatTreeMultiset.Node] = None 25 self.right: Optional[ScapegoatTreeMultiset.Node] = None 26 27 def __str__(self): 28 if self.left is None and self.right is None: 29 return f"key:{self.key, self.val, self.size, self.valsize}\n" 30 return f"key:{self.key, self.val, self.size, self.valsize},\n left:{self.left},\n right:{self.right}\n"
31 32 def __init__(self, a: Iterable[T] = []): 33 self.root = None 34 if not isinstance(a, list): 35 a = list(a) 36 self._build(a) 37 38 def _build(self, a: list[T]) -> None: 39 Node = ScapegoatTreeMultiset.Node 40 41 def rec(l: int, r: int) -> ScapegoatTreeMultiset.Node: 42 mid = (l + r) >> 1 43 node = Node(x[mid], y[mid]) 44 if l != mid: 45 node.left = rec(l, mid) 46 node.size += node.left.size 47 node.valsize += node.left.valsize 48 if mid + 1 != r: 49 node.right = rec(mid + 1, r) 50 node.size += node.right.size 51 node.valsize += node.right.valsize 52 return node 53 54 if not all(a[i] <= a[i + 1] for i in range(len(a) - 1)): 55 a = sorted(a) 56 if not a: 57 return 58 x, y = BSTMultisetNodeBase[T, ScapegoatTreeMultiset.Node]._rle(a) 59 self.root = rec(0, len(x)) 60 61 def _rebuild(self, node: Node) -> Node: 62 def rec(l: int, r: int) -> ScapegoatTreeMultiset.Node: 63 mid = (l + r) >> 1 64 node = a[mid] 65 node.size = 1 66 node.valsize = node.val 67 if l != mid: 68 node.left = rec(l, mid) 69 node.size += node.left.size 70 node.valsize += node.left.valsize 71 else: 72 node.left = None 73 if mid + 1 != r: 74 node.right = rec(mid + 1, r) 75 node.size += node.right.size 76 node.valsize += node.right.valsize 77 else: 78 node.right = None 79 return node 80 81 a = [] 82 stack = [] 83 while stack or node: 84 if node: 85 stack.append(node) 86 node = node.left 87 else: 88 node = stack.pop() 89 a.append(node) 90 node = node.right 91 return rec(0, len(a)) 92 93 def _kth_elm(self, k: int) -> tuple[T, int]: 94 if k < 0: 95 k += len(self) 96 node = self.root 97 while node: 98 t = (node.val + node.left.valsize) if node.left else node.val 99 if t - node.val <= k and k < t: 100 return node.key, node.val 101 elif t > k: 102 node = node.left 103 else: 104 node = node.right 105 k -= t 106 107 def _kth_elm_tree(self, k: int) -> tuple[T, int]: 108 if k < 0: 109 k += self.len_elm() 110 node = self.root 111 while node: 112 t = node.left.size if node.left else 0 113 if t == k: 114 return node.key, node.val 115 if t > k: 116 node = node.left 117 else: 118 node = node.right 119 k -= t + 1 120 assert False, "IndexError" 121
[docs] 122 def add(self, key: T, val: int = 1) -> None: 123 if val <= 0: 124 return 125 if not self.root: 126 self.root = ScapegoatTreeMultiset.Node(key, val) 127 return 128 node = self.root 129 path = [] 130 while node: 131 path.append(node) 132 if key == node.key: 133 node.val += val 134 for p in path: 135 p.valsize += val 136 return 137 node = node.left if key < node.key else node.right 138 if key < path[-1].key: 139 path[-1].left = ScapegoatTreeMultiset.Node(key, val) 140 else: 141 path[-1].right = ScapegoatTreeMultiset.Node(key, val) 142 if len(path) * ScapegoatTreeMultiset.BETA > math.log(self.len_elm()): 143 node_size = 1 144 while path: 145 pnode = path.pop() 146 pnode_size = pnode.size + 1 147 if ScapegoatTreeMultiset.ALPHA * pnode_size < node_size: 148 break 149 node_size = pnode_size 150 new_node = self._rebuild(pnode) 151 if not path: 152 self.root = new_node 153 return 154 if new_node.key < path[-1].key: 155 path[-1].left = new_node 156 else: 157 path[-1].right = new_node 158 for p in path: 159 p.size += 1 160 p.valsize += val
161 162 def _discard(self, key: T) -> bool: 163 path = [] 164 node = self.root 165 di, cnt = 1, 0 166 while node: 167 if key == node.key: 168 break 169 path.append(node) 170 di = key < node.key 171 node = node.left if di else node.right 172 if node.left and node.right: 173 path.append(node) 174 lmax = node.left 175 di = 0 if lmax.right else 1 176 while lmax.right: 177 cnt += 1 178 path.append(lmax) 179 lmax = lmax.right 180 lmax_val = lmax.val 181 node.key = lmax.key 182 node.val = lmax_val 183 node = lmax 184 cnode = node.left if node.left else node.right 185 if path: 186 if di == 1: 187 path[-1].left = cnode 188 else: 189 path[-1].right = cnode 190 else: 191 self.root = cnode 192 return True 193 for _ in range(cnt): 194 p = path.pop() 195 p.size -= 1 196 p.valsize -= lmax_val 197 for p in path: 198 p.size -= 1 199 p.valsize -= 1 200 return True 201
[docs] 202 def discard(self, key: T, val=1) -> bool: 203 if val <= 0: 204 return True 205 path = [] 206 node = self.root 207 while node: 208 path.append(node) 209 if key == node.key: 210 break 211 node = node.left if key < node.key else node.right 212 else: 213 return False 214 if val > node.val: 215 val = node.val - 1 216 if val > 0: 217 node.val -= val 218 while path: 219 path.pop().valsize -= val 220 if node.val == 1: 221 self._discard(key) 222 else: 223 node.val -= val 224 while path: 225 path.pop().valsize -= val 226 return True
227
[docs] 228 def remove(self, key: T, val: int = 1) -> None: 229 c = self.count(key) 230 if c > val: 231 raise KeyError(key) 232 self.discard(key, val)
233
[docs] 234 def count(self, key: T) -> int: 235 node = self.root 236 while node: 237 if key == node.key: 238 return node.val 239 node = node.left if key < node.key else node.right 240 return 0
241
[docs] 242 def discard_all(self, key: T) -> bool: 243 return self.discard(key, self.count(key))
244
[docs] 245 def le(self, key: T) -> Optional[T]: 246 return BSTMultisetNodeBase[T, ScapegoatTreeMultiset.Node].le(self.root, key)
247
[docs] 248 def lt(self, key: T) -> Optional[T]: 249 return BSTMultisetNodeBase[T, ScapegoatTreeMultiset.Node].lt(self.root, key)
250
[docs] 251 def ge(self, key: T) -> Optional[T]: 252 return BSTMultisetNodeBase[T, ScapegoatTreeMultiset.Node].ge(self.root, key)
253
[docs] 254 def gt(self, key: T) -> Optional[T]: 255 return BSTMultisetNodeBase[T, ScapegoatTreeMultiset.Node].gt(self.root, key)
256
[docs] 257 def index(self, key: T) -> int: 258 return BSTMultisetNodeBase[T, ScapegoatTreeMultiset.Node].index(self.root, key)
259
[docs] 260 def index_right(self, key: T) -> int: 261 return BSTMultisetNodeBase[T, ScapegoatTreeMultiset.Node].index_right( 262 self.root, key 263 )
264
[docs] 265 def index_keys(self, key: T) -> int: 266 k = 0 267 node = self.root 268 while node: 269 if key == node.key: 270 if node.left: 271 k += node.left.size 272 break 273 elif key < node.key: 274 node = node.left 275 else: 276 k += node.val if node.left is None else node.left.size + node.val 277 node = node.right 278 return k
279
[docs] 280 def index_right_keys(self, key: T) -> int: 281 k = 0 282 node = self.root 283 while node: 284 if key == node.key: 285 k += node.val if node.left is None else node.left.size + node.val 286 break 287 if key < node.key: 288 node = node.left 289 else: 290 k += node.val if node.left is None else node.left.size + node.val 291 node = node.right 292 return k
293
[docs] 294 def pop(self, k: int = -1) -> T: 295 if k < 0: 296 k += self.root.valsize 297 x = self[k] 298 self.discard(x) 299 return x
300
[docs] 301 def pop_min(self) -> T: 302 return self.pop(0)
303
[docs] 304 def pop_max(self) -> T: 305 return self.pop(-1)
306
[docs] 307 def items(self) -> Iterator[tuple[T, int]]: 308 for i in range(self.len_elm()): 309 yield self._kth_elm_tree(i)
310
[docs] 311 def keys(self) -> Iterator[T]: 312 for i in range(self.len_elm()): 313 yield self._kth_elm_tree(i)[0]
314
[docs] 315 def values(self) -> Iterator[int]: 316 for i in range(self.len_elm()): 317 yield self._kth_elm_tree(i)[1]
318
[docs] 319 def show(self) -> None: 320 print( 321 "{" + ", ".join(map(lambda x: f"{x[0]}: {x[1]}", self.tolist_items())) + "}" 322 )
323
[docs] 324 def get_elm(self, k: int) -> T: 325 assert ( 326 -self.len_elm() <= k < self.len_elm() 327 ), f"IndexError: {self.__class__.__name__}.get_elm({k}), len_elm=({self.len_elm()})" 328 return self._kth_elm_tree(k)[0]
329
[docs] 330 def len_elm(self) -> int: 331 return self.root.size if self.root else 0
332
[docs] 333 def tolist(self) -> list[T]: 334 return BSTMultisetNodeBase[T, ScapegoatTreeMultiset.Node].tolist(self.root)
335
[docs] 336 def tolist_items(self) -> list[tuple[T, int]]: 337 return BSTMultisetNodeBase[T, ScapegoatTreeMultiset.Node].tolist_items( 338 self.root 339 )
340
[docs] 341 def clear(self) -> None: 342 self.root = None
343
[docs] 344 def get_max(self) -> T: 345 return self._kth_elm_tree(-1)[0]
346
[docs] 347 def get_min(self) -> T: 348 return self._kth_elm_tree(0)[0]
349 350 def __contains__(self, key: T): 351 return BSTMultisetNodeBase[T, ScapegoatTreeMultiset.Node].contains( 352 self.root, key 353 ) 354 355 def __getitem__(self, k: int) -> T: 356 assert ( 357 -len(self) <= k < len(self) 358 ), f"IndexError: {self.__class__.__name__}[{k}], len={len(self)}" 359 return self._kth_elm(k)[0] 360 361 def __iter__(self): 362 self.__iter = 0 363 return self 364 365 def __next__(self): 366 if self.__iter == len(self): 367 raise StopIteration 368 res = self._kth_elm(self.__iter)[0] 369 self.__iter += 1 370 return res 371 372 def __reversed__(self): 373 for i in range(len(self)): 374 yield self._kth_elm(-i - 1)[0] 375 376 def __len__(self): 377 return self.root.valsize if self.root else 0 378 379 def __bool__(self): 380 return self.root is not None 381 382 def __str__(self): 383 return "{" + ", ".join(map(str, self.tolist())) + "}" 384 385 def __repr__(self): 386 return f"{self.__class__.__name__}({self.tolist})"