A Quick-and-Dirty Explanation of MonadFix

This post aims to give a quick-and-dirty explanation of MonadFix. It is quick and dirty because MonadFix is complex and weird stuff (or at least it seems so), and I do not think I am able to present a comprehensive discussion. Instead, this post mainly consists of a simple example that hopefully can give one a rough idea of what MonadFix does and where it might be useful.

It is recommended that one has basic knowledge of how fix works before reading on, since I will be making comparison between fix and mfix, the method in the MonadFix type class.

TL;DR

MonadFix is a type class containing a single method, mfix:

class Monad m => MonadFix m where
  mfix :: (a -> m a) -> m a

Note the similarity between the types of fix and mfix:

fix :: (a -> a) -> a
mfix :: (a -> m a) -> m a

mfix is basically the monadic version of fix. Their definitions also bear similarity:

fix f = f (fix f)
mfix f = f =<< mfix f

Note that these are not how fix and mfix are implemented. They should be considered specifications or definitions as opposed to implementations.

So,

fix f = f (f (f ...
mfix f = f =<< f =<< f =<< ...

Next, I’ll give a quick recap of fix (it won’t be comprehensive by any means) before giving an example of mfix, since understanding the former helps with the latter.

Recap of fix

fix is a combinator that returns the (least) fix point of a function, and can be used to eliminate recursion. For example, suppose we have a recursively defined value a of type A:

a :: A
a = <body that mentions a>

A recursive value like this can be defined non-recursively as fix f for some function f: if we define <body that mentions a> as f a, i.e.,

f :: A -> A
f x = <body that mentions x>

then we have a = f a, and so a is a fixed point of f. Note that a is a fixed point of f, not the fixed point, since f may have more than one fixed point. For example, take a to be the following function:

a :: Integer -> Integer
a = \x -> a (x+1)

Obviously a x evaluates to bottom (or ⊥) for every x. The corresponding f is

f :: (Integer -> Integer) -> Integer -> Integer
f a = \x -> a (x+1)

Now, this f has infinitely many fixed points. Any function const n where n is an integer or ⊥ is a fixed point of f, for instance

f (const 42) x = 42 = const 42 x

for any x, which shows const 42 is a fixed point of f. The function a defined by above is basically const ⊥, which is the least fixed point of f (it is “least” because ⊥ is considered to be “less” than any non-⊥ value, where “less” is not < on numbers, but a partial order defined on values of a given type). Since fix f always returns the least fixed point of f, we have a = fix f.

As another example, here’s the original and fixpoint definition of the factorial function:

facOriginal :: Natural -> Natural
facOriginal = \x -> if x == 0 then 1 else x * facOriginal (x-1)

facFix :: Natural -> Natural
facFix = fix g
  where
    g :: (Natural -> Natural) -> Natural -> Natural
    g f = \x -> if x == 0 then 1 else x * f (x-1)

MonadFix

As mentioned before, mfix is the monadic version of fix. It is useful if we want to compute a monadic value, whose body mentions the value inside the monad, i.e.,

aM :: M A
aM = <body that mentions a>

where M is some monad. Note that the body of aM doesn’t mention aM itself (otherwise it would be a normal recursion), but mentions the value a :: A inside aM :: M A. This is where mfix is useful.

Analogous to what we did with fix above, let’s define fM to be the following function:

fM :: A -> M A
fM x = <body that mentions x>

Then we have aM = fM =<< aM, and so aM = mfix fM.

Now let’s see a concrete example.

An Example of mfix

In this example, let the monad be IO. The goal is to write an IO action that asks the user to enter a boolean value, and depending on that boolean value, either returns the factorial function, or the function that computes the kth triangular number (i.e., 1 + 2 + ... + k).

A “regular” implementation of this IO action would look like this:

fIO :: IO (Natural -> Natural)
fIO = do
  let fac = \x -> if x == 0 then 1 else x * fac (x-1)
      tri = \x -> if x == 0 then 0 else x + tri (x-1)
  b <- readLn
  pure $ if b then fac else tri

fIO defines two inner functions fac and tri, and returns one of them based on the boolean value. Test it in ghci by running fIO >>= pure . ($5). It will display 120 if you enter True, and 15 if you enter False.

Now, can fIO be defined in terms of fix gIO for some gIO, similar as what we did for the pure factorial function? That would be difficult, since we can’t really make fIO a recursive function that calls itself. Perhaps the best we can do is this:

gIO :: IO (Natural -> Natural) -> IO (Natural -> Natural)
gIO fIO' = do
  b <- readLn
  pure $ \x ->
    if b
      then if x == 0 then 1 else x * unsafePerformIO fIO' (x-1)
      else if x == 0 then 0 else x + unsafePerformIO fIO' (x-1)

The unsafePerformIO is used to obtain Natural -> Natural by running the IO action IO (Natural -> Natural), since we need a Natural -> Natural, i.e., the value inside fIO' as opposed to fIO' itself, to apply to (x-1).

And even if we are cool with unsafePerformIO, the function fix gIO is quite different from fIO, since we are running the IO action via unsafePerformIO on every recursion. So if you want to compute the factorial of 5 via fix gIO by running

fix gIO >>= pure . ($5)

You’ll need to enter True 6 times, whereas with fIO you only enter it once.

This is exactly the scenario where mfix comes into play. Again, note the type of mfix:

mfix :: (a -> m a) -> m a

That a allows you to access the value inside m a, the latter of which is the result we are trying to obtain. Using mfix to implement fIO is nice and simple:

fIOmfix :: IO (Natural -> Natural)
fIOmfix = mfix gIO'
  where
    gIO' ::  (Natural -> Natural) -> IO (Natural -> Natural)
    gIO' f = do
      b <- readLn
      pure $ \x ->
        if b
          then if x == 0 then 1 else x * f (x-1)
          else if x == 0 then 0 else x + f (x-1)

Here f is the pure function inside fIOmfix. Since gIO' takes this pure function (unlike gIO), no unsafePerformIO is needed, and the IO action is performed only once. So mfix gIO' is equivalent to fIO.

mdo and rec Syntax Sugar

Haskell provides two syntax sugars, mdo and rec (both require the RecursiveDo extension), both of which can be used to prettify functions like fIOmfix:

{-# LANGUAGE RecursiveDo #-}

fIOmdo :: IO (Natural -> Natural)
fIOmdo = mdo
  b <- readLn
  res <- pure $ \x ->
    if b
      then if x == 0 then 1 else x * res (x-1)
      else if x == 0 then 0 else x + res (x-1)
  pure res

fIOrec :: IO (Natural -> Natural)
fIOrec = do
  b <- readLn
  rec res <- pure $ \x ->
        if b
          then if x == 0 then 1 else x * res (x-1)
          else if x == 0 then 0 else x + res (x-1)
  pure res

In both fIOmdo and fIOrec, res is the pure Natural -> Natural function inside the result (IO (Natural -> Natural)) that we are trying to obtain. They are both equivalent to fIO and mfix gIO'.

There are subtle differences between mdo and rec but I won’t go into them here.

Further Reading

The following are helpful to getter a better understanding of how the magic stuff of mfix works internally:

  • Take a look at the source code of Control.Monad.Fix to see how mfix is implemented for various monads
  • Will Fancher’s post on MonadFix
  • A paper on RecursiveDo
  • MonadFix on HaskellWiki