hash_string

ソースコード

from titan_pylib.string.hash_string import HashStringBase
from titan_pylib.string.hash_string import HashString

view on github

展開済みコード

  1# from titan_pylib.string.hash_string import HashString
  2# ref: https://qiita.com/keymoon/items/11fac5627672a6d6a9f6
  3# from titan_pylib.data_structures.segment_tree.segment_tree import SegmentTree
  4# from titan_pylib.data_structures.segment_tree.segment_tree_interface import (
  5#     SegmentTreeInterface,
  6# )
  7from abc import ABC, abstractmethod
  8from typing import TypeVar, Generic, Union, Iterable, Callable
  9
 10T = TypeVar("T")
 11
 12
 13class SegmentTreeInterface(ABC, Generic[T]):
 14
 15    @abstractmethod
 16    def __init__(self, n_or_a: Union[int, Iterable[T]], op: Callable[[T, T], T], e: T):
 17        raise NotImplementedError
 18
 19    @abstractmethod
 20    def set(self, k: int, v: T) -> None:
 21        raise NotImplementedError
 22
 23    @abstractmethod
 24    def get(self, k: int) -> T:
 25        raise NotImplementedError
 26
 27    @abstractmethod
 28    def prod(self, l: int, r: int) -> T:
 29        raise NotImplementedError
 30
 31    @abstractmethod
 32    def all_prod(self) -> T:
 33        raise NotImplementedError
 34
 35    @abstractmethod
 36    def max_right(self, l: int, f: Callable[[T], bool]) -> int:
 37        raise NotImplementedError
 38
 39    @abstractmethod
 40    def min_left(self, r: int, f: Callable[[T], bool]) -> int:
 41        raise NotImplementedError
 42
 43    @abstractmethod
 44    def tolist(self) -> list[T]:
 45        raise NotImplementedError
 46
 47    @abstractmethod
 48    def __getitem__(self, k: int) -> T:
 49        raise NotImplementedError
 50
 51    @abstractmethod
 52    def __setitem__(self, k: int, v: T) -> None:
 53        raise NotImplementedError
 54
 55    @abstractmethod
 56    def __str__(self):
 57        raise NotImplementedError
 58
 59    @abstractmethod
 60    def __repr__(self):
 61        raise NotImplementedError
 62from typing import Generic, Iterable, TypeVar, Callable, Union
 63
 64T = TypeVar("T")
 65
 66
 67class SegmentTree(SegmentTreeInterface, Generic[T]):
 68    """セグ木です。非再帰です。"""
 69
 70    def __init__(
 71        self, n_or_a: Union[int, Iterable[T]], op: Callable[[T, T], T], e: T
 72    ) -> None:
 73        """``SegmentTree`` を構築します。
 74        :math:`O(n)` です。
 75
 76        Args:
 77            n_or_a (Union[int, Iterable[T]]): ``n: int`` のとき、 ``e`` を初期値として長さ ``n`` の ``SegmentTree`` を構築します。
 78                                              ``a: Iterable[T]`` のとき、 ``a`` から ``SegmentTree`` を構築します。
 79            op (Callable[[T, T], T]): 2項演算の関数です。
 80            e (T): 単位元です。
 81        """
 82        self._op = op
 83        self._e = e
 84        if isinstance(n_or_a, int):
 85            self._n = n_or_a
 86            self._log = (self._n - 1).bit_length()
 87            self._size = 1 << self._log
 88            self._data = [e] * (self._size << 1)
 89        else:
 90            n_or_a = list(n_or_a)
 91            self._n = len(n_or_a)
 92            self._log = (self._n - 1).bit_length()
 93            self._size = 1 << self._log
 94            _data = [e] * (self._size << 1)
 95            _data[self._size : self._size + self._n] = n_or_a
 96            for i in range(self._size - 1, 0, -1):
 97                _data[i] = op(_data[i << 1], _data[i << 1 | 1])
 98            self._data = _data
 99
