天天看點

[Nowcoder 2018ACM多校第一場H] Longest Path

題目大意:

給你一棵n個節點的樹, 帶有邊權c[i], 定義路徑{ e1,e2…ek e 1 , e 2 … e k }的費用是 (e1−e2)2+(e2−e3)2+⋯+(ek−1−ek)2 ( e 1 − e 2 ) 2 + ( e 2 − e 3 ) 2 + ⋯ + ( e k − 1 − e k ) 2 。 求每個節點距自身最遠點的距離。 (n≤105,ci≤105,∑n≤106) ( n ≤ 10 5 , c i ≤ 10 5 , ∑ n ≤ 10 6 )

題目思路:

類似于求樹的直徑做樹形dp, 先標明1為根, 用f[i]表示向下走的答案, g[i]表示向上走的答案。 由于費用的更新需要用到兩條邊, 故擴充一下用f_ch[u][v]表示從u往下走第一步到v的答案, v是u的孩子, 這樣複雜度還是O(n)的解決f。

然後考慮向上走的情況g。 考慮已經求出了g[u], 現在要用u來求出他的所有孩子g[v], 對于一個點v來說,

g[v]=max(g[u]+(e(u,fa[u])−e(u,v))2,maxx!=v,x∈son[u]{f_ch[u][x]+(e(v,u)−e(u,x))2}) g [ v ] = max ( g [ u ] + ( e ( u , f a [ u ] ) − e ( u , v ) ) 2 , max x ! = v , x ∈ s o n [ u ] { f _ c h [ u ] [ x ] + ( e ( v , u ) − e ( u , x ) ) 2 } )

對于第二個max是個經典的dp斜率優化的問題, 将e(u, v)排序後, 維護上凸包+單調隊列, 正着做一遍反着做一遍即可。

PS: 關于dp斜率優化

考慮dp: f[i]=max{f[j]+(e[i]−e[j])2} f [ i ] = max { f [ j ] + ( e [ i ] − e [ j ] ) 2 }

對與某個轉移j, 将式子移項, 分離變量, 隻和i有關的部分、 隻和j有關的部分、 和i,j均有關的部分。

(f[i]−e[i]2)=(f[j]+e[j]2)−(2∗e[i])∗e[j] ( f [ i ] − e [ i ] 2 ) = ( f [ j ] + e [ j ] 2 ) − ( 2 ∗ e [ i ] ) ∗ e [ j ]

将 f[i]−e[i]2 f [ i ] − e [ i ] 2 看作截距b, f[j]+e[j]2 f [ j ] + e [ j ] 2 看作y, 2∗e[i] 2 ∗ e [ i ] 看作斜率k, e[j] e [ j ] 看作x。

上式可以看作線性函數b = y - kx。

每個j對應一個坐标(x,y), 一系列的j在圖上就是一些點, 對于一個i就是一個詢問, 每個i對應一個斜率k, 每個i求一個斜率為k的經過圖中某個點的最大截距。

這裡是取最大值故維護上凸包(取min則維護下凸包), 在本題中, 考慮将e[i]從小到打排序, 先正過來求一遍, 即每個i都會考慮一遍小于它的j。 維護一個單調隊列, 對于詢問i, 由于詢問的斜率是遞增的, 按上凸包順時針方向看, 相鄰點構成的斜率遞減, 詢問i的取最大值的點滿足其向下一個點斜率小于詢問i的斜率,向上一個點的斜率大于詢問i的斜率, 又考慮到詢問i的斜率是遞增的來詢問的, 凸包上的點也是按x坐标遞增來加入的, 故應從單調隊列的尾端掃描, 根據斜率的比較關系, 斜率越大的詢問取最大值的點越靠前, 如果隊尾的上一個點由于隊尾, 說明對于一個更大的斜率也會優于隊尾的, 故彈出隊尾元素。 再将i對應的坐标點加入凸包中, 可以用向量叉積判斷凸包走向來決定是否删除隊尾元素 。 反向求一遍同理。

Code:

#include <map>
#include <set>
#include <map>
#include <bitset>
#include <cmath>
#include <queue>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <algorithm>

#define ll long long
#define db double
#define pw(x) ((x) * (x))
#define fi first
#define se second
#define mp(x, y) make_pair(x, y)

