Source code for titan_pylib.data_structures.set.segki_set
1from typing import Optional, Iterable
2
3
[docs]
4class SegkiSet:
5
6 # 0以上u未満の整数が載る集合
7 # セグ木的な構造、各Nodeはその子孫のOR値を保持(ORではなくSUMならBITと同じ感じ)
8 #
9 # 空間: O(u)
10 # add, discard, predecessor, successor: O(logu)
11 # contains, len: O(1)
12 # iteration: (nlogu)
13 # kth element: O(klogu)
14
15 def __init__(self, u: int, a: Iterable[int] = []):
16 self.log = (u - 1).bit_length()
17 self.size = 1 << self.log
18 self.u = u
19 self.data = bytearray(self.size << 1)
20 self.len = 0
21 for _a in a:
22 self.add(_a)
23
[docs]
24 def add(self, k: int) -> bool:
25 k += self.size
26 if self.data[k]:
27 return False
28 self.len += 1
29 self.data[k] = 1
30 while k > 1:
31 k >>= 1
32 if self.data[k]:
33 break
34 self.data[k] = 1
35 return True
36
[docs]
37 def discard(self, k: int) -> bool:
38 k += self.size
39 if self.data[k] == 0:
40 return False
41 self.len -= 1
42 self.data[k] = 0
43 while k > 1:
44 if k & 1:
45 if self.data[k - 1]:
46 break
47 else:
48 if self.data[k + 1]:
49 break
50 k >>= 1
51 self.data[k] = 0
52 return True
53
[docs]
54 def get_min(self) -> Optional[int]:
55 if self.data[1] == 0:
56 return None
57 k = 1
58 while k < self.size:
59 k <<= 1
60 if self.data[k] == 0:
61 k |= 1
62 return k - self.size
63
[docs]
64 def get_max(self) -> Optional[int]:
65 if self.data[1] == 0:
66 return None
67 k = 1
68 while k < self.size:
69 k <<= 1
70 if self.data[k | 1]:
71 k |= 1
72 return k - self.size
73
74 """Find the largest element < key, or None if it doesn't exist. / O(logN)"""
75
[docs]
76 def lt(self, k: int) -> Optional[int]:
77 if self.data[1] == 0:
78 return None
79 x = k
80 k += self.size
81 while k > 1:
82 if k & 1 and self.data[k - 1]:
83 k >>= 1
84 break
85 k >>= 1
86 k <<= 1
87 if self.data[k] == 0:
88 return None
89 while k < self.size:
90 k <<= 1
91 if self.data[k | 1]:
92 k |= 1
93 k -= self.size
94 return k if k < x else None
95
96 """Find the smallest element > key, or None if it doesn't exist. / O(logN)"""
97
[docs]
98 def gt(self, k: int) -> Optional[int]:
99 if self.data[1] == 0:
100 return None
101 x = k
102 k += self.size
103 while k > 1:
104 if k & 1 == 0 and self.data[k + 1]:
105 k >>= 1
106 break
107 k >>= 1
108 k = k << 1 | 1
109 while k < self.size:
110 k <<= 1
111 if self.data[k] == 0:
112 k |= 1
113 k -= self.size
114 return k if k > x and k < self.u else None
115
[docs]
116 def le(self, k: int) -> Optional[int]:
117 if self.data[k + self.size]:
118 return k
119 return self.lt(k)
120
[docs]
121 def ge(self, k: int) -> Optional[int]:
122 if self.data[k + self.size]:
123 return k
124 return self.gt(k)
125
[docs]
126 def debug(self):
127 print(
128 "\n".join(
129 " ".join(map(str, (self.data[(1 << i) + j] for j in range(1 << i))))
130 for i in range(self.log + 1)
131 )
132 )
133
134 def __contains__(self, k: int):
135 return self.data[k + self.size] == 1
136
137 def __getitem__(self, k: int): # kは先頭か末尾にすることを推奨
138 # O(klogu)
139 if k < 0:
140 k += self.len
141 if k == 0:
142 return self.get_min()
143 if k == self.len - 1:
144 return self.get_max()
145 if k < self.len >> 1:
146 key = self.get_min()
147 for _ in range(k):
148 key = self.gt(key)
149 else:
150 key = self.get_max()
151 for _ in range(self.len - k - 1):
152 key = self.lt(key)
153 return key
154
155 def __len__(self):
156 return self.len
157
158 def __iter__(self):
159 key = self.get_min()
160 while key is not None:
161 yield key
162 key = self.gt(key)
163
164 def __str__(self):
165 return "{" + ", ".join(map(str, self)) + "}"