1from typing import Callable
2from itertools import chain
3from math import sqrt, ceil
4
5
[docs]
6class Mo:
7 """長さ `n` の列、クエリ数 `q` に対する `Mo's algorithm` です。
8 :math:`O(\\frac{n}{\\sqrt{q}})` です。
9
10 Args:
11 n (int): 列の長さです。
12 q (int): クエリの数です。
13
14 制約:
15 :math:`0 \\leq n, 0 \\leq q`
16 """
17
18 def __init__(self, n: int, q: int) -> None:
19 assert 0 <= n and 0 <= q, f"ValueError: {n=} {q=}"
20 self.n = n
21 self.q = q
22 self.bucket_size = ceil(sqrt(3) * n / sqrt(2 * q)) if q > 0 else n
23 if self.bucket_size == 0:
24 self.bucket_size = 1
25 self.bit = max(n, q).bit_length()
26 self.msk = (1 << self.bit) - 1
27 self.bucket = [[] for _ in range(n // self.bucket_size + 1)]
28 self.cnt = 0
29
[docs]
30 def add_query(self, l: int, r: int) -> None:
31 """区間 ``[l, r)`` に対するクエリを追加します。
32 :math:`O(1)` です。
33
34 制約:
35 :math:`0 \\leq l \\leq r \\leq n`
36 """
37 assert (
38 0 <= l <= r <= self.n
39 ), f"IndexError: {self.__class__.__name__}.add_query({l}, {r}), self.n={self.n}"
40 self.bucket[l // self.bucket_size].append(
41 (((r << self.bit) | l) << self.bit) | self.cnt
42 )
43 self.cnt += 1
44
[docs]
45 def run(
46 self,
47 add: Callable[[int], None],
48 delete: Callable[[int], None],
49 out: Callable[[int], None],
50 ) -> None:
51 """クエリを実行します。
52 :math:`O(q\\sqrt{n})` です。
53
54 Args:
55 add (Callable[[int], None]): 引数のインデックスに対応する要素を追加します。
56 delete (Callable[[int], None]): 引数のインデックスに対応する要素を削除します。
57 out (Callable[[int], None]): クエリ番号に対する答えを処理します。
58
59 制約:
60 ``q`` 回のクエリを ``add_query`` メソッドで追加する必要があります。
61 """
62 assert (
63 self.cnt == self.q
64 ), f"Not Enough Queries, now:{self.cnt}, expected:{self.q}"
65 bucket, bit, msk = self.bucket, self.bit, self.msk
66 for i, b in enumerate(bucket):
67 b.sort(reverse=i & 1)
68 nl, nr = 0, 0
69 for rli in chain(*bucket):
70 r, l = rli >> bit >> bit, rli >> bit & msk
71 while nl > l:
72 nl -= 1
73 add(nl)
74 while nr < r:
75 add(nr)
76 nr += 1
77 while nl < l:
78 delete(nl)
79 nl += 1
80 while nr > r:
81 nr -= 1
82 delete(nr)
83 out(rli & msk)
84
[docs]
85 def runrun(
86 self,
87 add_left: Callable[[int], None],
88 add_right: Callable[[int], None],
89 delete_left: Callable[[int], None],
90 delete_right: Callable[[int], None],
91 out: Callable[[int], None],
92 ) -> None:
93 """クエリを実行します。
94
95 :math:`O(q\\sqrt{n})` です。
96
97 Args:
98 add_left (Callable[[int], None]): 引数のインデックスに対応する要素を左から追加します。
99 add_right (Callable[[int], None]): 引数のインデックスに対応する要素を右から追加します。
100 delete_left (Callable[[int], None]): 引数のインデックスに対応する要素を左から削除します。
101 delete_right (Callable[[int], None]): 引数のインデックスに対応する要素を右から削除します。
102 out (Callable[[int], None]): クエリ番号に対する答えを処理します。
103
104 制約:
105 ``q`` 回のクエリを ``add_query`` メソッドで追加する必要があります。
106 """
107 assert (
108 self.cnt == self.q
109 ), f"Not Enough Queries, now:{self.cnt}, expected:{self.q}"
110 bucket, bit, msk = self.bucket, self.bit, self.msk
111 for i, b in enumerate(bucket):
112 b.sort(reverse=i & 1)
113 nl, nr = 0, 0
114 for rli in chain(*bucket):
115 r, l = rli >> bit >> bit, rli >> bit & msk
116 while nl > l:
117 nl -= 1
118 add_left(nl)
119 while nr < r:
120 add_right(nr)
121 nr += 1
122 while nl < l:
123 delete_left(nl)
124 nl += 1
125 while nr > r:
126 nr -= 1
127 delete_right(nr)
128 out(rli & msk)