天天看点

[LOJ]#2553. 「CTSC2018」暴力写挂 边分治+线段树合并

Solution

这题搞了好久……不过还是挺有收获的。

听说这种多棵树的题大概都是这样的套路?枚举第二棵树的 L C A ( x , y ) LCA(x,y) LCA(x,y),然后化一下式子可以发现: d e p t h x + d e p t h y − d e p t h L C A ( x , y ) = 1 2 ( d e p t h x + d e p t h y + d i s t a n c e x , y ) depth_x+depth_y-depth_{LCA(x,y)}={1\over 2}(depth_x+depth_y+distance_{x,y}) depthx​+depthy​−depthLCA(x,y)​=21​(depthx​+depthy​+distancex,y​)

这样问题就转化为在当前枚举点的两棵不同子树中各找一个点,求在第一棵树上深度和+距离的最大值。这样枚举的时候每次把一棵子树合并过来,一般用可以合并的数据结构,可以 d s u dsu dsu,但这样是 n l o g 2 n nlog^2n nlog2n的,无法通过。这时候又想到另外一个东西——线段树合并。但如何用线段树维护这些信息呢?有一个与线段树结构十分相似的东西,边分树,同样是二叉树的结构,不仅可以维护这些信息,而且可以方便地合并。

边分树的每个叶子节点代表原树的结点,而其他节点代表的是一条边,那么对于每个非叶子节点维护左右儿子的 d e p t h x + depth_x+ depthx​+到该边距离的最大值即可。

由于边分治的复杂度跟点的度数有关,所以要把原树转成二叉树,这样点数仍然是 O ( n ) O(n) O(n)级别。

我的代码全是自己YY的,所以又长又丑……

Code

