This post aims to give a quick-and-dirty explanation of MonadFix. It is quick and
dirty because MonadFix is complex and weird stuff (or at least it seems so), and I do not
think I am able to present a comprehensive discussion. Instead, this post mainly consists of a simple example
that hopefully can give one a rough idea of what MonadFix does and where it might
be useful.
It is recommended that one has basic knowledge of how fix
works before reading on, since I will be making comparison between fix and mfix, the method
in the MonadFix type class.
TL;DR
MonadFix is a type class containing a single method, mfix:
class Monad m => MonadFix m where
mfix :: (a -> m a) -> m a
Note the similarity between the types of fix and mfix:
fix :: (a -> a) -> a
mfix :: (a -> m a) -> m a
mfix is basically the monadic version of fix. Their definitions also bear similarity:
fix f = f (fix f)
mfix f = f =<< mfix f
Note that these are not how fix and mfix are implemented. They should be considered
specifications or definitions as opposed to implementations.
So,
fix f = f (f (f ...
mfix f = f =<< f =<< f =<< ...
Next, I’ll give a quick recap of fix (it won’t be comprehensive by any means) before giving an
example of mfix, since understanding the former helps with the latter.
Recap of fix
fix is a combinator that returns the (least) fix point of a function, and can be used to
eliminate recursion. For example, suppose we have a recursively defined value a of type A:
a :: A
a = <body that mentions a>
A recursive value like this can be defined non-recursively as fix f for some function f: if we define
<body that mentions a> as f a, i.e.,
f :: A -> A
f x = <body that mentions x>
then we have a = f a, and so a is a fixed point of f. Note that a is a fixed point
of f, not the fixed point, since f may have more than one fixed point. For example, take a to be
the following function:
a :: Integer -> Integer
a = \x -> a (x+1)
Obviously a x evaluates to bottom (or ⊥) for every x. The corresponding f is
f :: (Integer -> Integer) -> Integer -> Integer
f a = \x -> a (x+1)
Now, this f has infinitely many fixed points. Any function const n where n is an integer or ⊥ is a
fixed point of f, for instance
f (const 42) x = 42 = const 42 x
for any x, which shows const 42 is a fixed point of f. The function a defined by above is basically
const ⊥, which is the least fixed point of f (it is “least” because ⊥ is considered
to be “less” than any non-⊥ value, where “less” is not < on numbers, but a partial order defined
on values of a given type). Since fix f always returns the least fixed point of f,
we have a = fix f.
As another example, here’s the original and fixpoint definition of the factorial function:
facOriginal :: Natural -> Natural
facOriginal = \x -> if x == 0 then 1 else x * facOriginal (x-1)
facFix :: Natural -> Natural
facFix = fix g
where
g :: (Natural -> Natural) -> Natural -> Natural
g f = \x -> if x == 0 then 1 else x * f (x-1)
MonadFix
As mentioned before, mfix is the monadic version of fix. It is useful if we want to compute
a monadic value, whose body mentions the value inside the monad, i.e.,
aM :: M A
aM = <body that mentions a>
where M is some monad. Note that the body of aM doesn’t mention aM itself (otherwise it would be a normal
recursion), but mentions the value a :: A inside aM :: M A. This is where mfix is useful.
Analogous to what we did with fix above, let’s define fM to be the following function:
fM :: A -> M A
fM x = <body that mentions x>
Then we have aM = fM =<< aM, and so aM = mfix fM.
Now let’s see a concrete example.
An Example of mfix
In this example, let the monad be IO. The goal is to write an IO action that asks the user
to enter a boolean value, and depending on that boolean value, either returns the factorial function,
or the function that computes the kth triangular number (i.e., 1 + 2 + ... + k).
A “regular” implementation of this IO action would look like this:
fIO :: IO (Natural -> Natural)
fIO = do
let fac = \x -> if x == 0 then 1 else x * fac (x-1)
tri = \x -> if x == 0 then 0 else x + tri (x-1)
b <- readLn
pure $ if b then fac else tri
fIO defines two inner functions fac and tri, and returns one of them based on the
boolean value. Test it in ghci by running fIO >>= pure . ($5). It will display 120 if you enter True,
and 15 if you enter False.
Now, can fIO be defined in terms of fix gIO for some gIO, similar as what we did for the pure
factorial function? That would be difficult, since we can’t really make fIO a recursive function
that calls itself. Perhaps the best we can do is this:
gIO :: IO (Natural -> Natural) -> IO (Natural -> Natural)
gIO fIO' = do
b <- readLn
pure $ \x ->
if b
then if x == 0 then 1 else x * unsafePerformIO fIO' (x-1)
else if x == 0 then 0 else x + unsafePerformIO fIO' (x-1)
The unsafePerformIO is used to obtain Natural -> Natural by running the IO action
IO (Natural -> Natural), since we need a Natural -> Natural, i.e., the value
inside fIO' as opposed to fIO' itself, to apply to (x-1).
And even if we are cool with unsafePerformIO, the function
fix gIO is quite different from fIO, since we are running the IO action via
unsafePerformIO on every recursion. So if you want to compute the factorial of 5 via
fix gIO by running
fix gIO >>= pure . ($5)
You’ll need to enter True 6 times, whereas with fIO you only enter it once.
This is exactly the scenario where mfix comes into play. Again, note the type
of mfix:
mfix :: (a -> m a) -> m a
That a allows you to access the value inside m a, the latter of which is the result
we are trying to obtain. Using mfix to implement fIO is nice and simple:
fIOmfix :: IO (Natural -> Natural)
fIOmfix = mfix gIO'
where
gIO' :: (Natural -> Natural) -> IO (Natural -> Natural)
gIO' f = do
b <- readLn
pure $ \x ->
if b
then if x == 0 then 1 else x * f (x-1)
else if x == 0 then 0 else x + f (x-1)
Here f is the pure function inside fIOmfix. Since gIO' takes this pure function (unlike gIO), no unsafePerformIO
is needed, and the IO action is performed only once. So mfix gIO' is equivalent to
fIO.
mdo and rec Syntax Sugar
Haskell provides two syntax sugars, mdo and rec (both require the RecursiveDo extension), both
of which can be used to prettify functions like fIOmfix:
{-# LANGUAGE RecursiveDo #-}
fIOmdo :: IO (Natural -> Natural)
fIOmdo = mdo
b <- readLn
res <- pure $ \x ->
if b
then if x == 0 then 1 else x * res (x-1)
else if x == 0 then 0 else x + res (x-1)
pure res
fIOrec :: IO (Natural -> Natural)
fIOrec = do
b <- readLn
rec res <- pure $ \x ->
if b
then if x == 0 then 1 else x * res (x-1)
else if x == 0 then 0 else x + res (x-1)
pure res
In both fIOmdo and fIOrec, res is the pure Natural -> Natural function inside
the result (IO (Natural -> Natural)) that we are trying to obtain. They are both
equivalent to fIO and mfix gIO'.
There are subtle differences between mdo and rec but I won’t go into them here.
Further Reading
The following are helpful to getter a better understanding of how the magic stuff of
mfix works internally: