Browse Source

Misc improvements.

 * Minor code/documentation cleanups.
 * Add ForEach(), for people that want to use a closure or whatever
   instead of an iterator.
Yawning Angel 1 year ago
parent
commit
bf9e2f6cef
2 changed files with 57 additions and 19 deletions
  1. 26 6
      avl.go
  2. 31 13
      avl_test.go

+ 26 - 6
avl.go

@@ -21,6 +21,7 @@ import "errors"
 
 var (
 	errNoCmpFn          = errors.New("avl: no comparison function")
+	errNotInTree        = errors.New("avl: element not in tree")
 	errInvalidDirection = errors.New("avl: invalid direction")
 )
 
@@ -36,10 +37,10 @@ type CompareFunc func(a, b interface{}) int
 type Direction int8
 
 const (
-	// Backward is backward (in-order).
+	// Backward is backwards in-order.
 	Backward Direction = -1
 
-	// Forward is forward (in-order).
+	// Forward is forwards in-order.
 	Forward Direction = 1
 )
 
@@ -53,7 +54,8 @@ type Iterator struct {
 }
 
 // First moves the iterator to the first Node in the Tree and returns the
-// first Node or nil iff the Tree is empty.
+// first Node or nil iff the Tree is empty.  Note that "first" in this context
+// is dependent on the direction specified when constructing the iterator.
 func (it *Iterator) First() *Node {
 	it.cur, it.next = it.tree.firstOrLastInOrder(-it.sign), nil
 	if it.cur != nil {
@@ -98,7 +100,7 @@ type Node struct {
 
 func (n *Node) reset() {
 	// Note: This deliberately leaves Value intact.
-	n.parent, n.left, n.right = nil, nil, nil
+	n.parent, n.left, n.right = n, nil, nil
 	n.balance = 0
 }
 
@@ -179,6 +181,7 @@ func (t *Tree) Find(v interface{}) *Node {
 	}
 
 	cur := t.root
+descendLoop:
 	for cur != nil {
 		cmp := t.cmpFn(v, cur.Value)
 		switch {
@@ -187,11 +190,11 @@ func (t *Tree) Find(v interface{}) *Node {
 		case cmp > 0:
 			cur = cur.right
 		default:
-			return cur
+			break descendLoop
 		}
 	}
 
-	return nil
+	return cur
 }
 
 // Insert inserts the value into the Tree, and returns the newly created Node
@@ -233,6 +236,10 @@ func (t *Tree) Remove(node *Node) {
 	var parent *Node
 	var leftDeleted bool
 
+	if node.parent == node {
+		panic(errNotInTree)
+	}
+
 	t.size--
 	if node.left != nil && node.right != nil {
 		parent, leftDeleted = t.swapWithSuccessor(node)
@@ -292,6 +299,19 @@ func (t *Tree) Iterator(direction Direction) *Iterator {
 	}
 }
 
+// ForEach executes a function for each Node in the tree, visiting the nodes
+// in-order in the direction specified.  If the provided function returns
+// false, the iteration is stopped.  Modifying the Tree from within the
+// function is unsupprted except for removing the current Node.
+func (t *Tree) ForEach(direction Direction, fn func(*Node) bool) {
+	it := t.Iterator(direction)
+	for node := it.Get(); node != nil; node = it.Next() {
+		if !fn(node) {
+			return
+		}
+	}
+}
+
 func (t *Tree) firstOrLastInOrder(sign int8) *Node {
 	first := t.root
 	if first != nil {

+ 31 - 13
avl_test.go

@@ -60,19 +60,24 @@ func TestAVLTree(t *testing.T) {
 	}
 
 	// Test the forward/backward iterators.
-	inOrder := make([]int, 0, nrEntries)
+	fwdInOrder := make([]int, 0, nrEntries)
 	for k := range insertedMap {
-		inOrder = append(inOrder, k)
+		fwdInOrder = append(fwdInOrder, k)
+	}
+	sort.Ints(fwdInOrder)
+	require.Equal(fwdInOrder[0], tree.First().Value, "First(), full")
+	require.Equal(fwdInOrder[nrEntries-1], tree.Last().Value, "Last(), full")
+
+	revInOrder := make([]int, 0, nrEntries)
+	for i := len(fwdInOrder) - 1; i >= 0; i-- {
+		revInOrder = append(revInOrder, fwdInOrder[i])
 	}
-	sort.Ints(inOrder)
-	require.Equal(inOrder[0], tree.First().Value, "First(), full")
-	require.Equal(inOrder[nrEntries-1], tree.Last().Value, "Last(), full")
 
 	iter = tree.Iterator(Forward)
 	visited := 0
 	for node := iter.First(); node != nil; node = iter.Next() {
 		v, idx := node.Value.(int), visited
-		require.Equal(inOrder[visited], v, "Iterator: Forward[%v]", idx)
+		require.Equal(fwdInOrder[visited], v, "Iterator: Forward[%v]", idx)
 		require.Equal(node, iter.Get(), "Iterator: Forward[%v]: Get()", idx)
 		visited++
 	}
@@ -81,16 +86,29 @@ func TestAVLTree(t *testing.T) {
 	iter = tree.Iterator(Backward)
 	visited = 0
 	for node := iter.First(); node != nil; node = iter.Next() {
-		v, idx := node.Value.(int), nrEntries-1-visited
-		require.Equal(inOrder[idx], v, "Iterator: Backward[%v]", idx)
+		v, idx := node.Value.(int), visited
+		require.Equal(revInOrder[idx], v, "Iterator: Backward[%v]", idx)
 		require.Equal(node, iter.Get(), "Iterator: Backward[%v]: Get()", idx)
 		visited++
 	}
 	require.Equal(nrEntries, visited, "Iterator: Backward: Visited")
 
+	// Test the forward/backward ForEach.
+	forEachValues := make([]int, 0, nrEntries)
+	forEachFn := func(n *Node) bool {
+		forEachValues = append(forEachValues, n.Value.(int))
+		return true
+	}
+	tree.ForEach(Forward, forEachFn)
+	require.Equal(fwdInOrder, forEachValues, "ForEach: Forward")
+
+	forEachValues = make([]int, 0, nrEntries)
+	tree.ForEach(Backward, forEachFn)
+	require.Equal(revInOrder, forEachValues, "ForEach: Backward")
+
 	// Test removal.
-	for i, idx := range rand.Perm(nrEntries) {
-		v := inOrder[idx]
+	for i, idx := range rand.Perm(nrEntries) { // In random order.
+		v := fwdInOrder[idx]
 		node := tree.Find(v)
 		require.Equal(v, node.Value, "Find(): %v (Pre-remove)", v)
 
@@ -106,7 +124,7 @@ func TestAVLTree(t *testing.T) {
 	require.Nil(tree.Last(), "Last(): After removal")
 
 	// Refill the tree.
-	for _, v := range inOrder {
+	for _, v := range fwdInOrder {
 		tree.Insert(v)
 	}
 
@@ -115,8 +133,8 @@ func TestAVLTree(t *testing.T) {
 	visited = 0
 	for node := iter.Get(); node != nil; node = iter.Next() { // Omit calling First().
 		v, idx := node.Value.(int), visited
-		require.Equal(inOrder[idx], v, "Iterator: Forward[%v] (Pre-Remove)", idx)
-		require.Equal(inOrder[idx], tree.First().Value, "First() (Iterator, remove)")
+		require.Equal(fwdInOrder[idx], v, "Iterator: Forward[%v] (Pre-Remove)", idx)
+		require.Equal(fwdInOrder[idx], tree.First().Value, "First() (Iterator, remove)")
 		visited++
 
 		tree.Remove(node)