只要你跑的够快,锅就追不上你

0%

「十二省联考 2019」希望(组合计数 + 动态规划 + 长链剖分)

题目大意

我们所可以自慰的,想来想去,也还是所谓对于将来的希望。
希望是附丽于存在的,有存在,便有希望,有希望,便是光明。

「十二省联考 2019」希望(LOJ 3053)

给定一棵 $n$ 个结点的树以及两个数 $L, k$。对于树上的一个连通快 $S$,定义 $R(S)$ 为在 $S$ 中且离 $S$ 中的每个点距离都不超过 $L$ 的点集。问任选 $k$ 个联通快 $S_1, S_2, \cdots, S_k$,使得 $R(S_1), R(S_2), \cdots, R(S_n)$ 有交的方案数 $\bmod 998244353$ 的结果。

数据范围:$n, L \le 10^6$。

思路分析

第一部分:初步转化

容易发现 $R(S)$ 也是一个联通块。联通块有如下的性质:

  • 联通块和联通块的交还是联通块
  • 非空联通块的点数等于边数 $+1$

于是答案就可以转化为:

其中 $f(i)$ 表示有多少个 $S$ 满足 $i \in R(S)$,$g(i)$ 表示有多少个 $S$ 满足 $i, \text{fa}(i) \in R(S)$。

第二部分:dp 建模

考虑 DP。设计如下状态:

  • $\text{dp}(i, k)$ 满足在 $i$ 子树内有多少个包含 $i$ 的联通块,满足其中的所有点到 $i$ 的距离不超过 $k$
  • $\text{up}(i, k)$ 表示在 $i$ 子树外(包含 $i$)有多少个包含 $i$ 的联通块,满足其中的所有点到 $i$ 的距离不超过 $k$。

发现 $f(i) = \text{dp}(i, L) \times \text{up}(i, L)$,$g(i) = \text{dp}(i, L - 1) \times (\text{up}(i, L) - 1)$。可以通过计算 $\text{dp}, \text{up}$ 得到 $f, g$。

列出状态转移方程:

  • $\text{dp}(i, k) = \prod_{j \in \text{ch}(i)} \text{dp}(j, k - 1) + 1$
  • $\text{up}(i, k) = \text{up}(\text{fa}(i, k - 1)) \times \prod_{j \in \text{ch}(\text{fa}(i)), j \neq i} \text{dp}(j, k - 2) + 1$

于是我们得到了 $O(nL)$ 的做法,期望得分 $36$ 分。

至此难度不大于 NOIP 提高组。

第三部分:长链剖分

最难部分。

这个 DP 是和深度有关的,容易想到长链剖分优化。

先考虑优化 DP 数组的转移。我们设当前结点为 $u$,当前儿子结点为 $v$。我们需要把短儿子依次合并到长儿子上面。

但是有一个问题:对于深度大于短儿子深度的部分,我们要将其整体乘上 $\text{dp}(v, d(v))$,其中 $d(i)$ 表示点 $i$ 往下延伸的最长长度。我们发现要乘上一个数的后缀长度很大,是 $O(d(u) - d(v))$ 级别的,但是长链剖分要求合并的复杂度为 $O(d(v))$ 才能够做到线性。

于是我们考虑维护整体乘标记。把后缀乘法操作转化为先整体乘一个数,再将对应的前缀乘这个数的逆元。我们还要维护整体加标记,因为每次做完一个结点后,我们要将它的 DP 值 $+1$。此外,为了方便转移,也要维护整体乘标记的逆元。

乍一看算法好像是正确的,其实不然。注意到 $\text{dp}(v, d(v))$ 可能不存在逆元,这样后缀乘法操作就变成了后缀变 $0$ 操作。所以还需维护一个后缀变 $x$ 标记,其中 $x$ 与乘法标记相乘再加上加法标记等于 $0$。

再考虑优化 $\text{up}$ 数组的转移。考虑先转移短儿子,再转移长儿子。

