天天看点

CF204E Little Elephant and Strings(后缀数组+尺取,区间max)

LINK

题意

给出

n

n

n个串,求每个串有多少

[

l

,

r

]

[l,r]

[l,r]的区间

使得这段区间在这

n个串中的至少

k

k

k个串包含。

n个串穿起来,中间用不同的字符连接

对于串

s

s

s如何计算答案…

考虑对串

s的每个后缀计算一遍答案

相当于固定了左端点

l

l,要求最大的右端点

r

r

相当于找到

k个来自不同字符串的后缀(包括串s的后缀),使得它们的

c

p

lcp

lcp最大

那么串

s的答案加上

lcp

于是我们按照

a

sa

sa的顺序尺取一些恰好包括

k个不同串的区间

尺取的规则是对于每个

l,尺取一个

r满足

[l,r]刚好有

k个不同子串

区间最小值

lcp就是贡献,拿这个

lcp给这段区间每个点取

m

x

max

max

但是其他点怎么办??这段区间的贡献还可以往左往右扩展啊,只不过

lcp会变小

很可能前面尺取的区间

lcp,还不如现在这个区间的

lcp往左边扩展

或者甚至后面的某些点根本没有被长度为

k的包含进来

我们这样,让每个区间的贡献往右扩展,这样就不需要考虑往左边扩展

想一下一个区间的

lcp是

x

x,往右边扩展

lcp变成

i

(

h

e

g

t

)

min(lcp,height[i])

min(lcp,height[i])

所以我们最后做一个前缀最大值即可,也就是

m

x

[

i

]

=

a

(

,

n

1

h

e

g

t

)

mx[i]=max(mx[i],min(mx[i-1],height[i]))

mx[i]=max(mx[i],min(mx[i−1],height[i]))

因为和上一个后缀最长是

height[i]

height[i],所以需要取

min

min

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 1e6+10;
char a[maxn];
int s[maxn],x[maxn],y[maxn],c[maxn],sa[maxn],rk[maxn],height[maxn],n,m;
void get_sa(int n)
{
//	m = 30000;
	for(int i=1;i<=n;i++)	++c[x[i]=s[i]];
	for(int i=2;i<=m;i++)	c[i] += c[i-1];
	for(int i=n;i>=1;i--)	sa[c[x[i]]--] = i;
	for(int k=1;k<=n;k<<=1)
	{
		int num = 0;
		for(int i=n-k+1;i<=n;i++)	y[++num] = i;
		for(int i=1;i<=n;i++)	if( sa[i]>k )	y[++num] = sa[i]-k;
		for(int i=0;i<=m;i++)	c[i] = 0;
		for(int i=1;i<=n;i++)	++c[x[i]];
		for(int i=1;i<=m;i++)	c[i] += c[i-1]; 
		for(int i=n;i>=1;i--)	sa[c[x[y[i]]]--] = y[i], y[i] = 0;
		swap(x,y);
		x[sa[1]] = 1, num = 1;
		for(int i=2;i<=n;i++)
			x[sa[i]] = ( y[sa[i]]==y[sa[i-1]] && y[sa[i]+k]==y[sa[i-1]+k] )?num:++num; 
		if( num==n )	break;
		m = num;
	}
	for(int i=1;i<=n;i++)	rk[sa[i]] = i;
	for(int k=0,i=1;i<=n;i++)
	{
		if( rk[i]==1 )	continue;
		if( k )	k--;
		int j = sa[rk[i]-1];
		while( i+k<=n&&j+k<=n&&s[i+k]==s[j+k] )	k++;
		height[rk[i]] = k;
	}	
}
int vis[maxn],now,id[maxn],mx[maxn];//尺取 
void del(int l,int &now)
{
	vis[id[sa[l]]]--;
	if( vis[id[sa[l]]]==0 )	now--;
}
void add(int r,int &now)
{
	vis[id[sa[r]]]++;
	if( vis[id[sa[r]]]==1 )	now++;
}
int st[maxn][22],top,k;
void init()
{
	for(int i=1;i<=top;i++)	st[i][0] = height[i];
	for(int i=1;i<=20;i++)
	for(int j=1;j+(1<<i)-1<=top;j++)
		st[j][i] = min( st[j][i-1],st[j+(1<<(i-1))][i-1] );
}
int get(int l,int r)
{
	int k = log2(r-l+1);
	return min( st[l][k],st[r-(1<<k)+1][k] );
}
ll ans[maxn];
int main()
{
	cin >> n >> k;
    if(k==1)//特判k=1的情况
    {
    	for(long long i=1;i<=n;i++)
    	{
    		scanf("%s",a+1);
    		long long len=strlen(a+1);
    		printf("%lld ",(len*(len+1))>>1);
        }
        return 0;
    }
    m = n+200;
	for(int i=1;i<=n;i++)
	{
		scanf("%s",a+1); int len = strlen( a+1 );
		for(int j=1;j<=len;j++)	s[++top] = a[j]+n,id[top] = i;
		s[++top] = i;	
	}
	get_sa(top); init();
	int now = 0;
	for(int r=n,l=n+1;l<=top;l++)
	{
		if( l>n+1 )	del(l-1,now);
		while( r<l )	add(++r,now);
		while( r<top&&now<k )	add(++r,now);
		if( now<k )	break;
		//求区间最小值
		int lcp = get(l+1,r);
		//更新这段区间的每个点 
		for(int j=l;j<=r;j++)	mx[j] = max( mx[j],lcp );
	}
	for(int i=n+2;i<=top;i++)	mx[i] = max( mx[i],min( mx[i-1],height[i] ) );
	for(int i=n+1;i<=top;i++)	ans[id[sa[i]]] += mx[i];	
	for(int i=1;i<=n;i++)	cout << ans[i] << " ";
}
           

