Here’s a three-for-one special for you: A post about implementing the Levenshtein string distance algorithm in Scala AND refactoring it from an imperative style to a functional style AND I even throw in a short lesson on memoization. To make sure that our refactoring is correct and preserves the expected behavior, I’ll unit test the code along the way using ScalaTest. ScalaCheck, JUnit, or TestNG would work just as well, but I used ScalaTest.

First Things First

“What, exactly,” some of you may be asking, “is Levenshtein string distance? Some kind of Teutonic tailoring terminology?” Not at all. It’s a way of measuring how alike or different two strings of symbols are. For example, the string “sturgeon” is a lot more similar to “surgeon” than to “urgently”. “Sturgeon” and “surgeon” are only a single letter (t) apart. They have a string distance of 1. “Sturgeon” and “urgently” share some letters, but each has some letters not present in the other. So what’s their string distance? It’s not so obvious now.

String distance is useful. The use that most quickly springs to mind is spelling correction. If I type “computwr” a string distance algoritm could tell us that “computwr” is very close to the dictionary word “computer”. But there’s more to it than that. There are a lot of fuzzy problems in which you want to find which two sets of complex data are most similar. One way to solve a problem like that is to encode it into a string compare using string distance. For example, you could write a program to find pen strokes in an image and encode their shapes (up, curve left, down, angle right, etc) as a string which could be compared to known encodings for handwritten letters. By finding the closest matches you can create a handwriting recognition program. String distance is useful in DNA analysis, of course, recognizing patterns in signals, and a host of other situations.

The Levenshtein string distance algorithm was developed by Vladimir Levenshtein in 1965. It can easily tell us the distance from “sturgeon” to “urgently”. This algorithm breaks down string transformation into three basic operations: adding a character, deleting a character, and replacing a character. Each of these operations is assigned a cost, usually a cost of 1 for each operation. Leaving a character unchanged has a cost of 0. So to go from “surgeon” to “sturgeon” you leave the “s” unchanged for a cost of 0. Then you add a “t” for a cost of 1. All the other letters also remained unchanged, so the total cost is 1, just as we expected.

To change “sturgeon” to “urgently” is harder. They have the same number of letters, so we could just do a replacement on each one for a distance of 8. But is that the shortest distance? What if we try to re-use that “urge” from “sturgeon”? Can we re-use the “n”? Does that help? What about the “t”? We need an algorithm that we can follow.

The Grid

What we need is a way to find the cheapest combination of operations which changes the first string to the second. That’s the Levenshtein algorithm. It works like this. Write the first string vertically from top to bottom. To the right of each letter write 1, 2, 3, etc. Write a 0 above the number 1. Then write the second string horizontally and again add the numbers 1, 2, 3 etc. to the right of the 0. It will look something like this:

    u r g e n t l y
  0 1 2 3 4 5 6 7 8
s 1
t 2
u 3
r 4
g 5
e 6
o 7
n 8

That’s the first step. The grid remains un-filled-in at this point. Can you guess what the grid positions are going to hold? They are going to hold the cost to convert the various prefixes of “sturgeon” to the various prefixes of “urgently”. The position at the intersection of row r4, and the column a2, for example, will contain the cost to convert “stur” to “ur”. We fill in these positions (left-to-right then top to bottom) with the smallest of the following three numbers:

  • The number above the current position plus one.
  • The number to the left of the current position plus one.
  • The above-left number if row letter and column letter are the same, or the above-left number plus one otherwise.

When we finish, the bottom right corner contains the cost to convert the first string to the second.

But Why? (Understanding the Algorithm)

Those are some shockingly simple rules! Let’s examine how they translate into string distance.

First, what are those numbers 0 to n that we write along the top and left? They’re not just indices. The numbers along the top represent the cost to convert the empty string to the various prefixes of “urgently”. The cost is 0 to convert “” to “”, 1 to convert “” to “u”, 2 to convert “” to “ur”, and so on. The numbers on the left are the cost to convert the prefixes of “sturgeon” to the the empty string . “” to “” is 0, “s” to “” is 1, and so forth. These costs are obvious. The only way to convert “” to “urgently” is to add eight letters. There’s nothing to delete, nothing to replace. The only way to convert “sturgeon” to “” is to delete all eight letters.

Each position, as we have established, represents the cost to convert the string of characters down the left side ending in the current row’s character into the the string of characters along the top ending at the current column’s character. Put another way, for any given position let’s call the current row’s letter A, and the current column’s letter B. If we use a colon (:) to indicate string concatenation then the beginning string, the one along the left of the grid, can be written Prefix1:A. So for our example word “sturgeon” if we look at a position in row e6 then Prefix1 is “sturg” and the final letter, which we’re calling A, is “e”.

