parallel_tester

ソースコード

from titan_pylib.ahc.parallel_tester import ParallelTester

view on github

展開済みコード

  1# from titan_pylib.ahc.parallel_tester import ParallelTester
  2import multiprocessing.managers
  3from typing import Iterable, Callable
  4import argparse
  5from logging import getLogger, basicConfig
  6import subprocess
  7import multiprocessing
  8import time
  9import math
 10import os
 11import shutil
 12import csv
 13from functools import partial
 14import datetime
 15from ahc_settings import AHCSettings
 16
 17logger = getLogger(__name__)
 18basicConfig(
 19    format="%(asctime)s [%(levelname)s] : %(message)s",
 20    datefmt="%H:%M:%S",
 21    level=os.getenv("LOG_LEVEL", "INFO"),
 22)
 23
 24
 25def to_red(arg):
 26    return f"\u001b[31m{arg}\u001b[0m"
 27
 28
 29def to_green(arg):
 30    return f"\u001b[32m{arg}\u001b[0m"
 31
 32
 33def to_bold(arg):
 34    return f"\u001b[1m{arg}\u001b[0m"
 35
 36
 37KETA_SCORE = 10
 38KETA_TIME = 11
 39
 40
 41class ParallelTester:
 42    """テストケース並列回し屋です。
 43
 44    - 実行例(127並列)ahc_settings.pyを設定して以下のコマンドを実行
 45
 46    ``$ python3 ./parallel_tester.py -c -v -njobs 127``
 47    """
 48
 49    def __init__(
 50        self,
 51        compile_command: str,
 52        execute_command: str,
 53        input_file_names: list[str],
 54        cpu_count: int,
 55        verbose: bool,
 56        get_score: Callable[[list[float]], float],
 57        timeout: float,
 58    ) -> None:
 59        """
 60        Args:
 61            compile_command (str): コンパイルコマンドです。
 62            execute_command (str): 実行コマンドです。
 63                                    実行時引数は ``append_execute_command()`` メソッドで指定することも可能です。
 64            input_file_names (list[str]): 入力ファイル名のリストです。
 65            cpu_count (int): CPU数です。
 66            verbose (bool): ログを表示します。
 67            get_score (Callable[[list[float]], float]): スコアのリストに対して平均などを取って返してください。
 68            timeout (float): [ms]
 69        """
 70        self.compile_command = compile_command.split()
 71        self.execute_command = execute_command.split()
 72        self.input_file_names = input_file_names
 73        self.cpu_count = cpu_count
 74        self.verbose = verbose
 75        self.get_score = get_score
 76        self.timeout = (
 77            timeout / 1000 if (timeout is not None) and (timeout >= 0) else None
 78        )
 79        self.counter: multiprocessing.managers.ValueProxy
 80
 81    def show_score(self, scores: list[float]) -> float:
 82        """スコアのリストを受け取り、 ``get_score`` 関数で計算します。
 83        ついでに表示もします。
 84        """
 85        score = self.get_score(scores)
 86        logger.info(f"Ave.{score}")
 87        return score
 88
 89    def append_execute_command(self, args: Iterable[str]) -> None:
 90        """コマンドライン引数を追加します。"""
 91        for arg in args:
 92            self.execute_command.append(str(arg))
 93
 94    def compile(self) -> None:
 95        """``compile_command`` よりコンパイルします。"""
 96        logger.info("Compiling ...")
 97        subprocess.run(
 98            self.compile_command,
 99            stderr=subprocess.PIPE,
100            stdout=subprocess.PIPE,
101            text=True,
102            check=True,
103        )
104
105    def _process_file_light(self, input_file: str) -> float:
106        """入力`input_file`を処理します。
107
108        Returns:
109            float: スコア
110        """
111        with open(input_file, "r", encoding="utf-8") as f:
112            input_text = "".join(f.read())
113
114        try:
115            result = subprocess.run(
116                self.execute_command,
117                input=input_text,
118                timeout=self.timeout,
119                stderr=subprocess.PIPE,
120                stdout=subprocess.PIPE,
121                text=True,
122                check=True,
123            )
124            score_line = result.stderr.rstrip().split("\n")[-1]
125            _, score = score_line.split(" = ")
126            score = float(score)
127            return score
128        except subprocess.TimeoutExpired as e:
129            logger.error(to_red(f"TLE occured in {input_file}"))
130            return math.nan
131        except subprocess.CalledProcessError as e:
132            logger.error(to_red(f"Error occured in {input_file}"))
133            return math.nan, "ERROR", "-1"
134        except Exception as e:
135            logger.exception(e)
136            logger.error(to_red(f"!!! Error occured in {input_file}"))
137            return math.nan, "INNER_ERROR", "-1"
138
139    def run(self) -> list[float]:
140        """実行します。"""
141        pool = multiprocessing.Pool(processes=self.cpu_count)
142        result = pool.map(
143            partial(self._process_file_light), self.input_file_names, chunksize=1
144        )
145        pool.close()
146        pool.join()
147        return result
148
149    def _process_file(self, args) -> tuple[str, float, str, str]:
150        """入力`input_file`を処理します。
151
152        Returns:
153            tuple[str, float]: ファイル名、スコア、state, time
154        """
155        input_file, lock = args
156        with open(input_file, "r", encoding="utf-8") as f:
157            input_text = "".join(f.read())
158
159        filename = input_file
160        if filename.startswith("./"):
161            filename = filename[len("./") :]
162        filename = filename.split("/", 1)[1]
163
164        try:
165            start_time = time.time()
166            result = subprocess.run(
167                self.execute_command,
168                input=input_text,
169                timeout=self.timeout,
170                stderr=subprocess.PIPE,
171                stdout=subprocess.PIPE,
172                text=True,
173                check=True,
174            )
175            end_time = time.time()
176            score_line = result.stderr.rstrip().split("\n")[-1]
177            _, score = score_line.split(" = ")
178            score = float(score)
179            if self.verbose:
180                with lock:
181                    self.counter.value += 1
182                    cnt = self.counter.value
183                cnt = " " * (
184                    len(str(len(self.input_file_names))) - len(str(cnt))
185                ) + str(cnt)
186                s = str(score)
187                s = " " * (KETA_SCORE - len(s)) + s
188                t = f"{(end_time-start_time):.3f} sec"
189                t = " " * (KETA_TIME - len(t)) + t
190                logger.info(
191                    f"| {cnt} / {len(self.input_file_names)} | {input_file} | {s} | {t} |"
192                )
193
194            # stderr
195            with open(
196                f"{self.output_dir}/err/{filename}", "w", encoding="utf-8"
197            ) as out_f:
198                out_f.write(result.stderr)
199
200            # stdout
201            with open(
202                f"{self.output_dir}/out/{filename}", "w", encoding="utf-8"
203            ) as out_f:
204                out_f.write(result.stdout)
205
206            return input_file, score, "AC", f"{(end_time-start_time):.3f}"
207        except subprocess.TimeoutExpired as e:
208            if self.verbose:
209                with lock:
210                    self.counter.value += 1
211                    cnt = self.counter.value
212                cnt = " " * (
213                    len(str(len(self.input_file_names))) - len(str(cnt))
214                ) + str(cnt)
215                s = "-" * KETA_SCORE
216                t = f">{self.timeout:.3f} sec"
217                t = " " * (KETA_TIME - len(t)) + t
218                logger.info(
219                    f"| {cnt} / {len(self.input_file_names)} | {input_file} | {s} | {to_red(t)} |"
220                )
221
222            # stderr
223            with open(
224                f"{self.output_dir}/err/{filename}", "w", encoding="utf-8"
225            ) as out_f:
226                if e.stderr is not None:
227                    out_f.write(e.stderr.decode("utf-8"))
228
229            # stdout
230            with open(
231                f"{self.output_dir}/out/{filename}", "w", encoding="utf-8"
232            ) as out_f:
233                if e.stdout is not None:
234                    out_f.write(e.stdout.decode("utf-8"))
235
236            return input_file, math.nan, "TLE", f"{self.timeout:.3f}"
237        except subprocess.CalledProcessError as e:
238            with lock:
239                self.counter.value += 1
240            # logger.exception(e)
241            logger.error(to_red(f"Error occured in {input_file}"))
242            return input_file, math.nan, "ERROR", "-1"
243        except Exception as e:
244            with lock:
245                self.counter.value += 1
246            logger.exception(e)
247            logger.error(to_red(f"!!! Error occured in {input_file}"))
248            self.counter
249            return input_file, math.nan, "INNER_ERROR", "-1"
250
251    def run_record(self) -> list[tuple[str, float]]:
252        """実行します。"""
253        dt_now = datetime.datetime.now()
254
255        self.output_dir = "./alltests/"
256        if not os.path.exists(self.output_dir):
257            os.makedirs(self.output_dir)
258        self.output_dir += dt_now.strftime("%Y-%m-%d_%H-%M-%S")
259        if not os.path.exists(self.output_dir):
260            os.makedirs(self.output_dir)
261        if not os.path.exists(f"{self.output_dir}/err/"):
262            os.makedirs(f"{self.output_dir}/err/")
263        if not os.path.exists(f"{self.output_dir}/out/"):
264            os.makedirs(f"{self.output_dir}/out/")
265
266        with multiprocessing.Manager() as manager:
267            lock = manager.Lock()
268            self.counter = manager.Value("i", 0)
269            pool = multiprocessing.Pool(processes=self.cpu_count)
270            result = pool.map(
271                self._process_file,
272                [(file, lock) for file in self.input_file_names],
273                chunksize=1,
274            )
275            pool.close()
276            pool.join()
277
278        # csv
279        result.sort()
280        with open(
281            f"{self.output_dir}/result.csv", "w", encoding="utf-8", newline=""
282        ) as f:
283            writer = csv.writer(f)
284            writer.writerow(["filename", "score", "state", "time"])
285            for filename, score, state, t in result:
286                writer.writerow([filename, score, state, t])
287
288        # 出力を`./out/`へも書き出す
289        if not os.path.exists("./out/"):
290            os.makedirs("./out/")
291        for item in os.listdir(f"{self.output_dir}/out/"):
292            src_path = os.path.join(f"{self.output_dir}/out/", item)
293            dest_path = os.path.join("./out/", item)
294            if os.path.isfile(src_path):
295                shutil.copy2(src_path, dest_path)
296            elif os.path.isdir(src_path):
297                shutil.copytree(src_path, dest_path)
298
299        return result
300
301    @staticmethod
302    def get_args() -> argparse.Namespace:
303        """実行時引数を解析します。"""
304        parser = argparse.ArgumentParser()
305        parser.add_argument(
306            "-c",
307            "--compile",
308            required=False,
309            action="store_true",
310            default=False,
311            help="if compile the file. default is `False`.",
312        )
313        parser.add_argument(
314            "-v",
315            "--verbose",
316            required=False,
317            action="store_true",
318            default=False,
319            help="show logs. default is `False`.",
320        )
321        parser.add_argument(
322            "-njobs",
323            "--number_of_jobs",
324            required=False,
325            type=int,
326            action="store",
327            default=1,
328            help="set the number of cpu_count. default is `1`.",
329        )
330        return parser.parse_args()
331
332
333def build_tester(
334    settings: AHCSettings, njobs: int = 1, verbose: bool = False
335) -> ParallelTester:
336    """`ParallelTester` を返します
337
338    Args:
339        njobs (int, optional): cpu_count です。
340        verbose (bool, optional): ログを表示します。
341
342    Returns:
343        ParallelTester: テスターです。
344    """
345    tester = ParallelTester(
346        compile_command=settings.compile_command,
347        execute_command=settings.execute_command,
348        input_file_names=settings.input_file_names,
349        cpu_count=min(njobs, multiprocessing.cpu_count() - 1),
350        verbose=verbose,
351        get_score=settings.get_score,
352        timeout=settings.timeout,
353    )
354    return tester
355
356
357def main():
358    """実行時引数をもとに、 ``tester`` を立ち上げ実行します。"""
359    args = ParallelTester.get_args()
360    njobs = min(args.number_of_jobs, multiprocessing.cpu_count() - 1)
361    logger.info(f"{njobs=}")
362
363    tester = build_tester(AHCSettings, njobs, args.verbose)
364
365    if args.compile:
366        tester.compile()
367
368    logger.info("Start.")
369
370    start = time.time()
371    scores = tester.run_record()
372
373    nan_case = []
374    for filename, s, state, _ in scores:
375        if math.isnan(s):
376            nan_case.append((filename, state))
377    if nan_case:
378        tle_cnt = 0
379        other_cnt = 0
380        inner_cnt = 0
381
382        delta = max(13, max([len(filename) for filename, _ in nan_case])) + 2
383
384        logger.error("=" * (delta + 2))
385        logger.error(to_red(f"ErrorCount: {len(nan_case)}."))
386
387        logger.error("-" * (delta + 2))
388        logger.error("| TLE " + " " * (delta - len(" TLE ")) + "|")
389        for f, state in nan_case:
390            if state == "TLE":
391                tle_cnt += 1
392                logger.error("|" + to_red(f" {f} ") + "|")
393
394        logger.error("-" * (delta + 2))
395
396        logger.error("| ERROR " + " " * (delta - len(" ERROR ")) + "|")
397        for f, state in nan_case:
398            if state == "ERROR":
399                other_cnt += 1
400                logger.error("|" + to_red(f" {f} ") + "|")
401
402        logger.error("-" * (delta + 2))
403
404        logger.error("| INNER_ERROR " + " " * (delta - len(" INNER_ERROR ")) + "|")
405        for f, state in nan_case:
406            if state == "INNER_ERROR":
407                inner_cnt += 1
408                logger.error("|" + to_red(f" {f} ") + "|")
409
410        logger.error("-" * (delta + 2))
411        logger.error("=" * (delta + 2))
412
413        logger.error(to_red(f" TLE   : {tle_cnt} "))
414        logger.error(to_red(f" Other : {other_cnt} "))
415        logger.error(to_red(f" Inner : {inner_cnt} "))
416
417    score = tester.show_score([s for _, s, _, _ in scores])
418    logger.info(to_green(f"Finished in {time.time() - start:.4f} sec."))
419    return score
420
421
422if __name__ == "__main__":
423    main()

