Reinventing the Wheel: A Drop-in STL Multiset Alternative Using Treap

It turns out that writing a STL container from scratch is mostly a tedious physical labor.

STL is created to save C++ programmers the time of reinventing wheels. Unfortunately, many STL data structures, most notably self-balancing BSTs (aka. std::(multi)set), are not extendable and are by itself too limited to be used in the context of competitive programming, forcing us to write our own BSTs again and again in competitions.

There comes the fact I find really interesting: A quick and dirty self-balancing BST implementation written during a competition is only about 50-60 lines long, while the STL implementation of std::set and std::multiset is usually more than a thousand-lines long in total.

It does make me wonder: what makes this huge difference? And, will the code of our BST bloat as well if we write it the STL way -- with generics, iterators, and all the necessary bits and pieces as specified in the reference?

To answer my question, I have here tried creating my implementation of multiset using treap as the underlying data structure. The result, treap_multiset, is almost a drop-in replacement to std::multiset. The few places where it does not conform to the C++ standard are:

  1. It is currently not allocator-aware, so all allocator-related features are not implemented.
  2. emplace and emplace_hint are not implemented.
  3. All operations that have logarithmic time complexity in std::multiset still have logarithmic time complexity here, but only in the average sense (because treap is a randomized data structure), and could have linear worst-case time complexity (though very, very, very unlikely).
  4. void erase(iterator) takes amortized logarithmic time instead of constant time.
  5. A few uncommonly-used member types are missing.

treap_multiset also supports two new operations that are not supported in the original std::multiset:

  1.    size_type rank(iterator it) const;
       size_type rank(const_iterator it) const;

    Both take average logarithmic time and return the rank / position of the iterator.

  2.    iterator at(size_type index);
       const iterator at(size_type index) const;

    Both take average logarithmic time and return the iterator at the specified index/position.

The code is shown below:

#include <stdexcept>
#include <random>
#include <algorithm>

static std::random_device random_device;
static std::mt19937_64 random_engine(random_device());

template <typename T>
struct treap_node {
    using rand_weight_type = decltype(random_engine)::result_type;
    using size_type = std::size_t;

    treap_node *left, *right, *parent;
    rand_weight_type weight;
    size_type size;
    T value;

    treap_node(const T &value): left(nullptr), right(nullptr), parent(nullptr),
        weight(random_engine()), size(1), value(value) {}

    treap_node(treap_node *left, treap_node *right, treap_node *parent, 
        rand_weight_type weight, size_type size, const T &value): 
        left(left), right(right), parent(parent), weight(weight),
        size(size), value(value) {}

    void update_size() {
        size = 1 + (left ? left->size : 0) + (right ? right->size : 0);
    }
};

#define IMPL_ITERATOR_MOVE_NEXT do { \
    if (!node) break; \
    if (node->right) { \
        node = node->right; \
        while (node && node->left) node = node->left; \
    } else { \
        bool from_right = true; \
        while (from_right) { \
            if (!node->parent) { node = nullptr; break; } \
            from_right = node->parent->right == node; \
            node = node->parent; \
        } \
    } \
} while (0)

#define IMPL_ITERATOR_MOVE_PREV do { \
    if (!node) break; \
    if (node->left) { \
        node = node->left; \
        while (node && node->right) node = node->right; \
    } else { \
        bool from_left = true; \
        while (from_left) { \
            if (!node->parent) { node = nullptr; break; } \
            from_left = node->parent->left == node; \
            node = node->parent; \
        } \
    } \
} while (0)

