参考:
http://hi.baidu.com/strongoier/item/fe47a4191c18a37c1009b515
http://hi.baidu.com/shuxk/item/2bd327977576038159146161
http://hi.baidu.com/gugugupan/item/f29befcd12accc67f7c95db9
给出N(1 <= N <= 10000)个结点的树,求使得路径u -> v长度不超过k的点对(u, v)的个数。
点分治法:每次找出分治点,求经过其的路径的点对,子树递归处理即可。
会有重复值,需要减去子节点相关的一些值
注意技巧:
删除点的办法。adj[r ^ 1].mk = false;
找出分支点的方法。(与直径中心的差别?)
算法就是基于树的分治,对于每个树,先要选定一个根节点,这个根节点要保证它的深度尽量小,然后才可以保证时间复杂度。选根的过程大致就是一次DFS,找到以 I 为根节点时儿子最多的子树的儿子树最少的 I 就可以作为根了。
#pragma comment(linker, "/STACK:1024000000,1024000000")
#include <cstdio>
#include <ctime>
#include <cstdlib>
#include <cstring>
#include <queue>
#include <string>
#include <set>
#include <stack>
#include <map>
#include <cmath>
#include <vector>
#include <iostream>
#include <algorithm>
using namespace std;
#define FE(i, a, b) for(int i = (a); i <= (b); ++i)
#define FD(i, b, a) for(int i = (b); i >= (a); --i)
#define REP(i, N) for(int i = 0; i < (N); ++i)
#define CLR(a, v) memset(a, v, sizeof(a))
#define PB push_back
#define MP make_pair
#define RI(n) scanf("%d", &n)
#define RIII(a, b, c) scanf("%d%d%d",&a, &b, &c)
typedef long long LL;
const double eps = 1e-10;
const int maxn = 10010;
int a[maxn], an;
int sz[maxn], szn;
int f[maxn], d[maxn];
int n, m;
int cur;///curroot
int ans;
struct Edge {
int to, w, next;
bool mk;
}adj[maxn * 2];
int adj_tot, head[maxn];
void adj_add(int x, int y, int w)
{
adj[adj_tot].to = y; adj[adj_tot].w = w;
adj[adj_tot].next = head[x];
adj[adj_tot].mk = true;
head[x] = adj_tot++;
}
void find_root(int u, int fa)///找分治点
{
sz[u] = 1;///
f[u] = 0;
for (int r = head[u]; ~r; r = adj[r].next)
{
int v = adj[r].to;
if (adj[r].mk && v != fa)
{
find_root(v, u);
sz[u] += sz[v];
f[u] = max(f[u], sz[v]);
}
}
f[u] = max(f[u], szn - sz[u]);
if (f[u] < f[cur]) cur = u;
}
void dfs(int u, int fa)
{
sz[u] = 1;///
a[an++] = d[u];
for (int r = head[u]; ~r; r = adj[r].next)
{
int v = adj[r].to;
if (adj[r].mk && v != fa)
{
d[v] = d[u] + adj[r].w;
dfs(v, u);
sz[u] += sz[v];
}
}
}
int calc(int x, int initval)
{
int ret = an = 0;
d[x] = initval;
dfs(x, 0);
sort(a, a + an);
int l = 0, r = an - 1;
for (; l < r;)///!!!
{
if (a[r] + a[l] <= m)
ret += r - l++;
else r--;
}
return ret;
}
void solve(int u)
{
ans += calc(u, 0);
for (int r = head[u]; ~r; r = adj[r].next)
{
int v = adj[r].to;
if (adj[r].mk)
{
adj[r ^ 1].mk = false;///!!!
ans -= calc(v, adj[r].w);///???
f[0] = szn = sz[v];
find_root(v, cur = 0);
solve(cur);
}
}
}
int main ()
{
int x, y, w;
while (scanf("%d%d", &n, &m) != EOF && n + m )
{
adj_tot = 0; CLR(head, -1);
ans = 0;
for (int i = 0; i < n - 1; i++)
{
scanf("%d%d%d", &x, &y, &w);
adj_add(x, y, w); adj_add(y, x, w);
}
f[0] = szn = n;
find_root(1, cur = 0);
solve(cur);
printf("%d\n", ans);
}
return 0;
}