贴下以前写的代码 比赛前我准备着重看的 主席树 树dp 字符串
#include<bits/stdc++.h> using namespace std; typedef long long ll; const int MAXN = 1e5+5; const int MOD = 1e9+7; int N; int A[MAXN]; vector<int> mp[MAXN]; int val[MAXN][2]; /**********President Tree**********/ int tot; struct Node{ int ls, rs, cc; Node(int a=0, int b=0, int c=0):ls(a), rs(b), cc(c){} }tree[MAXN*20]; int sta[MAXN]; int Build(int pos,int l,int r){ int rt = tot++; tree[rt].cc = 1; tree[rt].ls = tree[rt].rs = 0; if(l == r) return rt; int mid = (l+r) >> 1; if(pos <= mid) tree[rt].ls = Build(pos, l,mid); else tree[rt].rs = Build(pos, mid+1,r); return rt; } int Merge(int x,int y,int l,int r){ if(x == 0 || y == 0) return x+y; tree[x].cc += tree[y].cc; if(l == r) { tree[x].ls = tree[x].rs = 0; return x; } int mid = (l+r) >> 1; tree[x].ls = Merge(tree[x].ls, tree[y].ls, l, mid); tree[x].rs = Merge(tree[x].rs, tree[y].rs,mid+1, r); return x; } int Query(int rt,int K,int l,int r){ if(tree[rt].cc < K) return 100000; if(l == r) return l; int mid = (l+r)>>1; if(tree[tree[rt].ls].cc >= K) return Query(tree[rt].ls, K, l ,mid); else return Query(tree[rt].rs, K-tree[tree[rt].ls].cc, mid+1, r); } void dfs(int x){ sta[x] = Build(A[x], 1, 100000); for(int i = 0; i < mp[x].size(); ++i){ int y = mp[x][i]; dfs(y); sta[x] = Merge(sta[x], sta[y], 1, 100000); } // printf("%d\n",sta[x]); val[x][0] = Query(sta[x], tree[sta[x]].cc+1>>1, 1, 100000); val[x][1] = Query(sta[x], (tree[sta[x]].cc+1>>1)+1, 1, 100000); } /**************BIT*************/ ll bitt[MAXN]; void Add(int pos,int num){ for(int i = pos; i > 0; i -= i&-i) bitt[i] += num; } ll Sum(int pos){ ll ans = 0; for(int i = pos; i <= 100000; i += i&-i) ans += bitt[i]; return ans; } ll dfs2(int x){ Add(val[x][0], val[x][1]-val[x][0]); ll ans = Sum(val[x][0]); for(int i = 0; i < mp[x].size(); ++i){ int y = mp[x][i]; ans = max(ans, dfs2(y)); } Add(val[x][0], val[x][0]-val[x][1]); return ans; } int main(){ while(~scanf("%d",&N)){ memset(bitt,0,sizeof(bitt)); tot = 1; for(int i = 1; i <= N; ++i) scanf("%d",&A[i]); for(int i = 1; i <= N; ++i) mp[i].clear(); for(int i = 2; i <= N; ++i){ int a; scanf("%d",&a); mp[a].push_back(i); } dfs(1); ll sum = 0; for(int i = 1; i <= N; ++i) sum += val[i][0]; printf("%lld\n", sum + dfs2(1)); } return 0; }转载于:https://www.cnblogs.com/Basasuya/p/8433740.html