codeforces 616F Expensive Strings (广义后缀自动机)

https://codeforces.com/contest/616/problem/F

题意:

给你n个字符串串,串SiS_i的权值为cic_i,对于一个字符串ss有函数

F(s)=i=1ncips,isF(s)=\sum_{i=1}^{n}c_i*p_{s,i}*|s|

其中ps,ip_{s,i}为字符串ss在串SiS_i中出现的次数。
让你构造一个ss使得F(s)F(s)最大,输出F(s)F(s)即可。

题解:

先考虑单独一个串SiS_i怎么做。
对这个串SiS_i建SAM,对于每一个节点stst,实际上它的贡献就是endpos(st)cimaxlen(st)|endpos(st)|*c_i*maxlen(st )
而对于多个串,其实是一样的。直接用广义后缀自动机做就好啦

代码:

#include<bits/stdc++.h>
using namespace std;
const int maxn=1e6+5;
const int mod=1e9+7;
#define pb push_back
#define fi first
#define se second
#define all(x) (x).begin(),(x).end()
#define rep(i,a,n) for (int i=a;i<=n;i++)
#define per(i,a,n) for (int i=n;i>=a;i--)
typedef long long ll;
typedef double db;
typedef vector<int> vi;
typedef pair<int,int> pii;
ll qpow(ll a,ll b){ll ans=1;a%=mod;assert(b>=0);for(;b;b>>=1){if(b&1)ans=ans*a%mod;a=a*a%mod;}return ans;}
ll gcd(ll a,ll b){return b>0?gcd(b,a%b):a;}
int n,m,T;
string s[maxn];
int c[maxn];
struct SAM{
	int nxt[maxn<<1][30];
	int link[maxn<<1];
	int step[maxn<<1];
	ll sum[maxn<<1];
	int sz,last,rt;
	void init(){
		sz=last=rt=1;
	}
	void add(int c){
		int p=last;
		int np=++sz;
		last=np;
		step[np]=step[p]+1;
		memset(nxt[np],0,sizeof nxt[np]);
		while(!nxt[p][c]&&p){
			nxt[p][c]=np;
			p=link[p];
		}
		if(p==0){
			link[np]=rt;
		}else{
			int q=nxt[p][c];
			if(step[q]==step[p]+1){
				link[np]=q;
			}else{
				int nq=++sz;
				memcpy(nxt[nq],nxt[q],sizeof nxt[q]);
				link[nq]=link[q];
				link[np]=link[q]=nq;
				step[nq]=step[p]+1;
				while(nxt[p][c]==q&&p){
					nxt[p][c]=nq;
					p=link[p];
				}
				
			} 
		}
	}
	int a[maxn<<1],b[maxn<<1];
	void build(){
		for(int i=1;i<=sz;i++) a[step[i]]++;
		for(int i=1;i<=sz;i++) a[i]+=a[i-1];
		for(int i=1;i<=sz;i++) b[a[step[i]]--]=i;
	}
	void solve(){
		for(int i=1,p;i<=n;i++){
			p=rt;
			for(auto ch:s[i]){
				p=nxt[p][ch-'a'];
				sum[p]+=c[i];
			}
		}
		for(int i=sz,u,fa;i>=1;i--){
			u=b[i],fa=link[b[i]];
			sum[fa]+=sum[u];
		}
		ll res=0;
		for(int i=1;i<=sz;i++){
			res=max(res,1ll*sum[i]*step[i]);
		}
		cout<<res<<endl; 
	}
}sam;
int main(){
	sam.init();
	cin>>n;
	for(int i=1;i<=n;i++){
		cin>>s[i];
		sam.last=1;
		for(auto c:s[i]){
			sam.add(c-'a');
		}
	}
	for(int i=1;i<=n;i++)cin>>c[i];
	sam.build();
	sam.solve();
//	cin>>n>>m;
	
	return 0;
}
/*
5
bbbb
baaa
bbba
aaba
bbaa
17 -17 -82 47 -85
*/


Q.E.D.