The Free monad explained – part 1

My next series of blog posts will be on the Free monad.  Don’t be scared by the name, it’s not so difficult to grasp it (if explains correctly) and your code could greatly improve.

Why ? Because it allows a clear separation between the program definition (i.e. the business logic) and its implementation. That means if one day you decide to change your implementation (e.g change your persistent storage) you just need to re-write your interpreter – leaving the business logic untouched.

Moreover being a functional construct it compose “nicely”. We can combine different programs into a bigger one, re-using all the interpreters already defined.

Doesn’t all that sound great? So let’s get started and implement our Free monad as we go through a simple implementation.

In this post I’ll show you all the plumbing (at least a simplified version of it) but that should get you a feel of how things work. If you want to use the free monad in your code don’t implement it yourself (especially if you use the code below – which is not stack-safe) but use a library such as cats or scalaz.

Note: This post is largely inspired by a talk from Rúnar Bjarnason at ScalaDays 2014.

First, let’s create ur own language with a set of instructions to interact with a user.

sealed trait Interact[A]
case class Tell(message: String) extends Interact[Unit]
case class Ask(prompt: String) extends Interact[String]

We define 2 instructions Tell that simply displays a message to the user and Ask to asks the user for an input. Ask extends Interact[String] because it ‘returns’ the input provided by the user. (Tell extends Interact[Unit] because there is nothing to ‘return’).

Then we’d like to compose our program with a set of instructions. Something like:

val prog = List(
   Ask("What's your firstname?"),
   Ask("What's your surname?"),
   Tell("Hello, ????")
)

As you can see, we have a problem here. We can’t access the user inputs. In fact we’d rather write something like this:

val prog = for {
   firstname <- Ask("What's your firstname?")
   surname   <- Ask("What's your surname?")
   _         <- Tell(s"Hello, $firstname $surname")
} yield ()

That looks much better but it doesn’t compile. To be able write this we need to provide map and flatMap operations for our Interact data type (Basically we need to turn our datatype into a monad).

import scala.language.higherKinds
trait Monad[M[_]] {
   def pure[A](a: A): M[A]
   def flatMap[A, B](ma: M[A])(f: A => M[B]): M[B]
}

This defines our Monad typeclass. What it says is that Monad takes a type constructor: M[_] (Think of M as Option, List, Future, …) and provide 2 operations:

  • pure: wraps a value into a context (e.g Some). This is also known as ‘return’
  • flatMap: turns an M[A] into an M[B] given a function that turns a A into a M[B]. (Again if you think as M being Option is the definition of Option’s flatMap function). This is also known as ‘bind’.

(Of course a Monad needs to obey the Monad’s law but let’s not worry about it here).

So we have our Interact data type and our Monad trait. We now need to transform our data type into a monad. To that matter let me introduce you to the Free monad:

sealed trait Free[F[_], A] {
   def flatMap[B](f: A => Free[F, B]): Free[F, B] = 
      this match {
         case Return(a)  => f(a)
         case Bind(i, k) => Bind(i, k andThen (_ flatMap f))
      }

   // if we have flatMap we can have map as well
   def map[B](f: A => B): Free[F, B] = 
      flatMap(a => Return(f(a)))
}
case class Return[F[_], A](a: A) extends Free[F, A]
case class Bind[F[_], I, A](i: F[I], k: I => Free[F, A]) extends Free[F, A]

Okay, let’s stop for a moment and see what we have here.
We have 2 case classes corresponding to the pure() and flatMap() operation and a Free trait that defines flatMap (and map).
If you follow the types in the flatMap operation you realise that there is not many solutions to implement it:

  • if we’re flatMapping over a Return(a) we just return f(a) which gives us a Free[F, B]
  • if we’re flatMapping over a Bind(i, k) we need to combine k with f to obtain a Free[F, B]. We apply k which gives us a Free[F, A] then we flatMap it with f to get a Free[F, B].

Let’s go on, we now have a Free monad and our data type. There is one thing missing to be able to write our for-comprehension:  a way to turn Interact[A] into a Free[Interact, A].

So let’s write a method that just do that:

import scala.language.implicitConversions

implicit def liftIntoFree[F[_], A](fa: F[A]): Free[F, A] =
   Bind(fa, (a: A) => Return(a))

Brilliant, we can now lift any F (including Interact) into Free. We make it implicit so that Scala will call it for us.

Good news, we are ready to write our program using the for-comprehension.

val prog: Free[Interact, Unit] = for {
   firstname <- Ask("What's your firstname?")
   surname   <- Ask("What's your surname?")
   _         <- Tell(s"Hello, $firstname $surname")
} yield ()

Good job, this times it compiles! Remember that for-comprehension is just syntactic sugar for map and flatMap. That means our program is equivalent to this:

val expandedProg: Free[Interact, Unit] =
   Ask("What's your firstname?").flatMap(firstname =>
      Ask("What's your surname?").flatMap(surname =>
         Tell(s"Hello, $firstname $surname").map(_ =>
            Return(())
         )
      )
   ) 

or even this:

val expandedProg2: Free[Interact, Unit] = 
   Bind[Interact, String, Unit](
      Ask("What's your firstname?"),
      firstname => Bind[Interact, String, Unit](
         Ask("What's your surname?"),
         surname => Bind[Interact, Unit, Unit](
            Tell(s"Hello, $firstname $surname"),
            _ => Return(())
         )
      )
   )

So what does our program do? Well, nothing yet! Remember we created Interact (Ask and Tell) out of thin air. They don’t do anything. So far we have just described what we want to do and that is brilliant because we have separated our program from its implementation. And this is just what the free monad is – a data structure that holds a computation without executing it.

To execute our program we need a way turn Interact into real-world actions. This is the job of the interpreter define with the following trait:

trait ~>[F[_], G[_]] {
   def apply[A](fa: F[A]): G[A] 
}

It has a special name so that it reads F ‘to’ G. Let’s write our first interpreter that uses the console to run our program. In this case Console is an interpreter from Interact to Id.

type Id[A] = A

object Console extends (Interact ~> Id) {
   def apply[A](i: Interact[A]): Id[A] = 
      i match {
         case Tell(message) => println(message)
         case Ask(prompt)   =>
            println(prompt)
            scala.io.StdIn.readLine()
      }
}

We need one more function to glue things together. This function will use our interpreter to turn the ‘sequence’ of Interact in the for-comprehension into a ‘sequence’ of Id (i.e. apply the side effects).
For that matter let’s add a foldMap function to our Free trait:

sealed trait Free[F[_], A] {
   // ... flatMap definition
   // ... map definition
   
   def foldMap[G[_]](f: F ~> G)(implicit monad: Monad[G]): G[A] =
      this match {
         case Return(a)  => monad.pure(a)
         case Bind(i, k) =>
            monad.flatMap(f(i))(a => k(a).foldMap(f))
      }
}

The implicit parameter is here to prove that G is a monad. We use it to wrap an A into a G and to flatMap G. So we need to provide a Monad instance for the Id type.

implicit val idMonad: Monad[Id] = new Monad[Id] {
   def pure[A](a: A): Id[A] = a
   def flatMap[A, B](a: Id[A])(f: A => Id[B]) = f(a)
}

At last we can run our little program:

prog.foldMap(Console)

Congratulations if you made it this far! We created a DSL (Domain Specific Language) with no implementation at all and write a tiny program using this language and using our Console interpreter we ran it in the console. That’s a big achievement!

However the test coverage isn’t great, so for sure, we need to do something about it.

Testing things on the console is not convenient. It would be much better if we could provide all the answers to the program and checks if it produces the expected output.

Well, let’s do just that by turning our program into a function that takes a map containing all the answers as input and returns a list with all the messages produced by the program. This is what the type Tester does. Tester is just a function that takes a map (i.e. the simulated user inputs – the keys are the prompts and the values the user inputs) and returns a list of output messages (along with the final results – which happens to be Unit here).

type Tester[A] = Map[String, String] => (List[String], A)

// Tester needs to be a monad
implicit val testerMonad = new Monad[Tester] {
   def pure[A](a: A): Tester[A] = (Nil, a)
   def flatMap[A, B](t: Tester[A])(f: A => Tester[B]): Tester[B] =
      inputs => {
         val (out1, a) = t(inputs)
         val (out2, b) = f(a)(inputs)
         (out1 ++ out2, b)
      }
}

Now let’s write a Test interpreter that turns Interact into a Tester function, and we’ll be ready to go.

object Test extends (Interact ~> Tester) {
   def apply[A](i: Interact[A]): Tester[A] = 
      i match {
         case Tell(message) => 
            _ => (message :: Nil, ())
         case Ask(prompt) =>
            inputs => (Nil, inputs(prompt))
      }
}

We can now easily test our program

val inputs = Map(
   "What's your firstname?" -> "John",
   "What's your surname?"   -> "Doe")
val (messages, _) = prog.foldMap(Test).apply(inputs)
// messages = List("Hello, John Doe")

Hope you appreciate how easy it was to switch the implementation of the program. This is one of the main advantage of the Free monad. The program is completely decoupled from its implementation.

User interaction is great but we need to combine it with code from other domains to make a useful application. Stay tuned because we’ll do just that in part-2.