【LOJ6681】yww 与树上的回文串(点分治)(AC自动机)(字符串哈希)(回文串broder理论)

it2022-05-05  146

传送门


社论(题解):

首先长剖重剖都考虑过了,并没有办法支持快速合并,边分更不用说了,权值在边上怎么边分怎么蛋疼。

考虑点分,我们知道如果一个回文串过了重心,他要么就是重心延伸出去的回文前缀,要么它被重心分成两段,短的一定是长的后缀。

很显然我们考虑用AC自动机求出这种后缀关系。

那么现在问题变成了,当前分治重心到这个点的串(设长度为 l l l)如果存在长度为 d d d的回文前缀,我们需要求它在fail树上长度为 l − d l-d ld的祖先出现了多少次。

窒息的是,回文前缀可能有一大堆。

但是根据broder理论我们知道,这些回文前缀可以被表示为不超过 log ⁡ \log log个等差数列。

假设我们有某一个公差为 d d d,首项为 a 0 a_0 a0,末项为 a t a_t at的等差数列,那么相当于我们需要求fail树上,长度 % d \%d %d a 0 a_0 a0相同,且长度在 l − a 0 l-a_0 la0 l − a t l-a_t lat之间的祖先个数。

对于 d ≤ n d\leq \sqrt n dn 的我们可以考虑用一个桶来记录,也就是记录 % i = j \%i=j %i=j的长度在祖先上出现了多少次,对于公差大的,直接暴力询问就行了。

但是,实际上由于在计算不同重心的时候本身回文串的切割会非常鬼畜,所以实际上我们只对公差小于等于 4 4 4的记录就能够达到最优效率了。

回文前缀直接用哈希求就行了。


代码:

