天天看点

题解 CF1073G Yet Another LCP Problem

题解 CF1073G Yet Another LCP Problem

题目链接

题意描述

定义\(LCP(S_i,S_j)\)为字符串\(S\)的后缀\(i\)和后缀\(j\)的最长公共前缀长度。

给定一个长为\(n\)的字符串\(S\),\(q\)次询问,每次询问给出两个正整数集合\(A,B\),大小分别为\(k,l\),求\(\sum\limits_{i\in A,j\in B}LCP(S_i,S_j)\)

Solution:

看到单字符串和\(LCP\),马上想到后缀树,用\(SAM\)建出后缀树后,\(LCP(S_i,S_j)\)就是后缀树上的\(Dep_{LCA(S_i,S_j)}\)。

现在看到数据范围:\(\sum\limits_{i=1}^{q}k_i,\sum\limits_{i=1}^{q}l_i\leq2\times10^5\),直接明示虚树。考虑怎样\(DP\),有一个经典的树上统计每条边的经过次数问题,答案是\(\sum\limits_{(u,v)\in E}Size_v\times(n-Size_v)\)。现在转化成\(LCA\)的数量,考虑容斥,设\(Size_{u,0/1}\)表示u的子树中在\(A/B\)集合中的点数,一个节点子树内的答案就是每个儿子\(v\)中\(A\)集合内的点数乘其它儿子中\(B\)集合内的点数,再加上这个节点上的\(A\)集合内的点数乘子树中\(B\)集合内的点数,最后乘上深度。即

\[\sum\limits_{u=1}^{n}[(\sum\limits_{v\in Son(u)}Size_{v,0}\times\sum\limits_{w\in Son(u),w\not= v}Size_{w,1})+(Size_{u,0}-\sum\limits_{v\in Son(u)}Size_{v,0})\times Size_{u,1}]\times Dep_u

\]

代码如下:

#include<bits/stdc++.h>
#define int long long
#define INF 0x3f3f3f3f
#define N 400005
#define ls k<<1
#define rs k<<1|1
#define mid ((l+r)>>1)
#define mp make_pair
#define pb push_back
#define fi first
#define se second
#define pii pair<int,int>
#define il inline
#define file(x) freopen(x".in","r",stdin);freopen(x".out","w",stdout);
using namespace std;
il int read(){
	int w=0,h=1;char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-')h=-h;ch=getchar();}
	while(ch>='0'&&ch<='9'){w=w*10+ch-'0';ch=getchar();}
	return w*h;
}
struct Edge{
	int next,to;
}edge[N<<2];
char S[N];
int n,q,rev[N];
int head[N],num;
void add(int u,int v){
	edge[++num].next=head[u];
	edge[num].to=v;head[u]=num;
}
namespace SAM{
	struct Node{
		int len,link,ch[26];
	}t[N];
	int tot=1,cur=1;
	void extend(char c){
		int p=c-'a',u=++tot;
		t[u].len=t[cur].len+1;
		swap(u,cur);
		while(u&&!t[u].ch[p])t[u].ch[p]=cur,u=t[u].link;
		if(!u)return t[cur].link=1,void();
		int v=t[u].ch[p];
		if(t[v].len==t[u].len+1)return t[cur].link=v,void();
		int w=++tot;
		t[w].len=t[u].len+1;
		t[w].link=t[v].link;
		t[v].link=t[cur].link=w;
		for(int i=0;i<26;i++)t[w].ch[i]=t[v].ch[i];
		while(u&&t[u].ch[p]==v)t[u].ch[p]=w,u=t[u].link;
	}
	void build(){
		int u=1;
		for(int i=n;i>=1;i--){
			u=t[u].ch[S[i]-'a'];
			rev[i]=u;
		}
		for(int i=2;i<=tot;i++)add(t[i].link,i);
	}
}
namespace Tree{
	struct Node{
		int dep,fa,sz,son,seg,top,bot;
	}t[N];
	int dfn;
	void dfs1(int u,int fa){
		t[u].dep=t[fa].dep+1;
		t[u].fa=fa;t[u].sz=1;
		for(int i=head[u];i;i=edge[i].next){
			int v=edge[i].to;
			if(v==fa)continue;
			dfs1(v,u);
			t[u].sz+=t[v].sz;
			if(t[v].sz>t[t[u].son].sz)t[u].son=v;
		}
	}
	void dfs2(int u,int topf){
		t[u].top=topf;
		t[u].seg=++dfn;
		if(t[u].son)dfs2(t[u].son,topf);
		for(int i=head[u];i;i=edge[i].next){
			int v=edge[i].to;
			if(v==t[u].fa||v==t[u].son)continue;
			dfs2(v,v);
		}
		t[u].bot=dfn;
	}
	int LCA(int x,int y){
		int fx=t[x].top,fy=t[y].top;
		while(fx!=fy){
			if(t[fx].dep<t[fy].dep)swap(fx,fy),swap(x,y);
			x=t[fx].fa;fx=t[x].top;
		}
		if(t[x].dep>t[y].dep)swap(x,y);
		return x;
	}
}
namespace VirTree{
	vector<int>vir,G[N];
	int stk[N],top;
	int sum[N][2];
	int ans;
	bool cmp(int x,int y){return Tree::t[x].seg<Tree::t[y].seg;}
	void build(){
		sort(vir.begin(),vir.end(),cmp);
		for(int i=vir.size()-2;~i;i--)vir.pb(Tree::LCA(vir[i],vir[i+1]));
		sort(vir.begin(),vir.end(),cmp);
		vir.erase(unique(vir.begin(),vir.end()),vir.end());
		top=0;
		for(auto u:vir){
			while(top&&Tree::t[stk[top]].bot<Tree::t[u].seg)top--;
			if(top)G[stk[top]].pb(u);
			stk[++top]=u;
		}
	}
	void clear(){
		for(auto u:vir)G[u].clear(),sum[u][0]=sum[u][1]=0;
		vir.clear();ans=0;
	}
	void dfs(int u){
		int tmp=sum[u][0];
		for(auto v:G[u]){
			dfs(v);
			sum[u][0]+=sum[v][0];
			sum[u][1]+=sum[v][1];
		}
		for(auto v:G[u])ans+=sum[v][0]*(sum[u][1]-sum[v][1])*SAM::t[u].len;
		ans+=tmp*sum[u][1]*SAM::t[u].len;
	}
}
signed main(){
	n=read();q=read();
	scanf("%s",S+1);
	for(int i=n;i>=1;i--)SAM::extend(S[i]);
	SAM::build();
	Tree::dfs1(1,0);
	Tree::dfs2(1,1);
	while(q--){
		int kl=read(),kr=read();
		for(int i=1,u;i<=kl;i++)
			VirTree::vir.pb(u=rev[read()]),VirTree::sum[u][0]++;
		for(int i=1,u;i<=kr;i++)
			VirTree::vir.pb(u=rev[read()]),VirTree::sum[u][1]++;
		VirTree::build();
		VirTree::dfs(VirTree::stk[1]);
		printf("%lld\n",VirTree::ans);
		VirTree::clear();
	}
	return 0;
}
           

