1from titan_pylib.my_class.ordered_set_interface import OrderedSetInterface
2from titan_pylib.my_class.supports_less_than import SupportsLessThan
3from titan_pylib.data_structures.bst_base.bst_set_node_base import BSTSetNodeBase
4from typing import Generic, Iterable, TypeVar, Optional
5
6T = TypeVar("T", bound=SupportsLessThan)
7
8
[docs]
9class TreapSet(OrderedSetInterface, Generic[T]):
10 """treap です。
11
12 乱数を使用して平衡を保っています。Hackされることなんてあるんですかね。今のところ集合と多重集合しかないです。
13 """
14
[docs]
15 class Random:
16
17 _x, _y, _z, _w = 123456789, 362436069, 521288629, 88675123
18
[docs]
19 @classmethod
20 def random(cls) -> int:
21 t = (cls._x ^ ((cls._x << 11) & 0xFFFFFFFF)) & 0xFFFFFFFF
22 cls._x, cls._y, cls._z = cls._y, cls._z, cls._w
23 cls._w = (cls._w ^ (cls._w >> 19)) ^ (
24 t ^ ((t >> 8)) & 0xFFFFFFFF
25 ) & 0xFFFFFFFF
26 return cls._w
27
[docs]
28 class Node:
29
30 def __init__(self, key: T, priority: int = -1):
31 self.key: T = key
32 self.left: Optional["TreapSet.Node"] = None
33 self.right: Optional["TreapSet.Node"] = None
34 self.priority: int = (
35 TreapSet.Random.random() if priority == -1 else priority
36 )
37
38 def __str__(self):
39 if self.left is None and self.right is None:
40 return f"key:{self.key, self.priority}\n"
41 return f"key:{self.key, self.priority},\n left:{self.left},\n right:{self.right}\n"
42
43 def __init__(self, a: Iterable[T] = []):
44 self.root: Optional["TreapSet.Node"] = None
45 self._len: int = 0
46 if not isinstance(a, list):
47 a = list(a)
48 if a:
49 self._build(a)
50
51 def _build(self, a: list[T]) -> None:
52 Node = TreapSet.Node
53
54 def rec(l: int, r: int) -> TreapSet.Node:
55 mid = (l + r) >> 1
56 node = Node(a[mid], rand[mid])
57 if l != mid:
58 node.left = rec(l, mid)
59 if mid + 1 != r:
60 node.right = rec(mid + 1, r)
61 return node
62
63 a = BSTSetNodeBase[T, TreapSet.Node].sort_unique(a)
64 self._len = len(a)
65 rand = sorted(TreapSet.Random.random() for _ in range(self._len))
66 self.root = rec(0, self._len)
67
68 def _rotate_L(self, node: Node) -> Node:
69 u = node.left
70 node.left = u.right
71 u.right = node
72 return u
73
74 def _rotate_R(self, node: Node) -> Node:
75 u = node.right
76 node.right = u.left
77 u.left = node
78 return u
79
[docs]
80 def add(self, key: T) -> bool:
81 if not self.root:
82 self.root = TreapSet.Node(key)
83 self._len = 1
84 return True
85 node = self.root
86 path = []
87 di = 0
88 while node:
89 if key == node.key:
90 return False
91 path.append(node)
92 if key < node.key:
93 di <<= 1
94 di |= 1
95 node = node.left
96 else:
97 di <<= 1
98 node = node.right
99 if di & 1:
100 path[-1].left = TreapSet.Node(key)
101 else:
102 path[-1].right = TreapSet.Node(key)
103 while path:
104 new_node = None
105 node = path.pop()
106 if di & 1:
107 if node.left.priority < node.priority:
108 new_node = self._rotate_L(node)
109 else:
110 if node.right.priority < node.priority:
111 new_node = self._rotate_R(node)
112 di >>= 1
113 if new_node:
114 if path:
115 if di & 1:
116 path[-1].left = new_node
117 else:
118 path[-1].right = new_node
119 else:
120 self.root = new_node
121 self._len += 1
122 return True
123
[docs]
124 def discard(self, key: T) -> bool:
125 node = self.root
126 pnode = None
127 while node:
128 if key == node.key:
129 break
130 pnode = node
131 node = node.left if key < node.key else node.right
132 else:
133 return False
134 self._len -= 1
135 while node.left and node.right:
136 if node.left.priority < node.right.priority:
137 if not pnode:
138 pnode = self._rotate_L(node)
139 self.root = pnode
140 continue
141 new_node = self._rotate_L(node)
142 if node.key < pnode.key:
143 pnode.left = new_node
144 else:
145 pnode.right = new_node
146 else:
147 if not pnode:
148 pnode = self._rotate_R(node)
149 self.root = pnode
150 continue
151 new_node = self._rotate_R(node)
152 if node.key < pnode.key:
153 pnode.left = new_node
154 else:
155 pnode.right = new_node
156 pnode = new_node
157 if not pnode:
158 if node.left is None:
159 self.root = node.right
160 else:
161 self.root = node.left
162 return True
163 if node.left is None:
164 if node.key < pnode.key:
165 pnode.left = node.right
166 else:
167 pnode.right = node.right
168 else:
169 if node.key < pnode.key:
170 pnode.left = node.left
171 else:
172 pnode.right = node.left
173 return True
174
[docs]
175 def remove(self, key: T) -> None:
176 if self.discard(key):
177 return
178 raise KeyError(key)
179
[docs]
180 def le(self, key: T) -> Optional[T]:
181 return BSTSetNodeBase[T, TreapSet.Node].le(self.root, key)
182
[docs]
183 def lt(self, key: T) -> Optional[T]:
184 return BSTSetNodeBase[T, TreapSet.Node].lt(self.root, key)
185
[docs]
186 def ge(self, key: T) -> Optional[T]:
187 return BSTSetNodeBase[T, TreapSet.Node].ge(self.root, key)
188
[docs]
189 def gt(self, key: T) -> Optional[T]:
190 return BSTSetNodeBase[T, TreapSet.Node].gt(self.root, key)
191
[docs]
192 def get_min(self) -> Optional[T]:
193 return BSTSetNodeBase[T, TreapSet.Node].get_min(self.root)
194
[docs]
195 def get_max(self) -> Optional[T]:
196 return BSTSetNodeBase[T, TreapSet.Node].get_max(self.root)
197
[docs]
198 def pop_min(self) -> T:
199 assert self.root, f"IndexError: pop_min() from Empty {self.__class__.__name__}."
200 node = self.root
201 pnode = None
202 while node.left:
203 pnode = node
204 node = node.left
205 self._len -= 1
206 res = node.key
207 if not pnode:
208 self.root = self.root.right
209 else:
210 pnode.left = node.right
211 return res
212
[docs]
213 def pop_max(self) -> T:
214 assert self.root, f"IndexError: pop_max() from Empty {self.__class__.__name__}."
215 node = self.root
216 pnode = None
217 while node.right:
218 pnode = node
219 node = node.right
220 self._len -= 1
221 res = node.key
222 if not pnode:
223 self.root = self.root.left
224 else:
225 pnode.right = node.left
226 return res
227
[docs]
228 def clear(self) -> None:
229 self.root = None
230
[docs]
231 def tolist(self) -> list[T]:
232 return BSTSetNodeBase[T, TreapSet.Node].tolist(self.root)
233
234 def __iter__(self):
235 self._it = self.get_min()
236 return self
237
238 def __next__(self):
239 if self._it is None:
240 raise StopIteration
241 res = self._it
242 self._it = self.gt(self._it)
243 return res
244
245 def __contains__(self, key: T):
246 return BSTSetNodeBase[T, TreapSet.Node].contains(self.root, key)
247
248 def __len__(self):
249 return self._len
250
251 def __bool__(self):
252 return self._len > 0
253
254 def __str__(self):
255 return "{" + ", ".join(map(str, self.tolist())) + "}"
256
257 def __repr__(self):
258 return f"{self.__class__.__name__}({self.tolist()})"