Source code for titan_pylib.ahc.parallel_tester

  1import multiprocessing.managers
  2from typing import Iterable, Callable
  3import argparse
  4from logging import getLogger, basicConfig
  5import subprocess
  6import multiprocessing
  7import time
  8import math
  9import os
 10import shutil
 11import csv
 12from functools import partial
 13import datetime
 14from ahc_settings import AHCSettings
 15
 16logger = getLogger(__name__)
 17basicConfig(
 18    format="%(asctime)s [%(levelname)s] : %(message)s",
 19    datefmt="%H:%M:%S",
 20    level=os.getenv("LOG_LEVEL", "INFO"),
 21)
 22
 23
[docs] 24def to_red(arg): 25 return f"\u001b[31m{arg}\u001b[0m"
26 27
[docs] 28def to_green(arg): 29 return f"\u001b[32m{arg}\u001b[0m"
30 31
[docs] 32def to_bold(arg): 33 return f"\u001b[1m{arg}\u001b[0m"
34 35 36KETA_SCORE = 10 37KETA_TIME = 11 38 39
[docs] 40class ParallelTester: 41 """テストケース並列回し屋です。 42 43 - 実行例(127並列)ahc_settings.pyを設定して以下のコマンドを実行 44 45 ``$ python3 ./parallel_tester.py -c -v -njobs 127`` 46 """ 47 48 def __init__( 49 self, 50 compile_command: str, 51 execute_command: str, 52 input_file_names: list[str], 53 cpu_count: int, 54 verbose: bool, 55 get_score: Callable[[list[float]], float], 56 timeout: float, 57 ) -> None: 58 """ 59 Args: 60 compile_command (str): コンパイルコマンドです。 61 execute_command (str): 実行コマンドです。 62 実行時引数は ``append_execute_command()`` メソッドで指定することも可能です。 63 input_file_names (list[str]): 入力ファイル名のリストです。 64 cpu_count (int): CPU数です。 65 verbose (bool): ログを表示します。 66 get_score (Callable[[list[float]], float]): スコアのリストに対して平均などを取って返してください。 67 timeout (float): [ms] 68 """ 69 self.compile_command = compile_command.split() 70 self.execute_command = execute_command.split() 71 self.input_file_names = input_file_names 72 self.cpu_count = cpu_count 73 self.verbose = verbose 74 self.get_score = get_score 75 self.timeout = ( 76 timeout / 1000 if (timeout is not None) and (timeout >= 0) else None 77 ) 78 self.counter: multiprocessing.managers.ValueProxy 79
[docs] 80 def show_score(self, scores: list[float]) -> float: 81 """スコアのリストを受け取り、 ``get_score`` 関数で計算します。 82 ついでに表示もします。 83 """ 84 score = self.get_score(scores) 85 logger.info(f"Ave.{score}") 86 return score
87
[docs] 88 def append_execute_command(self, args: Iterable[str]) -> None: 89 """コマンドライン引数を追加します。""" 90 for arg in args: 91 self.execute_command.append(str(arg))
92
[docs] 93 def compile(self) -> None: 94 """``compile_command`` よりコンパイルします。""" 95 logger.info("Compiling ...") 96 subprocess.run( 97 self.compile_command, 98 stderr=subprocess.PIPE, 99 stdout=subprocess.PIPE, 100 text=True, 101 check=True, 102 )
103 104 def _process_file_light(self, input_file: str) -> float: 105 """入力`input_file`を処理します。 106 107 Returns: 108 float: スコア 109 """ 110 with open(input_file, "r", encoding="utf-8") as f: 111 input_text = "".join(f.read()) 112 113 try: 114 result = subprocess.run( 115 self.execute_command, 116 input=input_text, 117 timeout=self.timeout, 118 stderr=subprocess.PIPE, 119 stdout=subprocess.PIPE, 120 text=True, 121 check=True, 122 ) 123 score_line = result.stderr.rstrip().split("\n")[-1] 124 _, score = score_line.split(" = ") 125 score = float(score) 126 return score 127 except subprocess.TimeoutExpired as e: 128 logger.error(to_red(f"TLE occured in {input_file}")) 129 return math.nan 130 except subprocess.CalledProcessError as e: 131 logger.error(to_red(f"Error occured in {input_file}")) 132 return math.nan, "ERROR", "-1" 133 except Exception as e: 134 logger.exception(e) 135 logger.error(to_red(f"!!! Error occured in {input_file}")) 136 return math.nan, "INNER_ERROR", "-1" 137
[docs] 138 def run(self) -> list[float]: 139 """実行します。""" 140 pool = multiprocessing.Pool(processes=self.cpu_count) 141 result = pool.map( 142 partial(self._process_file_light), self.input_file_names, chunksize=1 143 ) 144 pool.close() 145 pool.join() 146 return result
147 148 def _process_file(self, args) -> tuple[str, float, str, str]: 149 """入力`input_file`を処理します。 150 151 Returns: 152 tuple[str, float]: ファイル名、スコア、state, time 153 """ 154 input_file, lock = args 155 with open(input_file, "r", encoding="utf-8") as f: 156 input_text = "".join(f.read()) 157 158 filename = input_file 159 if filename.startswith("./"): 160 filename = filename[len("./") :] 161 filename = filename.split("/", 1)[1] 162 163 try: 164 start_time = time.time() 165 result = subprocess.run( 166 self.execute_command, 167 input=input_text, 168 timeout=self.timeout, 169 stderr=subprocess.PIPE, 170 stdout=subprocess.PIPE, 171 text=True, 172 check=True, 173 ) 174 end_time = time.time() 175 score_line = result.stderr.rstrip().split("\n")[-1] 176 _, score = score_line.split(" = ") 177 score = float(score) 178 if self.verbose: 179 with lock: 180 self.counter.value += 1 181 cnt = self.counter.value 182 cnt = " " * ( 183 len(str(len(self.input_file_names))) - len(str(cnt)) 184 ) + str(cnt) 185 s = str(score) 186 s = " " * (KETA_SCORE - len(s)) + s 187 t = f"{(end_time-start_time):.3f} sec" 188 t = " " * (KETA_TIME - len(t)) + t 189 logger.info( 190 f"| {cnt} / {len(self.input_file_names)} | {input_file} | {s} | {t} |" 191 ) 192 193 # stderr 194 with open( 195 f"{self.output_dir}/err/{filename}", "w", encoding="utf-8" 196 ) as out_f: 197 out_f.write(result.stderr) 198 199 # stdout 200 with open( 201 f"{self.output_dir}/out/{filename}", "w", encoding="utf-8" 202 ) as out_f: 203 out_f.write(result.stdout) 204 205 return input_file, score, "AC", f"{(end_time-start_time):.3f}" 206 except subprocess.TimeoutExpired as e: 207 if self.verbose: 208 with lock: 209 self.counter.value += 1 210 cnt = self.counter.value 211 cnt = " " * ( 212 len(str(len(self.input_file_names))) - len(str(cnt)) 213 ) + str(cnt) 214 s = "-" * KETA_SCORE 215 t = f">{self.timeout:.3f} sec" 216 t = " " * (KETA_TIME - len(t)) + t 217 logger.info( 218 f"| {cnt} / {len(self.input_file_names)} | {input_file} | {s} | {to_red(t)} |" 219 ) 220 221 # stderr 222 with open( 223 f"{self.output_dir}/err/{filename}", "w", encoding="utf-8" 224 ) as out_f: 225 if e.stderr is not None: 226 out_f.write(e.stderr.decode("utf-8")) 227 228 # stdout 229 with open( 230 f"{self.output_dir}/out/{filename}", "w", encoding="utf-8" 231 ) as out_f: 232 if e.stdout is not None: 233 out_f.write(e.stdout.decode("utf-8")) 234 235 return input_file, math.nan, "TLE", f"{self.timeout:.3f}" 236 except subprocess.CalledProcessError as e: 237 with lock: 238 self.counter.value += 1 239 # logger.exception(e) 240 logger.error(to_red(f"Error occured in {input_file}")) 241 return input_file, math.nan, "ERROR", "-1" 242 except Exception as e: 243 with lock: 244 self.counter.value += 1 245 logger.exception(e) 246 logger.error(to_red(f"!!! Error occured in {input_file}")) 247 self.counter 248 return input_file, math.nan, "INNER_ERROR", "-1" 249
[docs] 250 def run_record(self) -> list[tuple[str, float]]: 251 """実行します。""" 252 dt_now = datetime.datetime.now() 253 254 self.output_dir = "./alltests/" 255 if not os.path.exists(self.output_dir): 256 os.makedirs(self.output_dir) 257 self.output_dir += dt_now.strftime("%Y-%m-%d_%H-%M-%S") 258 if not os.path.exists(self.output_dir): 259 os.makedirs(self.output_dir) 260 if not os.path.exists(f"{self.output_dir}/err/"): 261 os.makedirs(f"{self.output_dir}/err/") 262 if not os.path.exists(f"{self.output_dir}/out/"): 263 os.makedirs(f"{self.output_dir}/out/") 264 265 with multiprocessing.Manager() as manager: 266 lock = manager.Lock() 267 self.counter = manager.Value("i", 0) 268 pool = multiprocessing.Pool(processes=self.cpu_count) 269 result = pool.map( 270 self._process_file, 271 [(file, lock) for file in self.input_file_names], 272 chunksize=1, 273 ) 274 pool.close() 275 pool.join() 276 277 # csv 278 result.sort() 279 with open( 280 f"{self.output_dir}/result.csv", "w", encoding="utf-8", newline="" 281 ) as f: 282 writer = csv.writer(f) 283 writer.writerow(["filename", "score", "state", "time"]) 284 for filename, score, state, t in result: 285 writer.writerow([filename, score, state, t]) 286 287 # 出力を`./out/`へも書き出す 288 if not os.path.exists("./out/"): 289 os.makedirs("./out/") 290 for item in os.listdir(f"{self.output_dir}/out/"): 291 src_path = os.path.join(f"{self.output_dir}/out/", item) 292 dest_path = os.path.join("./out/", item) 293 if os.path.isfile(src_path): 294 shutil.copy2(src_path, dest_path) 295 elif os.path.isdir(src_path): 296 shutil.copytree(src_path, dest_path) 297 298 return result
299
[docs] 300 @staticmethod 301 def get_args() -> argparse.Namespace: 302 """実行時引数を解析します。""" 303 parser = argparse.ArgumentParser() 304 parser.add_argument( 305 "-c", 306 "--compile", 307 required=False, 308 action="store_true", 309 default=False, 310 help="if compile the file. default is `False`.", 311 ) 312 parser.add_argument( 313 "-v", 314 "--verbose", 315 required=False, 316 action="store_true", 317 default=False, 318 help="show logs. default is `False`.", 319 ) 320 parser.add_argument( 321 "-njobs", 322 "--number_of_jobs", 323 required=False, 324 type=int, 325 action="store", 326 default=1, 327 help="set the number of cpu_count. default is `1`.", 328 ) 329 return parser.parse_args()
330 331
[docs] 332def build_tester( 333 settings: AHCSettings, njobs: int = 1, verbose: bool = False 334) -> ParallelTester: 335 """`ParallelTester` を返します 336 337 Args: 338 njobs (int, optional): cpu_count です。 339 verbose (bool, optional): ログを表示します。 340 341 Returns: 342 ParallelTester: テスターです。 343 """ 344 tester = ParallelTester( 345 compile_command=settings.compile_command, 346 execute_command=settings.execute_command, 347 input_file_names=settings.input_file_names, 348 cpu_count=min(njobs, multiprocessing.cpu_count() - 1), 349 verbose=verbose, 350 get_score=settings.get_score, 351 timeout=settings.timeout, 352 ) 353 return tester
354 355
[docs] 356def main(): 357 """実行時引数をもとに、 ``tester`` を立ち上げ実行します。""" 358 args = ParallelTester.get_args() 359 njobs = min(args.number_of_jobs, multiprocessing.cpu_count() - 1) 360 logger.info(f"{njobs=}") 361 362 tester = build_tester(AHCSettings, njobs, args.verbose) 363 364 if args.compile: 365 tester.compile() 366 367 logger.info("Start.") 368 369 start = time.time() 370 scores = tester.run_record() 371 372 nan_case = [] 373 for filename, s, state, _ in scores: 374 if math.isnan(s): 375 nan_case.append((filename, state)) 376 if nan_case: 377 tle_cnt = 0 378 other_cnt = 0 379 inner_cnt = 0 380 381 delta = max(13, max([len(filename) for filename, _ in nan_case])) + 2 382 383 logger.error("=" * (delta + 2)) 384 logger.error(to_red(f"ErrorCount: {len(nan_case)}.")) 385 386 logger.error("-" * (delta + 2)) 387 logger.error("| TLE " + " " * (delta - len(" TLE ")) + "|") 388 for f, state in nan_case: 389 if state == "TLE": 390 tle_cnt += 1 391 logger.error("|" + to_red(f" {f} ") + "|") 392 393 logger.error("-" * (delta + 2)) 394 395 logger.error("| ERROR " + " " * (delta - len(" ERROR ")) + "|") 396 for f, state in nan_case: 397 if state == "ERROR": 398 other_cnt += 1 399 logger.error("|" + to_red(f" {f} ") + "|") 400 401 logger.error("-" * (delta + 2)) 402 403 logger.error("| INNER_ERROR " + " " * (delta - len(" INNER_ERROR ")) + "|") 404 for f, state in nan_case: 405 if state == "INNER_ERROR": 406 inner_cnt += 1 407 logger.error("|" + to_red(f" {f} ") + "|") 408 409 logger.error("-" * (delta + 2)) 410 logger.error("=" * (delta + 2)) 411 412 logger.error(to_red(f" TLE : {tle_cnt} ")) 413 logger.error(to_red(f" Other : {other_cnt} ")) 414 logger.error(to_red(f" Inner : {inner_cnt} ")) 415 416 score = tester.show_score([s for _, s, _, _ in scores]) 417 logger.info(to_green(f"Finished in {time.time() - start:.4f} sec.")) 418 return score
419 420 421if __name__ == "__main__": 422 main()