100    def set(self, k: int, v: T) -> None:
101        """一点更新です。
102        :math:`O(\\log{n})` です。
103
104        Args:
105            k (int): 更新するインデックスです。
106            v (T): 更新する値です。
107
108        制約:
109            :math:`-n \\leq n \\leq k < n`
110        """
111        assert (
112            -self._n <= k < self._n
113        ), f"IndexError: {self.__class__.__name__}.set({k}, {v}), n={self._n}"
114        if k < 0:
115            k += self._n
116        k += self._size
117        self._data[k] = v
118        for _ in range(self._log):
119            k >>= 1
120            self._data[k] = self._op(self._data[k << 1], self._data[k << 1 | 1])
121
122    def get(self, k: int) -> T:
123        """一点取得です。
124        :math:`O(1)` です。
125
126        Args:
127            k (int): インデックスです。
128
129        制約:
130            :math:`-n \\leq n \\leq k < n`
131        """
132        assert (
133            -self._n <= k < self._n
134        ), f"IndexError: {self.__class__.__name__}.get({k}), n={self._n}"
135        if k < 0:
136            k += self._n
137        return self._data[k + self._size]
138
139    def prod(self, l: int, r: int) -> T:
140        """区間 ``[l, r)`` の総積を返します。
141        :math:`O(\\log{n})` です。
142
143        Args:
144            l (int): インデックスです。
145            r (int): インデックスです。
146
147        制約:
148            :math:`0 \\leq l \\leq r \\leq n`
149        """
150        assert (
151            0 <= l <= r <= self._n
152        ), f"IndexError: {self.__class__.__name__}.prod({l}, {r})"
153        l += self._size
154        r += self._size
155        lres = self._e
156        rres = self._e
157        while l < r:
158            if l & 1:
159                lres = self._op(lres, self._data[l])
160                l += 1
161            if r & 1:
162                rres = self._op(self._data[r ^ 1], rres)
163            l >>= 1
164            r >>= 1
165        return self._op(lres, rres)
166
167    def all_prod(self) -> T:
168        """区間 ``[0, n)`` の総積を返します。
169        :math:`O(1)` です。
170        """
171        return self._data[1]
172
173    def max_right(self, l: int, f: Callable[[T], bool]) -> int:
174        """Find the largest index R s.t. f([l, R)) == True. / O(\\log{n})"""
175        assert (
176            0 <= l <= self._n
177        ), f"IndexError: {self.__class__.__name__}.max_right({l}, f) index out of range"
178        # assert f(self._e), \
179        #     f'{self.__class__.__name__}.max_right({l}, f), f({self._e}) must be true.'
180        if l == self._n:
181            return self._n
182        l += self._size
183        s = self._e
184        while True:
185            while l & 1 == 0:
186                l >>= 1
187            if not f(self._op(s, self._data[l])):
188                while l < self._size:
189                    l <<= 1
190                    if f(self._op(s, self._data[l])):
191                        s = self._op(s, self._data[l])
192                        l |= 1
193                return l - self._size
194            s = self._op(s, self._data[l])
195            l += 1
196            if l & -l == l:
197                break
198        return self._n
199
200    def min_left(self, r: int, f: Callable[[T], bool]) -> int:
201        """Find the smallest index L s.t. f([L, r)) == True. / O(\\log{n})"""
202        assert (
203            0 <= r <= self._n
204        ), f"IndexError: {self.__class__.__name__}.min_left({r}, f) index out of range"
205        # assert f(self._e), \
206        #     f'{self.__class__.__name__}.min_left({r}, f), f({self._e}) must be true.'
207        if r == 0:
208            return 0
209        r += self._size
210        s = self._e
211        while True:
212            r -= 1
213            while r > 1 and r & 1:
214                r >>= 1
215            if not f(self._op(self._data[r], s)):
216                while r < self._size:
217                    r = r << 1 | 1
218                    if f(self._op(self._data[r], s)):
219                        s = self._op(self._data[r], s)
220                        r ^= 1
221                return r + 1 - self._size
222            s = self._op(self._data[r], s)
223            if r & -r == r:
224                break
225        return 0
226
227    def tolist(self) -> list[T]:
228        """リストにして返します。
229        :math:`O(n)` です。
230        """
231        return [self.get(i) for i in range(self._n)]
232
233    def show(self) -> None:
234        """デバッグ用のメソッドです。"""
235        print(
236            f"<{self.__class__.__name__}> [\n"
237            + "\n".join(
238                [
239                    "  "
240                    + " ".join(
241                        map(str, [self._data[(1 << i) + j] for j in range(1 << i)])
242                    )
243                    for i in range(self._log + 1)
244                ]
245            )
246            + "\n]"
247        )
248
249    def __getitem__(self, k: int) -> T:
250        assert (
251            -self._n <= k < self._n
252        ), f"IndexError: {self.__class__.__name__}.__getitem__({k}), n={self._n}"
253        return self.get(k)
254
255    def __setitem__(self, k: int, v: T):
256        assert (
257            -self._n <= k < self._n
258        ), f"IndexError: {self.__class__.__name__}.__setitem__{k}, {v}), n={self._n}"
259        self.set(k, v)
260
261    def __len__(self) -> int:
262        return self._n
263
264    def __str__(self) -> str:
265        return str(self.tolist())
266
267    def __repr__(self) -> str:
268        return f"{self.__class__.__name__}({self})"
269from typing import Optional, Final
270import random
271import string
272
273_titan_pylib_HashString_MOD: Final[int] = (1 << 61) - 1
274_titan_pylib_HashString_DIC: Final[dict[str, int]] = {
275    c: i for i, c in enumerate(string.ascii_lowercase, 1)
276}
277_titan_pylib_HashString_MASK30: Final[int] = (1 << 30) - 1
278_titan_pylib_HashString_MASK31: Final[int] = (1 << 31) - 1
279_titan_pylib_HashString_MASK61: Final[int] = _titan_pylib_HashString_MOD
280
281
282class HashStringBase:
283    """HashStringのベースクラスです。"""
284
285    def __init__(self, n: int = 0, base: int = -1, seed: Optional[int] = None) -> None:
286        """
287        :math:`O(n)` です。
288
289        Args:
290            n (int): 文字列の長さの上限です。上限を超えても問題ありません。
291            base (int, optional): Defaults to -1.
292            seed (Optional[int], optional): Defaults to None.
293        """
294        rand = random.Random(seed)
295        base = rand.randint(37, 10**9) if base < 0 else base
296        powb = [1] * (n + 1)
297        invb = [1] * (n + 1)
298        invbpow = pow(base, -1, _titan_pylib_HashString_MOD)
299        for i in range(1, n + 1):
300            powb[i] = HashStringBase.get_mul(powb[i - 1], base)
301            invb[i] = HashStringBase.get_mul(invb[i - 1], invbpow)
302        self.n = n
303        self.base = base
304        self.invpow = invbpow
305        self.powb = powb
306        self.invb = invb
307
308    @staticmethod
309    def get_mul(a: int, b: int) -> int:
310        au = a >> 31
311        ad = a & _titan_pylib_HashString_MASK31
312        bu = b >> 31
313        bd = b & _titan_pylib_HashString_MASK31
314        mid = ad * bu + au * bd
315        midu = mid >> 30
316        midd = mid & _titan_pylib_HashString_MASK30
317        return HashStringBase.get_mod(au * bu * 2 + midu + (midd << 31) + ad * bd)
318
319    @staticmethod
320    def get_mod(x: int) -> int:
321        xu = x >> 61
322        xd = x & _titan_pylib_HashString_MASK61
323        res = xu + xd
324        if res >= _titan_pylib_HashString_MOD:
325            res -= _titan_pylib_HashString_MOD
326        return res
327
328    def extend(self, cap: int) -> None:
329        pre_cap = len(self.powb)
330        powb, invb = self.powb, self.invb
331        powb += [0] * cap
332        invb += [0] * cap
333        invbpow = pow(self.base, -1, _titan_pylib_HashString_MOD)
334        for i in range(pre_cap, pre_cap + cap):
335            powb[i] = HashStringBase.get_mul(powb[i - 1], self.base)
336            invb[i] = HashStringBase.get_mul(invb[i - 1], invbpow)
337
338    def get_cap(self) -> int:
339        return len(self.powb)
340
341    def unite(self, h1: int, h2: int, k: int) -> int:
342        # len(h2) == k
343        # h1 <- h2
344        if k >= self.get_cap():
345            self.extend(k - self.get_cap() + 1)
346        return self.get_mod(self.get_mul(h1, self.powb[k]) + h2)
347
348
349class HashString:
350
351    def __init__(self, hsb: HashStringBase, s: str, update: bool = False) -> None:
352        """ロリハを構築します。
353        :math:`O(n)` です。
354
355        Args:
356            hsb (HashStringBase): ベースクラスです。
357            s (str): ロリハを構築する文字列です。
358            update (bool, optional): ``update=True`` のとき、1点更新が可能になります。
359        """
360        n = len(s)
361        data = [0] * n
362        acc = [0] * (n + 1)
363        if n >= hsb.get_cap():
364            hsb.extend(n - hsb.get_cap() + 1)
365        powb = hsb.powb
366        for i, c in enumerate(s):
367            data[i] = hsb.get_mul(powb[n - i - 1], _titan_pylib_HashString_DIC[c])
368            acc[i + 1] = hsb.get_mod(acc[i] + data[i])
369        self.hsb = hsb
370        self.n = n
371        self.acc = acc
372        self.used_seg = False
373        if update:
374            self.seg = SegmentTree(
375                data, lambda s, t: (s + t) % _titan_pylib_HashString_MOD, 0
376            )
377
378    def get(self, l: int, r: int) -> int:
379        """``s[l, r)`` のハッシュ値を返します。
380        1点更新処理後は :math:`O(\\log{n})` 、そうでなければ :math:`O(1)` です。
381
382        Args:
383            l (int): インデックスです。
384            r (int): インデックスです。
385
386        Returns:
387            int: ハッシュ値です。
388        """
389        assert 0 <= l <= r <= self.n
390        if self.used_seg:
391            return self.hsb.get_mul(self.seg.prod(l, r), self.hsb.invb[self.n - r])
392        return self.hsb.get_mul(
393            self.hsb.get_mod(self.acc[r] - self.acc[l]), self.hsb.invb[self.n - r]
394        )
395
396    def __getitem__(self, k: int) -> int:
397        """``s[k]`` のハッシュ値を返します。
398        1点更新処理後は :math:`O(\\log{n})` 、そうでなければ :math:`O(1)` です。
399
400        Args:
401            k (int): インデックスです。
402
403        Returns:
404            int: ハッシュ値です。
405        """
406        return self.get(k, k + 1)
407
408    def set(self, k: int, c: str) -> None:
409        """`k` 番目の文字を `c` に更新します。
410        :math:`O(\\log{n})` です。また、今後の ``get()`` が :math:`O(\\log{n})` になります。
411
412        Args:
413            k (int): インデックスです。
414            c (str): 更新する文字です。
415        """
416        self.used_seg = True
417        self.seg[k] = self.hsb.get_mul(
418            self.hsb.powb[self.n - k - 1], _titan_pylib_HashString_DIC[c]
419        )
420
421    def __setitem__(self, k: int, c: str) -> None:
422        return self.set(k, c)
423
424    def __len__(self):
425        return self.n
426
427    def get_lcp(self) -> list[int]:
428        """lcp配列を返します。
429        :math:`O(n\\log{n})` です。
430        """
431        a = [0] * self.n
432        memo = [-1] * (self.n + 1)
433        for i in range(self.n):
434            ok, ng = 0, self.n - i + 1
435            while ng - ok > 1:
436                mid = (ok + ng) >> 1
437                if memo[mid] == -1:
438                    memo[mid] = self.get(0, mid)
439                if memo[mid] == self.get(i, i + mid):
440                    ok = mid
441                else:
442                    ng = mid
443            a[i] = ok
444        return a

