天天看點

【JZOJ5058】【GDSOI2017模拟4.13】采蘑菇

Description

A君住在魔法森林裡,魔法森林可以看做一棵n個結點的樹,結點從1~n編号。樹中的每個結點上都生長着蘑菇。蘑菇有許多不同的種類,但同一個結點上的蘑菇都是同一種類,更具體地,i号結點上生長着種類為c[i]的蘑菇。

現在A君打算出去采蘑菇,但他并不知道哪裡的蘑菇更好,是以他標明起點s後會等機率随機選擇樹中的某個結點t作為終點,之後從s沿着(s,t)間的最短路徑走到t.并且A君會采摘途中所經過的所有結點上的蘑菇。

現在A君想知道,對于每一個結點u,假如他從這個結點出發,他最後能采摘到的蘑菇種類數的期望是多少。為了友善,你告訴A君答案*n的值即可。

Data Constraint

30%的資料:n <= 2000

另有20%的資料:給出的第i條邊為{i,i+1}

另有20%的資料:蘑菇的種類最多3種

100%的資料:1 <= n <= 3*10^5 , 0 <= c[i] <= n

Solution

這又是一道樹分治的題。但我用線段樹加換根打過了。我們維護顔色棵線段樹維護區間[l,r]中每個點到目前根的路徑上有這個顔色的點的數量。再換根時,我們先撤銷根的父親原來對根的子樹的影響,在加上目前根對全圖的影響,統計答案變化值+父親作為根即可,時間複雜度O(NlogN)。

Code

#include<iostream>
#include<cmath>
#include<cstring>
#include<cstdio>
#include<algorithm>
using namespace std;
const int maxn=+,maxn1=+;
struct code{
    int l,r,bz,num;
}f[maxn*];
int first[maxn],last[maxn],next[maxn],a[maxn1],d[maxn1],b[maxn1],v[maxn1];
int n,i,t,j,l,x,y,z,num,num1,dfn[maxn1],size[maxn1],fa[maxn1];
long long k,ans[maxn1];
void lian(int x,int y){
    last[++num]=y;next[num]=first[x];first[x]=num;
}
void insert(int &v,int l,int r,int x,int y,int z){
    int mid=(l+r)/;
    if (x>y) return;
    if (!v) v=++num1,f[v].num=;
    if (l!=r){
        if (!f[v].l) f[v].l=++num1;
        if (!f[v].r) f[v].r=++num1;
    }
    if (l>=x && r<=y){
        if (l== && r==n) k=-f[v].num;
        f[v].bz+=z;
        if (f[v].bz>) f[v].num=r-l+;
        else if (l!=r)f[v].num=f[f[v].l].num+f[f[v].r].num;
        else f[v].num=;
        if (l== && r==n) k+=f[v].num;
        return;
    }
    if (l<=y && mid>=x) insert(f[v].l,l,mid,x,y,z);
    if (mid<y && r>=x) insert(f[v].r,mid+,r,x,y,z);
    if (l== && r==n) k-=f[v].num;
    if (f[v].bz>) f[v].num=r-l+;
    else f[v].num=f[f[v].l].num+f[f[v].r].num;
    if (l== && r==n) k+=f[v].num;
}
void bfs(){
    int i=,j=,x,t;v[]=;fa[]=;
    while(i<j){
        x=v[++i];size[x]=;
        for (t=first[x];t;t=next[t]){
            if (fa[last[t]])continue;
            v[++j]=last[t];fa[v[j]]=x;
        }
    }
    fa[]=;
    for (j=n;j>=;j--)
        size[fa[v[j]]]+=size[v[j]];
    dfn[]=;
    for (i=;i<=n;i++){
        x=v[i];
        k=;
        for (t=first[x];t;t=next[t])
            if (last[t]!=fa[x])dfn[last[t]]=dfn[x]+k,k+=size[last[t]];
    }
    for (j=n;j>=;j--)
        x=v[j],insert(d[a[x]],,n,dfn[x],dfn[x]+size[x]-,);
    for (x=;x<=n;x++)
        v[dfn[x]]=x;
}
int main(){
    freopen("mushroom.in","r",stdin);freopen("mushroom.out","w",stdout);
    scanf("%d",&n);
    for (i=;i<=n;i++)
        scanf("%d",&a[i]);
    for (i=;i<n;i++)
        scanf("%d%d",&x,&y),lian(x,y),lian(y,x);
    num=;
    bfs();
    for (i=;i<=n;i++)
        k=f[d[i]].num,ans[]+=k;
    for (i=;i<=n;i++){
        x=v[i];y=fa[x];b[++b[]]=x;
        if (y){
            k=;
            insert(d[a[y]],,n,dfn[y],dfn[y]+size[y]-,-);
            insert(d[a[y]],,n,,dfn[x]-,);insert(d[a[y]],,n,dfn[x]+size[x],n,);
            insert(d[a[x]],,n,,dfn[x]-,);insert(d[a[x]],,n,dfn[x]+size[x],n,);
            ans[x]=ans[y]+k;
        }
        while (dfn[b[b[]]]+size[b[b[]]]-==dfn[v[i]]){
            x=b[b[]--];y=fa[x];
            if (y){
                insert(d[a[x]],,n,,dfn[x]-,-);insert(d[a[x]],,n,dfn[x]+size[x],n,-);
                insert(d[a[y]],,n,,dfn[x]-,-);insert(d[a[y]],,n,dfn[x]+size[x],n,-);
                insert(d[a[y]],,n,dfn[y],dfn[y]+size[y]-,);
            }
        }
    }
    for (i=;i<=n;i++)
        printf("%lld\n",ans[i]);
}
           