So, speaking in terms of Prefix1, Prefix2, A, and B, we use the following inputs:

  • The cost to change Prefix1 to Prefix2:B (the number above the current position).
  • The cost to change Prefix1:A to Prefix2 (the number to the left).
  • The cost to change Prefix1 to Prefix2 (the above-left number).

If we know the costs of these three conversions we can find the cost to change Prefix1:A to Prefix2:B using this logic:

  • We know that if converting Prefix1 to Prefix2:B has a cost of X, then Prefix1:A can be converted to Prefix2:B for the cost of X plus the cost of deleting A, or X + 1.
  • We know that if converting Prefix1:A to Prefix2 has a cost of Y, then Prefix1:A can be converted to Prefix2:B for the cost of Y plus the cost of adding B, or Y + 1.
  • We know that if converting Prefix1 to Prefix2 has a cost of Z, then Prefix1:A can be converted to Prefix2:B for the cost of Z plus the cost of replacing A with B, or Z + (0 or 1 depending on whether A = B).

If you understand the above logic, then you understand this really neat algorithm. It’s a quintessential example of dynamic programming. The solution is built up by solving simpler problems. You start with the trivial case of converting to and from the empty string, and then you build up the solution for the prefixes until you have the complete solution.

In the grid’s initial configuration there is only one location for which we know all three costs and that is row s1, column u1. That’s the only space for which all three neighboring values (above, left, and above-left) are populated. After we fill in this one, there are two more spaces available to us. Those are (t2, u1) and (s1, r2). Ordinarily, though, the spaces are populated line by line.

A Simple Example

Let’s do a quick example by hand. Then we’ll take a stab at implementing it in code. What’s the string distance from “hat” to “tape”? First, our empty grid:

    t a p e
  0 1 2 3 4
h 1
a 2
t 3

The first space is row h, column t. The letters are different so our choices are 1 + 1 (above), 1 + 1 (left), or 0 + 1 (above-left). 0 + 1 is the smallest value, so the first space gets populated with 1. The next space has choices 2 + 1 (above), 1 + 1 (left), or 1 + 1 (above-left). 1 + 1 is the lowest, so we fill in the second space with 2. Once we finish the row, we have this:

    t a p e
  0 1 2 3 4
h 1 1 2 3 4
a 2
t 3

The next space is row a, column t. Our choices are 1 + 1 (above), 2 + 1 (left), or 1 + 1 (above-left). 1 + 1 is the smallest value, so the space gets populated with 2. The next space has choices 2 + 1 (above), 2 + 1 (left), or 1 + 0 (above-left). Why 1 + 0? Because the above-left value is 1 and both letters for this space are “a” so we can replace “a” with “a” for free. Go ahead and finish the whole grid. This is the result:

    t a p e
  0 1 2 3 4
h 1 1 2 3 4
a 2 2 1 2 3
t 3 2 2 2 3

The strings “hat” and “tape” have a distance of 3.

Some Code, Finally

As fascinating as Levenshtein distance is, and as much more as there is to say on the topic, the time has come to write some code. Here’s a Scala implementation that is very close to the pencil-and-paper approach that we just performed.

import scala.Math.min
import scala.Math.max

object StringDistance {
  def stringDistance(s1: String, s2: String): Int = {
    def minimum(i1: Int, i2: Int, i3: Int) = min(min(i1, i2), i3)

    val dist = new Array[Array[Int]](s1.length + 1, s2.length + 1)

    for (idx <- 0 to s1.length) dist(idx)(0) = idx
    for (jdx <- 0 to s2.length) dist(0)(jdx) = jdx

    for (idx <- 1 to s1.length; jdx <- 1 to s2.length)
      dist(idx)(jdx) = minimum (
        dist(idx-1)(jdx  ) + 1,
        dist(idx  )(jdx-1) + 1,
        dist(idx-1)(jdx-1) + (if (s1(idx-1) == s2(jdx-1)) 0 else 1)
      )

    dist(s1.length)(s2.length)
  }

Do you see what I mean when I say it’s close to the pencil-and-paper approach? We actually construct a two-dimensional Array to represent the grid we drew earlier. It’s a very literal implementation.

To explain the code briefly, we declare a singleton StringDistance having a single method called stringDistance. Within this method we declare a 3-argument minimum method. (I wonder why there’s no “def min(params: Int*): Int” defined in scala.Math.) Then we create an array called “dist” and populate the top row and leftmost column in lines 8-11. The for loop on line 13 cycles through each array position from left to right then top to bottom, and populates them according to the Levenshtein logic. Finally, once the grid is filled in we return the number in the bottom right position.

The job’s not done yet, of course. We haven’t written our unit tests. Here’s the test I wrote:

import org.scalatest._

class StringDistanceSuite extends FunSuite with PrivateMethodTester {

