euler tour tree

ソースコード

  1#include <vector>
  2#include <cassert>
  3#include <unordered_map>
  4using namespace std;
  5
  6// EulerTourTree
  7namespace titan23 {
  8
  9    /**
 10     * @brief EulerTourTree(動的木)
 11     *
 12     * @tparam T
 13     * @tparam F
 14     * @tparam (*op)(T, T)
 15     * @tparam (*mapping)(F, T)
 16     * @tparam (*composition)(F, F)
 17     * @tparam (*e)()
 18     * @tparam (*id)()
 19     */
 20    template <class T,
 21                T (*op)(T, T),
 22                T (*e)(),
 23                class F,
 24                T (*mapping)(F, T),
 25                F (*composition)(F, F),
 26                F (*id)()>
 27    class EulerTourTree {
 28      private:
 29        struct Node;
 30        using NodePtr = Node*;
 31        int n, group_numbers;
 32        vector<NodePtr> ptr_vertex;
 33        unordered_map<long long, NodePtr> ptr_edge;
 34
 35        struct Node {
 36            T key, data;
 37            F lazy;
 38            NodePtr par, left, right;
 39
 40            Node() {}
 41
 42            Node(T key, F lazy) :
 43                key(key),
 44                data(key),
 45                lazy(lazy),
 46                par(nullptr),
 47                left(nullptr),
 48                right(nullptr) {}
 49        };
 50
 51        void _init_build(vector<T> &a) {
 52            ptr_vertex.resize(n);
 53            for (int i = 0; i < n; ++i) {
 54                ptr_vertex[i] = new Node(a[i], id());
 55            }
 56        }
 57
 58        NodePtr _popleft(NodePtr v) {
 59            v = _left_splay(v);
 60            if (v->right) v->right->par = nullptr;
 61            return v->right;
 62        }
 63
 64        NodePtr _pop(NodePtr v) {
 65            v = _right_splay(v);
 66            if (v->left) v->left->par = nullptr;
 67            return v->left;
 68        }
 69
 70        pair<NodePtr, NodePtr> _split_left(NodePtr v) {
 71            _splay(v);
 72            NodePtr x = v, y = v->right;
 73            if (y) y->par = nullptr;
 74            x->right = nullptr;
 75            _update(x);
 76            return make_pair(x, y);
 77        }
 78
 79        pair<NodePtr, NodePtr> _split_right(NodePtr v) {
 80            _splay(v);
 81            NodePtr x = v->left, y = v;
 82            if (x) x->par = nullptr;
 83            y->left = nullptr;
 84            _update(y);
 85            return make_pair(x, y);
 86        }
 87
 88        void _merge(NodePtr u, NodePtr v) {
 89            if ((!u) || (!v)) return;
 90            u = _right_splay(u);
 91            _splay(v);
 92            u->right = v;
 93            v->par = u;
 94            _update(u);
 95        }
 96
 97        void _splay(NodePtr node) {
 98            _propagate(node);
 99            while (node->par && node->par->par) {
100                NodePtr pnode = node->par, gnode = pnode->par;
101                _propagate(gnode);
102                _propagate(pnode);
103                _propagate(node);
104                node->par = gnode->par;;
105                NodePtr tmp1, tmp2;
106                if ((gnode->left == pnode) == (pnode->left == node)) {
107                    if (pnode->left == node) {
108                        tmp1 = node->right;
109                        pnode->left = tmp1;
110                        node->right = pnode;
111                        pnode->par = node;
112                        tmp2 = pnode->right;
113                        gnode->left = tmp2;
114                        pnode->right = gnode;
115                        gnode->par = pnode;
116                    } else {
117                        tmp1 = node->left;
118                        pnode->right = tmp1;
119                        node->left = pnode;
120                        pnode->par = node;
121                        tmp2 = pnode->left;
122                        gnode->right = tmp2;
123                        pnode->left = gnode;
124                        gnode->par = pnode;
125                    }
126                    if (tmp1) tmp1->par = pnode;
127                    if (tmp2) tmp2->par = gnode;
128                } else {
129                    if (pnode->left == node) {
130                        tmp1 = node->right;
131                        pnode->left = tmp1;
132                        node->right = pnode;
133                        tmp2 = node->left;
134                        gnode->right = tmp2;
135                        node->left = gnode;
136                        pnode->par = node;
137                        gnode->par = node;
138                    } else {
139                        tmp1 = node->left;
140                        pnode->right = tmp1;
141                        node->left = pnode;
142                        tmp2 = node->right;
143                        gnode->left = tmp2;
144                        node->right = gnode;
145                        pnode->par = node;
146                        gnode->par = node;
147                    }
148                    if (tmp1) tmp1->par = pnode;
149                    if (tmp2) tmp2->par = gnode;
150                }
151                _update(gnode);
152                _update(pnode);
153                _update(node);
154                if (!node->par) return;
155                if (node->par->left == gnode) {
156                    node->par->left = node;
157                } else {
158                    node->par->right = node;
159                }
160            }
161            if (!node->par) return;
162            NodePtr pnode = node->par;
163            _propagate(pnode);
164            _propagate(node);
165            if (pnode->left == node) {
166                pnode->left = node->right;
167                if (pnode->left) pnode->left->par = pnode;
168                node->right = pnode;
169            } else {
170                pnode->right = node->left;
171                if (pnode->right) pnode->right->par = pnode;
172                node->left = pnode;
173            }
174            node->par = nullptr;
175            pnode->par = node;
176            _update(pnode);
177            _update(node);
178        }
179
180        NodePtr _left_splay(NodePtr node) {
181            _splay(node);
182            while (node->left) node = node->left;
183            _splay(node);
184            return node;
185        }
186
187        NodePtr _right_splay(NodePtr node) {
188            _splay(node);
189            while (node->right) node = node->right;
190            _splay(node);
191            return node;
192        }
193
194        void _propagate(NodePtr node) {
195            if ((!node) || node->lazy == id()) return;
196            if (node->left) {
197                node->left->key = mapping(node->lazy, node->left->key);
198                node->left->data = mapping(node->lazy, node->left->data);
199                node->left->lazy = composition(node->lazy, node->left->lazy);
200            }
201            if (node->right) {
202                node->right->key = mapping(node->lazy, node->right->key);
203                node->right->data = mapping(node->lazy, node->right->data);
204                node->right->lazy = composition(node->lazy, node->right->lazy);
205            }
206            node->lazy = id();
207        }
208
209        void _update(NodePtr node) {
210            _propagate(node->left);
211            _propagate(node->right);
212            node->data = node->key;
213            if (node->left)  node->data = op(node->left->data, node->data);
214            if (node->right) node->data = op(node->data, node->right->data);
215        }
216
217      public:
218        EulerTourTree(int n) : n(0), group_numbers(0) {}
219
220        EulerTourTree(int n) : n(n), group_numbers(n) {
221            vector<T> a(n, e());
222            _init_build(a);
223        }
224
225        EulerTourTree(vector<T> a) : n((int)a.size()), group_numbers((int)a.size()) {
226            _init_build(a);
227        }
228
229        //! 隣接リストから構築する / `O(logn)`
230        void build(vector<vector<int>> &G) {
231            vector<int> seen(n, 0);
232            vector<long long> a;
233            vector<NodePtr> pool;
234
235            auto dfs = [&] (auto &&dfs, int v, int p) -> void {
236                a.emplace_back((long long)v*n+v);
237                for (const int &x: G[v]) {
238                    if (x == p) continue;
239                    a.emplace_back((long long)v*n+x);
240                    dfs(dfs, x, v);
241                    a.emplace_back((long long)x*n+v);
242                }
243            };
244
245            auto rec = [&] (auto &&rec, int l, int r) -> NodePtr {
246                int mid = (l + r) >> 1;
247                int u = a[mid]/n, v = a[mid]%n;
248                NodePtr node;
249                if (u == v) {
250                    node = ptr_vertex[u];
251                    seen[u] = 1;
252                } else {
253                    node = new Node(e(), id());
254                    ptr_edge[a[mid]] = node;
255                }
256
257                if (l != mid) {
258                    node->left = rec(rec, l, mid);
259                    node->left->par = node;
260                }
261                if (mid+1 != r) {
262                    node->right = rec(rec, mid+1, r);
263                    node->right->par = node;
264                }
265                _update(node);
266                return node;
267            };
268
269            for (int root = 0; root < n; ++root) {
270                if (seen[root]) continue;
271                a.clear();
272                dfs(dfs, root, -1);
273                rec(rec, 0, (int)a.size());
274            }
275        }
276
277        //! 辺 `{u, v}` を追加する / `O(logn)`
278        void link(const int u, const int v) {
279            reroot(u);
280            reroot(v);
281            NodePtr uv_node = new Node(e(), id());
282            NodePtr vu_node = new Node(e(), id());
283            ptr_edge[(long long)u*n+v] = uv_node;
284            ptr_edge[(long long)v*n+u] = vu_node;
285            NodePtr u_node = ptr_vertex[u];
286            NodePtr v_node = ptr_vertex[v];
287            _merge(u_node, uv_node);
288            _merge(uv_node, v_node);
289            _merge(v_node, vu_node);
290            --group_numbers;
291        }
292
293        //! 辺 `{u, v}` を削除する / `O(logn)`
294        void cut(const int u, const int v) {
295            reroot(v);
296            reroot(u);
297            NodePtr uv_node = ptr_edge[(long long)u*n+v];
298            NodePtr vu_node = ptr_edge[(long long)v*n+u];
299            ptr_edge.erase((long long)u*n+v);
300            ptr_edge.erase((long long)v*n+u);
301            NodePtr a, c, _;
302            tie(a, _) = _split_left(uv_node);
303            tie(_, c) = _split_right(vu_node);
304            a = _pop(a);
305            c = _popleft(c);
306            _merge(a, c);
307            ++group_numbers;
308        }
309
310        //! 辺 `{u, v}` がなければ追加する / `O(logn)`
311        bool merge(const int u, const int v) {
312            if (same(u, v)) return false;
313            link(u, v);
314            return true;
315        }
316
317        //! 辺 `{u, v}` があれば削除する / `O(logn)`
318        bool split(const int u, const int v) {
319            if (ptr_edge.find((long long)u*n+v) == ptr_edge.end() || ptr_edge.find((long long)v*n+u) == ptr_edge.end()) {
320                return false;
321            }
322            cut(u, v);
323            return true;
324        }
325
326        //! 代表元? / `O(logn)`
327        NodePtr leader(const int v) {
328            return _left_splay(ptr_vertex[v]);
329        }
330
331        //! 根を `v` にする / `O(logn)`
332        void reroot(const int v) {
333            NodePtr node = ptr_vertex[v];
334            auto[x, y] = _split_right(node);
335            _merge(y, x);
336            _splay(node);
337        }
338
339        //! 連結判定 / `O(logn)`
340        bool same(const int u, const int v) {
341            NodePtr u_node = ptr_vertex[u];
342            NodePtr v_node = ptr_vertex[v];
343            _splay(u_node);
344            _splay(v_node);
345            return (u_node->par != nullptr || u_node == v_node);
346        }
347
348        //! `v` を根とする部分木に `f` を作用、ただし `v` の親は `p(or -1)` / `O(logn)`
349        void subtree_apply(const int v, const int p, const F f) {
350            NodePtr v_node = ptr_vertex[v];
351            reroot(v);
352            if (p == -1) {
353                _splay(v_node);
354                v_node->key = mapping(f, v_node->key);
355                v_node->data = mapping(f, v_node->data);
356                v_node->lazy = composition(f, v_node->lazy);
357                return;
358            }
359            reroot(p);
360            NodePtr a, b, d;
361            tie(a, b) = _split_right(ptr_edge[(long long)p*n+v]);
362            tie(b, d) = _split_left(ptr_edge[(long long)v*n+p]);
363            _splay(v_node);
364            v_node->key = mapping(f, v_node->key);
365            v_node->data = mapping(f, v_node->data);
366            v_node->lazy = composition(f, v_node->lazy);
367            _propagate(v_node);
368            _merge(a, b);
369            _merge(b, d);
370        }
371
372        //! `v` を根とする部分木の総和、ただし `v` の親は `p(or -1)` / `O(logn)`
373        T subtree_sum(const int v, const int p) {
374            NodePtr v_node = ptr_vertex[v];
375            reroot(v);
376            if (p == -1) {
377                _splay(v_node);
378                return v_node->data;
379            }
380            reroot(p);
381            NodePtr a, b, d;
382            tie(a, b) = _split_right(ptr_edge[(long long)p*n+v]);
383            tie(b, d) = _split_left(ptr_edge[(long long)v*n+p]);
384            _splay(v_node);
385            T res = v_node->data;
386            _merge(a, b);
387            _merge(b, d);
388            return res;
389        }
390
391        //! 連結成分の個数を返す / `(1)`
392        int group_count() const {
393            return group_numbers;
394        }
395
396        //! `v` の値を取得 / `O(logn)`
397        T get_vertex(const int v) {
398            NodePtr node = ptr_vertex[v];
399            _splay(node);
400            return node->key;
401        }
402
403        //! `v` の値を `val` に変更 / `O(logn)`
404        void set_vertex(const int v, const T val) {
405            NodePtr node = ptr_vertex[v];
406            _splay(node);
407            node->key = val;
408            _update(node);
409        }
410    };
411}  // namespace titan23

仕様

Warning

doxygenfile: Cannot find file “titan_cpplib/data_structures/euler_tour_tree.cpp