8885e1238dc78320b711a63e01f29dc33b042754
[polintos/scott/priv.git] / include / c++ / util / rbtree.h
1 // util/rbtree.h -- A red/black tree implementation
2 //
3 // This software is copyright (c) 2006 Scott Wood <scott@buserror.net>.
4 // 
5 // This software is provided 'as-is', without any express or implied warranty.
6 // In no event will the authors or contributors be held liable for any damages
7 // arising from the use of this software.
8 // 
9 // Permission is hereby granted to everyone, free of charge, to use, copy,
10 // modify, prepare derivative works of, publish, distribute, perform,
11 // sublicense, and/or sell copies of the Software, provided that the above
12 // copyright notice and disclaimer of warranty be included in all copies or
13 // substantial portions of this software.
14
15 #ifndef _UTIL_RBTREE_H
16 #define _UTIL_RBTREE_H
17
18 #include <assert.h>
19 #include <stddef.h>
20 #include <stdint.h>
21
22 namespace Util {
23         // T must have an RBTree<T, NodeV, KeyV>::Node member called rbtree_node.
24         // NodeV < NodeV, NodeV < KeyV, and NodeV > KeyV must be supported.
25         //
26         // Using NodeV != KeyV allows things such as having NodeV represent a
27         // range of values and KeyV a single value, causing find() to search
28         // for the node in whose range KeyV falls.  It is an error to add
29         // nodes that are not well ordered.
30
31         template <typename T, typename NodeV, typename KeyV>
32         class RBTree {
33         public:
34                 struct Node {
35                         Node *parent, *left, *right;
36                         NodeV value;
37                         bool red;
38                         
39                         // A pointer to the parent's left or right pointer
40                         // (whichever side this node is on).  This is also
41                         // used for sanity checks (it is NULL for nodes not
42                         // on a tree).
43                         
44                         Node **parentlink;
45                         
46                         Node()
47                         {
48                                 parent = left = right = NULL;
49                                 parentlink = NULL;
50                         }
51                         
52                         bool is_on_rbtree()
53                         {
54                                 return parentlink != NULL;
55                         }
56                 };
57         
58         private:
59                 Node *top;
60         
61                 T *node_to_type(Node *n)
62                 {
63                         return reinterpret_cast<T *>(reinterpret_cast<uintptr_t>(n) - 
64                                                      offsetof(T, rbtree_node));
65                 }
66                 
67                 void assert_rb_condition(Node *left, Node *right)
68                 {
69                         assert(left->value < right->value);
70                         assert(!(right->value < left->value));
71                         assert(!left->red || !right->red);
72                 }
73                 
74                 void assert_well_ordered(Node *left, NodeV right)
75                 {
76                         assert(left->value < right);
77                         assert(!(left->value > right));
78                         assert(!left->red || !right->red);
79                 }
80                 
81                 Node *find_node(Node *n, NodeV &val, Node ***link);
82                 Node *find_node_key(KeyV val, bool exact = true);
83                 
84                 bool node_is_left(Node *n)
85                 {
86                         assert(n->parentlink == &n->parent->left ||
87                                n->parentlink == &n->parent->right);
88                 
89                         return n->parentlink == &n->parent->left;
90                 }
91                 
92                 Node *sibling(Node *n)
93                 {
94                         if (node_is_left(n))
95                                 return n->parent->right;
96
97                         return n->parent->left;
98                 }
99                 
100                 Node *uncle(Node *n)
101                 {
102                         return sibling(n->parent);
103                 }
104                 
105                 void rotate_left(Node *n);
106                 void rotate_right(Node *n);
107                 
108                 // B must have no right child and have A as an ancestor.
109                 void swap(Node *a, Node *b);
110
111                 bool red(Node *n)
112                 {
113                         return n && n->red;
114                 }
115                 
116 public:
117                 RBTree()
118                 {
119                         top = NULL;
120                 }
121                 
122                 bool empty()
123                 {
124                         return top == NULL;
125                 }
126                 
127                 T *find(KeyV value)
128                 {
129                         Node *n = find_node_key(value);
130                         
131                         if (n)
132                                 return node_to_type(n);
133                                 
134                         return NULL;
135                 }
136                 
137                 T *find_nearest(KeyV value)
138                 {
139                         Node *n = find_node_key(value, false);
140                         
141                         if (n)
142                                 return node_to_type(n);
143                         
144                         // Should only happen on an empty tree
145                         return NULL;
146                 }
147
148                 void add(T *t);
149                 void del(T *t);
150         };
151         
152         template <typename T, typename NodeV, typename KeyV>
153         typename RBTree<T, NodeV, KeyV>::Node *
154         RBTree<T, NodeV, KeyV>::find_node(Node *n, NodeV &val, Node ***link)
155         {
156                 if (link) {
157                         if (n == top)
158                                 *link = &top;
159                         else
160                                 *link = n->parentlink;
161                 }
162                 
163                 if (!top)
164                         return NULL;
165                 
166                 while (true) {
167                         assert(n);
168                         Node *next;
169                 
170                         if (n->value < val) {
171                                 next = n->right;
172                                 
173                                 if (next) {
174                                         assert_rb_condition(n, next);
175                                 } else {
176                                         if (link)
177                                                 *link = &n->right;
178
179                                         break;
180                                 }
181                         } else {
182                                 // This assert detects duplicate nodes, but not
183                                 // overlapping ranges (or otherwise non-well-ordered
184                                 // values).
185                                 
186                                 assert(val < n->value);
187                                 next = n->left;
188                                 
189                                 if (next) {
190                                         assert_rb_condition(next, n);
191                                 } else {
192                                         if (link)
193                                                 *link = &n->left;
194
195                                         break;
196                                 }
197                         }
198                         
199                         n = next;
200                 }
201                 
202                 return n;
203         }
204
205         template <typename T, typename NodeV, typename KeyV>
206         typename RBTree<T, NodeV, KeyV>::Node *
207         RBTree<T, NodeV, KeyV>::find_node_key(KeyV val, bool exact)
208         {
209                 Node *n = top;
210                 
211                 while (n) {
212                         Node *next;
213                 
214                         if (n->value < val) {
215                                 next = n->right;
216                                 
217                                 if (next)
218                                         assert_rb_condition(n, next);
219                         } else if (n->value > val) {
220                                 next = n->left;
221
222                                 if (next)
223                                         assert_rb_condition(next, n);
224                         } else {
225                                 break;
226                         }
227                         
228                         if (!next && !exact)
229                                 return n;
230                         
231                         n = next;
232                 }
233                 
234                 return n;
235         }
236         
237         template <typename T, typename NodeV, typename KeyV>
238         void RBTree<T, NodeV, KeyV>::rotate_left(Node *n)
239         {
240                 Node *new_top = n->right;
241                 
242                 assert(*n->parentlink == n);
243                 *n->parentlink = new_top;
244                 
245                 assert(new_top->parent == n);
246                 assert(new_top->parentlink == &n->right);
247
248                 new_top->parent = n->parent;
249                 new_top->parentlink = n->parentlink;
250                 
251                 n->parent = new_top;
252                 n->parentlink = &new_top->left;
253                 
254                 n->right = new_top->left;
255                 new_top->left = n;
256                 
257                 if (n->right) {
258                         assert(n->right->parent == new_top);
259                         assert(n->right->parentlink == &new_top->left);
260                 
261                         n->right->parent = n;
262                         n->right->parentlink = &n->right;
263                 }
264         }
265
266         template <typename T, typename NodeV, typename KeyV>
267         void RBTree<T, NodeV, KeyV>::rotate_right(Node *n)
268         {
269                 Node *new_top = n->left;
270                 
271                 assert(*n->parentlink == n);
272                 *n->parentlink = new_top;
273                 
274                 assert(new_top->parent == n);
275                 assert(new_top->parentlink == &n->left);
276
277                 new_top->parent = n->parent;
278                 new_top->parentlink = n->parentlink;
279                 
280                 n->parent = new_top;
281                 n->parentlink = &new_top->right;
282                 
283                 n->left = new_top->right;
284                 new_top->right = n;
285                 
286                 if (n->left) {
287                         assert(n->left->parent == new_top);
288                         assert(n->left->parentlink == &new_top->right);
289
290                         n->left->parent = n;
291                         n->left->parentlink = &n->left;
292                 }
293         }
294         
295         // B must have no right child and have A as an ancestor.
296         template <typename T, typename NodeV, typename KeyV>
297         void RBTree<T, NodeV, KeyV>::swap(Node *a, Node *b)
298         {
299                 Node *bparent = b->parent;
300                 Node **bparentlink = b->parentlink;
301                 
302                 assert(!b->right);
303                 assert(a->left || a->right);
304                 
305                 b->parent = a->parent;
306                 b->parentlink = a->parentlink;
307                 
308                 if (bparent == a) {
309                         a->parent = b;
310                         a->parentlink = &b->left;
311                 } else {
312                         a->parent = bparent;
313                         a->parentlink = bparentlink;
314                 }
315                 
316                 assert(a->parent != a);
317                 assert(b->parent != b);
318         
319                 Node *bleft = b->left;
320                 b->left = a->left;
321                 a->left = bleft;
322
323                 b->right = a->right;
324                 a->right = NULL;
325
326                 *a->parentlink = a;
327                 *b->parentlink = b;
328                 
329                 assert(a->parent != a);
330                 assert(b->parent != b);
331                 
332                 bool bred = b->red;
333                 b->red = a->red;
334                 a->red = bred;
335                 
336                 if (a->left) {
337                         a->left->parent = a;
338                         a->left->parentlink = &a->left;
339                 }
340
341                 if (b->right) {
342                         b->right->parent = b;
343                         b->right->parentlink = &b->right;
344                 }
345
346                 assert(b->left);
347                 b->left->parent = b;
348                 b->left->parentlink = &b->left;
349         }
350
351         template <typename T, typename NodeV, typename KeyV>
352         void RBTree<T, NodeV, KeyV>::add(T *t)
353         {
354                 Node *n = &t->rbtree_node;
355                 assert(!n->is_on_rbtree());
356                 
357                 Node *insert_at = find_node(top, n->value, &n->parentlink);
358                 
359                 assert(insert_at || !top);
360                 *n->parentlink = n;
361                 n->parent = insert_at;
362                 n->left = n->right = NULL;
363                 
364         repeat:
365                 assert(n->parentlink);
366                 assert(*n->parentlink == n);
367
368                 if (!n->parent) {
369                         n->red = false;
370                         return;
371                 }
372                 
373                 assert(n->parent->value < n->value != n->value < n->parent->value);
374                 n->red = true;
375
376                 if (!n->parent->red)
377                         return;
378
379                 Node *unc = uncle(n);
380                 if (red(unc)) {
381                         assert(!n->parent->parent->red);
382                         n->parent->red = unc->red = false;
383                         n = n->parent->parent;
384                         goto repeat;
385                 }
386                 
387                 if (node_is_left(n)) {
388                         if (!node_is_left(n->parent)) {
389                                 rotate_right(n->parent);
390                                 n = n->right;
391                         }
392                 } else {
393                         if (node_is_left(n->parent)) {
394                                 rotate_left(n->parent);
395                                 n = n->left;
396                         }
397                 }
398                 
399                 assert(n->parent->red);
400                 assert(!red(uncle(n)));
401                 assert(!n->parent->parent->red);
402                 
403                 n->parent->red = false;
404                 n->parent->parent->red = true;
405                 
406                 if (node_is_left(n)) {
407                         assert(node_is_left(n->parent));
408                         rotate_right(n->parent->parent);
409                 } else {
410                         assert(!node_is_left(n->parent));
411                         rotate_left(n->parent->parent);
412                 }
413         }
414
415         template <typename T, typename NodeV, typename KeyV>
416         void RBTree<T, NodeV, KeyV>::del(T *t)
417         {
418                 Node *n = &t->rbtree_node;
419                 assert(*n->parentlink == n);
420
421                 if (n->left && n->right) {
422                         Node *highest_on_left = find_node(n->left, n->value, NULL);
423                         assert(!highest_on_left->right);
424                         swap(n, highest_on_left);
425                         assert(!n->right);
426                 }
427
428                 Node *parent = n->parent;
429                 Node *child = n->left ? n->left : n->right;
430                 
431                 if (child) {
432                         assert(child->parent == n);
433
434                         child->parent = n->parent;
435                         child->parentlink = n->parentlink;
436                         *child->parentlink = child;
437                         assert(child != parent);
438                 } else {
439                         *n->parentlink = NULL;
440                 }
441                 
442                 n->parentlink = NULL;
443                 
444                 if (n->red)
445                         return;
446                 
447                 n = child;
448
449                 if (red(n)) {
450                         n->red = false;
451                         return;
452                 }
453
454         repeat:
455                 if (n == top) {
456                         assert(!red(n));
457                         return;
458                 }
459                 
460                 Node *sib;
461                 
462                 if (n)
463                         sib = sibling(n);
464                 else
465                         sib = parent->left ? parent->left : parent->right;
466                 
467                 if (sib->red) {
468                         assert(!parent->red);
469                         assert(!red(sib->left));
470                         assert(!red(sib->right));
471                         
472                         parent->red = true;
473                         sib->red = false;
474
475                         if (node_is_left(sib)) {
476                                 rotate_right(parent);
477                                 sib = parent->left;
478                         } else {
479                                 rotate_left(parent);
480                                 sib = parent->right;
481                         }
482                         
483                         if (n)
484                                 assert(sib == sibling(n));
485                 } else if (!parent->red && !red(sib->left) && !red(sib->right)) {
486                         sib->red = true;
487                         assert(n != parent);
488                         n = parent;
489                         parent = parent->parent;
490                         goto repeat;
491                 }
492
493                 assert(!sib->red);
494                 
495                 if (!parent->red && !red(sib->left) && !red(sib->right)) {
496                         sib->red = true;
497                         parent->red = false;
498                         return;
499                 }
500                 
501                 if (!red(sib->left) && !red(sib->right)) {
502                         assert(parent->red);
503                         sib->red = true;
504                         parent->red = false;
505                         return;
506                 }
507                 
508                 if (node_is_left(sib) && !red(sib->left) && red(sib->right)) {
509                         sib->red = true;
510                         sib->right->red = false;
511                         rotate_left(sib);
512                         sib = sib->parent;
513                 } else if (!node_is_left(sib) && !red(sib->right) && red(sib->left)) {
514                         sib->red = true;
515                         sib->left->red = false;
516                         rotate_right(sib);
517                         sib = sib->parent;
518                 }
519                 
520                 assert(parent == sib->parent);
521                 assert(!sib->red);
522                 
523                 sib->red = parent->red;
524                 parent->red = false;
525                 
526                 if (node_is_left(sib)) {
527                         assert(sib->left->red);
528                         sib->left->red = false;
529                         rotate_right(parent);
530                 } else {
531                         assert(sib->right->red);
532                         sib->right->red = false;
533                         rotate_left(parent);
534                 }
535         }
536
537         // RBPtr is a Pointer->Value associative array, and RBInt is an
538         // Integer->Value associative array.
539
540         template<typename Ptr, typename Val>
541         struct RBPtrNode {
542                 typedef RBTree<RBPtrNode, Ptr, Ptr> Tree;
543                 typename Tree::Node rbtree_node;
544                 Val value;
545                 
546                 intptr_t operator < (RBPtrNode &other)
547                 {
548                         return (intptr_t)other.rbtree_node.value -
549                                (intptr_t)rbtree_node.value;
550                 }
551
552                 intptr_t operator > (RBPtrNode &other)
553                 {
554                         return (intptr_t)rbtree_node.value -
555                                (intptr_t)other.rbtree_node.value;
556                 }
557
558                 operator Val &()
559                 {
560                         return value;
561                 }
562         };
563
564         template<typename Ptr, typename Val>
565         struct RBPtr : public RBTree<RBPtrNode<Ptr, Val>, Ptr, Ptr>
566         {
567                 typedef RBPtrNode<Ptr, Val> Node;
568                 typedef RBTree<Node, Ptr, Ptr> Tree;
569         
570                 void add(Ptr ptr, Val &val)
571                 {
572                         Node *node = new Node;
573                         node->value = val;
574                         node->rbtree_node.value = ptr;
575                         Tree::add(node);
576                 }
577                 
578                 void del(Ptr ptr)
579                 {
580                         delete find(ptr);
581                 }
582         };
583
584         template<typename Int, typename Val>
585         struct RBIntNode {
586                 typedef RBTree<RBIntNode, Int, Int> Tree;
587                 typename Tree::Node rbtree_node;
588                 Val value;
589                 
590                 intptr_t operator < (RBIntNode &other)
591                 {
592                         return other.rbtree_node.value - rbtree_node.value;
593                 }
594
595                 intptr_t operator > (RBIntNode &other)
596                 {
597                         return rbtree_node.value - other.rbtree_node.value;
598                 }
599                 
600                 operator Val &()
601                 {
602                         return value;
603                 }
604         };
605
606         template<typename Int, typename Val>
607         struct RBInt : public RBTree<RBIntNode<Int, Val>, Int, Int>
608         {
609                 typedef RBIntNode<Int, Val> Node;
610                 typedef RBTree<Node, Int, Int> Tree;
611         
612                 void add(Int key, Val &val)
613                 {
614                         Node *node = new Node;
615                         node->value = val;
616                         node->rbtree_node.value = key;
617                         Tree::add(node);
618                 }
619                 
620                 void del(Int key)
621                 {
622                         delete find(key);
623                 }
624         };
625 }
626
627 #endif