Source code for titan_pylib.data_structures.binary_trie.binary_trie_set

  1from titan_pylib.my_class.ordered_set_interface import OrderedSetInterface
  2from typing import Optional, Iterable
  3from array import array
  4
  5
[docs] 6class BinaryTrieSet(OrderedSetInterface): 7 8 def __init__(self, u: int, a: Iterable[int] = []) -> None: 9 """構築します。 10 :math:`O(n\\log{u})` です。 11 """ 12 self.left = array("I", bytes(8)) 13 self.right = array("I", bytes(8)) 14 self.par = array("I", bytes(8)) 15 self.size = array("I", bytes(8)) 16 self.valid = array("B", bytes(8)) 17 self.end = 2 18 self.root = 1 19 self.bit = (u - 1).bit_length() 20 self.lim = 1 << self.bit 21 self.xor = 0 22 for e in a: 23 self.add(e) 24 25 def _make_node(self) -> int: 26 end = self.end 27 if end >= len(self.left): 28 self.left.append(0) 29 self.right.append(0) 30 self.par.append(0) 31 self.size.append(0) 32 self.valid.append(1) 33 else: 34 self.valid[end] = 1 35 self.end += 1 36 return end 37 38 def _find(self, key: int) -> int: 39 left, right, valid = self.left, self.right, self.valid 40 key ^= self.xor 41 node = self.root 42 for i in range(self.bit - 1, -1, -1): 43 if key >> i & 1: 44 if (not right[node]) or (not valid[node]): 45 return -1 46 node = right[node] 47 else: 48 if (not left[node]) or (not valid[node]): 49 return -1 50 node = left[node] 51 return node 52
[docs] 53 def reserve(self, n: int) -> None: 54 """``n`` 要素分のメモリを確保します。 55 56 :math:`O(n)` です。 57 """ 58 assert n >= 0, f"ValueError: BinaryTrieSet.reserve({n})" 59 a = array("I", bytes(4 * n)) 60 self.left += a 61 self.right += a 62 self.par += a 63 self.size += a 64 self.valid += array("B", bytes(n))
65
[docs] 66 def add(self, key: int) -> bool: 67 assert ( 68 0 <= key < self.lim 69 ), f"ValueError: BinaryTrieSet.add({key}), lim={self.lim}" 70 left, right, par, size = self.left, self.right, self.par, self.size 71 key ^= self.xor 72 node = self.root 73 for i in range(self.bit - 1, -1, -1): 74 if key >> i & 1: 75 if not right[node]: 76 right[node] = self._make_node() 77 par[right[node]] = node 78 node = right[node] 79 else: 80 if not left[node]: 81 left[node] = self._make_node() 82 par[left[node]] = node 83 node = left[node] 84 if size[node]: 85 return False 86 size[node] = 1 87 for _ in range(self.bit): 88 node = par[node] 89 size[node] += 1 90 return True
91 92 def _rmeove(self, node: int) -> None: 93 left, right, par, size, valid = ( 94 self.left, 95 self.right, 96 self.par, 97 self.size, 98 self.valid, 99 ) 100 for _ in range(self.bit): 101 size[node] -= 1 102 if left[par[node]] == node: 103 node = par[node] 104 # left[node] = 0 105 valid[left[node]] = 0 106 if right[node]: 107 break 108 else: 109 node = par[node] 110 # right[node] = 0 111 valid[right[node]] = 0 112 if left[node]: 113 break 114 while node: 115 size[node] -= 1 116 node = par[node] 117
[docs] 118 def discard(self, key: int) -> bool: 119 assert ( 120 0 <= key < self.lim 121 ), f"ValueError: BinaryTrieSet.discard({key}), lim={self.lim}" 122 node = self._find(key) 123 if node == -1: 124 return False 125 self._rmeove(node) 126 return True
127
[docs] 128 def remove(self, key: int) -> None: 129 if self.discard(key): 130 return 131 raise KeyError(key)
132
[docs] 133 def pop(self, k: int = -1) -> int: 134 assert ( 135 -len(self) <= k < len(self) 136 ), f"IndexError: BinaryTrieSet.pop({k}), len={len(self)}" 137 if k < 0: 138 k += len(self) 139 left, right, size = self.left, self.right, self.size 140 node = self.root 141 res = 0 142 for i in range(self.bit - 1, -1, -1): 143 res <<= 1 144 if self.xor >> i & 1: 145 left, right = right, left 146 t = size[left[node]] 147 if t <= k: 148 k -= t 149 res |= 1 150 node = right[node] 151 else: 152 node = left[node] 153 if self.xor >> i & 1: 154 left, right = right, left 155 self._rmeove(node) 156 return res ^ self.xor
157
[docs] 158 def pop_min(self) -> int: 159 assert self, f"IndexError: BinaryTrieSet.pop_min(), len={len(self)}" 160 return self.pop(0)
161
[docs] 162 def pop_max(self) -> int: 163 return self.pop()
164
[docs] 165 def all_xor(self, x: int) -> None: 166 """すべての要素に ``x`` で ``xor`` をかけます。 167 168 :math:`O(1)` です。 169 """ 170 assert ( 171 0 <= x < self.lim 172 ), f"ValueError: BinaryTrieSet.all_xor({x}), lim={self.lim}" 173 self.xor ^= x
174
[docs] 175 def get_min(self) -> Optional[int]: 176 if not self: 177 return None 178 left, right = self.left, self.right 179 key = self.xor 180 ans = 0 181 node = self.root 182 for i in range(self.bit - 1, -1, -1): 183 ans <<= 1 184 if key >> i & 1: 185 if right[node]: 186 node = right[node] 187 ans |= 1 188 else: 189 node = left[node] 190 else: 191 if left[node]: 192 node = left[node] 193 else: 194 node = right[node] 195 ans |= 1 196 return ans ^ self.xor
197
[docs] 198 def get_max(self) -> Optional[int]: 199 if not self: 200 return None 201 left, right = self.left, self.right 202 key = self.xor 203 ans = 0 204 node = self.root 205 for i in range(self.bit - 1, -1, -1): 206 ans <<= 1 207 if key >> i & 1: 208 if left[node]: 209 node = left[node] 210 else: 211 node = right[node] 212 ans |= 1 213 else: 214 if right[node]: 215 ans |= 1 216 node = right[node] 217 else: 218 node = left[node] 219 return ans ^ self.xor
220
[docs] 221 def index(self, key: int) -> int: 222 assert ( 223 0 <= key < self.lim 224 ), f"ValueError: BinaryTrieSet.index({key}), lim={self.lim}" 225 left, right, size, valid = self.left, self.right, self.size, self.valid 226 k = 0 227 node = self.root 228 key ^= self.xor 229 for i in range(self.bit - 1, -1, -1): 230 if key >> i & 1: 231 k += size[left[node]] 232 node = right[node] 233 else: 234 node = left[node] 235 if (not node) or (not valid[node]): 236 break 237 return k
238
[docs] 239 def index_right(self, key: int) -> int: 240 assert ( 241 0 <= key < self.lim 242 ), f"ValueError: BinaryTrieSet.index_right({key}), lim={self.lim}" 243 left, right, size, valid = self.left, self.right, self.size, self.valid 244 k = 0 245 node = self.root 246 key ^= self.xor 247 for i in range(self.bit - 1, -1, -1): 248 if key >> i & 1: 249 k += size[left[node]] 250 node = right[node] 251 else: 252 node = left[node] 253 if (not node) or (not valid[node]): 254 break 255 else: 256 k += 1 257 return k
258
[docs] 259 def clear(self) -> None: 260 self.root = 1
261
[docs] 262 def gt(self, key: int) -> Optional[int]: 263 assert ( 264 0 <= key < self.lim 265 ), f"ValueError: BinaryTrieSet.gt({key}), lim={self.lim}" 266 i = self.index_right(key) 267 return None if i >= self.size[self.root] else self[i]
268
[docs] 269 def lt(self, key: int) -> Optional[int]: 270 assert ( 271 0 <= key < self.lim 272 ), f"ValueError: BinaryTrieSet.lt({key}), lim={self.lim}" 273 i = self.index(key) - 1 274 return None if i < 0 else self[i]
275
[docs] 276 def ge(self, key: int) -> Optional[int]: 277 assert ( 278 0 <= key < self.lim 279 ), f"ValueError: BinaryTrieSet.ge({key}), lim={self.lim}" 280 if key == 0: 281 return self.get_min() if self else None 282 i = self.index_right(key - 1) 283 return None if i >= self.size[self.root] else self[i]
284
[docs] 285 def le(self, key: int) -> Optional[int]: 286 assert ( 287 0 <= key < self.lim 288 ), f"ValueError: BinaryTrieSet.le({key}), lim={self.lim}" 289 i = self.index(key + 1) - 1 290 return None if i < 0 else self[i]
291
[docs] 292 def tolist(self) -> list[int]: 293 a = [] 294 if not self: 295 return a 296 val = self.get_min() 297 while val is not None: 298 a.append(val) 299 val = self.gt(val) 300 return a
301 302 def __contains__(self, key: int): 303 assert ( 304 0 <= key < self.lim 305 ), f"ValueError: {key} in BinaryTrieSet, lim={self.lim}" 306 return self._find(key) != -1 307 308 def __getitem__(self, k: int): 309 assert ( 310 -len(self) <= k < len(self) 311 ), f"IndexError: BinaryTrieSet[{k}], len={len(self)}" 312 if k < 0: 313 k += len(self) 314 left, right, size = self.left, self.right, self.size 315 node = self.root 316 res = 0 317 for i in range(self.bit - 1, -1, -1): 318 if self.xor >> i & 1: 319 left, right = right, left 320 t = size[left[node]] 321 if t <= k: 322 k -= t 323 node = right[node] 324 res |= 1 << i 325 else: 326 node = left[node] 327 if self.xor >> i & 1: 328 left, right = right, left 329 return res 330 331 def __bool__(self): 332 return self.size[self.root] != 0 333 334 def __iter__(self): 335 self.it = 0 336 return self 337 338 def __next__(self): 339 if self.it == len(self): 340 raise StopIteration 341 self.it += 1 342 return self.__getitem__(self.it - 1) 343 344 def __len__(self): 345 return self.size[self.root] 346 347 def __str__(self): 348 return "{" + ", ".join(map(str, self)) + "}" 349 350 def __repr__(self): 351 return f"BinaryTrieSet({(1<<self.bit)-1}, {self})"