Source code for titan_pylib.algorithm.mo

  1from typing import Callable
  2from itertools import chain
  3from math import sqrt, ceil
  4
  5
[docs] 6class Mo: 7 """長さ `n` の列、クエリ数 `q` に対する `Mo's algorithm` です。 8 :math:`O(\\frac{n}{\\sqrt{q}})` です。 9 10 Args: 11 n (int): 列の長さです。 12 q (int): クエリの数です。 13 14 制約: 15 :math:`0 \\leq n, 0 \\leq q` 16 """ 17 18 def __init__(self, n: int, q: int) -> None: 19 assert 0 <= n and 0 <= q, f"ValueError: {n=} {q=}" 20 self.n = n 21 self.q = q 22 self.bucket_size = ceil(sqrt(3) * n / sqrt(2 * q)) if q > 0 else n 23 if self.bucket_size == 0: 24 self.bucket_size = 1 25 self.bit = max(n, q).bit_length() 26 self.msk = (1 << self.bit) - 1 27 self.bucket = [[] for _ in range(n // self.bucket_size + 1)] 28 self.cnt = 0 29
[docs] 30 def add_query(self, l: int, r: int) -> None: 31 """区間 ``[l, r)`` に対するクエリを追加します。 32 :math:`O(1)` です。 33 34 制約: 35 :math:`0 \\leq l \\leq r \\leq n` 36 """ 37 assert ( 38 0 <= l <= r <= self.n 39 ), f"IndexError: {self.__class__.__name__}.add_query({l}, {r}), self.n={self.n}" 40 self.bucket[l // self.bucket_size].append( 41 (((r << self.bit) | l) << self.bit) | self.cnt 42 ) 43 self.cnt += 1
44
[docs] 45 def run( 46 self, 47 add: Callable[[int], None], 48 delete: Callable[[int], None], 49 out: Callable[[int], None], 50 ) -> None: 51 """クエリを実行します。 52 :math:`O(q\\sqrt{n})` です。 53 54 Args: 55 add (Callable[[int], None]): 引数のインデックスに対応する要素を追加します。 56 delete (Callable[[int], None]): 引数のインデックスに対応する要素を削除します。 57 out (Callable[[int], None]): クエリ番号に対する答えを処理します。 58 59 制約: 60 ``q`` 回のクエリを ``add_query`` メソッドで追加する必要があります。 61 """ 62 assert ( 63 self.cnt == self.q 64 ), f"Not Enough Queries, now:{self.cnt}, expected:{self.q}" 65 bucket, bit, msk = self.bucket, self.bit, self.msk 66 for i, b in enumerate(bucket): 67 b.sort(reverse=i & 1) 68 nl, nr = 0, 0 69 for rli in chain(*bucket): 70 r, l = rli >> bit >> bit, rli >> bit & msk 71 while nl > l: 72 nl -= 1 73 add(nl) 74 while nr < r: 75 add(nr) 76 nr += 1 77 while nl < l: 78 delete(nl) 79 nl += 1 80 while nr > r: 81 nr -= 1 82 delete(nr) 83 out(rli & msk)
84
[docs] 85 def runrun( 86 self, 87 add_left: Callable[[int], None], 88 add_right: Callable[[int], None], 89 delete_left: Callable[[int], None], 90 delete_right: Callable[[int], None], 91 out: Callable[[int], None], 92 ) -> None: 93 """クエリを実行します。 94 95 :math:`O(q\\sqrt{n})` です。 96 97 Args: 98 add_left (Callable[[int], None]): 引数のインデックスに対応する要素を左から追加します。 99 add_right (Callable[[int], None]): 引数のインデックスに対応する要素を右から追加します。 100 delete_left (Callable[[int], None]): 引数のインデックスに対応する要素を左から削除します。 101 delete_right (Callable[[int], None]): 引数のインデックスに対応する要素を右から削除します。 102 out (Callable[[int], None]): クエリ番号に対する答えを処理します。 103 104 制約: 105 ``q`` 回のクエリを ``add_query`` メソッドで追加する必要があります。 106 """ 107 assert ( 108 self.cnt == self.q 109 ), f"Not Enough Queries, now:{self.cnt}, expected:{self.q}" 110 bucket, bit, msk = self.bucket, self.bit, self.msk 111 for i, b in enumerate(bucket): 112 b.sort(reverse=i & 1) 113 nl, nr = 0, 0 114 for rli in chain(*bucket): 115 r, l = rli >> bit >> bit, rli >> bit & msk 116 while nl > l: 117 nl -= 1 118 add_left(nl) 119 while nr < r: 120 add_right(nr) 121 nr += 1 122 while nl < l: 123 delete_left(nl) 124 nl += 1 125 while nr > r: 126 nr -= 1 127 delete_right(nr) 128 out(rli & msk)