Scala


I recently began writing, as an exercise, some unit of measure code in Scala. I saw a headline in my newsreader some months ago about a Scala library for handling units of measure and I made a point NOT to read it because it sounded to me like an interesting problem and I wanted to first take a stab at it myself and then compare my solution to the one from the article and maybe even write my own article on my solution.

In the course of trying out a couple of designs I encountered a situation I wasn’t sure how to handle. I wanted an abstract class called Dimension which would encapsulate a measurement in some unit and I wanted to create several subtypes extending Dimension such as Length, Time, Mass, Temperature, etc. All Dimensions should be able to sum together, but only with their own kind. For example, 1 meter + 2 feet should give 1.6096 meters and 1 kilogram + 500 grams should give 1.5 kilograms. However, it makes no sense to add 30 seconds and 45 degrees Celsius. I wanted to arrange the types in such a way that a user of the library would not have the option of adding dimensions of two different types.

Here’s my initial code:

  abstract class Dimension(val value: Double, val name: String, val coef: Double) {
    def +(x: Dimension): Dimension
    override def toString(): String = value + " " + name
  }

  class Time(value: Double, name: String, coef: Double) extends
  Dimension(value, name, coef) {
    def +(x: Dimension): Time= new Time(value + coef * x.value / x.coef, name, coef)
  }

  class Length(value: Double, name: String, coef: Double) extends 
  Dimension(value, name, coef) {
    def +(x: Dimension): Length= new Length(value + coef * x.value / x.coef, name, coef)
  }

You can see the Dimension class declares an addition operator to satisfy our requirement that all Dimensions must be additive. No surprises so far.

The Time and Length classes extend Dimension and are concrete, so they must implement the addition operator. The operators have the same signature as the one from Dimension except for the return type. When creating a subtype, we are allowed to narrow the return types, so I made them more specific. Time.+ returns not merely a Dimension as in the supertype but a Time. Parameters, on the other hand, can only have their types widened in subtypes, so they remain Dimension. This is because the return type is a covariant position and the parameter type is a contravariant position. If you don’t know what variance is, I have an article on it.

This code has two main weaknesses. First, a developer subclassing Dimension is trusted to return the same type as the class. That is to say, a developer could write:

  class Length(value: Double, name: String, coef: Double) extends 
  Dimension(value, name, coef) {
    def +(x: Dimension): Time = new Time(value + coef * x.value / x.coef, name, coef)
  }

The developer could mix the types! This would return an unexpected and nonsensical result. The second and much more serious weakness is that a user of the library doesn’t have to pass a Length to Length.+. The user could write:

  val sum = new Length(10.0, "meters", 1.0) + new Time(10.0, "seconds", 1.0)

Nothing is stopping him from doing this.

What I wanted was a Dimension class that enforced some additional rules on its descendants. Dimension should dictate not only that all its subtypes must implement the addition operator but also that the operator should only accept a parameter of the same type as the class in which it is defined and that the operator should also return that same type. In short, I wanted to force all subtypes of Dimension to look like this:

  class X(...) extends Dimension(...) {
    def +(x: X): X = ...
  }

Learning to Accept Yourself (As a Parameter)

As I considered the problem I pretty quickly noticed a similarity to a basic Scala trait: scala.Ordered. Ordered[A] is used primarily in sorting. By mixing in Ordered, you can define an ordering for any class. The reason I thought of ordered is that includes the abstract method “compare (that : A) : Int”. This method compares “this” to “that”. In other words, it takes a parameter that it is able to compare with itself, which is *usually* a parameter of the same type. Here’s a typical use of Ordered[A]:

class Student(val lastName: String, val firstName: String) 
  extends Ordered[Student]
{
  def compare(that: Student) = 
    (lastName + "," + firstName).compare(that.lastName + "," + that.firstName)
}

That’s close to what I want, but not quite. Ordered[A] allows you to implement a class that can be compared to itself, but it doesn’t require it. You could implement “class Apples extends Ordered[Oranges]”, literally comparing apples and oranges. So this arrangement (a parameterized class or trait in which the subtypes specify themselves as the type parameter) allows, but does not enforce, the structure that I want. So Ordered[A] provides a clue, but not the complete solution.

Becoming Self Aware

The missing piece is a little-known Scala construct called the explicit self type. It is a way of specifying what type the “this” reference must have. The Scala-lang website has an article explaining another situation in which explicit self types are useful: specifying that within a class “this” should refer to an abstract variable type.

Here’s a very simple example of how explicit self types work. Here’s a base trait called TraitA and two traits that make use of TraitA. TraitB1 uses an explicit self type to denote that “this” must be of type TraitA, and TraitB2 uses extends to inherit from TraitA.

trait TraitA {
  def t1(): String
}

trait TraitB1 {
  self: TraitA =>
  def t2(): String = "TraitB1.t2 !" + t1() + "!"
}

trait TraitB2 extends TraitA {
  def t2(): String = "TraitB2.t2!" + t1() + "!"
}

TraitB1 and TraitB2 are exactly alike except for the way they gain access to the t1() method. Here’s an interpreter session in which we create some classes that extend these traits.

scala> class Class1 extends TraitB1 {
     |   def t1() = "Class1.t1"
     |   override def toString() = "Class1: " + t1() + " " + t2()
     | }
