A Gentle Run-through of Continuation Passing Style and Its Use Cases

This is yet another CPS introduction post, with a focus on the use cases of CPS. This is where I find some other CPS articles to be insufficient in, for example, when I first read the CPS entry on Wikibooks, I got what CPS is but it wasn’t quite clear to me what the different ways are in which CPS can profitably be used in practice. Besides, CPS is a fairly obscure, convoluted and counter-intuitive thing, so I reckon another post that explains it is always beneficial.

I’m aiming to write this post in a way that would have helped myself understand CPS better when I first learned it. It can be considered a supplement to the above Wikibooks entry.

I’ll first go over CPS basics, followed by discussing some use cases of CPS, and only in the end will I briefly touch upon the Cont monad and callCC. I personally find this flow to be more natural and digestible, since it keeps things more concrete and “bare metal” for the most part of the post, without the extra layer of abstraction Cont and callCC introduce.

Besides the CPS entry on Wikibooks, some examples and ideas in this post come from two excellent books: Design Concepts in Programming Languages (DCPL) by Franklyn Turbak and David Gifford with Mark Sheldon, and Compiling with Continuations (CwC) by Andrew W. Appel.

CPS Basics

Normally a function takes some input, returns some value, and then the caller of the function continues to do something with that returned value. In CPS, instead of returning some value, the function would take a continuation function, which represents what the caller would do with the returned value, as a first-class parameter, and calls the continuation where the normal function returns the result.

So, a function of type a -> b would become a -> (b -> r) -> r in CPS, where b -> r is the continuation. These two types are isomorphic due to the isomorphism between b and forall r. (b -> r) -> r:

toCPS :: a -> (forall r. (a -> r) -> r)
toCPS = flip ($)

fromCPS :: (forall r. (a -> r) -> r) -> a
fromCPS = ($ id)

For a simple example, take a look at the pythagoras and pythagoras_cps functions in the Wikibooks entry.

This is pretty much all there is to say about what CPS is. At this point, a pressing question would be: why is CPS useful? After all, it makes the types longer, makes the code more obscure, and doesn’t seem to be able to do anything regular functions can’t do.

To answer this question, recognize that CPS has the following characteristics:

1. Every function call is a tail call. You don’t ever call a function, then do something with the returned value. Instead, that “something” is passed to the function as the continuation. Functions don’t return, but pass their results to the continuations. Therefore there are no non-tail calls.

2. CPS makes a number of concepts more explicit, including return address, call stack, evaluation order, intermediate results, etc.

  • When a function is called, a return address needs to be made available, which is where the callee should return to when it is done. For regular functions, return addresses are transparently handled by the subroutine mechanism: a return address is pushed onto the call stack just before entering a function. This is not something programmers implementing the functions need to care about. In CPS, continuations can be viewed as explicit return addresses: a function “returns” by calling its continuation.

  • It is similar for the call stack. Take a look at the pythagoras_cps example in the Wikibooks entry. The nested continuations can be regarded as representing the call stack. Each time a function is called, a new continuation is created which corresponds to the new stack frame for the function call.

  • Evaluation order and intermediate results also become explicit. In the pythagoras example in the Wikibooks entry, it is unclear which is evaluated first, square x or square y. Indeed, the compiler is free to do either. And the results of square x and square y are not named. In pythagoras_cps, the evaluation order and the names of intermediate results are both explicitly spelled out.

3. You have more power when you have continuations as a first class functions. When implementing a regular function, you can only return the result at the end of the function. But when you have continuations explicitly passed to you as first class functions, you can get more creative, for instance calling the continuation of an enclosing function, passing the continuations to other functions, or storing them in some data structure and using them at a later time.

Because of these characteristics, there are a number of use cases for which CPS is a good fit. Next I’ll discuss some of these use cases in more detail.

CPS Use Cases

Stack Safety

Since all function calls in CPS are tail calls, one would naturally wonder if CPS can be used to convert a stack unsafe function into a stack safe one, and it can. Here’s a plain-vanilla, stack unsafe factorial function, and the corresponding CPS version:

-- Non-CPS, stack unsafe
fac :: Integer -> Integer
fac 0 = 1
fac n = n * fac (n-1)

-- CPS
facCPS :: Integer -> (Integer -> r) -> r
facCPS 0 k = k 1
facCPS n k = facCPS (n-1) $ k . (*n)

Two things about facCPS are worth pointing out:

  1. I didn’t bother to convert the - and the * operators into CPS. To do so, you’d need to create functions subCPS, multCPS :: Int -> Int -> (Int -> r) -> r, similar to the addCPS function in the pythagoras example in the Wikibooks entry. I’m simply treating - and * as primitive operators, rather than regular functions. Because to implement subCPS and multCPS, you still need to use - and * as primitive operators, which kinda defeats the purpose.
  2. This facCPS implementation is in fact still not stack safe, because of Haskell’s laziness. The (*n) will keep building up a large thunk before evaluating it, and evaluating the thunk may cause stack overflow for large n. To make it stack safe we can replace k . (*n) with \x -> x `seq` k (x * n). But this is a separate issue, so for simplicity I’m ignoring this laziness problem and just pretending that k . (*n) is fine.

Now facCPS id is the factorial function, and it is stack safe as long as the programming language we use supports tail-call optimization (TCO), which Haskell does. It is still stack unsafe in a non-TCO language. For example, facCPS, when translated into Scala, is not stack safe, because Scala only optimizes simple self-recursions, but doesn’t have TCO in general. In other words, Scala is not a properly tail recursive language. In particular, function composition in Scala is not stack safe: if you compose the identity function with itself a million times, the result is a function that causes stack overflow for all input. In Scala your options are trampolining the recursive function, or implementing it using tail recursion or iteration.

It is also worth pointing out that, even though facCPS is stack safe in Haskell, it is not the best way to implement the factorial function. The following iterative implementation is much more efficient:

facIterative :: Integer -> Integer
facIterative = go 1
  where
    go !acc 0 = acc
    go !acc n = go (acc*n) (n-1)

(This is called an iterative implementation because it is a simple self-recursive implementation, it works very much like a while loop in an imperative language, and can easily be translated into one.)

facIterative is more efficient mainly because the implementation is simple enough such that GHC, with -O2 enabled, is able to compile it into a simple loop. facIterative and facCPS id have similar performances without -O2, but with -O2, the former is much faster and performs much less memory allocation.

Non-Local Exits

Non-local exits refers to a “jump” action in which a nested function doesn’t return to its immediate caller, but returns to some outer level of control. This can be achieved using goto in many programming languages. goto is generally avoided but the same effect can be achieved by a combination of break, continue and return. Haskell, of course, has none of these things, instead, CPS is a nice way to implement non-local exits.

As an example, consider the following leaf-valued, non-empty tree data type:

data Tree = Branch Tree Tree | Leaf Int

Suppose we are tasked to implement a function that sums up the leaf values of a tree. This is fairly straightforward:

leafSum :: Tree -> Int
leafSum (Leaf x) = x
leafSum (Branch l r) = leafSum l + leafSum r

The CPS version is:

leafSumCPS :: Tree -> (Int -> r) -> r
leafSumCPS (Leaf x) k = k x
leafSumCPS (Branch l r) k =
  leafSumCPS l $ \vl ->
    leafSumCPS r $ \vr ->
      k (vl + vr)

Now, suppose a (completely arbitrary and useless, admittedly) requirement is added to our task: if any leaf’s value is 6, return 1000. Otherwise, still return the sum of all leaf values.

How do we modify our programs to accommodate this new requirement? It’s not completely straightforward any more. We have a few options:

Approach 1: Traverse Twice

The naive approach is to traverse the tree twice. In the first traversal we check whether there’s any Leaf 6, and if not, we traverse it again to compute the leaf sum. The code is omitted. This works, but the problem is obvious: the tree is traversed twice.

Approach 2: Fuse the Two Traversals

Instead of traversing the tree twice, we can fuse the two traversals, by making the inner function go return a pair. The first component of the pair indicates whether we have found a Leaf 6, and the second component is the leaf sum of the current subtree.

