天天看點

Educational Codeforces Round 8(E. Zbazi in Zeydabad(樹狀數組優化))

題目連結:點選打開連結

題意:一個n*m矩陣, 裡面的格子除了'z'就是'.',問有多少個z形圖案。

思路:因為n和m很大, 即使n^3複雜度也會逾時。  如果按照最樸素的方法, 我們可以處理一下字首和, 處理出一個格子向左l[i][j]、向右r[i][j]、斜向左下lr[i][j]連着的z到哪裡為止, 這樣我們用n^2複雜度枚舉每一個格子作為z形圖案的右上角,取min(l[i][j], lr[i][j]), 就可以立刻知道這個z形的最左下角到哪裡, 然後在這個對角線上掃一遍, 看看向右最遠是不是符合條件,複雜度n^3。  我們注意到, 最後一步是處理一個對角線區間上的問題。  這其實可以用樹狀數組快速累加和。

我們隻需要從右向左枚舉每一列, 将以這一列為最右端的線段最左端加進樹狀數組, 那麼我們再枚舉這一列的所有點作為Z形的右上角, 算出左下角。  由于隻有同一對角線上的行加列相同, 是以以此為樹狀數組編号, 開n + m棵樹狀數組, 就可以快速累加在這個對角線上的滿足列介于上面那一橫的範圍内的下面那一橫的數量了。 又以為從右向左枚舉列, 是以先加的一定滿足之後的。

細節參見代碼:

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<iostream>
#include<string>
#include<vector>
#include<stack>
#include<bitset>
#include<cstdlib>
#include<cmath>
#include<set>
#include<list>
#include<deque>
#include<map>
#include<queue>
#define Max(a,b) ((a)>(b)?(a):(b))
#define Min(a,b) ((a)<(b)?(a):(b))
using namespace std;
typedef long long ll;
const double PI = acos(-1.0);
const double eps = 1e-6;
const int mod = 1000000000 + 7;
const int INF = 1000000000;
const int maxn = 3000 + 10;
int T,n,m, bit[maxn*2][maxn], l[maxn][maxn],lr[maxn][maxn],r[maxn][maxn];
ll ans = 0;
struct node {
    int x, y;
};
vector<node> g[maxn];
char s[maxn][maxn];
int sum(int i, int x) {
    int ans = 0;
    while(x > 0) {
        ans += bit[i][x];
        x -= x & -x;
    }
    return ans;
}
void add(int i, int x, int d) {
    while(x <= m) {
        bit[i][x] += d;
        x += x & -x;
    }
}
void pre() {
    for(int i=1;i<=n;i++) {
        for(int j=1;j<=m;j++) {
            if(s[i][j] == 'z') l[i][j] = l[i][j-1]+1;
            else l[i][j] = 0;
        }
        for(int j=m;j>=1;j--) {
            if(s[i][j] == 'z') r[i][j] = r[i][j+1]+1;
            else r[i][j] = 0;
        }
    }
    for(int i=n;i>=1;i--) {
        for(int j=1;j<=m;j++) {
            if(s[i][j] != 'z') { lr[i][j] = 0; continue;}
            if(j-1 >= 1 && i+1 <= n) {
                lr[i][j] = lr[i+1][j-1] + 1;
            }
            else lr[i][j] = 1;
        }
    }
    for(int i=1;i<=n;i++) {
        for(int j=1;j<=m;j++) {
            g[j + r[i][j] - 1].push_back(node{i, j});
        }
    }
}
int main() {
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++) {
        scanf("%s",s[i]+1);
    }
    ans = 0;
    pre();
    for(int j=m;j>=1;j--) {
        int len = g[j].size();
        for(int i=0;i<len;i++) {
            int& x = g[j][i].x, &y = g[j][i].y;
            add(x+y, y, 1);
        }
        for(int i=1;i<=n;i++) {
            if(s[i][j] != 'z') continue;
            int len = min(l[i][j], lr[i][j]);
            int y = j - len + 1;
            ans += sum(j+i, j) - sum(i+j, j-len);
        }
    }
    printf("%I64d\n",ans);
    return 0;
}