This morning, a programmer visited #haskell and asked how to implement backtracking. Not surprisingly, most of the answers involved monads. After all, monads are ubiquitous in Haskell: They're used for IO, for probability, for error reporting, and even for quantum mechanics. If you program in Haskell, you'll probably want to understand monads. So where's the best place to start?

A friend of mine claims he didn't truly understand monads until he understood join. But once he figured that out, everything was suddenly obvious. That's the way it worked for me, too. But relatively few monad tutorials are based on join, so there's an open niche in a crowded market.

This monad tutorial uses join. Even better, it attempts to cram everything you need to know about monads into 15 minutes. (Hey, everybody needs a gimmick, right?)

Backtracking: The lazy way to code

We begin with a backtracking constraint solver. The idea: Given possible values for x and y, we want to pick those values which have a product of 8:

solveConstraint = do
x <- choose [1,2,3]
y <- choose [4,5,6]
guard (x*y == 8)
return (x,y)

Every time choose is called, we save the current program state. And every time guard fails, we backtrack to a saved state and try again. Eventually, we'll hit the right answer:

> take 1 solveConstraint
[(2,4)]

Let's build this program step-by-step in Haskell. When we're done, we'll have a monad. <!--more-->

Implementing choose

How can we implement choose in Haskell? The obvious version hits a dead-end quickly:

-- Pick one element from the list, saving
-- a backtracking point for later on.
choose :: [a] -> a
choose xs = ...

We could be slightly sneakier, and return all the possible choices as a list. We'll use Choice whenever we talk about these lists, just to keep things clear:

type Choice a = [a]

choose :: [a] -> Choice a
choose xs = xs

Running this program returns all the possible answers:

> choose [1,2,3]
[1,2,3]

Now, since Haskell is a lazy language, we can work with infinite numbers of choices, and only compute those we actually need:

> take 3 (choose [1..])
[1,2,3]

Combining several choices

Now we have the list [1,2,3] from our example. But what about the list [4,5,6]? Let's ignore guard for a minute, and work on getting the final pairs of numbers, unfiltered by any constraint.

For each item in the first list, we need to pair it with every item in the second list. We can do that using map and the following helper function:

pair456 :: Int -> Choice (Int,Int)
pair456 x = choose [(x,4), (x,5), (x,6)]

Sure enough, this gives us all 9 combinations:

> map pair456 (choose [1,2,3])
[[(1,4),(1,5),(1,6)],
[(2,4),(2,5),(2,6)],
[(3,4),(3,5),(3,6)]]

But now we have two layers of lists. We can fix that using join:

join :: Choice (Choice a) -> Choice a
join choices = concat choices

This collapses the two layers into one:

> join (map pair456 (choose [1,2,3]))
[(1,4),(1,5),(1,6),
(2,4),(2,5),(2,6),
(3,4),(3,5),(3,6)]

Now that we have join and map, we have two-thirds of a monad! (Math trivia: In category theory, join is usually written μ.)

In Haskell, join and map are usually combined into a single operator:

-- Hide the standard versions so we can
-- reimplement them.
import Prelude hiding ((>>=), return)

(>>=) :: Choice a -> (a -> Choice b) ->
Choice b
choices >>= f = join (map f choices)

This allows us to simplify our example even further:

> choose [1,2,3] >>= pair456
[(1,4),(1,5),(1,6),
(2,4),(2,5),(2,6),
(3,4),(3,5),(3,6)]

We're getting close! We only need to define the third monad function (and then figure out what to do about guard).

The missing function is almost too trivial to mention: Given a single value of type a, we need a convenient way to construct a value of type Choice a:

return :: a -> Choice a
return x = choose [x]

(More math trivia: return is also known as unit and η. That's a lot of names for a very simple idea.)

Let's start assembling the pieces. In the code below, (\x -> ...) creates a function with a single argument x. Pay careful attention to the parentheses:

makePairs :: Choice (Int,Int)
makePairs =
choose [1,2,3] >>= (\x ->
choose [4,5,6] >>= (\y ->
return (x,y)))

When run, this gives us a list of all possible combinations of x and y:

> makePairs
[(1,4),(1,5),(1,6),
(2,4),(2,5),(2,6),
(3,4),(3,5),(3,6)]

As it turns out, this is a really common idiom, so Haskell provides some nice syntactic sugar for us:

makePairs' :: Choice (Int,Int)
makePairs' = do
x <- choose [1,2,3]
y <- choose [4,5,6]
return (x,y)

This is equivalent to our previous implementation:

> makePairs'
[(1,4),(1,5),(1,6),
(2,4),(2,5),(2,6),
(3,4),(3,5),(3,6)]

The final piece: guard

In our backtracking monad, we can represent failure as a choice between zero options. (And indeed, this is known as the "zero" for our monad. Not all useful monads have zeros, but you'll see them occasionally.)

-- Define a "zero" for our monad.  This
-- represents failure.
mzero :: Choice a
mzero = choose []

-- Either fail, or return something
-- useless and continue the computation.
guard :: Bool -> Choice ()
guard True  = return ()
guard False = mzero

solveConstraint = do
x <- choose [1,2,3]
y <- choose [4,5,6]
guard (x*y == 8)
return (x,y)

Note that since the return value of guard is boring, we don't actually bind it to any variable. Haskell treats this as if we had written:

-- "_" is an anonymous variable.
_ <- guard (x*y == 8)

That's it!

> take 1 solveConstraint
[(2,4)]

Every monad has three pieces: return, map and join. This pattern crops up everywhere. For example, we can represent a computation which might fail using the Maybe monad:

returnMaybe :: a -> Maybe a
returnMaybe x = Just x

mapMaybe :: (a -> b) -> Maybe a -> Maybe b
mapMaybe f Nothing  = Nothing
mapMaybe f (Just x) = Just (f x)

joinMaybe :: Maybe (Maybe a) -> Maybe a
joinMaybe Nothing  = Nothing
joinMaybe (Just x) = x

Once again, we can use do to string together individual steps which might fail:

tryToComputeX :: Maybe Int
tryToComputeX = ...

maybeExample :: Maybe (Int, Int)
maybeExample = do
x <- tryToComputeX
y <- tryToComputeY x
return (x,y)

Once you can explain how this works, you understand monads. And you'll start to see this pattern everywhere. There's something deep about monads and abstract algebra that I don't understand, but which keeps cropping up over and over again.

Miscellaneous notes

In Haskell, monads are normally defined using the Monad type class. This requires you to define two functions: return and >>=. The map function for monads is actually named fmap, and you can find it in the Functor type class.

Also, every monad should obey three fairly reasonable rules if you don't want bad things to happen:

-- Adding and collapsing an outer layer
-- leaves a value unchanged.
join (return xs) == xs

-- Adding and collapsing an inner layer
-- leaves a value unchanged.
join (fmap return xs) == xs

-- Join order doesn't matter.
join (join xs) == join (fmap join xs)