题目大意
给定一棵 $n$ 个结点的树,根为 $r$。有 $Q$ 个询问,每次给定一个点集,问从根开始随机游走,到达这个点集中的所有点的期望步数。
数据范围:$n \le 18, Q \le 5 \times 10^3$。
思路分析
这个 $Q \le 5 \times 10^3$ 好像没什么用啊 QwQ
考虑预处理出每个点集的答案,然后对于每个询问直接查表。
这里介绍一下 Min-Max 容斥:
这么算 $n$ 个数的最大值有什么好处呢?考虑 $x_i$ 是随机变量,并且某个子集的最小值的期望比较好计算的情况。由于期望的线性性,我们可以使用 Min-Max 容斥,分别计算每个子集最小值的期望即可得到它们最大值的期望。注意到对于任意的两个随机变量 $x, y$,都有 $E(x + y) = E(x) + E(y)$,不要求 $x, y$ 独立。
对于这个问题,我们设共有 $m$ 个点,到达第 $i$ 个点的时间为 $t_i$,我们要求的就是 $E(\max {t_i})$。考虑计算 $E(\min {S_i}) (S \subseteq { t_i })$,也就是第一次到达一个点集的期望时间。
对于一个子集中的点,如果它到根的路径上已经存在点了,那么它就没有用了。我们扔掉没用的点以后,树的所有叶子结点一定都在点集中了。我们列出随机游走的方程,使用高斯消元法求解即可。最后,将每个点集的答案乘上容斥系数,再使用一遍高维前缀和,就可以求出原问题的答案了。
共要枚举 $2^n$ 个点集,暴力高斯消元的复杂度是 $O(n^3)$ 的,总复杂度 $O(2^n n^3)$,似乎不可通过。考虑图的结构是一棵树,我们可以 “手动高斯消元”。设 $L$ 是叶子结点集合,第 $u$ 个结点的度数是 $d(u)$,答案是 $f(u)$,那么有:
发现 $f(u)$ 肯定能表示成 $a(u) \times f(\text{fa}(u)) + b(u)$ 的形式。这样,我们从下往上递推,就可以直接算出 $a(u), b(u)$。算到根结点时,因为它没有父亲,所以 $b(u)$ 就是答案。这样,时间复杂度降低到了 $O(2^n n)$,可以通过本题。
代码实现
1 |
|
2 | using namespace std; |
3 | |
4 | const int maxn = 18, maxm = 1 << maxn, mod = 998244353; |
5 | int n, Q, rt, cur, inv[maxn + 3], deg[maxn + 3], a[maxn + 3], b[maxn + 3], f[maxm + 3]; |
6 | vector<int> G[maxn + 3]; |
7 | |
8 | int qpow(int a, int b) { |
9 | int c = 1; |
10 | for (; b; b >>= 1, a = 1ll * a * a % mod) { |
11 | if (b & 1) c = 1ll * a * c % mod; |
12 | } |
13 | return c; |
14 | } |
15 | |
16 | void add(int u, int v) { |
17 | G[u].push_back(v), deg[u]++; |
18 | } |
19 | |
20 | void dfs(int u, int pa = 0) { |
21 | if (cur >> (u - 1) & 1) { |
22 | a[u] = b[u] = 0; |
23 | return; |
24 | } |
25 | a[u] = inv[deg[u]], b[u] = 1; |
26 | int x = 1; |
27 | for (int i = 0, v; i < G[u].size(); i++) { |
28 | v = G[u][i]; |
29 | if (v == pa) continue; |
30 | dfs(v, u); |
31 | b[u] = (b[u] + 1ll * inv[deg[u]] * b[v]) % mod; |
32 | x = (x + 1ll * (mod - inv[deg[u]]) * a[v]) % mod; |
33 | } |
34 | x = qpow(x, mod - 2); |
35 | a[u] = 1ll * a[u] * x % mod, b[u] = 1ll * b[u] * x % mod; |
36 | } |
37 | |
38 | int main() { |
39 | scanf("%d %d %d", &n, &Q, &rt); |
40 | for (int i = 1; i <= n; i++) { |
41 | inv[i] = qpow(i, mod - 2); |
42 | } |
43 | for (int i = 1, u, v; i < n; i++) { |
44 | scanf("%d %d", &u, &v); |
45 | add(u, v), add(v, u); |
46 | } |
47 | for (int msk = 1; msk < 1 << n; msk++) { |
48 | cur = msk; |
49 | dfs(rt); |
50 | int num = mod - 1; |
51 | for (int i = 1; i <= n; i++) { |
52 | if (msk >> (i - 1) & 1) { |
53 | num = mod - num; |
54 | } |
55 | } |
56 | f[msk] = 1ll * num * b[rt] % mod; |
57 | } |
58 | for (int i = 1; i <= n; i++) { |
59 | for (int msk = 0; msk < 1 << n; msk++) { |
60 | if (msk >> (i - 1) & 1) { |
61 | f[msk] += f[msk ^ (1 << (i - 1))]; |
62 | f[msk] < mod ? 0 : f[msk] -= mod; |
63 | } |
64 | } |
65 | } |
66 | for (int k, x, msk; Q--; ) { |
67 | scanf("%d", &k); |
68 | msk = 0; |
69 | while (k--) { |
70 | scanf("%d", &x); |
71 | msk |= 1 << (x - 1); |
72 | } |
73 | printf("%d\n", f[msk]); |
74 | } |
75 | return 0; |
76 | } |