树上主席树 + LCA SPOJ - COT【Count on a tree】

it2022-05-09  30

树上主席树 + LCA SPOJ - COT【Count on a tree】

https://cn.vjudge.net/contest/304073#problem/B

题意

给定 n 个数、m 次询问以及 n-1 条边(保证建立成一棵树),查询树上 节点x 到 节点y 的路径中第 K 小的结点值是多少。

分析

这题还是花了不少功夫的,树上虽然看上去恐怖,其实做起来和普通主席树没太大区别。

普通主席树:每次建立都是基于前一个进行的。树上主席树:每次建立都是基于父节点进行的。

我们可以发现这棵主席树是包括当前节点的所有祖先结点的。需要注意的是查询的时候不是减去两次到 \(lca(x,y)\) 的值,因为公共节点也在路径之中,所以在查询的时候:

\[ sum = T[T[x].l].sum + T[T[y].l].sum - T[T[z].l].sum - T[T[fz].l].sum\]

sum 用于和 k 值进行比较,其中 \(z = lca(x, y)\)\(fz = pre[z]\)

然后 lca 的求法我用的是倍增,写起来相对比较好理解,不过感觉有空还得多做点 lca 的题。不看板子敲起来还是有点虚...

代码

#include <map> #include <set> #include <list> #include <cmath> #include <ctime> #include <deque> #include <stack> #include <queue> #include <bitset> #include <cctype> #include <cstdio> #include <vector> #include <string> #include <cstdlib> #include <cstring> #include <fstream> #include <iomanip> #include <numeric> #include <iostream> #include <algorithm> using namespace std; typedef long long ll; typedef unsigned long long ull; const double PI = acos(-1.0); const double eps = 1e-6; const int inf = 0x3f3f3f3f; const int mod = 1e9 + 7; const int maxn = 1e5 + 5; int n, m, cnt; int a[maxn]; int root[maxn]; vector<int> v; vector<int> G[maxn]; int pre[maxn]; int dep[maxn]; int f[maxn][20]; struct node { int l, r, sum; } T[maxn * 40]; void init() { v.clear(); for (int i = 0; i <= n; i++) { G[i].clear(); } cnt = 0; for (int i = 0; i <= n * 40; i++) { T[i].l = T[i].r = T[i].sum = 0; } memset(dep, 0, sizeof(dep)); memset(pre, 0, sizeof(pre)); memset(f, 0, sizeof(f)); memset(root, 0, sizeof(root)); } int getid(int x) { return lower_bound(v.begin(), v.end(), x) - v.begin() + 1; } void update(int l, int r, int &x, int y, int pos) { T[++cnt] = T[y]; T[cnt].sum ++; x = cnt; if (l >= r) { return ; } int mid = (l + r) / 2; if (mid >= pos) { update(l, mid, T[x].l, T[y].l, pos); } else { update(mid + 1, r, T[x].r, T[y].r, pos); } } void dfs(int u, int fa, int n) { pre[u] = fa; dep[u] = dep[fa] + 1; f[u][0] = fa; for (int i = 1; i < 20; i++) { f[u][i] = f[f[u][i - 1]][i - 1]; } update(1, n, root[u], root[fa], getid(a[u])); for (int i = 0; i < G[u].size(); i++) { if (G[u][i] == fa) continue; dfs(G[u][i], u, n); } } int lca(int x, int y) { if (dep[x] < dep[y]) { swap(x, y); } for (int i = 19; i >= 0; i--) { if (dep[x] - (1 << i) >= dep[y]) { x = f[x][i]; } } if (x == y) { return x; } for (int i = 19; i >= 0; i--) { if (f[x][i] != f[y][i]) { x = f[x][i]; y = f[y][i]; } } return f[x][0]; } int query(int l, int r, int x, int y, int z, int fz, int k) { if (l == r) { return l; } int mid = (l + r) / 2; int sum = T[T[x].l].sum + T[T[y].l].sum - T[T[z].l].sum - T[T[fz].l].sum; if (sum >= k) { return query(l, mid, T[x].l, T[y].l, T[z].l, T[fz].l, k); } else { return query(mid + 1, r, T[x].r, T[y].r, T[z].r, T[fz].r, k - sum); } } int main() { while (~scanf("%d%d", &n, &m)) { init(); for (int i = 1; i <= n; i++) { scanf("%d", &a[i]); v.push_back(a[i]); } sort(v.begin(), v.end()); v.erase(unique(v.begin(), v.end()), v.end()); int new_n = v.size(); // 其实直接用 n 也可以 for (int i = 1; i < n; i++) { int u, v; scanf("%d%d", &u, &v); G[u].push_back(v); G[v].push_back(u); } dfs(1, 0, new_n); for (int i = 1; i <= m; i++) { int x, y, k; scanf("%d%d%d", &x, &y, &k); int fa = lca(x, y); int ans = query(1, new_n, root[x], root[y], root[fa], root[pre[fa]], k) - 1; printf("%d\n", v[ans]); } } return 0; }

转载于:https://www.cnblogs.com/Decray/p/10927790.html


最新回复(0)