avl.go 10 KB


  1. // avl.go - An AVL tree implementation.
  2. //
  3. // To the extent possible under law, Yawning Angel has waived all copyright
  4. // and related or neighboring rights to avl, using the Creative
  5. // Commons "CC0" public domain dedication. See LICENSE or
  6. // <http://creativecommons.org/publicdomain/zero/1.0/> for full details.
  7. // Package avl implements an AVL tree.
  8. package avl
  9. // This is a fairly straight forward adaptation of the CC0 C implementation
  10. // from https://github.com/ebiggers/avl_tree/ by Eric Biggers into what is
  11. // hopefully idiomatic Go.
  12. //
  13. // The primary differences from the original package are:
  14. // * The balance factor is not stored separately from the parent pointer.
  15. // * The container is non-intrusive.
  16. // * Only in-order traversal is currently supported.
  17. import "errors"
  18. var (
  19. errNoCmpFn = errors.New("avl: no comparison function")
  20. errNotInTree = errors.New("avl: element not in tree")
  21. errInvalidDirection = errors.New("avl: invalid direction")
  22. )
  23. // CompareFunc is the function used to compare entries in the Tree to maintain
  24. // ordering. It MUST return < 0, 0, or > 0 if a is less than, equal to, or
  25. // greater than b respectively.
  26. //
  27. // Note: All calls made to the comparison function will pass the user supplied
  28. // value as a, and the in-Tree value as b.
  29. type CompareFunc func(a, b interface{}) int
  30. // Direction is the direction associated with an iterator.
  31. type Direction int8
  32. const (
  33. // Backward is backward in-order.
  34. Backward Direction = -1
  35. // Forward is forward in-order.
  36. Forward Direction = 1
  37. )
  38. // Iterator is a Tree iterator. Modifying the Tree while iterating is
  39. // unsupported except for removing the current Node.
  40. type Iterator struct {
  41. tree *Tree
  42. cur, next *Node
  43. sign int8
  44. initialized bool
  45. }
  46. // First moves the iterator to the first Node in the Tree and returns the
  47. // first Node or nil iff the Tree is empty. Note that "first" in this context
  48. // is dependent on the direction specified when constructing the iterator.
  49. func (it *Iterator) First() *Node {
  50. it.cur, it.next = it.tree.firstOrLastInOrder(-it.sign), nil
  51. if it.cur != nil {
  52. it.next = it.cur.nextOrPrevInOrder(it.sign)
  53. }
  54. it.initialized = true
  55. return it.cur
  56. }
  57. // Get returns the Node currently pointed to by the iterator. It is safe to
  58. // remove the Node returned from the Tree.
  59. func (it *Iterator) Get() *Node {
  60. if !it.initialized {
  61. return it.First()
  62. }
  63. return it.cur
  64. }
  65. // Next advances the iterator and returns the Node or nil iff the end of the
  66. // Tree has been reached.
  67. func (it *Iterator) Next() *Node {
  68. if !it.initialized {
  69. it.First()
  70. }
  71. if it.next == nil {
  72. return nil
  73. }
  74. it.cur = it.next
  75. it.next = it.cur.nextOrPrevInOrder(it.sign)
  76. return it.cur
  77. }
  78. // Node is a node of a Tree.
  79. type Node struct {
  80. // Value is the value stored by the Node.
  81. Value interface{}
  82. parent, left, right *Node
  83. balance int8
  84. }
  85. func (n *Node) reset() {
  86. // Note: This deliberately leaves Value intact.
  87. n.parent, n.left, n.right = n, nil, nil
  88. n.balance = 0
  89. }
  90. func (n *Node) setParentBalance(parent *Node, balance int8) {
  91. n.parent = parent
  92. n.balance = balance
  93. }
  94. func (n *Node) getChild(sign int8) *Node {
  95. if sign < 0 {
  96. return n.left
  97. }
  98. return n.right
  99. }
  100. func (n *Node) nextOrPrevInOrder(sign int8) *Node {
  101. var next, tmp *Node
  102. if next = n.getChild(+sign); next != nil {
  103. for {
  104. tmp = next.getChild(-sign)
  105. if tmp == nil {
  106. break
  107. }
  108. next = tmp
  109. }
  110. } else {
  111. tmp, next = n, n.parent
  112. for next != nil && tmp == next.getChild(+sign) {
  113. tmp, next = next, next.parent
  114. }
  115. }
  116. return next
  117. }
  118. func (n *Node) setChild(sign int8, child *Node) {
  119. if sign < 0 {
  120. n.left = child
  121. } else {
  122. n.right = child
  123. }
  124. }
  125. func (n *Node) adjustBalanceFactor(amount int8) {
  126. n.balance += amount
  127. }
  128. // Tree represents an AVL tree.
  129. type Tree struct {
  130. root *Node
  131. cmpFn CompareFunc
  132. size int
  133. }
  134. // Len returns the number of elements in the Tree.
  135. func (t *Tree) Len() int {
  136. return t.size
  137. }
  138. // First returns the first node in the Tree (in-order) or nil iff the Tree is
  139. // empty.
  140. func (t *Tree) First() *Node {
  141. return t.firstOrLastInOrder(-1)
  142. }
  143. // Last returns the last element in the Tree (in-order) or nil iff the Tree is
  144. // empty.
  145. func (t *Tree) Last() *Node {
  146. return t.firstOrLastInOrder(1)
  147. }
  148. // Find finds the value in the Tree, and returns the Node or nil iff the value
  149. // is not present.
  150. func (t *Tree) Find(v interface{}) *Node {
  151. if t.cmpFn == nil {
  152. panic(errNoCmpFn)
  153. }
  154. cur := t.root
  155. descendLoop:
  156. for cur != nil {
  157. cmp := t.cmpFn(v, cur.Value)
  158. switch {
  159. case cmp < 0:
  160. cur = cur.left
  161. case cmp > 0:
  162. cur = cur.right
  163. default:
  164. break descendLoop
  165. }
  166. }
  167. return cur
  168. }
  169. // Insert inserts the value into the Tree, and returns the newly created Node
  170. // or the existing Node iff the value is already present in the tree.
  171. func (t *Tree) Insert(v interface{}) *Node {
  172. if t.cmpFn == nil {
  173. panic(errNoCmpFn)
  174. }
  175. var cur *Node
  176. curPtr := &t.root
  177. for *curPtr != nil {
  178. cur = *curPtr
  179. cmp := t.cmpFn(v, cur.Value)
  180. switch {
  181. case cmp < 0:
  182. curPtr = &cur.left
  183. case cmp > 0:
  184. curPtr = &cur.right
  185. default:
  186. return cur
  187. }
  188. }
  189. n := &Node{
  190. Value: v,
  191. parent: cur,
  192. balance: 0,
  193. }
  194. *curPtr = n
  195. t.rebalanceAfterInsert(n)
  196. t.size++
  197. return n
  198. }
  199. // Remove removes the Node from the Tree.
  200. func (t *Tree) Remove(node *Node) {
  201. var parent *Node
  202. var leftDeleted bool
  203. if node.parent == node {
  204. panic(errNotInTree)
  205. }
  206. t.size--
  207. if node.left != nil && node.right != nil {
  208. parent, leftDeleted = t.swapWithSuccessor(node)
  209. } else {
  210. child := node.left
  211. if child == nil {
  212. child = node.right
  213. }
  214. parent = node.parent
  215. if parent != nil {
  216. if node == parent.left {
  217. parent.left = child
  218. leftDeleted = true
  219. } else {
  220. parent.right = child
  221. leftDeleted = false
  222. }
  223. if child != nil {
  224. child.parent = parent
  225. }
  226. } else {
  227. if child != nil {
  228. child.parent = parent
  229. }
  230. t.root = child
  231. node.reset()
  232. return
  233. }
  234. }
  235. for {
  236. if leftDeleted {
  237. parent = t.handleSubtreeShrink(parent, +1, &leftDeleted)
  238. } else {
  239. parent = t.handleSubtreeShrink(parent, -1, &leftDeleted)
  240. }
  241. if parent == nil {
  242. break
  243. }
  244. }
  245. node.reset()
  246. }
  247. // Iterator returns an iterator that traverses the tree (in-order) in the
  248. // specified direction. Modifying the Tree while iterating is unsupported
  249. // except for removing the current Node.
  250. func (t *Tree) Iterator(direction Direction) *Iterator {
  251. switch direction {
  252. case Forward, Backward:
  253. default:
  254. panic(errInvalidDirection)
  255. }
  256. return &Iterator{
  257. tree: t,
  258. sign: int8(direction),
  259. }
  260. }
  261. // ForEach executes a function for each Node in the tree, visiting the nodes
  262. // in-order in the direction specified. If the provided function returns
  263. // false, the iteration is stopped. Modifying the Tree from within the
  264. // function is unsupprted except for removing the current Node.
  265. func (t *Tree) ForEach(direction Direction, fn func(*Node) bool) {
  266. it := t.Iterator(direction)
  267. for node := it.Get(); node != nil; node = it.Next() {
  268. if !fn(node) {
  269. return
  270. }
  271. }
  272. }
  273. func (t *Tree) firstOrLastInOrder(sign int8) *Node {
  274. first := t.root
  275. if first != nil {
  276. for {
  277. tmp := first.getChild(+sign)
  278. if tmp == nil {
  279. break
  280. }
  281. first = tmp
  282. }
  283. }
  284. return first
  285. }
  286. func (t *Tree) replaceChild(parent, oldChild, newChild *Node) {
  287. if parent != nil {
  288. if oldChild == parent.left {
  289. parent.left = newChild
  290. } else {
  291. parent.right = newChild
  292. }
  293. } else {
  294. t.root = newChild
  295. }
  296. }
  297. func (t *Tree) rotate(a *Node, sign int8) {
  298. b := a.getChild(-sign)
  299. e := b.getChild(+sign)
  300. p := a.parent
  301. a.setChild(-sign, e)
  302. a.parent = b
  303. b.setChild(+sign, a)
  304. b.parent = p
  305. if e != nil {
  306. e.parent = a
  307. }
  308. t.replaceChild(p, a, b)
  309. }
  310. func (t *Tree) doDoubleRotate(b, a *Node, sign int8) *Node {
  311. e := b.getChild(+sign)
  312. f := e.getChild(-sign)
  313. g := e.getChild(+sign)
  314. p := a.parent
  315. eBal := e.balance
  316. a.setChild(-sign, g)
  317. aBal := -eBal
  318. if sign*eBal >= 0 {
  319. aBal = 0
  320. }
  321. a.setParentBalance(e, aBal)
  322. b.setChild(+sign, f)
  323. bBal := -eBal
  324. if sign*eBal <= 0 {
  325. bBal = 0
  326. }
  327. b.setParentBalance(e, bBal)
  328. e.setChild(+sign, a)
  329. e.setChild(-sign, b)
  330. e.setParentBalance(p, 0)
  331. if g != nil {
  332. g.parent = a
  333. }
  334. if f != nil {
  335. f.parent = b
  336. }
  337. t.replaceChild(p, a, e)
  338. return e
  339. }
  340. func (t *Tree) handleSubtreeGrowth(node, parent *Node, sign int8) bool {
  341. oldBalanceFactor := parent.balance
  342. if oldBalanceFactor == 0 {
  343. parent.adjustBalanceFactor(sign)
  344. return false
  345. }
  346. newBalanceFactor := oldBalanceFactor + sign
  347. if newBalanceFactor == 0 {
  348. parent.adjustBalanceFactor(sign)
  349. return true
  350. }
  351. if sign*node.balance > 0 {
  352. t.rotate(parent, -sign)
  353. parent.adjustBalanceFactor(-sign)
  354. node.adjustBalanceFactor(-sign)
  355. } else {
  356. t.doDoubleRotate(node, parent, -sign)
  357. }
  358. return true
  359. }
  360. func (t *Tree) rebalanceAfterInsert(inserted *Node) {
  361. node, parent := inserted, inserted.parent
  362. switch {
  363. case parent == nil:
  364. return
  365. case node == parent.left:
  366. parent.adjustBalanceFactor(-1)
  367. default:
  368. parent.adjustBalanceFactor(+1)
  369. }
  370. if parent.balance == 0 {
  371. return
  372. }
  373. for done := false; !done; {
  374. node = parent
  375. if parent = node.parent; parent == nil {
  376. return
  377. }
  378. if node == parent.left {
  379. done = t.handleSubtreeGrowth(node, parent, -1)
  380. } else {
  381. done = t.handleSubtreeGrowth(node, parent, +1)
  382. }
  383. }
  384. }
  385. func (t *Tree) swapWithSuccessor(x *Node) (*Node, bool) {
  386. var ret *Node
  387. var leftDeleted bool
  388. y := x.right
  389. if y.left == nil {
  390. ret = y
  391. } else {
  392. var q *Node
  393. for {
  394. q = y
  395. if y = y.left; y.left == nil {
  396. break
  397. }
  398. }
  399. if q.left = y.right; q.left != nil {
  400. q.left.parent = q
  401. }
  402. y.right = x.right
  403. x.right.parent = y
  404. ret = q
  405. leftDeleted = true
  406. }
  407. y.left = x.left
  408. x.left.parent = y
  409. y.parent = x.parent
  410. y.balance = x.balance
  411. t.replaceChild(x.parent, x, y)
  412. return ret, leftDeleted
  413. }
  414. func (t *Tree) handleSubtreeShrink(parent *Node, sign int8, leftDeleted *bool) *Node {
  415. oldBalanceFactor := parent.balance
  416. if oldBalanceFactor == 0 {
  417. parent.adjustBalanceFactor(sign)
  418. return nil
  419. }
  420. var node *Node
  421. newBalanceFactor := oldBalanceFactor + sign
  422. if newBalanceFactor == 0 {
  423. parent.adjustBalanceFactor(sign)
  424. node = parent
  425. } else {
  426. node = parent.getChild(sign)
  427. if sign*node.balance >= 0 {
  428. t.rotate(parent, -sign)
  429. if node.balance == 0 {
  430. node.adjustBalanceFactor(-sign)
  431. return nil
  432. }
  433. parent.adjustBalanceFactor(-sign)
  434. node.adjustBalanceFactor(-sign)
  435. } else {
  436. node = t.doDoubleRotate(node, parent, -sign)
  437. }
  438. }
  439. if parent = node.parent; parent != nil {
  440. *leftDeleted = node == parent.left
  441. }
  442. return parent
  443. }
  444. // New returns an initialized Tree.
  445. func New(cmpFn CompareFunc) *Tree {
  446. if cmpFn == nil {
  447. panic(errNoCmpFn)
  448. }
  449. return &Tree{cmpFn: cmpFn}
  450. }