仕様

class ParallelTester(compile_command: str, execute_command: str, input_file_names: list[str], cpu_count: int, verbose: bool, get_score: Callable[[list[float]], float], timeout: float)[source]

Bases: object

テストケース並列回し屋です。

  • 実行例(127並列)ahc_settings.pyを設定して以下のコマンドを実行

$ python3 ./parallel_tester.py -c -v -njobs 127

append_execute_command(args: Iterable[str]) None[source]

コマンドライン引数を追加します。

compile() None[source]

compile_command よりコンパイルします。

static get_args() Namespace[source]

実行時引数を解析します。

run() list[float][source]

実行します。

run_record() list[tuple[str, float]][source]

実行します。

show_score(scores: list[float]) float[source]

スコアのリストを受け取り、 get_score 関数で計算します。 ついでに表示もします。

build_tester(settings: AHCSettings, njobs: int = 1, verbose: bool = False) ParallelTester[source]

ParallelTester を返します

Parameters:
  • njobs (int, optional) – cpu_count です。

  • verbose (bool, optional) – ログを表示します。

Returns:

テスターです。

Return type:

ParallelTester

main()[source]

実行時引数をもとに、 tester を立ち上げ実行します。

to_bold(arg)[source]
to_green(arg)[source]
to_red(arg)[source]