Source code for titan_pylib.data_structures.splay_tree.splay_tree_multiset2

  1import sys
  2from typing import Generic, Iterable, TypeVar, Optional
  3
  4T = TypeVar("T")
  5
  6
[docs] 7class SplayTreeMultiset2(Generic[T]): 8
[docs] 9 class Node: 10 11 def __init__(self, key: T, val: int): 12 self.key = key 13 self.val = val 14 self.left = None 15 self.right = None 16 17 def __str__(self): 18 if self.left is None and self.right is None: 19 return f"key:{self.key, self.val}\n" 20 return ( 21 f"key:{self.key, self.val},\n left:{self.left},\n right:{self.right}\n" 22 )
23 24 def __init__(self, a: Iterable[T] = []): 25 self.node = None 26 self._len = 0 27 self._len_elm = 0 28 if not (hasattr(a, "__getitem__") and hasattr(a, "__len__")): 29 a = list(a) 30 if a: 31 self._build(a) 32 33 def _build(self, a: Iterable[T]) -> None: 34 Node = SplayTreeMultiset2.Node 35 36 def sort(l: int, r: int) -> SplayTreeMultiset2.Node: 37 mid = (l + r) >> 1 38 node = Node(key[mid], val[mid]) 39 if l != mid: 40 node.left = sort(l, mid) 41 if mid + 1 != r: 42 node.right = sort(mid + 1, r) 43 return node 44 45 a = sorted(a) 46 self._len = len(a) 47 key, val = self._rle(sorted(a)) 48 self._len_elm = len(key) 49 self.node = sort(0, len(key)) 50 51 def _rle(self, a: list[T]) -> tuple[list[T], list[int]]: 52 x = [] 53 y = [] 54 x.append(a[0]) 55 y.append(1) 56 for i, e in enumerate(a): 57 if i == 0: 58 continue 59 if e == x[-1]: 60 y[-1] += 1 61 continue 62 x.append(e) 63 y.append(1) 64 return x, y 65 66 def _splay(self, path: list[Node], di: int) -> Node: 67 for _ in range(len(path) >> 1): 68 node = path.pop() 69 pnode = path.pop() 70 if di & 1 == di >> 1 & 1: 71 if di & 1 == 1: 72 tmp = node.left 73 node.left = tmp.right 74 tmp.right = node 75 pnode.left = node.right 76 node.right = pnode 77 else: 78 tmp = node.right 79 node.right = tmp.left 80 tmp.left = node 81 pnode.right = node.left 82 node.left = pnode 83 else: 84 if di & 1 == 1: 85 tmp = node.left 86 node.left = tmp.right 87 pnode.right = tmp.left 88 tmp.right = node 89 tmp.left = pnode 90 else: 91 tmp = node.right 92 node.right = tmp.left 93 pnode.left = tmp.right 94 tmp.left = node 95 tmp.right = pnode 96 if not path: 97 return tmp 98 di >>= 2 99 if di & 1 == 1: 100 path[-1].left = tmp 101 else: 102 path[-1].right = tmp 103 gnode = path[0] 104 if di & 1 == 1: 105 node = gnode.left 106 gnode.left = node.right 107 node.right = gnode 108 else: 109 node = gnode.right 110 gnode.right = node.left 111 node.left = gnode 112 return node 113 114 def _set_search_splay(self, key: T) -> None: 115 node = self.node 116 if node is None or node.key == key: 117 return 118 path = [] 119 di = 0 120 while True: 121 if node.key == key: 122 break 123 elif key < node.key: 124 if node.left is None: 125 break 126 path.append(node) 127 di <<= 1 128 di |= 1 129 node = node.left 130 else: 131 if node.right is None: 132 break 133 path.append(node) 134 di <<= 1 135 node = node.right 136 if path: 137 self.node = self._splay(path, di) 138 139 def _get_min_splay(self, node: Node) -> Node: 140 if node is None or node.left is None: 141 return node 142 path = [] 143 while node.left is not None: 144 path.append(node) 145 node = node.left 146 return self._splay(path, (1 << len(path)) - 1) 147 148 def _get_max_splay(self, node: Node) -> Node: 149 if node is None or node.right is None: 150 return node 151 path = [] 152 while node.right is not None: 153 path.append(node) 154 node = node.right 155 return self._splay(path, 0) 156
[docs] 157 def add(self, key: T, val: int = 1) -> None: 158 self._len += val 159 if self.node is None: 160 self._len_elm += 1 161 self.node = SplayTreeMultiset2.Node(key, val) 162 return 163 self._set_search_splay(key) 164 if self.node.key == key: 165 self.node.val += val 166 return 167 self._len_elm += 1 168 node = SplayTreeMultiset2.Node(key, val) 169 if key < self.node.key: 170 node.left = self.node.left 171 node.right = self.node 172 self.node.left = None 173 else: 174 node.left = self.node 175 node.right = self.node.right 176 self.node.right = None 177 self.node = node 178 return
179
[docs] 180 def discard(self, key: T, val: int = 1) -> bool: 181 if self.node is None: 182 return False 183 self._set_search_splay(key) 184 if self.node.key != key: 185 return False 186 if self.node.val > val: 187 self.node.val -= val 188 self._len -= val 189 return True 190 self._len -= self.node.val 191 self._len_elm -= 1 192 if self.node.left is None: 193 self.node = self.node.right 194 elif self.node.right is None: 195 self.node = self.node.left 196 else: 197 node = self._get_min_splay(self.node.right) 198 node.left = self.node.left 199 self.node = node 200 return True
201
[docs] 202 def discard_all(self, key: T) -> bool: 203 return self.discar(key, self.count(key))
204
[docs] 205 def count(self, key: T) -> int: 206 if self.node is None: 207 return 0 208 self._set_search_splay(key) 209 return self.node.val if self.node.key == key else 0
210
[docs] 211 def le(self, key: T) -> Optional[T]: 212 node = self.node 213 if node is None: 214 return None 215 path = [] 216 di = 0 217 res = None 218 while True: 219 if node.key == key: 220 res = key 221 break 222 elif key < node.key: 223 if node.left is None: 224 break 225 path.append(node) 226 di <<= 1 227 di |= 1 228 node = node.left 229 else: 230 res = node.key 231 if node.right is None: 232 break 233 path.append(node) 234 di <<= 1 235 node = node.right 236 if path: 237 self.node = self._splay(path, di) 238 return res
239
[docs] 240 def lt(self, key: T) -> Optional[T]: 241 node = self.node 242 if node is None: 243 return None 244 path = [] 245 di = 0 246 res = None 247 while True: 248 if key <= node.key: 249 if node.left is None: 250 break 251 path.append(node) 252 di <<= 1 253 di |= 1 254 node = node.left 255 else: 256 res = node.key 257 if node.right is None: 258 break 259 path.append(node) 260 di <<= 1 261 node = node.right 262 if path: 263 self.node = self._splay(path, di) 264 return res
265
[docs] 266 def ge(self, key: T) -> Optional[T]: 267 node = self.node 268 if node is None: 269 return None 270 path = [] 271 di = 0 272 res = None 273 while True: 274 if node.key == key: 275 res = node.key 276 break 277 elif key < node.key: 278 res = node.key 279 if node.left is None: 280 break 281 path.append(node) 282 di <<= 1 283 di |= 1 284 node = node.left 285 else: 286 if node.right is None: 287 break 288 path.append(node) 289 di <<= 1 290 node = node.right 291 if path: 292 self.node = self._splay(path, di) 293 return res
294
[docs] 295 def gt(self, key: T) -> Optional[T]: 296 node = self.node 297 if node is None: 298 return None 299 path = [] 300 di = 0 301 res = None 302 while True: 303 if key < node.key: 304 res = node.key 305 if node.left is None: 306 break 307 path.append(node) 308 di <<= 1 309 di |= 1 310 node = node.left 311 else: 312 if node.right is None: 313 break 314 path.append(node) 315 di <<= 1 316 node = node.right 317 if path: 318 self.node = self._splay(path, di) 319 return res
320
[docs] 321 def pop_max(self) -> T: 322 self.node = self._get_max_splay(self.node) 323 res = self.node.key 324 self.discard(res) 325 return res
326
[docs] 327 def pop_min(self) -> T: 328 self.node = self._get_min_splay(self.node) 329 res = self.node.key 330 self.discard(res) 331 return res
332
[docs] 333 def get_min(self) -> Optional[T]: 334 if self.node is None: 335 return 336 self.node = self._get_min_splay(self.node) 337 return self.node.key
338
[docs] 339 def get_max(self) -> Optional[T]: 340 if self.node is None: 341 return 342 self.node = self._get_max_splay(self.node) 343 return self.node.key
344
[docs] 345 def tolist(self) -> list[T]: 346 a = [] 347 if self.node is None: 348 return a 349 if sys.getrecursionlimit() < self.len_elm(): 350 sys.setrecursionlimit(self.len_elm() + 1) 351 352 def rec(node): 353 if node.left is not None: 354 rec(node.left) 355 a.extend([node.key] * node.val) 356 if node.right is not None: 357 rec(node.right) 358 359 rec(self.node) 360 return a
361
[docs] 362 def tolist_items(self) -> list[tuple[T, int]]: 363 a = [] 364 if self.node is None: 365 return a 366 if sys.getrecursionlimit() < self._len_elm(): 367 sys.setrecursionlimit(self._len_elm() + 1) 368 369 def rec(node): 370 if node.left is not None: 371 rec(node.left) 372 a.append((node.key, node.val)) 373 if node.right is not None: 374 rec(node.right) 375 376 rec(self.node) 377 return a
378
[docs] 379 def len_elm(self) -> int: 380 return self._len_elm
381
[docs] 382 def clear(self) -> None: 383 self.node = None
384 385 def __getitem__(self, k): # 先s頭と末尾しか対応していない 386 if k == -1 or k == self._len - 1: 387 return self.get_max() 388 elif k == 0: 389 return self.get_min() 390 raise IndexError 391 392 def __contains__(self, key: T) -> bool: 393 self._set_search_splay(key) 394 return self.node is not None and self.node.key == key 395 396 def __len__(self): 397 return self._len 398 399 def __bool__(self): 400 return self.node is not None 401 402 def __str__(self): 403 return "{" + ", ".join(map(str, self.tolist())) + "}" 404 405 def __repr__(self): 406 return f"SplayTreeMultiset2({self.tolist()})"