Optimising free monad programs using Plated

Posted on October 26, 2017

This document is a Literate Haskell file, available here


In this article I demonstrate how to use classy prisms and Plated to write and apply optimisations to programs written in a free monad DSL.

Plated is a class in lens that provides powerful tools to work with self-recursive data structures. One such tool is recursive bottom-up rewriting, which repeatedly applies a transformation everywhere in a Plated structure until it can no longer be applied.

The free monad has an instance of Plated:

Traversable f => Plated (Free f a)

so if you derive Foldable and Traversable for the underlying functor, you can use the Plated combinators on your free monad DSL.

Defining classy prisms for the base functor provides pattern matching on free monad programs for free. When coupled with rewrite, we get a system for applying optimisations to our free monad programs with minimal effort.


Let’s get into it.

{-# language DeriveFunctor #-}
{-# language DeriveFoldable #-}
{-# language DeriveTraversable #-}
{-# language FlexibleInstances #-}
{-# language FunctionalDependencies #-}
{-# language MultiParamTypeClasses #-}
{-# language TemplateHaskell #-}

module FreePlated where

import Control.Lens.Fold   ((^?))
import Control.Lens.Plated (Plated, rewrite)
import Control.Lens.Prism  (aside)
import Control.Lens.Review ((#))
import Control.Lens.TH     (makeClassyPrisms)
import Control.Monad.Free  (Free, liftF, _Free)
import Data.Monoid         (First(..))

 

First, define the DSL. This is a bit of a contrived one:

data InstF a
  = One a
  | Many Int a
  | Pause a
  deriving (Functor, Foldable, Traversable, Eq, Show)

type Inst = Free InstF

one :: Inst ()
one = liftF $ One ()

many :: Int -> Inst ()
many n = liftF $ Many n ()

pause :: Inst ()
pause = liftF $ Pause ()
  • one - “do something once”
  • many n - “do something n times”
  • pause - “take a break”

In this DSL, we are going impose the property that many n should be equivalent to replicateM_ n one.

 

Next, generate classy prisms for the functor.

makeClassyPrisms ''InstF

makeClassyPrisms ''InstF generates the following prisms:

  • _InstF :: AsInstF s a => Prism' s (InstF a)`
  • _One :: AsInstF s a => Prism' s a
  • _Many :: AsInstF s a => Prism' s (Int, a)`
  • _Pause :: AsInstF s a => Prism' s a`

 

Lift the classy prisms into the free monad:

instance AsInstF (Inst a) (Inst a) where
  _InstF = _Free

We can now use the prisms as if they had these types:

  • _One :: Prism' (Inst a) (Inst a)
  • _Many :: Prism' (Inst a) (Int, Inst a)
  • _Pause :: Prism' (Inst a) (Inst a)

If one of these prisms match, it means the program begins with that particular instruction, and the Inst a returned is the tail of the program.

 

Now it’s time to write optimisations over the free monad structure. A rewrite rule has the type a -> Maybe a- if the function returns a Just, the input will be replaced with the contents of the Just. If it returns Nothing then no rewriting will occur.

optimisations :: AsInstF s s => [s -> Maybe s]
optimisations = [onesToMany, oneAndMany, manyAndOne]
  where

Rule 1: one followed by one is equivalent to many 2

    onesToMany s = do
      s' <- s ^? _One._One
      pure $ _Many # (2, s')

Rule 2: one followed by many n is equivalent to many (n+1)

    oneAndMany s = do
      (n, s') <- s ^? _One._Many
      pure $ _Many # (n+1, s')

Rule 3: many n followed by one is equivalent to many (n+1)

    manyAndOne s = do
      (n, s') <- s ^? _Many.aside _One
      pure $ _Many # (n+1, s')

 

The last step is to write a function that applies all the optimisations to a program.

optimise :: (Plated s, AsInstF s s) => s -> s
optimise = rewrite $ getFirst . foldMap (First .) optimisations

getFirst . foldMap (First .) has type [a -> Maybe a] -> a -> Maybe a. It combines all the rewrite rules into a single rule that picks the first rule to succeed for the input.

 

Now we can optimise a program:

program = do
  one
  one
  one
  pause
  one
  one
  many 3
  one

The ones before the pause should collapse into many 3, and the instructions after the pause should collapse into many 6.

ghci> optimise program == (many 3 *> pause *> many 6)
True

:)

> Isaac Elliott

Isaac really likes types