天天看点

ZZH与计数题目思路代码

题目

传送门 to usOJ

题目描述

众所周知,题目的难度与出题人的能力成正比。

现在「霸树中学」的同学要集训,要天天做题。由于 自愿捐助款项太多 资金周转,有时候会请来金牌出题人「标枪」,有时候请来废铁出题人「 O I D \tt OID OID」。

如果请来的是「标枪」,那么会触发效果「同学们的信心大幅下降了!」于是原本的信心值 v v v 的二进制(毕竟是信息竞赛)表示下为 1 1 1 的值,都会以 1 2 \frac{1}{2} 21​ 的概率变成 0 0 0 。

反之,如果请来「 O I D \tt OID OID」,会触发效果「同学们 对出题人的鄙夷 的信心大幅增强了!」于是原来的信心值 v v v 的二进制表示下为 0 0 0 的位置,都会以 1 2 \frac{1}{2} 21​ 的概率变为 1 1 1 。

而请来的是谁,也是一个概率问题。有 p p p 的概率请来「标枪」,也有 1 − p 1-p 1−p 的概率请来「 O I D \tt OID OID」,并且进行了 m m m 天集训(每天都重新请一个出题人,概率仍然为 p p p 与 1 − p 1-p 1−p 不变)。

告诉你最初每种信心值的数量,你能告诉我最后每种信心值的期望数量吗?

数据范围与提示

认为所有的二进制表示只有 n n n 位,也就是说,任何数字在任意时刻均小于 2 n 2^{n} 2n 。

n ≤ 17 n\le 17 n≤17 但是天数 m ≤ 1 0 9 m\le 10^9 m≤109 。由于输出要取模 998244353 998244353 998244353 ,所以 p = a b p=\frac{a}{b} p=ba​ 以及每种信心值的个数 v i v_i vi​ 满足 max ⁡ ( a , b , v i ) < 998244353 \max(a,b,v_i)<998244353 max(a,b,vi​)<998244353 就不足为奇。

时限 3 s 3\text{s} 3s ,空间 1 GB 1\text{GB} 1GB 。

思路

是不是有人认为可以每一位分开考虑呢 😮 然后你就会发现不对,因为两个二进制位之间有影响。

正确的姿势是用 d p \tt dp dp 整体的处理一个数字。对于某个数字, f ( t , i , j ) f(t,i,j) f(t,i,j) 表示 t t t 天之后得到另一个数,并且原来为 0 0 0 的,现在有 i i i 个是 1 1 1 ;原来是 1 1 1 的,现在有 j j j 个还是 1 1 1 。 f f f 求的是 概率 。

v = ( 00000 ⋯ 00    11111 ⋯ 11 ) 2 v ′ = ( 111 ⋯ 1 ⏟ i 个 1 000    111 ⋯ 1 ⏟ j 个 1 000 ) 2 \begin{aligned} v&=(00000\cdots00\;11111\cdots11)_2\\ v'&=(\underbrace{111\cdots1}_{i个1}000\;\underbrace{111\cdots1}_{j个1}000)_2 \end{aligned} vv′​=(00000⋯0011111⋯11)2​=(i个1

111⋯1​​000j个1

111⋯1​​000)2​​

为了方便观察,我在中间打了个空格,实际上没有。而且 v , v ′ v,v' v,v′ 中的 0 , 1 0,1 0,1 不需要排列的那么规整,不过为了观察,我把它们放在一起。

然后转移就比较简单了。假设原来有 i 0 i_0 i0​ 个 0 0 0 ,有 j 0 = n − i 0 j_0=n-i_0 j0​=n−i0​ 个 1 1 1 。那么

f ( t + 1 , i , j ) = ( 1 − p ) ∑ x = 0 i ∑ y = 0 j ( 1 2 ) n − x − y ( i x ) ( j y ) f ( t , x , y ) + p ∑ x = i i 0 ∑ y = j j 0 ( 1 2 ) x + y ( i 0 − i x − i ) ( j 0 − j y − j ) f ( t , x , y ) f(t+1,i,j)=\\ (1-p)\sum_{x=0}^{i}\sum_{y=0}^{j}\left({1\over 2}\right)^{n-x-y}{i\choose x}{j\choose y}f(t,x,y)\\ +p\sum_{x=i}^{i_0}\sum_{y=j}^{j_0}\left({1\over 2}\right)^{x+y}{i_0-i\choose x-i}{j_0-j\choose y-j}f(t,x,y) f(t+1,i,j)=(1−p)x=0∑i​y=0∑j​(21​)n−x−y(xi​)(yj​)f(t,x,y)+px=i∑i0​​y=j∑j0​​(21​)x+y(x−ii0​−i​)(y−jj0​−j​)f(t,x,y)

(易错点:为什么没有一些你们认为的组合数?因为这里计算的是到达某类数字中 任意一个 的概率,而不是所有这一类数的概率的总和。为什么有一些没想到的组合数?因为两类数字之间可能有多条边。)

枚举一个 i 0 i_0 i0​ ,然后做 O ( n 6 log ⁡ m ) \mathcal O(n^6\log m) O(n6logm) 的矩阵加速,复杂度 O ( n 7 log ⁡ m ) \mathcal O(n^7\log m) O(n7logm) 。蛤?感觉不行?但是 i 0 × j 0 ≤ n 4 i_0\times j_0\le\frac{n}{4} i0​×j0​≤4n​ ,于是有了 1 64 \frac{1}{64} 641​ 的超级牛逼的常数!

然而我们只求得了概率。 O ( 2 n × 2 n ) \mathcal O(2^n\times 2^n) O(2n×2n) 的计算贡献是不可接受的。所以我们还得 d p \tt dp dp 求。(一个值得思考的问题,为啥这里 d p \tt dp dp 可以做到更快?因为这里的贡献计算是有过程的,是 0 0 0 与 1 1 1 转化的数量决定。)

用 g ( t , S , i , j ) g(t,S,i,j) g(t,S,i,j) 表示,考虑了前 t t t 个二进制位之后,得到一个过程中的数字 S S S ,但是还有 i i i 个 0 0 0 要变成 1 1 1 ,有 j j j 个 1 1 1 要变成 0 0 0 。(关于这个,可以理解为在 D A G \tt DAG DAG 上推贡献。)

初始时 g ( 0 , S , i , j 0 − j ) = f ( m , i , j ) × v S g(0,S,i,j_0-j)=f(m,i,j)\times v_S g(0,S,i,j0​−j)=f(m,i,j)×vS​

(如果你仔细思考会明白,我还要指出 j 0 = b i t c o u n t ( S ) j_0=bitcount(S) j0​=bitcount(S) 才算严谨。)

复杂度 O ( n 3 × 2 n ) \mathcal O(n^3\times 2^n) O(n3×2n) 。然而同样也有 1 4 \frac{1}{4} 41​ 的常数(实际上可能更小)。

答案当然是 g ( n , 0 , 0 , 0 ) , … , g ( n , 2 n − 1 , 0 , 0 ) g(n,0,0,0),\dots,g(n,2^n-1,0,0) g(n,0,0,0),…,g(n,2n−1,0,0) ,推到底了。然后你发现 g g g 开不下,要把第一位滚动掉。然后你发现非常容易滚动,因为 S S S 只会与 S ⊕ 2 t S\oplus 2^t S⊕2t 发生关联。

代码

卡常技巧:计算 g g g 的复杂度极高!你的循环是怎样枚举的, g g g 的变量就要怎么开。

所以我开成了 g ( i , j , S ) g(i,j,S) g(i,j,S) 而不是 g ( S , i , j ) g(S,i,j) g(S,i,j) 。

#include <cstdio>
#include <iostream>
#include <vector>
#include <cstring>
using namespace std;
typedef long long int_;
const int __M__ = 20000000;
char inBuf[__M__], *iS = inBuf, *iT = inBuf;
inline char getChar(){
	if(iS == iT){
		iT = fread(inBuf,1,__M__,stdin)
			+ (iS = inBuf);
	}
	return *(iS ++);
}
inline int readint(){
	int a = 0; char c = getChar(), f = 1;
	for(; c<'0'||c>'9'; c=getChar())
		if(c == '-') f = -f;
	for(; '0'<=c&&c<='9'; c=getChar())
		a = (a<<3)+(a<<1)+(c^48);
	return a*f;
}
inline void writeint(int x){
	if(x > 9) writeint(x/10);
	putchar((x-x/10*10)^48);
}

const int Mod = 998244353;
inline int qkpow(int_ b,int q){
	int ans = 1; b %= Mod;
	for(; q; q>>=1,b=b*b%Mod)
		if(q&1) ans = ans*b%Mod;
	return ans;
}