#define TREAP_ITERATOR_DECL(name, qualifier, inc, dec) \
template <typename T> struct name { \
    qualifier treap_node<T> *node; bool past_the_end; \
    name(): node(nullptr), past_the_end(true) {} \
    name(qualifier treap_node<T> *node, bool past_the_end) \
        : node(node), past_the_end(past_the_end) {} \
    name(qualifier treap_node<T> *node): node(node), past_the_end(false) {} \
    qualifier T &operator *() qualifier { \
        if (!node || past_the_end) \
            throw std::runtime_error("dereferencing null/past-end iterator"); \
        return node->value; \
    } \
    bool operator ==(const name<T> &b) const { \
        return node == b.node && past_the_end == b.past_the_end; \
    } \
    bool operator !=(const name<T> &b) const { \
        return node != b.node || past_the_end != b.past_the_end; \
    } \
    name<T> &operator ++() { \
        qualifier treap_node<T> *backup = node; \
        IMPL_ITERATOR_MOVE_##inc; \
        if (!node) past_the_end = true, node = backup; \
        return *this; \
    } \
    name<T> operator ++(int) { \
        name<T> ret(*this); \
        return ++(*this), ret; \
    } \
    name<T> &operator --() { \
        if (past_the_end) past_the_end = false; \
        else IMPL_ITERATOR_MOVE_##dec; \
        if (!node) \
            throw std::runtime_error("can't decrement at the beginning"); \
        return *this; \
    } \
    name<T> operator --(int) { \
        name<T> ret(*this); \
        return --(*this), ret; \
    } \
}

TREAP_ITERATOR_DECL(treap_iterator, /* NO QUALIFIER */, NEXT, PREV);
TREAP_ITERATOR_DECL(reverse_treap_iterator, /* NO QUALIFIER */, PREV, NEXT);
TREAP_ITERATOR_DECL(const_treap_iterator, const, NEXT, PREV);
TREAP_ITERATOR_DECL(const_reverse_treap_iterator, const, PREV, NEXT);

template <typename T, typename Compare = std::less<T>>
class treap_multiset {
public:
    using key_type = T;
    using value_type = T;
    using size_type = typename treap_node<T>::size_type;
    using key_compare = Compare;
    using value_compare = Compare;
    using node_type = treap_node<T>*;
    using iterator = treap_iterator<T>;
    using reverse_iterator = reverse_treap_iterator<T>;
    using const_iterator = const_treap_iterator<T>;
    using const_reverse_iterator = const_reverse_treap_iterator<T>;

    treap_multiset(): root(nullptr) {}

    treap_multiset(const treap_multiset &b): root(deep_copy(b.root)), comp(b.comp) {}

    treap_multiset(treap_multiset &&b): root(b.root), comp(b.comp) {}

    ~treap_multiset() { if (root) recursive_free(root); }

    bool empty() const { return root == nullptr; }

    size_type size() const { return root ? root->size : 0; }

    size_type max_size() const { return 0x7FFFFFFF; }

    key_compare key_comp() const { return comp; }

    value_compare value_comp() const { return comp; }

    iterator begin() 
        { return iterator(leftmost(root)); }

    const_iterator begin() const 
        { return const_iterator(leftmost(root)); }

    const_iterator cbegin() const 
        { return const_iterator(leftmost(root)); }

    iterator end() 
        { return iterator(rightmost(root), true); }

    const_iterator end() const 
        { return const_iterator(rightmost(root), true); }

    const_iterator cend() const 
        { return const_iterator(rightmost(root), true); }

    reverse_iterator rbegin() 
        { return reverse_iterator(rightmost(root)); }

    const_reverse_iterator rbegin() const 
        { return const_reverse_iterator(rightmost(root)); }

    const_reverse_iterator crbegin() const 
        { return const_reverse_iterator(rightmost(root)); }

    reverse_iterator rend() 
        { return reverse_iterator(leftmost(root), true); }

    const_reverse_iterator rend() const 
        { return const_reverse_iterator(leftmost(root), true); }

    const_reverse_iterator crend() const 
        { return const_reverse_iterator(leftmost(root), true); }

    iterator insert(const value_type &value) {
        node_type left, right;
        split_le(root, value, left, nullptr, right, nullptr);
        node_type temp = new treap_node<T>(value);
        root = join(join(left, temp), right);
        return iterator(temp);
    }

    iterator insert(iterator position, const value_type &value) 
        { return insert(value); }

    template <typename II>
    void insert(II first, II last) { 
        for (; first != last; first++) 
            insert(*first); 
    }

