Source code for titan_pylib.data_structures.set.sorted_set

  1# https://github.com/tatyam-prime/SortedSet/blob/main/SortedSet.py
  2import math
  3from bisect import bisect_left, bisect_right
  4from typing import Generic, Iterable, Iterator, TypeVar, Optional
  5
  6T = TypeVar("T")
  7
  8
[docs] 9class SortedSet(Generic[T]): 10 BUCKET_RATIO = 16 11 SPLIT_RATIO = 24 12 13 def __init__(self, a: Iterable[T] = []) -> None: 14 "Make a new SortedSet from iterable. / O(N) if sorted and unique / O(N log N)" 15 a = list(a) 16 n = len(a) 17 if any(a[i] > a[i + 1] for i in range(n - 1)): 18 a.sort() 19 if any(a[i] >= a[i + 1] for i in range(n - 1)): 20 a, b = [], a 21 for x in b: 22 if not a or a[-1] != x: 23 a.append(x) 24 n = self.size = len(a) 25 num_bucket = int(math.ceil(math.sqrt(n / self.BUCKET_RATIO))) 26 self.a = [ 27 a[n * i // num_bucket : n * (i + 1) // num_bucket] 28 for i in range(num_bucket) 29 ] 30 31 def __iter__(self) -> Iterator[T]: 32 for i in self.a: 33 for j in i: 34 yield j 35 36 def __reversed__(self) -> Iterator[T]: 37 for i in reversed(self.a): 38 for j in reversed(i): 39 yield j 40 41 def __eq__(self, other) -> bool: 42 return list(self) == list(other) 43 44 def __len__(self) -> int: 45 return self.size 46 47 def __repr__(self) -> str: 48 return "SortedSet" + str(self.a) 49 50 def __str__(self) -> str: 51 s = str(list(self)) 52 return "{" + s[1 : len(s) - 1] + "}" 53 54 def _position(self, x: T) -> tuple[list[T], int, int]: 55 "return the bucket, index of the bucket and position in which x should be. self must not be empty." 56 for i, a in enumerate(self.a): 57 if x <= a[-1]: 58 break 59 return (a, i, bisect_left(a, x)) 60 61 def __contains__(self, x: T) -> bool: 62 if self.size == 0: 63 return False 64 a, _, i = self._position(x) 65 return i != len(a) and a[i] == x 66
[docs] 67 def add(self, x: T) -> bool: 68 "Add an element and return True if added. / O(√N)" 69 if self.size == 0: 70 self.a = [[x]] 71 self.size = 1 72 return True 73 a, b, i = self._position(x) 74 if i != len(a) and a[i] == x: 75 return False 76 a.insert(i, x) 77 self.size += 1 78 if len(a) > len(self.a) * self.SPLIT_RATIO: 79 mid = len(a) >> 1 80 self.a[b : b + 1] = [a[:mid], a[mid:]] 81 return True
82 83 def _pop(self, a: list[T], b: int, i: int) -> T: 84 ans = a.pop(i) 85 self.size -= 1 86 if not a: 87 del self.a[b] 88 return ans 89
[docs] 90 def discard(self, x: T) -> bool: 91 "Remove an element and return True if removed. / O(√N)" 92 if self.size == 0: 93 return False 94 a, b, i = self._position(x) 95 if i == len(a) or a[i] != x: 96 return False 97 self._pop(a, b, i) 98 return True
99
[docs] 100 def lt(self, x: T) -> Optional[T]: 101 "Find the largest element < x, or None if it doesn't exist." 102 for a in reversed(self.a): 103 if a[0] < x: 104 return a[bisect_left(a, x) - 1]
105
[docs] 106 def le(self, x: T) -> Optional[T]: 107 "Find the largest element <= x, or None if it doesn't exist." 108 for a in reversed(self.a): 109 if a[0] <= x: 110 return a[bisect_right(a, x) - 1]
111
[docs] 112 def gt(self, x: T) -> Optional[T]: 113 "Find the smallest element > x, or None if it doesn't exist." 114 for a in self.a: 115 if a[-1] > x: 116 return a[bisect_right(a, x)]
117
[docs] 118 def ge(self, x: T) -> Optional[T]: 119 "Find the smallest element >= x, or None if it doesn't exist." 120 for a in self.a: 121 if a[-1] >= x: 122 return a[bisect_left(a, x)]
123
[docs] 124 def __getitem__(self, i: int) -> T: 125 "Return the i-th element." 126 if i < 0: 127 for a in reversed(self.a): 128 i += len(a) 129 if i >= 0: 130 return a[i] 131 else: 132 for a in self.a: 133 if i < len(a): 134 return a[i] 135 i -= len(a) 136 raise IndexError
137
[docs] 138 def pop(self, i: int = -1) -> T: 139 "Pop and return the i-th element." 140 if i < 0: 141 for b, a in enumerate(reversed(self.a)): 142 i += len(a) 143 if i >= 0: 144 return self._pop(a, ~b, i) 145 else: 146 for b, a in enumerate(self.a): 147 if i < len(a): 148 return self._pop(a, b, i) 149 i -= len(a) 150 raise IndexError
151
[docs] 152 def index(self, x: T) -> int: 153 "Count the number of elements < x." 154 ans = 0 155 for a in self.a: 156 if a[-1] >= x: 157 return ans + bisect_left(a, x) 158 ans += len(a) 159 return ans
160
[docs] 161 def index_right(self, x: T) -> int: 162 "Count the number of elements <= x." 163 ans = 0 164 for a in self.a: 165 if a[-1] > x: 166 return ans + bisect_right(a, x) 167 ans += len(a) 168 return ans