题目大意
给定一个长度为 $n$ 的字符串 $s$,$s_i \in {<, >}$。要求计数长度为 $n + 1$ 的排列 $a$,满足 $a_i < a_{i + 1}$ 当且仅当 $s_i$ 为 $<$,模数为 $998244353$。
数据范围:$n \le 10^5$。
思路分析
首先想到一个状态数为 $n^2$ 的 $\text{dp}$,但是发现它不能优化。于是考虑容斥。
对于所有 $<$,我们考虑把它容斥成 $\neq$ 的方案数减去 $>$ 的方案数。这样容斥后,问题就变成了字符串的每一位都是无限制或大于号,求方案数。假设第 $i$ 个大于号的连续段长度为 $\text{len}_i - 1$,则总方案数为 $\displaystyle \frac{n!}{\prod \text{len}_i!}$。这样,我们就得到了一个 $O(2^n)$ 的做法。
我们发现我们可以使用带容斥系数的 $\text{dp}$ 来解决这个问题。那么:
其中 $\text{cnt}_i$ 表示前 $i$ 个位置有多少个 $\text{<}$,而答案就等于 $dp_{n + 1} \times (n + 1)!$。这样做的时间复杂度为 $O(n^2)$。
发现这个式子可以使用分治 FFT 优化。时间复杂度 $O(n \log^2 n)$,可以通过本题。
代码实现
1 |
|
2 |
|
3 | using namespace std; |
4 | |
5 | const int maxn = 1 << 18, mod = 998244353; |
6 | int n, m, k, cnt[maxn + 3], fact[maxn + 3], finv[maxn + 3], num[maxn + 3]; |
7 | int f[maxn + 3], g[maxn + 3], a[maxn + 3], b[maxn + 3], rev[maxn + 3], dp[maxn + 3]; |
8 | char s[maxn + 3]; |
9 | |
10 | void add(int &x, int y) { |
11 | x += y, x < mod ? 0 : x -= mod; |
12 | } |
13 | |
14 | int qpow(int a, int b) { |
15 | b < 0 ? b += mod - 1 : 0; |
16 | int c = 1; |
17 | for (; b; b >>= 1, a = 1ll * a * a % mod) { |
18 | if (b & 1) c = 1ll * a * c % mod; |
19 | } |
20 | return c; |
21 | } |
22 | |
23 | void prework(int n) { |
24 | for (int i = 1; i <= n; i++) { |
25 | cnt[i] = cnt[i - 1] + (s[i] == '<'); |
26 | } |
27 | fact[0] = finv[0] = 1; |
28 | for (int i = 1; i <= n; i++) { |
29 | fact[i] = 1ll * fact[i - 1] * i % mod; |
30 | } |
31 | finv[n] = qpow(fact[n], mod - 2); |
32 | for (int i = n; i; i--) { |
33 | finv[i - 1] = 1ll * finv[i] * i % mod; |
34 | } |
35 | num[0] = 1; |
36 | for (int i = 1; i <= n; i++) { |
37 | num[i] = mod - num[i - 1]; |
38 | } |
39 | } |
40 | |
41 | void dft(int a[], int n, int type) { |
42 | for (int i = 0; i < n; i++) { |
43 | if (i < rev[i]) swap(a[i], a[rev[i]]); |
44 | } |
45 | for (int k = 1; k < n; k <<= 1) { |
46 | int x = qpow(3, (mod - 1) / (k << 1) * type); |
47 | for (int i = 0; i < n; i += k << 1) { |
48 | int y = 1; |
49 | for (int j = i; j < i + k; j++, y = 1ll * x * y % mod) { |
50 | int p = a[j], q = 1ll * y * a[j + k] % mod; |
51 | a[j] = p + q, a[j] < mod ? 0 : a[j] -= mod; |
52 | a[j + k] = p - q, a[j + k] < 0 ? a[j + k] += mod : 0; |
53 | } |
54 | } |
55 | } |
56 | if (type == -1) { |
57 | int x = qpow(n, mod - 2); |
58 | for (int i = 0; i < n; i++) { |
59 | a[i] = 1ll * a[i] * x % mod; |
60 | } |
61 | } |
62 | } |
63 | |
64 | void cdq(int l, int r) { |
65 | if (l == r) { |
66 | if (!l) f[l] = 1; |
67 | f[l] = 1ll * f[l] * num[cnt[l - 1]] % mod; |
68 | dp[l] = f[l]; |
69 | f[l] = 1ll * f[l] * num[cnt[l]] % mod * (!l || s[l] == '<'); |
70 | return; |
71 | } |
72 | cdq(l, mid); |
73 | for (int i = l; i <= mid; i++) { |
74 | a[i - l] = f[i]; |
75 | } |
76 | for (int i = 0; i <= r - l; i++) { |
77 | b[i] = g[i]; |
78 | } |
79 | int x = mid - l, y = r - l; |
80 | for (k = 0, m = 1; m <= x + y; m <<= 1) k++; |
81 | for (int i = x + 1; i < m; i++) { |
82 | a[i] = 0; |
83 | } |
84 | for (int i = y + 1; i < m; i++) { |
85 | b[i] = 0; |
86 | } |
87 | for (int i = 1; i < m; i++) { |
88 | rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (k - 1)); |
89 | } |
90 | dft(a, m, 1), dft(b, m, 1); |
91 | for (int i = 0; i < m; i++) { |
92 | a[i] = 1ll * a[i] * b[i] % mod; |
93 | } |
94 | dft(a, m, -1); |
95 | for (int i = mid + 1; i <= r; i++) { |
96 | add(f[i], a[i - l]); |
97 | } |
98 | cdq(mid + 1, r); |
99 | } |
100 | |
101 | int main() { |
102 | scanf("%s", s + 1); |
103 | n = strlen(s + 1) + 1; |
104 | prework(n); |
105 | for (int i = 1; i <= n; i++) { |
106 | g[i] = finv[i]; |
107 | } |
108 | cdq(0, n); |
109 | int ans = 1ll * dp[n] * fact[n] % mod; |
110 | printf("%d\n", ans); |
111 | return 0; |
112 | } |