Source code for titan_pylib.data_structures.wbt.wbt_multiset

  1from titan_pylib.data_structures.wbt._wbt_multiset_node import _WBTMultisetNode
  2from typing import Generic, TypeVar, Optional, Iterable, Iterator
  3
  4T = TypeVar("T")
  5
  6
[docs] 7class WBTMultiset(Generic[T]): 8 9 __slots__ = "_root", "_min", "_max" 10 11 def __init__(self, a: Iterable[T] = []) -> None: 12 self._root: Optional[_WBTMultisetNode[T]] = None 13 self._min: Optional[_WBTMultisetNode[T]] = None 14 self._max: Optional[_WBTMultisetNode[T]] = None 15 self.__build(a) 16 17 def __build(self, a: Iterable[T]) -> None: 18 def build( 19 l: int, r: int, pnode: Optional[_WBTMultisetNode[T]] = None 20 ) -> _WBTMultisetNode[T]: 21 if l == r: 22 return None 23 mid = (l + r) // 2 24 node = _WBTMultisetNode(keys[mid], vals[mid]) 25 node._left = build(l, mid, node) 26 node._right = build(mid + 1, r, node) 27 node._par = pnode 28 node._update() 29 return node 30 31 a = list(a) 32 if not a: 33 return 34 if not all(a[i] <= a[i + 1] for i in range(len(a) - 1)): 35 a.sort() 36 # RLE 37 keys, vals = [a[0]], [1] 38 for i, elm in enumerate(a): 39 if i == 0: 40 continue 41 if elm == keys[-1]: 42 vals[-1] += 1 43 continue 44 keys.append(elm) 45 vals.append(1) 46 self._root = build(0, len(keys)) 47 self._max = self._root._max() 48 self._min = self._root._min() 49
[docs] 50 def add(self, key: T, count: int = 1) -> None: 51 if not self._root: 52 self._root = _WBTMultisetNode(key, count) 53 self._max = self._root 54 self._min = self._root 55 return 56 pnode = None 57 node = self._root 58 while node: 59 node._count_size += count 60 if key == node._key: 61 node._count += count 62 return 63 pnode = node 64 node = node._left if key < node._key else node._right 65 if key < pnode._key: 66 pnode._left = _WBTMultisetNode(key, count) 67 if key < self._min._key: 68 self._min = pnode._left 69 pnode._left._par = pnode 70 else: 71 pnode._right = _WBTMultisetNode(key, count) 72 if key > self._max._key: 73 self._max = pnode._right 74 pnode._right._par = pnode 75 self._root = pnode._rebalance()
76
[docs] 77 def find_key(self, key: T) -> Optional[_WBTMultisetNode[T]]: 78 node = self._root 79 while node: 80 if key == node._key: 81 return node 82 node = node._left if key < node._key else node._right 83 return None
84
[docs] 85 def find_order(self, k: int) -> _WBTMultisetNode[T]: 86 node = self._root 87 while True: 88 t = node._left._count_size + node._count if node._left else node._count 89 if t - node._count <= k < t: 90 return node 91 if t > k: 92 node = node._left 93 else: 94 node = node._right 95 k -= t
96
[docs] 97 def count(self, key: T) -> int: 98 node = self.find_key(key) 99 return node.count if node is not None else 0
100
[docs] 101 def remove_iter(self, node: _WBTMultisetNode[T]) -> None: 102 if node is self._min: 103 self._min = self._min._next() 104 if node is self._max: 105 self._max = self._max._prev() 106 delnode = node 107 pnode, mnode = node._par, None 108 if node._left and node._right: 109 pnode, mnode = node, node._left 110 while mnode._right: 111 pnode, mnode = mnode, mnode._right 112 node._count = mnode._count 113 node = mnode 114 cnode = node._right if not node._left else node._left 115 if cnode: 116 cnode._par = pnode 117 if pnode: 118 if pnode._left is node: 119 pnode._left = cnode 120 else: 121 pnode._right = cnode 122 self._root = pnode._rebalance() 123 else: 124 self._root = cnode 125 if mnode: 126 if self._root is delnode: 127 self._root = mnode 128 mnode._copy_from(delnode) 129 del delnode
130
[docs] 131 def remove(self, key: T, count: int = 1) -> None: 132 node = self.find_key(key) 133 assert node, f"KeyError: {key} is not found." 134 if node._count <= count: 135 self.remove_iter(node) 136 else: 137 node._count -= count 138 while node: 139 node._count_size -= count 140 node = node._par
141
[docs] 142 def discard(self, key: T, count: int = 1) -> bool: 143 node = self.find_key(key) 144 if node is None: 145 return False 146 if node._count <= count: 147 self.remove_iter(node) 148 else: 149 node._count -= count 150 while node: 151 node._count_size -= count 152 node = node._par 153 return True
154
[docs] 155 def pop(self, k: int = -1) -> T: 156 node = self.find_order(k) 157 key = node._key 158 if node._count == 0: 159 self.remove_iter(node) 160 else: 161 node._count -= 1 162 while node: 163 node._count_size -= 1 164 node = node._par 165 return key
166
[docs] 167 def le_iter(self, key: T) -> Optional[_WBTMultisetNode[T]]: 168 res = None 169 node = self._root 170 while node: 171 if key == node._key: 172 res = node 173 break 174 if key < node._key: 175 node = node._left 176 else: 177 res = node 178 node = node._right 179 return res
180
[docs] 181 def lt_iter(self, key: T) -> Optional[_WBTMultisetNode[T]]: 182 res = None 183 node = self._root 184 while node: 185 if key <= node._key: 186 node = node._left 187 else: 188 res = node 189 node = node._right 190 return res
191
[docs] 192 def ge_iter(self, key: T) -> Optional[_WBTMultisetNode[T]]: 193 res = None 194 node = self._root 195 while node: 196 if key == node._key: 197 res = node 198 break 199 if key < node._key: 200 res = node 201 node = node._left 202 else: 203 node = node._right 204 return res
205
[docs] 206 def gt_iter(self, key: T) -> Optional[_WBTMultisetNode[T]]: 207 res = None 208 node = self._root 209 while node: 210 if key < node._key: 211 res = node 212 node = node._left 213 else: 214 node = node._right 215 return res
216
[docs] 217 def le(self, key: T) -> Optional[T]: 218 res = None 219 node = self._root 220 while node: 221 if key == node._key: 222 res = key 223 break 224 if key < node._key: 225 node = node._left 226 else: 227 res = node._key 228 node = node._right 229 return res
230
[docs] 231 def lt(self, key: T) -> Optional[T]: 232 res = None 233 node = self._root 234 while node: 235 if key <= node._key: 236 node = node._left 237 else: 238 res = node._key 239 node = node._right 240 return res
241
[docs] 242 def ge(self, key: T) -> Optional[T]: 243 res = None 244 node = self._root 245 while node: 246 if key == node._key: 247 res = key 248 break 249 if key < node._key: 250 res = node._key 251 node = node._left 252 else: 253 node = node._right 254 return res
255
[docs] 256 def gt(self, key: T) -> Optional[T]: 257 res = None 258 node = self._root 259 while node: 260 if key < node._key: 261 res = node._key 262 node = node._left 263 else: 264 node = node._right 265 return res
266
[docs] 267 def index(self, key: T) -> int: 268 k = 0 269 node = self._root 270 while node: 271 if key == node._key: 272 k += node._left._count_size if node._left else 0 273 break 274 if key < node._key: 275 node = node._left 276 else: 277 k += node._left._count_size + node._count if node._left else node._count 278 node = node._right 279 return k
280
[docs] 281 def index_right(self, key: T) -> int: 282 k = 0 283 node = self._root 284 while node: 285 if key == node._key: 286 k += node._left._count_size + node._count if node._left else node._count 287 break 288 if key < node._key: 289 node = node._left 290 else: 291 k += node._left._count_size + node._count if node._left else node._count 292 node = node._right 293 return k
294
[docs] 295 def tolist(self) -> list[T]: 296 return list(self)
297
[docs] 298 def get_min(self) -> T: 299 assert self._min 300 return self._min._key
301
[docs] 302 def get_max(self) -> T: 303 assert self._max 304 return self._max._key
305
[docs] 306 def pop_min(self) -> T: 307 assert self._min 308 key = self._min._key 309 self._min._count -= 1 310 if self._min._count == 0: 311 self.remove_iter(self._min) 312 return key
313
[docs] 314 def pop_max(self) -> T: 315 assert self._max 316 key = self._max._key 317 self._max._count -= 1 318 if self._max._count == 0: 319 self.remove_iter(self._max) 320 return key
321
[docs] 322 def check(self) -> None: 323 if self._root is None: 324 # print("ok. 0 (empty)") 325 return 326 327 # _size, count_size, height 328 def dfs(node: _WBTMultisetNode[T]) -> tuple[int, int, int]: 329 h = 0 330 s = 1 331 cs = node.count 332 if node._left: 333 assert node._key > node._left._key 334 ls, lcs, lh = dfs(node._left) 335 s += ls 336 cs += lcs 337 h = max(h, lh) 338 if node._right: 339 assert node._key < node._right._key 340 rs, rcs, rh = dfs(node._right) 341 s += rs 342 cs += rcs 343 h = max(h, rh) 344 assert node._size == s 345 assert node._count_size == cs 346 node._balance_check() 347 return s, cs, h + 1 348 349 _, _, h = dfs(self._root)
350 # print(f"ok. {h}") 351 352 def __contains__(self, key: T) -> bool: 353 return self.find_key(key) is not None 354 355 def __getitem__(self, k: int) -> T: 356 assert ( 357 -len(self) <= k < len(self) 358 ), f"IndexError: {self.__class__.__name__}[{k}], len={len(self)}" 359 if k < 0: 360 k += len(self) 361 if k == 0: 362 return self.get_min() 363 if k == len(self) - 1: 364 return self.get_max() 365 return self.find_order(k)._key 366 367 def __delitem__(self, k: int) -> None: 368 node = self.find_order(k) 369 node._count -= 1 370 if node._count == 0: 371 self.remove_iter(node) 372 373 def __len__(self) -> int: 374 return self._root._count_size if self._root else 0 375 376 def __iter__(self) -> Iterator[T]: 377 stack: list[_WBTMultisetNode[T]] = [] 378 node = self._root 379 while stack or node: 380 if node: 381 stack.append(node) 382 node = node._left 383 else: 384 node = stack.pop() 385 for _ in range(node._count): 386 yield node._key 387 node = node._right 388 389 def __reversed__(self) -> Iterator[T]: 390 stack: list[_WBTMultisetNode[T]] = [] 391 node = self._root 392 while stack or node: 393 if node: 394 stack.append(node) 395 node = node._right 396 else: 397 node = stack.pop() 398 for _ in range(node._count): 399 yield node._key 400 node = node._left 401 402 def __str__(self) -> str: 403 return "{" + ", ".join(map(str, self)) + "}" 404 405 def __repr__(self) -> str: 406 return ( 407 f"{self.__class__.__name__}(" 408 + "[" 409 + ", ".join(map(str, self.tolist())) 410 + "])" 411 )