题目大意
「Codeforces 1096G」Lucky Tickets
给定偶数 $n$ 和 $k$ 个数位,求长度为 $n$ 的数字串,满足只使用给定的数位,且前 $\frac{n}{2}$ 位的和等于后 $\frac{n}{2}$ 位的和的个数 $\bmod 998244353$ 的结果。
数据范围:$n \le 2 \times 10^5$。
思路分析
记 $A_i$ 表示使用给定数位,$\frac{n}{2}$ 位的和为 $i$ 的方案数,那么答案为 $\sum_{i = 0}^{\infty}A_i^2$,于是我们只需求 $A$ 即可。
发现 $A$ 是由给定数位形成的生成函数的 $\frac{n}{2}$ 次方,因为生成函数的次数不超过 $9$,所以 $A$ 的最大次数不超过 $4.5n$,用 NTT 求解即可。 时间复杂度 $O(n \log n)$。
代码实现
1 |
|
2 |
|
3 | using namespace std; |
4 | |
5 | const int maxn = 1 << 20, mod = 998244353, g = 3; |
6 | int n, m, q, k, rev[maxn + 3], a[maxn + 3]; |
7 | |
8 | inline int func(int x, int y = mod) { |
9 | return x < 0 ? x + y : x < y ? x : x - y; |
10 | } |
11 | |
12 | int qpow(int a, int b) { |
13 | int c = 1; |
14 | for (; b; b >>= 1, a = 1ll * a * a % mod) { |
15 | if (b & 1) c = 1ll * a * c % mod; |
16 | } |
17 | return c; |
18 | } |
19 | |
20 | void dft(int a[], int n, int type) { |
21 | for (int i = 0; i < n; i++) if (i < rev[i]) { |
22 | swap(a[i], a[rev[i]]); |
23 | } |
24 | for (int k = 1; k < n; k *= 2) { |
25 | int x = qpow(g, func(type * (mod - 1) / (k * 2), mod - 1)); |
26 | for (int i = 0; i < n; i += 2 * k) { |
27 | int y = 1; |
28 | for (int j = i; j < i + k; j++, y = 1ll * x * y % mod) { |
29 | int p = a[j], q = 1ll * a[j + k] * y % mod; |
30 | a[j] = func(p + q), a[j + k] = func(p - q); |
31 | } |
32 | } |
33 | } |
34 | if (type == -1) { |
35 | int x = qpow(n, mod - 2); |
36 | for (int i = 0; i < n; i++) { |
37 | a[i] = 1ll * a[i] * x % mod; |
38 | } |
39 | } |
40 | } |
41 | |
42 | int main() { |
43 | scanf("%d %d", &n, &q); |
44 | n /= 2, m = n * 9; |
45 | for (int x; q--; ) { |
46 | scanf("%d", &x); |
47 | a[x] = 1; |
48 | } |
49 | for (k = 0; 1 << k < m; k++); |
50 | for (int i = 1; i < 1 << k; i++) { |
51 | rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (k - 1)); |
52 | } |
53 | dft(a, 1 << k, 1); |
54 | for (int i = 0; i < 1 << k; i++) { |
55 | a[i] = qpow(a[i], n); |
56 | } |
57 | dft(a, 1 << k, -1); |
58 | int ans = 0; |
59 | for (int i = 0; i <= m; i++) { |
60 | ans = (ans + 1ll * a[i] * a[i]) % mod; |
61 | } |
62 | printf("%d\n", ans); |
63 | return 0; |
64 | } |