只要你跑的够快,锅就追不上你

0%

「HNOI 2011」勾股定理(数论 + 动态规划)

题目大意

「HNOI 2011」勾股定理(Luogu 3213)

给定 $n$ 个数 $a_1, a_2, \cdots, a_n$,问可以选出多少种位置集合,满足集合内部任意两个位置对应的数不形成互素勾股数对。互素勾股数对的意义是:$(a, b) = 1$,且 $a^2 + b^2$ 可以表示成 $c^2$ 的形式。

数据范围:$n \le 10^6$,$1 \le a_i \le 2 \times 10^5$ 或 $2 \times 10^4 \le a_i \le 10^6$。

思路分析

互素勾股数对很少,我们考虑快速地找到他们。通过小学奥数可知,勾股数对一定可以写成如下形式:

如果 $x^2 - y^2$ 和 $2xy$ 互素,那么 $x, y$ 互素并且不全是奇数。考虑枚举 $x, y$,再对它们进行检验,具体代码如下:

1
void prework() {
2
	for (int x = 1; x * x <= m; x++) {
3
		for (int y = x + 1; 2 * x * y <= m; y++) {
4
			if (y * y > 2 * m) break;
5
			if (y * y - x * x <= m && (x & 1) != (y & 1) && gcd(x, y) == 1) {
6
				int a = y * y - x * x, b = 2 * x * y;
7
				// (a, b) 是一组 <= m 的互素勾股数对
8
			}
9
		}
10
	}
11
}

接下来根据预处理出的勾股数对建图,问题就转化成了求图的独立集个数。我们发现这个图的边数和点数很接近,于是我们可以搞出图的一个 DFS 树,然后对于非树边容斥,做树形 DP 即可。(其实就是「HNOI2018」毒瘤(Luogu 4426)一题的弱化版,不需要建虚树)

代码实现

1
#include <cstdio>
2
#include <vector>
3
using namespace std;
4
5
const int maxn = 1e6, mod = 1e9 + 7;
6
int n, m, bin[maxn + 3], cnt[maxn + 3], k, p[maxn + 3], q[maxn + 3], tm, lim[maxn + 3], f[maxn + 3][2];
7
bool vis[maxn + 3];
8
vector<int> G[maxn + 3], T[maxn + 3];
9
10
void add(int u, int v) {
11
	G[u].push_back(v);
12
}
13
14
int gcd(int a, int b) {
15
	return b ? gcd(b, a % b) : a;
16
}
17
18
void prework() {
19
	bin[0] = 1;
20
	for (int i = 1; i <= n; i++) {
21
		bin[i] = bin[i - 1] * 2 % mod;
22
	}
23
	for (int x = 1; x * x <= m; x++) {
24
		for (int y = x + 1; 2 * x * y <= m; y++) {
25
			if (y * y > 2 * m) break;
26
			if (y * y - x * x <= m && (x & 1) != (y & 1) && gcd(x, y) == 1) {
27
				int a = y * y - x * x, b = 2 * x * y;
28
				if (cnt[a] && cnt[b]) {
29
					add(a, b), add(b, a);
30
				}
31
			}
32
		}
33
	}
34
}
35
36
void dfs(int u, int pa = 0) {
37
	vis[u] = true;
38
	for (int i = 0, v; i < G[u].size(); i++) {
39
		v = G[u][i];
40
		if (v == pa) continue;
41
		if (!vis[v]) {
42
			T[u].push_back(v);
43
			dfs(v, u);
44
		} else {
45
			++k;
46
			p[k] = u;
47
			q[k] = v;
48
		}
49
	}
50
}
51
52
void dp(int u) {
53
	f[u][0] = 1, f[u][1] = bin[cnt[u]] - 1;
54
	if (lim[u] == tm) f[u][0] = 0;
55
	for (int i = 0, v; i < T[u].size(); i++) {
56
		dp(v = T[u][i]);
57
		f[u][0] = 1ll * f[u][0] * (f[v][0] + f[v][1]) % mod;
58
		f[u][1] = 1ll * f[u][1] * f[v][0] % mod;
59
	}
60
}
61
62
int bit_cnt(int x) {
63
	return __builtin_popcount(x);
64
}
65
66
int solve(int u) {
67
	k = 0;
68
	dfs(u);
69
	int ans = 0;
70
	for (int i = 0; i < 1 << k; i++) {
71
		++tm;
72
		for (int j = 1; j <= k; j++) {
73
			if (i >> (k - j) & 1) {
74
				lim[p[j]] = lim[q[j]] = tm;
75
			}
76
		}
77
		dp(u);
78
		ans = (ans + 1ll * (f[u][0] + f[u][1]) * (bit_cnt(i) & 1 ? mod - 1 : 1)) % mod;
79
	}
80
	return ans;
81
}
82
83
int main() {
84
	scanf("%d", &n);
85
	for (int i = 1, x; i <= n; i++) {
86
		scanf("%d", &x);
87
		m = max(m, x), cnt[x]++;
88
	}
89
	prework();
90
	int res = 1;
91
	for (int i = 1; i <= m; i++) if (cnt[i] && !vis[i]) {
92
		res = 1ll * res * solve(i) % mod;
93
	}
94
	printf("%d\n", (res + mod - 1) % mod);
95
	return 0;
96
}