1import sys
2from typing import Generic, Iterable, TypeVar, Optional
3
4T = TypeVar("T")
5
6
[docs]
7class SplayTreeMultiset2(Generic[T]):
8
[docs]
9 class Node:
10
11 def __init__(self, key: T, val: int):
12 self.key = key
13 self.val = val
14 self.left = None
15 self.right = None
16
17 def __str__(self):
18 if self.left is None and self.right is None:
19 return f"key:{self.key, self.val}\n"
20 return (
21 f"key:{self.key, self.val},\n left:{self.left},\n right:{self.right}\n"
22 )
23
24 def __init__(self, a: Iterable[T] = []):
25 self.node = None
26 self._len = 0
27 self._len_elm = 0
28 if not (hasattr(a, "__getitem__") and hasattr(a, "__len__")):
29 a = list(a)
30 if a:
31 self._build(a)
32
33 def _build(self, a: Iterable[T]) -> None:
34 Node = SplayTreeMultiset2.Node
35
36 def sort(l: int, r: int) -> SplayTreeMultiset2.Node:
37 mid = (l + r) >> 1
38 node = Node(key[mid], val[mid])
39 if l != mid:
40 node.left = sort(l, mid)
41 if mid + 1 != r:
42 node.right = sort(mid + 1, r)
43 return node
44
45 a = sorted(a)
46 self._len = len(a)
47 key, val = self._rle(sorted(a))
48 self._len_elm = len(key)
49 self.node = sort(0, len(key))
50
51 def _rle(self, a: list[T]) -> tuple[list[T], list[int]]:
52 x = []
53 y = []
54 x.append(a[0])
55 y.append(1)
56 for i, e in enumerate(a):
57 if i == 0:
58 continue
59 if e == x[-1]:
60 y[-1] += 1
61 continue
62 x.append(e)
63 y.append(1)
64 return x, y
65
66 def _splay(self, path: list[Node], di: int) -> Node:
67 for _ in range(len(path) >> 1):
68 node = path.pop()
69 pnode = path.pop()
70 if di & 1 == di >> 1 & 1:
71 if di & 1 == 1:
72 tmp = node.left
73 node.left = tmp.right
74 tmp.right = node
75 pnode.left = node.right
76 node.right = pnode
77 else:
78 tmp = node.right
79 node.right = tmp.left
80 tmp.left = node
81 pnode.right = node.left
82 node.left = pnode
83 else:
84 if di & 1 == 1:
85 tmp = node.left
86 node.left = tmp.right
87 pnode.right = tmp.left
88 tmp.right = node
89 tmp.left = pnode
90 else:
91 tmp = node.right
92 node.right = tmp.left
93 pnode.left = tmp.right
94 tmp.left = node
95 tmp.right = pnode
96 if not path:
97 return tmp
98 di >>= 2
99 if di & 1 == 1:
100 path[-1].left = tmp
101 else:
102 path[-1].right = tmp
103 gnode = path[0]
104 if di & 1 == 1:
105 node = gnode.left
106 gnode.left = node.right
107 node.right = gnode
108 else:
109 node = gnode.right
110 gnode.right = node.left
111 node.left = gnode
112 return node
113
114 def _set_search_splay(self, key: T) -> None:
115 node = self.node
116 if node is None or node.key == key:
117 return
118 path = []
119 di = 0
120 while True:
121 if node.key == key:
122 break
123 elif key < node.key:
124 if node.left is None:
125 break
126 path.append(node)
127 di <<= 1
128 di |= 1
129 node = node.left
130 else:
131 if node.right is None:
132 break
133 path.append(node)
134 di <<= 1
135 node = node.right
136 if path:
137 self.node = self._splay(path, di)
138
139 def _get_min_splay(self, node: Node) -> Node:
140 if node is None or node.left is None:
141 return node
142 path = []
143 while node.left is not None:
144 path.append(node)
145 node = node.left
146 return self._splay(path, (1 << len(path)) - 1)
147
148 def _get_max_splay(self, node: Node) -> Node:
149 if node is None or node.right is None:
150 return node
151 path = []
152 while node.right is not None:
153 path.append(node)
154 node = node.right
155 return self._splay(path, 0)
156
[docs]
157 def add(self, key: T, val: int = 1) -> None:
158 self._len += val
159 if self.node is None:
160 self._len_elm += 1
161 self.node = SplayTreeMultiset2.Node(key, val)
162 return
163 self._set_search_splay(key)
164 if self.node.key == key:
165 self.node.val += val
166 return
167 self._len_elm += 1
168 node = SplayTreeMultiset2.Node(key, val)
169 if key < self.node.key:
170 node.left = self.node.left
171 node.right = self.node
172 self.node.left = None
173 else:
174 node.left = self.node
175 node.right = self.node.right
176 self.node.right = None
177 self.node = node
178 return
179
[docs]
180 def discard(self, key: T, val: int = 1) -> bool:
181 if self.node is None:
182 return False
183 self._set_search_splay(key)
184 if self.node.key != key:
185 return False
186 if self.node.val > val:
187 self.node.val -= val
188 self._len -= val
189 return True
190 self._len -= self.node.val
191 self._len_elm -= 1
192 if self.node.left is None:
193 self.node = self.node.right
194 elif self.node.right is None:
195 self.node = self.node.left
196 else:
197 node = self._get_min_splay(self.node.right)
198 node.left = self.node.left
199 self.node = node
200 return True
201
[docs]
202 def discard_all(self, key: T) -> bool:
203 return self.discar(key, self.count(key))
204
[docs]
205 def count(self, key: T) -> int:
206 if self.node is None:
207 return 0
208 self._set_search_splay(key)
209 return self.node.val if self.node.key == key else 0
210
[docs]
211 def le(self, key: T) -> Optional[T]:
212 node = self.node
213 if node is None:
214 return None
215 path = []
216 di = 0
217 res = None
218 while True:
219 if node.key == key:
220 res = key
221 break
222 elif key < node.key:
223 if node.left is None:
224 break
225 path.append(node)
226 di <<= 1
227 di |= 1
228 node = node.left
229 else:
230 res = node.key
231 if node.right is None:
232 break
233 path.append(node)
234 di <<= 1
235 node = node.right
236 if path:
237 self.node = self._splay(path, di)
238 return res
239
[docs]
240 def lt(self, key: T) -> Optional[T]:
241 node = self.node
242 if node is None:
243 return None
244 path = []
245 di = 0
246 res = None
247 while True:
248 if key <= node.key:
249 if node.left is None:
250 break
251 path.append(node)
252 di <<= 1
253 di |= 1
254 node = node.left
255 else:
256 res = node.key
257 if node.right is None:
258 break
259 path.append(node)
260 di <<= 1
261 node = node.right
262 if path:
263 self.node = self._splay(path, di)
264 return res
265
[docs]
266 def ge(self, key: T) -> Optional[T]:
267 node = self.node
268 if node is None:
269 return None
270 path = []
271 di = 0
272 res = None
273 while True:
274 if node.key == key:
275 res = node.key
276 break
277 elif key < node.key:
278 res = node.key
279 if node.left is None:
280 break
281 path.append(node)
282 di <<= 1
283 di |= 1
284 node = node.left
285 else:
286 if node.right is None:
287 break
288 path.append(node)
289 di <<= 1
290 node = node.right
291 if path:
292 self.node = self._splay(path, di)
293 return res
294
[docs]
295 def gt(self, key: T) -> Optional[T]:
296 node = self.node
297 if node is None:
298 return None
299 path = []
300 di = 0
301 res = None
302 while True:
303 if key < node.key:
304 res = node.key
305 if node.left is None:
306 break
307 path.append(node)
308 di <<= 1
309 di |= 1
310 node = node.left
311 else:
312 if node.right is None:
313 break
314 path.append(node)
315 di <<= 1
316 node = node.right
317 if path:
318 self.node = self._splay(path, di)
319 return res
320
[docs]
321 def pop_max(self) -> T:
322 self.node = self._get_max_splay(self.node)
323 res = self.node.key
324 self.discard(res)
325 return res
326
[docs]
327 def pop_min(self) -> T:
328 self.node = self._get_min_splay(self.node)
329 res = self.node.key
330 self.discard(res)
331 return res
332
[docs]
333 def get_min(self) -> Optional[T]:
334 if self.node is None:
335 return
336 self.node = self._get_min_splay(self.node)
337 return self.node.key
338
[docs]
339 def get_max(self) -> Optional[T]:
340 if self.node is None:
341 return
342 self.node = self._get_max_splay(self.node)
343 return self.node.key
344
[docs]
345 def tolist(self) -> list[T]:
346 a = []
347 if self.node is None:
348 return a
349 if sys.getrecursionlimit() < self.len_elm():
350 sys.setrecursionlimit(self.len_elm() + 1)
351
352 def rec(node):
353 if node.left is not None:
354 rec(node.left)
355 a.extend([node.key] * node.val)
356 if node.right is not None:
357 rec(node.right)
358
359 rec(self.node)
360 return a
361
[docs]
362 def tolist_items(self) -> list[tuple[T, int]]:
363 a = []
364 if self.node is None:
365 return a
366 if sys.getrecursionlimit() < self._len_elm():
367 sys.setrecursionlimit(self._len_elm() + 1)
368
369 def rec(node):
370 if node.left is not None:
371 rec(node.left)
372 a.append((node.key, node.val))
373 if node.right is not None:
374 rec(node.right)
375
376 rec(self.node)
377 return a
378
[docs]
379 def len_elm(self) -> int:
380 return self._len_elm
381
[docs]
382 def clear(self) -> None:
383 self.node = None
384
385 def __getitem__(self, k): # 先s頭と末尾しか対応していない
386 if k == -1 or k == self._len - 1:
387 return self.get_max()
388 elif k == 0:
389 return self.get_min()
390 raise IndexError
391
392 def __contains__(self, key: T) -> bool:
393 self._set_search_splay(key)
394 return self.node is not None and self.node.key == key
395
396 def __len__(self):
397 return self._len
398
399 def __bool__(self):
400 return self.node is not None
401
402 def __str__(self):
403 return "{" + ", ".join(map(str, self.tolist())) + "}"
404
405 def __repr__(self):
406 return f"SplayTreeMultiset2({self.tolist()})"