题目大意
有一个长度为 $n$ 的数列 $a_1, a_2, \cdots, a_n$。进行 $q$ 次操作,每次随机选出一个区间,然后把区间中的所有数赋值为区间的最大值。问最后每一个位置上的数的期望乘上所有可能操作个数后$\bmod 10 ^ 9 + 7$ 是多少。
数据范围:$n, q \le 400$。
思路分析
算法一
由于数很大,我们先将序列离散化。我们假设数列已经离散化完毕,并且数列中等于 $i$ 的数原来是 $\text{val}(i)$。
对于每个 $x$ 分别考虑它最后在每个位置出现了几次。令 $\text{dp}(x, i, l, r)$ 表示 $i$ 次操作过后,$a_l, a_{l + 1}, \cdots, a_r \le x$,$a_{l - 1}, a_{r + 1} > x$,即 $[l, r]$ 是极长的 $\le x$ 区间的方案数。
考虑第 $i$ 次选择的区间:
- 若它完全在原区间内部或外部,则原区间不会变化。
- 若它横跨原区间的左端点,但是不经过右端点,那么新区间的右端点不变,左端点向右移动到选择的区间的右端点 $+ 1$ 处。
- 若它横跨原区间的右端点,但是不经过左端点,那么新区间的左端点不变,右端点向左移动到选择的区间的左端点 $- 1$ 处。
- 若它包含整个区间,那么原区间就会消失。
综上所述,我们可以列出 DP 的转移方程:
其中 $g(l, r)$ 表示对区间 $[l, r]$ 没有影响的操作区间个数。
使用前缀和优化即可达到 $O(n ^ 2 q)$ 的复杂度。由于有 $n$ 个不同的 $x$,总复杂度为 $O(n ^ 3 q)$。
那么对于一个位置 $i$,它 $\le x$ 的情况数就是 $\sum_{i \in [l, r]} \text{dp}(x, q, l, r)$。我们记这个数为 $f(i, x)$,位置 $i$ 等于 $x$ 的方案数就是 $f(i, x) - f(i, x - 1)$。所以位置 $i$ 的最终答案为:$\sum_{x} \text{val}(x) \times (f(i, x) - f(i, x - 1))$。
期望得分 $70$ 分。
算法二
我们考虑在算法一的基础上优化。发现复杂度瓶颈在于一开始的枚举 $x$,我们应该去掉这个过程。注意到对于 $f(i, x)$,它对位置 $i$ 有 $\text{val}(x) - \text{val}(x + 1)$ 的贡献。于是我们不必枚举 $x$,只需要在一开始给 DP 赋初始值的时候将贡献加进去即可。也就是说,对于每个 $x$,找出它对应的极长区间,然后让 $\text{dp}(0, l, r) \leftarrow \text{dp}(0, l, r) + \text{val}(x) - \text{val}(x + 1)$。这样的复杂度就减小到了 $O(n ^ 2 q)$。详见代码。
期望得分 $100$ 分。
代码实现
算法一(70 分)
1 |
|
2 |
|
3 | using namespace std; |
4 | |
5 | const int maxn = 400, mod = 1e9 + 7; |
6 | int n, m, q, a[maxn + 3], v[maxn + 3], c, p[maxn + 3], dp[maxn + 3][maxn + 3][maxn + 3], f[maxn + 3][maxn + 3], g[maxn + 3][maxn + 3], h[maxn + 3][maxn + 3]; |
7 | |
8 | int calc(int x) { |
9 | return x * (x + 1) / 2; |
10 | } |
11 | |
12 | int main() { |
13 | scanf("%d %d", &n, &q); |
14 | for (int i = 1; i <= n; i++) { |
15 | scanf("%d", &a[i]), v[i] = a[i]; |
16 | } |
17 | sort(v + 1, v + n + 1); |
18 | m = unique(v + 1, v + n + 1) - (v + 1); |
19 | for (int i = 1; i <= n; i++) { |
20 | a[i] = lower_bound(v + 1, v + m + 1, a[i]) - v; |
21 | } |
22 | for (int x = 1; x <= m; x++) { |
23 | c = 0; |
24 | p[++c] = 0; |
25 | for (int i = 1; i <= n; i++) { |
26 | if (a[i] > x) { |
27 | p[++c]= i; |
28 | } |
29 | } |
30 | p[++c] = n + 1; |
31 | for (int i = 1; i <= n; i++) { |
32 | for (int j = i; j <= n; j++) { |
33 | dp[0][i][j] = 0; |
34 | } |
35 | } |
36 | for (int i = 1; i < c; i++) { |
37 | if (p[i] != p[i + 1] - 1) { |
38 | dp[0][p[i] + 1][p[i + 1] - 1] = 1; |
39 | } |
40 | } |
41 | for (int t = 1; t <= q; t++) { |
42 | for (int i = 1; i <= n; i++) { |
43 | f[i][0] = 0; |
44 | for (int j = 1; j < i; j++) { |
45 | f[i][j] = (f[i][j - 1] + 1ll * (j - 1) * dp[t - 1][j][i]) % mod; |
46 | } |
47 | } |
48 | for (int i = n; i; i--) { |
49 | g[i][n + 1] = 0; |
50 | for (int j = n; j > i; j--) { |
51 | g[i][j] = (g[i][j + 1] + 1ll * (n - j) * dp[t - 1][i][j]) % mod; |
52 | } |
53 | } |
54 | for (int l = 1; l <= n; l++) { |
55 | for (int r = l; r <= n; r++) { |
56 | dp[t][l][r] = (1ll * dp[t - 1][l][r] * (calc(l - 1) + calc(r - l + 1) + calc(n - r)) + f[r][l - 1] + g[l][r + 1]) % mod; |
57 | } |
58 | } |
59 | } |
60 | for (int i = 1; i <= n; i++) { |
61 | for (int j = i; j <= n; j++) { |
62 | h[x][i] = (h[x][i] + dp[q][i][j]) % mod; |
63 | h[x][j + 1] = (h[x][j + 1] - dp[q][i][j] + mod) % mod; |
64 | } |
65 | } |
66 | for (int i = 1; i <= n; i++) { |
67 | h[x][i] = (h[x][i] + h[x][i - 1]) % mod; |
68 | } |
69 | } |
70 | for (int i = 1; i <= n; i++) { |
71 | int ans = 0; |
72 | for (int x = 1; x <= m; x++) { |
73 | ans = (ans + 1ll * (h[x][i] - h[x - 1][i] + mod) * v[x]) % mod; |
74 | } |
75 | printf("%d%c", ans, " \n"[i == n]); |
76 | } |
77 | return 0; |
78 | } |
算法二(100 分)
1 |
|
2 |
|
3 | using namespace std; |
4 | |
5 | const int maxn = 400, mod = 1e9 + 7; |
6 | int n, m, q, a[maxn + 3], v[maxn + 3], c, p[maxn + 3], dp[2][maxn + 3][maxn + 3], f[maxn + 3][maxn + 3], g[maxn + 3][maxn + 3], h[maxn + 3]; |
7 | |
8 | int calc(int x) { |
9 | return x * (x + 1) / 2; |
10 | } |
11 | |
12 | int func(int x) { |
13 | return x < 0 ? x + mod : x < mod ? x : x - mod; |
14 | } |
15 | |
16 | int main() { |
17 | scanf("%d %d", &n, &q); |
18 | for (int i = 1; i <= n; i++) { |
19 | scanf("%d", &a[i]), v[i] = a[i]; |
20 | } |
21 | sort(v + 1, v + n + 1); |
22 | m = unique(v + 1, v + n + 1) - (v + 1); |
23 | for (int i = 1; i <= n; i++) { |
24 | a[i] = lower_bound(v + 1, v + m + 1, a[i]) - v; |
25 | } |
26 | for (int x = 0; x <= m; x++) { |
27 | c = 0; |
28 | p[++c] = 0; |
29 | for (int i = 1; i <= n; i++) { |
30 | if (a[i] > x) { |
31 | p[++c]= i; |
32 | } |
33 | } |
34 | p[++c] = n + 1; |
35 | for (int i = 1; i < c; i++) { |
36 | if (p[i] != p[i + 1] - 1) { |
37 | dp[0][p[i] + 1][p[i + 1] - 1] = func(dp[0][p[i] + 1][p[i + 1] - 1] + v[x] - v[x + 1]); |
38 | } |
39 | } |
40 | } |
41 | int cur = 0, lst = 1; |
42 | for (int t = 1; t <= q; t++) { |
43 | swap(cur, lst); |
44 | for (int i = 1; i <= n; i++) { |
45 | f[i][0] = 0; |
46 | for (int j = 1; j < i; j++) { |
47 | f[i][j] = (f[i][j - 1] + 1ll * (j - 1) * dp[lst][j][i]) % mod; |
48 | } |
49 | } |
50 | for (int i = n; i; i--) { |
51 | g[i][n + 1] = 0; |
52 | for (int j = n; j > i; j--) { |
53 | g[i][j] = (g[i][j + 1] + 1ll * (n - j) * dp[lst][i][j]) % mod; |
54 | } |
55 | } |
56 | for (int l = 1; l <= n; l++) { |
57 | for (int r = l; r <= n; r++) { |
58 | dp[cur][l][r] = (1ll * dp[lst][l][r] * (calc(l - 1) + calc(r - l + 1) + calc(n - r)) + f[r][l - 1] + g[l][r + 1]) % mod; |
59 | } |
60 | } |
61 | } |
62 | for (int i = 1; i <= n; i++) { |
63 | for (int j = i; j <= n; j++) { |
64 | h[i] = (h[i] + dp[cur][i][j]) % mod; |
65 | h[j + 1] = (h[j + 1] - dp[cur][i][j] + mod) % mod; |
66 | } |
67 | } |
68 | for (int i = 1; i <= n; i++) { |
69 | h[i] = (h[i] + h[i - 1]) % mod; |
70 | printf("%d%c", h[i], " \n"[i == n]); |
71 | } |
72 | return 0; |
73 | } |