    size_type rank(iterator it) const { return rank(it.node); }
    
    size_type rank(const_iterator it) const { return rank(it.node); }

    iterator at(size_type index) {
        if (index < 0 || index > size()) return end();
        return iterator(at_internal(index));
    }

    const_iterator at(size_type index) const {
        if (index < 0 || index > size()) return end();
        return const_iterator(at_internal(index));
    }

    void erase(iterator pos) {
        assert_valid(pos);
        node_type a, b, c;
        size_type rank = this->rank(pos.node);
        split_size(root, rank, a, nullptr, c, nullptr);
        split_size(a, rank - 1, a, nullptr, b, nullptr);
        root = join(a, c);
        // assert(b == pos.node);
        delete b;
    }

    size_type erase(const value_type &key) {
        node_type a, b, c;
        split_le(root, key, a, nullptr, c, nullptr);
        split_re(a, key, a, nullptr, b, nullptr);
        root = join(a, c);
        if (b) {
            size_type ret = b->size;
            recursive_free(b);
            return ret;
        }
        return 0;
    }

    void erase(iterator first, iterator last) {
        size_type rank_first = rank(first);
        size_type rank_last = rank(last);
        node_type a, b, c;
        split_size(root, rank_last, a, nullptr, c, nullptr);
        split_size(a, rank_last - 1, a, nullptr, b, nullptr);
        root = join(a, c);
        if (b) recursive_free(b);
    }

    void clear() { if (root) recursive_free(root); }

    void swap(treap_multiset &b) { swap(root, b.root); }

    iterator find(const value_type &key) {
        node_type ret = find_internal(key);
        return ret ? iterator(ret) : end();
    }

    const_iterator find(const value_type &key) const {
        node_type ret = const_cast<treap_multiset<T>*>(this)->find_internal(key);
        return ret ? const_iterator(ret) : end();
    }

    size_type count(const value_type &key) const {
        node_type a, b, c;
        treap_multiset<T> *thiz = const_cast<treap_multiset<T>*>(this);
        thiz->split_le(root, key, a, nullptr, c, nullptr);
        thiz->split_re(root, key, a, nullptr, b, nullptr);
        size_type ret = b ? b->size : 0;
        thiz->root = thiz->join(thiz->join(a, b), c);
        return ret;
    }

    iterator lower_bound(const value_type &key) {
        node_type ret = lower_bound_internal(key);
        return ret ? iterator(ret) : end();
    }

    const_iterator lower_bound(const value_type &key) const {
        node_type ret = const_cast<treap_multiset<T>*>(this)->lower_bound_internal(key);
        return ret ? const_iterator(ret) : end();
    }

    iterator upper_bound(const value_type &key) {
        node_type ret = upper_bound_internal(key);
        return ret ? iterator(ret) : end();
    }

    const_iterator upper_bound(const value_type &key) const {
        node_type ret = const_cast<treap_multiset<T>*>(this)->upper_bound_internal(key);
        return ret ? const_iterator(ret) : end();
    }

    std::pair<iterator, iterator> equal_range(const value_type &key) {
        return std::make_pair(lower_bound(key), upper_bound(key));
    }

    std::pair<const_iterator, const_iterator> equal_range(const value_type &key) const {
        return std::make_pair(lower_bound(key), upper_bound(key));
    }

    size_type depth(node_type node) {
        if (!node) return 0;
        return 1 + std::max(depth(node->left), depth(node->right));
    }
    
private:
    Compare comp;
    node_type root;

    void recursive_free(node_type root) {
        if (root->left) 
            recursive_free(root->left);       
        if (root->right) 
            recursive_free(root->right);
        delete root;
    }

    void split_le(node_type root, const value_type &key, 
        node_type &left, node_type left_parent,
        node_type &right, node_type right_parent) {
        if (!root) { left = right = nullptr; return; }
        if (!comp(key, root->value)) {
            left = root; root->parent = left_parent;
            split_le(root->right, key, root->right, root, right, right_parent);
        } else {
            right = root; root->parent = right_parent;
            split_le(root->left, key, left, left_parent, root->left, root);
        }
        root->update_size();
    }

