天天看點

zoj 3649 lca 倍增 DP

下課後直接奔實驗室,月賽都快結束了,我擦,随便挑了一道題開始搞起來,是我喜歡的圖論,心裡暗暗欣喜,可是還是有點繞人,比賽結束剛好敲完,吃晚飯後,調了調,AC,呵呵

比賽的時候A的貌似很少,其實也不太難

連結:http://acm.zju.edu.cn/onlinejudge/showProblem.do?problemCode=3649

下面是我做過的倍增法求LCA以及倍增的DP的兩個練習題

http://blog.csdn.net/haha593572013/article/details/7796497

http://blog.csdn.net/haha593572013/article/details/7855282

題意:最大生成樹的那個就不說了,兩個算法的疊加,沒啥意思,抽象一下就是給你一棵樹,每個點都有點權,然後有很多詢問

每個詢問是兩個數x   y,然後要你求出最大內插補點,最大內插補點是這樣的:在這棵樹上從x走到y會得到一個點權的序列

 c1, c2, c3, ... ,ci

 find the maximum ck-cj (ck >= cj, j <= k). 

關鍵的難點在于一定要用後面的數減去前面的數

解法:

開始的時候一直在糾結最大減最小,還要最大的在最小的後面,如果沒有後面這個限制,那就直接樹鍊剖分或者倍增都可以解決,後面的限制是本題的亮點,仔細想想,應該會需要用到求lca,求某個點到lca點權的最大值或者最小值,然後再仔細一想,想在log(n)的時間内求出u到lca的“最大內插補點”,除了資料結構還能有什麼(資料結構顯然有點無力),可能是我太弱,實在想不出用什麼資料結構可以搞定,然後我就轉向倍增DP的思路,還是和上面第二個連結的DP類似,隻不過這道題需要一系列的DP數組

p[u][i]表示u的2的i次個祖先

mx[u][i]表示u到u的2的i次個祖先之間的最大點權值

mi[u][i]表示u到u的2的i次個祖先之間的最小點權值

dp[u][i]表示u到u的2的i次個祖先之間的最大內插補點

dp2[u][i]表示u的2的i次個祖先到u之間的最大內插補點

之是以需要dp2是因為從x到lca再從lca到y的路線是兩個相反的方向

求mx mi數組很簡單,和求p數組一模一樣

dp數組的求法也不難

路徑利用二進制被分成了一段一段,除了每一段的資訊可以更新目前狀态,段與段之間也需要考慮

然後在求答案的時候也是類似的方法

具體見代碼吧,應該能看懂的,都很簡單- -

#include<cstdio>
#include<cstring>
#include<vector>
#include<algorithm>
using namespace std;
const int inf = ~0u>>2;
const int maxn = 30010;
const int POW = 16;
int mi[maxn][POW],mx[maxn][POW],p[maxn][POW];
int f[maxn];
int find(int x) {return x==f[x] ? x : f[x]=find(f[x]);}
struct EDGE{
	int s,t,w;
}e[50010];
int cmp(EDGE x,EDGE y){
	return x.w>y.w;
}
vector<int> edge[maxn];
int n,m;
int val[maxn];
bool vis[maxn];
int d[maxn];

