天天看点

Codeforces 785E 题解(树套树-树状数组套线段树)

题目大意:

对于一个长度为n的序列进行k次操作,每次操作都是交换序列中的某两个数。对于每一个操作,回答当前序列中有多少个逆序对。

题解:

每次更改序列都可以理解为,将答案减去被调换的位置原有数字的对答案的贡献,然后调换两数字的位置,然后将答案加上被调换的位置在调换之后的数字对答案的贡献。

每一个数字对答案的贡献都可以理解为在其出现位置之前的比它大的数字的个数加上在其出现位置之后的比它小的数字的个数。

由此,用树状数组套线段树可做。树状数组是根据位置维护的,树状数组的每个结点包含一个权值线段树,但是直接开权值线段树的话空间会爆,所以线段数要动态开点,那么每一次修改只会至多使用log2n的空间。

同时也要注意,计算贡献的时候要注意被调换的两个数字之间的相互影响。

代码:

树套树版本

#include <cstdio>
#include <iostream>
using namespace std;

#define lowbit(k) (k&-(k))
const int maxn=int(e5)+;
int n,m;

struct Node {
    int sum,ls,rs;
    Node() {}
    Node(int s,int l,int r):sum(s),ls(l),rs(r) {}
}node[maxn*200];

int root[maxn],tot=;

void seg_modify(int k,int l,int r,int pos,int val) {
    if(l==r && l==pos) {
        node[k].sum=val;
        return;
    }
    int mid=(l+r)>>;
    int &ls=node[k].ls,&rs=node[k].rs;
    if(pos<=mid) {
        if(!ls) ls=++tot;
        seg_modify(ls,l,mid,pos,val);
    }
    else {
        if(!rs) rs=++tot;
        seg_modify(rs,mid+,r,pos,val);
    }
    node[k].sum=(ls?node[ls].sum:)+(rs?node[rs].sum:);
    return;
}

int seg_query(int k,int l,int r,int ql,int qr) {
    if(ql<=l && r<=qr) return node[k].sum;
    int mid=(l+r)>>;
    int &ls=node[k].ls,&rs=node[k].rs;
    if(qr<=mid) return ls?seg_query(ls,l,mid,ql,qr):0;
    if(ql> mid) return rs?seg_query(rs,mid+,r,ql,qr):;
    return (ls?seg_query(ls,l,mid,ql,qr):)+(rs?seg_query(rs,mid+,r,ql,qr):);
}

void bit_build() {
    for(int i=;i<=n;i++) root[i]=i;
    tot=n;
    for(int i=;i<=n;i++) {
        int len=lowbit(i);
        for(int j=i;j>=(i-len+);j--) seg_modify(root[i],,n,j,);
    }
    return;
}

int bit_query(int pos,int l,int r) {
    int res=;
    while(pos) {
        res+=seg_query(root[pos],,n,l,r);
        pos-=lowbit(pos);
    }
    return res;
}

void bit_modify(int pos,int val,int dir) {
    while(pos<=n) {
        seg_modify(root[pos],,n,val,dir);
        pos+=lowbit(pos);
    }
    return;
}

int a[maxn];

int main() {
#ifndef ONLINE_JUDGE
    freopen("input.txt","r",stdin);
    freopen("output.txt","w",stdout);
#endif // ONLINE_JUDGE

    scanf("%d%d",&n,&m);
    bit_build();

    long long ans=;
    for(int i=;i<=n;i++) a[i]=i;

    while(m--) {
        int x,y;
        scanf("%d%d",&x,&y);
        int g1=bit_query(n,,a[x]-)-bit_query(x,,a[x]-)+bit_query(x-,a[x]+,n);
        bit_modify(x,a[x],);
        int g2=bit_query(n,,a[y]-)-bit_query(y,,a[y]-)+bit_query(y-,a[y]+,n);
        bit_modify(y,a[y],);
        ans-=g1+g2;
        swap(a[x],a[y]);

        bit_modify(x,a[x],);
        int c1=bit_query(n,,a[x]-)-bit_query(x,,a[x]-)+bit_query(x-,a[x]+,n);
        bit_modify(y,a[y],);
        int c2=bit_query(n,,a[y]-)-bit_query(y,,a[y]-)+bit_query(y-,a[y]+,n);
        ans+=c1+c2;
        printf("%I64d\n",ans);
    }

    return ;
}
           

分块套树状数组版本(lewin)

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <iostream>
using namespace std;
// taken from here: https://github.com/igrsk/spoj/blob/master/SWAPS.cpp
#define FORO(i,n) for(int i = 0;i < n;i++)
#define FORI(i,n) for(int i = 1;i <= n;i++) 
const int MAXN = ;
const int MAXA = ;
const int MAXSQRTN = ;

int A[MAXN];
int N, M;
int bit[MAXSQRTN+][MAXA+];
int sqrtN;

void bitinc(int i,int v,int *d) {
  for(;i <= MAXA;i += i&-i) d[i]+=v;
}

void bitins(int x,int y,int v) {
  while(x <= sqrtN) {
    bitinc(y,v,bit[x]);
    x += x&-x;
  }
}

