fork download
  1. #include <bits/stdc++.h>
  2. using namespace std;
  3. using ll = long long;
  4. const int MAXN = 200500;
  5. const int LOG = 20;
  6. int n, q;
  7. vector<int> g[MAXN];
  8. int parent_[LOG][MAXN];
  9. int depth_[MAXN];
  10. int tin[MAXN], tout[MAXN], tim = 0;
  11.  
  12. struct Cmp { bool operator()(int a, int b) const { return tin[a] < tin[b]; } };
  13.  
  14. void build_dfs(int root = 1){
  15. // iterative DFS to compute tin/tout, depth and parent_[0]
  16. vector<pair<int,int>> st; // (v, next_child_index)
  17. st.reserve(n);
  18. for (int i=1;i<=n;i++) tin[i]=tout[i]=0;
  19. tim = 0;
  20. depth_[root] = 0;
  21. parent_[0][root] = 0;
  22. st.emplace_back(root, 0);
  23. while(!st.empty()){
  24. int v = st.back().first;
  25. int &ci = st.back().second;
  26. if (ci == 0){
  27. tin[v] = ++tim;
  28. }
  29. if (ci < (int)g[v].size()){
  30. int u = g[v][ci++];
  31. if (u == parent_[0][v]) continue;
  32. parent_[0][u] = v;
  33. depth_[u] = depth_[v] + 1;
  34. st.emplace_back(u, 0);
  35. } else {
  36. tout[v] = ++tim;
  37. st.pop_back();
  38. }
  39. }
  40. }
  41.  
  42. bool is_ancestor(int u, int v){ // u ancestor of v
  43. return tin[u] <= tin[v] && tout[u] >= tout[v];
  44. }
  45.  
  46. int lca(int u, int v){
  47. if (u==0) return v;
  48. if (v==0) return u;
  49. if (is_ancestor(u,v)) return u;
  50. if (is_ancestor(v,u)) return v;
  51. for (int k = LOG-1; k >= 0; --k){
  52. int pu = parent_[k][u];
  53. if (pu && !is_ancestor(pu, v)) u = pu;
  54. }
  55. return parent_[0][u];
  56. }
  57.  
  58. int dist(int u, int v){
  59. int w = lca(u,v);
  60. return depth_[u] + depth_[v] - 2*depth_[w];
  61. }
  62.  
  63. struct Query{int l,r,idx,block;};
  64.  
  65. int main(){
  66. ios::sync_with_stdio(false);
  67. cin.tie(nullptr);
  68.  
  69. // file IO as problem statement
  70. freopen("kingdom.inp","r",stdin);
  71. freopen("kingdom.out","w",stdout);
  72.  
  73. if (!(cin >> n >> q)) return 0;
  74. for (int i=1;i<=n;i++) g[i].clear();
  75. for (int i=0;i<n-1;i++){
  76. int u,v; cin >> u >> v;
  77. g[u].push_back(v);
  78. g[v].push_back(u);
  79. }
  80. // build dfs for lca
  81. build_dfs(1);
  82. // binary lifting
  83. for (int k=1;k<LOG;k++){
  84. for (int v=1; v<=n; v++){
  85. int p = parent_[k-1][v];
  86. parent_[k][v] = p ? parent_[k-1][p] : 0;
  87. }
  88. }
  89.  
  90. vector<Query> qs(q);
  91. for (int i=0;i<q;i++){
  92. int l,r; cin >> l >> r;
  93. qs[i].l = l; qs[i].r = r; qs[i].idx = i;
  94. }
  95. int BLOCK = max(1, (int)sqrt(n));
  96. for (int i=0;i<q;i++) qs[i].block = qs[i].l / BLOCK;
  97. sort(qs.begin(), qs.end(), [&](const Query &a, const Query &b){
  98. if (a.block != b.block) return a.block < b.block;
  99. if (a.block & 1) return a.r > b.r;
  100. return a.r < b.r;
  101. });
  102.  
  103. set<int, Cmp> st; // stores node ids, ordered by tin
  104. ll sum = 0; // sum of pairwise adjacent distances in cyclic order
  105. auto add_node = [&](int v){
  106. if (st.empty()){
  107. st.insert(v);
  108. return;
  109. }
  110. auto it = st.lower_bound(v);
  111. int s = (it == st.end() ? *st.begin() : *it);
  112. int p = (it == st.begin() ? *st.rbegin() : *prev(it));
  113. sum += (ll)dist(p, v) + dist(v, s) - dist(p, s);
  114. st.insert(it, v);
  115. };
  116. auto remove_node = [&](int v){
  117. if (st.size() == 1){
  118. st.erase(v);
  119. sum = 0;
  120. return;
  121. }
  122. auto it = st.find(v);
  123. auto it_s = next(it);
  124. if (it_s == st.end()) it_s = st.begin();
  125. int s = *it_s;
  126. int p = (it == st.begin() ? *st.rbegin() : *prev(it));
  127. sum -= (ll)dist(p, v) + dist(v, s) - dist(p, s);
  128. st.erase(it);
  129. };
  130.  
  131. vector<long long> ans(q);
  132. int curL = 1, curR = 0;
  133. for (auto &qu : qs){
  134. int L = qu.l, R = qu.r;
  135. while (curL > L) { add_node(--curL); }
  136. while (curR < R) { add_node(++curR); }
  137. while (curL < L) { remove_node(curL++); }
  138. while (curR > R) { remove_node(curR--); }
  139. if (st.empty()) ans[qu.idx] = 0;
  140. else ans[qu.idx] = sum/2 + 1; // number of vertices = edges + 1
  141. }
  142.  
  143. for (int i=0;i<q;i++) cout << ans[i] << '\n';
  144. return 0;
  145. }
  146.  
Success #stdin #stdout 0.01s 10124KB
stdin
Standard input is empty
stdout
Standard output is empty