Source code for titan_pylib.data_structures.set.segki_set

  1from typing import Optional, Iterable
  2
  3
[docs] 4class SegkiSet: 5 6 # 0以上u未満の整数が載る集合 7 # セグ木的な構造、各Nodeはその子孫のOR値を保持(ORではなくSUMならBITと同じ感じ) 8 # 9 # 空間: O(u) 10 # add, discard, predecessor, successor: O(logu) 11 # contains, len: O(1) 12 # iteration: (nlogu) 13 # kth element: O(klogu) 14 15 def __init__(self, u: int, a: Iterable[int] = []): 16 self.log = (u - 1).bit_length() 17 self.size = 1 << self.log 18 self.u = u 19 self.data = bytearray(self.size << 1) 20 self.len = 0 21 for _a in a: 22 self.add(_a) 23
[docs] 24 def add(self, k: int) -> bool: 25 k += self.size 26 if self.data[k]: 27 return False 28 self.len += 1 29 self.data[k] = 1 30 while k > 1: 31 k >>= 1 32 if self.data[k]: 33 break 34 self.data[k] = 1 35 return True
36
[docs] 37 def discard(self, k: int) -> bool: 38 k += self.size 39 if self.data[k] == 0: 40 return False 41 self.len -= 1 42 self.data[k] = 0 43 while k > 1: 44 if k & 1: 45 if self.data[k - 1]: 46 break 47 else: 48 if self.data[k + 1]: 49 break 50 k >>= 1 51 self.data[k] = 0 52 return True
53
[docs] 54 def get_min(self) -> Optional[int]: 55 if self.data[1] == 0: 56 return None 57 k = 1 58 while k < self.size: 59 k <<= 1 60 if self.data[k] == 0: 61 k |= 1 62 return k - self.size
63
[docs] 64 def get_max(self) -> Optional[int]: 65 if self.data[1] == 0: 66 return None 67 k = 1 68 while k < self.size: 69 k <<= 1 70 if self.data[k | 1]: 71 k |= 1 72 return k - self.size
73 74 """Find the largest element < key, or None if it doesn't exist. / O(logN)""" 75
[docs] 76 def lt(self, k: int) -> Optional[int]: 77 if self.data[1] == 0: 78 return None 79 x = k 80 k += self.size 81 while k > 1: 82 if k & 1 and self.data[k - 1]: 83 k >>= 1 84 break 85 k >>= 1 86 k <<= 1 87 if self.data[k] == 0: 88 return None 89 while k < self.size: 90 k <<= 1 91 if self.data[k | 1]: 92 k |= 1 93 k -= self.size 94 return k if k < x else None
95 96 """Find the smallest element > key, or None if it doesn't exist. / O(logN)""" 97
[docs] 98 def gt(self, k: int) -> Optional[int]: 99 if self.data[1] == 0: 100 return None 101 x = k 102 k += self.size 103 while k > 1: 104 if k & 1 == 0 and self.data[k + 1]: 105 k >>= 1 106 break 107 k >>= 1 108 k = k << 1 | 1 109 while k < self.size: 110 k <<= 1 111 if self.data[k] == 0: 112 k |= 1 113 k -= self.size 114 return k if k > x and k < self.u else None
115
[docs] 116 def le(self, k: int) -> Optional[int]: 117 if self.data[k + self.size]: 118 return k 119 return self.lt(k)
120
[docs] 121 def ge(self, k: int) -> Optional[int]: 122 if self.data[k + self.size]: 123 return k 124 return self.gt(k)
125
[docs] 126 def debug(self): 127 print( 128 "\n".join( 129 " ".join(map(str, (self.data[(1 << i) + j] for j in range(1 << i)))) 130 for i in range(self.log + 1) 131 ) 132 )
133 134 def __contains__(self, k: int): 135 return self.data[k + self.size] == 1 136 137 def __getitem__(self, k: int): # kは先頭か末尾にすることを推奨 138 # O(klogu) 139 if k < 0: 140 k += self.len 141 if k == 0: 142 return self.get_min() 143 if k == self.len - 1: 144 return self.get_max() 145 if k < self.len >> 1: 146 key = self.get_min() 147 for _ in range(k): 148 key = self.gt(key) 149 else: 150 key = self.get_max() 151 for _ in range(self.len - k - 1): 152 key = self.lt(key) 153 return key 154 155 def __len__(self): 156 return self.len 157 158 def __iter__(self): 159 key = self.get_min() 160 while key is not None: 161 yield key 162 key = self.gt(key) 163 164 def __str__(self): 165 return "{" + ", ".join(map(str, self)) + "}"