天天看点

[Nowcoder 2018ACM多校第一场H] Longest Path

题目大意:

给你一棵n个节点的树, 带有边权c[i], 定义路径{ e1,e2…ek e 1 , e 2 … e k }的费用是 (e1−e2)2+(e2−e3)2+⋯+(ek−1−ek)2 ( e 1 − e 2 ) 2 + ( e 2 − e 3 ) 2 + ⋯ + ( e k − 1 − e k ) 2 。 求每个节点距自身最远点的距离。 (n≤105,ci≤105,∑n≤106) ( n ≤ 10 5 , c i ≤ 10 5 , ∑ n ≤ 10 6 )

题目思路:

类似于求树的直径做树形dp, 先选定1为根, 用f[i]表示向下走的答案, g[i]表示向上走的答案。 由于费用的更新需要用到两条边, 故扩展一下用f_ch[u][v]表示从u往下走第一步到v的答案, v是u的孩子, 这样复杂度还是O(n)的解决f。

然后考虑向上走的情况g。 考虑已经求出了g[u], 现在要用u来求出他的所有孩子g[v], 对于一个点v来说,

g[v]=max(g[u]+(e(u,fa[u])−e(u,v))2,maxx!=v,x∈son[u]{f_ch[u][x]+(e(v,u)−e(u,x))2}) g [ v ] = max ( g [ u ] + ( e ( u , f a [ u ] ) − e ( u , v ) ) 2 , max x ! = v , x ∈ s o n [ u ] { f _ c h [ u ] [ x ] + ( e ( v , u ) − e ( u , x ) ) 2 } )

对于第二个max是个经典的dp斜率优化的问题, 将e(u, v)排序后, 维护上凸包+单调队列, 正着做一遍反着做一遍即可。

PS: 关于dp斜率优化

考虑dp: f[i]=max{f[j]+(e[i]−e[j])2} f [ i ] = max { f [ j ] + ( e [ i ] − e [ j ] ) 2 }

对与某个转移j, 将式子移项, 分离变量, 只和i有关的部分、 只和j有关的部分、 和i,j均有关的部分。

(f[i]−e[i]2)=(f[j]+e[j]2)−(2∗e[i])∗e[j] ( f [ i ] − e [ i ] 2 ) = ( f [ j ] + e [ j ] 2 ) − ( 2 ∗ e [ i ] ) ∗ e [ j ]

将 f[i]−e[i]2 f [ i ] − e [ i ] 2 看作截距b, f[j]+e[j]2 f [ j ] + e [ j ] 2 看作y, 2∗e[i] 2 ∗ e [ i ] 看作斜率k, e[j] e [ j ] 看作x。

上式可以看作线性函数b = y - kx。

每个j对应一个坐标(x,y), 一系列的j在图上就是一些点, 对于一个i就是一个询问, 每个i对应一个斜率k, 每个i求一个斜率为k的经过图中某个点的最大截距。

这里是取最大值故维护上凸包(取min则维护下凸包), 在本题中, 考虑将e[i]从小到打排序, 先正过来求一遍, 即每个i都会考虑一遍小于它的j。 维护一个单调队列, 对于询问i, 由于询问的斜率是递增的, 按上凸包顺时针方向看, 相邻点构成的斜率递减, 询问i的取最大值的点满足其向下一个点斜率小于询问i的斜率,向上一个点的斜率大于询问i的斜率, 又考虑到询问i的斜率是递增的来询问的, 凸包上的点也是按x坐标递增来加入的, 故应从单调队列的尾端扫描, 根据斜率的比较关系, 斜率越大的询问取最大值的点越靠前, 如果队尾的上一个点由于队尾, 说明对于一个更大的斜率也会优于队尾的, 故弹出队尾元素。 再将i对应的坐标点加入凸包中, 可以用向量叉积判断凸包走向来决定是否删除队尾元素 。 反向求一遍同理。

Code:

#include <map>
#include <set>
#include <map>
#include <bitset>
#include <cmath>
#include <queue>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <algorithm>

#define ll long long
#define db double
#define pw(x) ((x) * (x))
#define fi first
#define se second
#define mp(x, y) make_pair(x, y)

using namespace std;

const int N = (int) + ;

