treap_set¶
ソースコード¶
from titan_pylib.data_structures.treap.treap_set import TreapSet
展開済みコード¶
1# from titan_pylib.data_structures.treap.treap_set import TreapSet
2# from titan_pylib.my_class.ordered_set_interface import OrderedSetInterface
3# from titan_pylib.my_class.supports_less_than import SupportsLessThan
4from typing import Protocol
5
6
7class SupportsLessThan(Protocol):
8
9 def __lt__(self, other) -> bool: ...
10from abc import ABC, abstractmethod
11from typing import Iterable, Optional, Iterator, TypeVar, Generic
12
13T = TypeVar("T", bound=SupportsLessThan)
14
15
16class OrderedSetInterface(ABC, Generic[T]):
17
18 @abstractmethod
19 def __init__(self, a: Iterable[T]) -> None:
20 raise NotImplementedError
21
22 @abstractmethod
23 def add(self, key: T) -> bool:
24 raise NotImplementedError
25
26 @abstractmethod
27 def discard(self, key: T) -> bool:
28 raise NotImplementedError
29
30 @abstractmethod
31 def remove(self, key: T) -> None:
32 raise NotImplementedError
33
34 @abstractmethod
35 def le(self, key: T) -> Optional[T]:
36 raise NotImplementedError
37
38 @abstractmethod
39 def lt(self, key: T) -> Optional[T]:
40 raise NotImplementedError
41
42 @abstractmethod
43 def ge(self, key: T) -> Optional[T]:
44 raise NotImplementedError
45
46 @abstractmethod
47 def gt(self, key: T) -> Optional[T]:
48 raise NotImplementedError
49
50 @abstractmethod
51 def get_max(self) -> Optional[T]:
52 raise NotImplementedError
53
54 @abstractmethod
55 def get_min(self) -> Optional[T]:
56 raise NotImplementedError
57
58 @abstractmethod
59 def pop_max(self) -> T:
60 raise NotImplementedError
61
62 @abstractmethod
63 def pop_min(self) -> T:
64 raise NotImplementedError
65
66 @abstractmethod
67 def clear(self) -> None:
68 raise NotImplementedError
69
70 @abstractmethod
71 def tolist(self) -> list[T]:
72 raise NotImplementedError
73
74 @abstractmethod
75 def __iter__(self) -> Iterator:
76 raise NotImplementedError
77
78 @abstractmethod
79 def __next__(self) -> T:
80 raise NotImplementedError
81
82 @abstractmethod
83 def __contains__(self, key: T) -> bool:
84 raise NotImplementedError
85
86 @abstractmethod
87 def __len__(self) -> int:
88 raise NotImplementedError
89
90 @abstractmethod
91 def __bool__(self) -> bool:
92 raise NotImplementedError
93
94 @abstractmethod
95 def __str__(self) -> str:
96 raise NotImplementedError
97
98 @abstractmethod
99 def __repr__(self) -> str:
100 raise NotImplementedError
101# from titan_pylib.my_class.supports_less_than import SupportsLessThan
102# from titan_pylib.data_structures.bst_base.bst_set_node_base import BSTSetNodeBase
103from typing import TypeVar, Generic, Optional
104
105T = TypeVar("T")
106Node = TypeVar("Node")
107# protcolで、key,left,right を規定
108
109
110class BSTSetNodeBase(Generic[T, Node]):
111
112 @staticmethod
113 def sort_unique(a: list[T]) -> list[T]:
114 if not all(a[i] < a[i + 1] for i in range(len(a) - 1)):
115 a = sorted(a)
116 new_a = [a[0]]
117 for elm in a:
118 if new_a[-1] == elm:
119 continue
120 new_a.append(elm)
121 a = new_a
122 return a
123
124 @staticmethod
125 def contains(node: Node, key: T) -> bool:
126 while node:
127 if key == node.key:
128 return True
129 node = node.left if key < node.key else node.right
130 return False
131
132 @staticmethod
133 def get_min(node: Node) -> Optional[T]:
134 if not node:
135 return None
136 while node.left:
137 node = node.left
138 return node.key
139
140 @staticmethod
141 def get_max(node: Node) -> Optional[T]:
142 if not node:
143 return None
144 while node.right:
145 node = node.right
146 return node.key
147
148 @staticmethod
149 def le(node: Node, key: T) -> Optional[T]:
150 res = None
151 while node is not None:
152 if key == node.key:
153 res = key
154 break
155 if key < node.key:
156 node = node.left
157 else:
158 res = node.key
159 node = node.right
160 return res
161
162 @staticmethod
163 def lt(node: Node, key: T) -> Optional[T]:
164 res = None
165 while node is not None:
166 if key <= node.key:
167 node = node.left
168 else:
169 res = node.key
170 node = node.right
171 return res
172
173 @staticmethod
174 def ge(node: Node, key: T) -> Optional[T]:
175 res = None
176 while node is not None:
177 if key == node.key:
178 res = key
179 break
180 if key < node.key:
181 res = node.key
182 node = node.left
183 else:
184 node = node.right
185 return res
186
187 @staticmethod
188 def gt(node: Node, key: T) -> Optional[T]:
189 res = None
190 while node is not None:
191 if key < node.key:
192 res = node.key
193 node = node.left
194 else:
195 node = node.right
196 return res
197
198 @staticmethod
199 def index(node: Node, key: T) -> int:
200 k = 0
201 while node is not None:
202 if key == node.key:
203 if node.left is not None:
204 k += node.left.size
205 break
206 if key < node.key:
207 node = node.left
208 else:
209 k += 1 if node.left is None else node.left.size + 1
210 node = node.right
211 return k
212
213 @staticmethod
214 def index_right(node: Node, key: T) -> int:
215 k = 0
216 while node is not None:
217 if key == node.key:
218 k += 1 if node.left is None else node.left.size + 1
219 break
220 if key < node.key:
221 node = node.left
222 else:
223 k += 1 if node.left is None else node.left.size + 1
224 node = node.right
225 return k
226
227 @staticmethod
228 def tolist(node: Node) -> list[T]:
229 stack = []
230 res = []
231 while stack or node:
232 if node:
233 stack.append(node)
234 node = node.left
235 else:
236 node = stack.pop()
237 res.append(node.key)
238 node = node.right
239 return res
240
241 @staticmethod
242 def kth_elm(node: Node, k: int, _len: int) -> T:
243 if k < 0:
244 k += _len
245 while True:
246 t = 0 if node.left is None else node.left.size
247 if t == k:
248 return node.key
249 if t > k:
250 node = node.left
251 else:
252 node = node.right
253 k -= t + 1
254from typing import Generic, Iterable, TypeVar, Optional
255
256T = TypeVar("T", bound=SupportsLessThan)
257
258
259class TreapSet(OrderedSetInterface, Generic[T]):
260 """treap です。
261
262 乱数を使用して平衡を保っています。Hackされることなんてあるんですかね。今のところ集合と多重集合しかないです。
263 """
264
265 class Random:
266
267 _x, _y, _z, _w = 123456789, 362436069, 521288629, 88675123
268
269 @classmethod
270 def random(cls) -> int:
271 t = (cls._x ^ ((cls._x << 11) & 0xFFFFFFFF)) & 0xFFFFFFFF
272 cls._x, cls._y, cls._z = cls._y, cls._z, cls._w
273 cls._w = (cls._w ^ (cls._w >> 19)) ^ (
274 t ^ ((t >> 8)) & 0xFFFFFFFF
275 ) & 0xFFFFFFFF
276 return cls._w
277
278 class Node:
279
280 def __init__(self, key: T, priority: int = -1):
281 self.key: T = key
282 self.left: Optional["TreapSet.Node"] = None
283 self.right: Optional["TreapSet.Node"] = None
284 self.priority: int = (
285 TreapSet.Random.random() if priority == -1 else priority
286 )
287
288 def __str__(self):
289 if self.left is None and self.right is None:
290 return f"key:{self.key, self.priority}\n"
291 return f"key:{self.key, self.priority},\n left:{self.left},\n right:{self.right}\n"
292
293 def __init__(self, a: Iterable[T] = []):
294 self.root: Optional["TreapSet.Node"] = None
295 self._len: int = 0
296 if not isinstance(a, list):
297 a = list(a)
298 if a:
299 self._build(a)
300
301 def _build(self, a: list[T]) -> None:
302 Node = TreapSet.Node
303
304 def rec(l: int, r: int) -> TreapSet.Node:
305 mid = (l + r) >> 1
306 node = Node(a[mid], rand[mid])
307 if l != mid:
308 node.left = rec(l, mid)
309 if mid + 1 != r:
310 node.right = rec(mid + 1, r)
311 return node
312
313 a = BSTSetNodeBase[T, TreapSet.Node].sort_unique(a)
314 self._len = len(a)
315 rand = sorted(TreapSet.Random.random() for _ in range(self._len))
316 self.root = rec(0, self._len)
317
318 def _rotate_L(self, node: Node) -> Node:
319 u = node.left
320 node.left = u.right
321 u.right = node
322 return u
323
324 def _rotate_R(self, node: Node) -> Node:
325 u = node.right
326 node.right = u.left
327 u.left = node
328 return u
329
330 def add(self, key: T) -> bool:
331 if not self.root:
332 self.root = TreapSet.Node(key)
333 self._len = 1
334 return True
335 node = self.root
336 path = []
337 di = 0
338 while node:
339 if key == node.key:
340 return False
341 path.append(node)
342 if key < node.key:
343 di <<= 1
344 di |= 1
345 node = node.left
346 else:
347 di <<= 1
348 node = node.right
349 if di & 1:
350 path[-1].left = TreapSet.Node(key)
351 else:
352 path[-1].right = TreapSet.Node(key)
353 while path:
354 new_node = None
355 node = path.pop()
356 if di & 1:
357 if node.left.priority < node.priority:
358 new_node = self._rotate_L(node)
359 else:
360 if node.right.priority < node.priority:
361 new_node = self._rotate_R(node)
362 di >>= 1
363 if new_node:
364 if path:
365 if di & 1:
366 path[-1].left = new_node
367 else:
368 path[-1].right = new_node
369 else:
370 self.root = new_node
371 self._len += 1
372 return True
373
374 def discard(self, key: T) -> bool:
375 node = self.root
376 pnode = None
377 while node:
378 if key == node.key:
379 break
380 pnode = node
381 node = node.left if key < node.key else node.right
382 else:
383 return False
384 self._len -= 1
385 while node.left and node.right:
386 if node.left.priority < node.right.priority:
387 if not pnode:
388 pnode = self._rotate_L(node)
389 self.root = pnode
390 continue
391 new_node = self._rotate_L(node)
392 if node.key < pnode.key:
393 pnode.left = new_node
394 else:
395 pnode.right = new_node
396 else:
397 if not pnode:
398 pnode = self._rotate_R(node)
399 self.root = pnode
400 continue
401 new_node = self._rotate_R(node)
402 if node.key < pnode.key:
403 pnode.left = new_node
404 else:
405 pnode.right = new_node
406 pnode = new_node
407 if not pnode:
408 if node.left is None:
409 self.root = node.right
410 else:
411 self.root = node.left
412 return True
413 if node.left is None:
414 if node.key < pnode.key:
415 pnode.left = node.right
416 else:
417 pnode.right = node.right
418 else:
419 if node.key < pnode.key:
420 pnode.left = node.left
421 else:
422 pnode.right = node.left
423 return True
424
425 def remove(self, key: T) -> None:
426 if self.discard(key):
427 return
428 raise KeyError(key)
429
430 def le(self, key: T) -> Optional[T]:
431 return BSTSetNodeBase[T, TreapSet.Node].le(self.root, key)
432
433 def lt(self, key: T) -> Optional[T]:
434 return BSTSetNodeBase[T, TreapSet.Node].lt(self.root, key)
435
436 def ge(self, key: T) -> Optional[T]:
437 return BSTSetNodeBase[T, TreapSet.Node].ge(self.root, key)
438
439 def gt(self, key: T) -> Optional[T]:
440 return BSTSetNodeBase[T, TreapSet.Node].gt(self.root, key)
441
442 def get_min(self) -> Optional[T]:
443 return BSTSetNodeBase[T, TreapSet.Node].get_min(self.root)
444
445 def get_max(self) -> Optional[T]:
446 return BSTSetNodeBase[T, TreapSet.Node].get_max(self.root)
447
448 def pop_min(self) -> T:
449 assert self.root, f"IndexError: pop_min() from Empty {self.__class__.__name__}."
450 node = self.root
451 pnode = None
452 while node.left:
453 pnode = node
454 node = node.left
455 self._len -= 1
456 res = node.key
457 if not pnode:
458 self.root = self.root.right
459 else:
460 pnode.left = node.right
461 return res
462
463 def pop_max(self) -> T:
464 assert self.root, f"IndexError: pop_max() from Empty {self.__class__.__name__}."
465 node = self.root
466 pnode = None
467 while node.right:
468 pnode = node
469 node = node.right
470 self._len -= 1
471 res = node.key
472 if not pnode:
473 self.root = self.root.left
474 else:
475 pnode.right = node.left
476 return res
477
478 def clear(self) -> None:
479 self.root = None
480
481 def tolist(self) -> list[T]:
482 return BSTSetNodeBase[T, TreapSet.Node].tolist(self.root)
483
484 def __iter__(self):
485 self._it = self.get_min()
486 return self
487
488 def __next__(self):
489 if self._it is None:
490 raise StopIteration
491 res = self._it
492 self._it = self.gt(self._it)
493 return res
494
495 def __contains__(self, key: T):
496 return BSTSetNodeBase[T, TreapSet.Node].contains(self.root, key)
497
498 def __len__(self):
499 return self._len
500
501 def __bool__(self):
502 return self._len > 0
503
504 def __str__(self):
505 return "{" + ", ".join(map(str, self.tolist())) + "}"
506
507 def __repr__(self):
508 return f"{self.__class__.__name__}({self.tolist()})"
仕様¶
- class TreapSet(a: Iterable[T] = [])[source]¶
Bases:
OrderedSetInterface
,Generic
[T
]treap です。
乱数を使用して平衡を保っています。Hackされることなんてあるんですかね。今のところ集合と多重集合しかないです。