On Github lionell / data_structures
Created by Ruslan Sakevych / @lionell
int a[] = {1, 2, 3, 4, 5, 6}; build(); sum(1, 4); // 2 + 3 + 4 + 5 = 16 sum(4, 5); // 5 + 6 = 11 update(2, 10); // a[2] = 10 sum(2, 3); // 2 + 10 = 12
int t[4 * MAX_N]; // MEMORY LIMIT
void build(int v = 1, int tl = 0, int tr = n - 1) { if (tl == tr) { t[v] = a[tl]; } else { int tm = (tl + tr) / 2; build(2 * v, tl, tm); build(2 * v + 1, tm + 1, tr); t[v] = t[2 * v] + t[2 * v + 1]; } }
int sum(int l, int r, int v = 1, int tl = 0, int tr = n - 1) { if (l > r) { return 0; } if (tl == l && tr == r) { return t[v]; } int tm = (tl + tr) / 2; return sum(l, min(r, tm), 2 * v, tl, tm) + sum(max(tm + 1, l), r, a, 2 * v + 1, tm + 1, tr); }
void update(int i, int val, int v = 1, int tl = 0, int tr = n - 1) { if (tl == tr) { t[v] = val; } else { int tm = (tl + tr) / 2; if (i <= tm) { update(i, val, 2 * v, tl, tm); } else { update(i, val, 2 * v + 1, tm + 1, tr); } t[v] = t[2 * v] + t[2 * v + 1]; } }
... int tm = tl + (tr - tl) / 2; // instead of int tm = (tl + tr) / 2; ...
x >> 1 // instead of x / 2 x << 1 // instead of x * 2
Recursive call of update is tail-call.
So we can easily convert it to loop.
Now, we need some universal method for queries.
Let's define some function that can combine useful information from child nodes.
T t[MAX_N]; T combine(T l, T r) { ... } T make(...) { ... }
void build(...) { if (tl == tr) { // t[v] = a[tl]; t[v] = make(a[tl]); } else { int tm = tl + (tr - tl) / 2; build(2 * v, tl, tm); build(2 * v + 1, tm + 1, tr); // t[v] = t[2 * v] + t[2 * v + 1]; t[v] = combine(t[2 * v], t[2 * v + 1]); } }
void update(int i, T val, ...) { if (tl == tr) { // t[v] = val; t[v] = make(val); } else { int tm = tl + (tr - tl) / 2; if (i <= tm) { update(i, val, 2 * v, tl, tm); } else { update(i, val, 2 * v + 1, tm + 1, tr); } // t[v] = t[2 * v] + t[2 * v + 1]; t[v] = combine(t[2 * v], t[2 * v + 1]); } }
API example
int a[] = {0, 0, 0, 0, 0}; inc(1, 3, 1); // a = {0, 1, 1, 1, 0} inc(0, 2, -2); // a = {-2, -1, -1, 1, 0} get(1); // a[1] == -1
void inc(int l, int r, int x, ...) { if (l > r) { return; } if (tl == l && tr == r) { t[v] += x; } else { int tm = tl + (tr - tl) / 2; inc(l, min(r, tm), x, 2 * v, tl, tm); inc(max(tm + 1, l), x, r, 2 * v + 1, tm + 1, tr); } }
int get(int i, ...) { if (tl == tr) { return a[i]; } int tm = tl + (tr - tl) / 2; if (i <= tm) { return t[v] + get(i, 2 * v, tl, tm); } return t[v] + get(i, 2 * v + 1, tm + 1, tr); }
API example
int a[] = {0, 0, 0, 0, 0}; let(3, 4, 1); // a = {0, 0, 0, 1, 1} let(2, 3, 7); // a = {0, 0, 7, 7, 1} get(3); // a[1] == 7
void push(int v) { if (t[v] == -1) { return; } t[v * 2] = t[2 * v + 1] = t[v]; t[v] = -1; }
void let(int l, int r, int x, ...) { if (l > r) { return; } if (tl == l && tr == r) { t[v] = x; } else { push(v); int tm = tl + (tr - tl) / 2; let(l, min(r, tm), x, 2 * v, tl, tm); let(max(l, tm + 1), x, r, 2 * v + 1, tm + 1, tr); } }
int get(int i, ...) { if (tl == tr) { return t[v]; } push(v); int tm = tl + (tr - tl) / 2; if (i <= tm) { return get(i, 2 * v, tl, tm); } return get(i, 2 * v + 1, tm + 1, tr); }
API example
int a[] = {1, 2, 0, 4, 0}; build(); let(3, 4, 1); // a = {1, 2, 0, 1, 1} let(2, 3, 7); // a = {1, 2, 7, 7, 1} sum(1, 3); // 2 + 7 + 7 = 16
pii get(int v, int tl, int tr) { return t[v].second == INF ? t[v].first : t[v].second * (tr - tl + 1); }
void push(int v) { if (t[v].second == INF) { return; } t[2 * v].second = t[2 * v + 1].second = t[v].second; t[v].second = INF; }
void let(int l, int r, int x, ...) { if (l > r) { return; } if (tl == l && tr == r) { t[v].second = x; } else { push(v); int tm = tl + (tr - tl) / 2; let(l, min(r, tm), x, 2 * v, tl, tm); let(max(l, tm + 1), r, x, 2 * v + 1, tm + 1, tr); t[v] = {get(2 * v, tl, tm) + get(2 * v + 1, tm + 1, tr), INF}; } }
int sum(int l, int r, ...) { if (l > r) { return 0; } if (tl == l && tr == r) { return get(v, tl, tr); } push(v); int tm = tl + (tr - tl) / 2; return sum(l, min(r, tm), 2 * v, tl, tm) + sum(max(l, tm + 1), r, 2 * v + 1, tm + 1, tr); }
API example
t->insert(1); // t = {4} ... t->insert(6); // t = {1, 2, 3, 4, 5, 6} t->remove(3); // t = {1, 2, 4, 5, 6} t->split(2, l, r); // l = {1}, r = {2, 4, 5, 6} r->sum(); // 2 + 4 + 5 + 6 = 17 t = merge(l, r); // t = {1, 2, 4, 5, 6}
struct Treap { int x; int y; Treap *left; Treap *right; Treap(int x, int y, Treap *left, Treap *right); static Treap *merge(Treap *l, Treap* r); void split(int key, Treap *l, Treap *r); Treap *insert(int x); Treap *remove(int x); };
static Treap *merge(Treap *l, Treap *r) { if (l == nullptr) { return r; } if (r == nullptr) { return l; } if (l->y > r->y) { Treap *newRight = merge(l->right, r); return new Treap(l->x, l->y, l->left, newRight); } else { Treap *newLeft = merge(l, r->left); return new Treap(r->x, r->y, newLeft, r->right); } }
void split(int key, Treap *l, Treap *r) { Treap *newTree = nullptr; if (x <= key) { if (right == nullptr) r = nullptr; else right->split(key, newTree, r); l = new Treap(x, y, left, newTree); } else { if (left == nullptr) l = nullptr; else left->split(key, l, newTree); r = new Treap(x, y, newTree, right); } }
Treap *insert(int x) { Treap *l, *r; split(x, l, r); Treap *m = new Treap(x, rand()); return merge(merge(l, m), r); }
Treap *remove(int x) { Treap *l, *m, *r, *t; split(x, m, r); m->split(x - 1, l, t); return merge(l, r); }