Forging a DSL using Scala type classes

Tweet about this on TwitterShare on LinkedInShare on FacebookShare on Google+Share on Reddit

In this post we’re going to explore how to build a DSL (Domain Specific Language) with a user-friendly syntax while maintaining as much type-safety as possible. We want that any operations that is not allowed by the business rules fail at compile time. This would be really nice as it makes sure that no one writes such forbidden logic (even by mistake).

More over Scala provides really nice syntactic sugar that can make a DSL syntax pretty neat.

If you don’t know what type classes or don’t feel very comfortable with this concept, follow along as we’ll also explore how we can use them to dissociate data and behaviours (always a good practice).

To support this post we’ll build a DSL to support computation on people salaries.

Before getting there let’s start with something simple. We’ll definitely need some money so let’s define the following case classes:

final case class GBP(value: BigDecimal)
final case class EUR(value: BigDecimal)

We have 2 case classes to represent an amount of British pounds and an amount of Euros. We can also create some amounts in these currencies

val tenPounds = GBP(10)
val tenEuros  = EUR(10)

We can improve our DSL syntax by allowing the users to write something like:

val tenPounds = 10.GBP
val tenEuros  = 10.EUR

To enable this syntax we need an implicit class that defines a .GBP and .EUR method on Int. It’s good practice to place these classes into the companion object.

object GBP {
  implicit class GBPFromInt(value: Int) {
    def GBP = new GBP(value)
  }
  // could add more implicit classes for other types
}
object EUR {
  implicit class EURFromInt(value: Int) {
    def EUR = new EUR(value)
  }
  // could add more implicit classes for other types
}
import EUR._
import GBP._
val tenPounts = 10.GBP
val tenEuros  = 10.EUR

Nice! It works like magic. In fact it’s the scala implicit that makes it all possible. The compiler sees that we are calling the method EUR (or GBP) on Int. Of course there is no such methods, so it looks for an implicit class that takes an Int in its constructor and provides the EUR (or GBP) methods. It turns out this are just the implicit class we wrote.

So far so good, however there is not much we can do with our case classes. We would like to to some basic computations at least – something like:

10.GBP + 5.GBP
10.EUR - 5.EUR
10.GBP + 10.EUR // shouldn't compile: we don't want to mix euros and pounds together

Mmmh… What do we have here? Yes, behaviours. Our data model doesn’t need to change (our case classes are still valid – pounds are still pounds and euros still euros) but we need more functionalities. We need a way to add things together and a way to take away a given amount from another amount. We can even express minus in terms of plus if we have a way to negate an amount (i.e. multiply it by -1).

Let’s define what we need in a trait:

trait Amount[A] {
  def plus(a: A, b: A): A 
  def times(a: A, b: BigDecimal): A
}

These are the methods we need our case classes to support. If we have these operations available we can implement the operations we need: +, – (and even *, /).

The good news is that in Scala arithmetic operators are just regular methods.

// We can define the following method 
def +(a: A, b: A)
// and use it like this
val c = +(a, b)
// or like this
val d = b + c

There is no magic here it’s just 2 ways to call the method +. One looks nicer than the other but the compiler knows that it’s actually the same method call.

Notice that * and / should only allow to multiply an amount with a constant. Otherwise we’ll get EUR² which doesn’t really make sense. Fortunately we can encode exactly this in the method signatures:

trait Amount[A] {
  def plus(a: A, b: A): A 
  def times(a: A, b: BigDecimal): A

  class Ops(a: A) {
    def +(b: A): A = plus(a, b)
    def -(b: A): A = plus(a, times(b, -1))
    def *(b: BigDecimal): A = times(a, b)
    def /(b: BigDecimal): A = times(a, BigDecimal(1) / b)
  }
}

So we have our case class on one side and some behaviour on the other side. That doesn’t get us anywhere yet. We need a way to connect them together. We do this by implementing the Amount trait for each of our case classes.

object GBP {
  implicit class GBPFromInt(value: Int) {
    def GBP = new GBP(value)
  }
  implicit object GBPIsAmount extends Amount[GBP] {
    def plus(a: GBP, b: GBP): GBP = GBP(a.value + b.value)
    def times(a: GBP, b: BigDecimal): GBP = GBP(a.value * b)
  }
  implicit def ops(a: GBP) = new GBPIsAmount.Ops(a)
}
object EUR {
  implicit class EURFromInt(value: Int) {
    def EUR = new EUR(value)
  }
  implicit object EURIsAmount extends Amount[EUR] {
    def plus(a: EUR, b: EUR): EUR = EUR(a.value + b.value)
    def times(a: EUR, b: BigDecimal): EUR = EUR(a.value * b)
  }
  implicit def ops(a: EUR) = new EURIsAmount.Ops(a)
}

