Source code for titan_pylib.data_structures.binary_trie.binary_trie_multiset

  1from titan_pylib.my_class.ordered_multiset_interface import OrderedMultisetInterface
  2from typing import Optional, Iterable
  3from array import array
  4
  5
[docs] 6class BinaryTrieMultiset(OrderedMultisetInterface): 7 8 def __init__(self, u: int, a: Iterable[int] = []) -> None: 9 self.left = array("I", bytes(8)) 10 self.right = array("I", bytes(8)) 11 self.par = array("I", bytes(8)) 12 self.size = array("I", bytes(8)) 13 self.end: int = 2 14 self.root: int = 1 15 self.bit: int = (u - 1).bit_length() 16 self.lim: int = 1 << self.bit 17 self.xor: int = 0 18 for e in a: 19 self.add(e) 20 21 def _make_node(self) -> int: 22 if self.end >= len(self.left): 23 self.left.append(0) 24 self.right.append(0) 25 self.par.append(0) 26 self.size.append(0) 27 self.end += 1 28 return self.end - 1 29 30 def _find(self, key: int) -> int: 31 assert ( 32 0 <= key < self.lim 33 ), f"ValueError: BinaryTrieMultiset._find({key}), lim={self.lim}" 34 left, right = self.left, self.right 35 key ^= self.xor 36 node = self.root 37 for i in range(self.bit - 1, -1, -1): 38 if key >> i & 1: 39 if not right[node]: 40 return -1 41 node = right[node] 42 else: 43 if not left[node]: 44 return -1 45 node = left[node] 46 return node 47
[docs] 48 def reserve(self, n: int) -> None: 49 assert n >= 0, f"ValueError: BinaryTrieMultiset.reserve({n})" 50 a = array("I", bytes(4 * n)) 51 self.left += a 52 self.right += a 53 self.par += a 54 self.size += a
55
[docs] 56 def add(self, key: int, cnt: int = 1) -> None: 57 assert ( 58 0 <= key < self.lim 59 ), f"ValueError: BinaryTrieMultiset.add({key}), lim={self.lim}" 60 left, right, par, size = self.left, self.right, self.par, self.size 61 key ^= self.xor 62 node = self.root 63 for i in range(self.bit - 1, -1, -1): 64 size[node] += cnt 65 if key >> i & 1: 66 if not right[node]: 67 right[node] = self._make_node() 68 par[right[node]] = node 69 node = right[node] 70 else: 71 if not left[node]: 72 left[node] = self._make_node() 73 par[left[node]] = node 74 node = left[node] 75 size[node] += cnt
76 77 def _remove(self, node: int) -> None: 78 left, right, par, size = self.left, self.right, self.par, self.size 79 cnt = size[node] 80 for _ in range(self.bit): 81 size[node] -= cnt 82 if left[par[node]] == node: 83 node = par[node] 84 left[node] = 0 85 if right[node]: 86 break 87 else: 88 node = par[node] 89 right[node] = 0 90 if left[node]: 91 break 92 while node: 93 size[node] -= cnt 94 node = par[node] 95
[docs] 96 def discard(self, key: int, cnt: int = 1) -> bool: 97 assert ( 98 0 <= key < self.lim 99 ), f"ValueError: BinaryTrieMultiset.discard({key}), lim={self.lim}" 100 par, size = self.par, self.size 101 node = self._find(key) 102 if node == -1: 103 return False 104 if size[node] <= cnt: 105 self._remove(node) 106 else: 107 while node: 108 size[node] -= cnt 109 node = par[node] 110 return True
111
[docs] 112 def discard_all(self, key: int) -> bool: 113 return self.discard(key, self.count(key))
114
[docs] 115 def remove(self, key: int, cnt: int = 1) -> None: 116 left, right, par, size = self.left, self.right, self.par, self.size 117 _key = key 118 key ^= self.xor 119 node = self.root 120 for i in range(self.bit - 1, -1, -1): 121 if key >> i & 1: 122 if not right[node]: 123 assert False, f"{_key} is not found." 124 node = right[node] 125 else: 126 if not left[node]: 127 assert False, f"{_key} is not found." 128 node = left[node] 129 c = self.size[node] 130 if c < cnt: 131 assert False, f"{_key} is not found." 132 elif c == cnt: 133 self._remove(node) 134 else: 135 while node: 136 size[node] -= cnt 137 node = par[node]
138
[docs] 139 def count(self, key: int) -> int: 140 node = self._find(key) 141 return 0 if node == -1 else self.size[node]
142
[docs] 143 def pop(self, k: int = -1) -> int: 144 assert ( 145 -len(self) <= k < len(self) 146 ), f"IndexError: BinaryTrieMultiset.pop({k}), len={len(self)}" 147 if k < 0: 148 k += len(self) 149 left, right, par, size = self.left, self.right, self.par, self.size 150 node = self.root 151 res = 0 152 for i in range(self.bit - 1, -1, -1): 153 b = self.xor >> i & 1 154 if b: 155 left, right = right, left 156 t = size[left[node]] 157 res <<= 1 158 if not left[node]: 159 node = right[node] 160 res |= 1 161 elif not right[node]: 162 node = left[node] 163 else: 164 t = size[left[node]] 165 if t <= k: 166 k -= t 167 res |= 1 168 node = right[node] 169 else: 170 node = left[node] 171 if b: 172 left, right = right, left 173 if size[node] == 1: 174 self._remove(node) 175 else: 176 while node: 177 size[node] -= 1 178 node = par[node] 179 return res ^ self.xor
180
[docs] 181 def pop_min(self) -> int: 182 assert self, f"IndexError: BinaryTrieMultiset.pop_min(), len={len(self)}" 183 return self.pop(0)
184
[docs] 185 def pop_max(self) -> int: 186 return self.pop()
187
[docs] 188 def all_xor(self, x: int) -> None: 189 assert ( 190 0 <= x < self.lim 191 ), f"ValueError: BinaryTrieMultiset.all_xor({x}), lim={self.lim}" 192 self.xor ^= x
193
[docs] 194 def get_min(self) -> Optional[int]: 195 if not self: 196 return None 197 left, right = self.left, self.right 198 key = self.xor 199 ans = 0 200 node = self.root 201 for i in range(self.bit - 1, -1, -1): 202 ans <<= 1 203 if key >> i & 1: 204 if right[node]: 205 node = right[node] 206 ans |= 1 207 else: 208 node = left[node] 209 else: 210 if left[node]: 211 node = left[node] 212 else: 213 node = right[node] 214 ans |= 1 215 return ans ^ self.xor
216
[docs] 217 def get_max(self) -> Optional[int]: 218 if not self: 219 return None 220 left, right = self.left, self.right 221 key = self.xor 222 ans = 0 223 node = self.root 224 for i in range(self.bit - 1, -1, -1): 225 ans <<= 1 226 if key >> i & 1: 227 if left[node]: 228 node = left[node] 229 else: 230 node = right[node] 231 ans |= 1 232 else: 233 if right[node]: 234 ans |= 1 235 node = right[node] 236 else: 237 node = left[node] 238 return ans ^ self.xor
239
[docs] 240 def index(self, key: int) -> int: 241 assert ( 242 0 <= key < self.lim 243 ), f"ValueError: BinaryTrieMultiset.index({key}), lim={self.lim}" 244 left, right, size = self.left, self.right, self.size 245 k = 0 246 node = self.root 247 key ^= self.xor 248 for i in range(self.bit - 1, -1, -1): 249 if key >> i & 1: 250 k += size[left[node]] 251 node = right[node] 252 else: 253 node = left[node] 254 if not node: 255 break 256 return k
257
[docs] 258 def index_right(self, key: int) -> int: 259 assert ( 260 0 <= key < self.lim 261 ), f"ValueError: BinaryTrieMultiset.index_right({key}), lim={self.lim}" 262 left, right, size = self.left, self.right, self.size 263 k = 0 264 node = self.root 265 key ^= self.xor 266 for i in range(self.bit - 1, -1, -1): 267 if key >> i & 1: 268 k += size[left[node]] 269 node = right[node] 270 else: 271 node = left[node] 272 if not node: 273 break 274 else: 275 k += size[node] 276 return k
277
[docs] 278 def gt(self, key: int) -> Optional[int]: 279 assert ( 280 0 <= key < self.lim 281 ), f"ValueError: BinaryTrieMultiset.gt({key}), lim={self.lim}" 282 i = self.index_right(key) 283 return None if i >= self.size[self.root] else self[i]
284
[docs] 285 def lt(self, key: int) -> Optional[int]: 286 assert ( 287 0 <= key < self.lim 288 ), f"ValueError: BinaryTrieMultiset.lt({key}), lim={self.lim}" 289 i = self.index(key) - 1 290 return None if i < 0 else self[i]
291
[docs] 292 def ge(self, key: int) -> Optional[int]: 293 assert ( 294 0 <= key < self.lim 295 ), f"ValueError: BinaryTrieMultiset.ge({key}), lim={self.lim}" 296 if key == 0: 297 return self.get_min() if self else None 298 i = self.index_right(key - 1) 299 return None if i >= self.size[self.root] else self[i]
300
[docs] 301 def le(self, key: int) -> Optional[int]: 302 assert ( 303 0 <= key < self.lim 304 ), f"ValueError: BinaryTrieMultiset.le({key}), lim={self.lim}" 305 i = self.index(key + 1) - 1 306 return None if i < 0 else self[i]
307
[docs] 308 def tolist(self) -> list[int]: 309 a = [] 310 if not self: 311 return a 312 val = self.get_min() 313 while val is not None: 314 for _ in range(self.count(val)): 315 a.append(val) 316 val = self.gt(val) 317 return a
318
[docs] 319 def clear(self) -> None: 320 self.root = 1
321 322 def __contains__(self, key: int) -> bool: 323 assert ( 324 0 <= key < self.lim 325 ), f"ValueError: BinaryTrieMultiset.__contains__({key}), lim={self.lim}" 326 return self._find(key) != -1 327 328 def __getitem__(self, k: int) -> int: 329 assert ( 330 -len(self) <= k < len(self) 331 ), f"IndexError: BinaryTrieMultiset[({k}], len={len(self)}" 332 if k < 0: 333 k += len(self) 334 left, right, size = self.left, self.right, self.size 335 node = self.root 336 res = 0 337 for i in range(self.bit - 1, -1, -1): 338 b = self.xor >> i & 1 339 if b: 340 left, right = right, left 341 t = size[left[node]] 342 res <<= 1 343 if not left[node]: 344 node = right[node] 345 res |= 1 346 elif not right[node]: 347 node = left[node] 348 else: 349 t = size[left[node]] 350 if t <= k: 351 k -= t 352 res |= 1 353 node = right[node] 354 else: 355 node = left[node] 356 if b: 357 left, right = right, left 358 return res 359 360 def __bool__(self): 361 return self.size[self.root] != 0 362 363 def __iter__(self): 364 self.it = 0 365 return self 366 367 def __next__(self): 368 if self.it == len(self): 369 raise StopIteration 370 self.it += 1 371 return self.__getitem__(self.it - 1) 372 373 def __len__(self): 374 return self.size[self.root] 375 376 def __str__(self): 377 return "{" + ", ".join(map(str, self.tolist())) + "}" 378 379 def __repr__(self): 380 return f"BinaryTrieMultiset({(1<<self.bit)-1}, {self.tolist()})"