概念
众所周知,BST在特定的插入顺序下常常会退化成链,从而使单次查找的时间复杂度退化为,这是很糟糕的,因此出现了一箩筐的平衡树,在保证BST性质的同时,尽量保持BST的平衡。
Treap就是当中概念比较简单而且比较好写的一种真的吗?。Treap是Tree和Heap的合成,即“树堆”,其精髓就在于通过维护,使得一棵树在保有BST性质的同时,也具有大根堆的性质。为此,每个节点都会随机生成一个额外的权值用于堆性质的维护:
struct node {
int l, r; // 左右子树
int cnt, sz; // 当前键值副本树与子树(包括自身)大小
int val, w; // 键值与随机权值
} t[N];
int newnode(int val) {
t[++tot].val = val;
t[tot].cnt = t[tot].sz = 1;
t[tot].l = t[tot].r = 0;
t[tot].w = rand();
return tot;
}
void pushup(int p) { // 维护sz
t[p].sz = t[t[p].l].sz + t[t[p].r].sz + t[p].cnt;
}
为了防止讨论边界条件,我们插入两个值为的节点:
void build() {
tot = 0;
rt = newnode(-INF);
t[rt].r = newnode(INF);
pushup(rt);
}
旋转
改变BST的形状而不破坏其性质的操作是二叉树的旋转,分为右旋(zig)和左旋(zag)两种,以右旋为例,右旋把一个节点的左子树绕其“向右(顺时针)”旋转到原来的位置,并且让成为的右子树,原来的右子树成为的左子树,左旋类似:
落实在代码上,就是:
void zig(int &p) {
int q = t[p].l;
t[p].l = t[q].r, t[q].r = p;
p = q; pushup(t[p].r); pushup(p);
}
void zag(int &p) {
int q = t[p].r;
t[p].r = t[q].l, t[q].l = p;
p = q; pushup(t[p].l); pushup(p);
}
注意pushup
的顺序。
排名相关查询
我们接下来考虑平衡树的前两个查询操作,分别是查排名和按排名查数,两个都相对简单,可以递归实现:
int rank(int p, int val) {
if (p == 0) return 0;
if (val == t[p].val) return t[t[p].l].sz + 1;
if (val < t[p].val) return rank(t[p].l, val);
return t[t[p].l].sz + t[p].cnt + rank(t[p].r, val);
}
int query(int p, int rk) {
if (p == 0) return INF;
if (t[t[p].l].sz >= rk) return query(t[p].l, rk);
if (t[t[p].l].sz + t[p].cnt >= rk) return t[p].val;
return query(t[p].r, rk - t[t[p].l].sz - t[p].cnt);
}
注意这里查询的结果包含,要去掉。
前驱后继查询
接下来考虑查询的前驱/后继的操作,以前驱为例,我们初始化答案为,有以下几种情况:
- 没有找到节点:那么答案就在经过的节点上。
- 找到节点,可惜是叶子:同上,答案在经过的节点上。
- 找到节点,不是叶子:若节点有左子树,从开始一路往右走直到不存在右子树为止,此时节点对应的值即为前驱。
代码:
int querypre(int val) {
int ans = 1; // t[1].val = -inf
int p = rt;
while (p) {
if (val == t[p].val) {
if (t[p].l) {
p = t[p].l;
while (t[p].r) p = t[p].r;
ans = p;
}
break;
}
if (t[p].val < val && t[p].val > t[ans].val) ans = p;
p = val < t[p].val ? t[p].l : t[p].r;
}
return t[ans].val;
}
int querynxt(int val) {
int ans = 2; // t[2].val = inf
int p = rt;
while (p) {
if (val == t[p].val) {
if (t[p].r) {
p = t[p].r;
while (t[p].l) p = t[p].l;
ans = p;
}
break;
}
if (t[p].val > val && t[p].val < t[ans].val) ans = p;
p = val < t[p].val ? t[p].l : t[p].r;
}
return t[ans].val;
}
插入
递归进行,同时注意维护大根堆性质即可:
void insert(int &p, int val) {
if (p == 0) {
p = newnode(val);
return;
}
if (val == t[p].val) {
t[p].cnt++;
pushup(p);
return;
}
if (val < t[p].val) {
insert(t[p].l, val);
if (t[p].w < t[t[p].l].w) zig(p);
} else {
insert(t[p].r, val);
if (t[p].w < t[t[p].r].w) zag(p);
}
pushup(p);
}
删除
我们利用旋转一路把要删的节点转到叶子的位置,然后直接删除即可,同时一路上维护堆的性质:
void remove(int &p, int val) {
if (p == 0) return;
if (val == t[p].val) {
if (t[p].cnt > 1) {
t[p].cnt--;
pushup(p);
return;
}
if (t[p].l || t[p].r) {
if (t[p].r == 0 || t[t[p].l].w > t[t[p].r].w)
zig(p), remove(t[p].r, val);
else
zag(p), remove(t[p].l, val);
pushup(p);
} else p = 0;
return;
}
val < t[p].val ? remove(t[p].l, val) : remove(t[p].r, val);
pushup(p);
}
完整代码
以下是通过普通平衡树模板题的全部代码,共计150行左右,感觉平衡树这个东西代码量还是有点恐怖,而且需要注意非常多的细节,需要结合理解记忆:
#include <cstdio>
#include <cstdlib>
#include <ctime>
using namespace std;
const int N = 100010;
const int INF = 0x3f3f3f3f;
struct node {
int l, r;
int cnt, sz;
int val, w;
} t[N];
int n, tot = 0, rt;
int newnode(int val) {
t[++tot].val = val;
t[tot].cnt = t[tot].sz = 1;
t[tot].l = t[tot].r = 0;
t[tot].w = rand();
return tot;
}
void pushup(int p) {
t[p].sz = t[t[p].l].sz + t[t[p].r].sz + t[p].cnt;
}
void zig(int &p) {
int q = t[p].l;
t[p].l = t[q].r, t[q].r = p;
p = q; pushup(t[p].r); pushup(p);
}
void zag(int &p) {
int q = t[p].r;
t[p].r = t[q].l, t[q].l = p;
p = q; pushup(t[p].l); pushup(p);
}
void build() {
tot = 0;
rt = newnode(-INF);
t[rt].r = newnode(INF);
pushup(rt);
}
int rank(int p, int val) {
if (p == 0) return 0;
if (val == t[p].val) return t[t[p].l].sz + 1;
if (val < t[p].val) return rank(t[p].l, val);
return t[t[p].l].sz + t[p].cnt + rank(t[p].r, val);
}
int query(int p, int rk) {
if (p == 0) return INF;
if (t[t[p].l].sz >= rk) return query(t[p].l, rk);
if (t[t[p].l].sz + t[p].cnt >= rk) return t[p].val;
return query(t[p].r, rk - t[t[p].l].sz - t[p].cnt);
}
int querypre(int val) {
int ans = 1;
int p = rt;
while (p) {
if (val == t[p].val) {
if (t[p].l) {
p = t[p].l;
while (t[p].r) p = t[p].r;
ans = p;
}
break;
}
if (t[p].val < val && t[p].val > t[ans].val) ans = p;
p = val < t[p].val ? t[p].l : t[p].r;
}
return t[ans].val;
}
int querynxt(int val) {
int ans = 2;
int p = rt;
while (p) {
if (val == t[p].val) {
if (t[p].r) {
p = t[p].r;
while (t[p].l) p = t[p].l;
ans = p;
}
break;
}
if (t[p].val > val && t[p].val < t[ans].val) ans = p;
p = val < t[p].val ? t[p].l : t[p].r;
}
return t[ans].val;
}
void insert(int &p, int val) {
if (p == 0) {
p = newnode(val);
return;
}
if (val == t[p].val) {
t[p].cnt++;
pushup(p);
return;
}
if (val < t[p].val) {
insert(t[p].l, val);
if (t[p].w < t[t[p].l].w) zig(p);
} else {
insert(t[p].r, val);
if (t[p].w < t[t[p].r].w) zag(p);
}
pushup(p);
}
void remove(int &p, int val) {
if (p == 0) return;
if (val == t[p].val) {
if (t[p].cnt > 1) {
t[p].cnt--;
pushup(p);
return;
}
if (t[p].l || t[p].r) {
if (t[p].r == 0 || t[t[p].l].w > t[t[p].r].w)
zig(p), remove(t[p].r, val);
else
zag(p), remove(t[p].l, val);
pushup(p);
} else p = 0;
return;
}
val < t[p].val ? remove(t[p].l, val) : remove(t[p].r, val);
pushup(p);
}
int main() {
srand(time(0)); build();
scanf("%d", &n);
while (n--) {
int op, x;
scanf("%d%d", &op, &x);
switch (op) {
case 1: insert(rt, x); break;
case 2: remove(rt, x); break;
case 3: printf("%d\n", rank(rt, x) - 1); break;
case 4: printf("%d\n", query(rt, x + 1)); break;
case 5: printf("%d\n", querypre(x)); break;
case 6: printf("%d\n", querynxt(x)); break;
}
}
return 0;
}