传送门

首先,遇到等差数列这种形式,最先要想到移项。

A[k]A[j]=A[j]A[i]A[k]+A[i]=2×A[j]A[k]-A[j]=A[j]-A[i] \to A[k]+A[i]=2 \times A[j]

于是很容易想到固定jj,而在jj两边枚举i,ki,k

注意到如果固定jj2×A[j]2 \times A[j]为常数,于是可以构造生成函数Fleft=i=0[iA[1]...A[j1]]xiF_{left}=\sum _{i=0}^\infty [i \in A[1]...A[j-1]] x^iFright=i=0[iA[j+1]...A[n]]xiF_{right}=\sum _{i=0}^\infty [i \in A[j+1]...A[n]] x^i

Fleft×FrightF_{left} \times F_{right}x2×A[j]x^{2 \times A[j]}的系数就是答案。

但是注意到如果这样每次都要两边构造生成函数,然后计算,是O(n2logn)O(n^2 \log n)的。

发现从jjj+1j+1的过程中,变化的FleftF_{left}FrightF_{right}并不多,每次重新FFTFFT似乎有些浪费,而且每次FFTFFT只能计算出一个数的贡献。

于是我们毒瘤地想到,我们要扩大Fleft,FrightF_{left},F_{right}每次变化的次数,比如每次让它变化n\sqrt{n}个数。

于是可以采用分块,对于i,j,ki,j,k中两个以上的数在同一个块的情况,暴力解决:

inline void Query1(int id){//k在i,j右边
	int lb=(id-1)*Size+1,rb=min(id*Size,n);
	memset(Right,0,sizeof(Right));
	Add(Right,rb+1,n);//注意去掉
	for (register int j=lb+1;j<=rb;++j){//枚举中间的j
		for (register int i=lb;i<j;++i){//枚举左边的i
			if (2*A[j]-A[i]>=0) ans+=Right[2*A[j]-A[i]];
		}
	}
}

Query1Query1计算的是i,ji,j在编号为idid的块,而kki,ji,j右边,而且不在i,ji,j所在的块的情况。

inline void Query2(int id){
	int lb=(id-1)*Size+1,rb=min(id*Size,n);
	memset(Left,0,sizeof(Left));
	Add(Left,1,rb-1);
	for (register int j=rb-1;j>=lb;--j){//枚举中间的j
		Left[A[j]]--;
		for (register int k=j+1;k<=rb;++k){//枚举右边的k
			if (2*A[j]-A[k]>=0) ans+=Left[2*A[j]-A[k]];
		}
	}
}

Query2Query2计算的是i,ji,j在编号为idid的块,而kki,ji,j左边的情况。

这样可以做到不重不漏。

剩下FFTFFT非常好写,只要把1lb11\to lb-1rb+1nrb+1 \to nA[i]A[i]丢进LeftLeftRightRight两个数组卷积即可。

时间复杂度分析:假设块大小为szsz,暴力时间复杂度O(sz×n)O(sz \times n)FFTFFT时间复杂度为O(nsz×nlogn)O( \frac{n}{sz} \times \sqrt {n \log n})

于是总时间复杂度为O(sz×n+nsz×nlogn)O(sz \times n + \frac{n}{sz} \times n\log n)

搞一下均值,sz+1sz×nlogn2×nlognsz + \frac{1}{sz} \times n\log n \le 2 \times \sqrt {n \log n}

于是sz=nlognsz=\sqrt{n \log n}时最优。

实测sz=2600sz=2600最优。

时间复杂度O(nnlogn)O(n \sqrt {n \log n})

