题目大意
我们所可以自慰的,想来想去,也还是所谓对于将来的希望。
希望是附丽于存在的,有存在,便有希望,有希望,便是光明。
给定一棵 $n$ 个结点的树以及两个数 $L, k$。对于树上的一个连通快 $S$,定义 $R(S)$ 为在 $S$ 中且离 $S$ 中的每个点距离都不超过 $L$ 的点集。问任选 $k$ 个联通快 $S_1, S_2, \cdots, S_k$,使得 $R(S_1), R(S_2), \cdots, R(S_n)$ 有交的方案数 $\bmod 998244353$ 的结果。
数据范围:$n, L \le 10^6$。
思路分析
第一部分:初步转化
容易发现 $R(S)$ 也是一个联通块。联通块有如下的性质:
- 联通块和联通块的交还是联通块
- 非空联通块的点数等于边数 $+1$
于是答案就可以转化为:
其中 $f(i)$ 表示有多少个 $S$ 满足 $i \in R(S)$,$g(i)$ 表示有多少个 $S$ 满足 $i, \text{fa}(i) \in R(S)$。
第二部分:dp 建模
考虑 DP。设计如下状态:
- $\text{dp}(i, k)$ 满足在 $i$ 子树内有多少个包含 $i$ 的联通块,满足其中的所有点到 $i$ 的距离不超过 $k$
- $\text{up}(i, k)$ 表示在 $i$ 子树外(包含 $i$)有多少个包含 $i$ 的联通块,满足其中的所有点到 $i$ 的距离不超过 $k$。
发现 $f(i) = \text{dp}(i, L) \times \text{up}(i, L)$,$g(i) = \text{dp}(i, L - 1) \times (\text{up}(i, L) - 1)$。可以通过计算 $\text{dp}, \text{up}$ 得到 $f, g$。
列出状态转移方程:
- $\text{dp}(i, k) = \prod_{j \in \text{ch}(i)} \text{dp}(j, k - 1) + 1$
- $\text{up}(i, k) = \text{up}(\text{fa}(i, k - 1)) \times \prod_{j \in \text{ch}(\text{fa}(i)), j \neq i} \text{dp}(j, k - 2) + 1$
于是我们得到了 $O(nL)$ 的做法,期望得分 $36$ 分。
至此难度不大于 NOIP 提高组。
第三部分:长链剖分
最难部分。
这个 DP 是和深度有关的,容易想到长链剖分优化。
先考虑优化 DP 数组的转移。我们设当前结点为 $u$,当前儿子结点为 $v$。我们需要把短儿子依次合并到长儿子上面。
但是有一个问题:对于深度大于短儿子深度的部分,我们要将其整体乘上 $\text{dp}(v, d(v))$,其中 $d(i)$ 表示点 $i$ 往下延伸的最长长度。我们发现要乘上一个数的后缀长度很大,是 $O(d(u) - d(v))$ 级别的,但是长链剖分要求合并的复杂度为 $O(d(v))$ 才能够做到线性。
于是我们考虑维护整体乘标记。把后缀乘法操作转化为先整体乘一个数,再将对应的前缀乘这个数的逆元。我们还要维护整体加标记,因为每次做完一个结点后,我们要将它的 DP 值 $+1$。此外,为了方便转移,也要维护整体乘标记的逆元。
乍一看算法好像是正确的,其实不然。注意到 $\text{dp}(v, d(v))$ 可能不存在逆元,这样后缀乘法操作就变成了后缀变 $0$ 操作。所以还需维护一个后缀变 $x$ 标记,其中 $x$ 与乘法标记相乘再加上加法标记等于 $0$。
再考虑优化 $\text{up}$ 数组的转移。考虑先转移短儿子,再转移长儿子。
我们发现短儿子所用到的 $\text{up}$ 值的深度范围是 $d(v)$ 级别的,所以可以直接转移。
而长儿子可以打标记 + 将所有短儿子转移过来,复杂度为所有短儿子 $d$ 之和,也是均摊线性的。
现在主要的问题就是如何维护某个点的所有 “兄弟” 的 DP 值之积。发现它一定是一个前缀积和一个后缀积相乘。我们将 $u$ 的所有儿子倒过来考虑,前缀积可以通过在之前计算 $dp$ 值的时候将有更改的位置预先存起来,后缀积则可以边扫边更新。
此时离正解已经很接近了,可是求一个数逆元的复杂度还是带有 $\log$。所以时间复杂度为 $O(n \log n)$。
第四部分:离线求逆元
要做到严格线性,就要将求逆元的过程优化。
发现需要求的逆元只有 $h(i)$,其中 $h(i)$ 表示 $i$ 的子树内有多少个包含 $i$ 的联通块。这显然是可以预先线性地 DP 出来的。所以问题就变成了给定 $a_1, a_2, \cdots, a_n$,求 $a_1^{-1}, a_2^{-1}, \cdots, a_n^{-1} \mod 998244353$。
我们令 $A(i) = \prod_{j = 1}^{i} a_i$,可以先正着做一次前缀积求解 $A(i)$,然后通过 $A(n)$ 得出 $A(n)^{-1}$,最后反着做一次后缀积来求解 $A(i)^{-1}$。又有 $a_i = A(i)^{-1} \times A(i - 1)$,于是我们就可以线性求出逆元。
至此题目已经完美解决,时间复杂度 $O(n)$,期望得分 $100$ 分。
代码实现
1 |
|
2 |
|
3 |
|
4 |
|
5 | using namespace std; |
6 | typedef pair<int, int> pii; |
7 | |
8 | const int maxn = 1e6, maxm = 2 * maxn, mod = 998244353; |
9 | int n, L, k, d[maxn + 3], ch[maxn + 3], f[maxn + 3][2], g[maxn + 3]; |
10 | vector<int> G[maxn + 3]; |
11 | |
12 | int _pow(int a, int b) { |
13 | int c = 1; |
14 | for (; b; b >>= 1, a = 1ll * a * a % mod) { |
15 | if (b & 1) c = 1ll * a * c % mod; |
16 | } |
17 | return c; |
18 | } |
19 | |
20 | inline int func(const int &x) { |
21 | return x < 0 ? x + mod : x < mod ? x : x - mod; |
22 | } |
23 | |
24 | namespace prework { |
25 | int f[maxn + 3], cnt, id[maxn + 3], p[maxn + 3], ip[maxn + 3]; |
26 | void dfs(int u, int pa = 0) { |
27 | f[u] = 1; |
28 | for (int v: G[u]) if (v != pa) { |
29 | dfs(v, u); |
30 | f[u] = 1ll * f[u] * f[v] % mod; |
31 | if (d[v] + 1 > d[u]) { |
32 | d[u] = d[v] + 1, ch[u] = v; |
33 | } |
34 | } |
35 | f[u] = func(f[u] + 1); |
36 | if (f[u]) { |
37 | id[++cnt] = u; |
38 | } |
39 | } |
40 | void main() { |
41 | dfs(1); |
42 | p[0] = 1; |
43 | for (int i = 1; i <= cnt; i++) { |
44 | p[i] = 1ll * p[i - 1] * f[id[i]] % mod; |
45 | } |
46 | ip[cnt] = _pow(p[cnt], mod - 2); |
47 | for (int i = cnt; i; i--) { |
48 | ip[i - 1] = 1ll * ip[i] * f[id[i]] % mod; |
49 | } |
50 | for (int i = 1; i <= cnt; i++) { |
51 | f[id[i]] = 1ll * ip[i] * p[i - 1] % mod; |
52 | } |
53 | } |
54 | int query(int &x) { |
55 | return f[x]; |
56 | } |
57 | } |
58 | |
59 | struct node { |
60 | int a, ia, b, p, q; |
61 | vector<pii> P; |
62 | }; |
63 | |
64 | list<node> hist[maxn + 3]; |
65 | |
66 | namespace get_dp { |
67 | int arr[2 * maxm + 3], *cur = arr, *dp[maxm + 3], a[maxm + 3], ia[maxm + 3], b[maxm + 3], p[maxm + 3], q[maxm + 3]; |
68 | int query(int u, int d) { |
69 | return (1ll * a[u] * (d >= p[u] ? q[u] : dp[u][d]) + b[u]) % mod; |
70 | } |
71 | void update(int u, int v, int d, bool flag = true) { |
72 | int &ca = a[u], &cia = ia[u], &cb = b[u], &cp = p[u], &cq = q[u]; |
73 | node t; |
74 | t.a = ca, t.ia = cia, t.b = cb, t.p = cp, t.q = cq; |
75 | for (int i = 1; i <= d; i++) { |
76 | t.P.push_back(pii(i, dp[u][i])); |
77 | int val = query(v, i - 1); |
78 | if (i == cp) dp[u][cp++] = cq; |
79 | dp[u][i] = 1ll * cia * func(1ll * query(u, i) * val % mod - cb) % mod; |
80 | } |
81 | if (d < L) { |
82 | int val = query(v, d); |
83 | if (!val) { |
84 | cp = d + 1; |
85 | cq = func(mod - 1ll * cia * cb % mod); |
86 | } else { |
87 | t.P.push_back(pii(0, dp[u][0])); |
88 | int inv = prework::query(v); |
89 | for (int i = 0; i <= d; i++) { |
90 | dp[u][i] = 1ll * cia * func(1ll * query(u, i) * inv % mod - cb) % mod; |
91 | } |
92 | ca = 1ll * ca * val % mod, cb = 1ll * cb * val % mod; |
93 | cia = 1ll * cia * inv % mod; |
94 | } |
95 | } |
96 | if (flag) hist[u].push_back(t); |
97 | } |
98 | void dfs(int u, int pa = 0) { |
99 | int &ca = a[u], &cia = ia[u], &cb = b[u], &cp = p[u], &cq = q[u]; |
100 | ca = cia = 1, cp = L + 1; |
101 | if (ch[u]) { |
102 | dp[ch[u]] = dp[u] + 1; |
103 | dfs(ch[u], u); |
104 | ca = a[ch[u]], cia = ia[ch[u]], cb = b[ch[u]], cp = p[ch[u]], cq = q[ch[u]]; |
105 | } else { |
106 | cb = 2; |
107 | f[u][0] = f[u][1] = 1; |
108 | return; |
109 | } |
110 | dp[u][0] = 1ll * cia * (mod - cb + 1) % mod; |
111 | for (int v: G[u]) if (v != pa && v != ch[u]) { |
112 | dp[v] = cur, cur += d[v] + 1; |
113 | dfs(v, u); |
114 | update(u, v, min(d[v], L)); |
115 | } |
116 | cb = func(cb + 1); |
117 | f[u][0] = func(query(u, min(d[u], L)) - 1); |
118 | f[u][1] = func(query(u, min(d[u], L - 1)) - 1); |
119 | } |
120 | void main() { |
121 | dp[1] = cur, cur += d[1] + 1; |
122 | dfs(1); |
123 | } |
124 | void back(int u) { |
125 | node t = hist[u].back(); |
126 | hist[u].pop_back(); |
127 | a[u] = t.a, ia[u] = t.ia, b[u] = t.b, p[u] = t.p, q[u] = t.q; |
128 | for (pii p: t.P) dp[u][p.first] = p.second; |
129 | } |
130 | } |
131 | |
132 | using get_dp::dp; |
133 | |
134 | namespace get_up { |
135 | int arr[2 * maxn + 3], *cur = arr, *up[maxn + 3], a[maxn + 3], ia[maxn + 3], b[maxn + 3], p[maxn + 3], q[maxn + 3]; |
136 | int query(int u, int d) { |
137 | return (1ll * a[u] * (d >= p[u] ? q[u] : up[u][d]) + b[u]) % mod; |
138 | } |
139 | void dfs(int u, int pa = 0) { |
140 | if (d[u] >= L) { |
141 | up[u][0] = 1ll * ia[u] * (mod + 1 - b[u]) % mod; |
142 | } |
143 | g[u] = query(u, L); |
144 | if (!ch[u]) return; |
145 | int &ca = a[ch[u]], &cia = ia[ch[u]], &cb = b[ch[u]], &cp = p[ch[u]], &cq = q[ch[u]]; |
146 | ca = a[u], cia = ia[u], cb = b[u], cp = p[u], cq = q[u]; |
147 | int x = 1; |
148 | for (int v: G[u]) if (v != pa && v != ch[u]) { |
149 | x = max(x, d[v] + 1); |
150 | } |
151 | x = min(x, L); |
152 | int v = u + n; |
153 | dp[v] = get_dp::cur, get_dp::cur += x + 1; |
154 | get_dp::a[v] = get_dp::ia[v] = get_dp::b[v] = 1; |
155 | get_dp::p[v] = L + 1; |
156 | reverse(G[u].begin(), G[u].end()); |
157 | int ta = 1, tia = 1; |
158 | for (int t: G[u]) if (t != pa && t != ch[u]) { |
159 | get_dp::back(u); |
160 | int l = max(0, L - d[t]), r = L; |
161 | up[t] = cur + d[t] - l; |
162 | cur = up[t] + r + 1; |
163 | for (int i = max(l, 1); i <= r; i++) { |
164 | up[t][i] = 1ll * query(u, i - 1) * get_dp::query(u, min(i - 1, d[u])) % mod * get_dp::query(v, min(i - 1, x)) % mod; |
165 | } |
166 | a[t] = ia[t] = b[t] = 1, p[t] = L + 1; |
167 | get_dp::update(v, t, min(d[t], L), false); |
168 | if (x + 1 <= L) { |
169 | int val = get_dp::query(t, d[t]); |
170 | if (!val) { |
171 | cp = min(cp, d[t] + 1); |
172 | cq = func(mod - 1ll * cia * cb % mod); |
173 | } else { |
174 | ta = 1ll * ta * val % mod; |
175 | tia = 1ll * tia * prework::query(t) % mod; |
176 | } |
177 | } |
178 | dfs(t, u); |
179 | } |
180 | int t = ch[u]; |
181 | up[t] = up[u] - 1; |
182 | int l = max(0, L - d[t]), r = L; |
183 | cp = max(cp, l); |
184 | while (cp <= r && cp <= x + 1) { |
185 | up[t][cp++] = cq; |
186 | } |
187 | for (int i = max(l, 1); i <= r && i <= x; i++) { |
188 | up[t][i] = 1ll * cia * (1ll * query(u, i - 1) * get_dp::query(v, i - 1) % mod + mod - cb) % mod; |
189 | if (x + 1 <= r) { |
190 | up[t][i] = 1ll * cia * (1ll * query(t, i) * tia % mod + mod - cb) % mod; |
191 | } |
192 | } |
193 | if (x + 1 <= r) { |
194 | ca = 1ll * ca * ta % mod, cb = 1ll * cb * ta % mod; |
195 | cia = 1ll * cia * tia % mod; |
196 | } |
197 | cb = func(cb + 1); |
198 | dfs(t, u); |
199 | } |
200 | void main() { |
201 | a[1] = ia[1] = b[1] = 1, p[1] = L + 1; |
202 | up[1] = (cur += d[1]), cur += L + 1; |
203 | dfs(1); |
204 | } |
205 | } |
206 | |
207 | int main() { |
208 | scanf("%d %d %d", &n, &L, &k); |
209 | for (int i = 1, u, v; i < n; i++) { |
210 | scanf("%d %d", &u, &v); |
211 | G[u].push_back(v), G[v].push_back(u); |
212 | } |
213 | prework::main(); |
214 | get_dp::main(); |
215 | get_up::main(); |
216 | int ans = 0; |
217 | for (int i = 1; i <= n; i++) { |
218 | ans = (ans + _pow(1ll * f[i][0] * g[i] % mod, k)) % mod; |
219 | } |
220 | for (int i = 2; i <= n; i++) { |
221 | ans = func(ans - _pow(1ll * f[i][1] * func(g[i] - 1) % mod, k)); |
222 | } |
223 | printf("%d\n", ans); |
224 | return 0; |
225 | } |