我们发现短儿子所用到的 $\text{up}$ 值的深度范围是 $d(v)$ 级别的,所以可以直接转移。

而长儿子可以打标记 + 将所有短儿子转移过来,复杂度为所有短儿子 $d$ 之和,也是均摊线性的。

现在主要的问题就是如何维护某个点的所有 “兄弟” 的 DP 值之积。发现它一定是一个前缀积和一个后缀积相乘。我们将 $u$ 的所有儿子倒过来考虑,前缀积可以通过在之前计算 $dp$ 值的时候将有更改的位置预先存起来,后缀积则可以边扫边更新。

此时离正解已经很接近了,可是求一个数逆元的复杂度还是带有 $\log$。所以时间复杂度为 $O(n \log n)$。

第四部分:离线求逆元

要做到严格线性,就要将求逆元的过程优化。

发现需要求的逆元只有 $h(i)$,其中 $h(i)$ 表示 $i$ 的子树内有多少个包含 $i$ 的联通块。这显然是可以预先线性地 DP 出来的。所以问题就变成了给定 $a_1, a_2, \cdots, a_n$,求 $a_1^{-1}, a_2^{-1}, \cdots, a_n^{-1} \mod 998244353$。

我们令 $A(i) = \prod_{j = 1}^{i} a_i$,可以先正着做一次前缀积求解 $A(i)$,然后通过 $A(n)$ 得出 $A(n)^{-1}$,最后反着做一次后缀积来求解 $A(i)^{-1}$。又有 $a_i = A(i)^{-1} \times A(i - 1)$,于是我们就可以线性求出逆元。

至此题目已经完美解决,时间复杂度 $O(n)$,期望得分 $100$ 分。

代码实现

