天天看点

CF161D Distance in Tree(点分治)

题目

CF161D Distance in Tree(点分治)

题解

  • 与这道题类似,可用类似的方法
  • 统计小于等于 k k k的方案数减去小于 k k k的方案数,就是等于 k k k的方案数。

code

#include <map>
#include <set>
#include <list>
#include <cmath>
#include <deque>
#include <queue>
#include <stack>
#include <cctype>
#include <string>
#include <cstdio>
#include <vector>
#include <complex>
#include <cstring>
#include <iomanip>    
#include <iostream>
#include <algorithm>
#include <functional>
using namespace std; 
typedef long long LL; 
const int maxn = 50000 + 100; 

template <class T> 
inline void read(T &s) {
	s = 0; 
	T w = 1, ch = getchar(); 
	while (!isdigit(ch)) { if (ch == '-') w = -1; ch = getchar(); }
	while (isdigit(ch)) { s = (s << 1) + (s << 3) + (ch ^ 48); ch = getchar(); }
	s *= w; 
}

LL ans, ans2; 
LL d[maxn]; 
int n, m, k, tot, all_node, root;  
int lin[maxn], max_part[maxn], size[maxn], len[maxn]; 
bool vis[maxn]; 
struct node {
	int next, to; 
} edge[maxn << 1]; 

inline void add(int from, int to) {
	edge[++tot].to = to; 
	edge[tot].next = lin[from]; 
	lin[from] = tot; 
}

void get_root(int u, int fa) { 
	max_part[u] = 0, size[u] = 1; 
	for (int i = lin[u]; i; i = edge[i].next) {
		int v = edge[i].to; 
		if (v == fa || vis[v]) continue;  
		get_root(v, u); 
		size[u] += size[v]; 
		max_part[u] = max(max_part[u], size[v]); 
	}
	max_part[u] = max(max_part[u], all_node - max_part[u]); 
	if (max_part[u] < max_part[root]) root = u; 
}

void get_dis(int u, int fa) { 
	len[++len[0]] = d[u]; 
	for (int i = lin[u]; i; i = edge[i].next) {
		int v = edge[i].to; 
		if (v == fa || vis[v]) continue; 
		d[v] = d[u] + 1; 
		get_dis(v, u); 
	}
}

LL cal(int u, int now) { 
	d[u] = now, len[0] = 0; 
	get_dis(u, 0); 
	sort(len + 1, len + len[0] + 1); 
	LL all = 0ll; 
	for (int l = 1, r = len[0]; l < r; ) {
		if (len[l] + len[r] <= k) {
			all += r - l; 
			++l; 
		}
		else r--; 
	}

	for (int l = 1, r = len[0]; l < r; ) {
		if (len[l] + len[r] < k) {
			all -= r - l; 
			++l; 
		}
		else r--; 
	}

	return all; 
}

void solve(int u) { 
	vis[u] = true; 
	ans += cal(u, 0); 
	for (int i = lin[u]; i; i = edge[i].next) {
		int v = edge[i].to; 
		if (vis[v]) continue; 
		ans -= cal(v, 1); 
		all_node = size[v]; 
		root = 0; 
		get_root(v, u); 
		solve(root); 
	}
}

int main() {
	// freopen("data.in", "r", stdin); 
	// freopen("bf.out", "w", stdout); 

	read(n); read(k); 
	for (int i = 1; i < n; ++i) {
		int u, v, w; 
		read(u), read(v); 
		add(u, v); 
		add(v, u); 
	}
	// read(k); 
	all_node = n, max_part[0] = n, root = 0; 
	get_root(1, 0); 
	solve(root); 
	printf("%lld\n", ans); 
	return 0; 
}