1from titan_pylib.my_class.ordered_multiset_interface import OrderedMultisetInterface
2from titan_pylib.my_class.supports_less_than import SupportsLessThan
3from titan_pylib.data_structures.bst_base.bst_multiset_node_base import (
4 BSTMultisetNodeBase,
5)
6from typing import Generic, Iterable, TypeVar, Optional, Sequence
7
8T = TypeVar("T", bound=SupportsLessThan)
9
10
[docs]
11class TreapMultiset(OrderedMultisetInterface, Generic[T]):
12
[docs]
13 class Random:
14
15 _x, _y, _z, _w = 123456789, 362436069, 521288629, 88675123
16
[docs]
17 @classmethod
18 def random(cls) -> int:
19 t = cls._x ^ (cls._x << 11) & 0xFFFFFFFF
20 cls._x, cls._y, cls._z = cls._y, cls._z, cls._w
21 cls._w = (cls._w ^ (cls._w >> 19)) ^ (t ^ (t >> 8)) & 0xFFFFFFFF
22 return cls._w
23
[docs]
24 class Node:
25
26 def __init__(self, key: T, val: int = 1, priority: int = -1):
27 self.key: T = key
28 self.val: int = val
29 self.left: Optional["TreapMultiset.Node"] = None
30 self.right: Optional["TreapMultiset.Node"] = None
31 self.priority: int = (
32 TreapMultiset.Random.random() if priority == -1 else priority
33 )
34
35 def __str__(self):
36 if self.left is None and self.right is None:
37 return f"key:{self.key, self.priority}\n"
38 return f"key:{self.key, self.priority},\n left:{self.left},\n right:{self.right}\n"
39
40 def __init__(self, a: Iterable[T] = []):
41 self.root: Optional["TreapMultiset.Node"] = None
42 self._len: int = 0
43 self._len_elm: int = 0
44 if not isinstance(a, Sequence):
45 a = list(a)
46 if a:
47 self._build(a)
48
49 def _build(self, a: Iterable[T]) -> None:
50 Node = TreapMultiset.Node
51
52 def sort(l: int, r: int) -> TreapMultiset.Node:
53 mid = (l + r) >> 1
54 node = Node(key[mid], val[mid], rand[mid])
55 if l != mid:
56 node.left = sort(l, mid)
57 if mid + 1 != r:
58 node.right = sort(mid + 1, r)
59 return node
60
61 a = sorted(a)
62 key, val = BSTMultisetNodeBase[T, TreapMultiset.Node]._rle(a)
63 self._len = len(a)
64 self._len_elm = len(key)
65 rand = sorted(TreapMultiset.Random.random() for _ in range(self._len_elm))
66 self.root = sort(0, len(key))
67
68 def _rotate_L(self, node: Node) -> Node:
69 u = node.left
70 node.left = u.right
71 u.right = node
72 return u
73
74 def _rotate_R(self, node: Node) -> Node:
75 u = node.right
76 node.right = u.left
77 u.left = node
78 return u
79
[docs]
80 def add(self, key: T, val: int = 1) -> None:
81 self._len += val
82 if self.root is None:
83 self.root = TreapMultiset.Node(key, val)
84 self._len_elm += 1
85 return
86 node = self.root
87 path = []
88 di = 0
89 while node is not None:
90 if key == node.key:
91 node.val += val
92 return
93 path.append(node)
94 if key < node.key:
95 di <<= 1
96 di |= 1
97 node = node.left
98 else:
99 di <<= 1
100 node = node.right
101 self._len_elm += 1
102 if di & 1:
103 path[-1].left = TreapMultiset.Node(key, val)
104 else:
105 path[-1].right = TreapMultiset.Node(key, val)
106 while path:
107 new_node = None
108 node = path.pop()
109 if di & 1:
110 if node.left.priority < node.priority:
111 new_node = self._rotate_L(node)
112 else:
113 if node.right.priority < node.priority:
114 new_node = self._rotate_R(node)
115 di >>= 1
116 if new_node is not None:
117 if path:
118 if di & 1:
119 path[-1].left = new_node
120 else:
121 path[-1].right = new_node
122 else:
123 self.root = new_node
124 self._len += 1
125
[docs]
126 def discard(self, key: T, val: int = 1) -> bool:
127 node = self.root
128 pnode = None
129 while node is not None:
130 if key == node.key:
131 break
132 pnode = node
133 node = node.left if key < node.key else node.right
134 else:
135 return False
136 self._len -= min(val, node.val)
137 if node.val > val:
138 node.val -= val
139 return True
140 self._len_elm -= 1
141 while node.left is not None and node.right is not None:
142 if node.left.priority < node.right.priority:
143 if pnode is None:
144 pnode = self._rotate_L(node)
145 self.root = pnode
146 continue
147 new_node = self._rotate_L(node)
148 if node.key < pnode.key:
149 pnode.left = new_node
150 else:
151 pnode.right = new_node
152 else:
153 if pnode is None:
154 pnode = self._rotate_R(node)
155 self.root = pnode
156 continue
157 new_node = self._rotate_R(node)
158 if node.key < pnode.key:
159 pnode.left = new_node
160 else:
161 pnode.right = new_node
162 pnode = new_node
163 if pnode is None:
164 if node.left is None:
165 self.root = node.right
166 else:
167 self.root = node.left
168 return True
169 if node.left is None:
170 if node.key < pnode.key:
171 pnode.left = node.right
172 else:
173 pnode.right = node.right
174 else:
175 if node.key < pnode.key:
176 pnode.left = node.left
177 else:
178 pnode.right = node.left
179 return True
180
[docs]
181 def discard_all(self, key: T) -> bool:
182 return self.discard(key, self.count(key))
183
[docs]
184 def remove(self, key: T, val: int = 1) -> None:
185 if self.discard(key, val):
186 return
187 raise KeyError(key)
188
[docs]
189 def count(self, key: T) -> int:
190 return BSTMultisetNodeBase[T, TreapMultiset.Node].count(self.root)
191
[docs]
192 def le(self, key: T) -> Optional[T]:
193 return BSTMultisetNodeBase[T, TreapMultiset.Node].le(self.root, key)
194
[docs]
195 def lt(self, key: T) -> Optional[T]:
196 return BSTMultisetNodeBase[T, TreapMultiset.Node].lt(self.root, key)
197
[docs]
198 def ge(self, key: T) -> Optional[T]:
199 return BSTMultisetNodeBase[T, TreapMultiset.Node].ge(self.root, key)
200
[docs]
201 def gt(self, key: T) -> Optional[T]:
202 return BSTMultisetNodeBase[T, TreapMultiset.Node].gt(self.root, key)
203
[docs]
204 def len_elm(self) -> int:
205 return self._len_elm
206
[docs]
207 def show(self) -> None:
208 print(
209 "{" + ", ".join(map(lambda x: f"{x[0]}: {x[1]}", self.tolist_items())) + "}"
210 )
211
[docs]
212 def tolist(self) -> list[T]:
213 return BSTMultisetNodeBase[T, TreapMultiset.Node].tolist(self.root)
214
[docs]
215 def tolist_items(self) -> list[tuple[T, int]]:
216 return BSTMultisetNodeBase[T, TreapMultiset.Node].tolist_items(self.root)
217
[docs]
218 def get_min(self) -> Optional[T]:
219 return BSTMultisetNodeBase[T, TreapMultiset.Node][
220 T, TreapMultiset.Node
221 ].get_min(self.root)
222
[docs]
223 def get_max(self) -> Optional[T]:
224 return BSTMultisetNodeBase[T, TreapMultiset.Node].get_max(self.root)
225
[docs]
226 def pop_min(self) -> T:
227 assert self
228 self._len -= 1
229 node = self.root
230 pnode = None
231 while node.left is not None:
232 pnode = node
233 node = node.left
234 if node.val > 1:
235 node.val -= 1
236 return node.key
237 self._len_elm -= 1
238 res = node.key
239 if pnode is None:
240 self.root = self.root.right
241 else:
242 pnode.left = node.right
243 return res
244
[docs]
245 def pop_max(self) -> T:
246 assert self, "IndexError"
247 self._len -= 1
248 node = self.root
249 pnode = None
250 while node.right is not None:
251 pnode = node
252 node = node.right
253 if node.val > 1:
254 node.val -= 1
255 return node.key
256 self._len_elm -= 1
257 res = node.key
258 if pnode is None:
259 self.root = self.root.left
260 else:
261 pnode.right = node.left
262 return res
263
[docs]
264 def clear(self) -> None:
265 self.root = None
266
267 def __iter__(self):
268 self._it = self.get_min()
269 self._cnt = 1
270 return self
271
272 def __next__(self):
273 if self._it is None:
274 raise StopIteration
275 res = self._it
276 if self._cnt == self.count(self._it):
277 self._it = self.gt(self._it)
278 self._cnt = 1
279 else:
280 self._cnt += 1
281 return res
282
283 def __contains__(self, key: T):
284 return BSTMultisetNodeBase[T, TreapMultiset.Node].contains(self.root, key)
285
286 def __bool__(self):
287 return self.root is not None
288
289 def __len__(self):
290 return self._len
291
292 def __str__(self):
293 return "{" + ", ".join(map(str, self.tolist())) + "}"
294
295 def __repr__(self):
296 return f"TreapMultiset({self.tolist()})"