leafSumFused :: Tree -> Int
leafSumFused = snd . go False
  where
    go True _ = (True, 1000)
    go False (Leaf 6) = (True, 1000)
    go False (Leaf x) = (False, x)
    go False (Branch l r) =
      let (bl, resl) = go False l
          (br, resr) = go False r
       in if bl || br
            then (True, 1000)
            else (False, resl + resr)

This reduces the number of traversals to one, but it’s kinda ugly, and is definitely not the most readable and extensible code one could ever wish to see.

Approach 3: Implement Tree Traversal Iteratively

Instead of a recursive tree traversal, we can write an iterative one. Unlike the factorial function, where the iterative implementation is straightforward, an iterative tree traversal requires explicitly managing a stack. The implementation looks like this:

leafSumIterative :: Tree -> Int
leafSumIterative tree = go [tree] 0
  where
    go []         !acc = acc
    go (top:rest) !acc = case top of
      Branch l r -> go (l:r:rest) acc
      Leaf 6 -> 1000
      Leaf x -> go rest (acc+x)

This works nicely, but iterative implementations are often less natural, and harder to read and write, compared to recursive implementations. In this particular case, the iterative implementation needs a stack (the first argument to go). In the general case, the iterative implementation can be much more complex.

Approach 4: CPS

This problem can be solved gracefully using CPS:

leafSumCPS' :: Tree -> (Int -> r) -> r
leafSumCPS' tree k = go tree k
  where
    go (Leaf 6) _  = k  1000
    go (Leaf x) k' = k' x
    go (Branch l r) k' =
      go l $ \vl ->
        go r $ \vr ->
          k' (vl + vr)

Here, k is the outer continuation, which expects the final result; k' is the inner continuation, which expects the current local result. In the Leaf 6 case, we pass 1000 to the outer continuation k. This effectively returns 1000 as the final result. Put it another way, k here serves a similar purpose to the return keyword in many programming languages. Calling the final continuation with a value is equivalent to returning that value.

The takeaway of this example is that, with CPS it is much easier to jump to a certain part of the code, as long as we have the continuations we need at our disposal, and invoke that right continuation at the right time. The next use case, backtracking, further demonstrates this point.

Backtracking

Backtracking is a class of algorithms for finding solutions to computational problems by traversing a search space. It can be especially useful for NP-complete problems, where an asymptotically more efficient algorithm may not exist. Intuitively, in a tree- or graph-shaped search space, backtracking involves making a guess at a certain node, continuing the search, and upon hitting a wall, going back to that node and taking a different branch.

So when we explore the search space using backtracking, we need to jump to certain nodes at certain times. And for the same reason CPS is a good fit for non-local exits, it is also a good fit for backtracking. By calling the right continuation at the right time, you can jump around the search space however you want.

The following example, stolen from the DCPL book, illustrates CPS-based backtracking. It uses backtracking to solve the SAT problem, which is NP-complete. The following data type is used for boolean formulae:

data BF = Var String | Not BF | And BF BF | Or BF BF

and the function we want to implement has the following type and expected behavior:

solve :: BF -> Maybe (Map String Bool)

test1 = Var "a" `And` (Not (Var "b") `And` Var "c")
test2 = Var "a" `And` Not (Var "a")
test3 = Var "a" `Or` (Not (Var "b") `And` Var "c")
test4 = (Var "a" `Or` (Var "b" `Or` Var "c")) `And` (Not (Var "a") `And` Not (Var "b"))

solve test1 `shouldBe` Just ([("a",True),("b",False),("c",True)])
solve test2 `shouldBe` Nothing
solve test3 `shouldBe` Just ([("a",True)])
solve test4 `shouldBe` Just ([("a",False),("b",False),("c",True)])

Note that the solution for test3 only contains ("a",True) since the values of b and c don’t matter.

Here’s the implementation of solve:

import Data.Map (Map, (!?))
import qualified Data.Map as Map
import Prelude hiding (fail, succ)

