只要你跑的够快,锅就追不上你

0%

「NOI 2018」你的名字(后缀自动机 + 线段树)

「NOI 2018」你的名字(UOJ 395)

题目描述

给定字符串 $S$,$q$ 次询问给定 $(T, l, r)$,问 $T$ 有多少个本质不同的子串不在 $S$ 中出现。

数据范围:$\vert S \vert, \vert T \vert \le 5 \times 10^5, \sum \vert T \vert \le 10^6$。

思路分析

首先建出 $T$ 的 SAM,对于 $T$ 的 SAM 中的每个点考虑它表示的串的集合中有哪些不是 $S$ 的子串。假设某个点表示 $\text{abcd, bcd, cd}$ 三个串,如果 $\text{bcd}$ 出现过,那么之后的 $\text{cd}$ 也一定出现过,所以我们只需要求出最长的出现过的串的长度即可。记 $\text{lim}_i$ 表示 $T[1, i]$ 中是 $S$ 子串的最长后缀长度,假设我们现在考虑结点 $x$,那么我们找到 $x$ 的某个 endpos 记为 $\text{pos}_x$,那么它对答案的贡献就是:

考虑如何求出 $\text{pos}_x$ 和 $\text{lim}_x$。对于一个新产生的结点,我们将它的 $\text{pos}$ 定为插入的字符所在的位置。对于一个复制来的结点,我们将它的 $\text{pos}$ 定为原始结点的 $\text{pos}$。对于 $\text{lim}_x$,我们先考虑 $l = 1, r = \vert S \vert$ 的情况。我们建出 $S$ 的 SAM,假设我们已经得到了 $T[1, i - 1]$ 的答案以及在 $S$ 的 SAM 中这个后缀所在的结点,我们就可以尝试走这个结点的 $T_i$ 这条出边,如果没有这条出边的话我们就一直往上跳 fail,直到当前结点有这条出边为止,这样我们就得到了 $T[1, i]$ 的答案。

当 $l, r$ 任意的情况,我们考虑对于 $S$ 的 SAM 中的每个点维护一个它子树的 endpos 集合。这个可以每个点维护一个权值线段树,然后使用线段树合并来求出。我们还是按照刚才的思路,假设上次的结点为 $x$,答案为 $y$,我们每次检查 $x$ 从 $T_i$ 这条边走到的点的 endpos 集合中是否有 $[l + y, r]$ 之间的数,如果没有我们就把 $y \gets y - 1$,然后判断 $x$ 是否会变成 $x$ 的父亲。类似地,找到一个符合条件的出边时,我们就可以走出去,并把 $y \gets y + 1$。这样的复杂度是对的,因为我们每次只会把 $y$ 至多加 $1$。于是整个题就做完了,时间复杂度 $O(n \log n)$。

代码实现