int n;
int cnt, lst[N], nxt[N * ], to[N * ]; ll c[N * ], pre[N];
map<int, ll> f_ch[N];
map<int, ll> :: iterator it;
ll f[N], g[N];

void add(int u, int v, int w){
    nxt[++ cnt] = lst[u]; lst[u] = cnt; to[cnt] = v; c[cnt] = w;
    nxt[++ cnt] = lst[v]; lst[v] = cnt; to[cnt] = u; c[cnt] = w;
}

void dfs(int u, int fa){
    for (int j = lst[u]; j; j = nxt[j]){
        int v = to[j];
        if (v == fa) continue;
        pre[v] = c[j];
        dfs(v, u);

        ll &x = f_ch[u][v];
        x = ;
        for (it = f_ch[v].begin(); it != f_ch[v].end(); it ++){
            x = max(x, it->se + pw(c[j] - pre[it->fi]));
            f[u] = max(f[u], x);
        }

    }
}

pair <ll, int > tmp[N]; int sz;
pair <ll, ll > que[N]; int head, tail;

pair <ll, ll> operator-(pair<ll, ll> a, pair<ll, ll>  b){
    return mp(a.fi-b.fi, a.se-b.se);
}
ll operator*(pair<ll, ll> a, pair<ll, ll>  b){
    return a.fi * b.se - a.se * b.fi;
}


ll cross(pair<ll, ll> a, pair<ll, ll> b, pair<ll, ll> c){
    return (a - b) * (b - c);
}

ll count(pair<ll, ll > x, ll e){
    return x.se-*e*x.fi+pw(e);
}

void dfs2(int u, int fa){
    sz = ;
    for (int j = lst[u]; j; j = nxt[j]){
        int v = to[j];
        if (v == fa) continue;
        tmp[++ sz] = mp(pre[v], v);
    }

    sort(tmp + , tmp + sz + );
    que[head = tail = ] = mp(tmp[].fi, pw(tmp[].fi) + f_ch[u][tmp[].se]);

    for (int i = ; i <= sz; i ++){
        int v = tmp[i].se; ll e = tmp[i].fi;
        while (head < tail && count(que[tail], e) <= count(que[tail - ], e)) tail --;
        g[v] = max(g[v], count(que[tail], e));

        pair<ll, ll> p = mp(e, f_ch[u][v] + pw(e));
        while (head < tail && cross(p, que[tail], que[tail - ]) <= )
            tail --;
        que[++ tail] = p;
    }

    que[head = tail = ] = mp(tmp[sz].fi, pw(tmp[sz].fi) + f_ch[u][tmp[sz].se]);
    for (int i = sz - ; i >= ; i --){
        int v = tmp[i].se; ll e = tmp[i].fi;
        while (head < tail && count(que[tail], e) <= count(que[tail - ], e)) tail --;
        g[v] = max(g[v], count(que[tail], e));

        pair<ll, ll> p = mp(e, f_ch[u][v] + pw(e));
        while (head < tail && cross(p, que[tail], que[tail - ]) >= )
            tail --;
        que[++ tail] = p;
    }

    if (fa){
        for (int i = ; i <= sz; i ++){
            int v = tmp[i].se; ll e = tmp[i].fi;
            g[v] = max(g[v], g[u] + pw(pre[u] - e));
        }
    }

    for (int j = lst[u]; j; j = nxt[j]){
        int v = to[j];
        if (v == fa) continue;
        dfs2(v, u);
    }

}

int getint(){
    int ret = ; char c = getchar();
    while (c > '9' || c < '0') c = getchar();
    while (c <= '9' && c >= '0'){
        ret = ret *  + c - '0';
        c = getchar();
    }
    return ret;
}

int main(){
    while (scanf("%d", &n) != EOF){
        for (int i = , u, v, w; i <= n; i ++){
            u = getint(), v = getint(), w = getint();

            add(u, v, w);
        }

        dfs(, );

        dfs2(, );

        for (int i = ; i <= n; i ++){
            printf("%lld\n", max(f[i], g[i]));
            lst[i] = ;
            f[i] = g[i] = ;
            f_ch[i].clear();
        }
        cnt = ;
    }

    return ;
}