Source code for titan_pylib.data_structures.treap.treap_set

  1from titan_pylib.my_class.ordered_set_interface import OrderedSetInterface
  2from titan_pylib.my_class.supports_less_than import SupportsLessThan
  3from titan_pylib.data_structures.bst_base.bst_set_node_base import BSTSetNodeBase
  4from typing import Generic, Iterable, TypeVar, Optional
  5
  6T = TypeVar("T", bound=SupportsLessThan)
  7
  8
[docs] 9class TreapSet(OrderedSetInterface, Generic[T]): 10 """treap です。 11 12 乱数を使用して平衡を保っています。Hackされることなんてあるんですかね。今のところ集合と多重集合しかないです。 13 """ 14
[docs] 15 class Random: 16 17 _x, _y, _z, _w = 123456789, 362436069, 521288629, 88675123 18
[docs] 19 @classmethod 20 def random(cls) -> int: 21 t = (cls._x ^ ((cls._x << 11) & 0xFFFFFFFF)) & 0xFFFFFFFF 22 cls._x, cls._y, cls._z = cls._y, cls._z, cls._w 23 cls._w = (cls._w ^ (cls._w >> 19)) ^ ( 24 t ^ ((t >> 8)) & 0xFFFFFFFF 25 ) & 0xFFFFFFFF 26 return cls._w
27
[docs] 28 class Node: 29 30 def __init__(self, key: T, priority: int = -1): 31 self.key: T = key 32 self.left: Optional["TreapSet.Node"] = None 33 self.right: Optional["TreapSet.Node"] = None 34 self.priority: int = ( 35 TreapSet.Random.random() if priority == -1 else priority 36 ) 37 38 def __str__(self): 39 if self.left is None and self.right is None: 40 return f"key:{self.key, self.priority}\n" 41 return f"key:{self.key, self.priority},\n left:{self.left},\n right:{self.right}\n"
42 43 def __init__(self, a: Iterable[T] = []): 44 self.root: Optional["TreapSet.Node"] = None 45 self._len: int = 0 46 if not isinstance(a, list): 47 a = list(a) 48 if a: 49 self._build(a) 50 51 def _build(self, a: list[T]) -> None: 52 Node = TreapSet.Node 53 54 def rec(l: int, r: int) -> TreapSet.Node: 55 mid = (l + r) >> 1 56 node = Node(a[mid], rand[mid]) 57 if l != mid: 58 node.left = rec(l, mid) 59 if mid + 1 != r: 60 node.right = rec(mid + 1, r) 61 return node 62 63 a = BSTSetNodeBase[T, TreapSet.Node].sort_unique(a) 64 self._len = len(a) 65 rand = sorted(TreapSet.Random.random() for _ in range(self._len)) 66 self.root = rec(0, self._len) 67 68 def _rotate_L(self, node: Node) -> Node: 69 u = node.left 70 node.left = u.right 71 u.right = node 72 return u 73 74 def _rotate_R(self, node: Node) -> Node: 75 u = node.right 76 node.right = u.left 77 u.left = node 78 return u 79
[docs] 80 def add(self, key: T) -> bool: 81 if not self.root: 82 self.root = TreapSet.Node(key) 83 self._len = 1 84 return True 85 node = self.root 86 path = [] 87 di = 0 88 while node: 89 if key == node.key: 90 return False 91 path.append(node) 92 if key < node.key: 93 di <<= 1 94 di |= 1 95 node = node.left 96 else: 97 di <<= 1 98 node = node.right 99 if di & 1: 100 path[-1].left = TreapSet.Node(key) 101 else: 102 path[-1].right = TreapSet.Node(key) 103 while path: 104 new_node = None 105 node = path.pop() 106 if di & 1: 107 if node.left.priority < node.priority: 108 new_node = self._rotate_L(node) 109 else: 110 if node.right.priority < node.priority: 111 new_node = self._rotate_R(node) 112 di >>= 1 113 if new_node: 114 if path: 115 if di & 1: 116 path[-1].left = new_node 117 else: 118 path[-1].right = new_node 119 else: 120 self.root = new_node 121 self._len += 1 122 return True
123
[docs] 124 def discard(self, key: T) -> bool: 125 node = self.root 126 pnode = None 127 while node: 128 if key == node.key: 129 break 130 pnode = node 131 node = node.left if key < node.key else node.right 132 else: 133 return False 134 self._len -= 1 135 while node.left and node.right: 136 if node.left.priority < node.right.priority: 137 if not pnode: 138 pnode = self._rotate_L(node) 139 self.root = pnode 140 continue 141 new_node = self._rotate_L(node) 142 if node.key < pnode.key: 143 pnode.left = new_node 144 else: 145 pnode.right = new_node 146 else: 147 if not pnode: 148 pnode = self._rotate_R(node) 149 self.root = pnode 150 continue 151 new_node = self._rotate_R(node) 152 if node.key < pnode.key: 153 pnode.left = new_node 154 else: 155 pnode.right = new_node 156 pnode = new_node 157 if not pnode: 158 if node.left is None: 159 self.root = node.right 160 else: 161 self.root = node.left 162 return True 163 if node.left is None: 164 if node.key < pnode.key: 165 pnode.left = node.right 166 else: 167 pnode.right = node.right 168 else: 169 if node.key < pnode.key: 170 pnode.left = node.left 171 else: 172 pnode.right = node.left 173 return True
174
[docs] 175 def remove(self, key: T) -> None: 176 if self.discard(key): 177 return 178 raise KeyError(key)
179
[docs] 180 def le(self, key: T) -> Optional[T]: 181 return BSTSetNodeBase[T, TreapSet.Node].le(self.root, key)
182
[docs] 183 def lt(self, key: T) -> Optional[T]: 184 return BSTSetNodeBase[T, TreapSet.Node].lt(self.root, key)
185
[docs] 186 def ge(self, key: T) -> Optional[T]: 187 return BSTSetNodeBase[T, TreapSet.Node].ge(self.root, key)
188
[docs] 189 def gt(self, key: T) -> Optional[T]: 190 return BSTSetNodeBase[T, TreapSet.Node].gt(self.root, key)
191
[docs] 192 def get_min(self) -> Optional[T]: 193 return BSTSetNodeBase[T, TreapSet.Node].get_min(self.root)
194
[docs] 195 def get_max(self) -> Optional[T]: 196 return BSTSetNodeBase[T, TreapSet.Node].get_max(self.root)
197
[docs] 198 def pop_min(self) -> T: 199 assert self.root, f"IndexError: pop_min() from Empty {self.__class__.__name__}." 200 node = self.root 201 pnode = None 202 while node.left: 203 pnode = node 204 node = node.left 205 self._len -= 1 206 res = node.key 207 if not pnode: 208 self.root = self.root.right 209 else: 210 pnode.left = node.right 211 return res
212
[docs] 213 def pop_max(self) -> T: 214 assert self.root, f"IndexError: pop_max() from Empty {self.__class__.__name__}." 215 node = self.root 216 pnode = None 217 while node.right: 218 pnode = node 219 node = node.right 220 self._len -= 1 221 res = node.key 222 if not pnode: 223 self.root = self.root.left 224 else: 225 pnode.right = node.left 226 return res
227
[docs] 228 def clear(self) -> None: 229 self.root = None
230
[docs] 231 def tolist(self) -> list[T]: 232 return BSTSetNodeBase[T, TreapSet.Node].tolist(self.root)
233 234 def __iter__(self): 235 self._it = self.get_min() 236 return self 237 238 def __next__(self): 239 if self._it is None: 240 raise StopIteration 241 res = self._it 242 self._it = self.gt(self._it) 243 return res 244 245 def __contains__(self, key: T): 246 return BSTSetNodeBase[T, TreapSet.Node].contains(self.root, key) 247 248 def __len__(self): 249 return self._len 250 251 def __bool__(self): 252 return self._len > 0 253 254 def __str__(self): 255 return "{" + ", ".join(map(str, self.tolist())) + "}" 256 257 def __repr__(self): 258 return f"{self.__class__.__name__}({self.tolist()})"