还是类似的做法

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 1e6+10;
char a[maxn];
int s[maxn],x[maxn],y[maxn],c[maxn],sa[maxn],rk[maxn],height[maxn],n,m;
void get_sa(int n)
{
	m = 300;
	for(int i=1;i<=n;i++)	++c[x[i]=s[i]];
	for(int i=2;i<=m;i++)	c[i] += c[i-1];
	for(int i=n;i>=1;i--)	sa[c[x[i]]--] = i;
	for(int k=1;k<=n;k<<=1)
	{
		int num = 0;
		for(int i=n-k+1;i<=n;i++)	y[++num] = i;
		for(int i=1;i<=n;i++)	if( sa[i]>k )	y[++num] = sa[i]-k;
		for(int i=1;i<=m;i++)	c[i] = 0;
		for(int i=1;i<=n;i++)	++c[x[i]];
		for(int i=2;i<=m;i++)	c[i] += c[i-1]; 
		for(int i=n;i>=1;i--)	sa[c[x[y[i]]]--] = y[i], y[i] = 0;
		swap(x,y);
		x[sa[1]] = 1, num = 1;
		for(int i=2;i<=n;i++)
			x[sa[i]] = ( y[sa[i]]==y[sa[i-1]] && y[sa[i]+k]==y[sa[i-1]+k] )?num:++num; 
		if( num==n )	break;
		m = num;
	}
	for(int i=1;i<=n;i++)	rk[sa[i]] = i;
	for(int k=0,i=1;i<=n;i++)
	{
		if( rk[i]==1 )	continue;
		if( k )	k--;
		int j = sa[rk[i]-1];
		while( i+k<=n&&j+k<=n&&s[i+k]==s[j+k] )	k++;
		height[rk[i]] = k;
	}	
}
int vis[maxn],now,id[maxn],mx[maxn];//尺取 
void del(int l,int &now)
{
	vis[id[sa[l]]]--;
	if( vis[id[sa[l]]]==0 )	now--;
}
void add(int r,int &now)
{
	vis[id[sa[r]]]++;
	if( vis[id[sa[r]]]==1 )	now++;
}
int st[maxn][22],top,k;
void init()
{
	for(int i=1;i<=top;i++)	st[i][0] = height[i];
	for(int i=1;i<=20;i++)
	for(int j=1;j+(1<<i)-1<=top;j++)
		st[j][i] = min( st[j][i-1],st[j+(1<<(i-1))][i-1] );
}
int get(int l,int r)
{
	int k = log2(r-l+1);
	return min( st[l][k],st[r-(1<<k)+1][k] );
}
ll ans[maxn];
int main()
{
	cin >> n >> k;
	for(int i=1;i<=n;i++)
	{
		scanf("%s",a+1); int len = strlen( a+1 );
		for(int j=1;j<=len;j++)	s[++top] = a[j]-'a'+n,id[top] = i;
		s[++top] = i;	
	}
	get_sa(top); init();
	int now = 0;
	for(int l=n+1,r=n+1;r<=top;r++)
	{
		add(r,now);//加入右端点
		while( now>=k )
		{
			if( now==k )
			{
				int lcp = get(l+1,r);
				cout << l << " " << r << " " << lcp << endl;
				for(int j=l;j<=r;j++)	mx[j] = max( mx[j],lcp );				
			}
			if( vis[id[sa[l]]]==1 )
			{
				if( now==k )	break;
				now--;
			}
			vis[id[sa[l]]]--; l++;
		} 
	}
//	for(int i=n+2;i<=top;i++)	mx[i] = max( mx[i],min( mx[i-1],height[i] ) );
	for(int i=n+1;i<=top;i++)	ans[id[sa[i]]] += mx[i];	
	for(int i=1;i<=n;i++)	cout << ans[i] << " ";
}