1
#include <cstdio>
2
#include <list>
3
#include <vector>
4
#include <algorithm>
5
using namespace std;
6
typedef pair<int, int> pii;
7
8
const int maxn = 1e6, maxm = 2 * maxn, mod = 998244353;
9
int n, L, k, d[maxn + 3], ch[maxn + 3], f[maxn + 3][2], g[maxn + 3];
10
vector<int> G[maxn + 3];
11
12
int _pow(int a, int b) {
13
	int c = 1;
14
	for (; b; b >>= 1, a = 1ll * a * a % mod) {
15
		if (b & 1) c = 1ll * a * c % mod;
16
	}
17
	return c;
18
}
19
20
inline int func(const int &x) {
21
	return x < 0 ? x + mod : x < mod ? x : x - mod;
22
}
23
24
namespace prework {
25
	int f[maxn + 3], cnt, id[maxn + 3], p[maxn + 3], ip[maxn + 3];
26
	void dfs(int u, int pa = 0) {
27
		f[u] = 1;
28
		for (int v: G[u]) if (v != pa) {
29
			dfs(v, u);
30
			f[u] = 1ll * f[u] * f[v] % mod;
31
			if (d[v] + 1 > d[u]) {
32
				d[u] = d[v] + 1, ch[u] = v;
33
			}
34
		}
35
		f[u] = func(f[u] + 1);
36
		if (f[u]) {
37
			id[++cnt] = u;
38
		}
39
	}
40
	void main() {
41
		dfs(1);
42
		p[0] = 1;
43
		for (int i = 1; i <= cnt; i++) {
44
			p[i] = 1ll * p[i - 1] * f[id[i]] % mod;
45
		}
46
		ip[cnt] = _pow(p[cnt], mod - 2);
47
		for (int i = cnt; i; i--) {
48
			ip[i - 1] = 1ll * ip[i] * f[id[i]] % mod;
49
		}
50
		for (int i = 1; i <= cnt; i++) {
51
			f[id[i]] = 1ll * ip[i] * p[i - 1] % mod;
52
		}
53
	}
54
	int query(int &x) {
55
		return f[x];
56
	}
57
}
58
59
struct node {
60
	int a, ia, b, p, q;
61
	vector<pii> P;
62
};
63
64
list<node> hist[maxn + 3];
65
66
namespace get_dp {
67
	int arr[2 * maxm + 3], *cur = arr, *dp[maxm + 3], a[maxm + 3], ia[maxm + 3], b[maxm + 3], p[maxm + 3], q[maxm + 3];
68
	int query(int u, int d) {
69
		return (1ll * a[u] * (d >= p[u] ? q[u] : dp[u][d]) + b[u]) % mod;
70
	}
71
	void update(int u, int v, int d, bool flag = true) {
72
		int &ca = a[u], &cia = ia[u], &cb = b[u], &cp = p[u], &cq = q[u];
73
		node t;
74
		t.a = ca, t.ia = cia, t.b = cb, t.p = cp, t.q = cq;
75
		for (int i = 1; i <= d; i++) {
76
			t.P.push_back(pii(i, dp[u][i]));
77
			int val = query(v, i - 1);
78
			if (i == cp) dp[u][cp++] = cq;
79
			dp[u][i] = 1ll * cia * func(1ll * query(u, i) * val % mod - cb) % mod;
80
		}
81
		if (d < L) {
82
			int val = query(v, d);
83
			if (!val) {
84
				cp = d + 1;
85
				cq = func(mod - 1ll * cia * cb % mod);
86
			} else {
87
				t.P.push_back(pii(0, dp[u][0]));
88
				int inv = prework::query(v);
89
				for (int i = 0; i <= d; i++) {
90
					dp[u][i] = 1ll * cia * func(1ll * query(u, i) * inv % mod - cb) % mod;
91
				}
92
				ca = 1ll * ca * val % mod, cb = 1ll * cb * val % mod;
93
				cia = 1ll * cia * inv % mod;
94
			}
95
		}
96
		if (flag) hist[u].push_back(t);
97
	}
98
	void dfs(int u, int pa = 0) {
99
		int &ca = a[u], &cia = ia[u], &cb = b[u], &cp = p[u], &cq = q[u];
100
		ca = cia = 1, cp = L + 1;
101
		if (ch[u]) {
102
			dp[ch[u]] = dp[u] + 1;
103
			dfs(ch[u], u);
104
			ca = a[ch[u]], cia = ia[ch[u]], cb = b[ch[u]], cp = p[ch[u]], cq = q[ch[u]];
105
		} else {
106
			cb = 2;
107
			f[u][0] = f[u][1] = 1;
108
			return;
109
		}
110
		dp[u][0] = 1ll * cia * (mod - cb + 1) % mod;
111
		for (int v: G[u]) if (v != pa && v != ch[u]) {
112
			dp[v] = cur, cur += d[v] + 1;
113
			dfs(v, u);
114
			update(u, v, min(d[v], L));
115
		}
116
		cb = func(cb + 1);
117
		f[u][0] = func(query(u, min(d[u], L)) - 1);
118
		f[u][1] = func(query(u, min(d[u], L - 1)) - 1);
119
	}
120
	void main() {
121
		dp[1] = cur, cur += d[1] + 1;
122
		dfs(1);
123
	}
124
	void back(int u) {
125
		node t = hist[u].back();
126
		hist[u].pop_back();
127
		a[u] = t.a, ia[u] = t.ia, b[u] = t.b, p[u] = t.p, q[u] = t.q;
128
		for (pii p: t.P) dp[u][p.first] = p.second;
129
	}
130
}
131
132
using get_dp::dp;
133
134
namespace get_up {
135
	int arr[2 * maxn + 3], *cur = arr, *up[maxn + 3], a[maxn + 3], ia[maxn + 3], b[maxn + 3], p[maxn + 3], q[maxn + 3];
136
	int query(int u, int d) {
137
		return (1ll * a[u] * (d >= p[u] ? q[u] : up[u][d]) + b[u]) % mod;
138
	}
139
	void dfs(int u, int pa = 0) {
140
		if (d[u] >= L) {
141
			up[u][0] = 1ll * ia[u] * (mod + 1 - b[u]) % mod;
142
		}
143
		g[u] = query(u, L);
144
		if (!ch[u]) return;
145
		int &ca = a[ch[u]], &cia = ia[ch[u]], &cb = b[ch[u]], &cp = p[ch[u]], &cq = q[ch[u]];
146
		ca = a[u], cia = ia[u], cb = b[u], cp = p[u], cq = q[u];
147
		int x = 1;
148
		for (int v: G[u]) if (v != pa && v != ch[u]) {
149
			x = max(x, d[v] + 1);
150
		}
151
		x = min(x, L);
152
		int v = u + n;
153
		dp[v] = get_dp::cur, get_dp::cur += x + 1;
154
		get_dp::a[v] = get_dp::ia[v] = get_dp::b[v] = 1;
155
		get_dp::p[v] = L + 1;
156
		reverse(G[u].begin(), G[u].end());
157
		int ta = 1, tia = 1;
158
		for (int t: G[u]) if (t != pa && t != ch[u]) {
159
			get_dp::back(u);
160
			int l = max(0, L - d[t]), r = L;
161
			up[t] = cur + d[t] - l;
162
			cur = up[t] + r + 1;
163
			for (int i = max(l, 1); i <= r; i++) {
164
				up[t][i] = 1ll * query(u, i - 1) * get_dp::query(u, min(i - 1, d[u])) % mod * get_dp::query(v, min(i - 1, x)) % mod;
165
			}
166
			a[t] = ia[t] = b[t] = 1, p[t] = L + 1;
167
			get_dp::update(v, t, min(d[t], L), false);
168
			if (x + 1 <= L) {
169
				int val = get_dp::query(t, d[t]);
170
				if (!val) {
171
					cp = min(cp, d[t] + 1);
172
					cq = func(mod - 1ll * cia * cb % mod);
173
				} else {
174
					ta = 1ll * ta * val % mod;
175
					tia = 1ll * tia * prework::query(t) % mod;
176
				}
177
			}
178
			dfs(t, u);
179
		}
180
		int t = ch[u];
181
		up[t] = up[u] - 1;
182
		int l = max(0, L - d[t]), r = L;
183
		cp = max(cp, l);
184
		while (cp <= r && cp <= x + 1) {
185
			up[t][cp++] = cq;
186
		}
187
		for (int i = max(l, 1); i <= r && i <= x; i++) {
188
			up[t][i] = 1ll * cia * (1ll * query(u, i - 1) * get_dp::query(v, i - 1) % mod + mod - cb) % mod;
189
			if (x + 1 <= r) {
190
				up[t][i] = 1ll * cia * (1ll * query(t, i) * tia % mod + mod - cb) % mod; 
191
			}
192
		}
193
		if (x + 1 <= r) {
194
			ca = 1ll * ca * ta % mod, cb = 1ll * cb * ta % mod;
195
			cia = 1ll * cia * tia % mod;
196
		}
197
		cb = func(cb + 1);
198
		dfs(t, u);
199
	}
200
	void main() {
201
		a[1] = ia[1] = b[1] = 1, p[1] = L + 1;
202
		up[1] = (cur += d[1]), cur += L + 1;
203
		dfs(1);
204
	}
205
}
206
207
int main() {
208
	scanf("%d %d %d", &n, &L, &k);
209
	for (int i = 1, u, v; i < n; i++) {
210
		scanf("%d %d", &u, &v);
211
		G[u].push_back(v), G[v].push_back(u);
212
	}
213
	prework::main();
214
	get_dp::main();
215
	get_up::main();
216
	int ans = 0;
217
	for (int i = 1; i <= n; i++) {
218
		ans = (ans + _pow(1ll * f[i][0] * g[i] % mod, k)) % mod;
219
	}
220
	for (int i = 2; i <= n; i++) {
221
		ans = func(ans - _pow(1ll * f[i][1] * func(g[i] - 1) % mod, k));
222
	}
223
	printf("%d\n", ans);
224
	return 0;
225
}