只要你跑的够快,锅就追不上你

0%

「AHOI / HNOI 2017」礼物(生成函数 + 多项式)

题目大意

「AHOI / HNOI 2017」礼物(Luogu 3723)

有两个长度为 $n$ 的手环,每个位置上的亮度分别为 $x_1, x_2, \cdots, x_n$ 和 $y_1, y_2, \cdots, y_n$,它们都是 $[1, m]$ 中的整数。你可以给某个手环整体加上 $c$(整数),并将它旋转(循环位移)$k$ 位,然后使得 $\sum_{i = 1}^{n} (x_i - y_i) ^ 2$ 最小。

数据范围:$n \le 5 \times 10^4, 1 \le m \le 100$。

思路分析

我们先假设手环已经旋转完成,我们考虑如何选择最好的 $c$。可以发现:

和 $c$ 相关的是一个二次函数,其中二次项和一次项都为定值,所以我们可以求出二次函数的最小值。常数项中前面一项也是定值,而后面一项会变化,于是我们只要找到后面一项的最小值即可。换而言之,我们要找到进行循环移位后最小的:

考虑对于每个 $k$ 算出上式。发现这是一个卷积的变形形式,我们只需将一个数组翻转,另一个数组倍长后做一次卷积即可。那么通过 FFT,可以将复杂度做到 $O(n \log n)$。

代码实现

1
#include <cmath>
2
#include <cstdio>
3
#include <algorithm>
4
using namespace std;
5
6
typedef double db;
7
const int maxn = 5e4, maxm = 1 << 18, inf = 1e9 + 1;
8
const db pi = acos(-1);
9
int n, m, k, x[maxn + 3], y[maxn + 3], rev[maxm + 3];
10
11
struct complex {
12
    db r, i;
13
    complex(db r = 0, db i = 0): r(r), i(i) {}
14
    friend complex operator+ (const complex &a, const complex &b) {
15
        return complex(a.r + b.r, a.i + b.i);
16
    }
17
    friend complex operator- (const complex &a, const complex &b) {
18
        return complex(a.r - b.r, a.i - b.i);
19
    }
20
    friend complex operator* (const complex &a, const complex &b) {
21
        return complex(a.r * b.r - a.i * b.i, a.r * b.i + a.i * b.r);
22
    }
23
    friend complex operator/ (const complex &a, const db &b) {
24
        return complex(a.r / b, a.i / b);
25
    }
26
};
27
28
complex a[maxm + 3], b[maxm + 3], c[maxm + 3];
29
30
void dft(complex a[], int n, int type) {
31
    for (int i = 0; i < n; i++) if (i < rev[i]) {
32
        swap(a[i], a[rev[i]]);
33
    }
34
    for (int k = 1; k < n; k *= 2) {
35
        complex x = complex(cos(pi / k), type * sin(pi / k));
36
        for (int i = 0; i < n; i += k * 2) {
37
            complex y = 1;
38
            for (int j = i; j < i + k; j++, y = x * y) {
39
                complex p = a[j], q = a[j + k] * y;
40
                a[j] = p + q, a[j + k] = p - q;
41
            }
42
        }
43
    }
44
    if (type == -1) {
45
        for (int i = 0; i < n; i++) {
46
            a[i] = a[i] / n;
47
        }
48
    }
49
}
50
51
int main() {
52
    scanf("%d %d", &n, &m);
53
    int p = n, q = 0, r = 0;
54
    for (int i = 1; i <= n; i++) {
55
        scanf("%d", &x[i]);
56
        q += 2 * x[i], r += x[i] * x[i];
57
    }
58
    for (int i = 1; i <= n; i++) {
59
        scanf("%d", &y[i]);
60
        q -= 2 * y[i], r += y[i] * y[i];
61
    }
62
    int ans = inf;
63
    for (int x = -m; x <= m; x++) {
64
        ans = min(1ll * ans, 1ll * p * x * x + 1ll * q * x + r);
65
    }
66
    for (int i = 1; i <= n; i++) {
67
        a[n - i] = x[i];
68
    }
69
    for (int i = 1; i <= n; i++) {
70
        b[i - 1] = b[n + i - 1] = y[i];
71
    }
72
    for (k = 0; 1 << k <= 3 * n - 2; k++);
73
    for (int i = 1; i < 1 << k; i++) {
74
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (k - 1));
75
    }
76
    dft(a, 1 << k, 1), dft(b, 1 << k, 1);
77
    for (int i = 0; i < 1 << k; i++) {
78
        c[i] = a[i] * b[i];
79
    }
80
    dft(c, 1 << k, -1);
81
    int res = -inf;
82
    for (int i = n - 1; i < 2 * n - 1; i++) {
83
        res = max(res, int(c[i].r + .5));
84
    }
85
    ans -= 2 * res;
86
    printf("%d\n", ans);
87
    return 0; 
88
}