Trampoline is a way to make non-tail recursive functions stack-safe. Its Scala implementation is explained by Rúnar Bjarnason in his paper, Stackless Scala With Free Monads, and his book, Functional Programming in Scala. Rúnar’s (old) blog also has a post illustrating the idea. In this post I’d like to apply this technique on a few simple, concrete examples, and show step-by-step how it works on these examples and why it is able to make them stack-safe.

The name “trampoline” may sound fancy or even intimidating (to me anyway, when I first learned this concept), but the basic intuition is pretty simple: instead of letting the JVM run a recursive function and push a new frame to the call stack each time the recursion is performed, we rewrite the recursive function in a way that we, rather than the JVM, have control of its execution. During the execution, we will build a structure which is essentially equivalent as the call stack, except that it is built on the heap.

This is best illustrated with examples.

# Factorial

The simple, stack-unsafe factorial function is:

1
2
3
4

def unsafeFac(n: Int): Int = {
if (n == 0) 1
else n * unsafeFac(n - 1)
}

(please ignore integer overflow and negative numbers - they are irrelevant to this post)

When we call `unsafeFac(5)`

, the call stack will look like this:

Since it takes O(n) stack space, calling `unsafeFac`

with a large `n`

will result in `StackOverflowError`

.

To make further discussion easier, let me rewrite the `unsafeFac`

function in a slightly more verbose way:

1
2
3
4
5
6
7

def unsafeFac(n: Int): Int = {
if (n == 0) return 1
else {
val x = unsafeFac(n - 1)
return n * x
}
}

To apply trampoline to this function, we first create a `TailRec`

type (which will also be used by other examples in this post):

1
2
3
4
5
6
7
8

sealed trait TailRec[A] {
def map[B](f: A => B): TailRec[B] = flatMap(f andThen (Return(_)))
def flatMap[B](f: A => TailRec[B]): TailRec[B] = FlatMap(this, f)
}
final case class Return[A](a: A) extends TailRec[A]
final case class Suspend[A](resume: () => TailRec[A]) extends TailRec[A]
final case class FlatMap[A, B](sub: TailRec[A], k: A => TailRec[B]) extends TailRec[B]

We then rewrite the original recursive function using the `TailRec`

type, in the following manner:

- If the original function returns an
`A`

, the new function should return a`TailRec[A]`

. - Each
`return`

in the original function should be wrapped in a`Return`

. - Each recursive call in the original function should be wrapped in a
`Suspend`

. - Things we do after the recursive call (in this case, multiply the result by
`n`

) should be wrapped in a`FlatMap`

.

So our new “trampolined” factorial function is:

1
2
3
4
5
6

def fac(n: Int): TailRec[Int] = {
if (n == 0) Return(1)
else FlatMap[Int, Int](Suspend(() => fac(n - 1)), x => Return(n * x))
// or equivalently:
// else Suspend(() => fac(n - 1)).flatMap(x => Return(n * x))
}

The key things to note about the trampolined factorial function are:

- The
`return`

s (which pop stack frames) and recursive calls (which push stack frames) are gone - replaced by our own data types,`Return`

and`Suspend`

. This gives us control of how the new factorial function is executed. - The
`Suspend`

class wraps a*thunk*(a function that takes no parameter). This makes it lazy: when we create a`Suspend`

, the function it wraps is not evaluated. The function is only evaluated when we explicitly run it, which means we will only continue the recursion when we wish to do so.

To execute a trampolined function (i.e., to extract the `A`

out of a `TailRec[A]`

), we use the following tail-recursive `run`

function:

1
2
3
4
5
6
7
8
9

def run[A](tr: TailRec[A]): A = tr match {
case Return(a) => a
case Suspend(r) => run(r())
case FlatMap(x, f) => x match {
case Return(a) => run(f(a))
case Suspend(r) => run(FlatMap(r(), f))
case FlatMap(y, g) => run(y.flatMap(g(_) flatMap f))
}
}

Now let’s see what happens we we call `run(fac(5))`

. As the first step, we need to evaluate `fac(5)`

. According to the definition of `fac`

, `fac(5)`

returns

```
FlatMap(Suspend(() => fac(4)), x => Return(5 * x))
```

Note how this `FlatMap`