int dp[maxn][POW],dp2[maxn][POW];
inline int max(int a,int b) {
	return  a>b?a:b;
}
inline int min(int a,int b) {
	return a<b?a:b;
}
void dfs(int u,int f){  
	d[u]=d[f]+1;
	vis[u]=true;
    int sz=edge[u].size(),j;  
    for(int i=0;i<sz;i++){  
        int v=edge[u][i];  
		if(vis[v])  continue;
        p[v][0]=u; 
		mi[v][0]=min(val[v],val[u]);
        mx[v][0]=max(val[v],val[u]);  
		dp[v][0]=val[u]-val[v];
		dp2[v][0]=val[v]-val[u];
        for(j=1;j<POW;j++) {
			p[v][j]=p[p[v][j-1]][j-1]; 
			mx[v][j]=max(mx[v][j-1],mx[p[v][j-1]][j-1]);  
			mi[v][j]=min(mi[v][j-1],mi[p[v][j-1]][j-1]);

			dp[v][j]=max(dp[v][j-1],dp[p[v][j-1]][j-1]);
			dp[v][j]=max(dp[v][j]  ,mx[p[v][j-1]][j-1]-mi[v][j-1]) ;

			dp2[v][j]=max(dp2[v][j-1],dp2[p[v][j-1]][j-1]);
			dp2[v][j]=max(dp2[v][j]  ,mx[v][j-1] - mi[p[v][j-1]][j-1]) ;
		}
        dfs(v,u);  
    }  
}  
int LCA( int a, int b ){  
	int i;
    if( d[a] > d[b] ) a ^= b, b ^= a, a ^= b;  
    if( d[a] < d[b] ){  
        int del = d[b] - d[a];  
        for(i = 0; i < POW; i++ ) if(del&(1<<i)) b=p[b][i];  
    }  
    if( a != b ){  
        for(i = POW-1; i >= 0; i-- )   
            if( p[a][i] != p[b][i] )   
                 a = p[a][i] , b = p[b][i];  
        a = p[a][0], b = p[b][0];  
    }  
    return a;  
}  
void init(int n)  {
	 memset(vis,false,sizeof(vis));
	 fill(p[0],p[n+1],0);
	 fill(mx[0],mx[n+1],-inf);
	 fill(mi[0],mi[n+1],inf);
	 fill(dp[0],dp[n+1],-inf);
	 fill(dp2[0],dp2[n+1],-inf);
	 d[0]=0;
     dfs(1,0);
}
int getmin(int u,int lca,int dp[][POW]) {
	int ans=inf;
    int del=d[u] - d[lca];
    for(int i=POW-1;i>=0;i--) if(del & (1<<i)) {
		ans=min(ans,dp[u][i]);
		u=p[u][i];
	}
    return ans;
}
int getmax(int x,int lca,int dp[][POW]){
    int ans=0;
	int del=d[x]-d[lca];
	for(int i=POW-1;i>=0;i--) if(del & (1<<i)){
		ans=max(ans,dp[x][i]);
		x=p[x][i];
	}
	return ans;
}
int gao1(int x,int lca,int dp[][POW]) {
	int ans=0,tmp=0;
	int del=d[x]-d[lca];
	for(int i=POW-1;i>=0;i--) if(del & (1<<i)) {
		ans=max(ans,dp[x][i]);
		ans=max(ans,tmp-mi[x][i]);
		tmp=max(tmp,mx[x][i]);
		x=p[x][i];
	}
	return ans;
}
int gao2(int x,int lca,int dp[][POW]) {
		int ans=0,tmp=inf;
	int del=d[x]-d[lca];
	for(int i=POW-1;i>=0;i--) if(del & (1<<i)) {
		ans=max(ans,dp[x][i]);
		ans=max(ans,-(tmp-mx[x][i]));
		tmp=min(tmp,mi[x][i]);
		x=p[x][i];
	}
	return ans;
}
void solve(int x,int y) {
     int lca=LCA(x,y);
	 int a,b,c,d;
	 a=gao2(x,lca,dp);
	 b=gao1(y,lca,dp2);
	 c=getmax(y,lca,mx);
	 d=getmin(x,lca,mi);
	 int ans=max(max(a,b),c-d);
	 printf("%d\n",ans);
}
int main() {
	int i,j,k,x,y,q;
	while(scanf("%d",&n)!=EOF){
		for(i=1;i<=n;i++) scanf("%d",&val[i]),f[i]=i,edge[i].clear();
		scanf("%d",&m);
		for(i=0;i<m;i++)
			scanf("%d%d%d",&e[i].s,&e[i].t,&e[i].w);
		sort(e,e+m,cmp);
		int sum=0;
		for(i=0;i<m;i++)	{
             x=find(e[i].s);
			 y=find(e[i].t);
             if(x!=y) {
				 edge[e[i].s].push_back(e[i].t);
				 edge[e[i].t].push_back(e[i].s);
				 f[x]=y;
				 sum+=e[i].w;
			 }
		}
		printf("%d\n",sum);
		init(n);
		scanf("%d",&q);
		while(q--)	{
			scanf("%d%d",&x,&y);
            solve(x,y);
		}
	}
	return 0;
}
           

繼續閱讀