beam search euler

ソースコード

  1#include <iostream>
  2#include <vector>
  3#include <cassert>
  4#include <algorithm>
  5
  6#include "titan_cpplib/algorithm/random.cpp"
  7#include "titan_cpplib/ahc/state_pool.cpp"
  8#include "titan_cpplib/ahc/timer.cpp"
  9#include "titan_cpplib/data_structures/hash_set.cpp"
 10#include "titan_cpplib/others/print.cpp"
 11
 12using namespace std;
 13
 14#define rep(i, n) for (int i = 0; i < (n); ++i)
 15
 16//! 木上のビームサーチライブラリ
 17namespace beam_search_with_tree {
 18
 19using ScoreType = long long;
 20using HashType = unsigned long long;
 21const ScoreType INF = 1e9;
 22
 23// Action
 24struct Action {
 25    char d;
 26    ScoreType pre_score, nxt_score;
 27    HashType pre_hash, nxt_hash;
 28
 29    Action() {}
 30    Action(const char d) : d(d), pre_score(INF), nxt_score(INF), pre_hash(0), nxt_hash(0) {}
 31};
 32ostream& operator<<(ostream& os, const Action &action) {
 33    os << action.d;
 34    return os;
 35}
 36
 37class State {
 38  private:
 39
 40  public:
 41    ScoreType score;
 42    HashType hash;
 43
 44    State() {}
 45
 46    // TODO Stateを初期化する
 47    void init() {}
 48
 49    // TODO
 50    //! `action` をしたときの評価値とハッシュ値を返す
 51    //! ロールバックに必要な情報はすべてactionにメモしておく
 52    pair<ScoreType, HashType> try_op(Action &action) const {}
 53
 54    bool is_done() const {}
 55
 56    // TODO
 57    //! `action` をする
 58    void apply_op(const Action &action) {}
 59
 60    // TODO
 61    //! `action` を戻す
 62    void rollback(const Action &action) {}
 63
 64    // TODO
 65    //! 現状態から遷移可能な `Action` の `vector` を返す
 66    vector<Action> get_actions() const {}
 67
 68    void print() const {}
 69};
 70
 71struct BeamParam {
 72    int MAX_TURN;
 73    int BEAM_WIDTH;
 74};
 75
 76class BeamSearchWithTree {
 77    // ref: https://eijirou-kyopro.hatenablog.com/entry/2024/02/01/115639
 78
 79  private:
 80    static constexpr const int PRE_ORDER = -1;
 81    static constexpr const int POST_ORDER = -2;
 82
 83    titan23::HashSet seen;
 84    using ActionIDType = int;
 85    ActionIDType ActionID;
 86    vector<Action> result;
 87
 88    // ビームサーチの過程を表す木
 89    // <dir or id, action, action_id>
 90    // dir or id := 葉のとき、leaf_id
 91    //              そうでないとき、行きがけなら-1、帰りがけなら-2
 92    vector<tuple<int, Action, ActionIDType>> tree;
 93
 94    // 次のビーム候補を保持する配列
 95    vector<tuple<int, ScoreType, Action, ActionIDType>> next_beam; // <par, score, action, action_id>
 96
 97    vector<vector<int>> next_beam_data;
 98
 99    void get_next_beam(State* state, const int turn) {
100        next_beam.clear();
101        next_beam.reserve(tree.size() * 4); // TODO
102        seen.clear();
103
104        if (turn == 0) {
105            vector<Action> actions = state->get_actions();
106            for (Action &action : actions) {
107                auto [score, hash] = state->try_op(action);
108                if (seen.contains_insert(hash)) continue;
109                next_beam.emplace_back(PRE_ORDER, score, action, ActionID);
110                ActionID++;
111            }
112            return;
113        }
114
115        int leaf_id = 0;
116        for (int i = 0; i < tree.size(); ++i) {
117            auto [dir_or_leaf_id, action, _] = tree[i];
118            if (dir_or_leaf_id >= 0) {
119                state->apply_op(action);
120                vector<Action> actions = state->get_actions();
121                std::get<0>(tree[i]) = leaf_id;
122                for (Action &action : actions) {
123                    auto [score, hash] = state->try_op(action);
124                    if (seen.contains_insert(hash)) continue;
125                    next_beam.emplace_back(leaf_id, score, action, ActionID);
126                    ActionID++;
127                }
128                ++leaf_id;
129                state->rollback(action);
130            } else if (dir_or_leaf_id == PRE_ORDER) {
131                state->apply_op(action);
132            } else {
133                state->rollback(action);
134            }
135        }
136    }
137
138    //! 不要なNodeを削除し、木を更新する
139    int update_tree(State* state, const int turn) {
140        vector<tuple<int, Action, ActionIDType>> new_tree;
141        new_tree.reserve(tree.size());
142        if (turn == 0) {
143            for (auto &[par, _, new_action, action_id] : next_beam) {
144                assert(par == -1);
145                new_tree.emplace_back(0, new_action, action_id);
146            }
147            swap(tree, new_tree);
148            return 0;
149        }
150
151        int i = 0;
152        int apply_only_turn = 0;
153        while (true) {
154            const auto &[dir_or_leaf_id, action, action_id] = tree[i];
155            // 行きがけかつ帰りがけのaction_idが一致しているなら、一本道なので行くだけ
156            if (dir_or_leaf_id == PRE_ORDER && action_id == std::get<2>(tree.back())) {
157                ++i;
158                result.emplace_back(action);
159                state->apply_op(action);
160                tree.pop_back();
161                apply_only_turn++;
162            } else {
163                break;
164            }
165        }
166
167        for (; i < tree.size(); ++i) {
168            const auto &[dir_or_leaf_id, action, action_id] = tree[i];
169            if (dir_or_leaf_id >= 0) {
170                if (next_beam_data[dir_or_leaf_id].empty()) continue;
171                new_tree.emplace_back(PRE_ORDER, action, action_id);
172                for (const int beam_idx : next_beam_data[dir_or_leaf_id]) {
173                    auto &[_, __, new_action, new_action_id] = next_beam[beam_idx];
174                    new_tree.emplace_back(dir_or_leaf_id, new_action, new_action_id);
175                }
176                new_tree.emplace_back(POST_ORDER, action, action_id);
177                next_beam_data[dir_or_leaf_id].clear();
178            } else if (dir_or_leaf_id == PRE_ORDER) {
179                new_tree.emplace_back(PRE_ORDER, action, action_id);
180            } else {
181                int pre_dir = std::get<0>(new_tree.back());
182                if (pre_dir == PRE_ORDER) {
183                    new_tree.pop_back(); // 一つ前が行きがけなら、削除して追加しない
184                } else {
185                    new_tree.emplace_back(POST_ORDER, action, action_id);
186                }
187            }
188        }
189        swap(tree, new_tree);
190        return apply_only_turn;
191    }
192
193    void get_result() {
194        int best_id = -1;
195        ScoreType best_score = 0;
196        for (auto [par, score, _, __] : next_beam) {
197            if (best_id == -1 || score < best_score) {
198                best_score = score;
199                best_id = par;
200            }
201        }
202        assert(best_id != -1);
203        for (const auto &[dir_or_leaf_id, action, _] : tree) {
204            if (dir_or_leaf_id >= 0) {
205                if (best_id == dir_or_leaf_id) {
206                    result.emplace_back(action);
207                    return;
208                }
209            } else if (dir_or_leaf_id == PRE_ORDER) {
210                result.emplace_back(action);
211            } else {
212                result.pop_back();
213            }
214        }
215        cerr << PRINT_RED << "Error: 解が見つかりませんでした" << PRINT_NONE << endl;
216        assert(false);
217    }
218
219  public:
220
221    /**
222     * @brief ビームサーチをする
223     *
224     * @param param ターン数、ビーム幅を指定するパラメータ構造体
225     * @param verbose ログ出力するかどうか
226     * @return vector<Action>
227     */
228    vector<Action> search(const BeamParam &param, const bool verbose = false) {
229        if (verbose) cerr << PRINT_GREEN << "Info: start search()" << PRINT_NONE << endl;
230
231        ActionID = 0;
232        State* state = new State;
233        state->init();
234
235        this->seen = titan23::HashSet(param.BEAM_WIDTH * 4); // TODO
236
237        int now_turn = 0;
238        for (int turn = 0; turn < param.MAX_TURN; ++turn) {
239            if (verbose) cerr << "Info: # turn : " << turn+1 << endl;
240
241            // 次のビーム候補を求める
242            get_next_beam(state, turn-now_turn);
243
244            if (next_beam.empty()) {
245                cerr << PRINT_RED << "Error: \t次の候補が見つかりませんでした" << PRINT_NONE << endl;
246                assert(!next_beam.empty());
247            }
248
249            // ビームを絞る
250            int beam_width = min(param.BEAM_WIDTH, (int)next_beam.size());
251            assert(beam_width <= param.BEAM_WIDTH);
252
253            nth_element(next_beam.begin(), next_beam.begin() + beam_width, next_beam.end(), [&] (const tuple<int, ScoreType, Action, ActionIDType> &left, const tuple<int, ScoreType, Action, ActionIDType> &right) {
254                return std::get<1>(left) < std::get<1>(right);
255            });
256
257            tuple<int, ScoreType, Action, ActionIDType> bests = *min_element(next_beam.begin(), next_beam.begin() + beam_width, [&] (const tuple<int, ScoreType, Action, ActionIDType> &left, const tuple<int, ScoreType, Action, ActionIDType> &right) {
258                return std::get<1>(left) < std::get<1>(right);
259            });
260            if (verbose) cerr << "Info: \tbest_score = " << std::get<1>(bests) << endl;
261            if (std::get<1>(bests) == 0) { // TODO 終了条件
262                cerr << PRINT_GREEN << "Info: find valid solution." << PRINT_NONE << endl;
263                get_result();
264                result.emplace_back(std::get<2>(bests));
265                return result;
266            }
267
268            // 探索木の更新
269            if (turn != 0) {
270                if (next_beam_data.size() < next_beam.size()) {
271                    next_beam_data.resize(next_beam.size());
272                }
273                for (int i = 0; i < beam_width; ++i) {
274                    auto &[par, _, new_action, new_action_id] = next_beam[i];
275                    next_beam_data[par].emplace_back(i);
276                }
277            }
278            int apply_only_turn = update_tree(state, turn);
279            now_turn += apply_only_turn;
280        }
281
282        // 答えを復元する
283        if (verbose) cerr << PRINT_GREEN << "Info: MAX_TURN finished." << PRINT_NONE << endl;
284        get_result();
285        return result;
286    }
287};
288} // namespace beam_search
289
290// int main() {
291//     beam_search_with_tree::BeamParam param;
292//     param.MAX_TURN = 1000;
293//     param.BEAM_WIDTH = 100;
294//     beam_search_with_tree::init_zhs();
295//     beam_search_with_tree::BeamSearchWithTree bs;
296//     vector<beam_search_with_tree::Action> ans = bs.search(param, true);
297//     for (const beam_search_with_tree::Action &action : ans) {
298//         cout << action;
299//     }
300//     cout << endl;
301//     cerr << "Score = " << ans.size() << endl;
302//     return 0;
303// }

仕様

Warning

doxygenfile: Cannot find file “titan_cpplib/ahc/beam_search_euler.cpp