hash_string¶
ソースコード¶
from titan_pylib.string.hash_string import HashStringBase
from titan_pylib.string.hash_string import HashString
展開済みコード¶
1# from titan_pylib.string.hash_string import HashString
2# ref: https://qiita.com/keymoon/items/11fac5627672a6d6a9f6
3# from titan_pylib.data_structures.segment_tree.segment_tree import SegmentTree
4# from titan_pylib.data_structures.segment_tree.segment_tree_interface import (
5# SegmentTreeInterface,
6# )
7from abc import ABC, abstractmethod
8from typing import TypeVar, Generic, Union, Iterable, Callable
9
10T = TypeVar("T")
11
12
13class SegmentTreeInterface(ABC, Generic[T]):
14
15 @abstractmethod
16 def __init__(self, n_or_a: Union[int, Iterable[T]], op: Callable[[T, T], T], e: T):
17 raise NotImplementedError
18
19 @abstractmethod
20 def set(self, k: int, v: T) -> None:
21 raise NotImplementedError
22
23 @abstractmethod
24 def get(self, k: int) -> T:
25 raise NotImplementedError
26
27 @abstractmethod
28 def prod(self, l: int, r: int) -> T:
29 raise NotImplementedError
30
31 @abstractmethod
32 def all_prod(self) -> T:
33 raise NotImplementedError
34
35 @abstractmethod
36 def max_right(self, l: int, f: Callable[[T], bool]) -> int:
37 raise NotImplementedError
38
39 @abstractmethod
40 def min_left(self, r: int, f: Callable[[T], bool]) -> int:
41 raise NotImplementedError
42
43 @abstractmethod
44 def tolist(self) -> list[T]:
45 raise NotImplementedError
46
47 @abstractmethod
48 def __getitem__(self, k: int) -> T:
49 raise NotImplementedError
50
51 @abstractmethod
52 def __setitem__(self, k: int, v: T) -> None:
53 raise NotImplementedError
54
55 @abstractmethod
56 def __str__(self):
57 raise NotImplementedError
58
59 @abstractmethod
60 def __repr__(self):
61 raise NotImplementedError
62from typing import Generic, Iterable, TypeVar, Callable, Union
63
64T = TypeVar("T")
65
66
67class SegmentTree(SegmentTreeInterface, Generic[T]):
68 """セグ木です。非再帰です。"""
69
70 def __init__(
71 self, n_or_a: Union[int, Iterable[T]], op: Callable[[T, T], T], e: T
72 ) -> None:
73 """``SegmentTree`` を構築します。
74 :math:`O(n)` です。
75
76 Args:
77 n_or_a (Union[int, Iterable[T]]): ``n: int`` のとき、 ``e`` を初期値として長さ ``n`` の ``SegmentTree`` を構築します。
78 ``a: Iterable[T]`` のとき、 ``a`` から ``SegmentTree`` を構築します。
79 op (Callable[[T, T], T]): 2項演算の関数です。
80 e (T): 単位元です。
81 """
82 self._op = op
83 self._e = e
84 if isinstance(n_or_a, int):
85 self._n = n_or_a
86 self._log = (self._n - 1).bit_length()
87 self._size = 1 << self._log
88 self._data = [e] * (self._size << 1)
89 else:
90 n_or_a = list(n_or_a)
91 self._n = len(n_or_a)
92 self._log = (self._n - 1).bit_length()
93 self._size = 1 << self._log
94 _data = [e] * (self._size << 1)
95 _data[self._size : self._size + self._n] = n_or_a
96 for i in range(self._size - 1, 0, -1):
97 _data[i] = op(_data[i << 1], _data[i << 1 | 1])
98 self._data = _data
99
100 def set(self, k: int, v: T) -> None:
101 """一点更新です。
102 :math:`O(\\log{n})` です。
103
104 Args:
105 k (int): 更新するインデックスです。
106 v (T): 更新する値です。
107
108 制約:
109 :math:`-n \\leq n \\leq k < n`
110 """
111 assert (
112 -self._n <= k < self._n
113 ), f"IndexError: {self.__class__.__name__}.set({k}, {v}), n={self._n}"
114 if k < 0:
115 k += self._n
116 k += self._size
117 self._data[k] = v
118 for _ in range(self._log):
119 k >>= 1
120 self._data[k] = self._op(self._data[k << 1], self._data[k << 1 | 1])
121
122 def get(self, k: int) -> T:
123 """一点取得です。
124 :math:`O(1)` です。
125
126 Args:
127 k (int): インデックスです。
128
129 制約:
130 :math:`-n \\leq n \\leq k < n`
131 """
132 assert (
133 -self._n <= k < self._n
134 ), f"IndexError: {self.__class__.__name__}.get({k}), n={self._n}"
135 if k < 0:
136 k += self._n
137 return self._data[k + self._size]
138
139 def prod(self, l: int, r: int) -> T:
140 """区間 ``[l, r)`` の総積を返します。
141 :math:`O(\\log{n})` です。
142
143 Args:
144 l (int): インデックスです。
145 r (int): インデックスです。
146
147 制約:
148 :math:`0 \\leq l \\leq r \\leq n`
149 """
150 assert (
151 0 <= l <= r <= self._n
152 ), f"IndexError: {self.__class__.__name__}.prod({l}, {r})"
153 l += self._size
154 r += self._size
155 lres = self._e
156 rres = self._e
157 while l < r:
158 if l & 1:
159 lres = self._op(lres, self._data[l])
160 l += 1
161 if r & 1:
162 rres = self._op(self._data[r ^ 1], rres)
163 l >>= 1
164 r >>= 1
165 return self._op(lres, rres)
166
167 def all_prod(self) -> T:
168 """区間 ``[0, n)`` の総積を返します。
169 :math:`O(1)` です。
170 """
171 return self._data[1]
172
173 def max_right(self, l: int, f: Callable[[T], bool]) -> int:
174 """Find the largest index R s.t. f([l, R)) == True. / O(\\log{n})"""
175 assert (
176 0 <= l <= self._n
177 ), f"IndexError: {self.__class__.__name__}.max_right({l}, f) index out of range"
178 # assert f(self._e), \
179 # f'{self.__class__.__name__}.max_right({l}, f), f({self._e}) must be true.'
180 if l == self._n:
181 return self._n
182 l += self._size
183 s = self._e
184 while True:
185 while l & 1 == 0:
186 l >>= 1
187 if not f(self._op(s, self._data[l])):
188 while l < self._size:
189 l <<= 1
190 if f(self._op(s, self._data[l])):
191 s = self._op(s, self._data[l])
192 l |= 1
193 return l - self._size
194 s = self._op(s, self._data[l])
195 l += 1
196 if l & -l == l:
197 break
198 return self._n
199
200 def min_left(self, r: int, f: Callable[[T], bool]) -> int:
201 """Find the smallest index L s.t. f([L, r)) == True. / O(\\log{n})"""
202 assert (
203 0 <= r <= self._n
204 ), f"IndexError: {self.__class__.__name__}.min_left({r}, f) index out of range"
205 # assert f(self._e), \
206 # f'{self.__class__.__name__}.min_left({r}, f), f({self._e}) must be true.'
207 if r == 0:
208 return 0
209 r += self._size
210 s = self._e
211 while True:
212 r -= 1
213 while r > 1 and r & 1:
214 r >>= 1
215 if not f(self._op(self._data[r], s)):
216 while r < self._size:
217 r = r << 1 | 1
218 if f(self._op(self._data[r], s)):
219 s = self._op(self._data[r], s)
220 r ^= 1
221 return r + 1 - self._size
222 s = self._op(self._data[r], s)
223 if r & -r == r:
224 break
225 return 0
226
227 def tolist(self) -> list[T]:
228 """リストにして返します。
229 :math:`O(n)` です。
230 """
231 return [self.get(i) for i in range(self._n)]
232
233 def show(self) -> None:
234 """デバッグ用のメソッドです。"""
235 print(
236 f"<{self.__class__.__name__}> [\n"
237 + "\n".join(
238 [
239 " "
240 + " ".join(
241 map(str, [self._data[(1 << i) + j] for j in range(1 << i)])
242 )
243 for i in range(self._log + 1)
244 ]
245 )
246 + "\n]"
247 )
248
249 def __getitem__(self, k: int) -> T:
250 assert (
251 -self._n <= k < self._n
252 ), f"IndexError: {self.__class__.__name__}.__getitem__({k}), n={self._n}"
253 return self.get(k)
254
255 def __setitem__(self, k: int, v: T):
256 assert (
257 -self._n <= k < self._n
258 ), f"IndexError: {self.__class__.__name__}.__setitem__{k}, {v}), n={self._n}"
259 self.set(k, v)
260
261 def __len__(self) -> int:
262 return self._n
263
264 def __str__(self) -> str:
265 return str(self.tolist())
266
267 def __repr__(self) -> str:
268 return f"{self.__class__.__name__}({self})"
269from typing import Optional, Final
270import random
271import string
272
273_titan_pylib_HashString_MOD: Final[int] = (1 << 61) - 1
274_titan_pylib_HashString_DIC: Final[dict[str, int]] = {
275 c: i for i, c in enumerate(string.ascii_lowercase, 1)
276}
277_titan_pylib_HashString_MASK30: Final[int] = (1 << 30) - 1
278_titan_pylib_HashString_MASK31: Final[int] = (1 << 31) - 1
279_titan_pylib_HashString_MASK61: Final[int] = _titan_pylib_HashString_MOD
280
281
282class HashStringBase:
283 """HashStringのベースクラスです。"""
284
285 def __init__(self, n: int = 0, base: int = -1, seed: Optional[int] = None) -> None:
286 """
287 :math:`O(n)` です。
288
289 Args:
290 n (int): 文字列の長さの上限です。上限を超えても問題ありません。
291 base (int, optional): Defaults to -1.
292 seed (Optional[int], optional): Defaults to None.
293 """
294 rand = random.Random(seed)
295 base = rand.randint(37, 10**9) if base < 0 else base
296 powb = [1] * (n + 1)
297 invb = [1] * (n + 1)
298 invbpow = pow(base, -1, _titan_pylib_HashString_MOD)
299 for i in range(1, n + 1):
300 powb[i] = HashStringBase.get_mul(powb[i - 1], base)
301 invb[i] = HashStringBase.get_mul(invb[i - 1], invbpow)
302 self.n = n
303 self.base = base
304 self.invpow = invbpow
305 self.powb = powb
306 self.invb = invb
307
308 @staticmethod
309 def get_mul(a: int, b: int) -> int:
310 au = a >> 31
311 ad = a & _titan_pylib_HashString_MASK31
312 bu = b >> 31
313 bd = b & _titan_pylib_HashString_MASK31
314 mid = ad * bu + au * bd
315 midu = mid >> 30
316 midd = mid & _titan_pylib_HashString_MASK30
317 return HashStringBase.get_mod(au * bu * 2 + midu + (midd << 31) + ad * bd)
318
319 @staticmethod
320 def get_mod(x: int) -> int:
321 xu = x >> 61
322 xd = x & _titan_pylib_HashString_MASK61
323 res = xu + xd
324 if res >= _titan_pylib_HashString_MOD:
325 res -= _titan_pylib_HashString_MOD
326 return res
327
328 def extend(self, cap: int) -> None:
329 pre_cap = len(self.powb)
330 powb, invb = self.powb, self.invb
331 powb += [0] * cap
332 invb += [0] * cap
333 invbpow = pow(self.base, -1, _titan_pylib_HashString_MOD)
334 for i in range(pre_cap, pre_cap + cap):
335 powb[i] = HashStringBase.get_mul(powb[i - 1], self.base)
336 invb[i] = HashStringBase.get_mul(invb[i - 1], invbpow)
337
338 def get_cap(self) -> int:
339 return len(self.powb)
340
341 def unite(self, h1: int, h2: int, k: int) -> int:
342 # len(h2) == k
343 # h1 <- h2
344 if k >= self.get_cap():
345 self.extend(k - self.get_cap() + 1)
346 return self.get_mod(self.get_mul(h1, self.powb[k]) + h2)
347
348
349class HashString:
350
351 def __init__(self, hsb: HashStringBase, s: str, update: bool = False) -> None:
352 """ロリハを構築します。
353 :math:`O(n)` です。
354
355 Args:
356 hsb (HashStringBase): ベースクラスです。
357 s (str): ロリハを構築する文字列です。
358 update (bool, optional): ``update=True`` のとき、1点更新が可能になります。
359 """
360 n = len(s)
361 data = [0] * n
362 acc = [0] * (n + 1)
363 if n >= hsb.get_cap():
364 hsb.extend(n - hsb.get_cap() + 1)
365 powb = hsb.powb
366 for i, c in enumerate(s):
367 data[i] = hsb.get_mul(powb[n - i - 1], _titan_pylib_HashString_DIC[c])
368 acc[i + 1] = hsb.get_mod(acc[i] + data[i])
369 self.hsb = hsb
370 self.n = n
371 self.acc = acc
372 self.used_seg = False
373 if update:
374 self.seg = SegmentTree(
375 data, lambda s, t: (s + t) % _titan_pylib_HashString_MOD, 0
376 )
377
378 def get(self, l: int, r: int) -> int:
379 """``s[l, r)`` のハッシュ値を返します。
380 1点更新処理後は :math:`O(\\log{n})` 、そうでなければ :math:`O(1)` です。
381
382 Args:
383 l (int): インデックスです。
384 r (int): インデックスです。
385
386 Returns:
387 int: ハッシュ値です。
388 """
389 assert 0 <= l <= r <= self.n
390 if self.used_seg:
391 return self.hsb.get_mul(self.seg.prod(l, r), self.hsb.invb[self.n - r])
392 return self.hsb.get_mul(
393 self.hsb.get_mod(self.acc[r] - self.acc[l]), self.hsb.invb[self.n - r]
394 )
395
396 def __getitem__(self, k: int) -> int:
397 """``s[k]`` のハッシュ値を返します。
398 1点更新処理後は :math:`O(\\log{n})` 、そうでなければ :math:`O(1)` です。
399
400 Args:
401 k (int): インデックスです。
402
403 Returns:
404 int: ハッシュ値です。
405 """
406 return self.get(k, k + 1)
407
408 def set(self, k: int, c: str) -> None:
409 """`k` 番目の文字を `c` に更新します。
410 :math:`O(\\log{n})` です。また、今後の ``get()`` が :math:`O(\\log{n})` になります。
411
412 Args:
413 k (int): インデックスです。
414 c (str): 更新する文字です。
415 """
416 self.used_seg = True
417 self.seg[k] = self.hsb.get_mul(
418 self.hsb.powb[self.n - k - 1], _titan_pylib_HashString_DIC[c]
419 )
420
421 def __setitem__(self, k: int, c: str) -> None:
422 return self.set(k, c)
423
424 def __len__(self):
425 return self.n
426
427 def get_lcp(self) -> list[int]:
428 """lcp配列を返します。
429 :math:`O(n\\log{n})` です。
430 """
431 a = [0] * self.n
432 memo = [-1] * (self.n + 1)
433 for i in range(self.n):
434 ok, ng = 0, self.n - i + 1
435 while ng - ok > 1:
436 mid = (ok + ng) >> 1
437 if memo[mid] == -1:
438 memo[mid] = self.get(0, mid)
439 if memo[mid] == self.get(i, i + mid):
440 ok = mid
441 else:
442 ng = mid
443 a[i] = ok
444 return a
仕様¶
- class HashString(hsb: HashStringBase, s: str, update: bool = False)[source]¶
Bases:
object
- __getitem__(k: int) int [source]¶
s[k]
のハッシュ値を返します。 1点更新処理後は \(O(\log{n})\) 、そうでなければ \(O(1)\) です。- Parameters:
k (int) – インデックスです。
- Returns:
ハッシュ値です。
- Return type:
int