天天看點

LOJ2542「PKUWC2018」随機遊走(min-max容斥,高斯消元,樹上)

傳送門:https://loj.ac/problem/2542

m i n − m a x min-max min−max容斥模闆

設 m i n ( S ) min(S) min(S)為 S S S集合中最早被通路的點期望步數

m a x ( S ) max(S) max(S)為集合中最晚被通路的點的期望步數

詢問即是求 m a x ( S ) max(S) max(S)

由最值反演( m i n − m a x min-max min−max容斥)即可得到 m a x ( S ) = ∑ T ⊆ S ( − 1 ) ∣ T ∣ − 1 m i n ( T ) max(S)=\sum_{T\subseteq S}(-1)^{|T|-1}min(T) max(S)=∑T⊆S​(−1)∣T∣−1min(T)

現在的問題就是如何求 m i n ( T ) min(T) min(T)

用 f i , S f_{i,S} fi,S​表示從 i i i出發,第一次通路 S S S中的點所花的期望步數

f x , S f_{x,S} fx,S​即為我們要求的 m i n ( S ) min(S) min(S),寫出轉移方程用高斯消元解決

O ( 2 n ∗ n 3 ) O(2^n*n^3) O(2n∗n3), T L E TLE TLE

考慮優化,我們可以發現高斯消元的系數矩陣内隻有 O ( n ) O(n) O(n)個位置有系數,并且隻有父子之間有系數的關系。

當然如果你足夠熟練,你也可以發現這是樹上期望環狀 d p dp dp的套路

f x , s f_{x,s} fx,s​肯定可以表示為 k ∗ f f a x , x + b k*f_{fa_x,x}+b k∗ffax​,x​+b的形勢(也可以通過自下向上消元來了解)

列出方程解出消元的系數後從下往上消即可

注意容斥的時候不要用 3 n 3^n 3n枚舉子集

#include<bits/stdc++.h>
using namespace std;
#define rep(i,j,k) for(int i = j;i <= k;++i)
#define repp(i,j,k) for(int i = j;i >= k;--i)
#define rept(i,x) for(int i = linkk[x];i;i = e[i].n)
#define P pair<int,int>
#define Pil pair<int,ll>
#define Pli pair<ll,int>
#define Pll pair<ll,ll>
#define pb push_back 
#define pc putchar
#define mp make_pair
#define file(k) memset(k,0,sizeof(k))
#define ll long long
int rd()
{
    int sum = 0;char c = getchar();bool flag = true;
    while(c < '0' || c > '9') {if(c == '-') flag = false;c = getchar();}
    while(c >= '0' && c <= '9') sum = sum * 10 + c - 48,c = getchar();
    if(flag) return sum;
    else return -sum;
}
const int p = 998244353;
int n,q,X,ms;
int linkk[20],t;
int tmp[20][20],a[20][20],du[20];
int g[20],cnt[263000],k[20],f[263000];
bool flag[20];
//a為系數矩陣
struct node{int n,y;}e[40];
int Pow(int a,int x)
{
    int now = 1;
    for(;x;x >>= 1,a = 1ll*a*a%p) if(x&1) now = 1ll*now*a%p;
    return now;
}
int mul(int a,int b){return 1ll*a*b%p;}
int del(int a,int b){return ((a-b)%p+p)%p;}
int calc(int a,int b){return (a+b)%p;}
void pre()
{
    n = rd();q = rd();X = rd();
    int x,y;
    rep(i,1,n-1)
    {
        du[x = rd()]++;du[y = rd()]++;
        e[++t].y = y;e[t].n = linkk[x];linkk[x] = t;
        e[++t].y = x;e[t].n = linkk[y];linkk[y] = t;
    }
    ms = (1<<n)-1;
}
void dfs(int x,int fa,int s)
{
    if(s>>x-1&1){g[x] = k[x] = 0;return;}
    g[x] = k[x] = du[x];
    for(int i = linkk[x];i;i = e[i].n)
    {
        int y = e[i].y;
        if(y == fa) continue;

        dfs(y,x,s);
        k[x] = del(k[x],k[y]);
        g[x] = calc(g[x],g[y]);
    }
    k[x] = Pow(k[x],p-2);g[x] = 1ll*g[x]*k[x]%p;
}
int main()
{
    pre();
    rep(s,1,ms)
    {
        rep(i,1,n) if(s>>i-1&1)
            cnt[s]++;
        dfs(X,0,s);
        f[s] = g[X];
        if(cnt[s] % 2 == 0) f[s] = -f[s];
    }
    rep(j,1,n) rep(s,1,ms)
            if(s>>j-1&1) (f[s] += f[s-(1<<(j-1))])%=p;
    rep(s,1,ms) f[s] = (f[s]+p)%p;
    rep(i,1,q)
    {
        int x = rd(),y,s = 0;
        rep(i,1,x)
            y = rd(),s += (1<<(y-1));
        printf("%d\n",f[s]);
    }
    return 0;
}