using namespace std;

const int N = (int) + ;

int n;
int cnt, lst[N], nxt[N * ], to[N * ]; ll c[N * ], pre[N];
map<int, ll> f_ch[N];
map<int, ll> :: iterator it;
ll f[N], g[N];

void add(int u, int v, int w){
    nxt[++ cnt] = lst[u]; lst[u] = cnt; to[cnt] = v; c[cnt] = w;
    nxt[++ cnt] = lst[v]; lst[v] = cnt; to[cnt] = u; c[cnt] = w;
}

void dfs(int u, int fa){
    for (int j = lst[u]; j; j = nxt[j]){
        int v = to[j];
        if (v == fa) continue;
        pre[v] = c[j];
        dfs(v, u);

        ll &x = f_ch[u][v];
        x = ;
        for (it = f_ch[v].begin(); it != f_ch[v].end(); it ++){
            x = max(x, it->se + pw(c[j] - pre[it->fi]));
            f[u] = max(f[u], x);
        }

    }
}

pair <ll, int > tmp[N]; int sz;
pair <ll, ll > que[N]; int head, tail;

pair <ll, ll> operator-(pair<ll, ll> a, pair<ll, ll>  b){
    return mp(a.fi-b.fi, a.se-b.se);
}
ll operator*(pair<ll, ll> a, pair<ll, ll>  b){
    return a.fi * b.se - a.se * b.fi;
}


ll cross(pair<ll, ll> a, pair<ll, ll> b, pair<ll, ll> c){
    return (a - b) * (b - c);
}

ll count(pair<ll, ll > x, ll e){
    return x.se-*e*x.fi+pw(e);
}

void dfs2(int u, int fa){
    sz = ;
    for (int j = lst[u]; j; j = nxt[j]){
        int v = to[j];
        if (v == fa) continue;
        tmp[++ sz] = mp(pre[v], v);
    }

    sort(tmp + , tmp + sz + );
    que[head = tail = ] = mp(tmp[].fi, pw(tmp[].fi) + f_ch[u][tmp[].se]);

    for (int i = ; i <= sz; i ++){
        int v = tmp[i].se; ll e = tmp[i].fi;
        while (head < tail && count(que[tail], e) <= count(que[tail - ], e)) tail --;
        g[v] = max(g[v], count(que[tail], e));

        pair<ll, ll> p = mp(e, f_ch[u][v] + pw(e));
        while (head < tail && cross(p, que[tail], que[tail - ]) <= )
            tail --;
        que[++ tail] = p;
    }

    que[head = tail = ] = mp(tmp[sz].fi, pw(tmp[sz].fi) + f_ch[u][tmp[sz].se]);
    for (int i = sz - ; i >= ; i --){
        int v = tmp[i].se; ll e = tmp[i].fi;
        while (head < tail && count(que[tail], e) <= count(que[tail - ], e)) tail --;
        g[v] = max(g[v], count(que[tail], e));

        pair<ll, ll> p = mp(e, f_ch[u][v] + pw(e));
        while (head < tail && cross(p, que[tail], que[tail - ]) >= )
            tail --;
        que[++ tail] = p;
    }

    if (fa){
        for (int i = ; i <= sz; i ++){
            int v = tmp[i].se; ll e = tmp[i].fi;
            g[v] = max(g[v], g[u] + pw(pre[u] - e));
        }
    }

    for (int j = lst[u]; j; j = nxt[j]){
        int v = to[j];
        if (v == fa) continue;
        dfs2(v, u);
    }

}

int getint(){
    int ret = ; char c = getchar();
    while (c > '9' || c < '0') c = getchar();
    while (c <= '9' && c >= '0'){
        ret = ret *  + c - '0';
        c = getchar();
    }
    return ret;
}

int main(){
    while (scanf("%d", &n) != EOF){
        for (int i = , u, v, w; i <= n; i ++){
            u = getint(), v = getint(), w = getint();

            add(u, v, w);
        }

        dfs(, );

        dfs2(, );

        for (int i = ; i <= n; i ++){
            printf("%lld\n", max(f[i], g[i]));
            lst[i] = ;
            f[i] = g[i] = ;
            f_ch[i].clear();
        }
        cnt = ;
    }

    return ;
}