题目大意
求满足下列条件的,长度为 $n$ 的正整数序列 $a$ 数量 $\bmod 998244353$ 的结果:
- $\forall a_i \le A$
- $\forall i \in [1, Q], \max { a_{l_i}, a_{l_i + 1}, \cdots, a_{r_i} } = m_i$
数据范围:$n, A \le 9 \times 10 ^ 8, Q \le 500$。
思路分析
先将序列离散化。对于离散化后的每一段,处理出这段可能达到的最大值。
考虑将限制按照最大值分组,最大值相同的限制一起处理,最后将每组的答案相乘得到总答案。其正确性是因为对于每段区间,它只会对包含它的限制的最小值(等于这段可能达到的最大值)贡献,所以贡献是不重不漏的。
于是问题就转化成了求满足下列条件的,长度为 $n’$ 的正整数序列 $a’$ 数量 $\bmod 998244353$ 的结果:
- $\forall a’_i \le A’$
- $\forall i \in [1, Q’], \max { a’_{l’_i}, a’_{l’_i + 1}, \cdots, a’_{r’_i} } = m’$
可以使用 DP 的方法来求解该问题。令 $\text{len}_i$ 表示第 $i$ 段的长度,预处理 $\text{mn}_i$ 表示右端点为第 $i$ 段区间的限制中左端点所在段的最小值。令 $\text{dp}_{i, j}$ 表示考虑到第 $i$ 位,最后一个 $A ^ {\prime}$ 在第 $j$ 段上的方案数。有两种转移:
- $\text{dp}_{i, j} \leftarrow \text{dp}_{i - 1, j} \times (A ^ {\prime} - 1) ^ {\text{len}_i} (j \in [\text{mn}_i, i - 1])$
- $\text{dp}_{i, i} \leftarrow \text{dp}_{i - 1, j} \times ((A ^ {\prime}) ^ {\text{len}_i} - (A ^ {\prime} - 1) ^ {\text{len}_i}) (j \in [0, i - 1])$
总时间复杂度 $O(T \times Q^2 \times \log n)$。
代码实现
1 |
|
2 |
|
3 |
|
4 |
|
5 |
|
6 | using namespace std; |
7 | |
8 | const int maxn = 500, maxm = 2 * maxn, mod = 998244353; |
9 | bool vis[maxn + 3], ok[maxm + 3]; |
10 | int T, n, q, A, l[maxn + 3], r[maxn + 3], a[maxn + 3], m, pos[maxm + 3]; |
11 | int M, L[maxm + 3], R[maxm + 3], mx[maxm + 3], mn[maxm + 3], Q, tm[maxn + 3]; |
12 | int dp[maxm + 3][maxm + 3]; |
13 | vector<int> V[maxm + 3]; |
14 | multiset<int> S; |
15 | |
16 | int Pow(int a, int b) { |
17 | int c = 1; |
18 | for (; b; b >>= 1, a = 1ll * a * a % mod) { |
19 | if (b & 1) c = 1ll * a * c % mod; |
20 | } |
21 | return c; |
22 | } |
23 | |
24 | int solve(int n, int w) { |
25 | // dp[i][j] 表示前 i 位的最后一个当前最大值在 j 的方案数目 |
26 | dp[0][0] = 1; |
27 | for (int i = 1, k = pos[1]; i <= n; i++, k = pos[i]) { |
28 | for (int j = 0; j <= i; j++) dp[i][j] = 0; |
29 | int x = Pow(w - 1, R[k] - L[k] + 1), y = Pow(w, R[k] - L[k] + 1) - x; |
30 | y < 0 ? y += mod : 0; |
31 | for (int j = 0; j < i; j++) if (dp[i - 1][j]) { |
32 | if (j >= mn[i]) dp[i][j] = (dp[i][j] + 1ll * x * dp[i - 1][j]) % mod; |
33 | dp[i][i] = (dp[i][i] + 1ll * y * dp[i - 1][j]) % mod; |
34 | } |
35 | } |
36 | int res = 0; |
37 | for (int i = 0; i <= n; i++) { |
38 | res += dp[n][i], res < mod ? 0 : res -= mod; |
39 | } |
40 | return res; |
41 | } |
42 | |
43 | int main() { |
44 | scanf("%d", &T); |
45 | while (T--) { |
46 | scanf("%d %d %d", &n, &q, &A); |
47 | m = 0, pos[++m] = 1, pos[++m] = n + 1; |
48 | for (int i = 1; i <= q; i++) { |
49 | scanf("%d %d %d", &l[i], &r[i], &a[i]); |
50 | pos[++m] = l[i], pos[++m] = r[i] + 1; |
51 | } |
52 | sort(pos + 1, pos + m + 1); |
53 | m = unique(pos + 1, pos + m + 1) - (pos + 1); |
54 | M = m - 1; |
55 | for (int i = 1; i <= M; i++) { |
56 | L[i] = pos[i], R[i] = pos[i + 1] - 1; |
57 | V[i].clear(); |
58 | } |
59 | for (int i = 1; i <= q; i++) { |
60 | l[i] = lower_bound(pos + 1, pos + m + 1, l[i]) - pos; |
61 | r[i] = upper_bound(pos + 1, pos + m + 1, r[i]) - (pos + 1); |
62 | V[l[i]].push_back(i), V[r[i] + 1].push_back(i); |
63 | } |
64 | memset(vis, false, sizeof(vis)); |
65 | S.clear(); |
66 | for (int i = 1; i <= M; i++) { |
67 | for (int k: V[i]) { |
68 | if (!vis[k]) { |
69 | vis[k] = true; |
70 | S.insert(a[k]); |
71 | } else { |
72 | S.erase(S.lower_bound(a[k])); |
73 | } |
74 | } |
75 | mx[i] = S.empty() ? -1 : *S.begin(); |
76 | } |
77 | Q = 0; |
78 | for (int i = 1; i <= q; i++) { |
79 | tm[++Q] = a[i]; |
80 | } |
81 | sort(tm + 1, tm + Q + 1); |
82 | Q = unique(tm + 1, tm + Q + 1) - (tm + 1); |
83 | int ans = 1; |
84 | bool flag = true; |
85 | for (int i = 1; i <= Q; i++) { |
86 | m = 0; |
87 | for (int j = 1; j <= M; j++) { |
88 | if (tm[i] == mx[j]) pos[++m] = j; |
89 | } |
90 | for (int j = 1; j <= m; j++) mn[j] = -1; |
91 | for (int j = 1; j <= q; j++) { |
92 | if (tm[i] == a[j]) { |
93 | if (!m) { flag = false; break; } |
94 | l[j] = lower_bound(pos + 1, pos + m + 1, l[j]) - pos; |
95 | r[j] = upper_bound(pos + 1, pos + m + 1, r[j]) - (pos + 1); |
96 | mn[r[j]] = max(mn[r[j]], l[j]); |
97 | } |
98 | } |
99 | if (!flag) break; |
100 | ans = 1ll * ans * solve(m, tm[i]) % mod; |
101 | } |
102 | if (!flag) { puts("0"); continue; } |
103 | for (int i = 1; i <= M; i++) { |
104 | if (mx[i] == -1) ans = 1ll * ans * Pow(A, R[i] - L[i] + 1) % mod; |
105 | } |
106 | printf("%d\n", ans); |
107 | } |
108 | return 0; |
109 | } |