还没完,这题还有另一个做法:

看到\(\sum\)和\(LCA\),马上想到[GXOI/GZOI2019]旧词的套路,将\(A\)集合中的点到根的路径区间\(+1\),然后区间查询\(B\)集合中的点到根的路径上权的和。

代码如下:

#include<bits/stdc++.h>
#define int long long
#define INF 0x3f3f3f3f
#define N 400005
#define ls k<<1
#define rs k<<1|1
#define mid ((l+r)>>1)
#define mp make_pair
#define pb push_back
#define fi first
#define se second
#define pii pair<int,int>
#define il inline
#define file(x) freopen(x".in","r",stdin);freopen(x".out","w",stdout);
using namespace std;
il int read(){
	int w=0,h=1;char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-')h=-h;ch=getchar();}
	while(ch>='0'&&ch<='9'){w=w*10+ch-'0';ch=getchar();}
	return w*h;
}
struct Edge{
	int next,to;
}edge[N<<2];
char S[N];
int n,q,rev[N],pos[N];
int kl,A[N],kr,B[N],ans;
int head[N],num;
void add(int u,int v){
	edge[++num].next=head[u];
	edge[num].to=v;head[u]=num;
}
namespace SAM{
	struct Node{
		int len,link,ch[26];
	}t[N];
	int tot=1,cur=1;
	void extend(char c){
		int p=c-'a',u=++tot;
		t[u].len=t[cur].len+1;
		swap(u,cur);
		while(u&&!t[u].ch[p])t[u].ch[p]=cur,u=t[u].link;
		if(!u)return t[cur].link=1,void();
		int v=t[u].ch[p];
		if(t[v].len==t[u].len+1)return t[cur].link=v,void();
		int w=++tot;
		t[w].len=t[u].len+1;
		t[w].link=t[v].link;
		t[v].link=t[cur].link=w;
		for(int i=0;i<26;i++)t[w].ch[i]=t[v].ch[i];
		while(u&&t[u].ch[p]==v)t[u].ch[p]=w,u=t[u].link;
	}
	void build(){
		int u=1;
		for(int i=n;i>=1;i--){
			u=t[u].ch[S[i]-'a'];
			rev[i]=u;
		}
		for(int i=2;i<=tot;i++)add(t[i].link,i);
	}
}
namespace SGT{
	int Sum[N<<2],Val[N<<2],Tag[N<<2],Clr[N<<2];
	void pushup(int k){Sum[k]=Sum[ls]+Sum[rs];}
	void pushdown(int k,int l,int r){
		if(Clr[k]){
			Sum[ls]=Sum[rs]=0;
			Tag[ls]=Tag[rs]=0;
			Clr[ls]=Clr[rs]=1;
			Clr[k]=0;
		}
		if(!Tag[k])return;
		Sum[ls]=Sum[ls]+Tag[k]*Val[ls];
		Sum[rs]=Sum[rs]+Tag[k]*Val[rs];
		Tag[ls]+=Tag[k];
		Tag[rs]+=Tag[k];
		Tag[k]=0;
	}
	void build(int k,int l,int r){
		if(l==r){
			Val[k]=pos[l];
			return;
		}
		build(ls,l,mid);
		build(rs,mid+1,r);
		Val[k]=Val[ls]+Val[rs];
	}
	void modify(int k,int l,int r,int x,int y){
		if(l>y||r<x||l>r||x>y)return;
		if(l>=x&&r<=y){
			Sum[k]+=Val[k];
			Tag[k]+=1;
			return;
		}
		pushdown(k,l,r);
		if(x<=mid)modify(ls,l,mid,x,y);
		if(mid<y)modify(rs,mid+1,r,x,y);
		pushup(k);
	}
	int query(int k,int l,int r,int x,int y){
		if(l>y||r<x||l>r||x>y)return 0;
		if(l>=x&&r<=y)return Sum[k];
		pushdown(k,l,r);
		if(y<=mid)return query(ls,l,mid,x,y);
		if(mid<x)return query(rs,mid+1,r,x,y);
		return query(ls,l,mid,x,y)+query(rs,mid+1,r,x,y);
	}
}
namespace Tree{
	struct Node{
		int dep,fa,sz,son,seg,top,bot;
	}t[N];
	int dfn;
	void dfs1(int u,int fa){
		t[u].dep=t[fa].dep+1;
		t[u].fa=fa;t[u].sz=1;
		for(int i=head[u];i;i=edge[i].next){
			int v=edge[i].to;
			if(v==fa)continue;
			dfs1(v,u);
			t[u].sz+=t[v].sz;
			if(t[v].sz>t[t[u].son].sz)t[u].son=v;
		}
	}
	void dfs2(int u,int topf){
		t[u].top=topf;
		t[u].seg=++dfn;
		pos[dfn]=SAM::t[u].len-SAM::t[t[u].fa].len;
		if(t[u].son)dfs2(t[u].son,topf);
		for(int i=head[u];i;i=edge[i].next){
			int v=edge[i].to;
			if(v==t[u].fa||v==t[u].son)continue;
			dfs2(v,v);
		}
		t[u].bot=dfn;
	}
	int LCA(int x,int y){
		int fx=t[x].top,fy=t[y].top;
		while(fx!=fy){
			if(t[fx].dep<t[fy].dep)swap(fx,fy),swap(x,y);
			x=t[fx].fa;fx=t[x].top;
		}
		if(t[x].dep>t[y].dep)swap(x,y);
		return x;
	}
	void ask_modify(int x,int y){
		int fx=t[x].top,fy=t[y].top;
		while(fx!=fy){
			if(t[fx].dep<t[fy].dep)swap(fx,fy),swap(x,y);
			SGT::modify(1,1,SAM::tot,t[fx].seg,t[x].seg);
			x=t[fx].fa;fx=t[x].top;
		}
		if(t[x].dep>t[y].dep)swap(x,y);
		SGT::modify(1,1,SAM::tot,t[x].seg,t[y].seg);
	}
	int ask_query(int x,int y){
		int fx=t[x].top,fy=t[y].top,res=0;
		while(fx!=fy){
			if(t[fx].dep<t[fy].dep)swap(fx,fy),swap(x,y);
			res+=SGT::query(1,1,SAM::tot,t[fx].seg,t[x].seg);
			x=t[fx].fa;fx=t[x].top;
		}
		if(t[x].dep>t[y].dep)swap(x,y);
		res+=SGT::query(1,1,SAM::tot,t[x].seg,t[y].seg);
		return res;
	}
}
void solve(){
	for(int i=1;i<=kl;i++)
		Tree::ask_modify(1,A[i]);
	for(int i=1;i<=kr;i++)
		ans+=Tree::ask_query(1,B[i]);
}
signed main(){
	n=read();q=read();
	scanf("%s",S+1);
	for(int i=n;i>=1;i--)SAM::extend(S[i]);
	SAM::build();
	Tree::dfs1(1,0);
	Tree::dfs2(1,1);
	SGT::build(1,1,SAM::tot);
	while(q--){
		kl=read(),kr=read();
		for(int i=1;i<=kl;i++)A[i]=rev[read()];
		for(int i=1;i<=kr;i++)B[i]=rev[read()];
		SGT::Clr[1]=1;
		SGT::Sum[1]=SGT::Tag[1]=ans=0;
		solve();
		printf("%lld\n",ans);
	}
	return 0;
}
           

一些细节: