天天看点

LOJ2542「PKUWC2018」随机游走(min-max容斥,高斯消元,树上)

传送门:https://loj.ac/problem/2542

m i n − m a x min-max min−max容斥模板

设 m i n ( S ) min(S) min(S)为 S S S集合中最早被访问的点期望步数

m a x ( S ) max(S) max(S)为集合中最晚被访问的点的期望步数

询问即是求 m a x ( S ) max(S) max(S)

由最值反演( m i n − m a x min-max min−max容斥)即可得到 m a x ( S ) = ∑ T ⊆ S ( − 1 ) ∣ T ∣ − 1 m i n ( T ) max(S)=\sum_{T\subseteq S}(-1)^{|T|-1}min(T) max(S)=∑T⊆S​(−1)∣T∣−1min(T)

现在的问题就是如何求 m i n ( T ) min(T) min(T)

用 f i , S f_{i,S} fi,S​表示从 i i i出发,第一次访问 S S S中的点所花的期望步数

f x , S f_{x,S} fx,S​即为我们要求的 m i n ( S ) min(S) min(S),写出转移方程用高斯消元解决

O ( 2 n ∗ n 3 ) O(2^n*n^3) O(2n∗n3), T L E TLE TLE

考虑优化,我们可以发现高斯消元的系数矩阵内只有 O ( n ) O(n) O(n)个位置有系数,并且只有父子之间有系数的关系。

当然如果你足够熟练,你也可以发现这是树上期望环状 d p dp dp的套路

f x , s f_{x,s} fx,s​肯定可以表示为 k ∗ f f a x , x + b k*f_{fa_x,x}+b k∗ffax​,x​+b的形势(也可以通过自下向上消元来理解)

列出方程解出消元的系数后从下往上消即可

注意容斥的时候不要用 3 n 3^n 3n枚举子集

#include<bits/stdc++.h>
using namespace std;
#define rep(i,j,k) for(int i = j;i <= k;++i)
#define repp(i,j,k) for(int i = j;i >= k;--i)
#define rept(i,x) for(int i = linkk[x];i;i = e[i].n)
#define P pair<int,int>
#define Pil pair<int,ll>
#define Pli pair<ll,int>
#define Pll pair<ll,ll>
#define pb push_back 
#define pc putchar
#define mp make_pair
#define file(k) memset(k,0,sizeof(k))
#define ll long long
int rd()
{
    int sum = 0;char c = getchar();bool flag = true;
    while(c < '0' || c > '9') {if(c == '-') flag = false;c = getchar();}
    while(c >= '0' && c <= '9') sum = sum * 10 + c - 48,c = getchar();
    if(flag) return sum;
    else return -sum;
}
const int p = 998244353;
int n,q,X,ms;
int linkk[20],t;
int tmp[20][20],a[20][20],du[20];
int g[20],cnt[263000],k[20],f[263000];
bool flag[20];
//a为系数矩阵
struct node{int n,y;}e[40];
int Pow(int a,int x)
{
    int now = 1;
    for(;x;x >>= 1,a = 1ll*a*a%p) if(x&1) now = 1ll*now*a%p;
    return now;
}
int mul(int a,int b){return 1ll*a*b%p;}
int del(int a,int b){return ((a-b)%p+p)%p;}
int calc(int a,int b){return (a+b)%p;}
void pre()
{
    n = rd();q = rd();X = rd();
    int x,y;
    rep(i,1,n-1)
    {
        du[x = rd()]++;du[y = rd()]++;
        e[++t].y = y;e[t].n = linkk[x];linkk[x] = t;
        e[++t].y = x;e[t].n = linkk[y];linkk[y] = t;
    }
    ms = (1<<n)-1;
}
void dfs(int x,int fa,int s)
{
    if(s>>x-1&1){g[x] = k[x] = 0;return;}
    g[x] = k[x] = du[x];
    for(int i = linkk[x];i;i = e[i].n)
    {
        int y = e[i].y;
        if(y == fa) continue;

        dfs(y,x,s);
        k[x] = del(k[x],k[y]);
        g[x] = calc(g[x],g[y]);
    }
    k[x] = Pow(k[x],p-2);g[x] = 1ll*g[x]*k[x]%p;
}
int main()
{
    pre();
    rep(s,1,ms)
    {
        rep(i,1,n) if(s>>i-1&1)
            cnt[s]++;
        dfs(X,0,s);
        f[s] = g[X];
        if(cnt[s] % 2 == 0) f[s] = -f[s];
    }
    rep(j,1,n) rep(s,1,ms)
            if(s>>j-1&1) (f[s] += f[s-(1<<(j-1))])%=p;
    rep(s,1,ms) f[s] = (f[s]+p)%p;
    rep(i,1,q)
    {
        int x = rd(),y,s = 0;
        rep(i,1,x)
            y = rd(),s += (1<<(y-1));
        printf("%d\n",f[s]);
    }
    return 0;
}