Source code for titan_pylib.data_structures.set.fenwick_tree_set

  1from titan_pylib.my_class.supports_less_than import SupportsLessThan
  2from titan_pylib.data_structures.fenwick_tree.fenwick_tree import FenwickTree
  3from typing import Iterable, TypeVar, Generic, Union, Optional
  4
  5T = TypeVar("T", bound=SupportsLessThan)
  6
  7
[docs] 8class FenwickTreeSet(Generic[T]): 9 10 def __init__( 11 self, 12 _used: Union[int, Iterable[T]], 13 _a: Iterable[T] = [], 14 compress=True, 15 _multi=False, 16 ) -> None: 17 self._len = 0 18 if isinstance(_used, int): 19 self._to_origin = list(range(_used)) 20 elif isinstance(_used, set): 21 self._to_origin = sorted(_used) 22 else: 23 self._to_origin = sorted(set(_used)) 24 self._to_zaatsu: dict[T, int] = ( 25 {key: i for i, key in enumerate(self._to_origin)} 26 if compress 27 else self._to_origin 28 ) 29 self._size = len(self._to_origin) 30 self._cnt = [0] * self._size 31 _a = list(_a) 32 if _a: 33 a_ = [0] * self._size 34 if _multi: 35 self._len = len(_a) 36 for v in _a: 37 i = self._to_zaatsu[v] 38 a_[i] += 1 39 self._cnt[i] += 1 40 else: 41 for v in _a: 42 i = self._to_zaatsu[v] 43 if self._cnt[i] == 0: 44 self._len += 1 45 a_[i] = 1 46 self._cnt[i] = 1 47 self._fw = FenwickTree(a_) 48 else: 49 self._fw = FenwickTree(self._size) 50
[docs] 51 def add(self, key: T) -> bool: 52 i = self._to_zaatsu[key] 53 if self._cnt[i]: 54 return False 55 self._len += 1 56 self._cnt[i] = 1 57 self._fw.add(i, 1) 58 return True
59
[docs] 60 def remove(self, key: T) -> None: 61 if not self.discard(key): 62 raise KeyError(key)
63
[docs] 64 def discard(self, key: T) -> bool: 65 i = self._to_zaatsu[key] 66 if self._cnt[i]: 67 self._len -= 1 68 self._cnt[i] = 0 69 self._fw.add(i, -1) 70 return True 71 return False
72
[docs] 73 def le(self, key: T) -> Optional[T]: 74 i = self._to_zaatsu[key] 75 if self._cnt[i]: 76 return key 77 pref = self._fw.pref(i) - 1 78 return None if pref < 0 else self._to_origin[self._fw.bisect_right(pref)]
79
[docs] 80 def lt(self, key: T) -> Optional[T]: 81 pref = self._fw.pref(self._to_zaatsu[key]) - 1 82 return None if pref < 0 else self._to_origin[self._fw.bisect_right(pref)]
83
[docs] 84 def ge(self, key: T) -> Optional[T]: 85 i = self._to_zaatsu[key] 86 if self._cnt[i]: 87 return key 88 pref = self._fw.pref(i + 1) 89 return ( 90 None if pref >= self._len else self._to_origin[self._fw.bisect_right(pref)] 91 )
92
[docs] 93 def gt(self, key: T) -> Optional[T]: 94 pref = self._fw.pref(self._to_zaatsu[key] + 1) 95 return ( 96 None if pref >= self._len else self._to_origin[self._fw.bisect_right(pref)] 97 )
98
[docs] 99 def index(self, key: T) -> int: 100 return self._fw.pref(self._to_zaatsu[key])
101
[docs] 102 def index_right(self, key: T) -> int: 103 return self._fw.pref(self._to_zaatsu[key] + 1)
104
[docs] 105 def pop(self, k: int = -1) -> T: 106 assert ( 107 -self._len <= k < self._len 108 ), f"IndexError: FenwickTreeSet.pop({k}), Index out of range." 109 if k < 0: 110 k += self._len 111 self._len -= 1 112 x = self._fw._pop(k) 113 self._cnt[x] = 0 114 return self._to_origin[x]
115
[docs] 116 def pop_min(self) -> T: 117 assert ( 118 self._len > 0 119 ), f"IndexError: pop_min() from empty {self.__class__.__name__}." 120 return self.pop(0)
121
[docs] 122 def pop_max(self) -> T: 123 assert ( 124 self._len > 0 125 ), f"IndexError: pop_max() from empty {self.__class__.__name__}." 126 return self.pop(-1)
127
[docs] 128 def get_min(self) -> Optional[T]: 129 if not self: 130 return 131 return self[0]
132
[docs] 133 def get_max(self) -> Optional[T]: 134 if not self: 135 return 136 return self[-1]
137 138 def __getitem__(self, k): 139 assert ( 140 -self._len <= k < self._len 141 ), f"IndexError: FenwickTreeSet[{k}], Index out of range." 142 if k < 0: 143 k += self._len 144 return self._to_origin[self._fw.bisect_right(k)] 145 146 def __iter__(self): 147 self._iter = 0 148 return self 149 150 def __next__(self): 151 if self._iter == self._len: 152 raise StopIteration 153 res = self._to_origin[self._fw.bisect_right(self._iter)] 154 self._iter += 1 155 return res 156 157 def __reversed__(self): 158 _to_origin = self._to_origin 159 for i in range(self._len): 160 yield _to_origin[self._fw.bisect_right(self._len - i - 1)] 161 162 def __len__(self): 163 return self._len 164 165 def __contains__(self, key: T): 166 return self._cnt[self._to_zaatsu[key]] > 0 167 168 def __bool__(self): 169 return self._len > 0 170 171 def __str__(self): 172 return "{" + ", ".join(map(str, self)) + "}" 173 174 def __repr__(self): 175 return f"{self.__class__.__name__}({self})"