莫队算法学习笔记

我太弱了,现在才学莫队

以下均假设 n,mn,m 同阶。

普通莫队

简介

莫队算法主要用于解决离线区间查询等问题,若区间 [l,r][l,r] 的答案能够 O(1)O(1) 转移到与 [l,r][l,r] 相邻的区间(即 [l1,r],[l+1,r],[l,r1],[l,r+1][l-1,r],[l+1,r],[l,r-1],[l,r+1]),那么就可以使用莫队算法在 O(nn)O(n\sqrt{n}) 内求出所有询问的答案。

引入

P2709 小B的询问

给定一个长为 nn 的整数序列 aa,值域为 [1,k][1,k],有 mm 次询问,每个询问给定一个区间 [l,r][l,r],求:

i=1kci2\sum\limits_{i=1}^k c_i^2

其中 cic_i 表示数字 ii[l,r][l,r] 中的出现次数。

1n,m,k5×1041\le n,m,k \le 5\times 10^4

朴素的暴力

不多说,开一个桶,按照题意模拟即可。

朴素的优化

开两个指针 l,rl,r,每次将两个指针一步一步挪到询问的区间,每挪一步的时间复杂度为 O(1)O(1)

写成代码:

1
2
3
4
5
6
7
8
9
//query[i].l 是当前询问的左端点,query[i].r 是当前询问的右端点
while (l > query[i].l)
add(a[--l]);
while (r < query[i].r)
add(a[++r]);
while (l < query[i].l)
del(a[l++]);
while (r > query[i].r)
del(a[r--]);
1
2
3
4
5
6
7
8
9
10
11
12
13
//bucket 是桶,表示每个数字出现的次数
//sum 即当前区间的答案
inline void add(int x) // 添加一个数
{
sum += bucket[x] * 2 + 1;
// 把完全平方公式拆开
++bucket[x];
}
inline void del(int x) // 删除一个数
{
sum -= bucket[x] * 2 - 1;
--bucket[x];
}

看起来高级的优化

把所有询问按照左端点排序,这样左指针最多移动 O(n)O(n) 次。

但实际上这个优化是假的,因为右端点是无序的,因此最坏时间复杂度仍然是 O(n2)O(n^2)

优美的暴力

考虑分块。

我们把 [1,n][1,n] 均匀分成 n\sqrt{n} 块,然后把询问以 ll 所在块的编号为第一关键字,rr 为第二关键字从小到大排序,然后再用“朴素的优化”中的方法暴力转移。

可以证明,这样的方法时间复杂度为 O(nn)O(n\sqrt{n})详细证明

这就是莫队算法。

实现

直接上代码,里面有注释:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 5e4 + 5;
int n, m, k, block;
int a[maxn], bucket[maxn];
// a 是题目给出的序列,bucket 是桶
int l = 1, r = 0;
ll ans[maxn]; // ans 用来记录答案
ll sum; // 不开 long long 见祖宗
struct node
{
int id, l, r; // id 是下标
} query[maxn];
bool cmp(node x, node y)
{
if (x.l / block == y.l / block) // l 在同一个块时
return x.r < y.r;
else
return x.l / block < y.l / block;
}
inline void add(int x) // 添加一个数
{
sum += bucket[x] * 2 + 1;
// 把完全平方公式拆开
++bucket[x];
}
inline void del(int x) // 删除一个数
{
sum -= bucket[x] * 2 - 1;
--bucket[x];
}
int main()
{
ios::sync_with_stdio(false);
cin >> n >> m >> k;
for (int i = 1; i <= n; i++)
cin >> a[i];
block = sqrt(n);
for (int i = 1; i <= m; i++)
{
cin >> query[i].l >> query[i].r;
query[i].id = i;
}
sort(query + 1, query + 1 + m, cmp);
for (int i = 1; i <= m; i++)
{
while (l > query[i].l)
add(a[--l]);
while (r < query[i].r)
add(a[++r]);
while (l < query[i].l)
del(a[l++]);
while (r > query[i].r)
del(a[r--]);
ans[query[i].id] = sum;
}
for (int i = 1; i <= m; i++)
cout << ans[i] << endl;
return 0;
}

