天天看點

HDU - 5955 Guessing the Dice Roll——ac自動機+機率dp+高斯消元

#include <bits/stdc++.h>
using namespace std;
const int maxn = 105;
int T, N, L, data[maxn], id[maxn];
double a[maxn][maxn], b[maxn];
void Gauss(int num) {
    int i, j, k, col, max_r;
    for (k = 0, col = 0; k < num && col < num; k++, col++) {
        max_r = k;
        for (i = k+1; i < num; i++) {
            if (fabs(a[i][col]) > fabs(a[max_r][col])) max_r = i;
        }
        if (k != max_r) {
            for (j = col; j < num; j++) {
                swap(a[k][j], a[max_r][j]);
            }
            swap(b[k], b[max_r]);
        }
        b[k] /= a[k][col];
        for (j = col+1; j < num; j++) a[k][j] /= a[k][col];
        a[k][col] = 1;
        for (i = 0; i < num; i++) {
            if (i != k) {
                b[i] -= b[k]*a[i][k];
                for (j = col+1; j < num; j++) a[i][j] -= a[k][j]*a[i][col];
                a[i][col] = 0;
            }
        }
    }
}
int tot, ch[maxn][10], val[maxn], f[maxn], que[maxn];
void init() {
    tot = 0;
    memset(ch, 0, sizeof(ch));
    memset(val, 0, sizeof(val));
    memset(f, 0, sizeof(f));
}
void Insert(int idx) {
    int now = 0;
    for (int i = 1; i <= L; i++) {
        if (!ch[now][data[i]]) ch[now][data[i]] = ++tot;
        now = ch[now][data[i]];
    }
    val[now] = 1;
    id[idx] = now;
}
void Getfail() {
    int l = 1, r = 0;
    for (int i = 1; i <= 6; i++) if (ch[0][i]) que[++r] = ch[0][i];
    while (l <= r) {
        int u = que[l++];
        for (int i = 1; i <= 6; i++) {
            if (!ch[u][i]) { ch[u][i] = ch[f[u]][i]; continue; }
            int v = ch[u][i];
            que[++r] = v;
            int t = f[u];
            while (t && !ch[t][i]) t = f[t];
            f[v] = ch[t][i];
        }
    }
}
void Solve() {
    for (int i = 0; i <= tot; i++) {
        a[i][i] = 1, b[i] = 0;
        if (val[i]) continue;
        for (int j = 1; j <= 6; j++) {
            a[ch[i][j]][i] -= 1.0/6.0;
        }
    }
    b[0] += 1.0;
}
int main() {
    scanf("%d", &T);
    while (T--) {
        scanf("%d%d", &N, &L);
        init();
        for (int i = 1; i <= N; i++) {
            for (int j = 1; j <= L; j++) scanf("%d", &data[j]);
            Insert(i);
        }
        Getfail();
        Solve();
        Gauss(1+tot);
        for (int i = 1; i <= N; i++) {
            printf("%.6lf", b[id[i]]);
            if (i == N) printf("\n");
            else printf(" ");
        }
    }
    return 0;
}      
下一篇: jBPM4入門

繼續閱讀