只要你跑的够快,锅就追不上你

0%

「清华集训 2017」某位歌姬的故事(组合计数 + 动态规划)

题目大意

「清华集训 2017」某位歌姬的故事(UOJ 346)

求满足下列条件的,长度为 $n$ 的正整数序列 $a$ 数量 $\bmod 998244353$ 的结果:

  • $\forall a_i \le A$
  • $\forall i \in [1, Q], \max { a_{l_i}, a_{l_i + 1}, \cdots, a_{r_i} } = m_i$

数据范围:$n, A \le 9 \times 10 ^ 8, Q \le 500$。

思路分析

先将序列离散化。对于离散化后的每一段,处理出这段可能达到的最大值。

考虑将限制按照最大值分组,最大值相同的限制一起处理,最后将每组的答案相乘得到总答案。其正确性是因为对于每段区间,它只会对包含它的限制的最小值(等于这段可能达到的最大值)贡献,所以贡献是不重不漏的。

于是问题就转化成了求满足下列条件的,长度为 $n’$ 的正整数序列 $a’$ 数量 $\bmod 998244353$ 的结果:

  • $\forall a’_i \le A’$
  • $\forall i \in [1, Q’], \max { a’_{l’_i}, a’_{l’_i + 1}, \cdots, a’_{r’_i} } = m’$

可以使用 DP 的方法来求解该问题。令 $\text{len}_i$ 表示第 $i$ 段的长度,预处理 $\text{mn}_i$ 表示右端点为第 $i$ 段区间的限制中左端点所在段的最小值。令 $\text{dp}_{i, j}$ 表示考虑到第 $i$ 位,最后一个 $A ^ {\prime}$ 在第 $j$ 段上的方案数。有两种转移:

  • $\text{dp}_{i, j} \leftarrow \text{dp}_{i - 1, j} \times (A ^ {\prime} - 1) ^ {\text{len}_i} (j \in [\text{mn}_i, i - 1])$
  • $\text{dp}_{i, i} \leftarrow \text{dp}_{i - 1, j} \times ((A ^ {\prime}) ^ {\text{len}_i} - (A ^ {\prime} - 1) ^ {\text{len}_i}) (j \in [0, i - 1])$

总时间复杂度 $O(T \times Q^2 \times \log n)$。

代码实现

1
#include <cstdio>
2
#include <cstring>
3
#include <set>
4
#include <vector>
5
#include <algorithm>
6
using namespace std;
7
8
const int maxn = 500, maxm = 2 * maxn, mod = 998244353;
9
bool vis[maxn + 3], ok[maxm + 3];
10
int T, n, q, A, l[maxn + 3], r[maxn + 3], a[maxn + 3], m, pos[maxm + 3];
11
int M, L[maxm + 3], R[maxm + 3], mx[maxm + 3], mn[maxm + 3], Q, tm[maxn + 3];
12
int dp[maxm + 3][maxm + 3];
13
vector<int> V[maxm + 3];
14
multiset<int> S;
15
16
int Pow(int a, int b) {
17
	int c = 1;
18
	for (; b; b >>= 1, a = 1ll * a * a % mod) {
19
		if (b & 1) c = 1ll * a * c % mod;
20
	}
21
	return c;
22
}
23
24
int solve(int n, int w) {
25
	// dp[i][j] 表示前 i 位的最后一个当前最大值在 j 的方案数目
26
	dp[0][0] = 1;
27
	for (int i = 1, k = pos[1]; i <= n; i++, k = pos[i]) {
28
		for (int j = 0; j <= i; j++) dp[i][j] = 0;
29
		int x = Pow(w - 1, R[k] - L[k] + 1), y = Pow(w, R[k] - L[k] + 1) - x;
30
		y < 0 ? y += mod : 0;
31
		for (int j = 0; j < i; j++) if (dp[i - 1][j]) {
32
			if (j >= mn[i]) dp[i][j] = (dp[i][j] + 1ll * x * dp[i - 1][j]) % mod;
33
			dp[i][i] = (dp[i][i] + 1ll * y * dp[i - 1][j]) % mod;
34
		}
35
	}
36
	int res = 0;
37
	for (int i = 0; i <= n; i++) {
38
		res += dp[n][i], res < mod ? 0 : res -= mod;
39
	}
40
	return res;
41
}
42
43
int main() {
44
	scanf("%d", &T);
45
	while (T--) {
46
		scanf("%d %d %d", &n, &q, &A);
47
		m = 0, pos[++m] = 1, pos[++m] = n + 1;
48
		for (int i = 1; i <= q; i++) {
49
			scanf("%d %d %d", &l[i], &r[i], &a[i]);
50
			pos[++m] = l[i], pos[++m] = r[i] + 1;
51
		}
52
		sort(pos + 1, pos + m + 1);
53
		m = unique(pos + 1, pos + m + 1) - (pos + 1);
54
		M = m - 1;
55
		for (int i = 1; i <= M; i++) {
56
			L[i] = pos[i], R[i] = pos[i + 1] - 1;
57
			V[i].clear();
58
		}
59
		for (int i = 1; i <= q; i++) {
60
			l[i] = lower_bound(pos + 1, pos + m + 1, l[i]) - pos;
61
			r[i] = upper_bound(pos + 1, pos + m + 1, r[i]) - (pos + 1);
62
			V[l[i]].push_back(i), V[r[i] + 1].push_back(i);
63
		}
64
		memset(vis, false, sizeof(vis));
65
		S.clear();
66
		for (int i = 1; i <= M; i++) {
67
			for (int k: V[i]) {
68
				if (!vis[k]) {
69
					vis[k] = true;
70
					S.insert(a[k]);
71
				} else {
72
					S.erase(S.lower_bound(a[k]));
73
				}
74
			}
75
			mx[i] = S.empty() ? -1 : *S.begin();
76
		}
77
		Q = 0;
78
		for (int i = 1; i <= q; i++) {
79
			tm[++Q] = a[i];
80
		}
81
		sort(tm + 1, tm + Q + 1);
82
		Q = unique(tm + 1, tm + Q + 1) - (tm + 1);
83
		int ans = 1;
84
		bool flag = true;
85
		for (int i = 1; i <= Q; i++) {
86
			m = 0;
87
			for (int j = 1; j <= M; j++) {
88
				if (tm[i] == mx[j]) pos[++m] = j;
89
			}
90
			for (int j = 1; j <= m; j++) mn[j] = -1;
91
			for (int j = 1; j <= q; j++) {
92
				if (tm[i] == a[j]) {
93
					if (!m) { flag = false; break; }
94
					l[j] = lower_bound(pos + 1, pos + m + 1, l[j]) - pos;
95
					r[j] = upper_bound(pos + 1, pos + m + 1, r[j]) - (pos + 1);
96
					mn[r[j]] = max(mn[r[j]], l[j]);
97
				}
98
			}
99
			if (!flag) break;
100
			ans = 1ll * ans * solve(m, tm[i]) % mod;
101
		}
102
		if (!flag) { puts("0"); continue; }
103
		for (int i = 1; i <= M; i++) {
104
			if (mx[i] == -1) ans = 1ll * ans * Pow(A, R[i] - L[i] + 1) % mod;
105
		}
106
		printf("%d\n", ans);
107
	}
108
	return 0;
109
}