带修改莫队

「国家集训队」数颜色/维护队列

给出一个序列,有 mm 个操作,有两种操作:

  1. 修改序列上某一位的数字
  2. 询问区间 [l,r][l,r] 中有多少种不同的数字

可以这样理解,序列的值是随着时间变化的,因此我们可以加入一个时间维度,变成查询区间 [l,r,t][l,r,t] 的答案,相当于空间上的区间查询,同样可以离线。

每次仍然可以 O(1)O(1) 转移,不过一共有六种转移,分别是 [l1,r,t],[l+1,r,t],[l,r1,t],[l,r,t1],[l,r,t+1][l-1,r,t],[l+1,r,t],[l,r-1,t],[l,r,t-1],[l,r,t+1]

还是考虑分块,以 ll 所在块编号为第一关键字,rr 所在块编号为第二关键字,tt 的大小为第三关键字,从小到大排序。

块的大小最好取 n23n^{\frac{2}{3}} ,证明我仍然不会,可以参考 OI Wiki。另外,千万不要把块长取为 n\sqrt{n} ,不然你将激情 T 飞。

上代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#include <bits/stdc++.h>
using namespace std;
const int maxn = 133338;

inline int read()
{
int res = 0, w = 1;
char ch = getchar();
while (!isdigit(ch))
{
if (ch == '-')
w = -1;
ch = getchar();
}
while (isdigit(ch))
res = res * 10 + ch - 48, ch = getchar();
return res * w;
}

int n, m;
int a[maxn], b[1000001]; // b 是桶
struct node
{
int id, l, r, pre; // pre 记录上一次修改的编号
} query[maxn]; // 记录询问

struct node2
{
int p, val;
} modify[maxn]; // 记录修改

int ans, block;
int res[maxn];

bool cmp(node x, node y)
{
if (x.l / block == y.l / block)
{
if (x.r / block == y.r / block)
return x.pre < y.pre;
return x.r < y.r;
}
return x.l / block < y.l / block;
}
inline void add(int x)
{
++b[x];
if (b[x] == 1)
++ans;
}
inline void del(int x)
{
--b[x];
if (b[x] == 0)
--ans;
}
inline void upd(int now, int l, int r)
{
if (modify[now].p >= l && modify[now].p <= r) // 修改操作在当前询问区间内,会对答案造成影响
{
del(a[modify[now].p]);
add(modify[now].val);
}
swap(a[modify[now].p], modify[now].val); // 下一次操作一定和这一次相反,因此不需要写两个函数。
}
int main()
{
int l = 1, r = 0, now = 0;
n = read(), m = read();
block = pow(n, 0.666);
for (int i = 1; i <= n; i++)
a[i] = read();
int qnum = 0, mnum = 0;
for (int i = 1; i <= m; i++)
{
char ch = getchar();
while (ch != 'Q' && ch != 'R')
ch = getchar();
if (ch == 'Q')
{
++qnum;
query[qnum].l = read();
query[qnum].r = read();
query[qnum].id = qnum;
query[qnum].pre = mnum;
}
else
{
++mnum;
modify[mnum].p = read();
modify[mnum].val = read();
}
}
sort(query + 1, query + 1 + qnum, cmp);
for (int i = 1; i <= qnum; i++)
{
while (l > query[i].l)
add(a[--l]);
while (r < query[i].r)
add(a[++r]);
while (l < query[i].l)
del(a[l++]);
while (r > query[i].r)
del(a[r--]);
while (now > query[i].pre)
upd(now--, query[i].l, query[i].r);
while (now < query[i].pre)
upd(++now, query[i].l, query[i].r);
res[query[i].id] = ans;
}
for (int i = 1; i <= qnum; i++)
printf("%d\n", res[i]);
return 0;
}

未完待续……