给一个串,问这个串里所有本质不同的回文子串,有多少对满足一个串是另一个的子串。
这题现场过的人很少啊,题解也给了个蛮复杂我还没看懂的带log的做法,其实了解回文树的话特别好想,我们现场写了一个O(n)的做法(在牛客跑了72ms)。回文树还算是个新东西,还没有被玩坏,我以前刷的回文树套题基本都算是板子题,最近多校有几道回文树就进入了灵活运用的范畴了,出题人开始准备玩坏这个算法了,以后这都是基操。
见本质不同,想回文自动机,然后就开始梳理在回文树上,怎样的两个节点表示的回文串有子串关系?
首先是next边,很显然,一条next链上,除奇根偶根,所有父亲都是儿子的子串。 其次是fail边,又又又很显然,一条fail链上,父亲是儿子的后缀,所以除奇根偶根所有父亲都是儿子的子串。
好了这题就快做完了。 对于每个节点,答案就是其next链上除根外的父亲个数+fail链上除根外的父亲个数。 然而仅仅这样还是存在问题,因为next链和fail链有时候存在交叉,也就是说一个回文串既是另一个串的后缀,又在那个串中间出现,比如样例的aaaa,next链上的aa与fail链的aa是同一个节点。
于是还要想办法去个重。在dfs并且对每个节点跳fail树统计的时候,打个vis标记(当前节点也要标记,next链和fail链重的时候就可能会跳到这里),如果下次跳的时候发现了vis,就可以停止了,这样就简单的保证了不交叉,并且防止了fail树重复跳带来的复杂度。
最后,对于跑出来的东西做一个很简单的dp:当前节点贡献 = 父亲贡献 + fail链贡献 + 1。
ac代码:
#include <bits/stdc++.h> using namespace std; typedef long long ll; const int maxn = 1e5 + 5; int _, kase = 1; char s[maxn]; ll ans = 0; struct Pam { int next[maxn][26]; int fail[maxn]; int len[maxn];// 当前节点表示回文串的长度 int S[maxn]; int dp[maxn]; bool vis[maxn]; int last, n, p; int newNode(int l) { memset(next[p], 0, sizeof(next[p])); len[p] = l; dp[p] = 0; return p++; } void init() { ans = 0; n = last = p = 0; newNode(0); newNode(-1); S[n] = -1; fail[0] = 1; } int getFail(int x) { while (S[n - len[x] - 1] != S[n]) { x = fail[x]; } return x; } void add(int c) { S[++n] = c; int cur = getFail(last); if (!next[cur][c]) { int now = newNode(len[cur] + 2); fail[now] = next[getFail(fail[cur])][c]; next[cur][c] = now; } last = next[cur][c]; } int jump(int x) { int cnt = 0; vis[x] = 1; while (fail[x] != 0 && fail[x] != 1 && !vis[fail[x]]) { x = fail[x]; vis[x] = 1, ++cnt; } return cnt; } void clearJump(int x, int cnt) { vis[x] = 0; while (cnt--) { x = fail[x]; vis[x] = 0; } } void dfs(int x, int fa) { int jp = jump(x); dp[x] = jp; if (x != 1 && x != 0 && fa != 0 && fa != 1) { dp[x] = dp[fa] + jp + 1; } ans += dp[x]; for (int i = 0; i < 26; ++i) { if (next[x][i]) { dfs(next[x][i], x); } } clearJump(x, jp); } void build() { init(); for (int i = 1; s[i]; i++) { add(s[i] - 'a'); } } } pam; int main() { scanf("%d", &_); while (_--) { scanf("%s", s + 1); pam.build(); printf("Case #%d: ", kase++); pam.dfs(1, 1); // printf("%lld\n", ans); pam.dfs(0, 0); printf("%lld\n", ans); } return 0; }