順便附上我打的樹分治

#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
#include<algorithm>
#define ll long long
using namespace std;
const int maxn=e5+;
ll aw[maxn],an[maxn],size[maxn],ans[maxn],mx[maxn];
int first[maxn],last[maxn],next[maxn],a[maxn];
int n,i,t,j,k,l,x,y,z,num,p,bz[maxn],bz1[maxn],cnt[maxn],q,sum;
void lian(int x,int y){
    last[++num]=y;next[num]=first[x];first[x]=num;
}
void dg1(int x,int y){int t;size[x]=;mx[x]=;aw[a[x]]=;
    for (t=first[x];t;t=next[t])
        if (last[t]!=y && !bz[last[t]])dg1(last[t],x),size[x]+=size[last[t]],mx[x]=max(mx[x],size[last[t]]);
}
int find(int x,int y){
    if (max(p-size[x],mx[x])*2<=p) return x;int t;
    for (t=first[x];t;t=next[t]){
        if (last[t]==y || bz[last[t]]) continue;
        k=find(last[t],x);
        if (k) return k;
    }return ;
}
void dg2(int x,int y,ll sum){int t;
    if (bz1[a[x]]!=p) bz1[a[x]]=p,cnt[a[x]]=,sum++,aw[a[x]]+=size[x];
    cnt[a[x]]++;an[x]=sum;
    for (t=first[x];t;t=next[t]){
        if (last[t]==y || bz[last[t]]) continue;
        dg2(last[t],x,sum);
        an[x]+=an[last[t]];
    }
    cnt[a[x]]--;if (!cnt[a[x]]) bz1[a[x]]=-;
}
void dg3(int x,int y,ll c){int t;
    if (bz1[a[x]]!=p) bz1[a[x]]=p,cnt[a[x]]=,num++,c+=aw[a[x]];
    cnt[a[x]]++;ans[x]+=z*(num*(q-size[x])+(sum-an[x])-(c-size[x]*num));
    for (t=first[x];t;t=next[t]){
        if (last[t]==y || bz[last[t]]) continue;
        dg3(last[t],x,c);
    }
    cnt[a[x]]--;if (!cnt[a[x]]) bz1[a[x]]=-,num--;
}
void dg(int x){
    int t;bz[x]=;p=x;
    dg1(x,);
    dg2(x,,);
    ans[x]+=an[x];cnt[a[x]]=,num=;z=;
    for (t=first[x];t;t=next[t]){
        if (bz[last[t]]) continue;
        q=size[x];p=last[t];sum=an[x];bz1[a[x]]=p;
        dg3(last[t],x,aw[a[x]]);
    }
    z=-;
    for (t=first[x];t;t=next[t]){
        if (bz[last[t]]) continue;p=last[t];
        dg1(last[t],);bz1[a[x]]=p,cnt[a[x]]=,aw[a[x]]=size[last[t]];
        dg2(last[t],,);
        q=size[last[t]];sum=an[last[t]];num=;
        dg3(last[t],x,);
    }
    bz1[a[x]]=-;
    for (t=first[x];t;t=next[t])
        if (!bz[last[t]])p=size[last[t]],k=find(last[t],),dg(k);
}
int main(){
    freopen("mushroom.in","r",stdin);freopen("mushroom.out","w",stdout);
    scanf("%d",&n);
    for (i=;i<=n;i++)
        scanf("%d",&a[i]);
    for (i=;i<n;i++)
        scanf("%d%d",&x,&y),lian(x,y),lian(y,x);
    dg(n/);
    for (i=;i<=n;i++)
        printf("%lld\n",ans[i]);
}