只要你跑的够快,锅就追不上你

0%

「ZJOI 2016」线段树(动态规划)

题目大意

「ZJOI 2016」线段树(Luogu 3352)

有一个长度为 $n$ 的数列 $a_1, a_2, \cdots, a_n$。进行 $q$ 次操作,每次随机选出一个区间,然后把区间中的所有数赋值为区间的最大值。问最后每一个位置上的数的期望乘上所有可能操作个数后$\bmod 10 ^ 9 + 7$ 是多少。

数据范围:$n, q \le 400$。

思路分析

算法一

由于数很大,我们先将序列离散化。我们假设数列已经离散化完毕,并且数列中等于 $i$ 的数原来是 $\text{val}(i)$。

对于每个 $x$ 分别考虑它最后在每个位置出现了几次。令 $\text{dp}(x, i, l, r)$ 表示 $i$ 次操作过后,$a_l, a_{l + 1}, \cdots, a_r \le x$,$a_{l - 1}, a_{r + 1} > x$,即 $[l, r]$ 是极长的 $\le x$ 区间的方案数。

考虑第 $i$ 次选择的区间:

  • 若它完全在原区间内部或外部,则原区间不会变化。
  • 若它横跨原区间的左端点,但是不经过右端点,那么新区间的右端点不变,左端点向右移动到选择的区间的右端点 $+ 1$ 处。
  • 若它横跨原区间的右端点,但是不经过左端点,那么新区间的左端点不变,右端点向左移动到选择的区间的左端点 $- 1$ 处。
  • 若它包含整个区间,那么原区间就会消失。

综上所述,我们可以列出 DP 的转移方程:

其中 $g(l, r)$ 表示对区间 $[l, r]$ 没有影响的操作区间个数。

使用前缀和优化即可达到 $O(n ^ 2 q)$ 的复杂度。由于有 $n$ 个不同的 $x$,总复杂度为 $O(n ^ 3 q)$。

那么对于一个位置 $i$,它 $\le x$ 的情况数就是 $\sum_{i \in [l, r]} \text{dp}(x, q, l, r)$。我们记这个数为 $f(i, x)$,位置 $i$ 等于 $x$ 的方案数就是 $f(i, x) - f(i, x - 1)$。所以位置 $i$ 的最终答案为:$\sum_{x} \text{val}(x) \times (f(i, x) - f(i, x - 1))$。

期望得分 $70$ 分。

算法二

我们考虑在算法一的基础上优化。发现复杂度瓶颈在于一开始的枚举 $x$,我们应该去掉这个过程。注意到对于 $f(i, x)$,它对位置 $i$ 有 $\text{val}(x) - \text{val}(x + 1)$ 的贡献。于是我们不必枚举 $x$,只需要在一开始给 DP 赋初始值的时候将贡献加进去即可。也就是说,对于每个 $x$,找出它对应的极长区间,然后让 $\text{dp}(0, l, r) \leftarrow \text{dp}(0, l, r) + \text{val}(x) - \text{val}(x + 1)$。这样的复杂度就减小到了 $O(n ^ 2 q)$。详见代码。

期望得分 $100$ 分。

代码实现

算法一(70 分)

1
#include <cstdio>
2
#include <algorithm>
3
using namespace std;
4
5
const int maxn = 400, mod = 1e9 + 7;
6
int n, m, q, a[maxn + 3], v[maxn + 3], c, p[maxn + 3], dp[maxn + 3][maxn + 3][maxn + 3], f[maxn + 3][maxn + 3], g[maxn + 3][maxn + 3], h[maxn + 3][maxn + 3];
7
8
int calc(int x) {
9
    return x * (x + 1) / 2;
10
}
11
12
int main() {
13
    scanf("%d %d", &n, &q);
14
    for (int i = 1; i <= n; i++) {
15
        scanf("%d", &a[i]), v[i] = a[i];
16
    }
17
    sort(v + 1, v + n + 1);
18
    m = unique(v + 1, v + n + 1) - (v + 1);
19
    for (int i = 1; i <= n; i++) {
20
        a[i] = lower_bound(v + 1, v + m + 1, a[i]) - v;
21
    }
22
    for (int x = 1; x <= m; x++) {
23
        c = 0;
24
        p[++c] = 0;
25
        for (int i = 1; i <= n; i++) {
26
            if (a[i] > x) {
27
                p[++c]= i;
28
            }
29
        }
30
        p[++c] = n + 1;
31
        for (int i = 1; i <= n; i++) {
32
            for (int j = i; j <= n; j++) {
33
                dp[0][i][j] = 0;
34
            }
35
        }
36
        for (int i = 1; i < c; i++) {
37
            if (p[i] != p[i + 1] - 1) {
38
                dp[0][p[i] + 1][p[i + 1] - 1] = 1;
39
            }
40
        }
41
        for (int t = 1; t <= q; t++) {
42
            for (int i = 1; i <= n; i++) {
43
                f[i][0] = 0;
44
 				for (int j = 1; j < i; j++) {
45
 					f[i][j] = (f[i][j - 1] + 1ll * (j - 1) * dp[t - 1][j][i]) % mod;
46
 				}
47
            }
48
            for (int i = n; i; i--) {
49
                g[i][n + 1] = 0;
50
                for (int j = n; j > i; j--) {
51
                    g[i][j] = (g[i][j + 1] + 1ll * (n - j) * dp[t - 1][i][j]) % mod;
52
                }
53
            }
54
            for (int l = 1; l <= n; l++) {
55
                for (int r = l; r <= n; r++) {
56
                    dp[t][l][r] = (1ll * dp[t - 1][l][r] * (calc(l - 1) + calc(r - l + 1) + calc(n - r)) + f[r][l - 1] + g[l][r + 1]) % mod;
57
                }
58
            }
59
        }
60
        for (int i = 1; i <= n; i++) {
61
            for (int j = i; j <= n; j++) {
62
                h[x][i] = (h[x][i] + dp[q][i][j]) % mod;
63
                h[x][j + 1] = (h[x][j + 1] - dp[q][i][j] + mod) % mod;
64
            }
65
        }
66
        for (int i = 1; i <= n; i++) {
67
            h[x][i] = (h[x][i] + h[x][i - 1]) % mod;
68
        }
69
    }
70
    for (int i = 1; i <= n; i++) {
71
        int ans = 0;
72
        for (int x = 1; x <= m; x++) {
73
            ans = (ans + 1ll * (h[x][i] - h[x - 1][i] + mod) * v[x]) % mod;
74
        }
75
        printf("%d%c", ans, " \n"[i == n]);
76
    }
77
    return 0;
78
}

