This is yet another CPS introduction post, with a focus on the use cases of CPS. This is where I find some other CPS articles to be insufficient in, for example, when I first read the CPS entry on Wikibooks, I got what CPS is but it wasn’t quite clear to me what the different ways are in which CPS can profitably be used in practice. Besides, CPS is a fairly obscure, convoluted and counter-intuitive thing, so I reckon another post that explains it is always beneficial.
I’m aiming to write this post in a way that would have helped myself understand CPS better when I first learned it. It can be considered a supplement to the above Wikibooks entry.
I’ll first go over CPS basics, followed by discussing some use cases of CPS, and only in the end will I briefly touch upon
Cont monad and
callCC. I personally find this flow to be more natural and digestible,
since it keeps things more concrete and “bare metal” for the most part of the post, without the extra
layer of abstraction
Besides the CPS entry on Wikibooks, some examples and ideas in this post come from two excellent books: Design Concepts in Programming Languages (DCPL) by Franklyn Turbak and David Gifford with Mark Sheldon, and Compiling with Continuations (CwC) by Andrew W. Appel.
Normally a function takes some input, returns some value, and then the caller of the function continues to do something with that returned value. In CPS, instead of returning some value, the function would take a continuation function, which represents what the caller would do with the returned value, as a first-class parameter, and calls the continuation where the normal function returns the result.
So, a function of type
a -> b would become
a -> (b -> r) -> r in CPS, where
b -> r is the
continuation. These two types are isomorphic due to the isomorphism between
forall r. (b -> r) -> r:
toCPS :: a -> (forall r. (a -> r) -> r) toCPS = flip ($) fromCPS :: (forall r. (a -> r) -> r) -> a fromCPS = ($ id)
For a simple example, take a look at the
pythagoras_cps functions in the Wikibooks entry.
This is pretty much all there is to say about what CPS is. At this point, a pressing question would be: why is CPS useful? After all, it makes the types longer, makes the code more obscure, and doesn’t seem to be able to do anything regular functions can’t do.
To answer this question, recognize that CPS has the following characteristics:
1. Every function call is a tail call. You don’t ever call a function, then do something with the returned value. Instead, that “something” is passed to the function as the continuation. Functions don’t return, but pass their results to the continuations. Therefore there are no non-tail calls.
2. CPS makes a number of concepts more explicit, including return address, call stack, evaluation order, intermediate results, etc.
When a function is called, a return address needs to be made available, which is where the callee should return to when it is done. For regular functions, return addresses are transparently handled by the subroutine mechanism: a return address is pushed onto the call stack just before entering a function. This is not something programmers implementing the functions need to care about. In CPS, continuations can be viewed as explicit return addresses: a function “returns” by calling its continuation.
It is similar for the call stack. Take a look at the
pythagoras_cpsexample in the Wikibooks entry. The nested continuations can be regarded as representing the call stack. Each time a function is called, a new continuation is created which corresponds to the new stack frame for the function call.
Evaluation order and intermediate results also become explicit. In the
pythagorasexample in the Wikibooks entry, it is unclear which is evaluated first,
square y. Indeed, the compiler is free to do either. And the results of
square yare not named. In
pythagoras_cps, the evaluation order and the names of intermediate results are both explicitly spelled out.
3. You have more power when you have continuations as a first class functions. When implementing a regular function, you can only return the result at the end of the function. But when you have continuations explicitly passed to you as first class functions, you can get more creative, for instance calling the continuation of an enclosing function, passing the continuations to other functions, or storing them in some data structure and using them at a later time.
Because of these characteristics, there are a number of use cases for which CPS is a good fit. Next I’ll discuss some of these use cases in more detail.
CPS Use Cases
Since all function calls in CPS are tail calls, one would naturally wonder if CPS can be used to convert a stack unsafe function into a stack safe one, and it can. Here’s a plain-vanilla, stack unsafe factorial function, and the corresponding CPS version:
-- Non-CPS, stack unsafe fac :: Integer -> Integer fac 0 = 1 fac n = n * fac (n-1) -- CPS facCPS :: Integer -> (Integer -> r) -> r facCPS 0 k = k 1 facCPS n k = facCPS (n-1) $ k . (*n)
Two things about
facCPS are worth pointing out:
- I didn’t bother to convert the
*operators into CPS. To do so, you’d need to create functions
subCPS, multCPS :: Int -> Int -> (Int -> r) -> r, similar to the
addCPSfunction in the
pythagorasexample in the Wikibooks entry. I’m simply treating
*as primitive operators, rather than regular functions. Because to implement
multCPS, you still need to use
*as primitive operators, which kinda defeats the purpose.
facCPSimplementation is in fact still not stack safe, because of Haskell’s laziness. The
(*n)will keep building up a large thunk before evaluating it, and evaluating the thunk may cause stack overflow for large
n. To make it stack safe we can replace
k . (*n)with
\x -> x `seq` k (x * n). But this is a separate issue, so for simplicity I’m ignoring this laziness problem and just pretending that
k . (*n)is fine.
facCPS id is the factorial function, and it is stack safe as long as the programming language we use supports tail-call optimization (TCO), which Haskell does.
It is still stack unsafe in a non-TCO language. For example,
facCPS, when translated into Scala, is not stack safe,
because Scala only optimizes simple self-recursions, but doesn’t have TCO in general.
In other words, Scala is not a properly tail recursive language. In particular, function
composition in Scala is not stack safe: if you compose the identity function with itself a million times,
the result is a function that causes stack overflow for all input. In Scala your options are trampolining the
recursive function, or implementing it using tail recursion or iteration.
It is also worth pointing out that, even though
facCPS is stack safe in Haskell, it is not the best
way to implement the factorial function. The following iterative implementation is much more
facIterative :: Integer -> Integer facIterative = go 1 where go !acc 0 = acc go !acc n = go (acc*n) (n-1)
(This is called an iterative implementation because it is a simple self-recursive implementation,
it works very much like a
while loop in an imperative language, and can easily be translated into one.)
facIterative is more efficient mainly because the implementation is simple enough such that GHC, with
is able to compile it into a simple loop.
facCPS id have similar performances
-O2, but with
-O2, the former is much faster and performs much less memory allocation.
Non-local exits refers to a “jump” action in which a nested function doesn’t return to its immediate caller, but
returns to some outer level of control. This can be achieved using
goto in many programming languages.
goto is generally avoided but the same effect can be achieved by a combination of
return. Haskell, of course, has none of these things, instead, CPS is a nice way to implement non-local exits.
As an example, consider the following leaf-valued, non-empty tree data type:
data Tree = Branch Tree Tree | Leaf Int
Suppose we are tasked to implement a function that sums up the leaf values of a tree. This is fairly straightforward:
leafSum :: Tree -> Int leafSum (Leaf x) = x leafSum (Branch l r) = leafSum l + leafSum r
The CPS version is:
leafSumCPS :: Tree -> (Int -> r) -> r leafSumCPS (Leaf x) k = k x leafSumCPS (Branch l r) k = leafSumCPS l $ \vl -> leafSumCPS r $ \vr -> k (vl + vr)
Now, suppose a (completely arbitrary and useless, admittedly) requirement is added to our task: if any leaf’s value is 6, return 1000. Otherwise, still return the sum of all leaf values.
How do we modify our programs to accommodate this new requirement? It’s not completely straightforward any more. We have a few options:
Approach 1: Traverse Twice
The naive approach is to traverse the tree twice. In the first traversal we check whether
Leaf 6, and if not, we traverse it again to compute the leaf sum. The code is omitted.
This works, but the problem is obvious: the tree is traversed twice.
Approach 2: Fuse the Two Traversals
Instead of traversing the tree twice, we can fuse the two traversals, by making
the inner function
go return a pair. The first component of the pair indicates
whether we have found a
Leaf 6, and the second component is the leaf sum of the
leafSumFused :: Tree -> Int leafSumFused = snd . go False where go True _ = (True, 1000) go False (Leaf 6) = (True, 1000) go False (Leaf x) = (False, x) go False (Branch l r) = let (bl, resl) = go False l (br, resr) = go False r in if bl || br then (True, 1000) else (False, resl + resr)
This reduces the number of traversals to one, but it’s kinda ugly, and is definitely not the most readable and extensible code one could ever wish to see.
Approach 3: Implement Tree Traversal Iteratively
Instead of a recursive tree traversal, we can write an iterative one. Unlike the factorial function, where the iterative implementation is straightforward, an iterative tree traversal requires explicitly managing a stack. The implementation looks like this:
leafSumIterative :: Tree -> Int leafSumIterative tree = go [tree] 0 where go  !acc = acc go (top:rest) !acc = case top of Branch l r -> go (l:r:rest) acc Leaf 6 -> 1000 Leaf x -> go rest (acc+x)
This works nicely, but iterative implementations are often less natural, and harder to read and write,
compared to recursive implementations.
In this particular case, the iterative implementation needs a stack (the first argument to
go). In the general case,
the iterative implementation can be much more complex.
Approach 4: CPS
This problem can be solved gracefully using CPS:
leafSumCPS' :: Tree -> (Int -> r) -> r leafSumCPS' tree k = go tree k where go (Leaf 6) _ = k 1000 go (Leaf x) k' = k' x go (Branch l r) k' = go l $ \vl -> go r $ \vr -> k' (vl + vr)
k is the outer continuation, which expects the final result;
k' is the inner
continuation, which expects the current local result. In the
Leaf 6 case, we pass
1000 to the outer continuation
k. This effectively returns 1000 as the final result.
Put it another way,
k here serves a similar purpose to the
return keyword in many
programming languages. Calling the final continuation with a value is
equivalent to returning that value.
The takeaway of this example is that, with CPS it is much easier to jump to a certain part of the code, as long as we have the continuations we need at our disposal, and invoke that right continuation at the right time. The next use case, backtracking, further demonstrates this point.
Backtracking is a class of algorithms for finding solutions to computational problems by traversing a search space. It can be especially useful for NP-complete problems, where an asymptotically more efficient algorithm may not exist. Intuitively, in a tree- or graph-shaped search space, backtracking involves making a guess at a certain node, continuing the search, and upon hitting a wall, going back to that node and taking a different branch.
So when we explore the search space using backtracking, we need to jump to certain nodes at certain times. And for the same reason CPS is a good fit for non-local exits, it is also a good fit for backtracking. By calling the right continuation at the right time, you can jump around the search space however you want.
The following example, stolen from the DCPL book, illustrates CPS-based backtracking. It uses backtracking to solve the SAT problem, which is NP-complete. The following data type is used for boolean formulae:
data BF = Var String | Not BF | And BF BF | Or BF BF
and the function we want to implement has the following type and expected behavior:
solve :: BF -> Maybe (Map String Bool) test1 = Var "a" `And` (Not (Var "b") `And` Var "c") test2 = Var "a" `And` Not (Var "a") test3 = Var "a" `Or` (Not (Var "b") `And` Var "c") test4 = (Var "a" `Or` (Var "b" `Or` Var "c")) `And` (Not (Var "a") `And` Not (Var "b")) solve test1 `shouldBe` Just ([("a",True),("b",False),("c",True)]) solve test2 `shouldBe` Nothing solve test3 `shouldBe` Just ([("a",True)]) solve test4 `shouldBe` Just ([("a",False),("b",False),("c",True)])
Note that the solution for
test3 only contains
("a",True) since the values of
Here’s the implementation of
import Data.Map (Map, (!?)) import qualified Data.Map as Map import Prelude hiding (fail, succ) sat :: BF -> Map String Bool -> (Bool -> Map String Bool -> r -> r) -> r -> r sat bf asst succ fail = case bf of Var v -> case asst !? v of Just b -> succ b asst fail Nothing -> let asstT = Map.insert v True asst asstF = Map.insert v False asst tryT = succ True asstT tryF tryF = succ False asstF fail in tryT Not bf' -> let succNot = succ . not in sat bf' asst succNot fail And l r -> let succAnd True asstAnd failAnd = sat r asstAnd succ failAnd succAnd False asstAnd failAnd = succ False asstAnd failAnd in sat l asst succAnd fail Or l r -> let succOr True asstOr failOr = succ True asstOr failOr succOr False asstOr failOr = sat r asstOr succ failOr in sat l asst succOr fail solve :: BF -> Maybe (Map String Bool) solve bf = sat bf Map.empty (\b asst fail -> if b then Just asst else fail) Nothing
Or cases are fairly routine. The key of the implementation is
Var case, where if the current variable is unassigned, we first try setting
True, then continue with the exploration by calling the
If it eventually fails, the
tryF continuation will be called, which essentially goes back to the current node,
resets the variable to
False, and then continues with the same exploration by calling the same
succ continuation. If it fails again, then it means the variable in question can
False, suggesting that there’s no solution to the Boolean formula. The
fail continuation is then called, causing the entire calculation to
There are several continuations in this implementation, which may take some time to
sort out. Specifically,
sat takes a
succ continuation of type
Bool -> Map String Bool -> r -> r, a (global)
fail continuation of type
r (consider it
a simplification of
() -> r), and the
succ continuation itself also takes
fail continuation of type
r. It is a good exercise to attempt to
fully understand how this implementation works.
Backtracking using a
succ continuation and a
fail continuation is the machinery behind
the logict library.
Associativity of Monoid/Monad Operations
<> is asymptotically expensive for some monoid, most notably the free monoid (i.e., cons-lists),
>>= is asymptotically expensive for some monad, most notably the free monad.
One way to turn the table and change the associativity of
>>= is by using continuations.
First of all, CPS computations can be composed like this:
chainCPS :: ((a -> r) -> r) -> (a -> ((b -> r) -> r)) -> ((b -> r) -> r) chainCPS s f = s . flip f
By the way, this suggests that
(_ -> r) -> r) is a monad (i.e., the
Cont monad). By turning
chainCPS, we can change left-associative operations into
Every monoidal value
a can be embedded into, and projected from a CPS computation
(() -> a) -> a:
monoidToCPS :: Monoid a => a -> (() -> a) -> a monoidToCPS a = (a <>) . ($ ()) monoidFromCPS :: Monoid a => ((() -> a) -> a) -> a monoidFromCPS cps = cps (const mempty)
Now, given a left-associative
<>, we can turn each operand into a CPS computation,
<>, and finally, project the monoid out of the result, for instance
-- left-associative <> sumL = ([1,2,3] <> [4,5,6]) <> [7,8,9] -- right-associative <> sumR = monoidFromCPS $ monoidToCPS [1,2,3] `chainCPS` (\_ -> monoidToCPS [4,5,6]) `chainCPS` (\_ -> monoidToCPS [7,8,9])
If we expand
chainCPS in the definition of
sumR, and simplify, we get
sumR = [1,2,3] <> ([4,5,6] <> [7,8,9])
which is precisely the right-associative version of
sumL. This is the mechanism
In a completely analogous fashion, every monadic value
m a can be embedded into, and
projected from a CPS computation of type
forall r. (a -> m r) -> m r:
monadToCPS :: Monad m => m a -> (forall r. (a -> m r) -> m r) monadToCPS ma = (ma >>=) monadFromCPS :: Monad m => (forall r. (a -> m r) -> m r) -> m a monadFromCPS cps = cps pure
Given a left-associative
>>=, we can turn each operand into a CPS computation, substitute
>>=, and project the monadic value out of the result, for instance
-- left-associative >>= resL = [1,2,3] >>= (\x -> [x+1]) >>= (\y -> [y+2]) -- right-associative >>= resR = monadFromCPS $ monadToCPS [1,2,3] `chainCPS` (\x -> monadToCPS [x+1]) `chainCPS` (\y -> monadToCPS [y+2])
If we expand
chainCPS in the definition of
resR, and simplify, we get
resR = [1,2,3] >>= (\x -> [x+1] >>= (\y -> [y+2]))
which is precisely the right-associative version of
resL. This is the mechanism
Using a similar idea, we can express
foldl in terms of
foldr and vice versa. Take a look at the
implementation and see if you can spot the continuations involved.
a -> a and
(() -> a) -> a are isomorphic types.
CPS as a Compiler IR
The CwC book gives an excellent detailed explanation on the usage of CPS in the Standard ML of New Jersey compiler. To very briefly summarize why CPS is useful as a compiler IR: since CPS makes a number of things more explicit, such as return address, call stack, intermediate results, as well as every aspect of the control flow, a CPS’ed program is usually much closer to the assembly code that the compiler eventually generates, compared to the original program. For example, the variables in a CPS’ed program corresponds quite closely to the registers of the target machine. The explicitness is often the opposite of what a programmer writing programs wants, but they are highly relevant to the compiler implementation.
The Cont Monad and callCC
CPS computations can be composed as demonstrated by
chainCPS. This gives rise to the
Cont r a is simply a newtype wrapper around
(a -> r) -> r.
a -> r is “hidden” by
Cont r a, so how do we manipulate the continuations
Cont r a like what we did in the examples in Non-Local Exit and Backtracking? The answer is
callCC stands for “call with current continuation”, and it brings the current continuation, i.e.,
the continuation that expects the final result of the current computation, into scope.
If we adapt the
leafSumpCPS' function (in the non-local exits example) to use the
Cont monad and
callCC, it would look like this:
leafSumCPS'' :: Tree -> Cont r Int leafSumCPS'' tree = callCC $ \k -> let go (Leaf 6) _ = k 1000 go (Leaf x) k' = k' x go (Branch l r) k' = go l $ \vl -> go r $ \vr -> k' (vl + vr) in go tree k
go function is exactly the same as before. We just need to add a
callCC at the beginning of
leafSumCPS''. More examples
callCC can be found in the Wikibooks entry.
The implementation of
callCC, after stripping out the
runCont wrappers, is
callCC :: ((a -> (b -> r) -> r) -> (a -> r) -> r) -> (a -> r) -> r callCC f h = f (\a _ -> h a) h
It is a good exercise to try to decipher what the type means and what the implementation does. Hint: note the
ignored argument (
_) in the implementation. Its type is
b -> r, and it is a continuation that basically represents
the rest of the computation given the result of calling
k (which can be any type
b; it doesn’t matter). The entire
rest of the computation is ignored, and so we are effectively returning whatever value passed to
To make it further clear, think about what happens if we never call the “exit function”
k vs. we do call
k. For example,
if we never call
k, which is the first argument to
f, then it is the same as if
f ignores its first argument, say
f = const res where
res :: (a -> r) -> r. In this case
res is also what
callCC f returns, i.e., it returns the same
result as if we didn’t use
callCC in the first place.
This post is a gentle introduction to CPS and some of its use cases. Once again, CPS is obscure, labyrinthine,
and counter-intuitive in many ways. Overusing it is hardly ever a compelling idea (some say it’s like abusing
in many programming languages). It can be highly rewarding to understand CPS, but it’s probably advisable
to resist the urge to use it except for the killer use cases.
This post is adapted from a presentation I made in one of Formation’s Haskell study group sessions. I received helpful feedback on both the presentation and a draft of this post from coworker Ian-Woo Kim.