天天看点

JZOJ 6868. 【2020.11.17提高组模拟】数树(容斥+树上背包)JZOJ 6868. 【2020.11.17提高组模拟】数树

JZOJ 6868. 【2020.11.17提高组模拟】数树

题目大意

  • 给出一棵大小为 N N N的树,树边有向,求 ∀ i ∈ [ 1 , N ) ( a i , a i + 1 ) ∉ E d g e \forall i\in[1,N)(a_i,a_{i+1})\notin Edge ∀i∈[1,N)(ai​,ai+1​)∈/​Edge的排列方案数。
  • N ≤ 5000 N\leq 5000 N≤5000.

题解

  • 题意即任意一条树边的两点不能在排列中按边的指向方向连续出现。
  • 不难会想到容斥,假设已经知道了至少有 i i i条边不合法的方案数 s i s_i si​,则 a n s = ∑ s i ∗ ( n − i ) ! ∗ ( − 1 ) n − i ans=\sum s_i*(n-i)!*(-1)^{n-i} ans=∑si​∗(n−i)!∗(−1)n−i,
  • 容斥系数怎么得来的?
  • 如果没有边不合法,则 n n n个数任意排列,方案数为 n ! n! n!。每连上一条边,则把两个点锁定在一起,方案数为 ( n − 1 ) ! (n-1)! (n−1)!,以此类推。
  • 现在的问题是如何求 s i s_i si​?这里 s i s_i si​相当于在 n − 1 n-1 n−1条边中选出 i i i条边的方案数。
  • 可以用组合数直接算吗?由于所选的边需满足起点互不相同,终点互不相同,所以没法直接组合数计算。
  • 那么试着在树上DP, f i , j f_{i,j} fi,j​表示第 i i i个点的子树内选了 j j j条边的方案数,此外还要记录其它的状态,分别为:没有选择与该点相连的边,选择了该点的一条入边,选择了该点的一条出边,选择了该点的一条入边和一条出边。
  • 转移比较自然,每次从儿子节点转移,与儿子的连边选或不选,用树上背包即可。
  • 要枚举DFS每个节点,然后枚举背包的容量,再枚举转移的大小,复杂度是 O ( N 3 ) O(N^3) O(N3)的吗?
  • 如果真的就这样写的话,确实是 O ( N 3 ) O(N^3) O(N3)的,但一般的树上背包并不需要这样写。
  • 每次转移时,一个一个儿子按顺序转移,每做完一个儿子才把它的 s i z e size size加到父亲的 s i z e size size中,背包容量每次只需要枚举到当前的 s i z e size size,而转移的大小只需要枚举到儿子的 s i z e size size,这样确实的优化了不少,但复杂度有减少吗?
  • 发现枚举背包容量和转移大小的总操作次数等同于枚举子树中任意两个点匹配,而在整个程序过程中,任意两个点只会在它们的LCA出“匹配”一次,所以总的复杂度为 O ( N 2 ) O(N^2) O(N2)。

代码

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
#define N 5010
#define ll long long
#define md 998244353
ll F[N], f[N][N][4], g[N][4];
int last[N], nxt[N * 2], to[N * 2], o[N * 2], len = 0;
int si[N];
void add(int x, int y, int c) {
	to[++len] = y;
	nxt[len] = last[x];
	o[len] = c;
	last[x] = len;
}
// 0 neither
// 1 in
// 2 out
// 3 both
void dfs(int k, int fa) {
	si[k] = 0;
	f[k][0][0] = 1;
	for(int i = last[k]; i; i = nxt[i]) if(to[i] != fa) {
		int x = to[i];
		dfs(x, k);
		memset(g, 0, sizeof(g));
		for(int j = 0; j <= si[k]; j++) 
			for(int h = 0; h <= si[x]; h++) {
				ll t = f[x][h][0] + f[x][h][1] + f[x][h][2] +f[x][h][3];
				(g[j + h][0] += f[k][j][0] * t) %= md;
				(g[j + h][1] += f[k][j][1] * t) %= md;
				if(o[i] == 0) (g[j + h + 1][1] += f[k][j][0] * (f[x][h][0] + f[x][h][1])) %= md;
				(g[j + h][2] += f[k][j][2] * t) %= md;
				if(o[i] == 1) (g[j + h + 1][2] += f[k][j][0] * (f[x][h][0] + f[x][h][2])) %= md;
				(g[j + h][3] += f[k][j][3] * t) %= md;
				if(o[i] == 0) (g[j + h + 1][3] += f[k][j][2] * (f[x][h][0] + f[x][h][1])) %= md;
				if(o[i] == 1) (g[j + h + 1][3] += f[k][j][1] * (f[x][h][0] + f[x][h][2])) %= md;
			}
		si[k] += si[x];
		for(int j = 0; j <= si[k]; j++)
			for(int h = 0; h < 4; h++) f[k][j][h] = g[j][h];
	}
	si[k]++;
}
int main() {
	int n, i, j, x, y;
	scanf("%d", &n);
	for(i = 1; i < n; i++) {
		scanf("%d%d", &x, &y);
		add(x, y, 1), add(y, x, 0);
	}	
	dfs(1, 0);
	int r = 1;
	F[0] = 1;
	for(i = 1; i <= n; i++) F[i] = F[i - 1] * i % md;
	ll ans = 0;
	for(i = 0; i < n; i++) {
		ll t = 0;
		for(j = 0; j < 4; j++) (t += f[1][i][j]) %= md;
		(ans += F[n - i] * t % md * r + md) %= md;
		r = -r;
	}
	printf("%lld\n", ans);
	return 0;
}