This is yet another CPS introduction post, with a focus on the use cases of CPS. This is where I find some other CPS articles to be insufficient in, for example, when I first read the CPS entry on Wikibooks, I got what CPS is but it wasnâ€™t quite clear to me what the different ways are in which CPS can profitably be used in practice. Besides, CPS is a fairly obscure, convoluted and counter-intuitive thing, so I reckon another post that explains it is always beneficial.

Iâ€™m aiming to write this post in a way that would have helped myself understand CPS better when I first learned it. It can be considered a supplement to the above Wikibooks entry.

Iâ€™ll first go over CPS basics, followed by discussing some use cases of CPS, and only in the end will I briefly touch upon the `Cont` monad and `callCC`. I personally find this flow to be more natural and digestible, since it keeps things more concrete and â€śbare metalâ€ť for the most part of the post, without the extra layer of abstraction `Cont` and `callCC` introduce.

Besides the CPS entry on Wikibooks, some examples and ideas in this post come from two excellent books: Design Concepts in Programming Languages (DCPL) by Franklyn Turbak and David Gifford with Mark Sheldon, and Compiling with Continuations (CwC) by Andrew W. Appel.

# CPS Basics

Normally a function takes some input, returns some value, and then the caller of the function continues to do something with that returned value. In CPS, instead of returning some value, the function would take a continuation function, which represents what the caller would do with the returned value, as a first-class parameter, and calls the continuation where the normal function returns the result.

So, a function of type `a -> b` would become `a -> (b -> r) -> r` in CPS, where `b -> r` is the continuation. These two types are isomorphic due to the isomorphism between `b` and `forall r. (b -> r) -> r`:

``````toCPS :: a -> (forall r. (a -> r) -> r)
toCPS = flip (\$)

fromCPS :: (forall r. (a -> r) -> r) -> a
fromCPS = (\$ id)
``````

For a simple example, take a look at the `pythagoras` and `pythagoras_cps` functions in the Wikibooks entry.

This is pretty much all there is to say about what CPS is. At this point, a pressing question would be: why is CPS useful? After all, it makes the types longer, makes the code more obscure, and doesnâ€™t seem to be able to do anything regular functions canâ€™t do.

To answer this question, recognize that CPS has the following characteristics:

1. Every function call is a tail call. You donâ€™t ever call a function, then do something with the returned value. Instead, that â€śsomethingâ€ť is passed to the function as the continuation. Functions donâ€™t return, but pass their results to the continuations. Therefore there are no non-tail calls.

2. CPS makes a number of concepts more explicit, including return address, call stack, evaluation order, intermediate results, etc.

• When a function is called, a return address needs to be made available, which is where the callee should return to when it is done. For regular functions, return addresses are transparently handled by the subroutine mechanism: a return address is pushed onto the call stack just before entering a function. This is not something programmers implementing the functions need to care about. In CPS, continuations can be viewed as explicit return addresses: a function â€śreturnsâ€ť by calling its continuation.

• It is similar for the call stack. Take a look at the `pythagoras_cps` example in the Wikibooks entry. The nested continuations can be regarded as representing the call stack. Each time a function is called, a new continuation is created which corresponds to the new stack frame for the function call.

• Evaluation order and intermediate results also become explicit. In the `pythagoras` example in the Wikibooks entry, it is unclear which is evaluated first, `square x` or `square y`. Indeed, the compiler is free to do either. And the results of `square x` and `square y` are not named. In `pythagoras_cps`, the evaluation order and the names of intermediate results are both explicitly spelled out.

3. You have more power when you have continuations as a first class functions. When implementing a regular function, you can only return the result at the end of the function. But when you have continuations explicitly passed to you as first class functions, you can get more creative, for instance calling the continuation of an enclosing function, passing the continuations to other functions, or storing them in some data structure and using them at a later time.

Because of these characteristics, there are a number of use cases for which CPS is a good fit. Next Iâ€™ll discuss some of these use cases in more detail.

# CPS Use Cases

## Stack Safety

Since all function calls in CPS are tail calls, one would naturally wonder if CPS can be used to convert a stack unsafe function into a stack safe one, and it can. Hereâ€™s a plain-vanilla, stack unsafe factorial function, and the corresponding CPS version:

