题目大意
给定一个长度为 $n$ 的数列 $a_1, a_2, \cdots, a_n$,构造带权完全图 $G$,边 $(i, j)$ 的边权是 $a_i \oplus a_j$,其中 $\oplus$ 表示异或运算。求图 $G$ 的最小生成树。
数据范围:$n \le 2 \times 10^5, a_i \le 2^{30}$。
思路分析
暴力建图是不可取的。我们考虑异或运算的性质。
假设当前所有 $a_i$ 的最高位为 $k$。我们将 $a_i$ 分成两组:这位是 $0$ 的和这位是 $1$ 的。如果两组中都有数,那么生成树上至少要在两组之间连一条边。
两组之间连的边的权值的最高位肯定为 $k$,所以我们要连尽可能少的边,也就是只连一条。我们对于某一组的所有数建立 Trie,然后对于另一组的每个数都去查它与第一组数的最小异或和,这样我们就找到了一条权值最小的边。
接着,我们递归求解两组内部的最小生成树,也就是递归到一个子问题。这样,我们就求出了整个图的最小生成树。
由于递归层数为 $\log a_i$,Trie 树的复杂度为 $O(\log a_i)$,所以总时间复杂度 $O(n \log^2 a_i)$,可以通过本题。
代码实现
1 |
|
2 | using namespace std; |
3 | |
4 | typedef long long ll; |
5 | const int maxn = 2e5, logv = 30, maxm = maxn * logv; |
6 | int n, m, a[maxn + 3], ch[maxm + 3][2], cur; |
7 | |
8 | void clear() { |
9 | for (int i = 1; i <= m; i++) { |
10 | ch[i][0] = ch[i][1] = 0; |
11 | } |
12 | m = 1; |
13 | } |
14 | |
15 | void insert(int d, int x) { |
16 | int u = 1; |
17 | for (int i = d, k; ~i; i--) { |
18 | k = x >> i & 1; |
19 | if (!ch[u][k]) { |
20 | ch[u][k] = ++m; |
21 | } |
22 | u = ch[u][k]; |
23 | } |
24 | } |
25 | |
26 | int query(int d, int x) { |
27 | int u = 1, res = 0; |
28 | for (int i = d, k; ~i; i--) { |
29 | k = x >> i & 1; |
30 | if (ch[u][k]) { |
31 | u = ch[u][k]; |
32 | } else { |
33 | u = ch[u][k ^ 1]; |
34 | res |= 1 << i; |
35 | } |
36 | } |
37 | return res; |
38 | } |
39 | |
40 | bool comp(int i, int j) { |
41 | return (i >> cur & 1) < (j >> cur & 1); |
42 | } |
43 | |
44 | ll solve(int d, int l, int r) { |
45 | if (d < 0 || l >= r) { |
46 | return 0; |
47 | } |
48 | int t = 0; |
49 | for (int i = l; i <= r; i++) { |
50 | if (a[i] >> d & 1) { |
51 | t |= 1; |
52 | } else { |
53 | t |= 2; |
54 | } |
55 | } |
56 | if (t < 3) { |
57 | return solve(d - 1, l, r); |
58 | } |
59 | clear(); |
60 | for (int i = l; i <= r; i++) { |
61 | if (a[i] >> d & 1) { |
62 | insert(d - 1, a[i] ^ (1 << d)); |
63 | } |
64 | } |
65 | int cnt = 0; |
66 | ll ans = 1 << logv; |
67 | for (int i = l; i <= r; i++) { |
68 | if (~a[i] >> d & 1) { |
69 | cnt++; |
70 | ans = min(ans, 1ll * query(d - 1, a[i])); |
71 | } |
72 | } |
73 | ans += 1 << d; |
74 | cur = d; |
75 | sort(a + l, a + r + 1, comp); |
76 | ans += solve(d - 1, l, l + cnt - 1); |
77 | ans += solve(d - 1, l + cnt, r); |
78 | return ans; |
79 | } |
80 | |
81 | int main() { |
82 | scanf("%d", &n); |
83 | for (int i = 1; i <= n; i++) { |
84 | scanf("%d", &a[i]); |
85 | } |
86 | printf("%lld\n", solve(logv - 1, 1, n)); |
87 | return 0; |
88 | } |