算法二(100 分)

1
#include <cstdio>
2
#include <algorithm>
3
using namespace std;
4
5
const int maxn = 400, mod = 1e9 + 7;
6
int n, m, q, a[maxn + 3], v[maxn + 3], c, p[maxn + 3], dp[2][maxn + 3][maxn + 3], f[maxn + 3][maxn + 3], g[maxn + 3][maxn + 3], h[maxn + 3];
7
8
int calc(int x) {
9
	return x * (x + 1) / 2;
10
}
11
12
int func(int x) {
13
	return x < 0 ? x + mod : x < mod ? x : x - mod;
14
}
15
16
int main() {
17
	scanf("%d %d", &n, &q);
18
	for (int i = 1; i <= n; i++) {
19
		scanf("%d", &a[i]), v[i] = a[i];
20
	}
21
	sort(v + 1, v + n + 1);
22
	m = unique(v + 1, v + n + 1) - (v + 1);
23
	for (int i = 1; i <= n; i++) {
24
		a[i] = lower_bound(v + 1, v + m + 1, a[i]) - v;
25
	}
26
	for (int x = 0; x <= m; x++) {
27
		c = 0;
28
		p[++c] = 0;
29
		for (int i = 1; i <= n; i++) {
30
			if (a[i] > x) {
31
				p[++c]= i;
32
			}
33
		}
34
		p[++c] = n + 1;
35
		for (int i = 1; i < c; i++) {
36
			if (p[i] != p[i + 1] - 1) {
37
				dp[0][p[i] + 1][p[i + 1] - 1] = func(dp[0][p[i] + 1][p[i + 1] - 1] + v[x] - v[x + 1]);
38
			}
39
		}
40
	}
41
	int cur = 0, lst = 1;
42
	for (int t = 1; t <= q; t++) {
43
		swap(cur, lst);
44
		for (int i = 1; i <= n; i++) {
45
			f[i][0] = 0;
46
			for (int j = 1; j < i; j++) {
47
				f[i][j] = (f[i][j - 1] + 1ll * (j - 1) * dp[lst][j][i]) % mod;
48
			}
49
		}
50
		for (int i = n; i; i--) {
51
			g[i][n + 1] = 0;
52
			for (int j = n; j > i; j--) {
53
				g[i][j] = (g[i][j + 1] + 1ll * (n - j) * dp[lst][i][j]) % mod;
54
			}
55
		}
56
		for (int l = 1; l <= n; l++) {
57
			for (int r = l; r <= n; r++) {
58
				dp[cur][l][r] = (1ll * dp[lst][l][r] * (calc(l - 1) + calc(r - l + 1) + calc(n - r)) + f[r][l - 1] + g[l][r + 1]) % mod;
59
			}
60
		}
61
	}
62
	for (int i = 1; i <= n; i++) {
63
		for (int j = i; j <= n; j++) {
64
			h[i] = (h[i] + dp[cur][i][j]) % mod;
65
			h[j + 1] = (h[j + 1] - dp[cur][i][j] + mod) % mod;
66
		}
67
	}
68
	for (int i = 1; i <= n; i++) {
69
		h[i] = (h[i] + h[i - 1]) % mod;
70
		printf("%d%c", h[i], " \n"[i == n]);
71
	}
72
	return 0;
73
}