只要你跑的够快,锅就追不上你

0%

「Codeforces 888G」Xor MST(分治 + 字典树)

题目大意

「Codeforces 888G」Xor MST

给定一个长度为 $n$ 的数列 $a_1, a_2, \cdots, a_n$,构造带权完全图 $G$,边 $(i, j)$ 的边权是 $a_i \oplus a_j$,其中 $\oplus$ 表示异或运算。求图 $G$ 的最小生成树。

数据范围:$n \le 2 \times 10^5, a_i \le 2^{30}$。

思路分析

暴力建图是不可取的。我们考虑异或运算的性质。

假设当前所有 $a_i$ 的最高位为 $k$。我们将 $a_i$ 分成两组:这位是 $0$ 的和这位是 $1$ 的。如果两组中都有数,那么生成树上至少要在两组之间连一条边。

两组之间连的边的权值的最高位肯定为 $k$,所以我们要连尽可能少的边,也就是只连一条。我们对于某一组的所有数建立 Trie,然后对于另一组的每个数都去查它与第一组数的最小异或和,这样我们就找到了一条权值最小的边。

接着,我们递归求解两组内部的最小生成树,也就是递归到一个子问题。这样,我们就求出了整个图的最小生成树。

由于递归层数为 $\log a_i$,Trie 树的复杂度为 $O(\log a_i)$,所以总时间复杂度 $O(n \log^2 a_i)$,可以通过本题。

代码实现

1
#include <bits/stdc++.h>
2
using namespace std;
3
4
typedef long long ll;
5
const int maxn = 2e5, logv = 30, maxm = maxn * logv;
6
int n, m, a[maxn + 3], ch[maxm + 3][2], cur;
7
8
void clear() {
9
	for (int i = 1; i <= m; i++) {
10
		ch[i][0] = ch[i][1] = 0;
11
	}
12
	m = 1;
13
}
14
15
void insert(int d, int x) {
16
	int u = 1;
17
	for (int i = d, k; ~i; i--) {
18
		k = x >> i & 1;
19
		if (!ch[u][k]) {
20
			ch[u][k] = ++m;
21
		}
22
		u = ch[u][k];
23
	}
24
}
25
26
int query(int d, int x) {
27
	int u = 1, res = 0;
28
	for (int i = d, k; ~i; i--) {
29
		k = x >> i & 1;
30
		if (ch[u][k]) {
31
			u = ch[u][k];
32
		} else {
33
			u = ch[u][k ^ 1];
34
			res |= 1 << i;
35
		}
36
	}
37
	return res;
38
}
39
40
bool comp(int i, int j) {
41
	return (i >> cur & 1) < (j >> cur & 1);
42
}
43
44
ll solve(int d, int l, int r) {
45
	if (d < 0 || l >= r) {
46
		return 0;
47
	}
48
	int t = 0;
49
	for (int i = l; i <= r; i++) {
50
		if (a[i] >> d & 1) {
51
			t |= 1;
52
		} else {
53
			t |= 2;
54
		}
55
	}
56
	if (t < 3) {
57
		return solve(d - 1, l, r);
58
	}
59
	clear();
60
	for (int i = l; i <= r; i++) {
61
		if (a[i] >> d & 1) {
62
			insert(d - 1, a[i] ^ (1 << d));
63
		}
64
	}
65
	int cnt = 0;
66
	ll ans = 1 << logv;
67
	for (int i = l; i <= r; i++) {
68
		if (~a[i] >> d & 1) {
69
			cnt++;
70
			ans = min(ans, 1ll * query(d - 1, a[i]));
71
		}
72
	}
73
	ans += 1 << d;
74
	cur = d;
75
	sort(a + l, a + r + 1, comp);
76
	ans += solve(d - 1, l, l + cnt - 1);
77
	ans += solve(d - 1, l + cnt, r);
78
	return ans;
79
}
80
81
int main() {
82
	scanf("%d", &n);
83
	for (int i = 1; i <= n; i++) {
84
		scanf("%d", &a[i]);
85
	}
86
	printf("%lld\n", solve(logv - 1, 1, n));
87
	return 0;
88
}