And that’s it. We have just added some behaviours to our case classes and we can now write

import EUR._
import GBP._
val fifteenEuros = 10.EUR + 5.EUR
val fifteenPounds = 20.GBP - 5.GBP
val tenEuros = 10.EUR * 2
val tenPounds = 20.GBP / 2
// we can't mix pounds and euros
10.EUR + 10.GBP // doesn't compile
// we can't multiply amounts
10.EUR * 10.EUR // doesn't compile

It works! Isn’t that great? The syntax is nice and we can only perform operations that make sense! We also have implemented our first type class so let’s review what we’ve done so far.

  1. We have created some case classes to represent our data
  2. We have define a trait with the methods we’d like our case classes to implement
  3. We have provided an implementation of the above trait for each of our case classes

We didn’t have to change our case classes in this example. The compiler figures out where the +, -, * and / resides by looking at the available implicit classes. In our case it’s the implicit methods ops() which provide the implicit class that provide the arithmetic operators.

The typeclass trait defines the behaviour and the typeclass instance implements this behaviour for a specific class.

Let’s just put some words on these concepts. The Amount trait is called a type class and the implementations (EURIsAmount and GBPIsAmount) are called a type class instances. These instances can be seen as a proof or evidence that EUR and GBP are actually Amounts.

Introducing percent into to DSL

Now we’d like to compute a percentage of an amount. To that matter let’s introduce a new case class Percent with a smart constructor (the same as we did for GBP and EUR).

case class Percent(value: BigDecimal)
object Percent {
  implicit class PercentFromInt(value: Int) {
    def percent = Percent(value)
  }
}

We can now write some percentages

import Percent._
val tenPercent = 10.percent

It would have been nice to use the percent sign (%) instead but it conflicts with the modulo operator.

Now we need to add behaviours for our percentages. However percent is not an Amount. It is actually a constant that can be used to multiply an Amount. In fact if we can convert our percentage to a BigDecimal we can use it directly into the Amount’s operation.

Moreover there is still one problem with our implementation of Amount as we can’t switch the terms of the multiplication.

val tenPounds = 5.GBP * 2
val tenEuros = 2 * 5.EUR // doesn't compile

Let’s introduce a new type class to address both of this issues.

trait Constant[A] {
  def underlying(a: A): BigDecimal
  // Add some ops for the Constant instances
  class Ops(a: A) {
    def *[B](b: B)(implicit amount: Amount[B]) = amount.times(b, underlying(a))
  }
}

We now have a type class that allows to retrieve the underlying value of a Constant so that we can use it with Amount.times.

We also added a * operation so that we can write our multiplication starting with the constant. Note that this method takes an implicit parameters amount. This is the evidence that B is an Amount. Given this Amount[B] we can use it to multiply the amount by a constant number.

At this point we need to revisit the Amount Ops so that * and / can work with any Constant (and not just BigDecimal).

trait Amount[A] {
  def plus(a: A, b: A): A
  def times(a: A, b: BigDecimal): A
  class Ops(a: A) {
      def +(b: A) = plus(a, b)
      def -(b: A) = plus(a, times(b, -1))
      def *[B](b: B)(implicit constant: Constant[B]): A = 
        times(a, constant.underlying(b))
      def /[B](b: B)(implicit constant: Constant[B]): A = 
        times(a, BigDecimal(1) / constant.underlying(b))
    }
}

Next step is to provide type class instances for Percent and Int.

object Percent {
  implicit class PercentFromInt(value: Int) {
    def percent = Percent(value)
  }
  implicit object PercentIsConstant extends Constant[Percent] {
    override def underlying(a: Percent): BigDecimal = a.value / 100 
  }
  implicit def percentOps(a: Percent) = new PercentIsConstant.Ops(a)
}
implicit object IntIsConstant extends Constant[Int] {
  override def underlying(a: Int): BigDecimal = BigDecimal(a)
}
implicit def intConstantOps(b: Int) = new IntIsConstant.Ops(b)