int bitsum(int x,int y) {
  int ret = ;
  for(;x > ;x-=x&-x)
    for(int yy = y;yy > ;yy-=yy&-yy)
      ret += bit[x][yy];
  return ret;
}

void init() {
  for(sqrtN = ;sqrtN*sqrtN < N;sqrtN++) ;
  for(int i = ;i < N;i++) {
    bitins(i/sqrtN+,A[i],);
  }
}

int query(int i,int x) {
  int ret = bitsum(i/sqrtN,x);
  // 0 ~ sqrtN-1, sqrtN ~ 2*sqrtN-1, ... 
  for(int j = sqrtN*(i/sqrtN);j <= i;j++)
    if(A[j] <= x) ret++;
  return ret;
}

int main() {
  freopen("input.txt","r",stdin);
  freopen("output.txt","w",stdout);

  scanf("%d %d",&N, &M);
  FORO(i,N) A[i] = i+;

  init();

  long long orans = ;
  FORO(i,N) {
    orans += i-query(i,A[i])+;
  }
  cout<<"orans:"<<orans<<endl;

  int X, Y, q1, q2;
  FORO(mm,M) {
    scanf("%d %d",&X,&Y);
    // update and query
    X--; Y--;
    if (X != Y) {
      q1 = A[X];
      q2 = A[Y];


      orans -= (X-query(X-,A[X])-)+(query(N-,A[X]-)-query(X,A[X]-));
      bitins(X/sqrtN+,A[X],-);

      A[X] = q2;
      bitins(X/sqrtN+,q2,);
      orans += (X-query(X-,q2)-)+(query(N-,q2-)-query(X,q2-));

      orans -= (Y-query(Y-,A[Y])-)+(query(N-,A[Y]-)-query(Y,A[Y]-));
      bitins(Y/sqrtN+,A[Y],-);

      A[Y] = q1;
      bitins(Y/sqrtN+,q1,);
      orans += (Y-query(Y-,q1)-)+(query(N-,q1-)-query(Y,q1-));
    }

    printf("%lld\n",orans);
  }
  return ;
}
           

分块套vector版本(kmjp)

#include <bits/stdc++.h>
using namespace std;
typedef signed long long ll;

#undef _P
#define _P(...) (void)printf(__VA_ARGS__)
#define FOR(x,to) for(x=0;x<(to);x++)
#define FORR(x,arr) for(auto& x:arr)
#define ITR(x,c) for(__typeof(c.begin()) x=c.begin();x!=c.end();x++)
#define ALL(a) (a.begin()),(a.end())
#define ZERO(a) memset(a,0,sizeof(a))
#define MINUS(a) memset(a,0xff,sizeof(a))
//-------------------------------------------------------

int N,Q;
int A[];
int L,R;
const int D=;
vector<int> V[];

void erase(int id,int v) {
    int b=id/D;
    int i;
    FOR(i,V[b].size()) if(V[b][i]==v) {
        V[b].erase(V[b].begin()+i);
        return;
    }
}

void add(int id,int v) {
    int b=id/D;
    int i;
    FOR(i,V[b].size()) if(v<V[b][i]) {
        V[b].insert(V[b].begin()+i,v);
        return;
    }
    V[b].push_back(v);
}

int getmore(int id,int v) {
    int i,j;
    int ret=;
    FOR(i,) {
        if(id<(i+)*D) {
            for(j=i*D;j<id;j++) if(A[j]>v && A[j]!=<<) ret++;
            break;
        }
        else {
            ret += V[i].end()-lower_bound(ALL(V[i]),v);
        }
    }
    return ret;
}
int getless(int id,int v) {
    int i,j;
    int ret=;
    FOR(i,) {
        if(id<(i+)*D) {
            for(j=i*D;j<id;j++) if(A[j]<v) ret++;
            break;
        }
        else {
            ret += lower_bound(ALL(V[i]),v)-V[i].begin();
        }
    }
    return ret;
}



void solve() {
    int i,j,k,l,r,x,y; string s;

    cin>>N>>Q;
    FOR(i,N) {
        A[i]=i+;
        add(i,A[i]);
    }

    ll ret=;

    while(Q--) {
        cin>>L>>R;
        L--,R--;
        if(L==R) {
            cout<<ret<<endl;
            continue;
        }
        if(L>R) swap(L,R);

        if(A[L]<A[R]) ret--;
        else ret++;

        ret-=getmore(R,A[R]);
        ret-=A[R]--getless(R,A[R]);
        ret-=getmore(L,A[L]);
        ret-=A[L]--getless(L,A[L]);
        erase(R,A[R]);
        erase(L,A[L]);

        swap(A[L],A[R]);
        add(R,A[R]);
        add(L,A[L]);
        ret+=getmore(R,A[R]);
        ret+=A[R]--getless(R,A[R]);
        ret+=getmore(L,A[L]);
        ret+=A[L]--getless(L,A[L]);

        cout<<ret<<endl;
    }
}


int main(int argc,char** argv){
    string s;int i;
    if(argc==) ios::sync_with_stdio(false), cin.tie();
    FOR(i,argc-) s+=argv[i+],s+='\n';
    FOR(i,s.size()) ungetc(s[s.size()--i],stdin);
    solve(); return ;
}