仕様

class HashString(hsb: HashStringBase, s: str, update: bool = False)[source]

Bases: object

__getitem__(k: int) int[source]

s[k] のハッシュ値を返します。 1点更新処理後は \(O(\log{n})\) 、そうでなければ \(O(1)\) です。

Parameters:

k (int) – インデックスです。

Returns:

ハッシュ値です。

Return type:

int

get(l: int, r: int) int[source]

s[l, r) のハッシュ値を返します。 1点更新処理後は \(O(\log{n})\) 、そうでなければ \(O(1)\) です。

Parameters:
  • l (int) – インデックスです。

  • r (int) – インデックスです。

Returns:

ハッシュ値です。

Return type:

int

get_lcp() list[int][source]

lcp配列を返します。 \(O(n\log{n})\) です。

set(k: int, c: str) None[source]

k 番目の文字を c に更新します。 \(O(\log{n})\) です。また、今後の get()\(O(\log{n})\) になります。

Parameters:
  • k (int) – インデックスです。

  • c (str) – 更新する文字です。

class HashStringBase(n: int = 0, base: int = -1, seed: int | None = None)[source]

Bases: object

HashStringのベースクラスです。

extend(cap: int) None[source]
get_cap() int[source]
static get_mod(x: int) int[source]
static get_mul(a: int, b: int) int[source]
unite(h1: int, h2: int, k: int) int[source]