Stream, Laziness and Stack Safety

This post contains some notes on chapter 5 (Strictness and laziness) of the Functional Programming in Scala book (a.k.a The Red Book). The chapter covers the basics of non-strict and lazy evaluation in Scala, as well as an implementation of Stream, a lazy version of List.

Thunk vs. By-Name Parameter

A thunk is just a Function0: () => A, which takes an empty parameter list and produces a value of type A. A function value is considered fully evaluated (i.e., it is in normal form) even if it takes an empty parameter list. For example, when you create a function literal

val foo: () => Int = () => 2 + 3

The expression 2 + 3 is not evaluated when foo is created. It is only evaluated to normal form (i.e., 5) when you explicitly force it by calling foo(). Therefore, thunks can be used to make a method lazy in its parameters: just turn each parameter of type A into a parameter of type () => A.

A by-name parameter a: => A is also lazy, and is usually more convenient than thunks. If you have a function foo(a: => A), the caller can just pass an A to foo, rather than having to convert it to a thunk () => A.

In some cases though, you have to use thunks, for example in case class constructors. A case class constructor cannot take by-name parameters, because each case class constructor parameter is treated as a val parameter, and val parameters cannot be by-name. This is why the Stream constructor Cons has to using thunks rather than by-name parameters. On the other hand, the smart constructor cons takes by-name parameters.

sealed trait Stream[+A]
case object Empty extends Stream[Nothing]
case class Cons[+A](h: () => A, t: () => Stream[A]) extends Stream[A]

object Stream {
  def cons[A](hd: => A, tl: => Stream[A]): Stream[A] = {
    lazy val head = hd
    lazy val tail = tl
    Cons(() => head, () => tail)
  }

  def empty[A]: Stream[A] = Empty
}

Laziness and Stack Safety

Take a look at the following implementation of toList and take on Stream:

def toListUnsafe: List[A] = this match {
  case Cons(h, t) => h() :: t().toListUnsafe
  case _ => List()
}


def take(n: Int): Stream[A] = this match {
  case Cons(h, t) if n > 1 => cons(h(), t().take(n - 1))
  case Cons(h, _) if n == 1 => cons(h(), empty)
  case _ => empty
}

Both methods are recursive but not tail recursive. In terms of stack safety, toListUnsafe is not stack-safe: it can cause StackOverflowError when called on a large stream. To make it stack-safe we need to turn it into tail recursion using an inner helper function:

def toList: List[A] = {
  val buf = ListBuffer[A]()
  @tailrec
  def go(s: Stream[A]): List[A] = s match {
    case Cons(head, tail) =>
      buf += head()
      go(tail())
    case _ => buf.toList
  }
  go(this)
}

On the other hand, take is stack-safe and there is no need to make it tail recursive. The reason for the difference is where the recursion happens. In toListUnsafe, the recursion happens in the list constructor ::, which is strict in both parameters. This means we must fully evaluate both h() and t().toListUnsafe before starting to evaluate ::. In take, on the other hand, the recursion happens in the stream constructor cons, which is lazy in both parameters, so we can start evaluaing cons without evaluating h() or t().take(n - 1).

In fact, this is exactly how trampolining works.

To make this easier to see, let’s use a concrete example. Suppose we have an infinite stream s: Stream[Int] = 1, 2, 3,... and let’s see what happens when we call s.take(5).toList.

Since s matches the first case in take, we call cons(h(), Stream(2, 3,...).take(4)), without evaluating h() and t().toList, which returns Cons(() => h(), () => Stream(2, 3,...).take(4)). Note that this is a fully evaluated value and there is nothing to further reduce. So at this point, the take method returns, and its stack frame is popped off the stack.

Next we call toList on the value returned by s.take(5), i.e., Cons(() => h(), () => Stream(2, 3,...).take(4)). It matches the Cons(head, tail) case where head = () => h() and tail = () => Stream(2, 3,...).take(4)). When buf += head() is called, head() is fully evaluated into an A in order to be added to buf. In this case h() evaluates to 1, and we add 1 to buf. We then recursively call go(tail()) (since it is tail recursion, a recursive call does not consume an additional stack frame). go is strict, so tail() needs to be fully evaluated before go() is called.

Evaluating tail() means evaluating Stream(2, 3,...).take(4). Similar as before, it is evaluated into Cons(() => h(), () => Stream(3, 4,...).take(3)). The execution continues till we have taken 5 elements from the original stream. During this process, we alternate between the execution of take and toList, in other words, we are trampoling.

This is also why it doesn’t create an intermediate list when we call Stream(1, 2, 3, 4).map(_ + 10).filter(_ % 2 == 0).toList. It basically trampolines between map, filter and toList, and processes one element at a time.

What happens if we use toListUnsafe, as in s.take(n).toListUnsafe? We will still be trampolining between take and toListUnsafe, but it won’t be stack-safe: each call to take will not consume an additional stack frame because it returns a Cons without immediately going into the next level of recursion; however, each recursive call to toListUnsafe will consume an additional stack frame since it is not tail recursion. This will cause StackOverflowError for a large n.

foldRight Implementation

In The Red Book, the foldRight method on Stream is implemented in the following way:

def foldRight[B](z: => B)(f: (A, => B) => B): B =
  this match {
    case Cons(h,t) => f(h(), t().foldRight(z)(f))
    case _ => z
  }

Note that f takes the first parameter by value and second parameter by name. Taking the second parameter by-name enables foldRight to be used on large or even infinite streams, and it is always stack-safe.

I don’t think there’s any reason though, that the first parameter shouldn’t also be by-name. In fact, I think it should be - there’s no reason that the first parameter A has to be fully evaluated before f is called. Here’s another implementation where both parameters of f are by-name:

def foldRight2[B](z: => B)(f: (=> A, => B) => B): B =
  this match {
    case Cons(h,t) => f(h(), t().foldRight(z)(f))
    case _ => z
  }

Let’s implement map via foldRight and foldRight2, respectively, and see how they behave differently when mapping over a stream:

def map[B](f: A => B): Stream[B] = foldRight(empty[B])((h, t) => cons(f(h), t))

def map2[B](f: A => B): Stream[B] = foldRight2(empty[B])((h, t) => cons(f(h), t))

Let’s create the following stream:

val s: Stream[Int] =
  cons(
    {println("One"); 1},
    cons(
      {println("Two"); 2},
      cons(
        {println("Three"); 3},
        cons({println("Four"); 4}, empty)
      )
    )
  )

Calling s.map(_ + 10) prints One, indicating that the head element in s is evaluated, while calling s.map2(_ + 10) does not. Thus foldRight2, in which both parameters to f are by-name, is useful if you want to transform a stream without forcing even a single element.

The Scala standard library also has a Stream type that is similar to the one in The Red Book, except that the head element in Scala’s stream is strict, which means whenever you have a non-empty stream, the head element is always fully evaluated. If the above s is a Scala stream, One will be printed immediately when s is created.