只要你跑的够快,锅就追不上你

0%

「学习笔记」快速傅立叶变换

快速傅立叶变换(Fast Fourier Transformation)可以将多项式在系数表示法和(单位复根的)点值表示法之间互相转化,而它的时间复杂度仅为 $O(n \log n)$。

推荐阅读

Algocode 算法博客

代码实现

快速傅立叶变换(FFT):

1
#include <cmath>
2
#include <cstdio>
3
#include <algorithm>
4
using namespace std;
5
6
typedef double db;
7
const int maxn = 1e5, maxm = 1 << 18;
8
const db pi = acos(-1);
9
int n, m, k, rev[maxm + 3], ans[maxm + 3];
10
11
struct complex {
12
	db r, i;
13
	complex(db r = 0, db i = 0): r(r), i(i) {}
14
	friend complex operator+ (const complex &a, const complex &b) {
15
		return complex(a.r + b.r, a.i + b.i);
16
	}
17
	friend complex operator- (const complex &a, const complex &b) {
18
		return complex(a.r - b.r, a.i - b.i);
19
	}
20
	friend complex operator* (const complex &a, const complex &b) {
21
		return complex(a.r * b.r - a.i * b.i, a.r * b.i + a.i * b.r);
22
	}
23
	friend complex operator/ (const complex &a, const db &b) {
24
		return complex(a.r / b, a.i / b);
25
	}
26
};
27
28
complex a[maxm + 3], b[maxm + 3], c[maxm + 3];
29
30
void dft(complex a[], int n, int type) {
31
	for (int i = 0; i < n; i++) if (i < rev[i]) {
32
		swap(a[i], a[rev[i]]);
33
	}
34
	for (int k = 1; k < n; k *= 2) {
35
		complex x = complex(cos(pi / k), type * sin(pi / k));
36
		for (int i = 0; i < n; i += k * 2) {
37
			complex y = 1;
38
			for (int j = i; j < i + k; j++, y = x * y) {
39
				complex p = a[j], q = a[j + k] * y;
40
				a[j] = p + q, a[j + k] = p - q;
41
			}
42
		}
43
	}
44
	if (type == -1) {
45
		for (int i = 0; i < n; i++) {
46
			a[i] = a[i] / n;
47
		}
48
	}
49
}
50
51
int main() {
52
	scanf("%d %d", &n, &m);
53
	for (int i = 0, x; i <= n; i++) {
54
		scanf("%d", &x), a[i].r = x;
55
	}
56
	for (int i = 0, x; i <= m; i++) {
57
		scanf("%d", &x), b[i].r = x;
58
	}
59
	for (k = 0; 1 << k <= n + m; k++);
60
	for (int i = 1; i < 1 << k; i++) {
61
		rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (k - 1));
62
	}
63
	dft(a, 1 << k, 1), dft(b, 1 << k, 1);
64
	for (int i = 0; i < 1 << k; i++) {
65
		c[i] = a[i] * b[i];
66
	}
67
	dft(c, 1 << k, -1);
68
	for (int i = 0; i <= n + m; i++) {
69
		ans[i] = c[i].r + .5;
70
	}
71
	for (int i = 0; i <= n + m; i++) {
72
		printf("%d%c", ans[i], " \n"[i == n + m]);
73
	}
74
	return 0;
75
}

快速数论变换(NTT):

1
#include <cstdio>
2
#include <algorithm>
3
using namespace std;
4
5
const int maxn = 1e5, maxm = 1 << 18, mod = 998244353, g = 3;
6
int n, m, k, rev[maxm + 3], a[maxm + 3], b[maxm + 3], c[maxm + 3];
7
8
inline int func(int x, int y = mod) {
9
	return x < 0 ? x + y : x < y ? x : x - y;
10
}
11
12
int qpow(int a, int b) {
13
	int c = 1;
14
	for (; b; b >>= 1, a = 1ll * a * a % mod) {
15
		if (b & 1) c = 1ll * a * c % mod;
16
	}
17
	return c;
18
}
19
20
void dft(int a[], int n, int type) {
21
	for (int i = 0; i < n; i++) if (i < rev[i]) {
22
		swap(a[i], a[rev[i]]);
23
	}
24
	for (int k = 1; k < n; k *= 2) {
25
		int x = qpow(g, func(type * (mod - 1) / (k * 2), mod - 1));
26
		for (int i = 0; i < n; i += 2 * k) {
27
			int y = 1;
28
			for (int j = i; j < i + k; j++, y = 1ll * x * y % mod) {
29
				int p = a[j], q = 1ll * a[j + k] * y % mod;
30
				a[j] = func(p + q), a[j + k] = func(p - q);
31
			}
32
		}
33
	}
34
	if (type == -1) {
35
		int x = qpow(n, mod - 2);
36
		for (int i = 0; i < n; i++) {
37
			a[i] = 1ll * a[i] * x % mod;
38
		}
39
	}
40
}
41
42
int main() {
43
	scanf("%d %d", &n, &m);
44
	for (int i = 0; i <= n; i++) {
45
		scanf("%d", &a[i]);
46
	}
47
	for (int i = 0; i <= m; i++) {
48
		scanf("%d", &b[i]);
49
	}
50
	for (k = 0; 1 << k <= n + m; k++);
51
	for (int i = 1; i < 1 << k; i++) {
52
		rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (k - 1));
53
	}
54
	dft(a, 1 << k, 1), dft(b, 1 << k, 1);
55
	for (int i = 0; i < 1 << k; i++) {
56
		c[i] = 1ll * a[i] * b[i] % mod;
57
	}
58
	dft(c, 1 << k, -1);
59
	for (int i = 0; i <= n + m; i++) {
60
		printf("%d%c", c[i], " \n"[i == n + m]);
61
	}
62
	return 0;
63
}