avl_test.go 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. // avl_test.go - AVL tree tests.
  2. //
  3. // To the extent possible under law, Yawning Angel has waived all copyright
  4. // and related or neighboring rights to avl, using the Creative
  5. // Commons "CC0" public domain dedication. See LICENSE or
  6. // <http://creativecommons.org/publicdomain/zero/1.0/> for full details.
  7. package avl
  8. import (
  9. "math/rand"
  10. "sort"
  11. "testing"
  12. "github.com/stretchr/testify/require"
  13. )
  14. func TestAVLTree(t *testing.T) {
  15. require := require.New(t)
  16. tree := New(func(a, b interface{}) int {
  17. return a.(int) - b.(int)
  18. })
  19. require.Equal(0, tree.Len(), "Len(): empty")
  20. require.Nil(tree.First(), "First(): empty")
  21. require.Nil(tree.Last(), "Last(): empty")
  22. iter := tree.Iterator(Forward)
  23. require.Nil(iter.First(), "Iterator: First(), empty")
  24. require.Nil(iter.Next(), "Iterator: Next(), empty")
  25. // Test insertion.
  26. const nrEntries = 1024
  27. insertedMap := make(map[int]*Node)
  28. for len(insertedMap) != nrEntries {
  29. v := rand.Int()
  30. if insertedMap[v] != nil {
  31. continue
  32. }
  33. insertedMap[v] = tree.Insert(v)
  34. tree.validate(require)
  35. }
  36. require.Equal(nrEntries, tree.Len(), "Len(): After insertion")
  37. tree.validate(require)
  38. // Ensure that all entries can be found.
  39. for k, v := range insertedMap {
  40. require.Equal(v, tree.Find(k), "Find(): %v", k)
  41. require.Equal(k, v.Value, "Find(): %v Value", k)
  42. }
  43. // Test the forward/backward iterators.
  44. fwdInOrder := make([]int, 0, nrEntries)
  45. for k := range insertedMap {
  46. fwdInOrder = append(fwdInOrder, k)
  47. }
  48. sort.Ints(fwdInOrder)
  49. require.Equal(fwdInOrder[0], tree.First().Value, "First(), full")
  50. require.Equal(fwdInOrder[nrEntries-1], tree.Last().Value, "Last(), full")
  51. revInOrder := make([]int, 0, nrEntries)
  52. for i := len(fwdInOrder) - 1; i >= 0; i-- {
  53. revInOrder = append(revInOrder, fwdInOrder[i])
  54. }
  55. iter = tree.Iterator(Forward)
  56. visited := 0
  57. for node := iter.First(); node != nil; node = iter.Next() {
  58. v, idx := node.Value.(int), visited
  59. require.Equal(fwdInOrder[visited], v, "Iterator: Forward[%v]", idx)
  60. require.Equal(node, iter.Get(), "Iterator: Forward[%v]: Get()", idx)
  61. visited++
  62. }
  63. require.Equal(nrEntries, visited, "Iterator: Forward: Visited")
  64. iter = tree.Iterator(Backward)
  65. visited = 0
  66. for node := iter.First(); node != nil; node = iter.Next() {
  67. v, idx := node.Value.(int), visited
  68. require.Equal(revInOrder[idx], v, "Iterator: Backward[%v]", idx)
  69. require.Equal(node, iter.Get(), "Iterator: Backward[%v]: Get()", idx)
  70. visited++
  71. }
  72. require.Equal(nrEntries, visited, "Iterator: Backward: Visited")
  73. // Test the forward/backward ForEach.
  74. forEachValues := make([]int, 0, nrEntries)
  75. forEachFn := func(n *Node) bool {
  76. forEachValues = append(forEachValues, n.Value.(int))
  77. return true
  78. }
  79. tree.ForEach(Forward, forEachFn)
  80. require.Equal(fwdInOrder, forEachValues, "ForEach: Forward")
  81. forEachValues = make([]int, 0, nrEntries)
  82. tree.ForEach(Backward, forEachFn)
  83. require.Equal(revInOrder, forEachValues, "ForEach: Backward")
  84. // Test removal.
  85. for i, idx := range rand.Perm(nrEntries) { // In random order.
  86. v := fwdInOrder[idx]
  87. node := tree.Find(v)
  88. require.Equal(v, node.Value, "Find(): %v (Pre-remove)", v)
  89. tree.Remove(node)
  90. require.Equal(nrEntries-(i+1), tree.Len(), "Len(): %v (Post-remove)", v)
  91. tree.validate(require)
  92. node = tree.Find(v)
  93. require.Nil(node, "Find(): %v (Post-remove)", v)
  94. }
  95. require.Equal(0, tree.Len(), "Len(): After removal")
  96. require.Nil(tree.First(), "First(): After removal")
  97. require.Nil(tree.Last(), "Last(): After removal")
  98. // Refill the tree.
  99. for _, v := range fwdInOrder {
  100. tree.Insert(v)
  101. }
  102. // Test that removing the node doesn't break the iterator.
  103. iter = tree.Iterator(Forward)
  104. visited = 0
  105. for node := iter.Get(); node != nil; node = iter.Next() { // Omit calling First().
  106. v, idx := node.Value.(int), visited
  107. require.Equal(fwdInOrder[idx], v, "Iterator: Forward[%v] (Pre-Remove)", idx)
  108. require.Equal(fwdInOrder[idx], tree.First().Value, "First() (Iterator, remove)")
  109. visited++
  110. tree.Remove(node)
  111. tree.validate(require)
  112. }
  113. require.Equal(0, tree.Len(), "Len(): After iterating removal")
  114. }
  115. func (t *Tree) validate(require *require.Assertions) {
  116. checkInvariants(require, t.root, nil)
  117. }
  118. func checkInvariants(require *require.Assertions, node, parent *Node) int {
  119. if node == nil {
  120. return 0
  121. }
  122. // Validate the parent pointer.
  123. require.Equal(parent, node.parent)
  124. // Validate that the balance factor is -1, 0, 1.
  125. require.Condition(func() bool {
  126. switch node.balance {
  127. case -1, 0, 1:
  128. return true
  129. }
  130. return false
  131. })
  132. // Recursively derive the height of the left and right sub-trees.
  133. lHeight := checkInvariants(require, node.left, node)
  134. rHeight := checkInvariants(require, node.right, node)
  135. // Validate the AVL invariant and the balance factor.
  136. require.Equal(int(node.balance), rHeight-lHeight)
  137. if lHeight > rHeight {
  138. return lHeight + 1
  139. }
  140. return rHeight + 1
  141. }