<console>:6: error: illegal inheritance;
 self-type Class1 does not conform to TraitB1's selftype TraitB1 with TraitA
       class Class1 extends TraitB1 {
                            ^

scala> class Class2 extends TraitB1 with TraitA {
     |   def t1() = "Class2.t1"
     |   override def toString() = "Class2: " + t1() + " " + t2()
     | }
defined class Class2

scala> new Class2
res0: Class2 = Class2: Class2.t1 TraitB1.t2 !Class2.t1!

scala> class Class3 extends TraitB2 {
     |   def t1() = "Class3.t1"
     |   override def toString() = "Class3: " + t1() + " " + t2()
     | }
defined class Class3

scala> new Class3
res1: Class3 = Class3: Class3.t1 TraitB2.t2!Class3.t1!

Class1 extends TraitB1, the trait that used the explicit self type. The class defines t1(), so all the necessary implementation is there, but the compile fails anyway. The explicit self type says that “this” must be of type TraitA but neither Class1 nor TraitB1 extends TraitA, so even though the t1() method is supplied, “this” cannot have type TraitA for Class1 because Class1 does not inherit from TraitA.

Class2 and Class3 compile and run just fine. Class2 is identical to Class1 except that it mixes in TraitA. Since it is declared “with TraitA” the explicit self type is satisfied because “this” can have type TraitA.

Class3 is identical to the others except that it extends TraitB2. TraitB2 is declared as extending TraitA, so Class3 compiles because it indirectly inherits from TraitA.

Use explicit self types with care! They can be a little dangerous if you use them to subvert Scala’s compile-time type checking. For example:

trait StringMaker {
  def makeString(): String
}

class DoesntCompile extends StringMaker {
  override def toString() = "DoesntCompile " + makeString
}

class Compiles {
  self: StringMaker =>
  override def toString() = "Compiles " + makeString
}

The first class is declared as extending StringMaker, but it doesn’t implement the makeString method. This class fails to compile, and rightly so. The compiler does its job and warns you that the class won’t work.

The second class includes an explicit self type. The class named Compiles says that it must be a StringMaker. Now, it says this internally, not externally. The class declaration says nothing about StringMaker. Any code that used the Compiles class wouldn’t know that it’s supposed to be a StringMaker. The class compiles but when you try to instantiate it you get an exception. Not only is it an exception, it’s a NullPointerException which crashes my interpreter (!!!) which makes me think this may be a bug.

Time to Self Actualize

My solution to the problem was a combination of the “class X extends Ordered[X]” idiom and the explicit self type. Here it is:

  abstract class Dimension[T](val value: Double, val name: String, val coef: Double) {
    self: T =>
    protected def create(value: Double, name: String, coef: Double): T
    def +(x: Dimension[T]): T = create(value + coef * x.value / x.coef, name, coef)
    override def toString(): String = value + " " + name
  }

  class Time(value: Double, name: String, coef: Double) extends
        Dimension[Time](value, name, coef) {
    protected def create(a: Double, b: String, c: Double) = new Time(a, b, c)
  }

  class Length(value: Double, name: String, coef: Double) extends
        Dimension[Length](value, name, coef) {
    protected def create(a: Double, b: String, c: Double) = new Length(a, b, c)
  }

  class Mass(value: Double, name: String, coef: Double) extends
        Dimension[Length](value, name, coef) {
    protected def create(a: Double, b: String, c: Double) = new Length(a, b, c)
  }

This compiles just fine except for the last class, which I included to demonstrate that the compiler enforces conformance to the explicit self type. Every class that extends Dimension[X] must itself be an X. That enforces the rule I wanted. Here’s the compiler error for the last class.

$ scalac Units.scala
Units.scala:19: error: illegal inheritance;
 self-type Mass does not conform to Dimension[Length]'s selftype Dimension[Length] with Length
        Dimension[Length](value, name, coef) {
        ^
one error found

That basically says if you want to extend Dimension[Length] then you better be a Length yourself.

Considering that when I had investigated this problem for about 10 minutes I was nearly ready to call it impossible, this is a surprisingly simple and not too cryptic solution. Plus, it’s a usage of explicit self types that I hadn’t seen before. I wonder, in fact, why the Ordered[A] trait itself doesn’t use this trick.

As a bonus, here’s a sneak peek at part of my Units library so far.

  abstract class Dimension[T](val value: Double, val name: String, val coef: Double) {
    self: T =>
    protected def create(value: Double, name: String, coef: Double): T
    def +(x: Dimension[T]): T = create(value + coef * x.value / x.coef, name, coef)
    def -(x: Dimension[T]): T = create(value - coef * x.value / x.coef, name, coef)
    override def toString(): String = value + " " + name
  }

  class Time(value: Double, name: String, coef: Double) extends
        Dimension[Time](value, name, coef) {
    protected def create(a: Double, b: String, c: Double) = new Time(a, b, c)
  }

  class Length(value: Double, name: String, coef: Double) extends
        Dimension[Length](value, name, coef) {
    protected def create(a: Double, b: String, c: Double) = new Length(a, b, c)
  }

  abstract class TimeUnit(name: String, coef: Double) {
    def apply(value: Double) = new Time(value, name, coef)
    def apply(orig: Time) = new Time(0, name, coef) + orig
  }

  object Second   extends TimeUnit("seconds",    1.0)
  object Minute   extends TimeUnit("minutes",    1.0 / 60)
  object Hour     extends TimeUnit("hours",      1.0 / 3600)

  abstract class LengthUnit(name: String, coef: Double) {
    def apply(value: Double) = new Length(value, name, coef)
    def apply(orig: Length) = new Length(0, name, coef) + orig
  }

  object Meter      extends LengthUnit("meters",      1.0)
  object Inch       extends LengthUnit("inches",      1.0 / .0254)
  object Foot       extends LengthUnit("feet",        1.0 / .0254 / 12)

And here’s what it looks like in the interpreter:

scala> val length1 = Meter(3)
length1: Length = 3.0 meters

scala> val length2 = Foot(4.5)
length2: Length = 4.5 feet

scala> length1 + length2
res0: Length = 4.3716 meters

scala> length2 + length1
res1: Length = 14.34251968503937 feet

scala> Inch(length1 + length2)
res2: Length = 172.11023622047244 inch

scala> val time1 = Second(90)
time1: Time = 90.0 seconds

scala> val time2 = Hour(.75)
time2: Time = 0.75 hours

scala> Minute(time2)
res3: Time = 45.0 minutes

scala> time1 + length1
<console>:14: error: type mismatch;
 found   : Length
 required: Dimension[Time]
       time1 + length1
               ^

Don’t forget to subscribe to my RSS feed, or follow this blog on Twitter.
Copyright © 2009 Matthew Jason Malone

Here’s a three-for-one special for you: A post about implementing the Levenshtein string distance algorithm in Scala AND refactoring it from an imperative style to a functional style AND I even throw in a short lesson on memoization. To make sure that our refactoring is correct and preserves the expected behavior, I’ll unit test the code along the way using ScalaTest. ScalaCheck, JUnit, or TestNG would work just as well, but I used ScalaTest.

First Things First

“What, exactly,” some of you may be asking, “is Levenshtein string distance? Some kind of Teutonic tailoring terminology?” Not at all. It’s a way of measuring how alike or different two strings of symbols are. For example, the string “sturgeon” is a lot more similar to “surgeon” than to “urgently”. “Sturgeon” and “surgeon” are only a single letter (t) apart. They have a string distance of 1. “Sturgeon” and “urgently” share some letters, but each has some letters not present in the other. So what’s their string distance? It’s not so obvious now.

String distance is useful. The use that most quickly springs to mind is spelling correction. If I type “computwr” a string distance algoritm could tell us that “computwr” is very close to the dictionary word “computer”. But there’s more to it than that. There are a lot of fuzzy problems in which you want to find which two sets of complex data are most similar. One way to solve a problem like that is to encode it into a string compare using string distance. For example, you could write a program to find pen strokes in an image and encode their shapes (up, curve left, down, angle right, etc) as a string which could be compared to known encodings for handwritten letters. By finding the closest matches you can create a handwriting recognition program. String distance is useful in DNA analysis, of course, recognizing patterns in signals, and a host of other situations.

The Levenshtein string distance algorithm was developed by Vladimir Levenshtein in 1965. It can easily tell us the distance from “sturgeon” to “urgently”. This algorithm breaks down string transformation into three basic operations: adding a character, deleting a character, and replacing a character. Each of these operations is assigned a cost, usually a cost of 1 for each operation. Leaving a character unchanged has a cost of 0. So to go from “surgeon” to “sturgeon” you leave the “s” unchanged for a cost of 0. Then you add a “t” for a cost of 1. All the other letters also remained unchanged, so the total cost is 1, just as we expected.

To change “sturgeon” to “urgently” is harder. They have the same number of letters, so we could just do a replacement on each one for a distance of 8. But is that the shortest distance? What if we try to re-use that “urge” from “sturgeon”? Can we re-use the “n”? Does that help? What about the “t”? We need an algorithm that we can follow.

The Grid

What we need is a way to find the cheapest combination of operations which changes the first string to the second. That’s the Levenshtein algorithm. It works like this. Write the first string vertically from top to bottom. To the right of each letter write 1, 2, 3, etc. Write a 0 above the number 1. Then write the second string horizontally and again add the numbers 1, 2, 3 etc. to the right of the 0. It will look something like this:

    u r g e n t l y
  0 1 2 3 4 5 6 7 8
s 1
t 2
u 3
r 4
g 5
e 6
o 7
n 8

That’s the first step. The grid remains un-filled-in at this point. Can you guess what the grid positions are going to hold? They are going to hold the cost to convert the various prefixes of “sturgeon” to the various prefixes of “urgently”. The position at the intersection of row r4, and the column a2, for example, will contain the cost to convert “stur” to “ur”. We fill in these positions (left-to-right then top to bottom) with the smallest of the following three numbers:

  • The number above the current position plus one.
  • The number to the left of the current position plus one.
  • The above-left number if row letter and column letter are the same, or the above-left number plus one otherwise.

When we finish, the bottom right corner contains the cost to convert the first string to the second.

But Why? (Understanding the Algorithm)

Those are some shockingly simple rules! Let’s examine how they translate into string distance.

First, what are those numbers 0 to n that we write along the top and left? They’re not just indices. The numbers along the top represent the cost to convert the empty string to the various prefixes of “urgently”. The cost is 0 to convert “” to “”, 1 to convert “” to “u”, 2 to convert “” to “ur”, and so on. The numbers on the left are the cost to convert the prefixes of “sturgeon” to the the empty string . “” to “” is 0, “s” to “” is 1, and so forth. These costs are obvious. The only way to convert “” to “urgently” is to add eight letters. There’s nothing to delete, nothing to replace. The only way to convert “sturgeon” to “” is to delete all eight letters.

Each position, as we have established, represents the cost to convert the string of characters down the left side ending in the current row’s character into the the string of characters along the top ending at the current column’s character. Put another way, for any given position let’s call the current row’s letter A, and the current column’s letter B. If we use a colon (:) to indicate string concatenation then the beginning string, the one along the left of the grid, can be written Prefix1:A. So for our example word “sturgeon” if we look at a position in row e6 then Prefix1 is “sturg” and the final letter, which we’re calling A, is “e”.

So, speaking in terms of Prefix1, Prefix2, A, and B, we use the following inputs:

  • The cost to change Prefix1 to Prefix2:B (the number above the current position).
  • The cost to change Prefix1:A to Prefix2 (the number to the left).
  • The cost to change Prefix1 to Prefix2 (the above-left number).

If we know the costs of these three conversions we can find the cost to change Prefix1:A to Prefix2:B using this logic:

  • We know that if converting Prefix1 to Prefix2:B has a cost of X, then Prefix1:A can be converted to Prefix2:B for the cost of X plus the cost of deleting A, or X + 1.
  • We know that if converting Prefix1:A to Prefix2 has a cost of Y, then Prefix1:A can be converted to Prefix2:B for the cost of Y plus the cost of adding B, or Y + 1.
  • We know that if converting Prefix1 to Prefix2 has a cost of Z, then Prefix1:A can be converted to Prefix2:B for the cost of Z plus the cost of replacing A with B, or Z + (0 or 1 depending on whether A = B).

If you understand the above logic, then you understand this really neat algorithm. It’s a quintessential example of dynamic programming. The solution is built up by solving simpler problems. You start with the trivial case of converting to and from the empty string, and then you build up the solution for the prefixes until you have the complete solution.

In the grid’s initial configuration there is only one location for which we know all three costs and that is row s1, column u1. That’s the only space for which all three neighboring values (above, left, and above-left) are populated. After we fill in this one, there are two more spaces available to us. Those are (t2, u1) and (s1, r2). Ordinarily, though, the spaces are populated line by line.

A Simple Example

Let’s do a quick example by hand. Then we’ll take a stab at implementing it in code. What’s the string distance from “hat” to “tape”? First, our empty grid:

    t a p e
  0 1 2 3 4
h 1
a 2
t 3

The first space is row h, column t. The letters are different so our choices are 1 + 1 (above), 1 + 1 (left), or 0 + 1 (above-left). 0 + 1 is the smallest value, so the first space gets populated with 1. The next space has choices 2 + 1 (above), 1 + 1 (left), or 1 + 1 (above-left). 1 + 1 is the lowest, so we fill in the second space with 2. Once we finish the row, we have this:

    t a p e
  0 1 2 3 4
h 1 1 2 3 4
a 2
t 3

The next space is row a, column t. Our choices are 1 + 1 (above), 2 + 1 (left), or 1 + 1 (above-left). 1 + 1 is the smallest value, so the space gets populated with 2. The next space has choices 2 + 1 (above), 2 + 1 (left), or 1 + 0 (above-left). Why 1 + 0? Because the above-left value is 1 and both letters for this space are “a” so we can replace “a” with “a” for free. Go ahead and finish the whole grid. This is the result:

    t a p e
  0 1 2 3 4
h 1 1 2 3 4
a 2 2 1 2 3
t 3 2 2 2 3

The strings “hat” and “tape” have a distance of 3.

Some Code, Finally

As fascinating as Levenshtein distance is, and as much more as there is to say on the topic, the time has come to write some code. Here’s a Scala implementation that is very close to the pencil-and-paper approach that we just performed.

import scala.Math.min
import scala.Math.max

object StringDistance {
  def stringDistance(s1: String, s2: String): Int = {
    def minimum(i1: Int, i2: Int, i3: Int) = min(min(i1, i2), i3)

    val dist = new Array[Array[Int]](s1.length + 1, s2.length + 1)

    for (idx <- 0 to s1.length) dist(idx)(0) = idx
    for (jdx <- 0 to s2.length) dist(0)(jdx) = jdx

    for (idx <- 1 to s1.length; jdx <- 1 to s2.length)
      dist(idx)(jdx) = minimum (
        dist(idx-1)(jdx  ) + 1,
        dist(idx  )(jdx-1) + 1,
        dist(idx-1)(jdx-1) + (if (s1(idx-1) == s2(jdx-1)) 0 else 1)
      )

    dist(s1.length)(s2.length)
  }

Do you see what I mean when I say it’s close to the pencil-and-paper approach? We actually construct a two-dimensional Array to represent the grid we drew earlier. It’s a very literal implementation.

To explain the code briefly, we declare a singleton StringDistance having a single method called stringDistance. Within this method we declare a 3-argument minimum method. (I wonder why there’s no “def min(params: Int*): Int” defined in scala.Math.) Then we create an array called “dist” and populate the top row and leftmost column in lines 8-11. The for loop on line 13 cycles through each array position from left to right then top to bottom, and populates them according to the Levenshtein logic. Finally, once the grid is filled in we return the number in the bottom right position.

The job’s not done yet, of course. We haven’t written our unit tests. Here’s the test I wrote:

import org.scalatest._

class StringDistanceSuite extends FunSuite with PrivateMethodTester {

  test("stringDistance should work on empty strings") {
    assert( StringDistance.stringDistance(   "",    "") == 0 )
    assert( StringDistance.stringDistance(  "a",    "") == 1 )
    assert( StringDistance.stringDistance(   "",   "a") == 1 )
    assert( StringDistance.stringDistance("abc",    "") == 3 )
    assert( StringDistance.stringDistance(   "", "abc") == 3 )
  }

  test("stringDistance should work on equal strings") {
    assert( StringDistance.stringDistance(   "",    "") == 0 )
    assert( StringDistance.stringDistance(  "a",   "a") == 0 )
    assert( StringDistance.stringDistance("abc", "abc") == 0 )
  }

  test("stringDistance should work where only inserts are needed") {
    assert( StringDistance.stringDistance(   "",   "a") == 1 )
    assert( StringDistance.stringDistance(  "a",  "ab") == 1 )
    assert( StringDistance.stringDistance(  "b",  "ab") == 1 )
    assert( StringDistance.stringDistance( "ac", "abc") == 1 )
    assert( StringDistance.stringDistance("abcdefg", "xabxcdxxefxgx") == 6 )
  }

  test("stringDistance should work where only deletes are needed") {
    assert( StringDistance.stringDistance(  "a",    "") == 1 )
    assert( StringDistance.stringDistance( "ab",   "a") == 1 )
    assert( StringDistance.stringDistance( "ab",   "b") == 1 )
    assert( StringDistance.stringDistance("abc",  "ac") == 1 )
    assert( StringDistance.stringDistance("xabxcdxxefxgx", "abcdefg") == 6 )
  }

  test("stringDistance should work where only substitutions are needed") {
    assert( StringDistance.stringDistance(  "a",   "b") == 1 )
    assert( StringDistance.stringDistance( "ab",  "ac") == 1 )
    assert( StringDistance.stringDistance( "ac",  "bc") == 1 )
    assert( StringDistance.stringDistance("abc", "axc") == 1 )
    assert( StringDistance.stringDistance("xabxcdxxefxgx", "1ab2cd34ef5g6") == 6 )
  }

  test("stringDistance should work where many operations are needed") {
    assert( StringDistance.stringDistance("example", "samples") == 3 )
    assert( StringDistance.stringDistance("sturgeon", "urgently") == 6 )
    assert( StringDistance.stringDistance("levenshtein", "frankenstein") == 6 )
    assert( StringDistance.stringDistance("distance", "difference") == 5 )
    assert( StringDistance.stringDistance("java was neat", "scala is great") == 7 )
  }

}

It tests several special cases as well as the general case. All we have to do is compile our StringDistance object and the StringDistanceSuite unit test, fire up the scala interpreter and run the test:

scala> (new StringDistanceSuite).execute()
Test Starting - StringDistanceSuite: recursiveStringDistance should work on empty strings
Test Succeeded - StringDistanceSuite: recursiveStringDistance should work on empty strings
Test Starting - StringDistanceSuite: recursiveStringDistance should work on equal strings
Test Succeeded - StringDistanceSuite: recursiveStringDistance should work on equal strings
Test Starting - StringDistanceSuite: recursiveStringDistance should work where only inserts are needed
Test Succeeded - StringDistanceSuite: recursiveStringDistance should work where only inserts are needed
Test Starting - StringDistanceSuite: recursiveStringDistance should work where only deletes are needed
Test Succeeded - StringDistanceSuite: recursiveStringDistance should work where only deletes are needed
Test Starting - StringDistanceSuite: recursiveStringDistance should work where only substitutions are needed
Test Succeeded - StringDistanceSuite: recursiveStringDistance should work where only substitutions are needed
Test Starting - StringDistanceSuite: stringDistance should work where many operations are needed
Test Succeeded - StringDistanceSuite: stringDistance should work where many operations are needed

scala>

Refactoring: Reduced Memory Usage

The code passes all the tests, so let’s take things a step further. One shortcoming of our implementation is that it can require a lot of memory. That array has to be of size (n+1)*(m+1) where n and m are the lengths of the two strings we’re comparing. If you want to apply the method to strings that are a few hundred characters long (or longer) then you’re starting to talk about some serious memory requirements. How can we reduce the amount of memory required? Can you think of a way?

Once we complete one row of the grid, we use it again as an input when we compute the next row. But after that we’re done with it. Why leave it to clutter the heap? Let’s tweak our implementation slightly. Try rewriting the method using only two rows. Fill the first row with the initial 0, 1, 2, etc. Then use one to compute the other over and over. Think about how you would implement this, then have a look at my solution below.

  def stringDistance(s1: String, s2: String): Int = {
    def minimum(i1: Int, i2: Int, i3: Int) = min(min(i1, i2), i3)

    var dist = ( new Array[Int](s1.length + 1),
                 new Array[Int](s1.length + 1) )

    for (idx <- 0 to s1.length) dist._2(idx) = idx

    for (jdx <- 1 to s2.length) {
      val (newDist, oldDist) = dist
      newDist(0) = jdx
      for (idx <- 1 to s1.length) {
        newDist(idx) = minimum (
          oldDist(idx) + 1,
          newDist(idx-1) + 1,
          oldDist(idx-1) + (if (s1(idx-1) == s2(jdx-1)) 0 else 1)
        )
      }
      dist = dist.swap
    }

    dist._2(s1.length)
  }

This one uses a Pair (also called a Tuple2) containing two one-dimensional arrays, instead of a 2*(n+1) array. Pair happens to have the very handy “swap” method which we can use to trade out the rows when we’ve finished one and are ready to compute the next.

This is where our unit tests really show their worth. No need to wonder whether this new code really does work. We just recompile, run the tests again, and we can see that the code still gives us the expected results.

Refactoring: Recursion, Kind Of

What are some other ways we could write this code? Can it be improved? I thought I would try to replace the iteration in the previous implementations with recursion. Here’s what that code looks like:

  def stringDistance(s1: String, s2: String): Int = {
    def newCost(ins: Int, del: Int, subst: Int, c1: Char, c2:Char) =
      Math.min( Math.min( ins+1, del+1 ), subst + (if (c1 == c2) 0 else 1) )

    def getNewCosts(s1: List[Char], c2: Char, delVal: Int, prev: List[Int] ): List[Int] = (s1, prev) match {
      case (c1 :: _ , substVal :: insVal :: _) =>
        delVal :: getNewCosts(s1.tail, c2, newCost(insVal, delVal, substVal, c1, c2), prev.tail)
      case _ => List(delVal)
    }

    def sd(s1: List[Char], s2: List[Char], prev: List[Int]): Int = s2 match {
      case Nil => prev.last
      case _ => sd( s1, s2.tail, getNewCosts(s1, s2.head, prev.head+1, prev) )
    }

    (s1, s2) match {
      case (`s2`, `s1`) => 0
      case (_, "") | ("", _) => max(s1.length, s2.length)
      case _ => sd(s1.toList, s2.toList, (0 to s1.length).toList)
    }
  }

It’s a pretty naïve implementation, actually. It just replaces the the repetition of the two for loops with the repetition of the recursion of the two methods sd and getNewCosts. The sd method is even tail-recursive, allowing scala to optimize it. It does the same basic thing as the for loop version, though. It recurses through the characters of a row in the getNewCosts method, and it recurses through the rows of the grid in the sd method.

It looks more complicated than the previous implementations. It’s harder to read. But it passes our unit tests, so we can be pretty sure it’s correct.

Refactoring: List Methods

After the last implementation, I thought it looked a little sloppy. I wondered whether I could improve the situation by using some of the many useful methods built into scala’s List class. Here is the comparatively brief code that resulted:

  def stringDistance(s1: String, s2: String): Int = {
    def sd(s1: List[Char], s2: List[Char], costs: List[Int]): Int = s2 match {
      case Nil => costs.last
      case c2 :: tail => sd( s1, tail,
          (List(costs.head+1) /: costs.zip(costs.tail).zip(s1))((a,b) => b match {
            case ((rep,ins), chr) => Math.min( Math.min( ins+1, a.head+1 ), rep + (if (chr==c2) 0 else 1) ) :: a
          }).reverse
        )
    }
    sd(s1.toList, s2.toList, (0 to s1.length).toList)
  }

Like I say, it’s brief. Those List methods give you a lot of mileage.

Refactoring: Real Recursion

The more I looked at my previous attempt at a recursive solution, the more I realized how hare-brained it was. It wasn’t real recursion. It was just iteration using the stack. So I went back to the drawing board. If I want to know the final answer, the value in the bottom right position, how do I get it? I apply my three rules to the above, left, and above-left positions. How do I get those positions? Apply the rules again. That’s real recursion. Here’s my first stab at it:

  def stringDistance(s1: String, s2: String): Int = {
    def min(a:Int, b:Int, c:Int) = Math.min( Math.min( a, b ), c)
    def sd(s1: List[Char], s2: List[Char]): Int = (s1, s2) match {
      case (_, Nil) => s1.length
      case (Nil, _) => s2.length
      case (c1::t1, c2::t2)  => min( sd(t1,s2) + 1, sd(s1,t2) + 1,
                                     sd(t1,t2) + (if (c1==c2) 0 else 1) )
    }
    sd( s1.toList, s2.toList )
  }

Now, THAT is a nice looking recursive function. That’s more like it. See the pattern match block? If we try to convert from any string to the empty string or from the empty to a non-empty string then we just use the string length. You see? That takes care of the positions along the top and left of the grid. All the others are determined in the last case. That last case just applies our three rules. That is so short and simple. It’s a thing of beauty.

The only problem? It doesn’t work.

It’s technically correct. It will return correct answers … eventually. Or, rather, I think it will. I can’t be sure because it’s too slow to complete my unit tests! For anything but very short inputs (4 or 5 characters), the function takes a long time to return. Why? Let’s look at how many recursive calls are made for some inputs.

If we use strings “a” and “b” we pass over the “(_, Nil)” and “(Nil, _)” cases in the first call to function sd, because both our strings (Lists, actually) are non-empty. This results in three more calls to sd. Each of these three calls includes an empty List of characters, so there is no more recursion. That’s a total of four calls to sd for strings “a” and “b”.

What about “ab” and “xy”? Think about it for a moment? Step through the function in your head. How many calls to sd will there be for “ab” and “xy”?

Have you done it? I count 19. What about “abc” and “xyz”? I’ll save you the trouble. It’s 94. For length 4 strings it’s 481. For length 5 it’s 2524. Length 6 is 13,483. Then 73 thousand, then 400 thousand, then 2 million and so forth. Why so many calls? Each position in the grid is computed using all the positions to the left and all the positions above the current position. So a position in the top left will be computed and recomputed many times.

There is a way to get around this, of course. You probably already have some ideas. We’re going to do something called memoization. When you memoize a function, you make it remember results that it computed previously without having to actually recompute them. I’ll do that by caching results in a map. The map’s key is a Pair of List[Char]s, the inputs to my inner function sd, and its data is an Int, the return type of sd. I will modify sd to first check the map to see if the result for the current parameters has already been cached. If so, we simply return it. If not, we compute the value, cache it, and return it.

  def stringDistance(s1: String, s2: String): Int = {
    val memo = scala.collection.mutable.Map[(List[Char],List[Char]),Int]()
    def min(a:Int, b:Int, c:Int) = Math.min( Math.min( a, b ), c)
    def sd(s1: List[Char], s2: List[Char]): Int = {
      if (memo.contains((s1,s2)) == false)
        memo((s1,s2)) = (s1, s2) match {
          case (_, Nil) => s1.length
          case (Nil, _) => s2.length
          case (c1::t1, c2::t2)  => min( sd(t1,s2) + 1, sd(s1,t2) + 1,
                                         sd(t1,t2) + (if (c1==c2) 0 else 1) )
        }
      memo((s1,s2))
    }

    sd( s1.toList, s2.toList )
  }

There. That uglies up my function somewhat, but at least it’s usable now. And it passes my unit tests, so I’m reasonably assured that it’s right.

Memoization only works if you expect the same result for each identical function call. If your function takes input from stdin, for example, you can’t memoize that. Or if it has a random component. Or if, for any other reason, its return value is not always the same for the same inputs. You can memoize functions in different ways. There’s a post on Michid’s Weblog about a more general solution, a memoizing class which wraps existing functions to give you a memoized version.

PHEW!

Ok, I’ve tried to keep my rambling to a minimum in this post but it’s still a doozy. The things I wanted to get across are:

  • The usefulness of Levenshtein distance in solving a variety of problems.
  • How to understand the Levenshtein distance algorithm and why it works.
  • How to use unit tests to improve your code while maintaining some assurance that the new code still has the correct behavior.
  • Some of the different ways of implementing Levenshtein.

In the end, I think I like the second implementation (the one that switches out the two rows) and the last implementation the best. The second one seems to have good performance. I did some informal performance tests and it has a good mix of performance and simplicity. The last one, the memoized recursive one, appeals to me because it is in a more functional style and still has respectable performance.

Don’t forget to subscribe to my RSS feed, or follow this blog on Twitter.
Copyright © 2009 Matthew Jason Malone

In Scala, as in Java, C and many other languages, identifiers may contain a mix of lower and upper case characters. These identifiers are treated in a case sensitive manner. For example “index”, “Index” and “INDEX” would be treated as three separate identifiers. You can define all three in the same scope. That goes for Scala, Java, and most if not all descendants of the C language. In most of these languages, although case is significant in distinguishing identifiers, and although various capitalization schemes are used by convention, case does not alter functionality. Whether you name a variable “index”, “Index” or “INDEX”, as long as you don’t hide an identifier from an enclosing scope, the code will function in exactly the same way.

Scala, though, diverges slightly from this tradition. Here’s an example. Say we have a Pair of two Ints. Say we also have two plain Int values and we want to know whether those two Ints are equal to the values inside the Pair. In the case that they do match, we also want to know in which order they appear in the Pair.

Operations on tuples (such as a Pair) can often be implemented neatly by pattern matching. Here’s one solution to this problem:

def matchPair(x: (Int,Int), A: Int, b: Int): String = 
x match {
  case (A, b) => "Matches (A, b)"
  case (b, A) => "Matches (b, A)"
  case _      => "Matches neither"
}

This is completely unsurprising code except for one little detail. One of the function parameters is upper case while the other two are lower case. Other than that, there’s nothing unusual so far. So let’s try out this code in the Scala interpreter.

scala> def matchPair(x: (Int,Int), A: Int, b: Int): String = 
     | x match {
     |   case (A, b) => "Matches (A, b)"
     |   case (b, A) => "Matches (b, A)"
     |   case _      => "Matches neither"
     | }
matchPair: ((Int, Int),Int,Int)String

scala> val pair = (5, 10)
pair: (Int, Int) = (5,10)

scala> matchPair( pair,  5, 10 )
res1: String = Matches (A, b)

scala> matchPair( pair, 10,  5 )
res2: String = Matches (b, A)

scala> matchPair( pair, 99, 99 )
res3: String = Matches neither

So far so good! It returns the expected value when the values match in order, in reverse order, and when both values don’t match. Is this sufficient unit testing? What other tests would you run?

As you may well guess, no, this isn’t sufficient unit testing. Let’s try the case where one Int matches but not the other:

scala> matchPair( pair,  5, 99 )
res4: String = Matches (A, b)

scala> matchPair( pair, 99, 10 )
res5: String = Matches (b, A)

That didn’t work right. Is the matchPair function telling us that ‘pair’ (which is (5, 10) ) matches (5, 99) or (99 10)? That’s what it look like, but no. Scala does something a little bit surprising here. Do you know why?

As I said before, you can have variable and constants in Scala with upper or lower case names. Both are legal, just as they are in Java. But Scala makes some distinctions that Java doesn’t. Within a pattern (the part between ‘case’ and ‘=>’) Scala treats simple lower case identifiers differently. It uses them as new variables into which matched data is stored, but this is not the case for identifiers that begin with an upper case letter!

If you want to capture results of a pattern match in Scala you must use a lower case identifier and that identifier will hide any identifiers with the same name from an enclosing scope. So in our example function, “case (A, b)” matches a Pair. The first element of the pair is “A” which start with an upper case letter, so pattern matching results can’t be stored in it. It is used the way we intended, i.e. the pattern is matched if x._1 equals A.

The “b” in “case (A, b)”, though, begins with a lower case letter so it is assigned the value of x._2 (assuming x._1 equals A). It is as if you had typed “val b = x._2” in the function body. Within the case line, the “b” from the pattern hides the parameter named “b”.

So how can we make this function work the way we want? Here’s one way:

def matchPair(x: (Int,Int), A: Int, B: Int): String = 
x match {
  case (A, B) => "Matches (A, B)"
  case (B, A) => "Matches (B, A)"
  case _      => "Matches neither"
}

Now both the the Int parameters start with an upper case letter and are therefore tested against x._1 and x._2. This code passes our tests. Note that the code behaves differently simply based on the parameter names we choose. There’s another way to prevent Scala from using the identifiers for storing pattern results.

def matchPair(x: (Int,Int), a: Int, b: Int): String = 
x match {
  case (`a`, `b`) => "Matches (a, b)"
  case (`b`, `a`) => "Matches (b, a)"
  case _          => "Matches neither"
}

You can use the more traditional lower case parameter names if you quote them using the backquote character. That’s the key to the left of the “1” on my keyboard. This matchPair is equivalent to the one that used capital “A” and “B”.

Another quick example:

scala> val pair = (5, 10)
pair: (Int, Int) = (5,10)

scala> val (a,b) = pair
a: Int = 10
b: Int = 5

scala> val (X,Y) = pair
:5: error: not found: value X
       val (X,Y) = pair
            ^
:5: error: not found: value Y
       val (X,Y) = pair
              ^

You know that you can use the construct from line 4 above to declare and initialize multiple vals or vars using a tuple, right? And you know how that magic is done? Patterns, so the same principle applies here. The capitalized identifiers “X” and “Y” are taken to refer to existing identifiers because they can’t be used to store pattern match results. Since no such identifiers had been defined, you get an error.

If we define these values beforehand then Scala tries to match their values:

scala> val pair = (5,10)
pair: (Int, Int) = (5,10)

scala> val I = 5
I: Int = 5

scala> val J = 10
J: Int = 10

scala> val (I,q) = pair
q: Int = 10

scala> val (J,r) = pair
scala.MatchError: (5,10)
        at .<init>(<console>:6)
        at .<clinit>(<console>)
        at RequestResult$.<init>(<console>:3)
        at RequestResult$.<clinit>(<console>)
        at RequestResult$result(<console>)
        at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
        at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:39)
        at sun.reflect.DelegatingMethodAccessorImpl.invoke(...

scala> val (I,J) = pair

In line 10, the value I has been declared and initialized to 5 so it matches the first part of the Pair. The identifier q becomes a new val initialized to the value in the second part of the Pair.

In line 13, we do the same thing but we try to match J to the first part of the Pair. This won’t work since pair._1 is 5 and J is 10. A MatchError is thrown.

In line 24, we use both of the capitalized identifiers. They match, but there are no lower case identifiers to make new values out of, so the line does nothing except to confirm (by not throwing an Error) that I equals pair._1 and J equals pair._2.

Now you know how to match when you want to match and assign results when you want to assign results. I hope that being able to match against already-defined identifiers will make your matching code more powerful.

Don’t forget to subscribe to my RSS feed, or follow this blog on Twitter.
Copyright © 2009 Matthew Jason Malone

I noticed some strange behavior in some Scala code recently. It was rather a mystery. I looked for my error and googled for a solution for the longest time with no success. Eventually I got my answer from the Scala mailing list / Nabble forum. Here’s the class that was causing the trouble.

class ArrayWrapper[A](length: Int) {
  private val array = new Array[A](length)
  def apply(x: Int) = array(x)
  def update(x: Int, value: A) = array(x) = value
  override def toString(): String = array.toString
}

The first thing you’ll notice about this class is that it is extremely simple! There aren’t a lot of moving parts. It’s a simple wrapper that exposes 3 basic array behaviors: apply (a ‘getter’), update (a ‘putter’), and good ol’ toString. Arrays in Scala take a type parameter, and to ensure that this class could wrap an array of any type I used a type parameter, too. Have a good look at the class and make sure you understand how it works. It won’t take long.

How do you expect this class to behave? Let’s play a little fill-in-the-blanks. Here is a Scala interpreter session with some results blanked out.

scala> class ArrayWrapper[A](length: Int) {
    |   private val array = new Array[A](length)
    |   def apply(x: Int) = array(x)
    |   def update(x: Int, value: A) = array(x) = value
    |   override def toString(): String = array.toString
    | }
defined class ArrayWrapper

scala> val a = new ArrayWrapper[Int](5)
??????????????

scala> val x = a(0)
??????????????

scala> x.toString
??????????????

scala> a(0).toString
??????????????

scala> a(0) = 0

scala> a.toString
??????????????

scala> a(0).toString
??????????????

scala>

There are 6 blanks. What do you expect to see in each of those? Well, the first blank follows the creation of a new ArrayWrapper[Int] and its assignment to a val ‘a’. So, according to our overriding definition of toString, it is simply the result of the underlying Array’s toString method. I know from experience how a brand new Array of Ints looks. It looks like this:

scala> new Array[Int](5)
res1: Array[Int] = Array(0, 0, 0, 0, 0)

So that’s what I expect to see here. Anyone expect something different? Here’s what I actually saw in the first blank:

scala> val a = new ArrayWrapper[Int](5)
a: ArrayWrapper[Int] = Array(null, null, null, null, null)

Hmm. That’s not what I expected. Did you predict this? Why is this array full of nulls when a new Array[Int] is usually full of zeros? I was stumped. The array is parameterized, I reasoned, so maybe type erasure was involved. That doesn’t make sense, though. No types should be erased at this point.

Let’s look at the next few lines, 12-16. I called a(0) (the apply method) and assigned the result to x. I then called the toString method on x. What do you expect in these two lines? I would have expected a(0) to return 0 and x.toString to return “0”, but my conviction is shaken by that last result. Will a(0) return null? Will x.toString throw a NullPointerException? Decide what you predict will happen. Here’s the actual result:

scala> val x = a(0)
x: Int = 0

scala> x.toString
res0: java.lang.String = 0

Each line behaves in the “correct” way even though we saw all those nulls in the underlying array. That’s good news, I suppose. Maybe the problem is limited to Array’s toString method. It should be smooth sailing now. Let’s now look at line 18, in which we call a(0).toString. It’s just combining the operations (apply and toString) from the previous two lines without storing the intermediate result in ‘x’. I expected that to return String “0”. You can probably guess by now that what I expected is not what I got. Make your own prediction before you read the actual result below. What will happen when we call a(0).toString?

scala> a(0).toString
java.lang.NullPointerException
       at .<init>(<console>:7)
       at .<clinit>(<console>)
       at RequestResult$.<init>(<console>:3)
       at RequestResult$.<clinit>(<console>)
       at RequestResult$result(<console>)
       at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
       at
sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:39)
       at sun.reflect.DelegatingMethodAccessorImpl.i...

Ouch! NullPointerException! This is an unpleasant surprise. The call a(0) returned a zero earlier, and calling toString on that zero returned a String “0”. But now we get this disaster. I’m getting more and more confused. Do you have an explanation for this crazy behavior yet?

Moving along, in line 21 we assign 0 to a(0). Remember that a(0) returned 0 earlier. By the way, behind the scenes the line “a(0) = 0” doesn’t call the apply method, but the ‘update’ method. It succeeds. In lines 23 and 26 we call a.toString and a(0).toString. What will happen in each case? At this point, it’s anybody’s guess. The behavior has been so wacky I can’t even make a sensible prediction. Make a guess of your own, if you dare, and observe the actual result below:

scala> a(0) = 0

scala> a.toString
res3: String = Array(0, null, null, null, null)

scala> a(0).toString
res4: java.lang.String = 0

The underlying Array now appears to contain a zero in addition to the nulls. Also, the a(0).toString, which was throwing a NullPointerException earlier, is now succeeding.

As I say, I puzzled over this problem for some time. I wanted to blame the issue on type erasure in the parameterized type, but that explanation didn’t make sense. I posted a question to the Scala forum on Nabble and got a response back in short order from Daniel Sobral.

The culprit? Drumroll…

Boxing. Well, boxing, unboxing, and a peculiarity of parameterized types. To review, here is our ArrayWrapper class:

class ArrayWrapper[A](length: Int) {
  private val array = new Array[A](length)
  def apply(x: Int) = array(x)
  def update(x: Int, value: A) = array(x) = value
  override def toString(): String = array.toString
}

We declared ‘array’ to be an Array[A], which is to say an Array of who-knows-what. When the Array is defined in this way, with a type parameter of unknown type, the Array must be an array of object references! It cannot be an array of Java int primitives. That’s the peculiarity of parameterized types. That’s why the default values for the members of the array were null instead of 0. The underlying array is actually an array of java.lang.Integer objects.

When we ran ‘val x = a(0)’, Scala retrieved the value at index 0 which was null. The apply method has Int return type in our example, and null is not an legal value of an Int. Int is Scala’s version of the Java int primitive type. So the null was converted (unboxed) to Int value 0. Then it could be stored in val x, etc. Once it’s safely unboxed, it behaves like a normal Scala Int value.

So, why did a(0).toString not work? Shouldn’t the null returned from a(0) be unboxed to Int 0, then re-boxed for the toString call? Apparently it doesn’t work that way. The unboxing hasn’t happened at the time the toString call is executed, so that toString is called on the null, giving us the NullPointerException. I don’t know whether this behavior is imposed by the JVM or the Scala language. Either way, it seems to me like a violation of the Principle of Least Astonishment and an opportunity for improvement.

Once we call a(0) = 0, then the underlying array is populated with a boxed version of 0, which is to say an instance of java.lang.Integer. After it’s populated with a non-null it works normally.

Again, this only happens for Arrays with parameterized types. If we make ArrayWrapper non-parameterized and declare ‘array’ as an Array[Int] then the problem goes away.

scala> class ArrayWrapper(length: Int) {
     |   private val array = new Array[Int](length)
     |   def apply(x: Int) = array(x)
     |   def update(x: Int, value: Int) = array(x) = value
     |   override def toString(): String = array.toString
     | }
defined class ArrayWrapper

scala> val a = new ArrayWrapper(5)
a: ArrayWrapper = Array(0, 0, 0, 0, 0)

scala> a(0).toString
res0: java.lang.String = 0

There are a few lessons in all this for the Scala developer:

  • Be vigilant about Array initialization. Initialize them explicitly, especially when dealing with primitives like Int, Long, Float, Double, Byte and Char. Don’t trust the default values.
  • Beware parameterized Arrays. They are flawed. Consider specifying their type or using another collection instead, such as a List or Map which can’t contain un-initialized values.
  • Unit test all your code, even those parts that look too simple to screw up.

Don’t forget to subscribe to my RSS feed, or follow this blog on Twitter.
Copyright © 2009 Matthew Jason Malone

In my last post I reviewed the implementation of scala.List’s foldLeft and foldRight methods. That post included a couple of simple examples, but today I’d like to give you a whole lot more. The foldLeft method is extremely versatile. It can do thousands of jobs. Of course, it’s not the best tool for EVERY job, but when working on a list problem it’s a good idea to stop and think, “Should I be using foldLeft?”

Below, I’ll present a list of problem descriptions and solutions. I thought about listing all the problems first, and then the solutions, so the reader could work on his own solution and then scroll down to compare. But this would be very annoying for those who refuse, against my strenuous urging, to start up a Scala interpreter and try to write their own solution to each problem before reading my solution.

Sum

Write a function called ‘sum’ which takes a List[Int] and returns the sum of the Ints in the list. Don’t forget to use foldLeft.

def sum(list: List[Int]): Int = list.foldLeft(0)((r,c) => r+c)
def sum(list: List[Int]): Int = list.foldLeft(0)(_+_)

I’ll explain this first example in a bit more depth than the others, just to make sure we all know how foldLeft works.

These two definitions above are equivalent. Let’s examine the first one. The foldLeft method is called on the list parameter. The first parameter is 0. This is the starting value, and the value that will be returned if list is empty. The second parameter is a function literal. It takes parameters ‘r’ (for result) and ‘c’ (for current) and returns the sum of these two values. Scala is smart enough to figure out that since the first parameter (0) is an Int, the ‘r’ parameter must also be an Int. The initial value is always the same type as ‘r’. Scala can also tell that since ‘list’ is a List[Int] the ‘c’ parameter must also be an Int, so we don’t have to specify their types in the parameter list.

The foldLeft method takes that initial value, 0, and the function literal, and it begins to apply the function on each member of the list (parameter ‘c’), updating the result value (parameter ‘r’) each time. That result value that we call ‘r’ is sometimes called the accumulator, since it accumulates the results of the function calls.

In the first defintion, foldLeft’s second parameter (a function literal) uses explicitly named parameters. Notice that ‘r’ and ‘c’ are each referred to exactly once in the function literal, and in the same order as the parameter list. When function literal parameters are used in this way (once each, same order) you can use the shorthand demonstrated in the second definition. The first ‘_’ stands for ‘r’, and the second one stands for ‘c’.

Product

Now that you’ve got the idea, try this one. Write a function that takes a List[Int], and returns the product (of multiplication) of all the Ints in the list. It will be similar to the ‘sum’ function, but with a couple of differences.

def product(list: List[Int]): Int = list.foldLeft(1)(_*_)

Did you get it? It’s the same as ‘sum’ with two exceptions. The initial value is now 1, and the function literal’s parameters are multiplied instead of added. If the initial value were 0, as in ‘sum’, then the function would always return 0.

Count

This one’s a little different. Write a function that takes a List[Any] and returns the number of items in the list. Don’t just call list.length()! Implement it using foldLeft.

def count(list: List[Any]): Int =
  list.foldLeft(0)((sum,_) => sum + 1)

First, we pick our initial value. Remember that this is the value that will be returned for an empty list. An empty list has 0 elements, so we use 0. What function do we want to apply for every item in the list? We just want to increase the result value by one. We call that parameter ‘sum’ in this solution. We don’t care about the actual value of each list element, so we call the second parameter ‘_’, which means it should be discarded.

Average

Here’s a fun one. Write a function that takes a List[Double] and returns the average of the list’s values. There are two ways to go about this one. You could combine two of the previous solutions, using two foldLeft calls, or you could combine them into a single foldLeft. Try to find both solutions.

def average(list: List[Double]): Double =
  list.foldLeft(0.0)(_+_) / list.foldLeft(0.0)((r,c) => r+1)

def average(list: List[Double]): Double = list match {
  case head :: tail => tail.foldLeft( (head,1.0) )((r,c) =>
    ((r._1 + (c/r._2)) * r._2 / (r._2+1), r._2+1) )._1
  case Nil => NaN
}

The first solution is pretty easy and combines the ‘sum’ and ‘count’ solutions. In real life, of course, you wouldn’t use foldLeft to find the length of the list. You’d just use the length() method. Other than that, though, this is a perfectly sensible solution.

The second solution is more complex. First, the list is matched against two patterns. It is either interpreted as a head item followed by a tail, or as an empty list (Nil). If it’s empty, the function returns the same thing as the first solution, NaN (Not a Number) because you can’t divide by 0.

If the list is not empty, we use a Pair as our initial value. A Pair is just an ordered pair of values. It’s a convenient way to bundle values together. We use it when we need to keep track of more than one accumulator value. In this case, we want to keep track of the average “so far” and also the number of values that the average represents. If the function literal were just passed the average so far, it wouldn’t know how to weight the next value. Members of a Pair are accessed using special methods called ‘_1’ and ‘_2’. You can have groupings longer than 2, also. These are named Tuple3, Tuple4, and so on. In fact, Pair is just an alias of Tuple2. Notice that we didn’t use the word Pair or Tuple2 anywhere in the code. If you enclose a comma-delimited series of values in parentheses, Scala converts that series into the appropriate TupleX.

After we have built up the result, it is a Pair containing the average and the number of items in the list. We only want to return the average so we call ‘_1’ on the result of foldLeft.

Last

Whew! That one was a little tough. Here’s an easier one. Given a List[A] return the last value in the list. Again, no using List’s last() method.

def last[A](list: List[A]): A =
  list.foldLeft[A](list.head)((_, c) => c)

Easy! Mostly. You’ll notice that we’re using a type parameter, A, in this one. If you’re not familiar with type parameters, too bad. I can’t explain them here. Suffice it to say that our use of A here allows us to take a list of any type of contents, and return a result of just that type. So Scala knows that when this is called on a List[Int], it will return an Int. When it’s called on a List[String], it returns a String.

First, we pick an initial value. For the empty list the concept of a last item doesn’t make any sense, so forget that. We can use any value, so long as it’s of type A. list.head is convenient, so that’s our initial value. The function literal is the simplest we’ve seen. For each item in the list, it just returns that item itself. So when it gets to the end of the list, the accumulator holds the last item. We don’t use the accumulator value in the function literal, so it gets parameter name ‘_’.

Penultimate

Write a function called ‘penultimate’ that takes a List[A] and returns the penultimate item (i.e. the next to last item) in the list. Hint: Use a tuple.

def penultimate[A](list: List[A]): A =
  list.foldLeft( (list.head, list.tail.head) )((r, c) => (r._2, c) )._1

This one is very much like the function ‘last’, but instead of keeping just the current item it keeps a Pair containing the previous and current items. When foldLeft completes, its result is a Pair containing the next-to-last and last items. The “_1” method returns just the penultimate item.

Contains

Write a function called ‘contains’ that takes a List[A] and an item of type A, and returns true if the item is one of the members of the list, and false if it isn’t.

def contains[A](list: List[A], item: A): Boolean =
  list.foldLeft(false)(_ || _==item)

We choose an initial value of false. That is, we’ll assume the item is not in the list until we can prove otherwise. We use each of the two parameters exactly once and in the proper order, so we can use the ‘_’ shorthand in our function literal. That function literal returns the result so far (a Boolean) ORed with a comparison of the current item and the target value. If the target is ever found, the accumulator becomes true and stays true as foldLeft continues.

Get

Write a function called ‘get’ that takes a List[A] and an index Int, and returns the list value at the index position. Throw an exception if the index is out of bounds.

def get[A](list: List[A], idx: Int): A =
  list.tail.foldLeft((list.head,0)) {
    (r,c) => if (r._2 == idx) r else (c,r._2+1)
  } match {
    case (result, index) if (idx == index) => result
    case _ => throw new Exception("Bad index")
  }

This one has two parts. First there’s the foldLeft, and the result is pattern matched. The foldLeft is pretty easy to follow. The accumulator is a Pair containing the current item and the current index. The current item keeps updating and the current index keeps incrementing until the current index equals the passed in idx. Once the correct index is found the same accumulator is returned over and over. This works fine if idx parameter is in bounds. If it’s out of bounds, though, the foldLeft just returns a Pair containing the last item and the last index. That’s where the pattern match comes in. If the Pair contains the right index then we use the result item. Otherwise, we throw an exception.

MimicToString

Write a function called ‘mimicToString’ that mimics List’s own toString method. That is, it should return a String containing a comma-delimited series of string representations of the list contents with “List(” on the left and “)” on the right.

def mimicToString[A](list: List[A]): String = list match {
  case head :: tail => tail.foldLeft("List(" + head)(_ + ", " + _) + ")"
  case Nil => "List()"
}

This one also uses a pattern match, but this time the match happens first. The pattern match just treats the empty list as a special case. For the general case (a non-empty list) we use, of course, foldLeft. The accumulator starts out as “List(” + the head item. Then each remaining item (notice foldLeft is called on tail) is appended with a leading “, ” and a final “)” is added to the result of foldLeft.

Reverse

This one’s kind of fun. Make sure to try it before you look at my solution. Write a function called ‘reverse’ that takes a List and returns the same list in reverse order.

def reverse[A](list: List[A]): List[A] =
  list.foldLeft(List[A]())((r,c) => c :: r)

A very simple solution! The initial value of the accumulator is just an empty list. We don’t use Nil, but instead spell out the List type so that Scala will know what type to make ‘r’. As I say, we start with the empty list which is sensible because the reverse of an empty list is an empty list. Then, as we go through the list, we place each item at the front of the accumulator. So the item at the front of list becomes the last item in the accumulator. This goes on until we reach the end of list, and that last member of list goes onto the front of the accumulator. It’s a really neat and tidy solution.

Unique

Write a function called ‘unique’ that takes a List and returns the same List, but with duplicated items removed.

def unique[A](list: List[A]): List[A] =
  list.foldLeft(List[A]()) { (r,c) =>
    if (r.contains(c)) r else c :: r
  }.reverse

As usual, we start with an empty list. foldLeft looks at each list item and if it’s already contained in the accumulator then then it stays as it is. If it’s not in the accumulator then it’s appended. This code bears a striking similarity to the ‘reverse’ function we wrote earlier except for the “if (r.contains(c)) r” part. Because of this, the foldLeft result is actually the original list with duplicates removed, but in reverse order. To keep the output in the same order as the input, we add the call to reverse. We could also have chained on the foldLeft from the ‘reverse’ function, like so:

def unique[A](list: List[A]): List[A] =
  list.foldLeft(List[A]()) { (r,c) =>
    if (r.contains(c)) r else c :: r
  }.foldLeft(List[A]())((r,c) => c :: r)

ToSet

Write a function called ‘toSet’ that takes a List and returns a Set containing the unique elements of the list.

def toSet[A](list: List[A]): Set[A] =
  list.foldLeft(Set[A]())( (r,c) => r + c)

Super easy one. You just start out with an empty Set, which would be the right answer for an empty List. Then you just add each list item to the accumulator. Since the accumulator is a Set, it takes care of eliminating duplicates for you.

Double

Write a function called ‘double’ that takes a List and a new List in which each item appears twice in a row. For example double(List(1, 2, 3)) should return List(1, 1, 2, 2, 3, 3).

def double[A](list: List[A]): List[A] =
  list.foldLeft(List[A]())((r,c) => c :: c :: r).reverse

Again, pretty easy. Are you starting to see a pattern. When you use foldLeft to transform one list into another, you usually end up with the reverse of what you really want.

Alternately, you could have used the foldRight method instead. This does the same thing as foldLeft, except it accumulates its result from back to front instead of front to back. I can’t recommend using it, though, due to problems I point out in my other post on foldLeft and foldRight. But here’s what it would look like:

def double[A](list: List[A]): List[A] =
  list.foldRight(List[A]())((c,r) => c :: c :: r)

InsertionSort

This one takes some thinking. Write a function called ‘insertionSort’ that uses foldLeft to sort the input List using the insertion sort algorithm. Try it on your own before you look at the solution.

Need a hint? Use List’s ‘span’ method.

Did you find a solution? Here’s mine:

def insertionSort[A <% Ordered[A]](list: List[A]): List[A] =
  list.foldLeft(List[A]()) { (r,c) =>
    val (front, back) = r.span(_ < c)
    front ::: c :: back
  }

First, the type parameter ensures that we have elements that can be arranged in order. We start, predictably, with an empty list as our initial accumulator. Then, for each item we assume the accumulator is in order (which it always will be), and use span to split it into two sub-lists: all already-sorted items less than the current item, and all already-sorted items greater than or equal to the current item. We put the current item in between these two and the accumulator remains sorted. This is, of course, not the fastest way to sort a list. But it’s a neat foldLeft trick.

Pivot

Speaking of sorting, you can implement part of quicksort with foldLeft, the pivot. Write a function called ‘pivot’ that takes a List, and returns a Tuple3 containing: (1) a list of all elements less than the original list’s first element, (2) the first element, and (3) a List of all elements greater than or equal to the first element.

def pivot[A <% Ordered[A]](list: List[A]): (List[A],A,List[A]) =
  list.tail.foldLeft[(List[A],A,List[A])]( (Nil, list.head, Nil) ) {
    (result, item) =>
    val (r1, pivot, r2) = result
    if (item < pivot) (item :: r1, pivot, r2) else (r1, pivot, item :: r2)
  }

We’re using the first element, head, as the pivot value, so we skip the head and call foldLeft on list.tail. We initialize the accumulator to a Tuple3 containing the head element with an empty list on either side. Then for each item in the list we just pick which of the two lists to add to based on a comparison with the pivot value.

If you take the additional step of turning this into a recursive call, you can implement a quicksort algorithm. It probably won’t be a very efficient one because it will involve a lot of building and rebuilding lists. Give it a try if you like, and then look at my solution:

def quicksort[A <% Ordered[A]](list: List[A]): List[A] = list match {
  case head :: _ :: _ =>
    println(list)
    list.foldLeft[(List[A],List[A],List[A])]( (Nil, Nil, Nil) ) {
      (result, item) =>
      val (r1, r2, r3) = result
      if      (item < head) (item :: r1, r2, r3)
      else if (item > head) (r1, r2, item :: r3)
      else                  (r1, item :: r2, r3)
    } match {
      case (list1, list2, list3) =>
        quicksort(list1) ::: list2  ::: quicksort(list3)
    }
  case _ => list
}

Basically, for all lists that have more than 1 element the function chooses the head element as the pivot value, uses foldLeft to divide the list into three (less than, equal to, and greater than the pivot), recursively sorts the less-than and greater-than lists, and knits the three together.

Encode

Ok, we got a little into the weeds with that last one. Here’s a simpler one. Write a function called ‘encode’ that takes a List and returns a list of Pairs containing the original values and the number of times they are repeated. So passing List(1, 2, 2, 2, 2, 2, 3, 2, 2) to encode will return List((1, 1), (2, 5), (3, 1), (2, 2)).

def encode[A](list: List[A]): List[(A,Int)] =
list.foldLeft(List[(A,Int)]()){ (r,c) =>
    r match {
      case (value, count) :: tail =>
        if (value == c) (c, count+1) :: tail
        else            (c, 1) :: r
      case Nil =>
        (c, 1) :: r
    }
}.reverse

Decode

You knew this was coming. Write a function called ‘decode’ that does the opposite of encode. Calling ‘decode(encode(list))’ should return the original list.

def decode[A](list: List[(A,Int)]): List[A] =
list.foldLeft(List[A]()){ (r,c) =>
    var result = r
    for (_ <- 1 to c._2) result = c._1 :: result
    result
}.reverse

Encode and decode could both have been written by using foldRight and dropping the call to reverse.

Group

One last example. Write a function called ‘group’ that takes a List and an Int size that groups elements into sublists of the specified sizes. So calling “group( List(1, 2, 3, 4, 5, 6, 7), 3)” should return List(List(1, 2, 3), List(4, 5, 6), List(7)). Don’t forget to make sure list items are in the right order. Try it yourself before you look at the solution below.

def group[A](list: List[A], size: Int): List[List[A]] =
  list.foldLeft( (List[List[A]](),0) ) { (r,c) => r match {
    case (head :: tail, num) =>
      if (num < size)  ( (c :: head) :: tail , num + 1 )
      else             ( List(c) :: head :: tail , 1 )
    case (Nil, num) => (List(List(c)), 1)
    }
  }._1.foldLeft(List[List[A]]())( (r,c) => c.reverse :: r)

This code uses the first foldLeft to group the items in a way that’s convenient to list operations, and that last foldLeft to fix the order, which would otherwise be wrong in both the outer and inner lists.

The End!

That’s all for now. If you know of any neat foldLeft tricks, please do leave a comment. I’d be interested to hear about it.

Don’t forget to subscribe to my RSS feed, or follow this blog on Twitter.
Copyright © 2009 Matthew Jason Malone

In Java you don’t see a lot of linked lists, and if you do it’s almost always java.util.LinkedList. People never write their own lists. They don’t really need to, I suppose. The one from java.util is fine. Plenty of people are leading fulfilling software careers never having implemented their own linked list. But it’s kind of a shame. Knowing how your data structures work makes you a better programmer.

It’s even rarer for a person to implement his own linked list in Scala. Scala’s scala.List is one of the most used classes in the language, so it’s packed with functionality. It’s abstract, covariant, it has helper objects such as List and Nil and the little-known ‘::’ class, it inherits from Product, Seq, Collection, Iterable, and PartialFunction. The machinery of List pulls in Array, ListBuffer, and more. It can be hard to take it all in.

So let’s build our own linked list. We’ll start out with something very basic and un-Scala-like. Then we’ll improve it gradually until we have something a little closer to scala.List. I encourage you to fire up your Scala interpreter and follow along.

Back to Basics

First, a short review of linked lists. A linked list is a chain of nodes, each referring to exactly one other node until you get to the end of the chain. You refer to the list by its first item and you follow the chain of references to reach the other nodes.

What are the requirements for our first try? Our node should be able to hold a piece of data, and refer to the next node. It should also be able to report its length, and provide a toString method so we can visualize the list. Here we go.

class MyList(val head: Any, val tail: MyList) {
  def isEmpty = (head == null && tail == null)
  def length: Int = if (isEmpty) 0 else 1 + tail.length
  override def toString: String = if (isEmpty) "" else head + " " + tail
}

The value ‘head’ holds the data for the node, ‘tail’ refers to the next element in the chain. The ‘isEmpty’ method is true if the head and tail are both null. The length and toString methods are both defined using a similar pattern: if (isEmpty) [base result] else [data for current node + result of same method on tail].

Here’s what it looks like when we use this class:

scala> var list = new MyList(null, null)
list: MyList =

scala> list.length
res0: Int = 0

scala> list.isEmpty
res1: Boolean = true

scala> list = new MyList("ABC", list)
list: MyList = ABC

scala> list.length
res3: Int = 1

scala> list.isEmpty
res4: Boolean = false

scala> list = new MyList("XYZ", list)
list: MyList = XYZ ABC

scala> list = new MyList("123", list)
list: MyList = 123 XYZ ABC

scala> list.tail.head
res7: Any = XYZ

Not bad. It gets the job done. But it has some problems. First is the use of ‘null’. Use of null references is sloppy and increases the odds of a null pointer exception so, ideally, we don’t want to see that. It has other problems, too. It’s too verbose. It’s not typesafe. But for now let’s concentrate on getting rid of the nulls.

No Nulls Is Good Nulls

How can we do it? We’re using the null as a special value, a marker to tell us when a node is at the end of a list. So we’ll just use something else as that marker instead. What can we use? We’ll create a special object for the empty list. It will be recognized as empty just based on its identity, not on null values. So let’s try it:

class MyList(val head: Any, val tail: MyList) {
  def isEmpty = false
  def length: Int = if (isEmpty) 0 else 1 + tail.length
  override def toString: String = if (isEmpty) "" else head + " " + tail
}

object MyListNil extends MyList("arbitrary value", null) {
  override def isEmpty = true
}

That’s better. (The observant reader will note the similarity of MyListNil to scala.List’s Nil object.) We got rid of the nulls in the isEmpty method, but we still have to put something in the head and tail parameters of the MyList constructor. We put an arbitrary non-null value in head, but what do we put for tail? Either null or create a new MyList. And how can that MyList be instantiated? It also needs a tail. Vicious circle. So this solution leaves us still stuck with a null.

Earlier, the null was there to mark a special node. We factored out that usage. Now it’s there to allow us to create the MyListNil. How can we factor that out? MyListNil is required to call its parent’s constructor. What if had no parent? Then it wouldn’t be a MyList anymore. What if it had an abstract parent? Now you’re talking. Let’s see what that would look like.

abstract class MyList {
  def head: Any
  def tail: MyList
  def isEmpty: Boolean
  def length: Int
}

class MyListImpl(val head: Any, val tail: MyList) extends MyList {
  def isEmpty = false
  def length: Int = 1 + tail.length
  override def toString: String = head + " " + tail
}

object MyListNil extends MyList {
  def head: Any = throw new Exception("head of empty list")
  def tail: MyList = throw new Exception("tail of empty list")
  def isEmpty = true
  def length = 0
  override def toString =  ""
}

It’s a little more code, but much neater. There are no nulls anywhere. Here’s how it looks when we use this new MyList:

scala> var list: MyList = MyListNil
list: MyList =

scala> list = new MyListImpl("ABC", list)
list: MyList = ABC

scala> list = new MyListImpl("XYZ", list)
list: MyList = XYZ ABC

scala> list = new MyListImpl("123", list)
list: MyList = 123 XYZ ABC

scala> list.length
res3: Int = 3

scala> list.tail.head
res4: Any = XYZ

scala> list.tail.tail.tail.head
java.lang.Exception: head of empty list
        at ...

Pretty neat. The equivalent of MyListImpl in the Scala’s real List implementation is a class called ‘::’, which has that funny name, by the way, because it looks nice in pattern matching code. Sometimes ‘::’ is referred to as cons. With nulls finally eliminated, we can concentrate on other issues.

Brevity Is The Heart Of List

The thing that I notice at this point is that a lot of typing (on the keyboard) is required to use this list. We have to type out “list = new MyListImpl(…, list)” every time we add an item. We can improve this with a new method.

abstract class MyList {
  [...]
  def add(item: Any): MyList = new MyListImpl(item, this)
}

Now we have classes referring to each other. MyList creates new MyListImpls, and MyListImpl extends MyList. So you’ll need to put these classes in a .scala file and compile them instead of just typing them into the Scala interpreter. But, wow! Look how much easier it is to use MyList now:

scala> var list = MyListNil add("ABC") add("XYZ") add("123")
list: MyList = 123 XYZ ABC

scala> list.length
res1: Int = 3

So much easier! One thing I notice, though, is that the order of items in the code is different from the order produced by toString. We can change our ‘add’ method so that is right-associative instead of left-associative by using a method name that ends in ‘:’ (colon). We’ll use ‘::’ as the method name since that’s what scala.List uses.

abstract class MyList {
  [...]
  def ::(item: Any): MyList = new MyListImpl(item, this)
}

scala> var list = "ABC" :: "XYZ" :: "123" :: MyListNil
list: MyList = ABC XYZ 123

Now we’re really getting somewhere. This is starting to look more like scala.List. One other thing that the standard list implementation gives you is a shortcut for initializing lists. It looks like “List(1, 2, 3, 4)”. Notice there’s no ‘new’ keyword. This is done using the scala.List helper object and its ‘apply’ method. Below is our own MyList helper object.

object MyList {
  def apply(items: Any*): MyList = {
    var list: MyList = MyListNil
    for (idx <- 0 until items.length reverse)
      list = items(idx) :: list
    list
  }
}

scala> var list = MyList("ABC", "XYZ", "123")
list: MyList = ABC XYZ 123

scala> list = "Cool" :: list
list: MyList = Cool ABC XYZ 123

Better Type-Safe Than Sorry

Better. Our code looks much neater now when we use MyList. I’ll introduce just one more improvement to MyList. It still has a rather glaring problem. It provides no type information. It keeps all of its data using a reference to Any. If you don’t see why this is a problem, let’s see what happens when we want to get the length of some items in a MyList:

scala> var list = MyList("ABC", 12345, "WXYZ")
list: MyList = ABC 12345 WXYZ

scala> list.head.length
<console>:6: error: value length is not a member of Any
       list.head.length
                 ^

scala> list.head.asInstanceOf[String].length
res10: Int = 3

scala> list.tail.head.asInstanceOf[String].length
java.lang.ClassCastException: java.lang.Integer cannot be cast to java.lang.String
        at .<init>(<console>:6)

Ouch! First, when we try to call method ‘length’ on list.head Scala complains that list.head is a reference to Any. Any doesn’t have a length method. This isn’t a dynamically typed language like, say, Ruby. An object has to have the right type before we can start calling methods. What to do? You could implement a MyStringList where the head has type String. But then you’ll need a MyIntList, MyDoubleList, etc. What we need is a way to specify the type of data in the list when we create the MyList instance. What we need is a type parameter.

Here’s the complete MyList code using a type parameter, and a little demonstration code:

abstract class MyList[A] {
  def head: A
  def tail: MyList[A]
  def isEmpty: Boolean
  def length: Int
  def ::(item: A): MyList[A] = new MyListImpl[A](item, this)
}

class MyListImpl[A](val head: A, val tail: MyList[A]) extends MyList[A] {
  def isEmpty = false
  def length: Int = 1 + tail.length
  override def toString: String = head + " " + tail
}

object MyListNil extends MyList[Nothing] {
  def head: Nothing = throw new Exception("head of empty list")
  def tail: MyList[Nothing] = throw new Exception("tail of empty list")
  def isEmpty = true
  def length = 0
  override def toString =  ""
}

object MyList {
  def apply[A](items: A*): MyList[A] = {
    var list: MyList[A] = MyListNil.asInstanceOf[MyList[A]]
    for (idx <- 0 until items.length reverse)
      list = items(idx) :: list
    list
  }
}

scala> var list = MyList("ABC", "WXYZ", "123")
list: MyList[java.lang.String] = ABC WXYZ 123

scala> list.head.length
res0: Int = 3

scala> 3.14159 :: list
<console>:6: error: type mismatch;
 found   : Double
 required: java.lang.String
       3.14159 :: list
               ^

scala> var list = MyList("ABC", 123, 3.14159)
list: MyList[Any] = ABC 123 3.14159

Look at line 32. The “MyList(…)” returns a MyList[String]. Scala figures out from the parameters what type to use. In line 35, you can see how much easier it is to use the list contents when you know the type at compile time.

If you try to mix types, as in line 45, Scala determines the nearest common ancestor of the types (Any, in this case) and uses that. However, if the type parameter is already determined, as in line 38, it won’t change when you try to add data of a different type. To make line 38 work, we can make a small change to the ‘::’ method:

abstract class MyList[A] {
  [...]
  def ::[B >: A](item: B): MyList[B] = 
    new MyListImpl(item, this.asInstanceOf[MyList[B]])
}

scala> var list = MyList("ABC", "XYZ")
list: MyList[java.lang.String] = ABC XYZ

scala> 3.14159 :: list
res0: MyList[Any] = 3.14159 ABC XYZ

This says that ‘::’ takes a parameter of type B, which is either A or a superclass of A, and returns a MyList[B]. So if you have a MyList[String] and you call ‘::’ on it with a Double parameter, Scala figures out that although Double is not a superclass of String, String and Double are both descendants of Any, and it returns a MyList[Any].

Conclusion

That’s a good stopping point for now. Obviously you can take the MyList class a lot further and add a lot more methods, but we’ve created some code that approximates the basics provided by scala.List. In fact, you could take several of the scala.List methods (foldLeft, for example) and basically drop them right into MyList and they’d work fine.

Don’t forget to subscribe to my RSS feed, or follow this blog on Twitter.
Copyright © 2009 Matthew Jason Malone

One of my favorite functional programming tricks is folding. The fold left and fold right functions can do a lot of complicated things with a small amount of code. Today, I’d like to (1) introduce folding, (2) make note of some surprising, nay, shocking fold behavior, (3) review the folding code used in Scala’s List class, and (4) make some presumptuous suggestions on how to improve List.

Update: I’ve created a new post in which I list lots and lots of foldLeft examples in case you’d like to learn more about what folding can accomplish.

Know When to Hold ‘Em, Know When to Fold ‘Em

In case you’re not familiar with folding, I’ll describe it as briefly as I can.

Here’s the signature of the foldLeft function from List[A], a list of items of type A:

def foldLeft[B](z: B)(f: (B, A) => B): B

Firstly, foldLeft is a curried function (So is foldRight). If you don’t know about currying, that’s ok; this function just takes its two parameters (z and f) in two sets of parentheses instead of one. Currying isn’t the important part anyway.

The first parameter, z, is of type B, which is to say it can be different from the list contents type. The second parameter, f, is a function that takes a B and an A (a list item) as parameters, and it returns a value of type B. So the purpose of function f is to take a value of type B, use a list item to modify that value and return it.

The foldLeft function goes through the whole List, from head to tail, and passes each value to f. For the first list item, that first parameter, z, is used as the first parameter to f. For the second list item, the result of the first call to f is used as the B type parameter.

For example, say we had a list of Ints 1, 2, and 3. We could call foldLeft(“X”)((b,a) => b + a). For the first item, 1, the function we define would add string “X” to Int 1, returning string “X1”. For the second list item, 2, the function would add string “X1” to Int 2, returning “X12”. And for the final list item, 3, the function would add “X12” to 3 and return “X123”.

Here are a few more examples.

list.foldLeft(0)((b,a) => b+a)
list.foldLeft(1)((b,a) => b*a)
list.foldLeft(List[Int]())((b,a) => a :: b)

The first line is super simple. It’s almost like the example I described above, but the z value is the number 0 instead of string “X”. This fold combines the elements of the list by addition instead of concatenation. So the fold returns the sum of all Ints in the list. Line 2 combines them through multiplication. Do you see why the z value is 1 in this case?

Line 3 is a little more complex. Can you guess what it does? It starts out with an empty list of Ints and adds each item to the accumulator (We call the b parameter of our function the accumulator because it accumulates data from each of our list items). Because it starts with the head and adds to the beginning of the accumulator list until it gets to the last item of the original list, it returns the original list in reverse order.

The foldRight function works in much the same way as foldLeft. Can you guess the difference? You got it. It starts at the end of the list and works its way up to the head.

Folds can be used for MUCH more than I’ve shown here. With folds, you can solve lots of different problems with a standard construct. You should read up on them if you’re just starting out in functional programming.

All That Glitters Is Not Fold

Now for the moment you’ve been waiting for. Fold’s dirty little secret! The below is taken from a scala interpreter session.

scala> var shortList = 1 to 10 toList
shortList: List[Int] = List(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)

scala> var longList = 1 to 325000 toList
longList: List[Int] = List(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 
15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, ...

scala> shortList.foldLeft("")((x,y) => "X")
res1: java.lang.String = X

scala> shortList.foldRight("")((x,y) => "X")
res2: java.lang.String = X

scala> longList.foldLeft("")((x,y) => "X")
res3: java.lang.String = X

scala> longList.foldRight("")((x,y) => "X")
java.lang.StackOverflowError
        at scala.List.foldRight(List.scala:1079)
        at scala.List.foldRight(List.scala:1081)
        at scala.List.foldRight(List.scala:1081)
        at scala.List.foldRight(List.scala:1081)
        at scala.List.foldRight(List.scala:1081)
        at scala.List.foldRight(List.scala:1081)
        at scala.List.foldRight(List.scala:1081)
        at scala.List.foldRight(List.scala:1081)
        at scala.List.foldRig...

We created two lists: shortList with 10 items, and longList with 325,000 items. Then we perform a trivial foldLeft and foldRight on shortList. It’s trivial because the passed-in function always returns the string “X”; it doesn’t even use the list data.

Then we do a foldLeft on longList. This goes off without a hitch. Finally we try to do a foldRight, the same foldRight that succeeded on the shorter list, and it fails! The foldLeft worked. Why didn’t the foldRight work? It’s a perfectly reasonable call against a perfectly reasonable List. Something funny is going on here.

The error message says there was a stack overflow, and the stack trace shows a long list of calls at List.scala:1081. If you’ve read my post about tail-recursion, then you probably suspect that some recursive code is to blame.

Let’s look into List.scala, maybe the single most important Scala source file.

Fool’s Fold

Without further ado, here’s the code for foldLeft and foldRight from List.scala:

override def foldLeft[B](z: B)(f: (B, A) => B): B = {
  var acc = z
  var these = this
  while (!these.isEmpty) {
    acc = f(acc, these.head)
    these = these.tail
  }
  acc
}

override def foldRight[B](z: B)(f: (A, B) => B): B = this match {
  case Nil => z
  case x :: xs => f(x, xs.foldRight(z)(f))
}

Wow! Those two definitions are very different!

The foldLeft function is the one that worked for short and long lists. You can see why? It isn’t head-recursive. In fact, it isn’t recursive at all. It is implemented as a while loop. On each iteration, the next list item is passed to the function f and the accumulator (called acc) is updated. When there are no more list items, the accumulator is returned. No recursion means no stack overflows.

The foldRight function is implemented in a totally different way. If the list is empty, the z parameter is returned. Otherwise, a recursive call is made on the tail (the whole list minus the first item) of this list, and that result is passed to the function f. Study the foldRight definition. Do you understand how it works? It’s an elegant recursive solution, and the code really is quite pretty, but it’s not tail recursive so it fails for large lists.

Why didn’t Mr Odersky just write foldRight using a while loop, too? Then this problem wouldn’t exist, right? The reason is that Scala’s List is a implemented as a singly-linked list. Each list element has access to the next item in the list, but not to the previous item. You can only traverse a list in one direction! This works fine for foldLeft, which goes from head to tail, but foldRight has to start at the end of the list and work its way forward to the head. If foldRight uses recursion, it must recurse all the way to the end and then use the results of those recursive calls as the accumulator passed into function f.

See? The results of the recursive call must be used for further calculation, so the recursive call can’t be the last thing that happens, so it can’t be written as a tail-recursive function. If you don’t know what I’m talking about, read my introduction to tail-recursion.

Out With The Fold, In With The New

So is that it for foldRight? Is it hopeless? I say no!

There is a way to get the same result as foldRight, but using foldLeft. Can you guess what it is? Here’s how:

list.foldRight("X")((a,b) => a + b)
list.reverse.foldLeft("X")((b,a) => a + b)

These two lines are equivalent! They give the same result no matter what’s in list. Since foldRight processes list elements from last to first, that’s the same as processing the reversed list from first to last.

Here are three possible implementations of foldRight that could replace the current one.

def foldRight[B](z: B)(f: (A, B) => B): B = 
  reverse.foldLeft(z)((b,a) => f(a,b))

def foldRight[B](z: B)(f: (A, B) => B): B =
  if (length > 50) reverse.foldLeft(z)((b,a) => f(a,b))
  else             originalFoldRight(z)(f)

def foldRight[B](z: B)(f: (A, B) => B): B =
  try {
    originalFoldRight(z)(f)
  } catch {
    case e1: StackOverflowError => reverse.foldLeft(z)((b,a) => f(a,b))
  }

The first one simply replaces the original recursive logic with the equivalent call to reverse and foldLeft. Why wasn’t foldRight implemented this way to begin with? It may be, in part, that the authors thought the extra overhead of reversing the list was unwarranted. To me, it doesn’t seem that bad. The original foldRight and foldLeft functions are O(n), meaning they run in an amount of time roughly proportional to the number of items in the list. If you look at the source for the reverse function, you’ll see it’s also O(n). So running reverse followed by foldLeft is O(n).

The second implementation is a compromise. It uses the original recursive version of foldRight (referred to as originalFoldRight in the above code) only when the list is shorter than 50 elements. The reverse.foldLeft is used for lists of 50 elements or longer. 50 is just an arbitrary number, just a guess at a sensible limit on the number of recursive calls to allow.

The third implementation tries the original foldRight logic first and if the call stack overflows then it uses reverse.foldLeft. This solution is, of course, completely ridiculous, but even this would be better than a foldRight which sometimes crashes your program.

That’s All, Folds!

As I pointed out before, the reverse.foldLeft implementation of foldRight is O(n), same as the original recursive version. The original foldRight may work just fine when your Scala application is young and working with small data sets. Over time more customers are added, more products are created, more orders are placed, and then one day, *POOF*, a runtime error! It’s a ticking time-bomb.

As you may well guess, I would like to see the reverse.foldLeft logic used instead of the recursive version. That would prevent the stack overflow errors. But I would settle for just deprecating foldRight. It would be better to eliminate foldRight and force the coder to work around it than to leave it in its current state. In fact, I don’t think any head-recursive functions belong in the List class.

Do any readers have any insight into why foldRight is coded the way it is?

Don’t forget to subscribe to my RSS feed, or follow this blog on Twitter.
Copyright © 2009 Matthew Jason Malone

Next Page »