optimizer

ソースコード

from titan_pylib.ahc.optimizer import Optimizer

view on github

展開済みコード

  1# from titan_pylib.ahc.optimizer import Optimizer
  2import optuna
  3import time
  4from logging import getLogger, basicConfig
  5import os
  6import multiprocessing
  7from parallel_tester import ParallelTester, build_tester
  8from ahc_settings import AHCSettings
  9
 10logger = getLogger(__name__)
 11basicConfig(
 12    format="%(asctime)s [%(levelname)s] : %(message)s",
 13    datefmt="%H:%M:%S",
 14    level=os.getenv("LOG_LEVEL", "INFO"),
 15)
 16
 17
 18def to_red(arg):
 19    return f"\u001b[31m{arg}\u001b[0m"
 20
 21
 22def to_blue(arg):
 23    return f"\u001b[94m{arg}\u001b[0m"
 24
 25
 26def to_green(arg):
 27    return f"\u001b[32m{arg}\u001b[0m"
 28
 29
 30def to_bold(arg):
 31    return f"\u001b[1m{arg}\u001b[0m"
 32
 33
 34class Optimizer:
 35
 36    def __init__(self, settings: AHCSettings) -> None:
 37        self.settings: AHCSettings = settings
 38        logger.info(f"------------------------------------------")
 39        logger.info(to_blue(f"Optimizer settings:"))
 40        logger.info(f"- study_name : {to_bold(settings.study_name)}")
 41        logger.info(f"- direction  : {to_bold(settings.direction)}")
 42        logger.info(f"- n_trials   : {to_bold(settings.n_trials)}")
 43        logger.info(f"------------------------------------------")
 44        self.path = f"./optimizer_results/{self.settings.study_name}"
 45
 46    def optimize(self) -> None:
 47        if not os.path.exists(self.path):
 48            os.makedirs(self.path)
 49
 50        tester: ParallelTester = build_tester(
 51            self.settings, njobs=self.settings.n_jobs_parallel_tester
 52        )
 53        tester.compile()
 54
 55        start = time.time()
 56
 57        study: optuna.Study = optuna.create_study(
 58            direction=self.settings.direction,
 59            study_name=self.settings.study_name,
 60            storage=f"sqlite:///{self.path}/{self.settings.study_name}.db",
 61            load_if_exists=True,
 62        )
 63
 64        def _objective(trial: optuna.trial.Trial) -> float:
 65            tester: ParallelTester = build_tester(
 66                self.settings, njobs=self.settings.n_jobs_parallel_tester
 67            )
 68            args = self.settings.objective(trial)
 69            tester.append_execute_command(args)
 70            scores = tester.run()
 71            score = tester.get_score(scores)
 72            return score
 73
 74        study.optimize(
 75            _objective,
 76            n_trials=self.settings.n_trials,
 77            n_jobs=min(self.settings.n_jobs_optuna, multiprocessing.cpu_count() - 1),
 78        )
 79
 80        logger.info(study.best_trial)
 81        logger.info("writing results ...")
 82        self.output(study)
 83        logger.info(f"Finish parameter seraching. Time: {time.time() - start:.2f}sec.")
 84
 85    def output(self, study: optuna.Study) -> None:
 86        if not os.path.exists(self.path):
 87            os.makedirs(self.path)
 88        with open(f"{self.path}/result.txt", "w", encoding="utf-8") as f:
 89            print(study.best_trial, file=f)
 90
 91        img_path = self.path + "/images"
 92        if not os.path.exists(img_path):
 93            os.makedirs(img_path)
 94
 95        fig = optuna.visualization.plot_contour(study)
 96        fig.write_html(f"{img_path}/contour.html")
 97        fig.write_image(f"{img_path}/contour.png")
 98        fig = optuna.visualization.plot_edf(study)
 99        fig.write_html(f"{img_path}/edf.html")
100        fig.write_image(f"{img_path}/edf.png")
101        fig = optuna.visualization.plot_optimization_history(study)
102        fig.write_html(f"{img_path}/optimization_history.html")
103        fig.write_image(f"{img_path}/optimization_history.png")
104        fig = optuna.visualization.plot_parallel_coordinate(study)
105        fig.write_html(f"{img_path}/parallel_coordinate.html")
106        fig.write_image(f"{img_path}/parallel_coordinate.png")
107        fig = optuna.visualization.plot_slice(study)
108        fig.write_html(f"{img_path}/slice.html")
109        fig.write_image(f"{img_path}/slice.png")
110
111
112if __name__ == "__main__":
113    optimizer: Optimizer = Optimizer(AHCSettings)
114    optimizer.optimize()

仕様

class Optimizer(settings: AHCSettings)[source]

Bases: object

optimize() None[source]
output(study: Study) None[source]
to_blue(arg)[source]
to_bold(arg)[source]
to_green(arg)[source]
to_red(arg)[source]