#include<bits/stdc++.h>
using namespace std;
#define LL long long
#define pa pair<int,int>
const int Maxn=2200000;
const LL inf=(1LL<<60);
int read()
{
 int x=0,f=1;char ch=getchar();
 while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
 while(ch>='0'&&ch<='9')x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
 return x*f;
}
int n,ls[Maxn],rs[Maxn],fa[Maxn],fir[Maxn],g[Maxn][20],gg=0,dep[Maxn],Log[Maxn];LL dis[Maxn];
int root[Maxn],lc[Maxn*20],rc[Maxn*20],cnt=0;LL lx[Maxn*20],rx[Maxn*20],ans=-inf;
struct Edge{int x,y,d,next;};
struct Graph
{
 Edge e[Maxn<<1];
 int last[Maxn],len;
 void ins(int x,int y,int d)
 {
  int t=++len;
  e[t].x=x;e[t].y=y;e[t].d=d;
  e[t].next=last[x];last[x]=t;
 }
 void addedge(int x,int y,int d){ins(x,y,d),ins(y,x,d);}
}G1,G2,G3;
int tmp[Maxn],lt,id,V[Maxn];
int build(int l,int r)
{
 if(l==r)return tmp[l];
 int x=++id,mid=l+r>>1,lc=build(l,mid),rc=build(mid+1,r);
 G2.addedge(x,lc,(lc<=n)?V[lc]:0),G2.addedge(x,rc,(rc<=n)?V[rc]:0);
 return x;
} 
void dfs1(int x,int ff)
{
 lt=0;
 for(int i=G1.last[x];i;i=G1.e[i].next)
 {
  int y=G1.e[i].y;
  if(y==ff)continue;
  tmp[++lt]=y;V[y]=G1.e[i].d;
 }
 if(!lt)return;
 int rt=build(1,lt);
 if(rt>n)G2.addedge(x,rt,0);else G2.addedge(x,rt,V[rt]);
 for(int i=G1.last[x];i;i=G1.e[i].next)
 {
  int y=G1.e[i].y;
  if(y==ff)continue;
  dfs1(y,x);
 }
} 
int LCA(int x,int y)
{
 if(fir[x]>fir[y])swap(x,y);
 int t=Log[fir[y]-fir[x]+1];
 int p1=g[fir[x]][t],p2=g[fir[y]-(1<<t)+1][t];
 if(dep[p1]<dep[p2])return p1;
 return p2;
}
LL Dis(int x,int y){return dis[x]+dis[y]-dis[LCA(x,y)]*2LL;}
int pp,pd,sz[Maxn],tot,mn,rt;bool ban[Maxn];
int get_size(int x,int ff)
{
 int re=1;
 for(int i=G2.last[x];i;i=G2.e[i].next)
 {
  int y=G2.e[i].y;
  if(y==ff||ban[i])continue;
  re+=get_size(y,x);
 }
 return re;
}
void get_edge(int x,int ff)
{
 sz[x]=1;
 for(int i=G2.last[x];i;i=G2.e[i].next)
 {
  int y=G2.e[i].y;
  if(y==ff||ban[i])continue;
  get_edge(y,x);
  int t=max(sz[y],tot-sz[y]);
  if(t<mn)mn=t,pp=i;
  sz[x]+=sz[y];
 }
}
void dfs2(int x,int y,int z)
{
 mn=2147483647;tot=get_size(x,0);get_edge(x,0);
 if(sz[x]==1)ls[z]=x,fa[x]=z;
 else
 {
  ls[z]=id+pp;fa[id+pp]=z;ban[pp]=ban[pp^1]=true;
  dfs2(G2.e[pp].x,G2.e[pp].y,id+pp);
 }
 mn=2147483647;tot=get_size(y,0);get_edge(y,0);
 if(sz[y]==1)rs[z]=y,fa[y]=z;
 else
 {
  rs[z]=id+pp;fa[id+pp]=z;ban[pp]=ban[pp^1]=true;
  dfs2(G2.e[pp].x,G2.e[pp].y,id+pp);
 }
}
void dfs4(int x,int ff)
{
 fir[x]=++gg;g[gg][0]=x;dep[x]=dep[ff]+1;
 for(int i=G2.last[x];i;i=G2.e[i].next)
 {
  int y=G2.e[i].y;
  if(y==ff)continue;
  dis[y]=dis[x]+G2.e[i].d;
  dfs4(y,x);g[++gg][0]=x;
 }
}
void merge(int &u1,int u2,int ID,LL dd)
{
 if(!u1){u1=u2;return;}
 if(!u2)return;
 ans=max(ans,G2.e[ID-id].d+lx[u1]+rx[u2]-2LL*dd);
 ans=max(ans,G2.e[ID-id].d+lx[u2]+rx[u1]-2LL*dd);
 lx[u1]=max(lx[u1],lx[u2]);
 rx[u1]=max(rx[u1],rx[u2]);
 merge(lc[u1],lc[u2],ls[ID],dd),merge(rc[u1],rc[u2],rs[ID],dd);
}
void dfs3(int x,int ff,LL dd)
{
 ans=max(ans,2LL*dis[x]-2LL*dd);
 for(int i=G3.last[x];i;i=G3.e[i].next)
 {
  int y=G3.e[i].y;
  if(y==ff)continue;
  dfs3(y,x,dd+G3.e[i].d);
  merge(root[x],root[y],rt,dd);
 }
}
int main()
{
 G1.len=0,G2.len=1,G3.len=0;
 n=read();id=n;
 for(int i=1;i<n;i++)
 {
  int x=read(),y=read(),d=read();
  G1.addedge(x,y,d);
 }
 dfs1(1,0);
 dep[0]=-1;dfs4(1,0);
 Log[1]=0;for(int i=2;i<=gg;i++)Log[i]=Log[i>>1]+1; 
 for(int j=1;(1<<j)<=gg;j++)
 for(int i=1;i+(1<<j)-1<=gg;i++)
 {
  int p1=g[i][j-1],p2=g[i+(1<<(j-1))][j-1];
  if(dep[p1]<dep[p2])g[i][j]=p1;
  else g[i][j]=p2;
 }
 tot=id;mn=2147483647;get_edge(1,0);
 rt=id+pp;ban[pp]=ban[pp^1]=true; 
 dfs2(G2.e[pp].x,G2.e[pp].y,rt);
 for(int i=1;i<=n;i++)
 {
  int l1=-1,l2,x=i;
  while(1)
  {
   int t=++cnt; 
   if(l1!=-1)
   {
    if(ls[x]==l1)lc[t]=l2,lx[t]=Dis(G2.e[x-id].x,i)+dis[i],rx[t]=-inf;
    else rc[t]=l2,rx[t]=Dis(G2.e[x-id].y,i)+dis[i],lx[t]=-inf;
   }
   else lx[t]=rx[t]=-inf;
   l1=x,l2=t;
   if(x==rt){root[i]=t;break;}
   x=fa[x];
  }
 }
 for(int i=1;i<n;i++)
 {
  int x=read(),y=read(),d=read();
  G3.ins(x,y,d),G3.ins(y,x,d);
 }
 dfs3(1,0,0);
 printf("%lld",ans/2LL); 
}