1import multiprocessing.managers
2from typing import Iterable, Callable
3import argparse
4from logging import getLogger, basicConfig
5import subprocess
6import multiprocessing
7import time
8import math
9import os
10import shutil
11import csv
12from functools import partial
13import datetime
14from ahc_settings import AHCSettings
15
16logger = getLogger(__name__)
17basicConfig(
18 format="%(asctime)s [%(levelname)s] : %(message)s",
19 datefmt="%H:%M:%S",
20 level=os.getenv("LOG_LEVEL", "INFO"),
21)
22
23
[docs]
24def to_red(arg):
25 return f"\u001b[31m{arg}\u001b[0m"
26
27
[docs]
28def to_green(arg):
29 return f"\u001b[32m{arg}\u001b[0m"
30
31
[docs]
32def to_bold(arg):
33 return f"\u001b[1m{arg}\u001b[0m"
34
35
36KETA_SCORE = 10
37KETA_TIME = 11
38
39
[docs]
40class ParallelTester:
41 """テストケース並列回し屋です。
42
43 - 実行例(127並列)ahc_settings.pyを設定して以下のコマンドを実行
44
45 ``$ python3 ./parallel_tester.py -c -v -njobs 127``
46 """
47
48 def __init__(
49 self,
50 compile_command: str,
51 execute_command: str,
52 input_file_names: list[str],
53 cpu_count: int,
54 verbose: bool,
55 get_score: Callable[[list[float]], float],
56 timeout: float,
57 ) -> None:
58 """
59 Args:
60 compile_command (str): コンパイルコマンドです。
61 execute_command (str): 実行コマンドです。
62 実行時引数は ``append_execute_command()`` メソッドで指定することも可能です。
63 input_file_names (list[str]): 入力ファイル名のリストです。
64 cpu_count (int): CPU数です。
65 verbose (bool): ログを表示します。
66 get_score (Callable[[list[float]], float]): スコアのリストに対して平均などを取って返してください。
67 timeout (float): [ms]
68 """
69 self.compile_command = compile_command.split()
70 self.execute_command = execute_command.split()
71 self.input_file_names = input_file_names
72 self.cpu_count = cpu_count
73 self.verbose = verbose
74 self.get_score = get_score
75 self.timeout = (
76 timeout / 1000 if (timeout is not None) and (timeout >= 0) else None
77 )
78 self.counter: multiprocessing.managers.ValueProxy
79
[docs]
80 def show_score(self, scores: list[float]) -> float:
81 """スコアのリストを受け取り、 ``get_score`` 関数で計算します。
82 ついでに表示もします。
83 """
84 score = self.get_score(scores)
85 logger.info(f"Ave.{score}")
86 return score
87
[docs]
88 def append_execute_command(self, args: Iterable[str]) -> None:
89 """コマンドライン引数を追加します。"""
90 for arg in args:
91 self.execute_command.append(str(arg))
92
[docs]
93 def compile(self) -> None:
94 """``compile_command`` よりコンパイルします。"""
95 logger.info("Compiling ...")
96 subprocess.run(
97 self.compile_command,
98 stderr=subprocess.PIPE,
99 stdout=subprocess.PIPE,
100 text=True,
101 check=True,
102 )
103
104 def _process_file_light(self, input_file: str) -> float:
105 """入力`input_file`を処理します。
106
107 Returns:
108 float: スコア
109 """
110 with open(input_file, "r", encoding="utf-8") as f:
111 input_text = "".join(f.read())
112
113 try:
114 result = subprocess.run(
115 self.execute_command,
116 input=input_text,
117 timeout=self.timeout,
118 stderr=subprocess.PIPE,
119 stdout=subprocess.PIPE,
120 text=True,
121 check=True,
122 )
123 score_line = result.stderr.rstrip().split("\n")[-1]
124 _, score = score_line.split(" = ")
125 score = float(score)
126 return score
127 except subprocess.TimeoutExpired as e:
128 logger.error(to_red(f"TLE occured in {input_file}"))
129 return math.nan
130 except subprocess.CalledProcessError as e:
131 logger.error(to_red(f"Error occured in {input_file}"))
132 return math.nan, "ERROR", "-1"
133 except Exception as e:
134 logger.exception(e)
135 logger.error(to_red(f"!!! Error occured in {input_file}"))
136 return math.nan, "INNER_ERROR", "-1"
137
[docs]
138 def run(self) -> list[float]:
139 """実行します。"""
140 pool = multiprocessing.Pool(processes=self.cpu_count)
141 result = pool.map(
142 partial(self._process_file_light), self.input_file_names, chunksize=1
143 )
144 pool.close()
145 pool.join()
146 return result
147
148 def _process_file(self, args) -> tuple[str, float, str, str]:
149 """入力`input_file`を処理します。
150
151 Returns:
152 tuple[str, float]: ファイル名、スコア、state, time
153 """
154 input_file, lock = args
155 with open(input_file, "r", encoding="utf-8") as f:
156 input_text = "".join(f.read())
157
158 filename = input_file
159 if filename.startswith("./"):
160 filename = filename[len("./") :]
161 filename = filename.split("/", 1)[1]
162
163 try:
164 start_time = time.time()
165 result = subprocess.run(
166 self.execute_command,
167 input=input_text,
168 timeout=self.timeout,
169 stderr=subprocess.PIPE,
170 stdout=subprocess.PIPE,
171 text=True,
172 check=True,
173 )
174 end_time = time.time()
175 score_line = result.stderr.rstrip().split("\n")[-1]
176 _, score = score_line.split(" = ")
177 score = float(score)
178 if self.verbose:
179 with lock:
180 self.counter.value += 1
181 cnt = self.counter.value
182 cnt = " " * (
183 len(str(len(self.input_file_names))) - len(str(cnt))
184 ) + str(cnt)
185 s = str(score)
186 s = " " * (KETA_SCORE - len(s)) + s
187 t = f"{(end_time-start_time):.3f} sec"
188 t = " " * (KETA_TIME - len(t)) + t
189 logger.info(
190 f"| {cnt} / {len(self.input_file_names)} | {input_file} | {s} | {t} |"
191 )
192
193 # stderr
194 with open(
195 f"{self.output_dir}/err/{filename}", "w", encoding="utf-8"
196 ) as out_f:
197 out_f.write(result.stderr)
198
199 # stdout
200 with open(
201 f"{self.output_dir}/out/{filename}", "w", encoding="utf-8"
202 ) as out_f:
203 out_f.write(result.stdout)
204
205 return input_file, score, "AC", f"{(end_time-start_time):.3f}"
206 except subprocess.TimeoutExpired as e:
207 if self.verbose:
208 with lock:
209 self.counter.value += 1
210 cnt = self.counter.value
211 cnt = " " * (
212 len(str(len(self.input_file_names))) - len(str(cnt))
213 ) + str(cnt)
214 s = "-" * KETA_SCORE
215 t = f">{self.timeout:.3f} sec"
216 t = " " * (KETA_TIME - len(t)) + t
217 logger.info(
218 f"| {cnt} / {len(self.input_file_names)} | {input_file} | {s} | {to_red(t)} |"
219 )
220
221 # stderr
222 with open(
223 f"{self.output_dir}/err/{filename}", "w", encoding="utf-8"
224 ) as out_f:
225 if e.stderr is not None:
226 out_f.write(e.stderr.decode("utf-8"))
227
228 # stdout
229 with open(
230 f"{self.output_dir}/out/{filename}", "w", encoding="utf-8"
231 ) as out_f:
232 if e.stdout is not None:
233 out_f.write(e.stdout.decode("utf-8"))
234
235 return input_file, math.nan, "TLE", f"{self.timeout:.3f}"
236 except subprocess.CalledProcessError as e:
237 with lock:
238 self.counter.value += 1
239 # logger.exception(e)
240 logger.error(to_red(f"Error occured in {input_file}"))
241 return input_file, math.nan, "ERROR", "-1"
242 except Exception as e:
243 with lock:
244 self.counter.value += 1
245 logger.exception(e)
246 logger.error(to_red(f"!!! Error occured in {input_file}"))
247 self.counter
248 return input_file, math.nan, "INNER_ERROR", "-1"
249
[docs]
250 def run_record(self) -> list[tuple[str, float]]:
251 """実行します。"""
252 dt_now = datetime.datetime.now()
253
254 self.output_dir = "./alltests/"
255 if not os.path.exists(self.output_dir):
256 os.makedirs(self.output_dir)
257 self.output_dir += dt_now.strftime("%Y-%m-%d_%H-%M-%S")
258 if not os.path.exists(self.output_dir):
259 os.makedirs(self.output_dir)
260 if not os.path.exists(f"{self.output_dir}/err/"):
261 os.makedirs(f"{self.output_dir}/err/")
262 if not os.path.exists(f"{self.output_dir}/out/"):
263 os.makedirs(f"{self.output_dir}/out/")
264
265 with multiprocessing.Manager() as manager:
266 lock = manager.Lock()
267 self.counter = manager.Value("i", 0)
268 pool = multiprocessing.Pool(processes=self.cpu_count)
269 result = pool.map(
270 self._process_file,
271 [(file, lock) for file in self.input_file_names],
272 chunksize=1,
273 )
274 pool.close()
275 pool.join()
276
277 # csv
278 result.sort()
279 with open(
280 f"{self.output_dir}/result.csv", "w", encoding="utf-8", newline=""
281 ) as f:
282 writer = csv.writer(f)
283 writer.writerow(["filename", "score", "state", "time"])
284 for filename, score, state, t in result:
285 writer.writerow([filename, score, state, t])
286
287 # 出力を`./out/`へも書き出す
288 if not os.path.exists("./out/"):
289 os.makedirs("./out/")
290 for item in os.listdir(f"{self.output_dir}/out/"):
291 src_path = os.path.join(f"{self.output_dir}/out/", item)
292 dest_path = os.path.join("./out/", item)
293 if os.path.isfile(src_path):
294 shutil.copy2(src_path, dest_path)
295 elif os.path.isdir(src_path):
296 shutil.copytree(src_path, dest_path)
297
298 return result
299
[docs]
300 @staticmethod
301 def get_args() -> argparse.Namespace:
302 """実行時引数を解析します。"""
303 parser = argparse.ArgumentParser()
304 parser.add_argument(
305 "-c",
306 "--compile",
307 required=False,
308 action="store_true",
309 default=False,
310 help="if compile the file. default is `False`.",
311 )
312 parser.add_argument(
313 "-v",
314 "--verbose",
315 required=False,
316 action="store_true",
317 default=False,
318 help="show logs. default is `False`.",
319 )
320 parser.add_argument(
321 "-njobs",
322 "--number_of_jobs",
323 required=False,
324 type=int,
325 action="store",
326 default=1,
327 help="set the number of cpu_count. default is `1`.",
328 )
329 return parser.parse_args()
330
331
[docs]
332def build_tester(
333 settings: AHCSettings, njobs: int = 1, verbose: bool = False
334) -> ParallelTester:
335 """`ParallelTester` を返します
336
337 Args:
338 njobs (int, optional): cpu_count です。
339 verbose (bool, optional): ログを表示します。
340
341 Returns:
342 ParallelTester: テスターです。
343 """
344 tester = ParallelTester(
345 compile_command=settings.compile_command,
346 execute_command=settings.execute_command,
347 input_file_names=settings.input_file_names,
348 cpu_count=min(njobs, multiprocessing.cpu_count() - 1),
349 verbose=verbose,
350 get_score=settings.get_score,
351 timeout=settings.timeout,
352 )
353 return tester
354
355
[docs]
356def main():
357 """実行時引数をもとに、 ``tester`` を立ち上げ実行します。"""
358 args = ParallelTester.get_args()
359 njobs = min(args.number_of_jobs, multiprocessing.cpu_count() - 1)
360 logger.info(f"{njobs=}")
361
362 tester = build_tester(AHCSettings, njobs, args.verbose)
363
364 if args.compile:
365 tester.compile()
366
367 logger.info("Start.")
368
369 start = time.time()
370 scores = tester.run_record()
371
372 nan_case = []
373 for filename, s, state, _ in scores:
374 if math.isnan(s):
375 nan_case.append((filename, state))
376 if nan_case:
377 tle_cnt = 0
378 other_cnt = 0
379 inner_cnt = 0
380
381 delta = max(13, max([len(filename) for filename, _ in nan_case])) + 2
382
383 logger.error("=" * (delta + 2))
384 logger.error(to_red(f"ErrorCount: {len(nan_case)}."))
385
386 logger.error("-" * (delta + 2))
387 logger.error("| TLE " + " " * (delta - len(" TLE ")) + "|")
388 for f, state in nan_case:
389 if state == "TLE":
390 tle_cnt += 1
391 logger.error("|" + to_red(f" {f} ") + "|")
392
393 logger.error("-" * (delta + 2))
394
395 logger.error("| ERROR " + " " * (delta - len(" ERROR ")) + "|")
396 for f, state in nan_case:
397 if state == "ERROR":
398 other_cnt += 1
399 logger.error("|" + to_red(f" {f} ") + "|")
400
401 logger.error("-" * (delta + 2))
402
403 logger.error("| INNER_ERROR " + " " * (delta - len(" INNER_ERROR ")) + "|")
404 for f, state in nan_case:
405 if state == "INNER_ERROR":
406 inner_cnt += 1
407 logger.error("|" + to_red(f" {f} ") + "|")
408
409 logger.error("-" * (delta + 2))
410 logger.error("=" * (delta + 2))
411
412 logger.error(to_red(f" TLE : {tle_cnt} "))
413 logger.error(to_red(f" Other : {other_cnt} "))
414 logger.error(to_red(f" Inner : {inner_cnt} "))
415
416 score = tester.show_score([s for _, s, _, _ in scores])
417 logger.info(to_green(f"Finished in {time.time() - start:.4f} sec."))
418 return score
419
420
421if __name__ == "__main__":
422 main()