#include<bits/stdc++.h> #define ll long long #define re register #define gc get_char #define cs const namespace IO{ inline char get_char(){ static cs int Rlen=1<<22|1; static char buf[Rlen],*p1,*p2; return (p1==p2)&&(p2=(p1=buf)+fread(buf,1,Rlen,stdin),p1==p2)?EOF:*p1++; } template<typename T> inline T get(){ char c; while(!isdigit(c=gc()));T num=c^48; while(isdigit(c=gc()))num=(num+(num<<2)<<1)+(c^48); return num; } inline int getint(){return get<int>();} } using namespace IO; using std::cerr; using std::cout; cs int N=50004; cs int SqrtN=4; int n; cs ll mod=1e9+7; cs ll B=47; ll bas[N]; inline void init_bas(){ bas[0]=1; for(int re i=1;i<N;++i)bas[i]=bas[i-1]*B%mod; } namespace AC{ int son[N][2],fail[N],now; int len[N],cnt[N]; inline int newnode(int x){ ++now; len[now]=x; return now; } std::vector<int> G[N]; inline void build_fail(){ std::queue<int> q; for(int re i=0;i<2;++i)if(son[1][i])q.push(son[1][i]),fail[son[1][i]]=1; else son[1][i]=1; while(!q.empty()){ int u=q.front();q.pop(); for(int re i=0;i<2;++i) if(son[u][i])fail[son[u][i]]=son[fail[u]][i],q.push(son[u][i]); else son[u][i]=son[fail[u]][i]; } for(int re i=2;i<=now;++i)G[fail[i]].push_back(i); } struct prefix{ int l,r,d; prefix(){} prefix(int _l,int _r,int _d):l(_l),r(_r),d(_d){} }; std::vector<prefix> vec[N]; void dfs_pal(int u,ll h1,ll h2,cs std::vector<prefix> &cur){ vec[u]=cur; if(len[u]>0&&h1==h2){ if(cur.empty())vec[u].push_back(prefix(len[u],len[u],len[u])); else { auto &p=vec[u].back(); if(p.d==len[u]-p.r)p.r=len[u]; else vec[u].push_back(prefix(len[u],len[u],len[u]-p.r)); } } for(int re i=0;i<2;++i)if(son[u][i]) dfs_pal(son[u][i],(h1*B+i)%mod,(h2+bas[len[u]]*i)%mod,vec[u]); } int ans[N]; int st[N],top; inline int find(int x){ int l=1,r=top,mid; while(l<=r)len[st[mid=l+r>>1]]<=x?l=mid+1:r=mid-1; return st[r]; } struct Query{ int a,b,id,tag; Query(int _a,int _b,int _id,int _tag):a(_a),b(_b),id(_id),tag(_tag){} }; std::vector<Query> q[N]; int c[SqrtN+3][SqrtN+3]; int d[N]; void dfs(int u){ for(int re i=1;i<=SqrtN;++i)c[i][len[u]%i]+=cnt[u]; d[len[u]]+=cnt[u]; st[++top]=u; for(cs auto &p:vec[u]){ if(p.d<=SqrtN){ int x=find(len[u]-p.r-1); int y=find(len[u]-p.l); int k=(len[u]-p.l)%p.d; if(x!=y){ if(x>0)q[x].push_back(Query(p.d,k,u,-1)); if(y>0)q[y].push_back(Query(p.d,k,u,1)); } } else { for(int a=p.l;a<=p.r;a+=p.d) q[u].push_back(Query(0,len[u]-a,u,+1)); } } for(int re v:G[u])dfs(v); for(auto &q:(AC::q[u])){ if(q.a)ans[q.id]+=q.tag*c[q.a][q.b]; else ans[q.id]+=q.tag*d[q.b]; } for(int re i=1;i<=SqrtN;++i)c[i][len[u]%i]-=cnt[u]; d[len[u]]-=cnt[u]; st[top--]=0; } ll solve(){ dfs(1);ll res=0; for(int re i=2;i<=now;++i)res+=(ll)cnt[i]*ans[i]+(ll)cnt[i]*(cnt[i]-1)/2; return res; } ll calc(){ dfs_pal(1,0,0,std::vector<prefix>(0)); build_fail(); return solve(); } inline void clear(){ while(now){ son[now][0]=son[now][1]=fail[now]=len[now]=cnt[now]=ans[now]=0; vec[now].clear(); G[now].clear(); q[now].clear(); --now; } newnode(0); } } namespace TDC{ struct edge{ int to,w; }; std::vector<edge> G[N]; inline void addedge(int u,int v,int w){ G[u].push_back((edge){v,w}); G[v].push_back((edge){u,w}); } bool ban[N]; int siz[N]; int tot,mx,g; ll ans; void get_siz(int u,int fa){ siz[u]=1; for(edge &e:G[u])if(!ban[e.to]&&e.to!=fa)get_siz(e.to,u),siz[u]+=siz[e.to]; } void find_G(int u,int fa){ int mx_u=tot-siz[u]; for(edge &e:G[u])if(!ban[e.to]&&e.to!=fa)find_G(e.to,u),mx_u=std::max(mx_u,siz[e.to]); if(mx_u<mx)mx=mx_u,g=u; } int get_G(int u){ get_siz(u,0); tot=siz[u]; g=-1,mx=0x3f3f3f3f; find_G(u,0); assert(~g); return g; } void dfs(int u,int p,int nd){ AC::cnt[nd]++; for(edge &e:G[u])if(e.to!=p&&!ban[e.to]){ if(!AC::son[nd][e.w])AC::son[nd][e.w]=AC::newnode(AC::len[nd]+1); dfs(e.to,u,AC::son[nd][e.w]); } } inline ll calc(int u,int w=-1){ AC::clear(); if(~w){ AC::son[1][w]=AC::newnode(1); dfs(u,0,AC::son[1][w]); } else dfs(u,0,1); return AC::calc(); } inline void solve_G(int u){ ban[u]=true; ans+=calc(u); for(edge &e:G[u])if(!ban[e.to]){ ans-=calc(e.to,e.w); int t=get_G(e.to); solve_G(t); } } inline void solve(){ int u=get_G(1); solve_G(u); cout<<ans<<"\n"; } } signed main(){ init_bas(); // freopen("pal.in","r",stdin);freopen("pal.out","w",stdout); n=getint(); for(int re i=1;i<n;++i){ int u=getint(),v=getint(),w=getint(); TDC::addedge(u,v,w); } TDC::solve(); return 0; }

最新回复(0)