segki_set¶
ソースコード¶
from titan_pylib.data_structures.set.segki_set import SegkiSet
展開済みコード¶
1# from titan_pylib.data_structures.set.segki_set import SegkiSet
2from typing import Optional, Iterable
3
4
5class SegkiSet:
6
7 # 0以上u未満の整数が載る集合
8 # セグ木的な構造、各Nodeはその子孫のOR値を保持(ORではなくSUMならBITと同じ感じ)
9 #
10 # 空間: O(u)
11 # add, discard, predecessor, successor: O(logu)
12 # contains, len: O(1)
13 # iteration: (nlogu)
14 # kth element: O(klogu)
15
16 def __init__(self, u: int, a: Iterable[int] = []):
17 self.log = (u - 1).bit_length()
18 self.size = 1 << self.log
19 self.u = u
20 self.data = bytearray(self.size << 1)
21 self.len = 0
22 for _a in a:
23 self.add(_a)
24
25 def add(self, k: int) -> bool:
26 k += self.size
27 if self.data[k]:
28 return False
29 self.len += 1
30 self.data[k] = 1
31 while k > 1:
32 k >>= 1
33 if self.data[k]:
34 break
35 self.data[k] = 1
36 return True
37
38 def discard(self, k: int) -> bool:
39 k += self.size
40 if self.data[k] == 0:
41 return False
42 self.len -= 1
43 self.data[k] = 0
44 while k > 1:
45 if k & 1:
46 if self.data[k - 1]:
47 break
48 else:
49 if self.data[k + 1]:
50 break
51 k >>= 1
52 self.data[k] = 0
53 return True
54
55 def get_min(self) -> Optional[int]:
56 if self.data[1] == 0:
57 return None
58 k = 1
59 while k < self.size:
60 k <<= 1
61 if self.data[k] == 0:
62 k |= 1
63 return k - self.size
64
65 def get_max(self) -> Optional[int]:
66 if self.data[1] == 0:
67 return None
68 k = 1
69 while k < self.size:
70 k <<= 1
71 if self.data[k | 1]:
72 k |= 1
73 return k - self.size
74
75 """Find the largest element < key, or None if it doesn't exist. / O(logN)"""
76
77 def lt(self, k: int) -> Optional[int]:
78 if self.data[1] == 0:
79 return None
80 x = k
81 k += self.size
82 while k > 1:
83 if k & 1 and self.data[k - 1]:
84 k >>= 1
85 break
86 k >>= 1
87 k <<= 1
88 if self.data[k] == 0:
89 return None
90 while k < self.size:
91 k <<= 1
92 if self.data[k | 1]:
93 k |= 1
94 k -= self.size
95 return k if k < x else None
96
97 """Find the smallest element > key, or None if it doesn't exist. / O(logN)"""
98
99 def gt(self, k: int) -> Optional[int]:
100 if self.data[1] == 0:
101 return None
102 x = k
103 k += self.size
104 while k > 1:
105 if k & 1 == 0 and self.data[k + 1]:
106 k >>= 1
107 break
108 k >>= 1
109 k = k << 1 | 1
110 while k < self.size:
111 k <<= 1
112 if self.data[k] == 0:
113 k |= 1
114 k -= self.size
115 return k if k > x and k < self.u else None
116
117 def le(self, k: int) -> Optional[int]:
118 if self.data[k + self.size]:
119 return k
120 return self.lt(k)
121
122 def ge(self, k: int) -> Optional[int]:
123 if self.data[k + self.size]:
124 return k
125 return self.gt(k)
126
127 def debug(self):
128 print(
129 "\n".join(
130 " ".join(map(str, (self.data[(1 << i) + j] for j in range(1 << i))))
131 for i in range(self.log + 1)
132 )
133 )
134
135 def __contains__(self, k: int):
136 return self.data[k + self.size] == 1
137
138 def __getitem__(self, k: int): # kは先頭か末尾にすることを推奨
139 # O(klogu)
140 if k < 0:
141 k += self.len
142 if k == 0:
143 return self.get_min()
144 if k == self.len - 1:
145 return self.get_max()
146 if k < self.len >> 1:
147 key = self.get_min()
148 for _ in range(k):
149 key = self.gt(key)
150 else:
151 key = self.get_max()
152 for _ in range(self.len - k - 1):
153 key = self.lt(key)
154 return key
155
156 def __len__(self):
157 return self.len
158
159 def __iter__(self):
160 key = self.get_min()
161 while key is not None:
162 yield key
163 key = self.gt(key)
164
165 def __str__(self):
166 return "{" + ", ".join(map(str, self)) + "}"