题目大意
给定 $n$ 个数 $a_1, a_2, \cdots, a_n$,问可以选出多少种位置集合,满足集合内部任意两个位置对应的数不形成互素勾股数对。互素勾股数对的意义是:$(a, b) = 1$,且 $a^2 + b^2$ 可以表示成 $c^2$ 的形式。
数据范围:$n \le 10^6$,$1 \le a_i \le 2 \times 10^5$ 或 $2 \times 10^4 \le a_i \le 10^6$。
思路分析
互素勾股数对很少,我们考虑快速地找到他们。通过小学奥数可知,勾股数对一定可以写成如下形式:
如果 $x^2 - y^2$ 和 $2xy$ 互素,那么 $x, y$ 互素并且不全是奇数。考虑枚举 $x, y$,再对它们进行检验,具体代码如下:
1 | void prework() { |
2 | for (int x = 1; x * x <= m; x++) { |
3 | for (int y = x + 1; 2 * x * y <= m; y++) { |
4 | if (y * y > 2 * m) break; |
5 | if (y * y - x * x <= m && (x & 1) != (y & 1) && gcd(x, y) == 1) { |
6 | int a = y * y - x * x, b = 2 * x * y; |
7 | // (a, b) 是一组 <= m 的互素勾股数对 |
8 | } |
9 | } |
10 | } |
11 | } |
接下来根据预处理出的勾股数对建图,问题就转化成了求图的独立集个数。我们发现这个图的边数和点数很接近,于是我们可以搞出图的一个 DFS 树,然后对于非树边容斥,做树形 DP 即可。(其实就是「HNOI2018」毒瘤(Luogu 4426)一题的弱化版,不需要建虚树)
代码实现
1 |
|
2 |
|
3 | using namespace std; |
4 | |
5 | const int maxn = 1e6, mod = 1e9 + 7; |
6 | int n, m, bin[maxn + 3], cnt[maxn + 3], k, p[maxn + 3], q[maxn + 3], tm, lim[maxn + 3], f[maxn + 3][2]; |
7 | bool vis[maxn + 3]; |
8 | vector<int> G[maxn + 3], T[maxn + 3]; |
9 | |
10 | void add(int u, int v) { |
11 | G[u].push_back(v); |
12 | } |
13 | |
14 | int gcd(int a, int b) { |
15 | return b ? gcd(b, a % b) : a; |
16 | } |
17 | |
18 | void prework() { |
19 | bin[0] = 1; |
20 | for (int i = 1; i <= n; i++) { |
21 | bin[i] = bin[i - 1] * 2 % mod; |
22 | } |
23 | for (int x = 1; x * x <= m; x++) { |
24 | for (int y = x + 1; 2 * x * y <= m; y++) { |
25 | if (y * y > 2 * m) break; |
26 | if (y * y - x * x <= m && (x & 1) != (y & 1) && gcd(x, y) == 1) { |
27 | int a = y * y - x * x, b = 2 * x * y; |
28 | if (cnt[a] && cnt[b]) { |
29 | add(a, b), add(b, a); |
30 | } |
31 | } |
32 | } |
33 | } |
34 | } |
35 | |
36 | void dfs(int u, int pa = 0) { |
37 | vis[u] = true; |
38 | for (int i = 0, v; i < G[u].size(); i++) { |
39 | v = G[u][i]; |
40 | if (v == pa) continue; |
41 | if (!vis[v]) { |
42 | T[u].push_back(v); |
43 | dfs(v, u); |
44 | } else { |
45 | ++k; |
46 | p[k] = u; |
47 | q[k] = v; |
48 | } |
49 | } |
50 | } |
51 | |
52 | void dp(int u) { |
53 | f[u][0] = 1, f[u][1] = bin[cnt[u]] - 1; |
54 | if (lim[u] == tm) f[u][0] = 0; |
55 | for (int i = 0, v; i < T[u].size(); i++) { |
56 | dp(v = T[u][i]); |
57 | f[u][0] = 1ll * f[u][0] * (f[v][0] + f[v][1]) % mod; |
58 | f[u][1] = 1ll * f[u][1] * f[v][0] % mod; |
59 | } |
60 | } |
61 | |
62 | int bit_cnt(int x) { |
63 | return __builtin_popcount(x); |
64 | } |
65 | |
66 | int solve(int u) { |
67 | k = 0; |
68 | dfs(u); |
69 | int ans = 0; |
70 | for (int i = 0; i < 1 << k; i++) { |
71 | ++tm; |
72 | for (int j = 1; j <= k; j++) { |
73 | if (i >> (k - j) & 1) { |
74 | lim[p[j]] = lim[q[j]] = tm; |
75 | } |
76 | } |
77 | dp(u); |
78 | ans = (ans + 1ll * (f[u][0] + f[u][1]) * (bit_cnt(i) & 1 ? mod - 1 : 1)) % mod; |
79 | } |
80 | return ans; |
81 | } |
82 | |
83 | int main() { |
84 | scanf("%d", &n); |
85 | for (int i = 1, x; i <= n; i++) { |
86 | scanf("%d", &x); |
87 | m = max(m, x), cnt[x]++; |
88 | } |
89 | prework(); |
90 | int res = 1; |
91 | for (int i = 1; i <= m; i++) if (cnt[i] && !vis[i]) { |
92 | res = 1ll * res * solve(i) % mod; |
93 | } |
94 | printf("%d\n", (res + mod - 1) % mod); |
95 | return 0; |
96 | } |