We’re now all setup so let’s try our new functionalities

val onePound = 10.percent * 10.GBP
val oneEuro = 10.EUR * 10.percent
val twoEuros = 2 * oneEuro 

Periodic amount

In this part we’re going to add some periodic amount to represent both monthly and yearly incomes. Of course we want some type-safety in our DSL and we shouldn’t be able to write a program that mixes yearly and monthly incomes together.

It would also be nice if we could re-use the Amount type class we have defined previously for our periodic amounts.

Let’s start by defining our monthly and yearly periods.

sealed trait Month
sealed trait Year
final val Month = new Month() {}
final val Year = new Year() {}
[/code]

Month and Year have nothing in common. They are not related in anyway although they both represent a Period. We could have used a sealed trait Period and have both Year and Month extends Period but as this post is all about type classes, let’s define a Period type class.

trait Period[A] 
object Period {
  implicit object MonthIsPeriod extends Period[Month] 
  implicit object YearIsPeriod extends Period[Year] 
}

We could have use objects to define month and year but the type signature would then show as Month.type which is not user friendly enough for our DSL. In fact there is a much nicer trick we can do with these definitions.

Now we have amounts on one side and periods on the other. Let’s combine them so that we can have amounts per period. To that matter let’s introduce the case class Per:

final case class Per[A, B: Period](value: A, period: B)
implicit class PerFromValue[A](value: A) {
  def per[B: Period](period: B): Per[A, B] = Per(value, period)
}

Nothing new here – except the syntax which is just another way to say it should be an implicit parameter of type Period[B] available.
I tend to prefer this syntax for type classes as it reads nicely: “B which complies to Period”.

Now let’s use some scala syntactic sugar which is available for free and we can create some periodic amounts:

val regularIncome: Per[GBP, Year] = 45000.GBP.per(Year)
// which can also be written as
val syntacticIncome: GBP Per Year = 65000.GBP per Year

Pretty neat, isn’t it? Now you see why a Year.type wasn’t very welcome in the type signature.

In order to be able to use our periodic amount as a regular Amount we need to provide implicit instances of Amount for the Per class.

implicit def perIsAmount[A: Amount, B: Period]: Amount[A Per B] = new Amount[A Per B] {
  val amount = implicitly[Amount[A]]
  val period = implicitly[Period[B]]
  def plus(a: A Per B, b: A Per B): A Per B =
    Per(quantity.plus(a.value, b.value), a.period)
  def times(A: A Per B, b: BigDecimal): A Per B =
    Per(quantity.times(a.value, b), a.period)
}
implicit perAmountOps[A, B](a: A Per B)(implicit amount: Amount[A Per B]) =
  new amount.Ops(a)

And that’s it we can now use periodic amounts just like we use regular amounts

val income1 = 2000.EUR per Month
val income2 = 1000.EUR per Month
val totalIncome = income1 + income2
val income3 = 1000.GBP per Month
val income4 = 12000.EUR per Year
income1 + income3 // doesn't compile: Can't mix EUR and GBP
income1 + income4 // doesn't compile: Can't mix Month and Year

One missing thing is to be able to convert monthly amount to yearly amount and vice versa.

Let’s define a convertTo method on the Period type class

sealed trait Period[B] {
  def changePeriod[A: Quantity, C: Period](a: A Per B, c: C): A Per C =
    if (c == a.period) Per(a.value, c)
    else if (c == month) Per(implicitly[Quantity[A]].times(a.value, BigDecimal(1) / 12), c)
    else Per(implicitly[Quantity[A]].times(a.value, 12), c)

  class Ops[A: Quantity](a: A Per B) {
    def convertTo[C: Period](c: C): A Per C = changePeriod(a, c)
  }
}
implicit def perPeriodOps[A: Amount, B: Period](a: A Per B) = {
   val period = implicitly[Period[B]]
   new period.Ops(a)
}

We can now easily turn monthly amounts into yearly amounts

val monthlyAmount = 1000.EUR per Month
val yearlyAmount = monthlyAmount convertTo Year

Summing over a sequence of amounts

If you’re familiar with the Scala collections you probably know that there is a sum method available on these collections. This signature of sum looks like this:

def sum[B >: A](implicit num: Numeric[B]): B