``````-- Non-CPS, stack unsafe
fac :: Integer -> Integer
fac 0 = 1
fac n = n * fac (n-1)

-- CPS
facCPS :: Integer -> (Integer -> r) -> r
facCPS 0 k = k 1
facCPS n k = facCPS (n-1) \$ k . (*n)
``````

Two things about `facCPS` are worth pointing out:

1. I didnâ€™t bother to convert the `-` and the `*` operators into CPS. To do so, youâ€™d need to create functions `subCPS, multCPS :: Int -> Int -> (Int -> r) -> r`, similar to the `addCPS` function in the `pythagoras` example in the Wikibooks entry. Iâ€™m simply treating `-` and `*` as primitive operators, rather than regular functions. Because to implement `subCPS` and `multCPS`, you still need to use `-` and `*` as primitive operators, which kinda defeats the purpose.
2. This `facCPS` implementation is in fact still not stack safe, because of Haskellâ€™s laziness. The `(*n)` will keep building up a large thunk before evaluating it, and evaluating the thunk may cause stack overflow for large `n`. To make it stack safe we can replace `k . (*n)` with `\x -> x `seq` k (x * n)`. But this is a separate issue, so for simplicity Iâ€™m ignoring this laziness problem and just pretending that `k . (*n)` is fine.

Now `facCPS id` is the factorial function, and it is stack safe as long as the programming language we use supports tail-call optimization (TCO), which Haskell does. It is still stack unsafe in a non-TCO language. For example, `facCPS`, when translated into Scala, is not stack safe, because Scala only optimizes simple self-recursions, but doesnâ€™t have TCO in general. In other words, Scala is not a properly tail recursive language. In particular, function composition in Scala is not stack safe: if you compose the identity function with itself a million times, the result is a function that causes stack overflow for all input. In Scala your options are trampolining the recursive function, or implementing it using tail recursion or iteration.

It is also worth pointing out that, even though `facCPS` is stack safe in Haskell, it is not the best way to implement the factorial function. The following iterative implementation is much more efficient:

``````facIterative :: Integer -> Integer
facIterative = go 1
where
go !acc 0 = acc
go !acc n = go (acc*n) (n-1)
``````

(This is called an iterative implementation because it is a simple self-recursive implementation, it works very much like a `while` loop in an imperative language, and can easily be translated into one.)

`facIterative` is more efficient mainly because the implementation is simple enough such that GHC, with `-O2` enabled, is able to compile it into a simple loop. `facIterative` and `facCPS id` have similar performances without `-O2`, but with `-O2`, the former is much faster and performs much less memory allocation.

## Non-Local Exits

Non-local exits refers to a â€śjumpâ€ť action in which a nested function doesnâ€™t return to its immediate caller, but returns to some outer level of control. This can be achieved using `goto` in many programming languages. `goto` is generally avoided but the same effect can be achieved by a combination of `break`, `continue` and `return`. Haskell, of course, has none of these things, instead, CPS is a nice way to implement non-local exits.

As an example, consider the following leaf-valued, non-empty tree data type:

``````data Tree = Branch Tree Tree | Leaf Int
``````

Suppose we are tasked to implement a function that sums up the leaf values of a tree. This is fairly straightforward:

``````leafSum :: Tree -> Int
leafSum (Leaf x) = x
leafSum (Branch l r) = leafSum l + leafSum r
``````

The CPS version is:

``````leafSumCPS :: Tree -> (Int -> r) -> r
leafSumCPS (Leaf x) k = k x
leafSumCPS (Branch l r) k =
leafSumCPS l \$ \vl ->
leafSumCPS r \$ \vr ->
k (vl + vr)
``````

Now, suppose a (completely arbitrary and useless, admittedly) requirement is added to our task: if any leafâ€™s value is 6, return 1000. Otherwise, still return the sum of all leaf values.

How do we modify our programs to accommodate this new requirement? Itâ€™s not completely straightforward any more. We have a few options:

### Approach 1: Traverse Twice

The naive approach is to traverse the tree twice. In the first traversal we check whether thereâ€™s any `Leaf 6`, and if not, we traverse it again to compute the leaf sum. The code is omitted. This works, but the problem is obvious: the tree is traversed twice.

### Approach 2: Fuse the Two Traversals

Instead of traversing the tree twice, we can fuse the two traversals, by making the inner function `go` return a pair. The first component of the pair indicates whether we have found a `Leaf 6`, and the second component is the leaf sum of the current subtree.

``````leafSumFused :: Tree -> Int
leafSumFused = snd . go False
where
go True _ = (True, 1000)
go False (Leaf 6) = (True, 1000)
go False (Leaf x) = (False, x)
go False (Branch l r) =
let (bl, resl) = go False l
(br, resr) = go False r
in if bl || br
then (True, 1000)
else (False, resl + resr)
``````

This reduces the number of traversals to one, but itâ€™s kinda ugly, and is definitely not the most readable and extensible code one could ever wish to see.

### Approach 3: Implement Tree Traversal Iteratively

Instead of a recursive tree traversal, we can write an iterative one. Unlike the factorial function, where the iterative implementation is straightforward, an iterative tree traversal requires explicitly managing a stack. The implementation looks like this:

``````leafSumIterative :: Tree -> Int
leafSumIterative tree = go [tree] 0
where
go []         !acc = acc
go (top:rest) !acc = case top of
Branch l r -> go (l:r:rest) acc
Leaf 6 -> 1000
Leaf x -> go rest (acc+x)
``````

This works nicely, but iterative implementations are often less natural, and harder to read and write, compared to recursive implementations. In this particular case, the iterative implementation needs a stack (the first argument to `go`). In the general case, the iterative implementation can be much more complex.

### Approach 4: CPS

This problem can be solved gracefully using CPS:

``````leafSumCPS' :: Tree -> (Int -> r) -> r
leafSumCPS' tree k = go tree k
where
go (Leaf 6) _  = k  1000
go (Leaf x) k' = k' x
go (Branch l r) k' =
go l \$ \vl ->
go r \$ \vr ->
k' (vl + vr)
``````

Here, `k` is the outer continuation, which expects the final result; `k'` is the inner continuation, which expects the current local result. In the `Leaf 6` case, we pass 1000 to the outer continuation `k`. This effectively returns 1000 as the final result. Put it another way, `k` here serves a similar purpose to the `return` keyword in many programming languages. Calling the final continuation with a value is equivalent to returning that value.

The takeaway of this example is that, with CPS it is much easier to jump to a certain part of the code, as long as we have the continuations we need at our disposal, and invoke that right continuation at the right time. The next use case, backtracking, further demonstrates this point.

## Backtracking

Backtracking is a class of algorithms for finding solutions to computational problems by traversing a search space. It can be especially useful for NP-complete problems, where an asymptotically more efficient algorithm may not exist. Intuitively, in a tree- or graph-shaped search space, backtracking involves making a guess at a certain node, continuing the search, and upon hitting a wall, going back to that node and taking a different branch.

So when we explore the search space using backtracking, we need to jump to certain nodes at certain times. And for the same reason CPS is a good fit for non-local exits, it is also a good fit for backtracking. By calling the right continuation at the right time, you can jump around the search space however you want.

The following example, stolen from the DCPL book, illustrates CPS-based backtracking. It uses backtracking to solve the SAT problem, which is NP-complete. The following data type is used for boolean formulae:

``````data BF = Var String | Not BF | And BF BF | Or BF BF
``````

and the function we want to implement has the following type and expected behavior:

``````solve :: BF -> Maybe (Map String Bool)

test1 = Var "a" `And` (Not (Var "b") `And` Var "c")
test2 = Var "a" `And` Not (Var "a")
test3 = Var "a" `Or` (Not (Var "b") `And` Var "c")
test4 = (Var "a" `Or` (Var "b" `Or` Var "c")) `And` (Not (Var "a") `And` Not (Var "b"))

solve test1 `shouldBe` Just ([("a",True),("b",False),("c",True)])
solve test2 `shouldBe` Nothing
solve test3 `shouldBe` Just ([("a",True)])
solve test4 `shouldBe` Just ([("a",False),("b",False),("c",True)])
``````

Note that the solution for `test3` only contains `("a",True)` since the values of `b` and `c` donâ€™t matter.

Hereâ€™s the implementation of `solve`:

``````import Data.Map (Map, (!?))
import qualified Data.Map as Map
import Prelude hiding (fail, succ)

sat :: BF
-> Map String Bool
-> (Bool -> Map String Bool -> r -> r)
-> r
-> r
sat bf asst succ fail = case bf of
Var v ->
case asst !? v of
Just b -> succ b asst fail
Nothing ->
let asstT = Map.insert v True asst
asstF = Map.insert v False asst
tryT = succ True asstT tryF
tryF = succ False asstF fail
in tryT

Not bf' ->
let succNot = succ . not
in sat bf' asst succNot fail

And l r ->
let succAnd True asstAnd failAnd = sat r asstAnd succ failAnd
succAnd False asstAnd failAnd = succ False asstAnd failAnd
in sat l asst succAnd fail

Or l r ->
let succOr True asstOr failOr = succ True asstOr failOr
succOr False asstOr failOr = sat r asstOr succ failOr
in sat l asst succOr fail

solve :: BF -> Maybe (Map String Bool)
solve bf =
sat
bf
Map.empty
(\b asst fail -> if b then Just asst else fail)
Nothing
``````

The `Not`, `And` and `Or` cases are fairly routine. The key of the implementation is the `Var` case, where if the current variable is unassigned, we first try setting it to `True`, then continue with the exploration by calling the `succ` continuation. If it eventually fails, the `tryF` continuation will be called, which essentially goes back to the current node, resets the variable to `False`, and then continues with the same exploration by calling the same `succ` continuation. If it fails again, then it means the variable in question can be neither `True` nor `False`, suggesting that thereâ€™s no solution to the Boolean formula. The top-level `fail` continuation is then called, causing the entire calculation to return `Nothing`.

There are several continuations in this implementation, which may take some time to sort out. Specifically, `sat` takes a `succ` continuation of type `Bool -> Map String Bool -> r -> r`, a (global) `fail` continuation of type `r` (consider it a simplification of `() -> r`), and the `succ` continuation itself also takes a (local) `fail` continuation of type `r`. It is a good exercise to attempt to fully understand how this implementation works.

Backtracking using a `succ` continuation and a `fail` continuation is the machinery behind the logict library.

Left-associative `<>` is asymptotically expensive for some monoid, most notably the free monoid (i.e., cons-lists), and left-associative `>>=` is asymptotically expensive for some monad, most notably the free monad. One way to turn the table and change the associativity of `<>` and `>>=` is by using continuations.

First of all, CPS computations can be composed like this:

``````chainCPS :: ((a -> r) -> r) -> (a -> ((b -> r) -> r)) -> ((b -> r) -> r)
chainCPS s f = s . flip f
``````

By the way, this suggests that `(_ -> r) -> r)` is a monad (i.e., the `Cont` monad). By turning `<>` and `>>=` into `chainCPS`, we can change left-associative operations into right-associative ones.

### Monoid

Every monoidal value `a` can be embedded into, and projected from a CPS computation of type `(() -> a) -> a`:

``````monoidToCPS :: Monoid a => a -> (() -> a) -> a
monoidToCPS a = (a <>) . (\$ ())

monoidFromCPS :: Monoid a => ((() -> a) -> a) -> a
monoidFromCPS cps = cps (const mempty)
``````

Now, given a left-associative `<>`, we can turn each operand into a CPS computation, substitute `chainCPS` for `<>`, and finally, project the monoid out of the result, for instance

``````-- left-associative <>
sumL = ([1,2,3] <> [4,5,6]) <> [7,8,9]

-- right-associative <>
sumR = monoidFromCPS \$
monoidToCPS [1,2,3]
`chainCPS` (\_ -> monoidToCPS [4,5,6])
`chainCPS` (\_ -> monoidToCPS [7,8,9])
``````

If we expand `chainCPS` in the definition of `sumR`, and simplify, we get

``````sumR = [1,2,3] <> ([4,5,6] <> [7,8,9])
``````

which is precisely the right-associative version of `sumL`. This is the mechanism behind DList.

In a completely analogous fashion, every monadic value `m a` can be embedded into, and projected from a CPS computation of type `forall r. (a -> m r) -> m r`:

``````monadToCPS :: Monad m => m a -> (forall r. (a -> m r) -> m r)

monadFromCPS :: Monad m => (forall r. (a -> m r) -> m r) -> m a
``````

Given a left-associative `>>=`, we can turn each operand into a CPS computation, substitute `chainCPS` for `>>=`, and project the monadic value out of the result, for instance

``````-- left-associative >>=
resL = [1,2,3] >>= (\x -> [x+1]) >>= (\y -> [y+2])

-- right-associative >>=
``````

If we expand `chainCPS` in the definition of `resR`, and simplify, we get

``````resR = [1,2,3] >>= (\x -> [x+1] >>= (\y -> [y+2]))
``````

which is precisely the right-associative version of `resL`. This is the mechanism behind Codensity.

Using a similar idea, we can express `foldl` in terms of `foldr` and vice versa. Take a look at the implementation and see if you can spot the continuations involved. Hint: `a -> a` and `(() -> a) -> a` are isomorphic types.

## CPS as a Compiler IR

The CwC book gives an excellent detailed explanation on the usage of CPS in the Standard ML of New Jersey compiler. To very briefly summarize why CPS is useful as a compiler IR: since CPS makes a number of things more explicit, such as return address, call stack, intermediate results, as well as every aspect of the control flow, a CPSâ€™ed program is usually much closer to the assembly code that the compiler eventually generates, compared to the original program. For example, the variables in a CPSâ€™ed program corresponds quite closely to the registers of the target machine. The explicitness is often the opposite of what a programmer writing programs wants, but they are highly relevant to the compiler implementation.

# The Cont Monad and callCC

CPS computations can be composed as demonstrated by `chainCPS`. This gives rise to the `Cont` monad. `Cont r a` is simply a newtype wrapper around `(a -> r) -> r`.

The continuation `a -> r` is â€śhiddenâ€ť by `Cont r a`, so how do we manipulate the continuations using `Cont r a` like what we did in the examples in Non-Local Exit and Backtracking? The answer is `callCC`. `callCC` stands for â€ścall with current continuationâ€ť, and it brings the current continuation, i.e., the continuation that expects the final result of the current computation, into scope.

If we adapt the `leafSumpCPS'` function (in the non-local exits example) to use the `Cont` monad and `callCC`, it would look like this:

``````leafSumCPS'' :: Tree -> Cont r Int
leafSumCPS'' tree = callCC \$ \k ->
let
go (Leaf 6) _  = k  1000
go (Leaf x) k' = k' x
go (Branch l r) k' =
go l \$ \vl ->
go r \$ \vr ->
k' (vl + vr)
in
go tree k
``````

The `go` function is exactly the same as before. We just need to add a `callCC` at the beginning of `leafSumCPS''`. More examples of `callCC` can be found in the Wikibooks entry.

The implementation of `callCC`, after stripping out the `cont` and `runCont` wrappers, is

``````callCC :: ((a -> (b -> r) -> r) -> (a -> r) -> r)
-> (a -> r) -> r
callCC f h = f (\a _ -> h a) h
``````

It is a good exercise to try to decipher what the type means and what the implementation does. Hint: note the ignored argument (`_`) in the implementation. Its type is `b -> r`, and it is a continuation that basically represents the rest of the computation given the result of calling `k` (which can be any type `b`; it doesnâ€™t matter). The entire rest of the computation is ignored, and so we are effectively returning whatever value passed to `k`.

To make it further clear, think about what happens if we never call the â€śexit functionâ€ť `k` vs. we do call `k`. For example, if we never call `k`, which is the first argument to `f`, then it is the same as if `f` ignores its first argument, say `f = const res` where `res :: (a -> r) -> r`. In this case `res` is also what `callCC f` returns, i.e., it returns the same result as if we didnâ€™t use `callCC` in the first place.

# Conclusions

This post is a gentle introduction to CPS and some of its use cases. Once again, CPS is obscure, labyrinthine, and counter-intuitive in many ways. Overusing it is hardly ever a compelling idea (some say itâ€™s like abusing `goto` statements in many programming languages). It can be highly rewarding to understand CPS, but itâ€™s probably advisable to resist the urge to use it except for the killer use cases.

# Acknowledgement

This post is adapted from a presentation I made in one of Formationâ€™s Haskell study group sessions. I received helpful feedback on both the presentation and a draft of this post from coworker Ian-Woo Kim.