How Trampoline Works in Scala

Trampoline is a way to make non-tail recursive functions stack-safe. Its Scala implementation is explained by Rúnar Bjarnason in his paper, Stackless Scala With Free Monads, and his book, Functional Programming in Scala. Rúnar’s (old) blog also has a post illustrating the idea. In this post I’d like to apply this technique on a few simple, concrete examples, and show step-by-step how it works on these examples and why it is able to make them stack-safe.

The name “trampoline” may sound fancy or even intimidating (to me anyway, when I first learned this concept), but the basic intuition is pretty simple: instead of letting the JVM run a recursive function and push a new frame to the call stack each time the recursion is performed, we rewrite the recursive function in a way that we, rather than the JVM, have control of its execution. During the execution, we will build a structure which is essentially equivalent as the call stack, except that it is built on the heap.

This is best illustrated with examples.

Factorial

The simple, stack-unsafe factorial function is:

1
2
3
4
def unsafeFac(n: Int): Int = {
  if (n == 0) 1
  else n * unsafeFac(n - 1)
}

(please ignore integer overflow and negative numbers - they are irrelevant to this post)

When we call unsafeFac(5), the call stack will look like this:

Since it takes O(n) stack space, calling unsafeFac with a large n will result in StackOverflowError.

To make further discussion easier, let me rewrite the unsafeFac function in a slightly more verbose way:

1
2
3
4
5
6
7
def unsafeFac(n: Int): Int = {
  if (n == 0) return 1
  else {
    val x = unsafeFac(n - 1)
    return n * x
  }
}

To apply trampoline to this function, we first create a TailRec type (which will also be used by other examples in this post):

1
2
3
4
5
6
7
8
sealed trait TailRec[A] {
  def map[B](f: A => B): TailRec[B] = flatMap(f andThen (Return(_)))
  def flatMap[B](f: A => TailRec[B]): TailRec[B] = FlatMap(this, f)
}

final case class Return[A](a: A) extends TailRec[A]
final case class Suspend[A](resume: () => TailRec[A]) extends TailRec[A]
final case class FlatMap[A, B](sub: TailRec[A], k: A => TailRec[B]) extends TailRec[B]

We then rewrite the original recursive function using the TailRec type, in the following manner:

  • If the original function returns an A, the new function should return a TailRec[A].
  • Each return in the original function should be wrapped in a Return.
  • Each recursive call in the original function should be wrapped in a Suspend.
  • Things we do after the recursive call (in this case, multiply the result by n) should be wrapped in a FlatMap.

So our new “trampolined” factorial function is:

1
2
3
4
5
6
def fac(n: Int): TailRec[Int] = {
  if (n == 0) Return(1)
  else FlatMap[Int, Int](Suspend(() => fac(n - 1)), x => Return(n * x))
  // or equivalently:
  // else Suspend(() => fac(n - 1)).flatMap(x => Return(n * x))
}

The key things to note about the trampolined factorial function are:

  1. The returns (which pop stack frames) and recursive calls (which push stack frames) are gone - replaced by our own data types, Return and Suspend. This gives us control of how the new factorial function is executed.
  2. The Suspend class wraps a thunk (a function that takes no parameter). This makes it lazy: when we create a Suspend, the function it wraps is not evaluated. The function is only evaluated when we explicitly run it, which means we will only continue the recursion when we wish to do so.

To execute a trampolined function (i.e., to extract the A out of a TailRec[A]), we use the following tail-recursive run function:

1
2
3
4
5
6
7
8
9
def run[A](tr: TailRec[A]): A = tr match {
  case Return(a) => a
  case Suspend(r) => run(r())
  case FlatMap(x, f) => x match {
    case Return(a) => run(f(a))
    case Suspend(r) => run(FlatMap(r(), f))
    case FlatMap(y, g) => run(y.flatMap(g(_) flatMap f))
  }
}

Now let’s see what happens we we call run(fac(5)). As the first step, we need to evaluate fac(5). According to the definition of fac, fac(5) returns

FlatMap(Suspend(() => fac(4)), x => Return(5 * x))