resembles the first frame in the call stack shown above. So now we have

```
run(FlatMap(Suspend(() => fac(4)), x => Return(5 * x)))
```

The argument to `run`

is now fully evaluated, so we enter the `run`

function. Note that we do *not* evaluate `fac(4)`

at this point, because, as I just explained above, `fac(4)`

is wrapped in a thunk in `Suspend`

.

According to the definition of `run`

(in particular, line 6), we now have

```
run(FlatMap(fac(4), x => Return(5 * x)))
```

Next we need to go back into the `fac`

function to evaluate `fac(4)`

, which gives us

```
run(FlatMap(FlatMap(Suspend(() => fac(3)), x => Return (4 * x)), x => Return(5 * x)))
```

which looks very much like the first two frames of the call stack. Then we return to the `run`

function to run this `FlatMap`

, and we now have

```
run(FlatMap(fac(3), x => FlatMap(Return(4 * x), x => Return(5 * x))))
```

At this point we go back into the `fac`

function again to evaluate `fac(3)`

, and the computation continues in a similar fashion. In the end, we will have built the following structure:

```
FlatMap(
Return(1),
x => Flatmap(
Return(1 * x), x => Flatmap(
Return(2 * x), x => Flatmap(
Return(3 * x), x => Flatmap(
Return(4 * x), x => Return(5 * x)
)
)
)
)
)
```

This `FlatMap`

has the same structure as the call stack, except that it is on the heap. Running this `FlatMap`

will give us the desired result, 120, in a stack-safe way.

It should now be obvious why this technique is called “trampoline”: during the execution of `run(fac(5))`

, we keep jumping back and forth between `run`

and `fac`

.

# Even and Odd

In this example we use two functions to check whether a number is even or odd:

1
2
3
4
5
6
7
8
9

def unsafeEven(n: Int): Boolean = {
if (n == 0) true
else unsafeOdd(n - 1)
}
def unsafeOdd(n: Int): Boolean = {
if (n == 0) false
else unsafeEven(n - 1)
}

These two functions are in fact tail recursive, but they are *mutually tail recursive*: they are defined in terms of each other. Scala cannot optimize for mutual tail recursions due to limitations of JVM (by contrast, Haskell can, so these two functions are stack-safe in Haskell), so passing a large `n`

to either function will cause StackOverflowError.

We can trampoline these two functions in the same way as the factorial function:

1
2
3
4
5
6
7
8
9

def even(n: Int): TailRec[Boolean] = {
if (n == 0) Return(true)
else Suspend(() => odd(n - 1))
}
def odd(n: Int): TailRec[Boolean] = {
if (n == 0) Return(false)
else Suspend(() => even(n - 1))
}

We do not need `FlatMap`

in this case because no further step is needed after the tail call. And because no `FlatMap`

is involved, it doesn’t need to build a structure on the heap like it does for the factorial function. Running `even`

and `odd`

takes O(1) stack and O(1) heap.

# Fibonacci

Now let’s trampline the following Fibonacci function:

1
2
3
4

def unsafeFib(n: Int): Int = {
if (n <= 1) n
else unsafeFib(n - 2) + unsafeFib(n - 1)
}

What makes the Fibonacci function slightly more interesting is that it makes two recursive calls in its body, although this doesn’t really add any new challenge - the way we trampoline the Fibonacci function is the same as before:

1
2
3
4

def fib(n: Int): TailRec[Int] = {
if (n <= 1) Return(n)
else Suspend(() => fib(n - 2)).flatMap(x => Suspend(() => fib(n - 1)).flatMap(y => Return(x + y)))
}

Whatever happens after the first recursive call (`unsafeFib(n - 2)`

) is wrapped in the first `flatMap`

, and whatever happens after the second recursive call (`unsafeFib(n - 1)`

) is wrapped in the second `flatMap`

.

Since `flatMap(y => Return(x + y)))`

is the same as `map(y => x + y)`

, we can simplify the above implementation a little bit:

1
2
3
4

def fib(n: Int): TailRec[Int] = {
if (n <= 1) Return(n)
else Suspend(() => fib(n - 2)).flatMap(x => Suspend(() => fib(n - 1)).map(x + _))
}

We can also turn a sequence of `flatMap`

s followed by a `map`

