只要你跑的够快,锅就追不上你

0%

「LOJ 575」不等关系(容斥原理 + 多项式)

题目大意

「LOJ 575」不等关系

给定一个长度为 $n$ 的字符串 $s$,$s_i \in {<, >}$。要求计数长度为 $n + 1$ 的排列 $a$,满足 $a_i < a_{i + 1}$ 当且仅当 $s_i$ 为 $<$,模数为 $998244353$。

数据范围:$n \le 10^5$。

思路分析

首先想到一个状态数为 $n^2$ 的 $\text{dp}$,但是发现它不能优化。于是考虑容斥。

对于所有 $<$,我们考虑把它容斥成 $\neq$ 的方案数减去 $>$ 的方案数。这样容斥后,问题就变成了字符串的每一位都是无限制或大于号,求方案数。假设第 $i$ 个大于号的连续段长度为 $\text{len}_i - 1$,则总方案数为 $\displaystyle \frac{n!}{\prod \text{len}_i!}$。这样,我们就得到了一个 $O(2^n)$ 的做法。

图一

我们发现我们可以使用带容斥系数的 $\text{dp}$ 来解决这个问题。那么:

其中 $\text{cnt}_i$ 表示前 $i$ 个位置有多少个 $\text{<}$,而答案就等于 $dp_{n + 1} \times (n + 1)!$。这样做的时间复杂度为 $O(n^2)$。

发现这个式子可以使用分治 FFT 优化。时间复杂度 $O(n \log^2 n)$,可以通过本题。

代码实现

1
#include <bits/stdc++.h>
2
#define mid ((l + r) >> 1)
3
using namespace std;
4
5
const int maxn = 1 << 18, mod = 998244353;
6
int n, m, k, cnt[maxn + 3], fact[maxn + 3], finv[maxn + 3], num[maxn + 3];
7
int f[maxn + 3], g[maxn + 3], a[maxn + 3], b[maxn + 3], rev[maxn + 3], dp[maxn + 3];
8
char s[maxn + 3];
9
10
void add(int &x, int y) {
11
	x += y, x < mod ? 0 : x -= mod;
12
}
13
14
int qpow(int a, int b) {
15
	b < 0 ? b += mod - 1 : 0;
16
	int c = 1;
17
	for (; b; b >>= 1, a = 1ll * a * a % mod) {
18
		if (b & 1) c = 1ll * a * c % mod;
19
	}
20
	return c;
21
}
22
23
void prework(int n) {
24
	for (int i = 1; i <= n; i++) {
25
		cnt[i] = cnt[i - 1] + (s[i] == '<');
26
	}
27
	fact[0] = finv[0] = 1;
28
	for (int i = 1; i <= n; i++) {
29
		fact[i] = 1ll * fact[i - 1] * i % mod;
30
	}
31
	finv[n] = qpow(fact[n], mod - 2);
32
	for (int i = n; i; i--) {
33
		finv[i - 1] = 1ll * finv[i] * i % mod;
34
	}
35
	num[0] = 1;
36
	for (int i = 1; i <= n; i++) {
37
		num[i] = mod - num[i - 1];
38
	}
39
}
40
41
void dft(int a[], int n, int type) {
42
	for (int i = 0; i < n; i++) {
43
		if (i < rev[i]) swap(a[i], a[rev[i]]);
44
	}
45
	for (int k = 1; k < n; k <<= 1) {
46
		int x = qpow(3, (mod - 1) / (k << 1) * type);
47
		for (int i = 0; i < n; i += k << 1) {
48
			int y = 1;
49
			for (int j = i; j < i + k; j++, y = 1ll * x * y % mod) {
50
				int p = a[j], q = 1ll * y * a[j + k] % mod;
51
				a[j] = p + q, a[j] < mod ? 0 : a[j] -= mod;
52
				a[j + k] = p - q, a[j + k] < 0 ? a[j + k] += mod : 0;
53
			}
54
		}
55
	}
56
	if (type == -1) {
57
		int x = qpow(n, mod - 2);
58
		for (int i = 0; i < n; i++) {
59
			a[i] = 1ll * a[i] * x % mod;
60
		}
61
	}
62
}
63
64
void cdq(int l, int r) {
65
	if (l == r) {
66
		if (!l) f[l] = 1;
67
		f[l] = 1ll * f[l] * num[cnt[l - 1]] % mod;
68
		dp[l] = f[l];
69
		f[l] = 1ll * f[l] * num[cnt[l]] % mod * (!l || s[l] == '<');
70
		return;
71
	}
72
	cdq(l, mid);
73
	for (int i = l; i <= mid; i++) {
74
		a[i - l] = f[i];
75
	}
76
	for (int i = 0; i <= r - l; i++) {
77
		b[i] = g[i]; 
78
	}
79
	int x = mid - l, y = r - l;
80
	for (k = 0, m = 1; m <= x + y; m <<= 1) k++;
81
	for (int i = x + 1; i < m; i++) {
82
		a[i] = 0;
83
	}
84
	for (int i = y + 1; i < m; i++) {
85
		b[i] = 0;
86
	}
87
	for (int i = 1; i < m; i++) {
88
		rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (k - 1));
89
	}
90
	dft(a, m, 1), dft(b, m, 1);
91
	for (int i = 0; i < m; i++) {
92
		a[i] = 1ll * a[i] * b[i] % mod; 
93
	}
94
	dft(a, m, -1);
95
	for (int i = mid + 1; i <= r; i++) {
96
		add(f[i], a[i - l]);
97
	}
98
	cdq(mid + 1, r);
99
}
100
101
int main() {
102
	scanf("%s", s + 1);
103
	n = strlen(s + 1) + 1;
104
	prework(n);
105
	for (int i = 1; i <= n; i++) {
106
		g[i] = finv[i];
107
	}
108
	cdq(0, n);
109
	int ans = 1ll * dp[n] * fact[n] % mod;
110
	printf("%d\n", ans);
111
	return 0;
112
}