#include <bits/stdc++.h>
#define MAXN 200005
using namespace std;
inline int read(){
	int x=0,f=1;
	char ch=getchar();
	while (ch<'0'||ch>'9'){
		x=(x<<1)+(x<<3)+(ch^'0');
		ch=getchar();
	}
	while (ch>='0'&&ch<='9'){
		x=(x<<1)+(x<<3)+(ch^'0');
		ch=getchar();
	}
	return x*f;
}
namespace FFT{
	const double PI=acos(-1.0);
	struct Complex{
    	double x,y;
	}a[MAXN],b[MAXN];
	inline Complex operator + (const Complex &A,const Complex &B){
	    return Complex{A.x+B.x,A.y+B.y};
	}
	inline Complex operator - (const Complex &A,const Complex &B){
	    return Complex{A.x-B.x,A.y-B.y};
	}
	inline Complex operator * (const Complex &A,const Complex &B){
	    return Complex{A.x*B.x-A.y*B.y,A.x*B.y+A.y*B.x};
	}
	int r[MAXN];
	inline void FFT(Complex *A,int n,int type){
	    for (register int i=0;i<n;++i) if (i<r[i]) swap(A[i],A[r[i]]);
	    for (register int i=1;i<n;i<<=1){
	        int R=i<<1;
	        Complex Wn=Complex{cos(2*PI/R),type*sin(2*PI/R)};
	        for (register int j=0;j<n;j+=R){
	            Complex w=Complex{1,0};
	            for (register int k=0;k<i;++k,w=w*Wn){
	                Complex x=A[j+k],y=w*A[i+j+k];
	                A[j+k]=x+y,A[i+j+k]=x-y;
	            }
	        }
	    }
	}
	int m,L;
	inline void Init(int len){
		m=1,L=0;
	    while (m<=2*len) m<<=1,L++;
	    memset(r,0,sizeof(r));
	    for (register int i=0;i<=m;++i){
	        r[i]=(r[i>>1]>>1|((i&1)<<(L-1)));
	    }
	}
	inline void Mul(int *des,int *A,int *B,int len){
		Init(len);
		for (register int i=0;i<=len;++i) a[i]=Complex{(double)A[i],0},b[i]=Complex{(double)B[i],0};
		for (register int i=len+1;i<m;++i) a[i]=Complex{0,0},b[i]=Complex{0,0};
		FFT(a,m,1),FFT(b,m,1);
		for (register int i=0;i<m;++i) a[i]=a[i]*b[i];
		FFT(a,m,-1);
		for (register int i=0;i<=len;++i) des[i]=(int)((double)a[i].x/m+0.5);
	}
}
using namespace FFT;
int A[MAXN],id[MAXN],Size,Max;
inline void Add(int *F,int l,int r){
	for (register int i=l;i<=r;++i){
		F[A[i]]++,Max=max(Max,A[i]);
	}
}
//A[k]-A[j]=A[j]-A[i]
//A[k]+A[i]=2*A[j]
//找到j
int Left[MAXN],Right[MAXN],res[MAXN];
long long ans;
int n;
inline void Query1(int id){//k在i,j右边 
	int lb=(id-1)*Size+1,rb=min(id*Size,n);
	memset(Right,0,sizeof(Right));
	Add(Right,rb+1,n);//注意去掉
	for (register int j=lb+1;j<=rb;++j){//枚举中间的j
		for (register int i=lb;i<j;++i){//枚举左边的i
			if (2*A[j]-A[i]>=0) ans+=Right[2*A[j]-A[i]];
		}
	}
}
inline void Query2(int id){
	int lb=(id-1)*Size+1,rb=min(id*Size,n);
	memset(Left,0,sizeof(Left));
	Add(Left,1,rb-1);
	for (register int j=rb-1;j>=lb;--j){//枚举中间的j
		Left[A[j]]--;
		for (register int k=j+1;k<=rb;++k){//枚举右边的k
			if (2*A[j]-A[k]>=0) ans+=Left[2*A[j]-A[k]];
		}
	}
}
int main(){
	n=read();
	for (register int i=1;i<=n;++i){
		A[i]=read();
	}
	Size=sqrt(n*log(n)/log(2));
	for (register int i=1;i<=n;++i){
		id[i]=(i-1)/Size+1;
	}
	for (register int i=1;i<=id[n];++i){//计算每个块中的 
		Query1(i),Query2(i);
	}
	int temp=ans;
	for (register int i=2;i<=id[n]-1;++i){
		int lb=(i-1)*Size+1,rb=min(i*Size,n);
		memset(Left,0,sizeof(Left)),memset(Right,0,sizeof(Right));
		Max=0;
		Add(Left,1,lb-1),Add(Right,rb+1,n);//两边的构造生成函数
		Mul(res,Left,Right,Max*2);
		for (register int j=lb;j<=rb;++j){
			ans+=res[A[j]*2];
		}
	}
	printf("%lld\n",ans);
}