只要你跑的够快,锅就追不上你

0%

「PKUWC 2018」随机游走(容斥原理 + 高斯消元)

题目大意

「PKUWC 2018」随机游走(LOJ 2542)

给定一棵 $n$ 个结点的树,根为 $r$。有 $Q$ 个询问,每次给定一个点集,问从根开始随机游走,到达这个点集中的所有点的期望步数。

数据范围:$n \le 18, Q \le 5 \times 10^3$。

思路分析

这个 $Q \le 5 \times 10^3$ 好像没什么用啊 QwQ

考虑预处理出每个点集的答案,然后对于每个询问直接查表。

这里介绍一下 Min-Max 容斥:

这么算 $n$ 个数的最大值有什么好处呢?考虑 $x_i$ 是随机变量,并且某个子集的最小值的期望比较好计算的情况。由于期望的线性性,我们可以使用 Min-Max 容斥,分别计算每个子集最小值的期望即可得到它们最大值的期望。注意到对于任意的两个随机变量 $x, y$,都有 $E(x + y) = E(x) + E(y)$,不要求 $x, y$ 独立。

对于这个问题,我们设共有 $m$ 个点,到达第 $i$ 个点的时间为 $t_i$,我们要求的就是 $E(\max {t_i})$。考虑计算 $E(\min {S_i}) (S \subseteq { t_i })$,也就是第一次到达一个点集的期望时间。

对于一个子集中的点,如果它到根的路径上已经存在点了,那么它就没有用了。我们扔掉没用的点以后,树的所有叶子结点一定都在点集中了。我们列出随机游走的方程,使用高斯消元法求解即可。最后,将每个点集的答案乘上容斥系数,再使用一遍高维前缀和,就可以求出原问题的答案了。

共要枚举 $2^n$ 个点集,暴力高斯消元的复杂度是 $O(n^3)$ 的,总复杂度 $O(2^n n^3)$,似乎不可通过。考虑图的结构是一棵树,我们可以 “手动高斯消元”。设 $L$ 是叶子结点集合,第 $u$ 个结点的度数是 $d(u)$,答案是 $f(u)$,那么有:

发现 $f(u)$ 肯定能表示成 $a(u) \times f(\text{fa}(u)) + b(u)$ 的形式。这样,我们从下往上递推,就可以直接算出 $a(u), b(u)$。算到根结点时,因为它没有父亲,所以 $b(u)$ 就是答案。这样,时间复杂度降低到了 $O(2^n n)$,可以通过本题。

代码实现

1
#include <bits/stdc++.h>
2
using namespace std;
3
4
const int maxn = 18, maxm = 1 << maxn, mod = 998244353;
5
int n, Q, rt, cur, inv[maxn + 3], deg[maxn + 3], a[maxn + 3], b[maxn + 3], f[maxm + 3];
6
vector<int> G[maxn + 3];
7
8
int qpow(int a, int b) {
9
	int c = 1;
10
	for (; b; b >>= 1, a = 1ll * a * a % mod) {
11
		if (b & 1) c = 1ll * a * c % mod;
12
	}
13
	return c;
14
}
15
16
void add(int u, int v) {
17
	G[u].push_back(v), deg[u]++;
18
}
19
20
void dfs(int u, int pa = 0) {
21
	if (cur >> (u - 1) & 1) {
22
		a[u] = b[u] = 0;
23
		return;
24
	}
25
	a[u] = inv[deg[u]], b[u] = 1;
26
	int x = 1;
27
	for (int i = 0, v; i < G[u].size(); i++) {
28
		v = G[u][i];
29
		if (v == pa) continue;
30
		dfs(v, u);
31
		b[u] = (b[u] + 1ll * inv[deg[u]] * b[v]) % mod;
32
		x = (x + 1ll * (mod - inv[deg[u]]) * a[v]) % mod;
33
	}
34
	x = qpow(x, mod - 2);
35
	a[u] = 1ll * a[u] * x % mod, b[u] = 1ll * b[u] * x % mod;
36
}
37
38
int main() {
39
	scanf("%d %d %d", &n, &Q, &rt);
40
	for (int i = 1; i <= n; i++) {
41
		inv[i] = qpow(i, mod - 2);
42
	}
43
	for (int i = 1, u, v; i < n; i++) {
44
		scanf("%d %d", &u, &v);
45
		add(u, v), add(v, u);
46
	}
47
	for (int msk = 1; msk < 1 << n; msk++) {
48
		cur = msk;
49
		dfs(rt);
50
		int num = mod - 1;
51
		for (int i = 1; i <= n; i++) {
52
			if (msk >> (i - 1) & 1) {
53
				num = mod - num;
54
			}
55
		}
56
		f[msk] = 1ll * num * b[rt] % mod;
57
	}
58
	for (int i = 1; i <= n; i++) {
59
		for (int msk = 0; msk < 1 << n; msk++) {
60
			if (msk >> (i - 1) & 1) {
61
				f[msk] += f[msk ^ (1 << (i - 1))];
62
				f[msk] < mod ? 0 : f[msk] -= mod;
63
			}
64
		}
65
	}
66
	for (int k, x, msk; Q--; ) {
67
		scanf("%d", &k);
68
		msk = 0;
69
		while (k--) {
70
			scanf("%d", &x);
71
			msk |= 1 << (x - 1);
72
		}
73
		printf("%d\n", f[msk]);
74
	}
75
	return 0;
76
}