「Codeforces 547E」Mike and Friends
题目描述
给定 $n$ 个串,$q$ 次询问 $(l, r, k)$,问第 $k$ 个串在第 $[l, r]$ 个串中作为子串出现了多少次。
数据范围:$n \le 2 \times 10^5, q \le 5 \times 10^5$。
思路分析
把所有串用 #
隔开拼成一个大串,然后后缀排个序。第 $k$ 个串如果是某个后缀 $i$ 的前缀当且仅当 $\text{lcp}(p_k, i) \ge \text{len}(k)$,其中 $p_k$ 表示第 $k$ 个串开始的位置,$\text{len}(k)$ 表示第 $k$ 个串的长度。发现这样的 $i$ 在后缀数组 $\text{sa}$ 中形成一段连续的区间,并且我们可以倍增求出这段区间 $[L_k, R_k]$。所以每个串可以看成 $(i, \text{rnk}(i))$ 一个点,每个询问可以看成求矩形 $(p_l, L_k) - (p_{r + 1} - 1, R_k)$ 中点的个数。直接扫描线 + 树状数组即可,时间复杂度 $O(n \log n)$。
代码实现
本文的后缀数组板子较之前进行了优化,变得短了一些,推荐大家使用。
1 |
|
2 | using namespace std; |
3 | |
4 | const int maxn = 4e5, maxq = 1e6 + maxn, logn = 20; |
5 | int n, m, q, len[maxn + 3], pos[maxn + 3]; |
6 | int sz, sa[maxn + 3], rnk[maxn + 3], cnt[maxn + 3], k_1[maxn + 3], k_2[maxn + 3], hei[maxn + 3]; |
7 | int h[logn + 3][maxn + 3], lft[maxn + 3], rht[maxn + 3], tot, bit[maxn + 3], ans[maxq + 3]; |
8 | char s[maxn + 3], t[maxn + 3]; |
9 | |
10 | struct event { |
11 | int t, x, y, i, l, r; |
12 | friend bool operator < (const event &a, const event &b) { |
13 | return a.x == b.x ? a.t < b.t : a.x < b.x; |
14 | } |
15 | } ev[maxq + 3]; |
16 | |
17 | void r_sort(int a[], int b[], int k[], int n) { |
18 | fill(cnt + 1, cnt + sz + 1, 0); |
19 | for (int i = 1; i <= n; i++) cnt[k[i]]++; |
20 | for (int i = 2; i <= sz; i++) cnt[i] += cnt[i - 1]; |
21 | for (int i = n; i; i--) b[cnt[k[a[i]]]--] = a[i]; |
22 | } |
23 | |
24 | void make(char s[], int n) { |
25 | sz = max(n + 1, 27); |
26 | for (int i = 1; i <= n; i++) rnk[i] = s[i] - 'a' + 1, k_1[i] = i; |
27 | r_sort(k_1, sa, rnk, n); |
28 | for (int i = 1; i <= n; i++) rnk[sa[i]] = rnk[sa[i - 1]] + (s[sa[i]] != s[sa[i - 1]]); |
29 | for (int k = 1; k < n; k <<= 1) { |
30 | fill(k_1 + 1, k_1 + n + 1, 1); |
31 | for (int i = 1; i <= n - k; i++) k_1[i] = rnk[i + k] + 1; |
32 | for (int i = 1; i <= n; i++) k_2[i] = rnk[i], sa[i] = i; |
33 | r_sort(sa, rnk, k_1, n), r_sort(rnk, sa, k_2, n); |
34 | for (int i = 1; i <= n; i++) { |
35 | bool f = k_1[sa[i]] == k_1[sa[i - 1]] && k_2[sa[i]] == k_2[sa[i - 1]]; |
36 | rnk[sa[i]] = f ? rnk[sa[i - 1]] : rnk[sa[i - 1]] + 1; |
37 | } |
38 | if (rnk[sa[n]] == n) break; |
39 | } |
40 | s[n + 1] = 0; |
41 | for (int i = 1, j, k = 0; i <= n; hei[rnk[i++] - 1] = k) { |
42 | for (k = max(0, k - 1), j = sa[rnk[i] - 1]; s[i + k] == s[j + k]; k++); |
43 | } |
44 | for (int i = 1; i <= n - 1; i++) { |
45 | h[0][i] = hei[i]; |
46 | } |
47 | for (int k = 1; 1 << k <= n - 1; k++) { |
48 | for (int i = 1, j = (1 << (k - 1)) + 1; i <= n - (1 << k); i++, j++) { |
49 | h[k][i] = min(h[k - 1][i], h[k - 1][j]); |
50 | } |
51 | } |
52 | } |
53 | |
54 | void add(int t, int x, int y, int i, int l, int r) { |
55 | ev[++tot].t = t, ev[tot].x = x, ev[tot].y = y, ev[tot].i = i, ev[tot].l = l, ev[tot].r = r; |
56 | } |
57 | |
58 | void mod(int x) { |
59 | for (int i = x; i <= m; i += i & -i) { |
60 | bit[i]++; |
61 | } |
62 | } |
63 | |
64 | int sum(int x) { |
65 | int y = 0; |
66 | for (int i = x; i; i ^= i & -i) { |
67 | y += bit[i]; |
68 | } |
69 | return y; |
70 | } |
71 | |
72 | int main() { |
73 | scanf("%d %d", &n, &q); |
74 | for (int i = 1; i <= n; i++) { |
75 | scanf("%s", t + 1); |
76 | len[i] = strlen(t + 1); |
77 | pos[i] = m + 1; |
78 | for (int j = 1; j <= len[i]; j++) { |
79 | s[++m] = t[j]; |
80 | } |
81 | s[++m] = '{'; |
82 | } |
83 | pos[n + 1] = m + 1; |
84 | make(s, m); |
85 | for (int i = 1; i <= n; i++) { |
86 | int x = rnk[pos[i]]; |
87 | for (int j = logn; ~j; j--) { |
88 | int t = x - (1 << j); |
89 | if (t > 0 && h[j][t] >= len[i]) x = t; |
90 | } |
91 | lft[i] = x, x = rnk[pos[i]]; |
92 | for (int j = logn; ~j; j--) { |
93 | int t = x + (1 << j); |
94 | if (t <= m && h[j][x] >= len[i]) x = t; |
95 | } |
96 | rht[i] = x; |
97 | } |
98 | for (int i = 1; i <= m; i++) { |
99 | add(1, i, rnk[i], -1, -1, -1); |
100 | } |
101 | for (int i = 1, l, r, k; i <= q; i++) { |
102 | scanf("%d %d %d", &l, &r, &k); |
103 | l = pos[l], r = pos[r + 1] - 1; |
104 | int L = lft[k], R = rht[k]; |
105 | add(2, l - 1, -1, i, L, R), add(2, r, 1, i, L, R); |
106 | } |
107 | sort(ev + 1, ev + tot + 1); |
108 | for (int i = 1; i <= tot; i++) { |
109 | if (ev[i].t == 1) { |
110 | mod(ev[i].y); |
111 | } else { |
112 | ans[ev[i].i] += (sum(ev[i].r) - sum(ev[i].l - 1)) * ev[i].y; |
113 | } |
114 | } |
115 | for (int i = 1; i <= q; i++) { |
116 | printf("%d\n", ans[i]); |
117 | } |
118 | return 0; |
119 | } |