1
#include <bits/stdc++.h>
2
#define mid ((l + r) >> 1)
3
#define ls(x) son[x][0]
4
#define rs(x) son[x][1]
5
using namespace std;
6
7
typedef long long ll;
8
const int maxn = 1e6, maxm = 4e7;
9
int n, q, m, cnt, son[maxm + 3][2], lim[maxn + 3];
10
char s[maxn + 3], t[maxn + 3];
11
12
struct sam {
13
	int n, tot, lst, ch[maxn + 3][26], lnk[maxn + 3], len[maxn + 3];
14
	int tmp[maxn + 3], ord[maxn + 3], pos[maxn + 3], rt[maxn + 3];
15
16
	void clear() {
17
		for (int i = 1; i <= tot; i++) {
18
			memset(ch[i], 0, sizeof(ch[i]));
19
		}
20
		fill(lnk + 1, lnk + tot + 1, 0);
21
		fill(len + 1, len + tot + 1, 0);
22
		n = 0, tot = 1, lst = 1;
23
	}
24
25
	void extend(int k) {
26
		int x = ++tot, y = lst;
27
		pos[x] = ++n, len[x] = len[y] + 1, lst = x;
28
		while (y && !ch[y][k]) ch[y][k] = x, y = lnk[y];
29
		if (!y) {
30
			lnk[x] = 1;
31
			return;
32
		}
33
		int z = ch[y][k];
34
		if (len[y] + 1 == len[z]) {
35
			lnk[x] = z;
36
		} else {
37
			int t = ++tot;
38
			len[t] = len[y] + 1, lnk[t] = lnk[z], pos[t] = pos[z];
39
			memcpy(ch[t], ch[z], sizeof(ch[t]));
40
			lnk[x] = lnk[z] = t;
41
			while (ch[y][k] == z) ch[y][k] = t, y = lnk[y];
42
		}
43
	}
44
45
	void sort() {
46
		fill(tmp + 1, tmp + n + 1, 0);
47
		for (int i = 1; i <= tot; i++) tmp[len[i]]++;
48
		for (int i = 1; i <= n; i++) tmp[i] += tmp[i - 1];
49
		for (int i = tot; i; i--) ord[tmp[len[i]]--] = i;
50
	}
51
52
	int merge(int x, int y) {
53
		if (!x || !y) return x + y;
54
		int z = ++cnt;
55
		ls(z) = merge(ls(x), ls(y));
56
		rs(z) = merge(rs(x), rs(y));
57
		return z;
58
	}
59
60
	void mod(int &x, int l, int r, int y) {
61
		if (!x) x = ++cnt;
62
		if (l == r) return;
63
		if (y <= mid) mod(ls(x), l, mid, y);
64
		else mod(rs(x), mid + 1, r, y);
65
	}
66
	
67
	void solve() {
68
		sort();
69
		for (int i = tot; i > 1; i--) {
70
			rt[lnk[ord[i]]] = merge(rt[lnk[ord[i]]], rt[ord[i]]);
71
		}
72
	}
73
74
	bool query(int x, int l, int r, int lx, int rx) {
75
		if (!x) return false;
76
		if (l >= lx && r <= rx) return true;
77
		bool ret = false;
78
		if (lx <= mid) ret |= query(ls(x), l, mid, lx, rx);
79
		if (rx > mid) ret |= query(rs(x), mid + 1, r, lx, rx);
80
		return ret;
81
	}
82
} A, B;
83
84
int main() {
85
	scanf("%s", s + 1);
86
	n = strlen(s + 1);
87
	A.clear();
88
	for (int i = 1; i <= n; i++) {
89
		A.extend(s[i] - 'a');
90
		A.mod(A.rt[A.lst], 1, n, i);
91
	}
92
	A.solve();
93
	scanf("%d", &q);
94
	for (int l, r; q --> 0; ) {
95
		scanf("%s %d %d", t + 1, &l, &r);
96
		m = strlen(t + 1);
97
		B.clear();
98
		for (int i = 1; i <= m; i++) {
99
			B.extend(t[i] - 'a');
100
		}
101
		int	cur = 1, len = 0;
102
		for (int i = 1; i <= m; i++) {
103
			int k = t[i] - 'a';
104
			while (len && !A.query(A.rt[A.ch[cur][k]], 1, n, l + len, r)) {
105
				len--;
106
				if (len == A.len[A.lnk[cur]]) {
107
					cur = A.lnk[cur];
108
				}
109
			}
110
			if (A.query(A.rt[A.ch[cur][k]], 1, n, l + len, r)) {
111
				len++, cur = A.ch[cur][k];
112
			}
113
			lim[i] = len;
114
		}
115
		ll ans = 0;
116
		for (int i = 2; i <= B.tot; i++) {
117
			ans += max(0, B.len[i] - max(B.len[B.lnk[i]], lim[B.pos[i]]));
118
		}
119
		printf("%lld\n", ans);
120
	}
121
	return 0;
122
}