treap_multiset¶
ソースコード¶
from titan_pylib.data_structures.treap.treap_multiset import TreapMultiset
展開済みコード¶
1# from titan_pylib.data_structures.treap.treap_multiset import TreapMultiset
2# from titan_pylib.my_class.ordered_multiset_interface import OrderedMultisetInterface
3# from titan_pylib.my_class.supports_less_than import SupportsLessThan
4from typing import Protocol
5
6
7class SupportsLessThan(Protocol):
8
9 def __lt__(self, other) -> bool: ...
10from abc import ABC, abstractmethod
11from typing import Iterable, Optional, Iterator, TypeVar, Generic
12
13T = TypeVar("T", bound=SupportsLessThan)
14
15
16class OrderedMultisetInterface(ABC, Generic[T]):
17
18 @abstractmethod
19 def __init__(self, a: Iterable[T]) -> None:
20 raise NotImplementedError
21
22 @abstractmethod
23 def add(self, key: T, cnt: int) -> None:
24 raise NotImplementedError
25
26 @abstractmethod
27 def discard(self, key: T, cnt: int) -> bool:
28 raise NotImplementedError
29
30 @abstractmethod
31 def discard_all(self, key: T) -> bool:
32 raise NotImplementedError
33
34 @abstractmethod
35 def count(self, key: T) -> int:
36 raise NotImplementedError
37
38 @abstractmethod
39 def remove(self, key: T, cnt: int) -> None:
40 raise NotImplementedError
41
42 @abstractmethod
43 def le(self, key: T) -> Optional[T]:
44 raise NotImplementedError
45
46 @abstractmethod
47 def lt(self, key: T) -> Optional[T]:
48 raise NotImplementedError
49
50 @abstractmethod
51 def ge(self, key: T) -> Optional[T]:
52 raise NotImplementedError
53
54 @abstractmethod
55 def gt(self, key: T) -> Optional[T]:
56 raise NotImplementedError
57
58 @abstractmethod
59 def get_max(self) -> Optional[T]:
60 raise NotImplementedError
61
62 @abstractmethod
63 def get_min(self) -> Optional[T]:
64 raise NotImplementedError
65
66 @abstractmethod
67 def pop_max(self) -> T:
68 raise NotImplementedError
69
70 @abstractmethod
71 def pop_min(self) -> T:
72 raise NotImplementedError
73
74 @abstractmethod
75 def clear(self) -> None:
76 raise NotImplementedError
77
78 @abstractmethod
79 def tolist(self) -> list[T]:
80 raise NotImplementedError
81
82 @abstractmethod
83 def __iter__(self) -> Iterator:
84 raise NotImplementedError
85
86 @abstractmethod
87 def __next__(self) -> T:
88 raise NotImplementedError
89
90 @abstractmethod
91 def __contains__(self, key: T) -> bool:
92 raise NotImplementedError
93
94 @abstractmethod
95 def __len__(self) -> int:
96 raise NotImplementedError
97
98 @abstractmethod
99 def __bool__(self) -> bool:
100 raise NotImplementedError
101
102 @abstractmethod
103 def __str__(self) -> str:
104 raise NotImplementedError
105
106 @abstractmethod
107 def __repr__(self) -> str:
108 raise NotImplementedError
109# from titan_pylib.my_class.supports_less_than import SupportsLessThan
110# from titan_pylib.data_structures.bst_base.bst_multiset_node_base import (
111# BSTMultisetNodeBase,
112# )
113from typing import TypeVar, Generic, Optional
114
115T = TypeVar("T")
116Node = TypeVar("Node")
117# protcolで、key,val,left,right を規定
118
119
120class BSTMultisetNodeBase(Generic[T, Node]):
121
122 @staticmethod
123 def count(node, key: T) -> int:
124 while node is not None:
125 if node.key == key:
126 return node.val
127 node = node.left if key < node.key else node.right
128 return 0
129
130 @staticmethod
131 def get_min(node: Node) -> Optional[T]:
132 if not node:
133 return None
134 while node.left:
135 node = node.left
136 return node.key
137
138 @staticmethod
139 def get_max(node: Node) -> Optional[T]:
140 if not node:
141 return None
142 while node.right:
143 node = node.right
144 return node.key
145
146 @staticmethod
147 def contains(node: Node, key: T) -> bool:
148 while node:
149 if key == node.key:
150 return True
151 node = node.left if key < node.key else node.right
152 return False
153
154 @staticmethod
155 def tolist(node: Node) -> list[T]:
156 stack = []
157 a = []
158 while stack or node:
159 if node:
160 stack.append(node)
161 node = node.left
162 else:
163 node = stack.pop()
164 for _ in range(node.val):
165 a.append(node.key)
166 node = node.right
167 return a
168
169 @staticmethod
170 def tolist_items(node: Node) -> list[tuple[T, int]]:
171 stack = []
172 a = []
173 while stack or node:
174 if node:
175 stack.append(node)
176 node = node.left
177 else:
178 node = stack.pop()
179 a.append((node.key, node.val))
180 node = node.right
181 return a
182
183 @staticmethod
184 def _rle(a: list[T]) -> tuple[list[T], list[int]]:
185 keys, vals = [], []
186 keys.append(a[0])
187 vals.append(1)
188 for i, elm in enumerate(a):
189 if i == 0:
190 continue
191 if elm == keys[-1]:
192 vals[-1] += 1
193 continue
194 keys.append(elm)
195 vals.append(1)
196 return keys, vals
197
198 @staticmethod
199 def le(node: Node, key: T) -> Optional[T]:
200 res = None
201 while node is not None:
202 if key == node.key:
203 res = key
204 break
205 if key < node.key:
206 node = node.left
207 else:
208 res = node.key
209 node = node.right
210 return res
211
212 @staticmethod
213 def lt(node: Node, key: T) -> Optional[T]:
214 res = None
215 while node is not None:
216 if key <= node.key:
217 node = node.left
218 else:
219 res = node.key
220 node = node.right
221 return res
222
223 @staticmethod
224 def ge(node: Node, key: T) -> Optional[T]:
225 res = None
226 while node is not None:
227 if key == node.key:
228 res = key
229 break
230 if key < node.key:
231 res = node.key
232 node = node.left
233 else:
234 node = node.right
235 return res
236
237 @staticmethod
238 def gt(node: Node, key: T) -> Optional[T]:
239 res = None
240 while node is not None:
241 if key < node.key:
242 res = node.key
243 node = node.left
244 else:
245 node = node.right
246 return res
247
248 @staticmethod
249 def index(node: Node, key: T) -> int:
250 k = 0
251 while node:
252 if key == node.key:
253 if node.left:
254 k += node.left.valsize
255 break
256 if key < node.key:
257 node = node.left
258 else:
259 k += node.val if node.left is None else node.left.valsize + node.val
260 node = node.right
261 return k
262
263 @staticmethod
264 def index_right(node: Node, key: T) -> int:
265 k = 0
266 while node:
267 if key == node.key:
268 k += node.val if node.left is None else node.left.valsize + node.val
269 break
270 if key < node.key:
271 node = node.left
272 else:
273 k += node.val if node.left is None else node.left.valsize + node.val
274 node = node.right
275 return k
276from typing import Generic, Iterable, TypeVar, Optional, Sequence
277
278T = TypeVar("T", bound=SupportsLessThan)
279
280
281class TreapMultiset(OrderedMultisetInterface, Generic[T]):
282
283 class Random:
284
285 _x, _y, _z, _w = 123456789, 362436069, 521288629, 88675123
286
287 @classmethod
288 def random(cls) -> int:
289 t = cls._x ^ (cls._x << 11) & 0xFFFFFFFF
290 cls._x, cls._y, cls._z = cls._y, cls._z, cls._w
291 cls._w = (cls._w ^ (cls._w >> 19)) ^ (t ^ (t >> 8)) & 0xFFFFFFFF
292 return cls._w
293
294 class Node:
295
296 def __init__(self, key: T, val: int = 1, priority: int = -1):
297 self.key: T = key
298 self.val: int = val
299 self.left: Optional["TreapMultiset.Node"] = None
300 self.right: Optional["TreapMultiset.Node"] = None
301 self.priority: int = (
302 TreapMultiset.Random.random() if priority == -1 else priority
303 )
304
305 def __str__(self):
306 if self.left is None and self.right is None:
307 return f"key:{self.key, self.priority}\n"
308 return f"key:{self.key, self.priority},\n left:{self.left},\n right:{self.right}\n"
309
310 def __init__(self, a: Iterable[T] = []):
311 self.root: Optional["TreapMultiset.Node"] = None
312 self._len: int = 0
313 self._len_elm: int = 0
314 if not isinstance(a, Sequence):
315 a = list(a)
316 if a:
317 self._build(a)
318
319 def _build(self, a: Iterable[T]) -> None:
320 Node = TreapMultiset.Node
321
322 def sort(l: int, r: int) -> TreapMultiset.Node:
323 mid = (l + r) >> 1
324 node = Node(key[mid], val[mid], rand[mid])
325 if l != mid:
326 node.left = sort(l, mid)
327 if mid + 1 != r:
328 node.right = sort(mid + 1, r)
329 return node
330
331 a = sorted(a)
332 key, val = BSTMultisetNodeBase[T, TreapMultiset.Node]._rle(a)
333 self._len = len(a)
334 self._len_elm = len(key)
335 rand = sorted(TreapMultiset.Random.random() for _ in range(self._len_elm))
336 self.root = sort(0, len(key))
337
338 def _rotate_L(self, node: Node) -> Node:
339 u = node.left
340 node.left = u.right
341 u.right = node
342 return u
343
344 def _rotate_R(self, node: Node) -> Node:
345 u = node.right
346 node.right = u.left
347 u.left = node
348 return u
349
350 def add(self, key: T, val: int = 1) -> None:
351 self._len += val
352 if self.root is None:
353 self.root = TreapMultiset.Node(key, val)
354 self._len_elm += 1
355 return
356 node = self.root
357 path = []
358 di = 0
359 while node is not None:
360 if key == node.key:
361 node.val += val
362 return
363 path.append(node)
364 if key < node.key:
365 di <<= 1
366 di |= 1
367 node = node.left
368 else:
369 di <<= 1
370 node = node.right
371 self._len_elm += 1
372 if di & 1:
373 path[-1].left = TreapMultiset.Node(key, val)
374 else:
375 path[-1].right = TreapMultiset.Node(key, val)
376 while path:
377 new_node = None
378 node = path.pop()
379 if di & 1:
380 if node.left.priority < node.priority:
381 new_node = self._rotate_L(node)
382 else:
383 if node.right.priority < node.priority:
384 new_node = self._rotate_R(node)
385 di >>= 1
386 if new_node is not None:
387 if path:
388 if di & 1:
389 path[-1].left = new_node
390 else:
391 path[-1].right = new_node
392 else:
393 self.root = new_node
394 self._len += 1
395
396 def discard(self, key: T, val: int = 1) -> bool:
397 node = self.root
398 pnode = None
399 while node is not None:
400 if key == node.key:
401 break
402 pnode = node
403 node = node.left if key < node.key else node.right
404 else:
405 return False
406 self._len -= min(val, node.val)
407 if node.val > val:
408 node.val -= val
409 return True
410 self._len_elm -= 1
411 while node.left is not None and node.right is not None:
412 if node.left.priority < node.right.priority:
413 if pnode is None:
414 pnode = self._rotate_L(node)
415 self.root = pnode
416 continue
417 new_node = self._rotate_L(node)
418 if node.key < pnode.key:
419 pnode.left = new_node
420 else:
421 pnode.right = new_node
422 else:
423 if pnode is None:
424 pnode = self._rotate_R(node)
425 self.root = pnode
426 continue
427 new_node = self._rotate_R(node)
428 if node.key < pnode.key:
429 pnode.left = new_node
430 else:
431 pnode.right = new_node
432 pnode = new_node
433 if pnode is None:
434 if node.left is None:
435 self.root = node.right
436 else:
437 self.root = node.left
438 return True
439 if node.left is None:
440 if node.key < pnode.key:
441 pnode.left = node.right
442 else:
443 pnode.right = node.right
444 else:
445 if node.key < pnode.key:
446 pnode.left = node.left
447 else:
448 pnode.right = node.left
449 return True
450
451 def discard_all(self, key: T) -> bool:
452 return self.discard(key, self.count(key))
453
454 def remove(self, key: T, val: int = 1) -> None:
455 if self.discard(key, val):
456 return
457 raise KeyError(key)
458
459 def count(self, key: T) -> int:
460 return BSTMultisetNodeBase[T, TreapMultiset.Node].count(self.root)
461
462 def le(self, key: T) -> Optional[T]:
463 return BSTMultisetNodeBase[T, TreapMultiset.Node].le(self.root, key)
464
465 def lt(self, key: T) -> Optional[T]:
466 return BSTMultisetNodeBase[T, TreapMultiset.Node].lt(self.root, key)
467
468 def ge(self, key: T) -> Optional[T]:
469 return BSTMultisetNodeBase[T, TreapMultiset.Node].ge(self.root, key)
470
471 def gt(self, key: T) -> Optional[T]:
472 return BSTMultisetNodeBase[T, TreapMultiset.Node].gt(self.root, key)
473
474 def len_elm(self) -> int:
475 return self._len_elm
476
477 def show(self) -> None:
478 print(
479 "{" + ", ".join(map(lambda x: f"{x[0]}: {x[1]}", self.tolist_items())) + "}"
480 )
481
482 def tolist(self) -> list[T]:
483 return BSTMultisetNodeBase[T, TreapMultiset.Node].tolist(self.root)
484
485 def tolist_items(self) -> list[tuple[T, int]]:
486 return BSTMultisetNodeBase[T, TreapMultiset.Node].tolist_items(self.root)
487
488 def get_min(self) -> Optional[T]:
489 return BSTMultisetNodeBase[T, TreapMultiset.Node][
490 T, TreapMultiset.Node
491 ].get_min(self.root)
492
493 def get_max(self) -> Optional[T]:
494 return BSTMultisetNodeBase[T, TreapMultiset.Node].get_max(self.root)
495
496 def pop_min(self) -> T:
497 assert self
498 self._len -= 1
499 node = self.root
500 pnode = None
501 while node.left is not None:
502 pnode = node
503 node = node.left
504 if node.val > 1:
505 node.val -= 1
506 return node.key
507 self._len_elm -= 1
508 res = node.key
509 if pnode is None:
510 self.root = self.root.right
511 else:
512 pnode.left = node.right
513 return res
514
515 def pop_max(self) -> T:
516 assert self, "IndexError"
517 self._len -= 1
518 node = self.root
519 pnode = None
520 while node.right is not None:
521 pnode = node
522 node = node.right
523 if node.val > 1:
524 node.val -= 1
525 return node.key
526 self._len_elm -= 1
527 res = node.key
528 if pnode is None:
529 self.root = self.root.left
530 else:
531 pnode.right = node.left
532 return res
533
534 def clear(self) -> None:
535 self.root = None
536
537 def __iter__(self):
538 self._it = self.get_min()
539 self._cnt = 1
540 return self
541
542 def __next__(self):
543 if self._it is None:
544 raise StopIteration
545 res = self._it
546 if self._cnt == self.count(self._it):
547 self._it = self.gt(self._it)
548 self._cnt = 1
549 else:
550 self._cnt += 1
551 return res
552
553 def __contains__(self, key: T):
554 return BSTMultisetNodeBase[T, TreapMultiset.Node].contains(self.root, key)
555
556 def __bool__(self):
557 return self.root is not None
558
559 def __len__(self):
560 return self._len
561
562 def __str__(self):
563 return "{" + ", ".join(map(str, self.tolist())) + "}"
564
565 def __repr__(self):
566 return f"TreapMultiset({self.tolist()})"
仕様¶
- class TreapMultiset(a: Iterable[T] = [])[source]¶
Bases:
OrderedMultisetInterface
,Generic
[T
]