parallel_tester¶
ソースコード¶
from titan_pylib.ahc.parallel_tester import ParallelTester
展開済みコード¶
1# from titan_pylib.ahc.parallel_tester import ParallelTester
2import multiprocessing.managers
3from typing import Iterable, Callable
4import argparse
5from logging import getLogger, basicConfig
6import subprocess
7import multiprocessing
8import time
9import math
10import os
11import shutil
12import csv
13from functools import partial
14import datetime
15from ahc_settings import AHCSettings
16
17logger = getLogger(__name__)
18basicConfig(
19 format="%(asctime)s [%(levelname)s] : %(message)s",
20 datefmt="%H:%M:%S",
21 level=os.getenv("LOG_LEVEL", "INFO"),
22)
23
24
25def to_red(arg):
26 return f"\u001b[31m{arg}\u001b[0m"
27
28
29def to_green(arg):
30 return f"\u001b[32m{arg}\u001b[0m"
31
32
33def to_bold(arg):
34 return f"\u001b[1m{arg}\u001b[0m"
35
36
37KETA_SCORE = 10
38KETA_TIME = 11
39
40
41class ParallelTester:
42 """テストケース並列回し屋です。
43
44 - 実行例(127並列)ahc_settings.pyを設定して以下のコマンドを実行
45
46 ``$ python3 ./parallel_tester.py -c -v -njobs 127``
47 """
48
49 def __init__(
50 self,
51 compile_command: str,
52 execute_command: str,
53 input_file_names: list[str],
54 cpu_count: int,
55 verbose: bool,
56 get_score: Callable[[list[float]], float],
57 timeout: float,
58 ) -> None:
59 """
60 Args:
61 compile_command (str): コンパイルコマンドです。
62 execute_command (str): 実行コマンドです。
63 実行時引数は ``append_execute_command()`` メソッドで指定することも可能です。
64 input_file_names (list[str]): 入力ファイル名のリストです。
65 cpu_count (int): CPU数です。
66 verbose (bool): ログを表示します。
67 get_score (Callable[[list[float]], float]): スコアのリストに対して平均などを取って返してください。
68 timeout (float): [ms]
69 """
70 self.compile_command = compile_command.split()
71 self.execute_command = execute_command.split()
72 self.input_file_names = input_file_names
73 self.cpu_count = cpu_count
74 self.verbose = verbose
75 self.get_score = get_score
76 self.timeout = (
77 timeout / 1000 if (timeout is not None) and (timeout >= 0) else None
78 )
79 self.counter: multiprocessing.managers.ValueProxy
80
81 def show_score(self, scores: list[float]) -> float:
82 """スコアのリストを受け取り、 ``get_score`` 関数で計算します。
83 ついでに表示もします。
84 """
85 score = self.get_score(scores)
86 logger.info(f"Ave.{score}")
87 return score
88
89 def append_execute_command(self, args: Iterable[str]) -> None:
90 """コマンドライン引数を追加します。"""
91 for arg in args:
92 self.execute_command.append(str(arg))
93
94 def compile(self) -> None:
95 """``compile_command`` よりコンパイルします。"""
96 logger.info("Compiling ...")
97 subprocess.run(
98 self.compile_command,
99 stderr=subprocess.PIPE,
100 stdout=subprocess.PIPE,
101 text=True,
102 check=True,
103 )
104
105 def _process_file_light(self, input_file: str) -> float:
106 """入力`input_file`を処理します。
107
108 Returns:
109 float: スコア
110 """
111 with open(input_file, "r", encoding="utf-8") as f:
112 input_text = "".join(f.read())
113
114 try:
115 result = subprocess.run(
116 self.execute_command,
117 input=input_text,
118 timeout=self.timeout,
119 stderr=subprocess.PIPE,
120 stdout=subprocess.PIPE,
121 text=True,
122 check=True,
123 )
124 score_line = result.stderr.rstrip().split("\n")[-1]
125 _, score = score_line.split(" = ")
126 score = float(score)
127 return score
128 except subprocess.TimeoutExpired as e:
129 logger.error(to_red(f"TLE occured in {input_file}"))
130 return math.nan
131 except subprocess.CalledProcessError as e:
132 logger.error(to_red(f"Error occured in {input_file}"))
133 return math.nan, "ERROR", "-1"
134 except Exception as e:
135 logger.exception(e)
136 logger.error(to_red(f"!!! Error occured in {input_file}"))
137 return math.nan, "INNER_ERROR", "-1"
138
139 def run(self) -> list[float]:
140 """実行します。"""
141 pool = multiprocessing.Pool(processes=self.cpu_count)
142 result = pool.map(
143 partial(self._process_file_light), self.input_file_names, chunksize=1
144 )
145 pool.close()
146 pool.join()
147 return result
148
149 def _process_file(self, args) -> tuple[str, float, str, str]:
150 """入力`input_file`を処理します。
151
152 Returns:
153 tuple[str, float]: ファイル名、スコア、state, time
154 """
155 input_file, lock = args
156 with open(input_file, "r", encoding="utf-8") as f:
157 input_text = "".join(f.read())
158
159 filename = input_file
160 if filename.startswith("./"):
161 filename = filename[len("./") :]
162 filename = filename.split("/", 1)[1]
163
164 try:
165 start_time = time.time()
166 result = subprocess.run(
167 self.execute_command,
168 input=input_text,
169 timeout=self.timeout,
170 stderr=subprocess.PIPE,
171 stdout=subprocess.PIPE,
172 text=True,
173 check=True,
174 )
175 end_time = time.time()
176 score_line = result.stderr.rstrip().split("\n")[-1]
177 _, score = score_line.split(" = ")
178 score = float(score)
179 if self.verbose:
180 with lock:
181 self.counter.value += 1
182 cnt = self.counter.value
183 cnt = " " * (
184 len(str(len(self.input_file_names))) - len(str(cnt))
185 ) + str(cnt)
186 s = str(score)
187 s = " " * (KETA_SCORE - len(s)) + s
188 t = f"{(end_time-start_time):.3f} sec"
189 t = " " * (KETA_TIME - len(t)) + t
190 logger.info(
191 f"| {cnt} / {len(self.input_file_names)} | {input_file} | {s} | {t} |"
192 )
193
194 # stderr
195 with open(
196 f"{self.output_dir}/err/{filename}", "w", encoding="utf-8"
197 ) as out_f:
198 out_f.write(result.stderr)
199
200 # stdout
201 with open(
202 f"{self.output_dir}/out/{filename}", "w", encoding="utf-8"
203 ) as out_f:
204 out_f.write(result.stdout)
205
206 return input_file, score, "AC", f"{(end_time-start_time):.3f}"
207 except subprocess.TimeoutExpired as e:
208 if self.verbose:
209 with lock:
210 self.counter.value += 1
211 cnt = self.counter.value
212 cnt = " " * (
213 len(str(len(self.input_file_names))) - len(str(cnt))
214 ) + str(cnt)
215 s = "-" * KETA_SCORE
216 t = f">{self.timeout:.3f} sec"
217 t = " " * (KETA_TIME - len(t)) + t
218 logger.info(
219 f"| {cnt} / {len(self.input_file_names)} | {input_file} | {s} | {to_red(t)} |"
220 )
221
222 # stderr
223 with open(
224 f"{self.output_dir}/err/{filename}", "w", encoding="utf-8"
225 ) as out_f:
226 if e.stderr is not None:
227 out_f.write(e.stderr.decode("utf-8"))
228
229 # stdout
230 with open(
231 f"{self.output_dir}/out/{filename}", "w", encoding="utf-8"
232 ) as out_f:
233 if e.stdout is not None:
234 out_f.write(e.stdout.decode("utf-8"))
235
236 return input_file, math.nan, "TLE", f"{self.timeout:.3f}"
237 except subprocess.CalledProcessError as e:
238 with lock:
239 self.counter.value += 1
240 # logger.exception(e)
241 logger.error(to_red(f"Error occured in {input_file}"))
242 return input_file, math.nan, "ERROR", "-1"
243 except Exception as e:
244 with lock:
245 self.counter.value += 1
246 logger.exception(e)
247 logger.error(to_red(f"!!! Error occured in {input_file}"))
248 self.counter
249 return input_file, math.nan, "INNER_ERROR", "-1"
250
251 def run_record(self) -> list[tuple[str, float]]:
252 """実行します。"""
253 dt_now = datetime.datetime.now()
254
255 self.output_dir = "./alltests/"
256 if not os.path.exists(self.output_dir):
257 os.makedirs(self.output_dir)
258 self.output_dir += dt_now.strftime("%Y-%m-%d_%H-%M-%S")
259 if not os.path.exists(self.output_dir):
260 os.makedirs(self.output_dir)
261 if not os.path.exists(f"{self.output_dir}/err/"):
262 os.makedirs(f"{self.output_dir}/err/")
263 if not os.path.exists(f"{self.output_dir}/out/"):
264 os.makedirs(f"{self.output_dir}/out/")
265
266 with multiprocessing.Manager() as manager:
267 lock = manager.Lock()
268 self.counter = manager.Value("i", 0)
269 pool = multiprocessing.Pool(processes=self.cpu_count)
270 result = pool.map(
271 self._process_file,
272 [(file, lock) for file in self.input_file_names],
273 chunksize=1,
274 )
275 pool.close()
276 pool.join()
277
278 # csv
279 result.sort()
280 with open(
281 f"{self.output_dir}/result.csv", "w", encoding="utf-8", newline=""
282 ) as f:
283 writer = csv.writer(f)
284 writer.writerow(["filename", "score", "state", "time"])
285 for filename, score, state, t in result:
286 writer.writerow([filename, score, state, t])
287
288 # 出力を`./out/`へも書き出す
289 if not os.path.exists("./out/"):
290 os.makedirs("./out/")
291 for item in os.listdir(f"{self.output_dir}/out/"):
292 src_path = os.path.join(f"{self.output_dir}/out/", item)
293 dest_path = os.path.join("./out/", item)
294 if os.path.isfile(src_path):
295 shutil.copy2(src_path, dest_path)
296 elif os.path.isdir(src_path):
297 shutil.copytree(src_path, dest_path)
298
299 return result
300
301 @staticmethod
302 def get_args() -> argparse.Namespace:
303 """実行時引数を解析します。"""
304 parser = argparse.ArgumentParser()
305 parser.add_argument(
306 "-c",
307 "--compile",
308 required=False,
309 action="store_true",
310 default=False,
311 help="if compile the file. default is `False`.",
312 )
313 parser.add_argument(
314 "-v",
315 "--verbose",
316 required=False,
317 action="store_true",
318 default=False,
319 help="show logs. default is `False`.",
320 )
321 parser.add_argument(
322 "-njobs",
323 "--number_of_jobs",
324 required=False,
325 type=int,
326 action="store",
327 default=1,
328 help="set the number of cpu_count. default is `1`.",
329 )
330 return parser.parse_args()
331
332
333def build_tester(
334 settings: AHCSettings, njobs: int = 1, verbose: bool = False
335) -> ParallelTester:
336 """`ParallelTester` を返します
337
338 Args:
339 njobs (int, optional): cpu_count です。
340 verbose (bool, optional): ログを表示します。
341
342 Returns:
343 ParallelTester: テスターです。
344 """
345 tester = ParallelTester(
346 compile_command=settings.compile_command,
347 execute_command=settings.execute_command,
348 input_file_names=settings.input_file_names,
349 cpu_count=min(njobs, multiprocessing.cpu_count() - 1),
350 verbose=verbose,
351 get_score=settings.get_score,
352 timeout=settings.timeout,
353 )
354 return tester
355
356
357def main():
358 """実行時引数をもとに、 ``tester`` を立ち上げ実行します。"""
359 args = ParallelTester.get_args()
360 njobs = min(args.number_of_jobs, multiprocessing.cpu_count() - 1)
361 logger.info(f"{njobs=}")
362
363 tester = build_tester(AHCSettings, njobs, args.verbose)
364
365 if args.compile:
366 tester.compile()
367
368 logger.info("Start.")
369
370 start = time.time()
371 scores = tester.run_record()
372
373 nan_case = []
374 for filename, s, state, _ in scores:
375 if math.isnan(s):
376 nan_case.append((filename, state))
377 if nan_case:
378 tle_cnt = 0
379 other_cnt = 0
380 inner_cnt = 0
381
382 delta = max(13, max([len(filename) for filename, _ in nan_case])) + 2
383
384 logger.error("=" * (delta + 2))
385 logger.error(to_red(f"ErrorCount: {len(nan_case)}."))
386
387 logger.error("-" * (delta + 2))
388 logger.error("| TLE " + " " * (delta - len(" TLE ")) + "|")
389 for f, state in nan_case:
390 if state == "TLE":
391 tle_cnt += 1
392 logger.error("|" + to_red(f" {f} ") + "|")
393
394 logger.error("-" * (delta + 2))
395
396 logger.error("| ERROR " + " " * (delta - len(" ERROR ")) + "|")
397 for f, state in nan_case:
398 if state == "ERROR":
399 other_cnt += 1
400 logger.error("|" + to_red(f" {f} ") + "|")
401
402 logger.error("-" * (delta + 2))
403
404 logger.error("| INNER_ERROR " + " " * (delta - len(" INNER_ERROR ")) + "|")
405 for f, state in nan_case:
406 if state == "INNER_ERROR":
407 inner_cnt += 1
408 logger.error("|" + to_red(f" {f} ") + "|")
409
410 logger.error("-" * (delta + 2))
411 logger.error("=" * (delta + 2))
412
413 logger.error(to_red(f" TLE : {tle_cnt} "))
414 logger.error(to_red(f" Other : {other_cnt} "))
415 logger.error(to_red(f" Inner : {inner_cnt} "))
416
417 score = tester.show_score([s for _, s, _, _ in scores])
418 logger.info(to_green(f"Finished in {time.time() - start:.4f} sec."))
419 return score
420
421
422if __name__ == "__main__":
423 main()
仕様¶
- class ParallelTester(compile_command: str, execute_command: str, input_file_names: list[str], cpu_count: int, verbose: bool, get_score: Callable[[list[float]], float], timeout: float)[source]¶
Bases:
object
テストケース並列回し屋です。
実行例(127並列)ahc_settings.pyを設定して以下のコマンドを実行
$ python3 ./parallel_tester.py -c -v -njobs 127
- build_tester(settings: AHCSettings, njobs: int = 1, verbose: bool = False) ParallelTester [source]¶
ParallelTester を返します
- Parameters:
njobs (int, optional) – cpu_count です。
verbose (bool, optional) – ログを表示します。
- Returns:
テスターです。
- Return type: