1# from titan_pylib.data_structures.splay_tree.splay_tree_multiset2 import SplayTreeMultiset2
2import sys
3from typing import Generic, Iterable, TypeVar, Optional
4
5T = TypeVar("T")
6
7
8class SplayTreeMultiset2(Generic[T]):
9
10 class Node:
11
12 def __init__(self, key: T, val: int):
13 self.key = key
14 self.val = val
15 self.left = None
16 self.right = None
17
18 def __str__(self):
19 if self.left is None and self.right is None:
20 return f"key:{self.key, self.val}\n"
21 return (
22 f"key:{self.key, self.val},\n left:{self.left},\n right:{self.right}\n"
23 )
24
25 def __init__(self, a: Iterable[T] = []):
26 self.node = None
27 self._len = 0
28 self._len_elm = 0
29 if not (hasattr(a, "__getitem__") and hasattr(a, "__len__")):
30 a = list(a)
31 if a:
32 self._build(a)
33
34 def _build(self, a: Iterable[T]) -> None:
35 Node = SplayTreeMultiset2.Node
36
37 def sort(l: int, r: int) -> SplayTreeMultiset2.Node:
38 mid = (l + r) >> 1
39 node = Node(key[mid], val[mid])
40 if l != mid:
41 node.left = sort(l, mid)
42 if mid + 1 != r:
43 node.right = sort(mid + 1, r)
44 return node
45
46 a = sorted(a)
47 self._len = len(a)
48 key, val = self._rle(sorted(a))
49 self._len_elm = len(key)
50 self.node = sort(0, len(key))
51
52 def _rle(self, a: list[T]) -> tuple[list[T], list[int]]:
53 x = []
54 y = []
55 x.append(a[0])
56 y.append(1)
57 for i, e in enumerate(a):
58 if i == 0:
59 continue
60 if e == x[-1]:
61 y[-1] += 1
62 continue
63 x.append(e)
64 y.append(1)
65 return x, y
66
67 def _splay(self, path: list[Node], di: int) -> Node:
68 for _ in range(len(path) >> 1):
69 node = path.pop()
70 pnode = path.pop()
71 if di & 1 == di >> 1 & 1:
72 if di & 1 == 1:
73 tmp = node.left
74 node.left = tmp.right
75 tmp.right = node
76 pnode.left = node.right
77 node.right = pnode
78 else:
79 tmp = node.right
80 node.right = tmp.left
81 tmp.left = node
82 pnode.right = node.left
83 node.left = pnode
84 else:
85 if di & 1 == 1:
86 tmp = node.left
87 node.left = tmp.right
88 pnode.right = tmp.left
89 tmp.right = node
90 tmp.left = pnode
91 else:
92 tmp = node.right
93 node.right = tmp.left
94 pnode.left = tmp.right
95 tmp.left = node
96 tmp.right = pnode
97 if not path:
98 return tmp
99 di >>= 2
100 if di & 1 == 1:
101 path[-1].left = tmp
102 else:
103 path[-1].right = tmp
104 gnode = path[0]
105 if di & 1 == 1:
106 node = gnode.left
107 gnode.left = node.right
108 node.right = gnode
109 else:
110 node = gnode.right
111 gnode.right = node.left
112 node.left = gnode
113 return node
114
115 def _set_search_splay(self, key: T) -> None:
116 node = self.node
117 if node is None or node.key == key:
118 return
119 path = []
120 di = 0
121 while True:
122 if node.key == key:
123 break
124 elif key < node.key:
125 if node.left is None:
126 break
127 path.append(node)
128 di <<= 1
129 di |= 1
130 node = node.left
131 else:
132 if node.right is None:
133 break
134 path.append(node)
135 di <<= 1
136 node = node.right
137 if path:
138 self.node = self._splay(path, di)
139
140 def _get_min_splay(self, node: Node) -> Node:
141 if node is None or node.left is None:
142 return node
143 path = []
144 while node.left is not None:
145 path.append(node)
146 node = node.left
147 return self._splay(path, (1 << len(path)) - 1)
148
149 def _get_max_splay(self, node: Node) -> Node:
150 if node is None or node.right is None:
151 return node
152 path = []
153 while node.right is not None:
154 path.append(node)
155 node = node.right
156 return self._splay(path, 0)
157
158 def add(self, key: T, val: int = 1) -> None:
159 self._len += val
160 if self.node is None:
161 self._len_elm += 1
162 self.node = SplayTreeMultiset2.Node(key, val)
163 return
164 self._set_search_splay(key)
165 if self.node.key == key:
166 self.node.val += val
167 return
168 self._len_elm += 1
169 node = SplayTreeMultiset2.Node(key, val)
170 if key < self.node.key:
171 node.left = self.node.left
172 node.right = self.node
173 self.node.left = None
174 else:
175 node.left = self.node
176 node.right = self.node.right
177 self.node.right = None
178 self.node = node
179 return
180
181 def discard(self, key: T, val: int = 1) -> bool:
182 if self.node is None:
183 return False
184 self._set_search_splay(key)
185 if self.node.key != key:
186 return False
187 if self.node.val > val:
188 self.node.val -= val
189 self._len -= val
190 return True
191 self._len -= self.node.val
192 self._len_elm -= 1
193 if self.node.left is None:
194 self.node = self.node.right
195 elif self.node.right is None:
196 self.node = self.node.left
197 else:
198 node = self._get_min_splay(self.node.right)
199 node.left = self.node.left
200 self.node = node
201 return True
202
203 def discard_all(self, key: T) -> bool:
204 return self.discar(key, self.count(key))
205
206 def count(self, key: T) -> int:
207 if self.node is None:
208 return 0
209 self._set_search_splay(key)
210 return self.node.val if self.node.key == key else 0
211
212 def le(self, key: T) -> Optional[T]:
213 node = self.node
214 if node is None:
215 return None
216 path = []
217 di = 0
218 res = None
219 while True:
220 if node.key == key:
221 res = key
222 break
223 elif key < node.key:
224 if node.left is None:
225 break
226 path.append(node)
227 di <<= 1
228 di |= 1
229 node = node.left
230 else:
231 res = node.key
232 if node.right is None:
233 break
234 path.append(node)
235 di <<= 1
236 node = node.right
237 if path:
238 self.node = self._splay(path, di)
239 return res
240
241 def lt(self, key: T) -> Optional[T]:
242 node = self.node
243 if node is None:
244 return None
245 path = []
246 di = 0
247 res = None
248 while True:
249 if key <= node.key:
250 if node.left is None:
251 break
252 path.append(node)
253 di <<= 1
254 di |= 1
255 node = node.left
256 else:
257 res = node.key
258 if node.right is None:
259 break
260 path.append(node)
261 di <<= 1
262 node = node.right
263 if path:
264 self.node = self._splay(path, di)
265 return res
266
267 def ge(self, key: T) -> Optional[T]:
268 node = self.node
269 if node is None:
270 return None
271 path = []
272 di = 0
273 res = None
274 while True:
275 if node.key == key:
276 res = node.key
277 break
278 elif key < node.key:
279 res = node.key
280 if node.left is None:
281 break
282 path.append(node)
283 di <<= 1
284 di |= 1
285 node = node.left
286 else:
287 if node.right is None:
288 break
289 path.append(node)
290 di <<= 1
291 node = node.right
292 if path:
293 self.node = self._splay(path, di)
294 return res
295
296 def gt(self, key: T) -> Optional[T]:
297 node = self.node
298 if node is None:
299 return None
300 path = []
301 di = 0
302 res = None
303 while True:
304 if key < node.key:
305 res = node.key
306 if node.left is None:
307 break
308 path.append(node)
309 di <<= 1
310 di |= 1
311 node = node.left
312 else:
313 if node.right is None:
314 break
315 path.append(node)
316 di <<= 1
317 node = node.right
318 if path:
319 self.node = self._splay(path, di)
320 return res
321
322 def pop_max(self) -> T:
323 self.node = self._get_max_splay(self.node)
324 res = self.node.key
325 self.discard(res)
326 return res
327
328 def pop_min(self) -> T:
329 self.node = self._get_min_splay(self.node)
330 res = self.node.key
331 self.discard(res)
332 return res
333
334 def get_min(self) -> Optional[T]:
335 if self.node is None:
336 return
337 self.node = self._get_min_splay(self.node)
338 return self.node.key
339
340 def get_max(self) -> Optional[T]:
341 if self.node is None:
342 return
343 self.node = self._get_max_splay(self.node)
344 return self.node.key
345
346 def tolist(self) -> list[T]:
347 a = []
348 if self.node is None:
349 return a
350 if sys.getrecursionlimit() < self.len_elm():
351 sys.setrecursionlimit(self.len_elm() + 1)
352
353 def rec(node):
354 if node.left is not None:
355 rec(node.left)
356 a.extend([node.key] * node.val)
357 if node.right is not None:
358 rec(node.right)
359
360 rec(self.node)
361 return a
362
363 def tolist_items(self) -> list[tuple[T, int]]:
364 a = []
365 if self.node is None:
366 return a
367 if sys.getrecursionlimit() < self._len_elm():
368 sys.setrecursionlimit(self._len_elm() + 1)
369
370 def rec(node):
371 if node.left is not None:
372 rec(node.left)
373 a.append((node.key, node.val))
374 if node.right is not None:
375 rec(node.right)
376
377 rec(self.node)
378 return a
379
380 def len_elm(self) -> int:
381 return self._len_elm
382
383 def clear(self) -> None:
384 self.node = None
385
386 def __getitem__(self, k): # 先s頭と末尾しか対応していない
387 if k == -1 or k == self._len - 1:
388 return self.get_max()
389 elif k == 0:
390 return self.get_min()
391 raise IndexError
392
393 def __contains__(self, key: T) -> bool:
394 self._set_search_splay(key)
395 return self.node is not None and self.node.key == key
396
397 def __len__(self):
398 return self._len
399
400 def __bool__(self):
401 return self.node is not None
402
403 def __str__(self):
404 return "{" + ", ".join(map(str, self.tolist())) + "}"
405
406 def __repr__(self):
407 return f"SplayTreeMultiset2({self.tolist()})"