sat :: BF
    -> Map String Bool
    -> (Bool -> Map String Bool -> r -> r)
    -> r
    -> r
sat bf asst succ fail = case bf of
  Var v ->
    case asst !? v of
      Just b -> succ b asst fail
      Nothing ->
        let asstT = Map.insert v True asst
            asstF = Map.insert v False asst
            tryT = succ True asstT tryF
            tryF = succ False asstF fail
         in tryT

  Not bf' ->
    let succNot = succ . not
     in sat bf' asst succNot fail

  And l r ->
    let succAnd True asstAnd failAnd = sat r asstAnd succ failAnd
        succAnd False asstAnd failAnd = succ False asstAnd failAnd
     in sat l asst succAnd fail

  Or l r ->
    let succOr True asstOr failOr = succ True asstOr failOr
        succOr False asstOr failOr = sat r asstOr succ failOr
     in sat l asst succOr fail

solve :: BF -> Maybe (Map String Bool)
solve bf =
  sat
    bf
    Map.empty
    (\b asst fail -> if b then Just asst else fail)
    Nothing

The Not, And and Or cases are fairly routine. The key of the implementation is the Var case, where if the current variable is unassigned, we first try setting it to True, then continue with the exploration by calling the succ continuation. If it eventually fails, the tryF continuation will be called, which essentially goes back to the current node, resets the variable to False, and then continues with the same exploration by calling the same succ continuation. If it fails again, then it means the variable in question can be neither True nor False, suggesting that there’s no solution to the Boolean formula. The top-level fail continuation is then called, causing the entire calculation to return Nothing.

There are several continuations in this implementation, which may take some time to sort out. Specifically, sat takes a succ continuation of type Bool -> Map String Bool -> r -> r, a (global) fail continuation of type r (consider it a simplification of () -> r), and the succ continuation itself also takes a (local) fail continuation of type r. It is a good exercise to attempt to fully understand how this implementation works.

Backtracking using a succ continuation and a fail continuation is the machinery behind the logict library.

Associativity of Monoid/Monad Operations

Left-associative <> is asymptotically expensive for some monoid, most notably the free monoid (i.e., cons-lists), and left-associative >>= is asymptotically expensive for some monad, most notably the free monad. One way to turn the table and change the associativity of <> and >>= is by using continuations.

First of all, CPS computations can be composed like this:

chainCPS :: ((a -> r) -> r) -> (a -> ((b -> r) -> r)) -> ((b -> r) -> r)
chainCPS s f = s . flip f

By the way, this suggests that (_ -> r) -> r) is a monad (i.e., the Cont monad). By turning <> and >>= into chainCPS, we can change left-associative operations into right-associative ones.

Monoid

Every monoidal value a can be embedded into, and projected from a CPS computation of type (() -> a) -> a:

monoidToCPS :: Monoid a => a -> (() -> a) -> a
monoidToCPS a = (a <>) . ($ ())

monoidFromCPS :: Monoid a => ((() -> a) -> a) -> a
monoidFromCPS cps = cps (const mempty)

Now, given a left-associative <>, we can turn each operand into a CPS computation, substitute chainCPS for <>, and finally, project the monoid out of the result, for instance

-- left-associative <>
sumL = ([1,2,3] <> [4,5,6]) <> [7,8,9]

-- right-associative <>
sumR = monoidFromCPS $
  monoidToCPS [1,2,3]
    `chainCPS` (\_ -> monoidToCPS [4,5,6])
    `chainCPS` (\_ -> monoidToCPS [7,8,9])

If we expand chainCPS in the definition of sumR, and simplify, we get

sumR = [1,2,3] <> ([4,5,6] <> [7,8,9])

which is precisely the right-associative version of sumL. This is the mechanism behind DList.

Monad

In a completely analogous fashion, every monadic value m a can be embedded into, and projected from a CPS computation of type forall r. (a -> m r) -> m r:

monadToCPS :: Monad m => m a -> (forall r. (a -> m r) -> m r)
monadToCPS ma = (ma >>=)

monadFromCPS :: Monad m => (forall r. (a -> m r) -> m r) -> m a
monadFromCPS cps = cps pure

Given a left-associative >>=, we can turn each operand into a CPS computation, substitute chainCPS for >>=, and project the monadic value out of the result, for instance

-- left-associative >>=
resL = [1,2,3] >>= (\x -> [x+1]) >>= (\y -> [y+2])

-- right-associative >>=
resR = monadFromCPS $
  monadToCPS [1,2,3]
    `chainCPS` (\x -> monadToCPS [x+1])
    `chainCPS` (\y -> monadToCPS [y+2])

If we expand chainCPS in the definition of resR, and simplify, we get

resR = [1,2,3] >>= (\x -> [x+1] >>= (\y -> [y+2]))

which is precisely the right-associative version of resL. This is the mechanism behind Codensity.

Using a similar idea, we can express foldl in terms of foldr and vice versa. Take a look at the implementation and see if you can spot the continuations involved. Hint: a -> a and (() -> a) -> a are isomorphic types.

CPS as a Compiler IR

The CwC book gives an excellent detailed explanation on the usage of CPS in the Standard ML of New Jersey compiler. To very briefly summarize why CPS is useful as a compiler IR: since CPS makes a number of things more explicit, such as return address, call stack, intermediate results, as well as every aspect of the control flow, a CPS’ed program is usually much closer to the assembly code that the compiler eventually generates, compared to the original program. For example, the variables in a CPS’ed program corresponds quite closely to the registers of the target machine. The explicitness is often the opposite of what a programmer writing programs wants, but they are highly relevant to the compiler implementation.

The Cont Monad and callCC

CPS computations can be composed as demonstrated by chainCPS. This gives rise to the Cont monad. Cont r a is simply a newtype wrapper around (a -> r) -> r.

The continuation a -> r is “hidden” by Cont r a, so how do we manipulate the continuations using Cont r a like what we did in the examples in Non-Local Exit and Backtracking? The answer is callCC. callCC stands for “call with current continuation”, and it brings the current continuation, i.e., the continuation that expects the final result of the current computation, into scope.

If we adapt the leafSumpCPS' function (in the non-local exits example) to use the Cont monad and callCC, it would look like this:

leafSumCPS'' :: Tree -> Cont r Int
leafSumCPS'' tree = callCC $ \k ->
  let
    go (Leaf 6) _  = k  1000
    go (Leaf x) k' = k' x
    go (Branch l r) k' =
      go l $ \vl ->
        go r $ \vr ->
          k' (vl + vr)
  in
    go tree k

The go function is exactly the same as before. We just need to add a callCC at the beginning of leafSumCPS''. More examples of callCC can be found in the Wikibooks entry.

The implementation of callCC, after stripping out the cont and runCont wrappers, is

callCC :: ((a -> (b -> r) -> r) -> (a -> r) -> r)
       -> (a -> r) -> r
callCC f h = f (\a _ -> h a) h

It is a good exercise to try to decipher what the type means and what the implementation does. Hint: note the ignored argument (_) in the implementation. Its type is b -> r, and it is a continuation that basically represents the rest of the computation given the result of calling k (which can be any type b; it doesn’t matter). The entire rest of the computation is ignored, and so we are effectively returning whatever value passed to k.

To make it further clear, think about what happens if we never call the “exit function” k vs. we do call k. For example, if we never call k, which is the first argument to f, then it is the same as if f ignores its first argument, say f = const res where res :: (a -> r) -> r. In this case res is also what callCC f returns, i.e., it returns the same result as if we didn’t use callCC in the first place.

Conclusions

This post is a gentle introduction to CPS and some of its use cases. Once again, CPS is obscure, labyrinthine, and counter-intuitive in many ways. Overusing it is hardly ever a compelling idea (some say it’s like abusing goto statements in many programming languages). It can be highly rewarding to understand CPS, but it’s probably advisable to resist the urge to use it except for the killer use cases.

Acknowledgement

This post is adapted from a presentation I made in one of Formation’s Haskell study group sessions. I received helpful feedback on both the presentation and a draft of this post from coworker Ian-Woo Kim.