Or, how to optimize MapReduce, and when folds are faster than loops

Purely functional programming might actually be worth the pain, if you care about large-scale optimization.

Lately, I've been studying how to speed up parallel algorithms. Many parallel algorithms, such as Google's MapReduce, have two parts:

  1. First, you transform the data by mapping one or more functions over each value.
  2. Next, you repeatedly merge the transformed data, "reducing" it down to a final result.

Unfortunately, there's a couple of nasty performance problems lurking here. We really want to combine all those steps into a single pass, so that we can eliminate temporary working data. But we don't always want to do this optimization by hand---it would be better if the compiler could do it for us.

As it turns out, Haskell is an amazing testbed for this kind of optimization. Let's build a simple model, show where it breaks, and then crank the performance way up.

Trees, and the performance problems they cause

We'll use single-threaded trees for our testbed. They're simple enough to demonstrate the basic idea, and they can be generalized to parallel systems. (If you want know how, check out the papers at the end of this article.)

A tree is either empty, or it is a node with a left child, a value and a right child:

data Tree a = Empty
            | Node (Tree a) a (Tree a)
  deriving (Show)

Here's a sample tree containing three values:

tree = (Node left 2 right)
  where left  = (Node Empty 1 Empty)
        right = (Node Empty 3 Empty)

We can use treeMap to apply a function to every value in a tree, creating a new tree:

treeMap :: (a -> b) -> Tree a -> Tree b

treeMap f Empty = Empty
treeMap f (Node l x r) =
  Node (treeMap f l) (f x) (treeMap f r)

Using treeMap, we can build various functions that manipulate trees:

-- Double each value in a tree.
treeDouble tree = treeMap (*2) tree

-- Add one to each value in a tree.
treeIncr tree   = treeMap (+1) tree

What if we want to add up all the values in a tree? Well, we could write a simple recursive sum function:

treeSum Empty = 0
treeSum (Node l x r) =
  treeSum l + x + treeSum r

But for reasons that will soon become clear, it's much better to refactor the recursive part of treeSum into a reusable treeFold function ("fold" is Haskell's name for "reduce"):

treeFold f b Empty = b
treeFold f b (Node l x r) =
  f (treeFold f b l) x (treeFold f b r)

treeSum t = treeFold (\l x r -> l+x+r) 0 t

Now we can double all the values in a tree, add 1 to each, and sum up the result:

treeSum (treeIncr (treeDouble tree))

But there's a very serious problem with this code. Imagine that we're working with a million-node tree. The two calls to treeMap (buried inside treeIncr and treeDouble) will each create a new million-node tree. Obviously, this will kill our performance, and it will make our garbage collector cry.

Fortunately, we can do a lot better than this, thanks to some funky GHC extensions.

Getting rid of the intermediate trees

So how do we get rid of those intermediate trees? Well, we could merge:

treeSum (treeIncr (treeDouble tree))

...into a single recursive call:

treeSumIncrDouble Empty = 0
treeSumIncrDouble (Node l x r) =
  treeSumIncrDouble l + (x*2+1) + treeSumIncrDouble r

But that's ugly, because it breaks the encapsulation of treeSum, etc. Worse, it requires us to manually intervene and write code every time we hit a bottleneck.

Now, here's where the GHC magic comes in. First, we add the following line to the top of our source file:

{-# OPTIONS_GHC -O -fglasgow-exts -ddump-simpl-stats #-}

This turns on optimization, enables certain GHC-specific extensions, and tells GHC to summarize the work of the optimizer. (Also, we need to make sure that profiling is turned off, because it blocks certain optimizations.)

Next, let's walk through the first optimization we want the compiler to perform---merging two calls to treeMap into one:

treeIncr (treeDouble tree)

-- Inline treeIncr, treeDouble
treeMap (+1) (treeMap (*2) tree)

-- Combine into a single pass
treeMap ((+1) . (*2)) tree

Here's the magic part. We can use the RULES pragma to explain this optimization to the compiler:

  "treeMap/treeMap"  forall f g t.

  treeMap f (treeMap g t) = treeMap (f . g) t

Note that this is only valid in a pure functional language like Haskell! If we were working in ML or Lisp, then f and g might have side effects, and we couldn't safely combine the two passes without doing a lot more work.*

We can similarly merge an adjacent treeFold/treeMap pair into a single pass:

  "treeFold/treeMap"  forall f b g t.
  treeFold f b (treeMap g t) =
    treeFold (\l x r -> f l (g x) r) b t

Using just these two rules, I saw a 225% increase in the number of nodes processed per second. Under the right circumstances, GHC can even outperform C code by applying these kinds of inter-procedural optimizations.

Where to learn more

Rewrite rules are documented in the GHC manual and on the Haskell Wiki. Don Stewart also suggests using QuickCheck to verify the correctness of rewrite rules.

There's also a lot of good papers on this subject. Here are a few:

  1. Functional Programming with Bananas, Lenses, Envelopes and Barbed Wire uses fold and related combinators to optimize recursive functions.
  2. Theorems for Free shows how to automatically derive valid rewrite rules for map and any polymorphic function. This is closely related to the idea of a natural transformation in category theory.
  3. Cheap Deforestation for Non-Strict Functional Languages discusses techniques for eliminating intermediate "trees" from a computation. See also Deforestation: Transforming Programs to Eliminate Trees. Thanks, pejo!
  4. Comprehending Queries uses map/fold fusion to optimize database queries.
  5. Rewriting Haskell Strings uses rewrite rules to massively improve Haskell's string performance.
  6. Google's MapReduce Programming Model---Revisited analyzes MapReduce in more detail, porting it to Haskell. Thanks, augustss!
  7. Data Parallel Haskell: A status report shows how to use rewrite rules to optimize nested data parallel code. This is how to optimize a parallel map/reduce.


(Special thanks to Don Stewart, who helped me get all this working. See also his insanely optimized tree code for further performance ideas.)