    void split_re(node_type root, const value_type &key, 
        node_type &left, node_type left_parent,
        node_type &right, node_type right_parent) {
        if (!root) { left = right = nullptr; return; }
        if (comp(root->value, key)) {
            left = root; root->parent = left_parent;
            split_re(root->right, key, root->right, root, right, right_parent);
        } else {
            right = root; root->parent = right_parent;
            split_re(root->left, key, left, left_parent, root->left, root);
        }
        root->update_size();
    }

    void split_size(node_type root, size_type size, 
        node_type &left, node_type left_parent,
        node_type &right, node_type right_parent) {
        if (!root) { left = right = nullptr; return; }
        size_type left_size = 1 + (root->left ? root->left->size : 0);
        if (left_size <= size) {
            left = root; root->parent = left_parent;
            split_size(root->right, 
                size - left_size, root->right, root, right, right_parent);
        } else {
            right = root; root->parent = right_parent;
            split_size(root->left, size, left, left_parent, root->left, root);
        }
        root->update_size();
    }

    node_type join(node_type left, node_type right) {
        if (!left) return right;
        if (!right) return left;
        if (left->weight <= right->weight) {
            node_type temp = join(left->right, right);
            if (temp) temp->parent = left;
            left->right = temp;
            left->update_size();
            return left;
        } else {
            node_type temp = join(left, right->left);
            if (temp) temp->parent = right;
            right->left = temp;
            right->update_size();
            return right;
        }
    }

    node_type leftmost(node_type x) const {
        if (!x) return nullptr;
        node_type ret = x; while (ret->left) ret = ret->left;
        return ret;
    }

    node_type rightmost(node_type x) const {
        if (!x) return nullptr;
        node_type ret = x; while (ret->right) ret = ret->right;
        return ret;
    }

    node_type find_internal(const value_type &key) {
        node_type a, b, c;
        split_le(root, key, a, nullptr, c, nullptr);
        split_re(root, key, a, nullptr, b, nullptr);
        root = join(join(a, b), c);
        return b;
    }

    node_type lower_bound_internal(const value_type &key) {
        node_type left, right;
        split_re(root, key, left, nullptr, right, nullptr);
        node_type ret = leftmost(right);
        root = join(left, right);
        return ret;
    }

    node_type upper_bound_internal(const value_type &key) {
        node_type left, right;
        split_le(root, key, left, nullptr, right, nullptr);
        node_type ret = leftmost(right);
        root = join(left, right);
        return ret;
    }

    node_type at_internal(size_type index) const {
        node_type temp = root;
        while (true) {
            size_type left_size = 1 + (temp->left ? temp->left->size : 0);
            if (index == left_size) return temp;
            else if (index < left_size) temp = temp->left;
            else temp = temp->right, index -= left_size;
        }
        return nullptr; // UNREACHABLE
    }

    size_type rank(node_type node) {
        bool from_right = true;
        size_type ret = 0;
        while (node) {
            if (from_right)
                ret += 1 + (node->left ? node->left->size : 0);
            if (node->parent)
                from_right = node == node->parent->right;
            node = node->parent;
        }
        return ret;
    }

    node_type deep_copy(node_type node) {
        if (!node) return nullptr;
        node_type left = deep_copy(node->left);
        node_type right = deep_copy(node->right);
        node_type ret = new treap_node<T>(
            left, right, nullptr,
            node->weight, node->size, node->value
        );
        if (left) left->parent = ret;
        if (right) right->parent = ret;
        return ret;
    }

    void assert_valid(iterator it) {
        if (!it.node || it.past_the_end)
            throw std::runtime_error("invalid iterator");
        node_type temp = it.node;
        while (temp->parent) temp = temp->parent;
        if (temp != root)
            throw std::runtime_error("invalid iterator");
    }
};

(The code above hasn't been thoroughly tested yet and could still contain bugs).