into Scala’s for-comprehension:

1
2
3
4
5
6
7

def fib(n: Int): TailRec[Int] = {
if (n <= 1) Return(n)
else for {
x <- Suspend(() => fib(n - 2))
y <- Suspend(() => fib(n - 1))
} yield x + y
}

Again, running the `fib`

function will build a structure on the heap similar as the call stack for the `unsafeFib`

function.

So having two recursive calls is not so difficult. What about an arbitrary number of recursive calls?

# Map over Tree

In the final example we shall play with a recursive data type representing an n-ary tree:

1
2
3
4
5

sealed trait Tree[A] {
def label: A
}
final case class Leaf[A](label: A) extends Tree[A]
final case class Node[A](label: A, children: List[Tree[A]]) extends Tree[A]

The following is a recursive function that maps on the labels of tree nodes:

1
2
3
4

def unsafeTreeMap[A, B](tree: Tree[A], f: A => B): Tree[B] = tree match {
case Leaf(a) => Leaf(f(a))
case Node(a, children) => Node(f(a), children.map(unsafeTreeMap(_, f)))
}

Using the same approach as before: wrapping `return`

s in `Return`

s and wrapping recursive calls in `Suspend`

s, this is what we get:

1
2
3
4
5
6

def treeMap[A, B](tree: Tree[A], f: A => B): TailRec[Tree[B]] = tree match {
case Leaf(a) => Return(Leaf(f(a)))
case Node(a, children) =>
val ltt: List[TailRec[Tree[B]]] = children.map(child => Suspend(() => treeMap(child, f)))
???
}

The `Leaf`

case is trivial, but in the `Node`

case, what should we do after getting a `List[TailRec[Tree[B]]]`

? It’s not quite obvious. It would be much better if we had a `TailRec[List[Tree[B]]]`

instead; if that’s the case we can simply proceed with `flatMap`

to get what we want:

1
2
3
4
5
6
7
8

def treeMap[A, B](tree: Tree[A], f: A => B): TailRec[Tree[B]] = tree match {
case Leaf(a) => Return(Leaf(f(a)))
case Node(a, children) =>
val ltt: List[TailRec[Tree[B]]] = children.map(child => Suspend(() => treeMap(child, f)))
val tlt: TailRec[List[Tree[B]]] = ???
tlt.flatMap(lt => Return(Node(f(a), lt)))
// or equivalently: tlt.map(Node(f(a), _))
}

It turns out that converting a `List[TailRec[Tree[B]]]`

to a `TailRec[List[Tree[B]]]`

is a standard operation in functional programming, known as `sequence`

. Scala has a `sequence`

method for `Future`

, but not for `List`

. Multiple open-source libraries provide the `sequence`

method for `List`

as part of the `Traverse`

type class, including Scalaz, Cats and Structures. Here let’s just implement our own version:

1
2
3
4

def sequence[A](ltt: List[TailRec[A]]): TailRec[List[A]] =
ltt.reverse.foldLeft(Return(Nil): TailRec[List[A]]) { (tla, ta) =>
ta map ((_: A) :: (_: List[A])).curried flatMap tla.map
}

Now we can complete our implementation of the trampolined `treeMap`

:

1
2
3
4
5
6
7

def treeMap[A, B](tree: Tree[A], f: A => B): TailRec[Tree[B]] = tree match {
case Leaf(a) => Return(Leaf(f(a)))
case Node(a, children) =>
val ltt: List[TailRec[Tree[B]]] = children.map(child => Suspend(() => treeMap(child, f)))
val tlt: TailRec[List[Tree[B]]] = sequence(ltt)
tlt.map(Node(f(a), _))
}

# A Final Word on Laziness and Stack-Safety

The key reason why trampolined functions are stack-safe is because `Suspend`

is lazy, in other words, the recursion happens in a lazy structure. Generally speaking, lazy recursions tend to be stack-safe, even if they are not tail recursions. For example, the following function:

```
def func(x: Int): Stream[Int] = x #:: func(x + 1)
```

is stack safe, even though it is not tail recursive, because the tail of a `Stream`

is lazy. Calling

```
println(func(1).take(n).toList)
```

with a large `n`

will not cause `StackOverflowError`

. The execution simply trampolines between `func`

and `take`

.