beam search with tree

ソースコード

  1#include <vector>
  2#include <algorithm>
  3
  4#include "titan_cpplib/others/print.cpp"
  5#include "titan_cpplib/algorithm/random.cpp"
  6#include "titan_cpplib/ahc/state_pool.cpp"
  7#include "titan_cpplib/ahc/timer.cpp"
  8#include "titan_cpplib/data_structures/hash_set.cpp"
  9using namespace std;
 10
 11//! 木上のビームサーチライブラリ
 12namespace beam_search_with_tree {
 13
 14using ScoreType = long long;
 15using HashType = unsigned long long;
 16
 17struct Action {
 18    Action() {}
 19
 20    friend ostream& operator<<(ostream& os, const Action &action) {
 21        return os;
 22    }
 23};
 24
 25class State {
 26  private:
 27    titan23::Random srand;
 28    ScoreType score;
 29    HashType hash;
 30
 31  public:
 32    // TODO Stateを初期化する
 33    void init() {
 34        this->score = 0;
 35    }
 36
 37    // TODO
 38    //! `action` をしたときの評価値とハッシュ値を返す
 39    //! ロールバックに必要な情報はすべてactionにメモしておく
 40    pair<ScoreType, HashType> try_op(Action &action) const {
 41    }
 42
 43    // TODO
 44    //! `action` をする
 45    void apply_op(const Action &action) {
 46    }
 47
 48    // TODO
 49    //! `action` を戻す
 50    void rollback(const Action &action) {
 51    }
 52
 53    // TODO
 54    //! 現状態から遷移可能な `Action` の `vector` を返す
 55    vector<Action> get_actions() const {
 56    }
 57
 58    ScoreType get_score() {
 59        return this->score;
 60    }
 61
 62    void print() const {
 63    }
 64};
 65
 66struct BeamParam {
 67    int MAX_TURN;
 68    int BEAM_WIDTH;
 69};
 70
 71using TreeNodeID = int;
 72using SubStateID = int;
 73
 74//! try_opした結果をメモしておく構造体
 75struct SubState {
 76    TreeNodeID par;
 77    Action action;
 78    ScoreType score;
 79
 80    SubState() : par(-1) {}
 81    SubState(TreeNodeID par, const Action &action, ScoreType score) : par(par), action(action), score(score) {}
 82};
 83
 84//! ビームサーチの過程を表す木のノード
 85struct TreeNode {
 86    TreeNodeID par;
 87    Action pre_action;
 88    ScoreType score;
 89    vector<TreeNodeID> child;
 90
 91    TreeNode() : par(-1) {}
 92
 93    bool is_leaf() const {
 94        return child.empty();
 95    }
 96};
 97
 98titan23::StatePool<TreeNode> treenode_pool;
 99titan23::StatePool<SubState> substate_pool;
100
101
102class BeamSearchWithTree {
103  private:
104    ScoreType best_score;
105    TreeNodeID best_id;
106    titan23::HashSet seen;
107
108    void get_next_beam_recursion(State* state, TreeNodeID node, vector<SubStateID> &next_beam, int depth, const int beam_width) {
109        if (depth == 0) { // 葉
110            vector<Action> actions = state->get_actions();
111            for (Action &action : actions) {
112                auto [score, hash] = state->try_op(action);
113                if (seen.contains_insert(hash)) continue;
114                SubStateID substate = substate_pool.gen();
115                substate_pool.get(substate)->par = node;
116                substate_pool.get(substate)->action = action;
117                substate_pool.get(substate)->score = score;
118                next_beam.emplace_back(substate);
119            }
120            return;
121        }
122        for (const TreeNodeID nxt_node : treenode_pool.get(node)->child) {
123            state->apply_op(treenode_pool.get(nxt_node)->pre_action);
124            get_next_beam_recursion(state, nxt_node, next_beam, depth-1, beam_width);
125            state->rollback(treenode_pool.get(nxt_node)->pre_action);
126        }
127    }
128
129    tuple<int, TreeNodeID, vector<SubStateID>> get_next_beam(State* state, TreeNodeID node, int turn, const int beam_width) {
130        int cnt = 0;
131        while (true) { // 一本道は行くだけ
132            if (treenode_pool.get(node)->child.size() != 1) break;
133            ++cnt;
134            node = treenode_pool.get(node)->child[0];
135            state->apply_op(treenode_pool.get(node)->pre_action);
136        }
137        vector<SubStateID> next_beam;
138        seen.clear();
139        get_next_beam_recursion(state, node, next_beam, turn-cnt, beam_width);
140        return make_tuple(cnt, node, next_beam);
141    }
142
143    //! 不要なNodeを削除し、木を更新する
144    bool update_tree(const TreeNodeID node, const int depth) {
145        if (treenode_pool.get(node)->is_leaf()) return depth == 0;
146        int idx = 0;
147        while (idx < treenode_pool.get(node)->child.size()) {
148            TreeNodeID nxt_node = treenode_pool.get(node)->child[idx];
149            if (!update_tree(nxt_node, depth-1)) {
150                treenode_pool.del(nxt_node);
151                treenode_pool.get(node)->child.erase(treenode_pool.get(node)->child.begin() + idx);
152                continue;
153            }
154            ++idx;
155        }
156        return idx > 0;
157    }
158
159    //! node以上のパスを返す
160    vector<Action> get_path(TreeNodeID node) {
161        vector<Action> result;
162        while (node != -1 && treenode_pool.get(node)->par != -1) {
163            result.emplace_back(treenode_pool.get(node)->pre_action);
164            node = treenode_pool.get(node)->par;
165        }
166        reverse(result.begin(), result.end());
167        return result;
168    }
169
170    //! for debug
171    void print_tree(State* state, const TreeNodeID node, int depth) {
172    }
173
174    //! node以下で、葉かつ最も評価値の良いノードを見るける / 葉はターン数からは判断していないので注意
175    void get_best_node(TreeNodeID node) {
176        if (treenode_pool.get(node)->is_leaf()) {
177            if (best_id == -1 || treenode_pool.get(node)->score < best_score) {
178                best_score = treenode_pool.get(node)->score;
179                best_id = node;
180            }
181            return;
182        }
183        for (TreeNodeID nxt_node : treenode_pool.get(node)->child) {
184            get_best_node(nxt_node);
185        }
186    }
187
188    vector<Action> get_result(TreeNodeID root) {
189        best_id = -1; // 更新
190        get_best_node(root);
191        TreeNodeID node = best_id;
192        vector<Action> result = get_path(node);
193        cerr << treenode_pool.get_size() << endl;
194        return result;
195    }
196
197  public:
198    /**
199     * @brief ビームサーチをする
200     *
201     * @param param ターン数、ビーム幅を指定するパラメータ構造体
202     * @param verbose 途中結果のスコアを標準エラー出力するかどうか
203     * @return vector<Action>
204     */
205    vector<Action> search(const BeamParam &param, const bool verbose = false) {
206        TreeNodeID root = treenode_pool.gen();
207        treenode_pool.get(root)->child.clear();
208        treenode_pool.get(root)->par = -1;
209
210        State* state = new State;
211        state->init();
212
213        this->seen = titan23::HashSet(param.BEAM_WIDTH * 4); // TODO
214
215        int now_turn = 0;
216
217        for (int turn = 0; turn < param.MAX_TURN; ++turn) {
218            if (verbose) cerr << "# turn : " << turn << endl;
219
220            // 次のビーム候補を求める
221            auto [apply_only_turn, next_root, next_beam] = get_next_beam(state, root, turn-now_turn, param.BEAM_WIDTH);
222            root = next_root;
223            now_turn += apply_only_turn;
224            assert(!next_beam.empty());
225            if (verbose) {
226                cerr << "min_score=" << substate_pool.get((*min_element(next_beam.begin(), next_beam.end(), [] (const SubStateID &left, const SubStateID &right) {
227                    return substate_pool.get(left)->score < substate_pool.get(right)->score;
228                })))->score << endl;
229            }
230
231            // ビームを絞る
232            int beam_width = min(param.BEAM_WIDTH, (int)next_beam.size());
233            nth_element(next_beam.begin(), next_beam.begin() + beam_width, next_beam.end(), [&] (const SubStateID &left, const SubStateID &right) {
234                return substate_pool.get(left)->score < substate_pool.get(right)->score;
235            });
236
237            // 探索木の更新
238            for (int i = 0; i < beam_width; ++i) {
239                SubStateID s = next_beam[i];
240                TreeNodeID new_node = treenode_pool.gen();
241                treenode_pool.get(substate_pool.get(s)->par)->child.emplace_back(new_node);
242                treenode_pool.get(new_node)->par = substate_pool.get(s)->par;
243                treenode_pool.get(new_node)->pre_action = substate_pool.get(s)->action;
244                treenode_pool.get(new_node)->score = substate_pool.get(s)->score;
245            }
246            substate_pool.clear();
247            update_tree(root, turn-now_turn+1);
248        }
249
250        // 答えを復元する
251        vector<Action> result = get_result(root);
252        return result;
253    }
254};
255} // namespace beam_search
256
257// int main() {
258//     beam_search_with_tree::BeamParam param;
259//     param.MAX_TURN = 2500;
260//     param.BEAM_WIDTH = 1000;
261//     beam_search_with_tree::BeamSearchWithTree bs;
262//     vector<beam_search_with_tree::Action> ans = bs.search(param, true);
263// }

仕様

Warning

doxygenfile: Cannot find file “titan_cpplib/ahc/beam_search_with_tree.cpp