It turns out that writing a STL container from scratch is mostly a tedious physical labor.
STL is created to save C++ programmers the time of reinventing
wheels. Unfortunately, many STL data structures, most notably
self-balancing BSTs (aka. std::(multi)set
), are not
extendable and are by itself too limited to be used in the context of
competitive programming, forcing us to write our own BSTs again and
again in competitions.
There comes the fact I find really interesting: A quick and dirty
self-balancing BST implementation written during a competition is only
about 50-60 lines long, while the STL implementation of
std::set
and std::multiset
is usually more
than a thousand-lines long in total.
It does make me wonder: what makes this huge difference? And, will the code of our BST bloat as well if we write it the STL way – with generics, iterators, and all the necessary bits and pieces as specified in the reference?
To answer my question, I have here tried creating my implementation
of multiset using treap as the underlying data structure. The result,
treap_multiset
, is almost a drop-in replacement to
std::multiset
. The few places where it does not conform to
the C++ standard are:
- It is currently not allocator-aware, so all allocator-related features are not implemented.
emplace
andemplace_hint
are not implemented.- All operations that have logarithmic time complexity in
std::multiset
still have logarithmic time complexity here, but only in the average sense (because treap is a randomized data structure), and could have linear worst-case time complexity (though very, very, very unlikely). void erase(iterator)
takes amortized logarithmic time instead of constant time.- A few uncommonly-used member types are missing.
treap_multiset
also supports two new operations that are
not supported in the original std::multiset
:
size_type rank(iterator it) const; size_type rank(const_iterator it) const;
Both take average logarithmic time and return the rank / position of the iterator.
(size_type index); iterator atconst iterator at(size_type index) const;
Both take average logarithmic time and return the iterator at the specified index/position.
The code is shown below:
#include <stdexcept>
#include <random>
#include <algorithm>
static std::random_device random_device;
static std::mt19937_64 random_engine(random_device());
template <typename T>
struct treap_node {
using rand_weight_type = decltype(random_engine)::result_type;
using size_type = std::size_t;
*left, *right, *parent;
treap_node rand_weight_type weight;
size_type size;
;
T value
(const T &value): left(nullptr), right(nullptr), parent(nullptr),
treap_node(random_engine()), size(1), value(value) {}
weight
(treap_node *left, treap_node *right, treap_node *parent,
treap_noderand_weight_type weight, size_type size, const T &value):
(left), right(right), parent(parent), weight(weight),
left(size), value(value) {}
size
void update_size() {
= 1 + (left ? left->size : 0) + (right ? right->size : 0);
size }
};
#define IMPL_ITERATOR_MOVE_NEXT do { \
if (!node) break; \
if (node->right) { \
node = node->right; \
while (node && node->left) node = node->left; \
} else { \
bool from_right = true; \
while (from_right) { \
if (!node->parent) { node = nullptr; break; } \
from_right = node->parent->right == node; \
node = node->parent; \
} \
} \
} while (0)
#define IMPL_ITERATOR_MOVE_PREV do { \
if (!node) break; \
if (node->left) { \
node = node->left; \
while (node && node->right) node = node->right; \
} else { \
bool from_left = true; \
while (from_left) { \
if (!node->parent) { node = nullptr; break; } \
from_left = node->parent->left == node; \
node = node->parent; \
} \
} \
} while (0)
#define TREAP_ITERATOR_DECL(name, qualifier, inc, dec) \
template <typename T> struct name { \
qualifier treap_node<T> *node; bool past_the_end; \
name(): node(nullptr), past_the_end(true) {} \
name(qualifier treap_node<T> *node, bool past_the_end) \
: node(node), past_the_end(past_the_end) {} \
name(qualifier treap_node<T> *node): node(node), past_the_end(false) {} \
qualifier T &operator *() qualifier { \
if (!node || past_the_end) \
throw std::runtime_error("dereferencing null/past-end iterator"); \
return node->value; \
} \
bool operator ==(const name<T> &b) const { \
return node == b.node && past_the_end == b.past_the_end; \
} \
bool operator !=(const name<T> &b) const { \
return node != b.node || past_the_end != b.past_the_end; \
} \
name<T> &operator ++() { \
qualifier treap_node<T> *backup = node; \
IMPL_ITERATOR_MOVE_##inc; \
if (!node) past_the_end = true, node = backup; \
return *this; \
} \
name<T> operator ++(int) { \
name<T> ret(*this); \
return ++(*this), ret; \
} \
name<T> &operator --() { \
if (past_the_end) past_the_end = false; \
else IMPL_ITERATOR_MOVE_##dec; \
if (!node) \
throw std::runtime_error("can't decrement at the beginning"); \
return *this; \
} \
name<T> operator --(int) { \
name<T> ret(*this); \
return --(*this), ret; \
} \
}
(treap_iterator, /* NO QUALIFIER */, NEXT, PREV);
TREAP_ITERATOR_DECL(reverse_treap_iterator, /* NO QUALIFIER */, PREV, NEXT);
TREAP_ITERATOR_DECL(const_treap_iterator, const, NEXT, PREV);
TREAP_ITERATOR_DECL(const_reverse_treap_iterator, const, PREV, NEXT);
TREAP_ITERATOR_DECL
template <typename T, typename Compare = std::less<T>>
class treap_multiset {
public:
using key_type = T;
using value_type = T;
using size_type = typename treap_node<T>::size_type;
using key_compare = Compare;
using value_compare = Compare;
using node_type = treap_node<T>*;
using iterator = treap_iterator<T>;
using reverse_iterator = reverse_treap_iterator<T>;
using const_iterator = const_treap_iterator<T>;
using const_reverse_iterator = const_reverse_treap_iterator<T>;
(): root(nullptr) {}
treap_multiset
(const treap_multiset &b): root(deep_copy(b.root)), comp(b.comp) {}
treap_multiset
(treap_multiset &&b): root(b.root), comp(b.comp) {}
treap_multiset
~treap_multiset() { if (root) recursive_free(root); }
bool empty() const { return root == nullptr; }
size_type size() const { return root ? root->size : 0; }
size_type max_size() const { return 0x7FFFFFFF; }
() const { return comp; }
key_compare key_comp
() const { return comp; }
value_compare value_comp
()
iterator begin{ return iterator(leftmost(root)); }
() const
const_iterator begin{ return const_iterator(leftmost(root)); }
() const
const_iterator cbegin{ return const_iterator(leftmost(root)); }
()
iterator end{ return iterator(rightmost(root), true); }
() const
const_iterator end{ return const_iterator(rightmost(root), true); }
() const
const_iterator cend{ return const_iterator(rightmost(root), true); }
()
reverse_iterator rbegin{ return reverse_iterator(rightmost(root)); }
() const
const_reverse_iterator rbegin{ return const_reverse_iterator(rightmost(root)); }
() const
const_reverse_iterator crbegin{ return const_reverse_iterator(rightmost(root)); }
()
reverse_iterator rend{ return reverse_iterator(leftmost(root), true); }
() const
const_reverse_iterator rend{ return const_reverse_iterator(leftmost(root), true); }
() const
const_reverse_iterator crend{ return const_reverse_iterator(leftmost(root), true); }
(const value_type &value) {
iterator insertnode_type left, right;
(root, value, left, nullptr, right, nullptr);
split_lenode_type temp = new treap_node<T>(value);
= join(join(left, temp), right);
root return iterator(temp);
}
(iterator position, const value_type &value)
iterator insert{ return insert(value); }
template <typename II>
void insert(II first, II last) {
for (; first != last; first++)
(*first);
insert}
size_type rank(iterator it) const { return rank(it.node); }
size_type rank(const_iterator it) const { return rank(it.node); }
(size_type index) {
iterator atif (index < 0 || index > size()) return end();
return iterator(at_internal(index));
}
(size_type index) const {
const_iterator atif (index < 0 || index > size()) return end();
return const_iterator(at_internal(index));
}
void erase(iterator pos) {
(pos);
assert_validnode_type a, b, c;
size_type rank = this->rank(pos.node);
(root, rank, a, nullptr, c, nullptr);
split_size(a, rank - 1, a, nullptr, b, nullptr);
split_size= join(a, c);
root // assert(b == pos.node);
delete b;
}
size_type erase(const value_type &key) {
node_type a, b, c;
(root, key, a, nullptr, c, nullptr);
split_le(a, key, a, nullptr, b, nullptr);
split_re= join(a, c);
root if (b) {
size_type ret = b->size;
(b);
recursive_freereturn ret;
}
return 0;
}
void erase(iterator first, iterator last) {
size_type rank_first = rank(first);
size_type rank_last = rank(last);
node_type a, b, c;
(root, rank_last, a, nullptr, c, nullptr);
split_size(a, rank_last - 1, a, nullptr, b, nullptr);
split_size= join(a, c);
root if (b) recursive_free(b);
}
void clear() { if (root) recursive_free(root); }
void swap(treap_multiset &b) { swap(root, b.root); }
(const value_type &key) {
iterator findnode_type ret = find_internal(key);
return ret ? iterator(ret) : end();
}
(const value_type &key) const {
const_iterator findnode_type ret = const_cast<treap_multiset<T>*>(this)->find_internal(key);
return ret ? const_iterator(ret) : end();
}
size_type count(const value_type &key) const {
node_type a, b, c;
<T> *thiz = const_cast<treap_multiset<T>*>(this);
treap_multiset->split_le(root, key, a, nullptr, c, nullptr);
thiz->split_re(root, key, a, nullptr, b, nullptr);
thizsize_type ret = b ? b->size : 0;
->root = thiz->join(thiz->join(a, b), c);
thizreturn ret;
}
(const value_type &key) {
iterator lower_boundnode_type ret = lower_bound_internal(key);
return ret ? iterator(ret) : end();
}
(const value_type &key) const {
const_iterator lower_boundnode_type ret = const_cast<treap_multiset<T>*>(this)->lower_bound_internal(key);
return ret ? const_iterator(ret) : end();
}
(const value_type &key) {
iterator upper_boundnode_type ret = upper_bound_internal(key);
return ret ? iterator(ret) : end();
}
(const value_type &key) const {
const_iterator upper_boundnode_type ret = const_cast<treap_multiset<T>*>(this)->upper_bound_internal(key);
return ret ? const_iterator(ret) : end();
}
std::pair<iterator, iterator> equal_range(const value_type &key) {
return std::make_pair(lower_bound(key), upper_bound(key));
}
std::pair<const_iterator, const_iterator> equal_range(const value_type &key) const {
return std::make_pair(lower_bound(key), upper_bound(key));
}
size_type depth(node_type node) {
if (!node) return 0;
return 1 + std::max(depth(node->left), depth(node->right));
}
private:
;
Compare compnode_type root;
void recursive_free(node_type root) {
if (root->left)
(root->left);
recursive_freeif (root->right)
(root->right);
recursive_freedelete root;
}
void split_le(node_type root, const value_type &key,
node_type &left, node_type left_parent,
node_type &right, node_type right_parent) {
if (!root) { left = right = nullptr; return; }
if (!comp(key, root->value)) {
= root; root->parent = left_parent;
left (root->right, key, root->right, root, right, right_parent);
split_le} else {
= root; root->parent = right_parent;
right (root->left, key, left, left_parent, root->left, root);
split_le}
->update_size();
root}
void split_re(node_type root, const value_type &key,
node_type &left, node_type left_parent,
node_type &right, node_type right_parent) {
if (!root) { left = right = nullptr; return; }
if (comp(root->value, key)) {
= root; root->parent = left_parent;
left (root->right, key, root->right, root, right, right_parent);
split_re} else {
= root; root->parent = right_parent;
right (root->left, key, left, left_parent, root->left, root);
split_re}
->update_size();
root}
void split_size(node_type root, size_type size,
node_type &left, node_type left_parent,
node_type &right, node_type right_parent) {
if (!root) { left = right = nullptr; return; }
size_type left_size = 1 + (root->left ? root->left->size : 0);
if (left_size <= size) {
= root; root->parent = left_parent;
left (root->right,
split_size- left_size, root->right, root, right, right_parent);
size } else {
= root; root->parent = right_parent;
right (root->left, size, left, left_parent, root->left, root);
split_size}
->update_size();
root}
node_type join(node_type left, node_type right) {
if (!left) return right;
if (!right) return left;
if (left->weight <= right->weight) {
node_type temp = join(left->right, right);
if (temp) temp->parent = left;
->right = temp;
left->update_size();
leftreturn left;
} else {
node_type temp = join(left, right->left);
if (temp) temp->parent = right;
->left = temp;
right->update_size();
rightreturn right;
}
}
node_type leftmost(node_type x) const {
if (!x) return nullptr;
node_type ret = x; while (ret->left) ret = ret->left;
return ret;
}
node_type rightmost(node_type x) const {
if (!x) return nullptr;
node_type ret = x; while (ret->right) ret = ret->right;
return ret;
}
node_type find_internal(const value_type &key) {
node_type a, b, c;
(root, key, a, nullptr, c, nullptr);
split_le(root, key, a, nullptr, b, nullptr);
split_re= join(join(a, b), c);
root return b;
}
node_type lower_bound_internal(const value_type &key) {
node_type left, right;
(root, key, left, nullptr, right, nullptr);
split_renode_type ret = leftmost(right);
= join(left, right);
root return ret;
}
node_type upper_bound_internal(const value_type &key) {
node_type left, right;
(root, key, left, nullptr, right, nullptr);
split_lenode_type ret = leftmost(right);
= join(left, right);
root return ret;
}
node_type at_internal(size_type index) const {
node_type temp = root;
while (true) {
size_type left_size = 1 + (temp->left ? temp->left->size : 0);
if (index == left_size) return temp;
else if (index < left_size) temp = temp->left;
else temp = temp->right, index -= left_size;
}
return nullptr; // UNREACHABLE
}
size_type rank(node_type node) {
bool from_right = true;
size_type ret = 0;
while (node) {
if (from_right)
+= 1 + (node->left ? node->left->size : 0);
ret if (node->parent)
= node == node->parent->right;
from_right = node->parent;
node }
return ret;
}
node_type deep_copy(node_type node) {
if (!node) return nullptr;
node_type left = deep_copy(node->left);
node_type right = deep_copy(node->right);
node_type ret = new treap_node<T>(
, right, nullptr,
left->weight, node->size, node->value
node);
if (left) left->parent = ret;
if (right) right->parent = ret;
return ret;
}
void assert_valid(iterator it) {
if (!it.node || it.past_the_end)
throw std::runtime_error("invalid iterator");
node_type temp = it.node;
while (temp->parent) temp = temp->parent;
if (temp != root)
throw std::runtime_error("invalid iterator");
}
};
(The code above hasn’t been thoroughly tested yet and could still contain bugs).