只要你跑的够快,锅就追不上你

0%

「学习笔记」拉格朗日插值

题目大意

「模板」拉格朗日插值(Luogu 4781)

给定平面上的 $n$ 个点 $(x_1, y_1), (x_2, y_2), \cdots, (x_n, y_n)$,它们可以确定一个至多 $n - 1$ 次多项式。问这个多项式在 $k$ 点的取值$\bmod 998244353$ 的结果。

数据范围:$n \le 2000$。

思路分析

我们可以构造一个多项式 $f_i(k)$ 满足它在 $x_i$ 处为 $1$,但是在 $x_j (i \neq j)$ 处为 $0$:

于是题目中求的多项式 $f$ 就是:

这个多项式可以 $O(n^2)$ 计算,于是我们已经可以解决这个问题了。

但是我们发现在有些题目中,我们是动态维护这个过程的。也就是说,每次在平面上增加一个点,平面上的点确定的多项式的次数就会增加 $1$,然后我们求出所有这些多项式在 $k$ 点的值。

这时我们直接暴力做就变成 $O(n^3)$ 的了,不够优秀。观察式子,我们可以发现:

令 $g = \prod_{i} (k - x_j), t_i = \frac{y_i}{\prod_{i \neq j} x_i - x_j}$,那么:

这样,我们只需要再加入一个点的同时维护 $g, t_i$ 即可。$g$ 可以 $O(1)$ 地维护,而对于 $t_i$,我们需要扫一遍所有点,还要求 $n$ 个数的逆元,直接暴力做是 $O(n \log n)$ 的。但是,我们使用线性求逆元的科技即可把这一步优化到 $O(n)$。总复杂度 $O(n^2)$。

代码实现

这里采用动态的插值方法。

1
#include <cstdio>
2
3
const int maxn = 2e3, mod = 998244353;
4
int n, k, x[maxn + 3], y[maxn + 3], g, t[maxn + 3], num[maxn + 3], suf[maxn + 3], inv[maxn + 3];
5
6
int f(int x) {
7
	return x < 0 ? x + mod : x < mod ? x : x - mod;
8
}
9
10
int qpow(int a, int b) {
11
	int c = 1;
12
	for (; b; b >>= 1, a = 1ll * a * a % mod) {
13
		if (b & 1) c = 1ll * a * c % mod;
14
	}
15
	return c;
16
}
17
18
void solve(int num[], int inv[], int n) {
19
	num[0] = num[n + 1] = 1;
20
	for (int i = 0; i <= n + 1; i++) {
21
		suf[i] = num[i];
22
	}
23
	for (int i = 1; i <= n; i++) {
24
		num[i] = 1ll * num[i - 1] * num[i] % mod;
25
	}
26
	for (int i = n; i; i--) {
27
		suf[i] = 1ll * suf[i + 1] * suf[i] % mod;
28
	}
29
	int x = qpow(num[n], mod - 2);
30
	for (int i = 1; i <= n; i++) {
31
		inv[i] = 1ll * x * num[i - 1] % mod * suf[i + 1] % mod;
32
	}
33
}
34
35
int main() {
36
	scanf("%d %d", &n, &k);
37
	g = 1;
38
	for (int i = 1; i <= n; i++) {
39
		scanf("%d %d", &x[i], &y[i]);
40
		g = 1ll * g * f(k - x[i]) % mod;
41
		for (int j = 1; j < i; j++) {
42
			num[j] = f(x[i] - x[j]);
43
		}
44
		solve(num, inv, i - 1);
45
		t[i] = y[i];
46
		for (int j = 1; j < i; j++) {
47
			t[i] = 1ll * t[i] * inv[j] % mod;
48
		}
49
		for (int j = 1; j < i; j++) {
50
			t[j] = 1ll * t[j] * f(-inv[j]) % mod;
51
		}
52
	}
53
	int ans = 0;
54
	for (int i = 1; i <= n; i++) {
55
		ans = (ans + 1ll * t[i] * qpow(f(k - x[i]), mod - 2)) % mod;
56
	}
57
	ans = 1ll * ans * g % mod;
58
	printf("%d\n", ans);
59
	return 0;
60
}