LG3809 后缀排序 题解

作者:XiaoQuQu,发表于 Tue Feb 27 2024。

写这题的时候才发现之前写过的“SA”是使用 std::sort() 版的 $n\log ^2 n$ 的后缀数组,然后爆补 $n \log n$ 的计数排序。发现对于计数排序还是有很多不理解的地方,写一篇题解,怕自己啥时候又忘了。

阅读这篇题解可能需要你对后缀数组有一定的基础认识,可以前往 OI-Wiki 等查看“后缀数组”的介绍,本篇题解只是作为对于其他博客等的补充。

Part1. 计数排序

考虑排序的本质,其实是求出每个数的排名。

计数排序先用桶储存了所有数的出现次数,如序列 $a=[2,2,1,2,3,4]$,桶数组 $b=[1,3,1,1]$,对 $b$ 数组做前缀和,得到前缀和数组 $c=[1,4,5,6]$。

接下来我们从最大数到最小数枚举 $i$,则我们发现,所有数 $i$ 的排名都应该为 $c_{i-1}+1$。

Part2. 倍增,倍增

这一部分在阅读其他博文时可能有“看起来很简单,写起来很复杂”的感觉,在这里我尝试尽量用贴近代码语言的方式表述。

考虑倍增地求后缀数组,枚举 $k=2^w$,保证 $k\le n$,假设我们已知所有长度为 $2^{w-1}$ 的字符串的排名,从 $i$ 开始的字符串排名记为 $r_i$,我们希望求长度为 $2^w$ 次方的字符串的排名。

考虑如何比较两个从 $i,j$ 开始的字符串的大小,其实相当于比较 $r_i,r_j$ 的大小,若相等比较 $r_{i+2^{w-1}},r_{j+2^{w-1}}$ 的大小。

所以我们有一个很朴素的思想,也就是按照 $r_{i+2^{w-1}}$ 的大小作为第二关键字,$r_i$ 的大小作为第一关键字,然后进行排序,注意表述顺序,因为我们这里使用 LSD 进行基数排序,会先对优先级低的关键字比较。

但是这样的常数会很大,考虑我们怎么样省略掉比较第二关键字的步骤。

Part3. 去掉第二关键字

我们发现,可以枚举 $i$,对于 $i+k>n$ 的 $i$,他按照第二关键字排名后肯定是在最前面的,因为其第二关键字为 $0$。

对于 $i+k<n$ 的怎么办?我们可以枚举 $i$。若在这一轮倍增之前第 $i$ 小的字符串起始于 $sa_i$ 且 $sa_i-k>0$,那么我们就可以肯定,对于 $sa_i-k$ 这一项,排完序后在 $i+k<n$ 的部分的排名是 $i$。

这样我们就省略了对于第二关键字的排序,直接对于第一关键字排序即可,具体写起来是这样的。

int p = 0;
for (int i = n; i + (1 << w) > n; --i) id[++p] = i; // for i + k > n
for (int i = 1; i <= n; ++i)
  	if (sa[i] > (1 << w)) id[++p] = sa[i] - (1 << w); // 第 i 个位置的数字应该为 sa[i] - k

Part4. 大的来了

设对于第二关键字排好序时,第 $i$ 小的字符串起始于 $id_i$。考虑计数排序,可以直接按照 $r_{id_i}$ 进行计数排序,得到的结果直接存在 $sa_i$ 中。

接下来考虑更新 $r$ 数组,考虑从小到大枚举 $i$,然后直接将 $r_{sa_i}$ 更新为 $i$。但是这样会有问题。即有些字符串是相同的,我们需要对这些字符串进行去重。

考虑我们是如何比较两个字符串的大小,发现判断两个以 $i,j$ 开头的字符串是否相同,可以直接判断旧的 $r_i,r_j$ 与 $r_{i+k},r_{j+k}$ 是否相同即可。

完整代码如下。

const int MAXN = 2e6 + 5, MAXD = 256;
int n, sa[MAXN], rk[MAXN], cnt[MAXN], oldrk[MAXN], id[MAXN];
char s[MAXN];

void work() {
	cin >> (s + 1); n = strlen(s + 1);
	for (int i = 1; i <= n; ++i) cnt[rk[i] = s[i]]++;
	for (int i = 1; i <= MAXD; ++i) cnt[i] += cnt[i - 1];
	for (int i = n; i >= 1; --i) sa[cnt[rk[i]]--] = i;
	for (int w = 0; (1 << w) <= n; ++w) {
		int p = 0;
		for (int i = n; i + (1 << w) > n; --i) id[++p] = i;
		for (int i = 1; i <= n; ++i)
			if (sa[i] > (1 << w)) id[++p] = sa[i] - (1 << w);
		for (int i = 0; i <= p; ++i) cnt[i] = 0;
		for (int i = 1; i <= n; ++i) ++cnt[rk[id[i]]];
		for (int i = 1; i <= p; ++i) cnt[i] += cnt[i - 1];
		for (int i = n; i >= 1; --i) sa[cnt[rk[id[i]]]--] = id[i];
		p = 0;
		for (int i = 1; i <= n; ++i) oldrk[i] = rk[i];
		for (int i = 1; i <= n; ++i) {
			if (oldrk[sa[i]] == oldrk[sa[i - 1]] && oldrk[sa[i] + (1 << w)] == oldrk[sa[i - 1] + (1 << w)])
				rk[sa[i]] = p;
			else rk[sa[i]] = ++p;
		}
		if (p == n) break;
	}
	for (int i = 1; i <= n; ++i) cout << sa[i] << ' '; 
}

Copyright © 2024 LVJ, Open-Source Project. 本站内容在无特殊说明情况下均遵循 CC-BY-SA 4.0 协议,内容版权归属原作者。