Yes you’ve spotted it right there! Numeric is a type class. If our amounts were numeric we could sum them up, right out of the box! However I chose to not doing so because Numeric defines a method “times” which doesn’t make sense in case of amounts (as we discussed previously).

On the other hand I’d like to have a sum functionality for collection of amounts. So let’s add this:

trait Amount[A] {
  def zero: A // new method which creates a zero amount
  def plus(a: A, b: A): A
  def times(a: A, b: BigDecimal): A
  
  class Ops(a: A) {
    def +(b: A) = plus(a, b)
    def -(b: A) = plus(a, times(b, -1))
    def *[B](b: B)(implicit constant: Constant[B]): A = times(a, constant.underlying(b))
    def /[B](b: B)(implicit constant: Constant[B]): A = times(a, BigDecimal(1) / constant.underlying(b))
  }

  // new class which adds the total method on sequences
  class SeqOps(seq: Seq[A]) {
    // sum might be more appropriate however it is already define on TraversableOnce (which Seq inherits from)
    // and the TraversableLike sum will be used without giving a chance to resolve this one
    def total: A = seq.foldLeft(zero)(plus)
  }
}

As we’ve added a new method “zero” to the Amount trait we need to implement it in all the typeclass instances.

It’s quite straight forward for EUR and GBP but slightly more tricky for periodic amounts as we need a way to create a period out of nothing.

We add an “instance” method in the Period typeclass which will just return a Month or Year depending on the typeclass instance that is used.

trait Period[A] {
  def instance: A
}
implicit object MonthIsPeriod extends Period[Month] {
  def instance: Month = Month
}
implicit object YearIsPeriod extends Period[Year] {
  def instance: Year = Year
}
implicit def perIsAmount[A: Amount, B: Period]: Amount[A Per B] = new Amount[A Per B] {
  val amount = implicitly[Amount[A]]
  val period = implicitly[Period[B]]
  def zero: A Per B = Per(amount.zero, period.instance)
  // ...
}
implicit object EURIsAmount extends Amount[EUR] {
  def zero: EUR = EUR(0)
  // ...
}
implicit object GBPIsAmount extends Amount[GBP] {
  def zero: GBP = GBP(0)
  // ...
}

Then, we need to define the implicit SeqOps instances:

implicit def seqOps[A](seq: Seq[A])(implicit amount: Amount[A]) = new amount.SeqOps(seq)

And here we go – another new functionality available in our DSL!

val totalIncome = List(1000.GBP per Month, 500.GBP per Month, 300.GBP per month).total
val mixedIncome = List(1000.EUR per Month, 10000.EUR per Year).total // doesn't compile

It would be nice if we could sum up periodic amount over different periods. We have already all the pieces we need:
We can sum up a list of amount and we can convert a periodic amount over another period.

implicit class PerMonthOrYearSequence[A: Amount](seq: Seq[A Per _]) {
  def unifyPer[C: Period](c: C): Seq[A Per C] =
    seq.map { a =>
      if (a.period == month) a.asInstanceOf[A Per Month] convertTo c
      else a.asInstanceOf[A Per Year] convertTo c
      }
    }
}

I agree, it doesn’t look very good but I haven’t yet figure out a proper way to handle the period type as it looks something like

Per[A, _ >: Year with Month <: Object]

Any suggestions welcome here!

Still it works and we can now solve our mixed income problem:

val mixedIncome = List(1000.EUR per Month, 10000.EUR per Year).unifyPer(Year).total

Conclusion

Scala really shines for DSL definition. Our tiny finance DSL reads like plain english. It means program written with this DSL can be easily understood by non-technical people. Moreover we gain additional type-safety as we cannot mix different units together (EUR with POUNDS or monthly with yearly amounts). And we were able to define this DSL with just plain Scala, there is no import of any library whatsoever. We only used the syntactic sugar offered by the language!

I started this blog post to experiment with type classes. It turns out that type classes are a very powerful tool.
It’s much more flexible than class inheritance and allows to decouple data representation from their behaviours. This is a great as data tends to be more stable than its associated behaviours.

There is a bit of boiler-plate, though, when deriving the type class instances. EUR and GBP companion object look very similar indeed. Fortunately there is a way to reduce such boilerplate by deriving type class instances automatically with shapeless.

Sometimes the implicit resolution can be tricky – especially when the implicit is shadowed by another method with a similar name (e.g. sum).

Finally you can have a look at the source code used in this post directly on Github.