天天看点

P4755-Beautiful Pair【笛卡尔树,线段树】

正题

题目链接:https://www.luogu.com.cn/problem/P4755

题目大意

\(n\)个数字的一个序列,求有多少个点对\(i,j\)满足\(a_i\times a_j\leq max\{a_k\}(k\in[l,r])\)

解题思路

如果构建一棵笛卡尔树的话那么两个点之间的\(max\)就在笛卡尔树的\(LCA\)位置。

所以对于每个位置维护一个线段树,然后每次暴力枚举小的那棵子树在大子树的线段树中查询即可。然后线段树合并或者启发式合并上去就好了。

建笛卡尔树的时候用\(\text{RMQ}\)查询区间最大值然后递归下去就好了。

当然因为是乘法所以小的那个值域不会超过\(\sqrt{10^9}\)所以也可以树状数组+启发式合并。

这里写的是线段树的做法,时间复杂度都是\(O(n\log^2 n)\)

注意\(1\)要特判就好了

code

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=1e5+10,L=20;
int n,a[N],lg[N],f[N][L+1],inf;
long long ans;
struct Seg_Tree{
    int cnt,w[N<<6],ls[N<<6],rs[N<<6];
    void Change(int &x,int L,int R,int pos,int val){
        if(!x)x=++cnt;w[x]+=val;
        if(L==R)return;
        int mid=(L+R)>>1;
        if(pos<=mid)Change(ls[x],L,mid,pos,val);
        else Change(rs[x],mid+1,R,pos,val);
        return;
    }
    int Ask(int x,int L,int R,int l,int r){
        if(!x||l>r)return 0;
        if(L==l&&R==r)return w[x];
        int mid=(L+R)>>1;
        if(r<=mid)return Ask(ls[x],L,mid,l,r);
        if(l>mid)return Ask(rs[x],mid+1,R,l,r);
        return Ask(ls[x],L,mid,l,mid)+Ask(rs[x],mid+1,R,mid+1,r);
    }
    int Merge(int x,int y,int L,int R){
        if(!x||!y)return x+y;
        int mid=(L+R)>>1;w[x]+=w[y];
        if(L==R)return x;
        ls[x]=Merge(ls[x],ls[y],L,mid);
        rs[x]=Merge(rs[x],rs[y],mid+1,R);
        return x;
    }
}T;
int Ask(int l,int r){
    int z=lg[r-l+1];
    int x=f[l][z],y=f[r-(1<<z)+1][z];
    return (a[x]>=a[y])?x:y;
}
int solve(int l,int r){
    if(l>r)return 0;
    int x=Ask(l,r),ls,rs;
    ls=solve(l,x-1);
    rs=solve(x+1,r);
    if(ls)ans+=T.Ask(ls,1,inf,1,1);
    if(rs)ans+=T.Ask(rs,1,inf,1,1);
    if(x-l<=r-x){
        for(int i=l;i<x;i++)
            ans+=T.Ask(rs,1,inf,1,a[x]/a[i]);
    }
    else{
        for(int i=x+1;i<=r;i++)
            ans+=T.Ask(ls,1,inf,1,a[x]/a[i]);
    }
    ls=T.Merge(ls,rs,1,inf);
    T.Change(ls,1,inf,a[x],1);
    return ls;
}
int main()
{
    // printf("%d\n",sizeof(T)>>20);
    scanf("%d",&n);
    for(int i=1;i<=n;i++){
        scanf("%d",&a[i]);
        inf=max(inf,a[i]);
        ans+=(a[i]==1);
        f[i][0]=i;
    }
    inf=1e9;
    for(int i=2;i<=n;i++)lg[i]=lg[i>>1]+1;
    for(int j=1;(1<<j)<=n;j++)
        for(int i=1;i+(1<<j)-1<=n;i++){
            int x=f[i][j-1],y=f[i+(1<<j-1)][j-1];
            if(a[x]>=a[y])f[i][j]=x;
            else f[i][j]=y;
        }
    solve(1,n);
    printf("%lld",ans);
    return 0;
}