只要你跑的够快,锅就追不上你

0%

「Codeforces 547E」Mike and Friends(后缀数组)

「Codeforces 547E」Mike and Friends

题目描述

给定 $n$ 个串,$q$ 次询问 $(l, r, k)$,问第 $k$ 个串在第 $[l, r]$ 个串中作为子串出现了多少次。

数据范围:$n \le 2 \times 10^5, q \le 5 \times 10^5$。

思路分析

把所有串用 # 隔开拼成一个大串,然后后缀排个序。第 $k$ 个串如果是某个后缀 $i$ 的前缀当且仅当 $\text{lcp}(p_k, i) \ge \text{len}(k)$,其中 $p_k$ 表示第 $k$ 个串开始的位置,$\text{len}(k)$ 表示第 $k$ 个串的长度。发现这样的 $i$ 在后缀数组 $\text{sa}$ 中形成一段连续的区间,并且我们可以倍增求出这段区间 $[L_k, R_k]$。所以每个串可以看成 $(i, \text{rnk}(i))$ 一个点,每个询问可以看成求矩形 $(p_l, L_k) - (p_{r + 1} - 1, R_k)$ 中点的个数。直接扫描线 + 树状数组即可,时间复杂度 $O(n \log n)$。

代码实现

本文的后缀数组板子较之前进行了优化,变得短了一些,推荐大家使用。

1
#include <bits/stdc++.h>
2
using namespace std;
3
4
const int maxn = 4e5, maxq = 1e6 + maxn, logn = 20;
5
int n, m, q, len[maxn + 3], pos[maxn + 3];
6
int sz, sa[maxn + 3], rnk[maxn + 3], cnt[maxn + 3], k_1[maxn + 3], k_2[maxn + 3], hei[maxn + 3];
7
int h[logn + 3][maxn + 3], lft[maxn + 3], rht[maxn + 3], tot, bit[maxn + 3], ans[maxq + 3];
8
char s[maxn + 3], t[maxn + 3];
9
10
struct event {
11
	int t, x, y, i, l, r;
12
	friend bool operator < (const event &a, const event &b) {
13
		return a.x == b.x ? a.t < b.t : a.x < b.x;
14
	}
15
} ev[maxq + 3];
16
17
void r_sort(int a[], int b[], int k[], int n) {
18
	fill(cnt + 1, cnt + sz + 1, 0);
19
	for (int i = 1; i <= n; i++) cnt[k[i]]++;
20
	for (int i = 2; i <= sz; i++) cnt[i] += cnt[i - 1];
21
	for (int i = n; i; i--) b[cnt[k[a[i]]]--] = a[i];
22
}
23
24
void make(char s[], int n) {
25
	sz = max(n + 1, 27);
26
	for (int i = 1; i <= n; i++) rnk[i] = s[i] - 'a' + 1, k_1[i] = i;
27
	r_sort(k_1, sa, rnk, n);
28
	for (int i = 1; i <= n; i++) rnk[sa[i]] = rnk[sa[i - 1]] + (s[sa[i]] != s[sa[i - 1]]);
29
	for (int k = 1; k < n; k <<= 1) {
30
		fill(k_1 + 1, k_1 + n + 1, 1);
31
		for (int i = 1; i <= n - k; i++) k_1[i] = rnk[i + k] + 1;
32
		for (int i = 1; i <= n; i++) k_2[i] = rnk[i], sa[i] = i;
33
		r_sort(sa, rnk, k_1, n), r_sort(rnk, sa, k_2, n);
34
		for (int i = 1; i <= n; i++) {
35
			bool f = k_1[sa[i]] == k_1[sa[i - 1]] && k_2[sa[i]] == k_2[sa[i - 1]];
36
			rnk[sa[i]] = f ? rnk[sa[i - 1]] : rnk[sa[i - 1]] + 1;
37
		}
38
		if (rnk[sa[n]] == n) break;
39
	}
40
	s[n + 1] = 0;
41
	for (int i = 1, j, k = 0; i <= n; hei[rnk[i++] - 1] = k) {
42
		for (k = max(0, k - 1), j = sa[rnk[i] - 1]; s[i + k] == s[j + k]; k++);
43
	}
44
	for (int i = 1; i <= n - 1; i++) {
45
		h[0][i] = hei[i];
46
	}
47
	for (int k = 1; 1 << k <= n - 1; k++) {
48
		for (int i = 1, j = (1 << (k - 1)) + 1; i <= n - (1 << k); i++, j++) {
49
			h[k][i] = min(h[k - 1][i], h[k - 1][j]);
50
		}
51
	}
52
}
53
54
void add(int t, int x, int y, int i, int l, int r) {
55
	ev[++tot].t = t, ev[tot].x = x, ev[tot].y = y, ev[tot].i = i, ev[tot].l = l, ev[tot].r = r;
56
}
57
58
void mod(int x) {
59
	for (int i = x; i <= m; i += i & -i) {
60
		bit[i]++;
61
	}
62
}
63
64
int sum(int x) {
65
	int y = 0;
66
	for (int i = x; i; i ^= i & -i) {
67
		y += bit[i];
68
	}
69
	return y;
70
}
71
72
int main() {
73
	scanf("%d %d", &n, &q);
74
	for (int i = 1; i <= n; i++) {
75
		scanf("%s", t + 1);
76
		len[i] = strlen(t + 1);
77
		pos[i] = m + 1;
78
		for (int j = 1; j <= len[i]; j++) {
79
			s[++m] = t[j];
80
		}
81
		s[++m] = '{';
82
	}
83
	pos[n + 1] = m + 1;
84
	make(s, m);
85
	for (int i = 1; i <= n; i++) {
86
		int x = rnk[pos[i]];
87
		for (int j = logn; ~j; j--) {
88
			int t = x - (1 << j);
89
			if (t > 0 && h[j][t] >= len[i]) x = t;
90
		}
91
		lft[i] = x, x = rnk[pos[i]];
92
		for (int j = logn; ~j; j--) {
93
			int t = x + (1 << j);
94
			if (t <= m && h[j][x] >= len[i]) x = t;
95
		}
96
		rht[i] = x;
97
	}
98
	for (int i = 1; i <= m; i++) {
99
		add(1, i, rnk[i], -1, -1, -1);
100
	}
101
	for (int i = 1, l, r, k; i <= q; i++) {
102
		scanf("%d %d %d", &l, &r, &k);
103
		l = pos[l], r = pos[r + 1] - 1;
104
		int L = lft[k], R = rht[k];
105
		add(2, l - 1, -1, i, L, R), add(2, r, 1, i, L, R);
106
	}
107
	sort(ev + 1, ev + tot + 1);
108
	for (int i = 1; i <= tot; i++) {
109
		if (ev[i].t == 1) {
110
			mod(ev[i].y);
111
		} else {
112
			ans[ev[i].i] += (sum(ev[i].r) - sum(ev[i].l - 1)) * ev[i].y;
113
		}
114
	}
115
	for (int i = 1; i <= q; i++) {
116
		printf("%d\n", ans[i]);
117
	}
118
	return 0;
119
}