const int MaxN = 17;
int N; // 矩阵的大小
struct Matrix{
	int a[90][90];
	void clear(){
		for(int i=0; i<N; ++i)
		for(int j=0; j<N; ++j)
			a[i][j] = 0;
	}
	Matrix operator * (const Matrix &b) const {
		Matrix c; c.clear();
		for(int i=0; i<N; ++i)
		for(int j=0; j<N; ++j)
		if(a[i][j] != 0)
		for(int k=0; k<N; ++k)
			c.a[i][k] = (c.a[i][k]+1ll
				*a[i][j]*b.a[j][k])%Mod;
		return c; // 复杂度还蛮高的
	}
};
Matrix qkpow(Matrix b,int q){
	Matrix ans = b; -- q;
	for(; q; q>>=1,b=b*b)
		if(q&1) ans = ans*b;
	return ans;
}

const int inv2 = (Mod+1)>>1;
int n, m, p, c[MaxN+1][MaxN+1];
int f[MaxN+1][MaxN+1][MaxN+1];
int powTwo[MaxN+1]; // (1/2)^x
void solveF(){
	powTwo[0] = 1;
	for(int i=1; i<=n; ++i)
		powTwo[i] = 1ll*powTwo[i-1]*inv2%Mod;
	for(int i=0; i<=n; ++i)
	for(int j=c[i][0]=1; j<=i; ++j)
		c[i][j] = c[i-1][j-1]+c[i-1][j];
	Matrix S; // S 为转移矩阵
	for(int i0=0; i0<=n; ++i0){
		int j0 = n-i0+1; // 开区间
		N = (i0+1)*j0; S.clear();
		for(int i=0; i<=i0; ++i)
		for(int j=0; j<j0; ++j){
			int zxy = i*j0+j; // 卡常
			for(int x=0; x<=i; ++x)
			for(int y=0; y<=j; ++y)
				S.a[x*j0+y][zxy] =
					1ll*powTwo[n-x-y]*
					c[i][x]%Mod*c[j][y]%Mod
					*(Mod+1-p)%Mod;
			for(int x=i; x<=i0; ++x)
			for(int y=j; y<j0; ++y)
				S.a[x*j0+y][zxy] =
					1ll*powTwo[x+y]*
					c[i0-i][x-i]%Mod*
					c[j0-1-j][y-j]%Mod
					*p%Mod;
			S.a[zxy][zxy] += (Mod+1ll-p)
				*powTwo[n-i-j]%Mod; // 不变
			S.a[zxy][zxy] %= Mod;
		}
		if(m) S = qkpow(S,m);
		else{
			S.clear(); // 单位矩阵
			for(int i=0; i<N; ++i)
				S.a[i][i] = 1;
		}
		for(int i=0; i<=i0; ++i)
		for(int j=0; j<j0; ++j)
			f[i0][i][j] = S.a[j0-1][i*j0+j];
	}
}

int g[MaxN+1][MaxN+1][1<<MaxN];
int cnt[1<<MaxN]; // bitcount
int num[1<<MaxN]; // amount of number
void solveG(){
	for(int S=0; S<(1<<n); ++S){
		if(S) cnt[S] = cnt[S-(S&-S)]+1;
		for(int i=0; i<=n-cnt[S]; ++i)
		for(int j=0; j<=cnt[S]; ++j)
			g[i][cnt[S]-j][S] = 1ll*num[S]
				*f[n-cnt[S]][i][j]%Mod;
	}
	for(int t=0; t<n; ++t)
	for(int i=0; i<n-t; ++i)
	for(int j=0; i+j<n-t; ++j){
		int *g_ = g[i][j], *gi = g[i+1][j], *gj = g[i][j+1];
		for(int S=0; S<(1<<n); ++S)
			if(!(S>>t&1))
				g_[S] = (g_[S]+gj[S^(1<<t)])%Mod;
			else // if(S>>t&1)
				g_[S] = (g_[S]+gi[S^(1<<t)])%Mod;
	}
}

int main(){
	n = readint(), m = readint();
	p = readint(); p = 1ll*p*
		qkpow(readint(),Mod-2)%Mod;
	solveF(); // 预处理 F
	for(int i=0; i<(1<<n); ++i)
		num[i] = readint();
	solveG();
	writeint(g[0][0][0]); // 去掉行末空格
	for(int i=1; i<(1<<n); ++i){
		putchar(' ');
		writeint(g[0][0][i]);
	}
	putchar('\n');
	return 0;
}
           

继续阅读