rbtree: silence parentheses warning
[polintos/scott/priv.git] / include / c++ / util / rbtree.h
1 // util/rbtree.h -- A red/black tree implementation
2 //
3 // This software is copyright (c) 2006 Scott Wood <scott@buserror.net>.
4 // 
5 // This software is provided 'as-is', without any express or implied warranty.
6 // In no event will the authors or contributors be held liable for any damages
7 // arising from the use of this software.
8 // 
9 // Permission is hereby granted to everyone, free of charge, to use, copy,
10 // modify, prepare derivative works of, publish, distribute, perform,
11 // sublicense, and/or sell copies of the Software, provided that the above
12 // copyright notice and disclaimer of warranty be included in all copies or
13 // substantial portions of this software.
14
15 #ifndef _UTIL_RBTREE_H
16 #define _UTIL_RBTREE_H
17
18 #include <assert.h>
19 #include <stddef.h>
20 #include <stdint.h>
21
22 namespace Util {
23         // T must have an RBTree<T, NodeV, KeyV>::Node member called rbtree_node.
24         // NodeV < NodeV, NodeV < KeyV, and NodeV > KeyV must be supported.
25         //
26         // Using NodeV != KeyV allows things such as having NodeV represent a
27         // range of values and KeyV a single value, causing find() to search
28         // for the node in whose range KeyV falls.  It is an error to add
29         // nodes that are not well ordered.
30
31         template <typename T, typename NodeV, typename KeyV>
32         class RBTree {
33         public:
34                 struct Node {
35                         Node *parent, *left, *right;
36                         NodeV value;
37                         bool red;
38                         
39                         // A pointer to the parent's left or right pointer
40                         // (whichever side this node is on).  This is also
41                         // used for sanity checks (it is NULL for nodes not
42                         // on a tree).
43                         
44                         Node **parentlink;
45                         
46                         Node()
47                         {
48                                 parent = left = right = NULL;
49                                 parentlink = NULL;
50                         }
51                         
52                         bool is_on_rbtree()
53                         {
54                                 return parentlink != NULL;
55                         }
56                 };
57         
58         private:
59                 Node *top;
60         
61                 T *node_to_type(Node *n)
62                 {
63                         return reinterpret_cast<T *>(reinterpret_cast<uintptr_t>(n) - 
64                                                      offsetof(T, rbtree_node));
65                 }
66                 
67                 void assert_rb_condition(Node *left, Node *right)
68                 {
69                         assert(left->value < right->value);
70                         assert(!(right->value < left->value));
71                         assert(!left->red || !right->red);
72                 }
73                 
74                 void assert_well_ordered(Node *left, NodeV right)
75                 {
76                         assert(left->value < right);
77                         assert(!(left->value > right));
78                         assert(!left->red || !right->red);
79                 }
80                 
81                 Node *find_node(Node *n, NodeV &val, Node ***link);
82                 Node *find_node_key(KeyV val, bool exact = true);
83                 
84                 bool node_is_left(Node *n)
85                 {
86                         assert(n->parentlink == &n->parent->left ||
87                                n->parentlink == &n->parent->right);
88                 
89                         return n->parentlink == &n->parent->left;
90                 }
91                 
92                 Node *sibling(Node *n)
93                 {
94                         if (node_is_left(n))
95                                 return n->parent->right;
96
97                         return n->parent->left;
98                 }
99                 
100                 Node *uncle(Node *n)
101                 {
102                         return sibling(n->parent);
103                 }
104                 
105                 void rotate_left(Node *n);
106                 void rotate_right(Node *n);
107                 
108                 // B must have no right child and have A as an ancestor.
109                 void swap(Node *a, Node *b);
110
111                 bool red(Node *n)
112                 {
113                         return n && n->red;
114                 }
115                 
116 public:
117                 RBTree()
118                 {
119                         top = NULL;
120                 }
121                 
122                 bool empty()
123                 {
124                         return top == NULL;
125                 }
126                 
127                 T *find(KeyV value)
128                 {
129                         Node *n = find_node_key(value);
130                         
131                         if (n)
132                                 return node_to_type(n);
133                                 
134                         return NULL;
135                 }
136                 
137                 T *find_nearest(KeyV value)
138                 {
139                         Node *n = find_node_key(value, false);
140                         
141                         if (n)
142                                 return node_to_type(n);
143                         
144                         // Should only happen on an empty tree
145                         return NULL;
146                 }
147
148                 void add(T *t);
149                 void del(T *t);
150         };
151         
152         template <typename T, typename NodeV, typename KeyV>
153         typename RBTree<T, NodeV, KeyV>::Node *
154         RBTree<T, NodeV, KeyV>::find_node(Node *n, NodeV &val, Node ***link)
155         {
156                 if (link) {
157                         if (n == top)
158                                 *link = &top;
159                         else
160                                 *link = n->parentlink;
161                 }
162                 
163                 if (!top)
164                         return NULL;
165                 
166                 while (true) {
167                         assert(n);
168                         Node *next;
169                 
170                         if (n->value < val) {
171                                 next = n->right;
172                                 
173                                 if (next) {
174                                         assert_rb_condition(n, next);
175                                 } else {
176                                         if (link)
177                                                 *link = &n->right;
178
179                                         break;
180                                 }
181                         } else {
182                                 // This assert detects duplicate nodes, but not
183                                 // overlapping ranges (or otherwise non-well-ordered
184                                 // values).
185                                 
186                                 assert(val < n->value);
187                                 next = n->left;
188                                 
189                                 if (next) {
190                                         assert_rb_condition(next, n);
191                                 } else {
192                                         if (link)
193                                                 *link = &n->left;
194
195                                         break;
196                                 }
197                         }
198                         
199                         n = next;
200                 }
201                 
202                 return n;
203         }
204
205         template <typename T, typename NodeV, typename KeyV>
206         typename RBTree<T, NodeV, KeyV>::Node *
207         RBTree<T, NodeV, KeyV>::find_node_key(KeyV val, bool exact)
208         {
209                 Node *n = top;
210                 
211                 while (n) {
212                         Node *next;
213                 
214                         if (n->value < val) {
215                                 next = n->right;
216                                 
217                                 if (next)
218                                         assert_rb_condition(n, next);
219                         } else if (n->value > val) {
220                                 next = n->left;
221
222                                 if (next)
223                                         assert_rb_condition(next, n);
224                         } else {
225                                 break;
226                         }
227                         
228                         if (!next && !exact)
229                                 return n;
230                         
231                         n = next;
232                 }
233                 
234                 return n;
235         }
236         
237         template <typename T, typename NodeV, typename KeyV>
238         void RBTree<T, NodeV, KeyV>::rotate_left(Node *n)
239         {
240                 Node *new_top = n->right;
241                 
242                 assert(*n->parentlink == n);
243                 *n->parentlink = new_top;
244                 
245                 assert(new_top->parent == n);
246                 assert(new_top->parentlink == &n->right);
247
248                 new_top->parent = n->parent;
249                 new_top->parentlink = n->parentlink;
250                 
251                 n->parent = new_top;
252                 n->parentlink = &new_top->left;
253                 
254                 n->right = new_top->left;
255                 new_top->left = n;
256                 
257                 if (n->right) {
258                         assert(n->right->parent == new_top);
259                         assert(n->right->parentlink == &new_top->left);
260                 
261                         n->right->parent = n;
262                         n->right->parentlink = &n->right;
263                 }
264         }
265
266         template <typename T, typename NodeV, typename KeyV>
267         void RBTree<T, NodeV, KeyV>::rotate_right(Node *n)
268         {
269                 Node *new_top = n->left;
270                 
271                 assert(*n->parentlink == n);
272                 *n->parentlink = new_top;
273                 
274                 assert(new_top->parent == n);
275                 assert(new_top->parentlink == &n->left);
276
277                 new_top->parent = n->parent;
278                 new_top->parentlink = n->parentlink;
279                 
280                 n->parent = new_top;
281                 n->parentlink = &new_top->right;
282                 
283                 n->left = new_top->right;
284                 new_top->right = n;
285                 
286                 if (n->left) {
287                         assert(n->left->parent == new_top);
288                         assert(n->left->parentlink == &new_top->right);
289
290                         n->left->parent = n;
291                         n->left->parentlink = &n->left;
292                 }
293         }
294         
295         // B must have no right child and have A as an ancestor.
296         template <typename T, typename NodeV, typename KeyV>
297         void RBTree<T, NodeV, KeyV>::swap(Node *a, Node *b)
298         {
299                 Node *bparent = b->parent;
300                 Node **bparentlink = b->parentlink;
301                 
302                 assert(!b->right);
303                 assert(a->left || a->right);
304                 
305                 b->parent = a->parent;
306                 b->parentlink = a->parentlink;
307                 
308                 if (bparent == a) {
309                         a->parent = b;
310                         a->parentlink = &b->left;
311                 } else {
312                         a->parent = bparent;
313                         a->parentlink = bparentlink;
314                 }
315                 
316                 assert(a->parent != a);
317                 assert(b->parent != b);
318         
319                 Node *bleft = b->left;
320                 b->left = a->left;
321                 a->left = bleft;
322
323                 b->right = a->right;
324                 a->right = NULL;
325
326                 *a->parentlink = a;
327                 *b->parentlink = b;
328                 
329                 assert(a->parent != a);
330                 assert(b->parent != b);
331                 
332                 bool bred = b->red;
333                 b->red = a->red;
334                 a->red = bred;
335                 
336                 if (a->left) {
337                         a->left->parent = a;
338                         a->left->parentlink = &a->left;
339                 }
340
341                 if (b->right) {
342                         b->right->parent = b;
343                         b->right->parentlink = &b->right;
344                 }
345
346                 assert(b->left);
347                 b->left->parent = b;
348                 b->left->parentlink = &b->left;
349         }
350
351         template <typename T, typename NodeV, typename KeyV>
352         void RBTree<T, NodeV, KeyV>::add(T *t)
353         {
354                 Node *n = &t->rbtree_node;
355                 assert(!n->is_on_rbtree());
356                 
357                 Node *insert_at = find_node(top, n->value, &n->parentlink);
358                 
359                 assert(insert_at || !top);
360                 *n->parentlink = n;
361                 n->parent = insert_at;
362                 n->left = n->right = NULL;
363                 
364         repeat:
365                 assert(n->parentlink);
366                 assert(*n->parentlink == n);
367
368                 if (!n->parent) {
369                         n->red = false;
370                         return;
371                 }
372                 
373                 assert((n->parent->value < n->value) != 
374                        (n->value < n->parent->value));
375                 n->red = true;
376
377                 if (!n->parent->red)
378                         return;
379
380                 Node *unc = uncle(n);
381                 if (red(unc)) {
382                         assert(!n->parent->parent->red);
383                         n->parent->red = unc->red = false;
384                         n = n->parent->parent;
385                         goto repeat;
386                 }
387                 
388                 if (node_is_left(n)) {
389                         if (!node_is_left(n->parent)) {
390                                 rotate_right(n->parent);
391                                 n = n->right;
392                         }
393                 } else {
394                         if (node_is_left(n->parent)) {
395                                 rotate_left(n->parent);
396                                 n = n->left;
397                         }
398                 }
399                 
400                 assert(n->parent->red);
401                 assert(!red(uncle(n)));
402                 assert(!n->parent->parent->red);
403                 
404                 n->parent->red = false;
405                 n->parent->parent->red = true;
406                 
407                 if (node_is_left(n)) {
408                         assert(node_is_left(n->parent));
409                         rotate_right(n->parent->parent);
410                 } else {
411                         assert(!node_is_left(n->parent));
412                         rotate_left(n->parent->parent);
413                 }
414         }
415
416         template <typename T, typename NodeV, typename KeyV>
417         void RBTree<T, NodeV, KeyV>::del(T *t)
418         {
419                 Node *n = &t->rbtree_node;
420                 assert(*n->parentlink == n);
421
422                 if (n->left && n->right) {
423                         Node *highest_on_left = find_node(n->left, n->value, NULL);
424                         assert(!highest_on_left->right);
425                         swap(n, highest_on_left);
426                         assert(!n->right);
427                 }
428
429                 Node *parent = n->parent;
430                 Node *child = n->left ? n->left : n->right;
431                 
432                 if (child) {
433                         assert(child->parent == n);
434
435                         child->parent = n->parent;
436                         child->parentlink = n->parentlink;
437                         *child->parentlink = child;
438                         assert(child != parent);
439                 } else {
440                         *n->parentlink = NULL;
441                 }
442                 
443                 n->parentlink = NULL;
444                 
445                 if (n->red)
446                         return;
447                 
448                 n = child;
449
450                 if (red(n)) {
451                         n->red = false;
452                         return;
453                 }
454
455         repeat:
456                 if (n == top) {
457                         assert(!red(n));
458                         return;
459                 }
460                 
461                 Node *sib;
462                 
463                 if (n)
464                         sib = sibling(n);
465                 else
466                         sib = parent->left ? parent->left : parent->right;
467                 
468                 if (sib->red) {
469                         assert(!parent->red);
470                         assert(!red(sib->left));
471                         assert(!red(sib->right));
472                         
473                         parent->red = true;
474                         sib->red = false;
475
476                         if (node_is_left(sib)) {
477                                 rotate_right(parent);
478                                 sib = parent->left;
479                         } else {
480                                 rotate_left(parent);
481                                 sib = parent->right;
482                         }
483                         
484                         if (n)
485                                 assert(sib == sibling(n));
486                 } else if (!parent->red && !red(sib->left) && !red(sib->right)) {
487                         sib->red = true;
488                         assert(n != parent);
489                         n = parent;
490                         parent = parent->parent;
491                         goto repeat;
492                 }
493
494                 assert(!sib->red);
495                 
496                 if (!parent->red && !red(sib->left) && !red(sib->right)) {
497                         sib->red = true;
498                         parent->red = false;
499                         return;
500                 }
501                 
502                 if (!red(sib->left) && !red(sib->right)) {
503                         assert(parent->red);
504                         sib->red = true;
505                         parent->red = false;
506                         return;
507                 }
508                 
509                 if (node_is_left(sib) && !red(sib->left) && red(sib->right)) {
510                         sib->red = true;
511                         sib->right->red = false;
512                         rotate_left(sib);
513                         sib = sib->parent;
514                 } else if (!node_is_left(sib) && !red(sib->right) && red(sib->left)) {
515                         sib->red = true;
516                         sib->left->red = false;
517                         rotate_right(sib);
518                         sib = sib->parent;
519                 }
520                 
521                 assert(parent == sib->parent);
522                 assert(!sib->red);
523                 
524                 sib->red = parent->red;
525                 parent->red = false;
526                 
527                 if (node_is_left(sib)) {
528                         assert(sib->left->red);
529                         sib->left->red = false;
530                         rotate_right(parent);
531                 } else {
532                         assert(sib->right->red);
533                         sib->right->red = false;
534                         rotate_left(parent);
535                 }
536         }
537
538         // RBPtr is a Pointer->Value associative array, and RBInt is an
539         // Integer->Value associative array.
540
541         template<typename Ptr, typename Val>
542         struct RBPtrNode {
543                 typedef RBTree<RBPtrNode, Ptr, Ptr> Tree;
544                 typename Tree::Node rbtree_node;
545                 Val value;
546                 
547                 intptr_t operator < (RBPtrNode &other)
548                 {
549                         return (intptr_t)other.rbtree_node.value -
550                                (intptr_t)rbtree_node.value;
551                 }
552
553                 intptr_t operator > (RBPtrNode &other)
554                 {
555                         return (intptr_t)rbtree_node.value -
556                                (intptr_t)other.rbtree_node.value;
557                 }
558
559                 operator Val &()
560                 {
561                         return value;
562                 }
563         };
564
565         template<typename Ptr, typename Val>
566         struct RBPtr : public RBTree<RBPtrNode<Ptr, Val>, Ptr, Ptr>
567         {
568                 typedef RBPtrNode<Ptr, Val> Node;
569                 typedef RBTree<Node, Ptr, Ptr> Tree;
570         
571                 void add(Ptr ptr, Val &val)
572                 {
573                         Node *node = new Node;
574                         node->value = val;
575                         node->rbtree_node.value = ptr;
576                         Tree::add(node);
577                 }
578                 
579                 void del(Ptr ptr)
580                 {
581                         delete find(ptr);
582                 }
583         };
584
585         template<typename Int, typename Val>
586         struct RBIntNode {
587                 typedef RBTree<RBIntNode, Int, Int> Tree;
588                 typename Tree::Node rbtree_node;
589                 Val value;
590                 
591                 intptr_t operator < (RBIntNode &other)
592                 {
593                         return other.rbtree_node.value - rbtree_node.value;
594                 }
595
596                 intptr_t operator > (RBIntNode &other)
597                 {
598                         return rbtree_node.value - other.rbtree_node.value;
599                 }
600                 
601                 operator Val &()
602                 {
603                         return value;
604                 }
605         };
606
607         template<typename Int, typename Val>
608         struct RBInt : public RBTree<RBIntNode<Int, Val>, Int, Int>
609         {
610                 typedef RBIntNode<Int, Val> Node;
611                 typedef RBTree<Node, Int, Int> Tree;
612         
613                 void add(Int key, Val &val)
614                 {
615                         Node *node = new Node;
616                         node->value = val;
617                         node->rbtree_node.value = key;
618                         Tree::add(node);
619                 }
620                 
621                 void del(Int key)
622                 {
623                         delete find(key);
624                 }
625         };
626 }
627
628 #endif