HDU 6589 Sequence【NTT】

it2022-05-09  36

题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=6589 AC代码:

#include<bits/stdc++.h> using namespace std; #define mod 998244353 typedef long long ll; const int maxn=3e5+10; const int maxn1=1e6+10; ll a[maxn1]; int n,m; #define G 3 int cnt[5]; namespace NTT {//模板内容 int rev[maxn], n, m; ll A[maxn], B[maxn], C[maxn]; inline ll Pow(ll a, ll k) { ll base = 1; while (k) { if (k & 1) base = (base * a) % mod; a = (a * a) % mod; k >>= 1; } return base % mod; } void NTT(ll *a, int len, int opt) { for (int i = 0; i < len; i++) { if (i < rev[i]) { swap(a[i], a[rev[i]]); } } for (int i = 1; i < len; i <<= 1) { ll wn = Pow(G, (opt * ((mod - 1) / (i << 1)) + mod - 1) % (mod - 1)); int step = i << 1; for (int j = 0; j < len; j += step) { ll w = 1; for (int k = 0; k < i; k++, w = (1ll * w * wn) % mod) { ll x = a[j + k]; ll y = 1ll * w * a[j + k + i] % mod; a[j + k] = (x + y) % mod; a[j + k + i] = (x - y + mod) % mod; } } } if (opt == -1) { ll r = Pow(len, mod - 2); for (int i = 0; i < len; i++) a[i] = 1ll * a[i] * r % mod; } } void solve(int n, int m) { int x, l = 0 ,len = 1; while (len <= n + m) len <<= 1, ++l; for (int i = 0; i < len; ++i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (l - 1)); NTT(A, len, 1), NTT(B, len, 1); for (int i = 0; i < len; ++i){ C[i] = (ll) (A[i] * B[i]) % mod; A[i]=B[i]=0; } NTT(C, len, -1); } } ll inv[maxn1]; ll c[maxn1]; ll chengji[maxn1]; ll qpow(ll x,ll y) { ll ans=1; while(y){ if(y&1) ans=ans*x%mod; x=x*x%mod; y>>=1; } return ans; } void presolve() { chengji[0]=inv[0]=1; for(ll i=1;i<=maxn1-10;i++){ chengji[i]=chengji[i-1]*i%mod; } inv[maxn1-10]=qpow(chengji[maxn1-10],mod-2); for(ll i=maxn1-11;i>0;i--) { inv[i]=inv[i+1]*(i+1)%mod; } } ll C(int n,int m) { if(m==0) return 1; if(n<m) return 0; ll ans=chengji[n]*inv[m]%mod*inv[n-m]%mod; return ans; } void cc(ll n) { for(ll i=1;i<maxn1;i++) { c[i]=C(n-2+i,i-1); } } int main() { int t; presolve(); cin>>t; while(t--) { cin>>n>>m; for(int i=1;i<=n;i++) { scanf("%d",&a[i]); } for(int i=0;i<5;i++) cnt[i]=0; for(int i=0;i<m;i++) { int tmp; scanf("%d",&tmp); cnt[tmp]++; } cc(cnt[1]); for(int i=0;i<n;i++) { NTT::A[i]=a[i+1]; NTT::B[i]=c[i+1]; } NTT::solve(n,n); for(int i=0;i<n;i++) a[i+1]=NTT::C[i]; cc(cnt[2]); vector<int>ve1,ve2,ve3; for(int i=1;i<=n;i++) { if(i%2==0) { ve2.push_back(a[i]); } else ve1.push_back(a[i]); } for(int i=0;i<(int)ve1.size();i++) { NTT::A[i]=ve1[i]; NTT::B[i]=c[i+1]; } NTT::solve((int)ve1.size(),(int)ve1.size()); for(int i=1;i<=n;i+=2) { a[i]=NTT::C[i/2]; } for(int i=0;i<(int)ve2.size();i++) { NTT::A[i]=ve2[i]; NTT::B[i]=c[i+1]; } NTT::solve((int)ve2.size(),(int)ve2.size()); for(int i=2;i<=n;i+=2) { a[i]=NTT::C[i/2-1]; } ve1.clear(); ve2.clear(); cc(cnt[3]); for(int i=1;i<=n;i++) { if(i%3==1) { ve1.push_back(a[i]); } else if(i%3==2) { ve2.push_back(a[i]); } else ve3.push_back(a[i]); } for(int i=0;i<(int)ve1.size();i++) { NTT::A[i]=ve1[i]; NTT::B[i]=c[i+1]; } NTT::solve((int)ve1.size(),(int)ve1.size()); for(int i=1;i<=n;i+=3) { a[i]=NTT::C[i/3]; } for(int i=0;i<(int)ve2.size();i++) { NTT::A[i]=ve2[i]; NTT::B[i]=c[i+1]; } NTT::solve((int)ve2.size(),(int)ve2.size()); for(int i=2;i<=n;i+=3) { a[i]=NTT::C[i/3]; } for(int i=0;i<(int)ve3.size();i++) { NTT::A[i]=ve3[i]; NTT::B[i]=c[i+1]; } NTT::solve((int)ve3.size(),(int)ve3.size()); for(int i=3;i<=n;i+=3) { a[i]=NTT::C[i/3-1]; } ll ans=0; for(int i=1;i<=n;i++) { ans=ans^(1LL*i*a[i]); } cout<<ans<<endl; } return 0; } /* 2 5 2 2 4 2 1 1 1 1 5 2 3 2 2 4 1 2 2 233 121 */

参考:https://www.cnblogs.com/xusirui/p/11229450.html


最新回复(0)