Note how this FlatMap resembles the first frame in the call stack shown above. So now we have

run(FlatMap(Suspend(() => fac(4)), x => Return(5 * x)))

The argument to run is now fully evaluated, so we enter the run function. Note that we do not evaluate fac(4) at this point, because, as I just explained above, fac(4) is wrapped in a thunk in Suspend.

According to the definition of run (in particular, line 6), we now have

run(FlatMap(fac(4), x => Return(5 * x)))

Next we need to go back into the fac function to evaluate fac(4), which gives us

run(FlatMap(FlatMap(Suspend(() => fac(3)), x => Return (4 * x)), x => Return(5 * x)))

which looks very much like the first two frames of the call stack. Then we return to the run function to run this FlatMap, and we now have

run(FlatMap(fac(3), x => FlatMap(Return(4 * x), x => Return(5 * x))))

At this point we go back into the fac function again to evaluate fac(3), and the computation continues in a similar fashion. In the end, we will have built the following structure:

FlatMap(
  Return(1),
  x => Flatmap(
    Return(1 * x), x => Flatmap(
      Return(2 * x), x => Flatmap(
        Return(3 * x), x => Flatmap(
          Return(4 * x), x => Return(5 * x)
        )
      )
    )
  )
)

This FlatMap has the same structure as the call stack, except that it is on the heap. Running this FlatMap will give us the desired result, 120, in a stack-safe way.

It should now be obvious why this technique is called “trampoline”: during the execution of run(fac(5)), we keep jumping back and forth between run and fac.

Even and Odd

In this example we use two functions to check whether a number is even or odd:

1
2
3
4
5
6
7
8
9
def unsafeEven(n: Int): Boolean = {
  if (n == 0) true
  else unsafeOdd(n - 1)
}

def unsafeOdd(n: Int): Boolean = {
  if (n == 0) false
  else unsafeEven(n - 1)
}

These two functions are in fact tail recursive, but they are mutually tail recursive: they are defined in terms of each other. Scala cannot optimize for mutual tail recursions due to limitations of JVM (by contrast, Haskell can, so these two functions are stack-safe in Haskell), so passing a large n to either function will cause StackOverflowError.

We can trampoline these two functions in the same way as the factorial function:

1
2
3
4
5
6
7
8
9
def even(n: Int): TailRec[Boolean] = {
  if (n == 0) Return(true)
  else Suspend(() => odd(n - 1))
}

def odd(n: Int): TailRec[Boolean] = {
  if (n == 0) Return(false)
  else Suspend(() => even(n - 1))
}

We do not need FlatMap in this case because no further step is needed after the tail call. And because no FlatMap is involved, it doesn’t need to build a structure on the heap like it does for the factorial function. Running even and odd takes O(1) stack and O(1) heap.

Fibonacci

Now let’s trampline the following Fibonacci function:

1
2
3
4
def unsafeFib(n: Int): Int = {
  if (n <= 1) n
  else unsafeFib(n - 2) + unsafeFib(n - 1)
}

What makes the Fibonacci function slightly more interesting is that it makes two recursive calls in its body, although this doesn’t really add any new challenge - the way we trampoline the Fibonacci function is the same as before:

1
2
3
4
def fib(n: Int): TailRec[Int] = {
  if (n <= 1) Return(n)
  else Suspend(() => fib(n - 2)).flatMap(x => Suspend(() => fib(n - 1)).flatMap(y => Return(x + y)))
}

Whatever happens after the first recursive call (unsafeFib(n - 2)) is wrapped in the first flatMap, and whatever happens after the second recursive call (unsafeFib(n - 1)) is wrapped in the second flatMap.

Since flatMap(y => Return(x + y))) is the same as map(y => x + y), we can simplify the above implementation a little bit:

1
2
3
4
def fib(n: Int): TailRec[Int] = {
  if (n <= 1) Return(n)
  else Suspend(() => fib(n - 2)).flatMap(x => Suspend(() => fib(n - 1)).map(x + _))
}

We can also turn a sequence of flatMaps followed by a map into Scala’s for-comprehension:

1
2
3
4
5
6
7
def fib(n: Int): TailRec[Int] = {
  if (n <= 1) Return(n)
  else for {
    x <- Suspend(() => fib(n - 2))
    y <- Suspend(() => fib(n - 1))
  } yield x + y
}

Again, running the fib function will build a structure on the heap similar as the call stack for the unsafeFib function.

So having two recursive calls is not so difficult. What about an arbitrary number of recursive calls?

Map over Tree

In the final example we shall play with a recursive data type representing an n-ary tree:

1
2
3
4
5
sealed trait Tree[A] {
  def label: A
}
final case class Leaf[A](label: A) extends Tree[A]
final case class Node[A](label: A, children: List[Tree[A]]) extends Tree[A]

The following is a recursive function that maps on the labels of tree nodes:

1
2
3
4
def unsafeTreeMap[A, B](tree: Tree[A], f: A => B): Tree[B] = tree match {
  case Leaf(a) => Leaf(f(a))
  case Node(a, children) => Node(f(a), children.map(unsafeTreeMap(_, f)))
}

Using the same approach as before: wrapping returns in Returns and wrapping recursive calls in Suspends, this is what we get:

1
2
3
4
5
6
def treeMap[A, B](tree: Tree[A], f: A => B): TailRec[Tree[B]] = tree match {
  case Leaf(a) => Return(Leaf(f(a)))
  case Node(a, children) =>
    val ltt: List[TailRec[Tree[B]]] = children.map(child => Suspend(() => treeMap(child, f)))
    ???
}

The Leaf case is trivial, but in the Node case, what should we do after getting a List[TailRec[Tree[B]]]? It’s not quite obvious. It would be much better if we had a TailRec[List[Tree[B]]] instead; if that’s the case we can simply proceed with flatMap to get what we want:

1
2
3
4
5
6
7
8
def treeMap[A, B](tree: Tree[A], f: A => B): TailRec[Tree[B]] = tree match {
  case Leaf(a) => Return(Leaf(f(a)))
  case Node(a, children) =>
    val ltt: List[TailRec[Tree[B]]] = children.map(child => Suspend(() => treeMap(child, f)))
    val tlt: TailRec[List[Tree[B]]] = ???
    tlt.flatMap(lt => Return(Node(f(a), lt)))
    // or equivalently: tlt.map(Node(f(a), _))
}

It turns out that converting a List[TailRec[Tree[B]]] to a TailRec[List[Tree[B]]] is a standard operation in functional programming, known as sequence. Scala has a sequence method for Future, but not for List. Multiple open-source libraries provide the sequence method for List as part of the Traverse type class, including Scalaz, Cats and Structures. Here let’s just implement our own version:

1
2
3
4
def sequence[A](ltt: List[TailRec[A]]): TailRec[List[A]] =
  ltt.reverse.foldLeft(Return(Nil): TailRec[List[A]]) { (tla, ta) =>
    ta map ((_: A) :: (_: List[A])).curried flatMap tla.map
  }

Now we can complete our implementation of the trampolined treeMap:

1
2
3
4
5
6
7
def treeMap[A, B](tree: Tree[A], f: A => B): TailRec[Tree[B]] = tree match {
  case Leaf(a) => Return(Leaf(f(a)))
  case Node(a, children) =>
    val ltt: List[TailRec[Tree[B]]] = children.map(child => Suspend(() => treeMap(child, f)))
    val tlt: TailRec[List[Tree[B]]] = sequence(ltt)
    tlt.map(Node(f(a), _))
}

A Final Word on Laziness and Stack-Safety

The key reason why trampolined functions are stack-safe is because Suspend is lazy, in other words, the recursion happens in a lazy structure. Generally speaking, lazy recursions tend to be stack-safe, even if they are not tail recursions. For example, the following function:

def func(x: Int): Stream[Int] = x #:: func(x + 1)

is stack safe, even though it is not tail recursive, because the tail of a Stream is lazy. Calling

println(func(1).take(n).toList)

with a large n will not cause StackOverflowError. The execution simply trampolines between func and take.