题目大意
「AHOI / HNOI 2017」礼物(Luogu 3723)
有两个长度为 $n$ 的手环,每个位置上的亮度分别为 $x_1, x_2, \cdots, x_n$ 和 $y_1, y_2, \cdots, y_n$,它们都是 $[1, m]$ 中的整数。你可以给某个手环整体加上 $c$(整数),并将它旋转(循环位移)$k$ 位,然后使得 $\sum_{i = 1}^{n} (x_i - y_i) ^ 2$ 最小。
数据范围:$n \le 5 \times 10^4, 1 \le m \le 100$。
思路分析
我们先假设手环已经旋转完成,我们考虑如何选择最好的 $c$。可以发现:
和 $c$ 相关的是一个二次函数,其中二次项和一次项都为定值,所以我们可以求出二次函数的最小值。常数项中前面一项也是定值,而后面一项会变化,于是我们只要找到后面一项的最小值即可。换而言之,我们要找到进行循环移位后最小的:
考虑对于每个 $k$ 算出上式。发现这是一个卷积的变形形式,我们只需将一个数组翻转,另一个数组倍长后做一次卷积即可。那么通过 FFT,可以将复杂度做到 $O(n \log n)$。
代码实现
1 |
|
2 |
|
3 |
|
4 | using namespace std; |
5 | |
6 | typedef double db; |
7 | const int maxn = 5e4, maxm = 1 << 18, inf = 1e9 + 1; |
8 | const db pi = acos(-1); |
9 | int n, m, k, x[maxn + 3], y[maxn + 3], rev[maxm + 3]; |
10 | |
11 | struct complex { |
12 | db r, i; |
13 | complex(db r = 0, db i = 0): r(r), i(i) {} |
14 | friend complex operator+ (const complex &a, const complex &b) { |
15 | return complex(a.r + b.r, a.i + b.i); |
16 | } |
17 | friend complex operator- (const complex &a, const complex &b) { |
18 | return complex(a.r - b.r, a.i - b.i); |
19 | } |
20 | friend complex operator* (const complex &a, const complex &b) { |
21 | return complex(a.r * b.r - a.i * b.i, a.r * b.i + a.i * b.r); |
22 | } |
23 | friend complex operator/ (const complex &a, const db &b) { |
24 | return complex(a.r / b, a.i / b); |
25 | } |
26 | }; |
27 | |
28 | complex a[maxm + 3], b[maxm + 3], c[maxm + 3]; |
29 | |
30 | void dft(complex a[], int n, int type) { |
31 | for (int i = 0; i < n; i++) if (i < rev[i]) { |
32 | swap(a[i], a[rev[i]]); |
33 | } |
34 | for (int k = 1; k < n; k *= 2) { |
35 | complex x = complex(cos(pi / k), type * sin(pi / k)); |
36 | for (int i = 0; i < n; i += k * 2) { |
37 | complex y = 1; |
38 | for (int j = i; j < i + k; j++, y = x * y) { |
39 | complex p = a[j], q = a[j + k] * y; |
40 | a[j] = p + q, a[j + k] = p - q; |
41 | } |
42 | } |
43 | } |
44 | if (type == -1) { |
45 | for (int i = 0; i < n; i++) { |
46 | a[i] = a[i] / n; |
47 | } |
48 | } |
49 | } |
50 | |
51 | int main() { |
52 | scanf("%d %d", &n, &m); |
53 | int p = n, q = 0, r = 0; |
54 | for (int i = 1; i <= n; i++) { |
55 | scanf("%d", &x[i]); |
56 | q += 2 * x[i], r += x[i] * x[i]; |
57 | } |
58 | for (int i = 1; i <= n; i++) { |
59 | scanf("%d", &y[i]); |
60 | q -= 2 * y[i], r += y[i] * y[i]; |
61 | } |
62 | int ans = inf; |
63 | for (int x = -m; x <= m; x++) { |
64 | ans = min(1ll * ans, 1ll * p * x * x + 1ll * q * x + r); |
65 | } |
66 | for (int i = 1; i <= n; i++) { |
67 | a[n - i] = x[i]; |
68 | } |
69 | for (int i = 1; i <= n; i++) { |
70 | b[i - 1] = b[n + i - 1] = y[i]; |
71 | } |
72 | for (k = 0; 1 << k <= 3 * n - 2; k++); |
73 | for (int i = 1; i < 1 << k; i++) { |
74 | rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (k - 1)); |
75 | } |
76 | dft(a, 1 << k, 1), dft(b, 1 << k, 1); |
77 | for (int i = 0; i < 1 << k; i++) { |
78 | c[i] = a[i] * b[i]; |
79 | } |
80 | dft(c, 1 << k, -1); |
81 | int res = -inf; |
82 | for (int i = n - 1; i < 2 * n - 1; i++) { |
83 | res = max(res, int(c[i].r + .5)); |
84 | } |
85 | ans -= 2 * res; |
86 | printf("%d\n", ans); |
87 | return 0; |
88 | } |