  test("stringDistance should work on empty strings") {
    assert( StringDistance.stringDistance(   "",    "") == 0 )
    assert( StringDistance.stringDistance(  "a",    "") == 1 )
    assert( StringDistance.stringDistance(   "",   "a") == 1 )
    assert( StringDistance.stringDistance("abc",    "") == 3 )
    assert( StringDistance.stringDistance(   "", "abc") == 3 )
  }

  test("stringDistance should work on equal strings") {
    assert( StringDistance.stringDistance(   "",    "") == 0 )
    assert( StringDistance.stringDistance(  "a",   "a") == 0 )
    assert( StringDistance.stringDistance("abc", "abc") == 0 )
  }

  test("stringDistance should work where only inserts are needed") {
    assert( StringDistance.stringDistance(   "",   "a") == 1 )
    assert( StringDistance.stringDistance(  "a",  "ab") == 1 )
    assert( StringDistance.stringDistance(  "b",  "ab") == 1 )
    assert( StringDistance.stringDistance( "ac", "abc") == 1 )
    assert( StringDistance.stringDistance("abcdefg", "xabxcdxxefxgx") == 6 )
  }

  test("stringDistance should work where only deletes are needed") {
    assert( StringDistance.stringDistance(  "a",    "") == 1 )
    assert( StringDistance.stringDistance( "ab",   "a") == 1 )
    assert( StringDistance.stringDistance( "ab",   "b") == 1 )
    assert( StringDistance.stringDistance("abc",  "ac") == 1 )
    assert( StringDistance.stringDistance("xabxcdxxefxgx", "abcdefg") == 6 )
  }

  test("stringDistance should work where only substitutions are needed") {
    assert( StringDistance.stringDistance(  "a",   "b") == 1 )
    assert( StringDistance.stringDistance( "ab",  "ac") == 1 )
    assert( StringDistance.stringDistance( "ac",  "bc") == 1 )
    assert( StringDistance.stringDistance("abc", "axc") == 1 )
    assert( StringDistance.stringDistance("xabxcdxxefxgx", "1ab2cd34ef5g6") == 6 )
  }

  test("stringDistance should work where many operations are needed") {
    assert( StringDistance.stringDistance("example", "samples") == 3 )
    assert( StringDistance.stringDistance("sturgeon", "urgently") == 6 )
    assert( StringDistance.stringDistance("levenshtein", "frankenstein") == 6 )
    assert( StringDistance.stringDistance("distance", "difference") == 5 )
    assert( StringDistance.stringDistance("java was neat", "scala is great") == 7 )
  }

}

It tests several special cases as well as the general case. All we have to do is compile our StringDistance object and the StringDistanceSuite unit test, fire up the scala interpreter and run the test:

scala> (new StringDistanceSuite).execute()
Test Starting - StringDistanceSuite: recursiveStringDistance should work on empty strings
Test Succeeded - StringDistanceSuite: recursiveStringDistance should work on empty strings
Test Starting - StringDistanceSuite: recursiveStringDistance should work on equal strings
Test Succeeded - StringDistanceSuite: recursiveStringDistance should work on equal strings
Test Starting - StringDistanceSuite: recursiveStringDistance should work where only inserts are needed
Test Succeeded - StringDistanceSuite: recursiveStringDistance should work where only inserts are needed
Test Starting - StringDistanceSuite: recursiveStringDistance should work where only deletes are needed
Test Succeeded - StringDistanceSuite: recursiveStringDistance should work where only deletes are needed
Test Starting - StringDistanceSuite: recursiveStringDistance should work where only substitutions are needed
Test Succeeded - StringDistanceSuite: recursiveStringDistance should work where only substitutions are needed
Test Starting - StringDistanceSuite: stringDistance should work where many operations are needed
Test Succeeded - StringDistanceSuite: stringDistance should work where many operations are needed

scala>

Refactoring: Reduced Memory Usage

The code passes all the tests, so let’s take things a step further. One shortcoming of our implementation is that it can require a lot of memory. That array has to be of size (n+1)*(m+1) where n and m are the lengths of the two strings we’re comparing. If you want to apply the method to strings that are a few hundred characters long (or longer) then you’re starting to talk about some serious memory requirements. How can we reduce the amount of memory required? Can you think of a way?

Once we complete one row of the grid, we use it again as an input when we compute the next row. But after that we’re done with it. Why leave it to clutter the heap? Let’s tweak our implementation slightly. Try rewriting the method using only two rows. Fill the first row with the initial 0, 1, 2, etc. Then use one to compute the other over and over. Think about how you would implement this, then have a look at my solution below.

  def stringDistance(s1: String, s2: String): Int = {
    def minimum(i1: Int, i2: Int, i3: Int) = min(min(i1, i2), i3)

    var dist = ( new Array[Int](s1.length + 1),
                 new Array[Int](s1.length + 1) )

    for (idx <- 0 to s1.length) dist._2(idx) = idx

    for (jdx <- 1 to s2.length) {
      val (newDist, oldDist) = dist
      newDist(0) = jdx
      for (idx <- 1 to s1.length) {
        newDist(idx) = minimum (
          oldDist(idx) + 1,
          newDist(idx-1) + 1,
          oldDist(idx-1) + (if (s1(idx-1) == s2(jdx-1)) 0 else 1)
        )
      }
      dist = dist.swap
    }

    dist._2(s1.length)
  }

This one uses a Pair (also called a Tuple2) containing two one-dimensional arrays, instead of a 2*(n+1) array. Pair happens to have the very handy “swap” method which we can use to trade out the rows when we’ve finished one and are ready to compute the next.

This is where our unit tests really show their worth. No need to wonder whether this new code really does work. We just recompile, run the tests again, and we can see that the code still gives us the expected results.

Refactoring: Recursion, Kind Of

What are some other ways we could write this code? Can it be improved? I thought I would try to replace the iteration in the previous implementations with recursion. Here’s what that code looks like:

  def stringDistance(s1: String, s2: String): Int = {
    def newCost(ins: Int, del: Int, subst: Int, c1: Char, c2:Char) =
      Math.min( Math.min( ins+1, del+1 ), subst + (if (c1 == c2) 0 else 1) )

    def getNewCosts(s1: List[Char], c2: Char, delVal: Int, prev: List[Int] ): List[Int] = (s1, prev) match {
      case (c1 :: _ , substVal :: insVal :: _) =>
        delVal :: getNewCosts(s1.tail, c2, newCost(insVal, delVal, substVal, c1, c2), prev.tail)
      case _ => List(delVal)
    }

    def sd(s1: List[Char], s2: List[Char], prev: List[Int]): Int = s2 match {
      case Nil => prev.last
      case _ => sd( s1, s2.tail, getNewCosts(s1, s2.head, prev.head+1, prev) )
    }

    (s1, s2) match {
      case (`s2`, `s1`) => 0
      case (_, "") | ("", _) => max(s1.length, s2.length)
      case _ => sd(s1.toList, s2.toList, (0 to s1.length).toList)
    }
  }

It’s a pretty naïve implementation, actually. It just replaces the the repetition of the two for loops with the repetition of the recursion of the two methods sd and getNewCosts. The sd method is even tail-recursive, allowing scala to optimize it. It does the same basic thing as the for loop version, though. It recurses through the characters of a row in the getNewCosts method, and it recurses through the rows of the grid in the sd method.

It looks more complicated than the previous implementations. It’s harder to read. But it passes our unit tests, so we can be pretty sure it’s correct.

Refactoring: List Methods

After the last implementation, I thought it looked a little sloppy. I wondered whether I could improve the situation by using some of the many useful methods built into scala’s List class. Here is the comparatively brief code that resulted:

  def stringDistance(s1: String, s2: String): Int = {
    def sd(s1: List[Char], s2: List[Char], costs: List[Int]): Int = s2 match {
      case Nil => costs.last
      case c2 :: tail => sd( s1, tail,
          (List(costs.head+1) /: costs.zip(costs.tail).zip(s1))((a,b) => b match {
            case ((rep,ins), chr) => Math.min( Math.min( ins+1, a.head+1 ), rep + (if (chr==c2) 0 else 1) ) :: a
          }).reverse
        )
    }
    sd(s1.toList, s2.toList, (0 to s1.length).toList)
  }

Like I say, it’s brief. Those List methods give you a lot of mileage.

Refactoring: Real Recursion

The more I looked at my previous attempt at a recursive solution, the more I realized how hare-brained it was. It wasn’t real recursion. It was just iteration using the stack. So I went back to the drawing board. If I want to know the final answer, the value in the bottom right position, how do I get it? I apply my three rules to the above, left, and above-left positions. How do I get those positions? Apply the rules again. That’s real recursion. Here’s my first stab at it:

  def stringDistance(s1: String, s2: String): Int = {
    def min(a:Int, b:Int, c:Int) = Math.min( Math.min( a, b ), c)
    def sd(s1: List[Char], s2: List[Char]): Int = (s1, s2) match {
      case (_, Nil) => s1.length
      case (Nil, _) => s2.length
      case (c1::t1, c2::t2)  => min( sd(t1,s2) + 1, sd(s1,t2) + 1,
                                     sd(t1,t2) + (if (c1==c2) 0 else 1) )
    }
    sd( s1.toList, s2.toList )
  }

Now, THAT is a nice looking recursive function. That’s more like it. See the pattern match block? If we try to convert from any string to the empty string or from the empty to a non-empty string then we just use the string length. You see? That takes care of the positions along the top and left of the grid. All the others are determined in the last case. That last case just applies our three rules. That is so short and simple. It’s a thing of beauty.

The only problem? It doesn’t work.

It’s technically correct. It will return correct answers … eventually. Or, rather, I think it will. I can’t be sure because it’s too slow to complete my unit tests! For anything but very short inputs (4 or 5 characters), the function takes a long time to return. Why? Let’s look at how many recursive calls are made for some inputs.

If we use strings “a” and “b” we pass over the “(_, Nil)” and “(Nil, _)” cases in the first call to function sd, because both our strings (Lists, actually) are non-empty. This results in three more calls to sd. Each of these three calls includes an empty List of characters, so there is no more recursion. That’s a total of four calls to sd for strings “a” and “b”.

What about “ab” and “xy”? Think about it for a moment? Step through the function in your head. How many calls to sd will there be for “ab” and “xy”?

Have you done it? I count 19. What about “abc” and “xyz”? I’ll save you the trouble. It’s 94. For length 4 strings it’s 481. For length 5 it’s 2524. Length 6 is 13,483. Then 73 thousand, then 400 thousand, then 2 million and so forth. Why so many calls? Each position in the grid is computed using all the positions to the left and all the positions above the current position. So a position in the top left will be computed and recomputed many times.

There is a way to get around this, of course. You probably already have some ideas. We’re going to do something called memoization. When you memoize a function, you make it remember results that it computed previously without having to actually recompute them. I’ll do that by caching results in a map. The map’s key is a Pair of List[Char]s, the inputs to my inner function sd, and its data is an Int, the return type of sd. I will modify sd to first check the map to see if the result for the current parameters has already been cached. If so, we simply return it. If not, we compute the value, cache it, and return it.

  def stringDistance(s1: String, s2: String): Int = {
    val memo = scala.collection.mutable.Map[(List[Char],List[Char]),Int]()
    def min(a:Int, b:Int, c:Int) = Math.min( Math.min( a, b ), c)
    def sd(s1: List[Char], s2: List[Char]): Int = {
      if (memo.contains((s1,s2)) == false)
        memo((s1,s2)) = (s1, s2) match {
          case (_, Nil) => s1.length
          case (Nil, _) => s2.length
          case (c1::t1, c2::t2)  => min( sd(t1,s2) + 1, sd(s1,t2) + 1,
                                         sd(t1,t2) + (if (c1==c2) 0 else 1) )
        }
      memo((s1,s2))
    }

    sd( s1.toList, s2.toList )
  }

There. That uglies up my function somewhat, but at least it’s usable now. And it passes my unit tests, so I’m reasonably assured that it’s right.

Memoization only works if you expect the same result for each identical function call. If your function takes input from stdin, for example, you can’t memoize that. Or if it has a random component. Or if, for any other reason, its return value is not always the same for the same inputs. You can memoize functions in different ways. There’s a post on Michid’s Weblog about a more general solution, a memoizing class which wraps existing functions to give you a memoized version.

PHEW!

Ok, I’ve tried to keep my rambling to a minimum in this post but it’s still a doozy. The things I wanted to get across are:

  • The usefulness of Levenshtein distance in solving a variety of problems.
  • How to understand the Levenshtein distance algorithm and why it works.
  • How to use unit tests to improve your code while maintaining some assurance that the new code still has the correct behavior.
  • Some of the different ways of implementing Levenshtein.

In the end, I think I like the second implementation (the one that switches out the two rows) and the last implementation the best. The second one seems to have good performance. I did some informal performance tests and it has a good mix of performance and simplicity. The last one, the memoized recursive one, appeals to me because it is in a more functional style and still has respectable performance.

Don’t forget to subscribe to my RSS feed, or follow this blog on Twitter.
Copyright © 2009 Matthew Jason Malone

One of my favorite functional programming tricks is folding. The fold left and fold right functions can do a lot of complicated things with a small amount of code. Today, I’d like to (1) introduce folding, (2) make note of some surprising, nay, shocking fold behavior, (3) review the folding code used in Scala’s List class, and (4) make some presumptuous suggestions on how to improve List.

Update: I’ve created a new post in which I list lots and lots of foldLeft examples in case you’d like to learn more about what folding can accomplish.

Know When to Hold ‘Em, Know When to Fold ‘Em

In case you’re not familiar with folding, I’ll describe it as briefly as I can.

Here’s the signature of the foldLeft function from List[A], a list of items of type A:

def foldLeft[B](z: B)(f: (B, A) => B): B

Firstly, foldLeft is a curried function (So is foldRight). If you don’t know about currying, that’s ok; this function just takes its two parameters (z and f) in two sets of parentheses instead of one. Currying isn’t the important part anyway.

The first parameter, z, is of type B, which is to say it can be different from the list contents type. The second parameter, f, is a function that takes a B and an A (a list item) as parameters, and it returns a value of type B. So the purpose of function f is to take a value of type B, use a list item to modify that value and return it.

The foldLeft function goes through the whole List, from head to tail, and passes each value to f. For the first list item, that first parameter, z, is used as the first parameter to f. For the second list item, the result of the first call to f is used as the B type parameter.

For example, say we had a list of Ints 1, 2, and 3. We could call foldLeft(“X”)((b,a) => b + a). For the first item, 1, the function we define would add string “X” to Int 1, returning string “X1”. For the second list item, 2, the function would add string “X1” to Int 2, returning “X12”. And for the final list item, 3, the function would add “X12” to 3 and return “X123”.

Here are a few more examples.

list.foldLeft(0)((b,a) => b+a)
list.foldLeft(1)((b,a) => b*a)
list.foldLeft(List[Int]())((b,a) => a :: b)

The first line is super simple. It’s almost like the example I described above, but the z value is the number 0 instead of string “X”. This fold combines the elements of the list by addition instead of concatenation. So the fold returns the sum of all Ints in the list. Line 2 combines them through multiplication. Do you see why the z value is 1 in this case?

Line 3 is a little more complex. Can you guess what it does? It starts out with an empty list of Ints and adds each item to the accumulator (We call the b parameter of our function the accumulator because it accumulates data from each of our list items). Because it starts with the head and adds to the beginning of the accumulator list until it gets to the last item of the original list, it returns the original list in reverse order.

The foldRight function works in much the same way as foldLeft. Can you guess the difference? You got it. It starts at the end of the list and works its way up to the head.

Folds can be used for MUCH more than I’ve shown here. With folds, you can solve lots of different problems with a standard construct. You should read up on them if you’re just starting out in functional programming.

All That Glitters Is Not Fold

Now for the moment you’ve been waiting for. Fold’s dirty little secret! The below is taken from a scala interpreter session.

scala> var shortList = 1 to 10 toList
shortList: List[Int] = List(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)

scala> var longList = 1 to 325000 toList
longList: List[Int] = List(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 
15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, ...

scala> shortList.foldLeft("")((x,y) => "X")
res1: java.lang.String = X

scala> shortList.foldRight("")((x,y) => "X")
res2: java.lang.String = X

scala> longList.foldLeft("")((x,y) => "X")
res3: java.lang.String = X

scala> longList.foldRight("")((x,y) => "X")
java.lang.StackOverflowError
        at scala.List.foldRight(List.scala:1079)
        at scala.List.foldRight(List.scala:1081)
        at scala.List.foldRight(List.scala:1081)
        at scala.List.foldRight(List.scala:1081)
        at scala.List.foldRight(List.scala:1081)
        at scala.List.foldRight(List.scala:1081)
        at scala.List.foldRight(List.scala:1081)
        at scala.List.foldRight(List.scala:1081)
        at scala.List.foldRig...

We created two lists: shortList with 10 items, and longList with 325,000 items. Then we perform a trivial foldLeft and foldRight on shortList. It’s trivial because the passed-in function always returns the string “X”; it doesn’t even use the list data.

Then we do a foldLeft on longList. This goes off without a hitch. Finally we try to do a foldRight, the same foldRight that succeeded on the shorter list, and it fails! The foldLeft worked. Why didn’t the foldRight work? It’s a perfectly reasonable call against a perfectly reasonable List. Something funny is going on here.

The error message says there was a stack overflow, and the stack trace shows a long list of calls at List.scala:1081. If you’ve read my post about tail-recursion, then you probably suspect that some recursive code is to blame.

Let’s look into List.scala, maybe the single most important Scala source file.

Fool’s Fold

Without further ado, here’s the code for foldLeft and foldRight from List.scala:

override def foldLeft[B](z: B)(f: (B, A) => B): B = {
  var acc = z
  var these = this
  while (!these.isEmpty) {
    acc = f(acc, these.head)
    these = these.tail
  }
  acc
}

override def foldRight[B](z: B)(f: (A, B) => B): B = this match {
  case Nil => z
  case x :: xs => f(x, xs.foldRight(z)(f))
}

Wow! Those two definitions are very different!

The foldLeft function is the one that worked for short and long lists. You can see why? It isn’t head-recursive. In fact, it isn’t recursive at all. It is implemented as a while loop. On each iteration, the next list item is passed to the function f and the accumulator (called acc) is updated. When there are no more list items, the accumulator is returned. No recursion means no stack overflows.

The foldRight function is implemented in a totally different way. If the list is empty, the z parameter is returned. Otherwise, a recursive call is made on the tail (the whole list minus the first item) of this list, and that result is passed to the function f. Study the foldRight definition. Do you understand how it works? It’s an elegant recursive solution, and the code really is quite pretty, but it’s not tail recursive so it fails for large lists.

Why didn’t Mr Odersky just write foldRight using a while loop, too? Then this problem wouldn’t exist, right? The reason is that Scala’s List is a implemented as a singly-linked list. Each list element has access to the next item in the list, but not to the previous item. You can only traverse a list in one direction! This works fine for foldLeft, which goes from head to tail, but foldRight has to start at the end of the list and work its way forward to the head. If foldRight uses recursion, it must recurse all the way to the end and then use the results of those recursive calls as the accumulator passed into function f.

See? The results of the recursive call must be used for further calculation, so the recursive call can’t be the last thing that happens, so it can’t be written as a tail-recursive function. If you don’t know what I’m talking about, read my introduction to tail-recursion.

Out With The Fold, In With The New

So is that it for foldRight? Is it hopeless? I say no!

There is a way to get the same result as foldRight, but using foldLeft. Can you guess what it is? Here’s how:

list.foldRight("X")((a,b) => a + b)
list.reverse.foldLeft("X")((b,a) => a + b)

These two lines are equivalent! They give the same result no matter what’s in list. Since foldRight processes list elements from last to first, that’s the same as processing the reversed list from first to last.

Here are three possible implementations of foldRight that could replace the current one.

def foldRight[B](z: B)(f: (A, B) => B): B = 
  reverse.foldLeft(z)((b,a) => f(a,b))

def foldRight[B](z: B)(f: (A, B) => B): B =
  if (length > 50) reverse.foldLeft(z)((b,a) => f(a,b))
  else             originalFoldRight(z)(f)

def foldRight[B](z: B)(f: (A, B) => B): B =
  try {
    originalFoldRight(z)(f)
  } catch {
    case e1: StackOverflowError => reverse.foldLeft(z)((b,a) => f(a,b))
  }

The first one simply replaces the original recursive logic with the equivalent call to reverse and foldLeft. Why wasn’t foldRight implemented this way to begin with? It may be, in part, that the authors thought the extra overhead of reversing the list was unwarranted. To me, it doesn’t seem that bad. The original foldRight and foldLeft functions are O(n), meaning they run in an amount of time roughly proportional to the number of items in the list. If you look at the source for the reverse function, you’ll see it’s also O(n). So running reverse followed by foldLeft is O(n).

The second implementation is a compromise. It uses the original recursive version of foldRight (referred to as originalFoldRight in the above code) only when the list is shorter than 50 elements. The reverse.foldLeft is used for lists of 50 elements or longer. 50 is just an arbitrary number, just a guess at a sensible limit on the number of recursive calls to allow.

The third implementation tries the original foldRight logic first and if the call stack overflows then it uses reverse.foldLeft. This solution is, of course, completely ridiculous, but even this would be better than a foldRight which sometimes crashes your program.

That’s All, Folds!

As I pointed out before, the reverse.foldLeft implementation of foldRight is O(n), same as the original recursive version. The original foldRight may work just fine when your Scala application is young and working with small data sets. Over time more customers are added, more products are created, more orders are placed, and then one day, *POOF*, a runtime error! It’s a ticking time-bomb.

As you may well guess, I would like to see the reverse.foldLeft logic used instead of the recursive version. That would prevent the stack overflow errors. But I would settle for just deprecating foldRight. It would be better to eliminate foldRight and force the coder to work around it than to leave it in its current state. In fact, I don’t think any head-recursive functions belong in the List class.

Do any readers have any insight into why foldRight is coded the way it is?

Don’t forget to subscribe to my RSS feed, or follow this blog on Twitter.
Copyright © 2009 Matthew Jason Malone

In my last post (Tail-Recursion Basics In Scala), I went on and on about how scala, as a functional programming language, was written to accommodate recursion. When you need to do a task over and over the procedural way is to iterate, and the functional way is to recurse. That’s what I thought. But I was browsing the source for the scala List class and I noticed something weird. It’s chock full of iteration and procedural code!

Well, I don’t have to tell you I was a little disappointed. I never claimed to have unreservedly swallowed the FP hype, but I figured the authors of the scala.List must surely practice what they preach. I’m puzzled. To show you what I mean, here are a few methods written in a procedural style just as they appear in scala.List, alongside my attempt to write them in a functional, tail-recursive fashion. This isn’t an exhaustive list by any means. There are tons of methods in class scala.List like these.

Mark the following! I am not trying to to correct the authors of scala.List or point their code out as inferior. On the contrary. They know what they are doing. I do not. They implemented a general-use functional/OO programming language. I have not. I am merely wondering aloud why scala.List was implemented the way it was.

I’m using http://lampsvn.epfl.ch/trac/scala/browser/scala/tags/R_2_7_2_RC2/src/library/scala/List.scala as my source.

List.length

I used a length function as an example in my last post on recursion. The length method of List returns a count of the items in the List. First, here’s the actual code for the length method:

sealed abstract class List[+A] extends Seq[A] with Product {
...
  def length: Int = {
    var these = this
    var len = 0
    while (!these.isEmpty) {
      len += 1
      these = these.tail
    }
    len
  }
...
}

See why I called this code procedural? Those vars and that while loop. Vars are used for variables whose values change. There’s nothing wrong with it, but it’s a feature of procedural code. Iterative loops are also more procedural than functional. You could argue that the while is not a procedural loop, but a functional closure. But I know iteration when I see it. Why wasn’t it written using recursion? Here’s how it could be implemented in a more functional style.

object List {
...
  def length(list: List[_], len: Int): Int = {
    if (list.isEmpty) len
    else length(list.tail, len + 1)
  }
...
}

sealed abstract class List[+A] extends Seq[A] with Product {
...
  def length: Int = List.length(this, 0)
...
}

There is a List class as well as a List object. I put the recursive function in the List helper object because I couldn’t get the tail recursion optimization when I implemented it as a recursive member function within class List.

List.indices

The indices method returns a List of the zero-based indices of a List. If a list has 4 items, its indices method returns List(0, 1, 2, 3). Once again, we’ll first look at the actual code for the indices method:

sealed abstract class List[+A] extends Seq[A] with Product {
...
  def indices: List[Int] = {
    val b = new ListBuffer[Int]
    var i = 0
    var these = this
    while (!these.isEmpty) {
      b += i
      i += 1
      these = these.tail
    }
    b.toList
  }
...
}

As before, this code steps through the list in a while loop. Here’s how it could be done using recursion:

object List {
...
  def indices(list: List[_], idxList: List[Int], curIdx: Int): List[Int] = {
    if (list.isEmpty) idxList
    else List.indices(list.tail, idxList + curIdx, curIdx + 1)
  }
...
}

sealed abstract class List[+A] extends Seq[A] with Product {
...
  def indices: List[Int] = List.indices(this, Nil, 0)
...
}

List.last

This is a very simple method. It just returns the last item in the List. Again, here’s the actual code for the List.last method:

sealed abstract class List[+A] extends Seq[A] with Product {
...
  override def last: A =
    if (isEmpty) throw new Predef.NoSuchElementException("Nil.last")
    else {
      var cur = this
      var next = this.tail
      while (!next.isEmpty) {
        cur = next
        next = next.tail
      }
      cur.head
    }
...
}

Here’s what it might look like in a functional style.

object List {
...
  def last[A](list: List[A]): A =
    if (list.isEmpty) throw new Predef.NoSuchElementException("Nil.last")
    else {
      if (list.tail.isEmpty) head
      else last(list.tail)
    }
...
}

sealed abstract class List[+A] extends Seq[A] with Product {
...
  override def last: A = List.last(this)
...
}

List.foreach

Ok, last one. This method runs a function once for each item in the List, passing the item as a parameter each time. Here’s the real code:

sealed abstract class List[+A] extends Seq[A] with Product {
...
  final override def foreach(f: A => Unit) {
    var these = this
    while (!these.isEmpty) {
      f(these.head)
      these = these.tail
    }
  }
...
}

And, again, here’s the code rendered in a more functional style.

object List {
...
  def foreach[A](list: List[A], f: A => Unit) {
    if (!list.isEmpty) {
      f(list.head)
      foreach(list.tail, f)
    }
  }
...
}

sealed abstract class List[+A] extends Seq[A] with Product {
...
  final override def foreach(f: A => Unit) = List.foreach(this, f)
...
}

So What?

So what? What does it matter that while loops are used instead of tail recursion? Well, it doesn’t really particularly matter. Not to me, anyway. But it’s surprising that proponents, nay, authors of a functional language would write (perfectly good) procedural code like this.

I don’t know why. Maybe it’s for the benefit of the Java coders whom scala’s creators hope will become early adopters. Maybe the functions in the List object are frowned upon for some design reason. More likely it’s some technical issue that my inexperienced eye can’t discern. Is there some reason my examples would not work? If you know, please share.

Update: By the way, Debasish Ghosh wrote a related article today. It includes a comparison of the List implementation in Erlang and Scala.

Don’t forget to subscribe to my RSS feed, or follow this blog on Twitter.
Copyright © 2008 Matthew Jason Malone