你也许听过
傅里叶变换,它可以把信号在时域和频域互相转化。打个比方,如果说时域是一首歌曲的波形,那么频域就是乐谱。最后进入我们耳朵的是复杂的声波,但它是由一系列特定频率的简单波形按一定规律组合得到的。直接在时域上对波形处理可能是比较困难的,但经过傅里叶变换,我们可以把它转化到较好处理的频域上,处理后再通过相应的
逆变换转化回去。
维基百科上介绍傅里叶变换的经典gif
离散傅里叶变换(Discrete Fourier Transform,
DFT)是傅里叶变换在时域和频域上都呈离散的形式。它在很多领域都有各种不同的应用,但在算法竞赛上,主要是用来解决
多项式乘法(卷积)等问题。
我们往往把
次多项式写成
,它由它的
个系数
完全确定,所以这也叫
系数表达法。这可以看作这个多项式的 频域,多项式就是由若干简单的幂函数线性组合而成。
但还有另一种表达方法,即
点值表达法。在多项式上取
个不同的点(这相当于 时域上的 采样):
,这些点也可以唯一地确定多项式。
点值表达法有一个好处:假如
是另一个多项式
的点值表达法,那么设
,则立得
的点值表达(就是分别相乘):
。
但是,题目可不会好心给我们点值表达法,所以我们就需要把系数表达法转化为点值表达法。是的,我前面已经暗示了,这里可以用上
离散傅里叶变换了。现在是频域不好处理,我们就把它转化到时域。取一组特殊点
,并设
(
是虚数单位,这里把多项式当作 复多项式处理,这些特殊点实际上就是单位根),则:
(或:
)
这里的
是多项式的系数。对应地,有逆变换:
(或:
)
(实际上,这个“逆变换”才是信号领域的DFT,我们平时说的DFT其实是DFT的逆变换,不信你随便看一个跟算法竞赛无关的介绍DFT的博客。当然系数也可能有点差别,但是这个系数其实不重要,反正都是要转化回来的,只要保证两个系数相乘等于
就好了。 )
别急着抄公式,这个公式看起来很高端,其实就是带进去变了下形,和随便选
个点没什么区别,时间复杂度仍然是
。但是,我们可以利用单位根的一些性质,将它的复杂度减小到
。现在,我们请出这篇笔记的主角:
快速傅里叶变换。
快速傅里叶变换(Fast Fourier Transform,
FFT)利用分治思想简化DFT的计算,时间复杂度是
。
我们来玩一下这个公式:
。不妨设
为偶数,把
奇偶项分开:
如果我们把原来的系数序列的奇偶项分别看作一个
新的系数序列,即令
,
,我们也可以
分别对它们进行离散傅里叶变换,分别设
、
,注意到
、
分别就是
和
,所以我们得到一个公式:
但是这个式子只在
时才成立(因为
、
长度均为
,只有在
时,
、
才能分别用
和
代替),如何确定剩下得
个系数呢?其实用
代替
分别代入
和
会发现,
与
是相等的(注意欧拉恒等式
),
也是如此。而
,所以:
就是单位根
,所以也可以写成:
所以只需要求出
和
,就可以在
内求出
,我们把 问题规模缩小了一半!显然
和
也可以用同样的方法求下去,这样一直
递归下去,就可以用
的时间完成DFT!
当然能一直递归下去的条件是
是
的整次幂。但这不成问题,若非如此,直接补成
的整次幂即可,前导零不影响计算。对于
次多项式,一般可以用这样的代码:
int N = 1 << int(log2(n) + 1);
但这里还有一个小坑,如果有0次多项式,会出现
log2(0)
这样的表达式,注意避免……
现在我们来一步一步实现这个算法。注意我们这里的计算大部分是在
复数域上进行的,我们可以用
<complex>
中提供的复数模板类:
typedef complex<double> comp; // 重命名一下以免麻烦
定义一些常量:
const comp I(0, 1);
const double PI = acos(-1); // 环境允许也可以使用M_PI
注意到DFT与其逆变换之间之间只有很小的区别,所以我们可以只用一个函数实现两者的功能:
// 把F从频域变换到时域,或从时域变换到频域
void fft(comp F[], int N, int inv = 1); // inv = 1表示DFT,-1表示其逆变换
递归出口是N=1,这时频域和时域相同(都只有一个数):
if (N == 1)
return;
先把偶数项放到F的前半部分,奇数项放到后半部分:
memcpy(tmp, F, sizeof(comp) * N); // 先复制一个临时数组
for (int i = 0; i < N; i++)
*(i % 2 ? F + i / 2 + N / 2 : F + i / 2) = tmp[i]; // 这里写得有点花哨,也可以展开成if-else
递归地对前半部分和后半部分进行变换:
fft(F, N / 2, inv), fft(F + N / 2, N / 2, inv);
然后按刚刚推导的公式进行计算即可:
comp *G = F, *H = F + N / 2; // 用指针比较直观
comp cur = 1, step = exp(2 * PI / N * inv * I); // 单位根可以递推计算,逆变换要多乘一个-1
for (int k = 0; k < N / 2; k++)
{
tmp[k] = G[k] + cur * H[k];
tmp[k + N / 2] = G[k] - cur * H[k];
cur *= step;
}
memcpy(F, tmp, sizeof(comp) * N);
// 这里逆变换没有除以N,因为复数除法很慢,可以等要用到时先用real()求实部再除以N
这是朴素的FFT算法,它可以通过模板题(多项式乘法),代码如下:
#include <bits/stdc++.h>
using namespace std;
inline int read()
{
int ans = 0;
char c = getchar();
while (!isdigit(c))
c = getchar();
while (isdigit(c))
{
ans = ans * 10 + c - '0';
c = getchar();
}
return ans;
}
typedef complex<double> comp;
const int MAXN = 1000005;
const comp I(0, 1);
const double PI = acos(-1);
comp A[MAXN * 3], B[MAXN * 3], tmp[MAXN * 3], ans[MAXN * 3]; // 数组开大一些
void fft(comp F[], int N, int inv = 1)
{
if (N == 1)
return;
memcpy(tmp, F, sizeof(comp) * N);
for (int i = 0; i < N; i++)
*(i % 2 ? F + i / 2 + N / 2 : F + i / 2) = tmp[i];
fft(F, N / 2, inv), fft(F + N / 2, N / 2, inv);
comp *G = F, *H = F + N / 2;
comp cur = 1, step = exp(2 * PI / N * inv * I);
for (int k = 0; k < N / 2; k++)
{
tmp[k] = G[k] + cur * H[k];
tmp[k + N / 2] = G[k] - cur * H[k];
cur *= step;
}
memcpy(F, tmp, sizeof(comp) * N);
}
int main()
{
int n = read(), m = read(), N = 1 << int(log2(n + m) + 1); // 补成2的整次幂
for (int i = 0; i <= n; ++i)
A[i] = read();
for (int i = 0; i <= m; ++i)
B[i] = read();
fft(A, N), fft(B, N);
for (int i = 0; i < N; ++i)
ans[i] = A[i] * B[i];
fft(ans, N, -1);
for (int i = 0; i <= n + m; ++i)
printf("%d ", int(ans[i].real() / N + 0.1)); // +0.1规避浮点误差
return 0;
}
这个算法把数组复制过来复制过去,不免还是有点慢,我们可以对其进行
常数优化。
为了避免反复的拷贝,我们可以提前确定序列中每个元素最后的位置。把下标用二进制表示,模拟递归过程如下:
把偶数项放左边,奇数项放右边:
再对左右子序列按奇偶排序:
观察一下,是否每个元素最后的下标,都是原来的二进制表示的
翻转?这被称为
蝴蝶变换,我们通过蝴蝶变换,就可以提前让序列就位,然后通过迭代而不是递归地进行求解,防止拷贝。
蝴蝶变换可以通过递推
地实现:
int bit = log2(N) - 1;
for (int i = 0; i < N; ++i)
{
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << bit);
if (i < rev[i]) // 这里要防止防止重复交换
swap(a[i], a[rev[i]]);
}
(rev[i >> 1] >> 1) | ((i & 1) << limit)
这段代码很有意思,我们把一个二进制数分为两段:head+tail,其中tail是最后一位,head是前面的位,我们希望得到tail+rev(head)。
那么i >> 1就是0+head,rev[i >> 1]就是rev(head)+0,rev[i>>1]>>1就是rev(head),这是前半部分。而后半部分则是取出最后一位,然后左移到rev(head)前面。
接下来我们再考虑合并的问题,再来看这组公式:
由于每个数的位置都已经确定了,这就完全可以用迭代的方式进行,具体看代码:
// 非递归版FFT,确保N是2的整次幂
int rev[MAXN * 3];
const comp I(0, 1);
const double PI = acos(-1);
void fft(comp F[], int N, int inv = 1)
{
int bit = log2(N);
for (int i = 0; i < N; ++i) // 蝴蝶变换
{
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
if (i < rev[i])
swap(F[i], F[rev[i]]);
}
for (int l = 1; l < N; l *= 2) // 枚举合并前的区间长度
{
comp step = exp(inv * PI / l * I);
for (int i = 0; i < N; i += l * 2) // 遍历起始点
{
comp cur(1, 0);
for (int k = i; k < i + l; ++k)
{
comp g = F[k], h = F[k + l] * cur;
F[k] = g + h, F[k + l] = g - h;
cur *= step;
}
}
}
}
// 逆变换记得在外部把实部除以N
这个方法比朴素的递归FFT快了近一倍,也是最常用的版本。
顺带一提,对于多项式乘法,还有一个“
三步变两步”优化:设
和
是 实多项式,
,则
,注意到我们要求的
正是
虚部的一半。这样只需要两次FFT就可以求出结果。
for (int i = 0; i <= n; ++i)
A[i] = read();
for (int i = 0; i <= m; ++i)
B[i] = read();
for (int i = 0; i <= max(n, m); ++i)
F[i] = comp(A[i], B[i]);
fft(F, N);
for (int i = 0; i < N; ++i)
F[i] = F[i] * F[i];
fft(F, N, -1);
for (int i = 0; i <= n + m; ++i)
printf("%d ", int(F[i].imag() / (N * 2) + 0.1));
https://zhuanlan.zhihu.com/p/105467597zhuanlan.zhihu.com