Here’s a three-for-one special for you: A post about implementing the Levenshtein string distance algorithm in Scala AND refactoring it from an imperative style to a functional style AND I even throw in a short lesson on memoization. To make sure that our refactoring is correct and preserves the expected behavior, I’ll unit test the code along the way using ScalaTest. ScalaCheck, JUnit, or TestNG would work just as well, but I used ScalaTest.

First Things First

“What, exactly,” some of you may be asking, “is Levenshtein string distance? Some kind of Teutonic tailoring terminology?” Not at all. It’s a way of measuring how alike or different two strings of symbols are. For example, the string “sturgeon” is a lot more similar to “surgeon” than to “urgently”. “Sturgeon” and “surgeon” are only a single letter (t) apart. They have a string distance of 1. “Sturgeon” and “urgently” share some letters, but each has some letters not present in the other. So what’s their string distance? It’s not so obvious now.

String distance is useful. The use that most quickly springs to mind is spelling correction. If I type “computwr” a string distance algoritm could tell us that “computwr” is very close to the dictionary word “computer”. But there’s more to it than that. There are a lot of fuzzy problems in which you want to find which two sets of complex data are most similar. One way to solve a problem like that is to encode it into a string compare using string distance. For example, you could write a program to find pen strokes in an image and encode their shapes (up, curve left, down, angle right, etc) as a string which could be compared to known encodings for handwritten letters. By finding the closest matches you can create a handwriting recognition program. String distance is useful in DNA analysis, of course, recognizing patterns in signals, and a host of other situations.

The Levenshtein string distance algorithm was developed by Vladimir Levenshtein in 1965. It can easily tell us the distance from “sturgeon” to “urgently”. This algorithm breaks down string transformation into three basic operations: adding a character, deleting a character, and replacing a character. Each of these operations is assigned a cost, usually a cost of 1 for each operation. Leaving a character unchanged has a cost of 0. So to go from “surgeon” to “sturgeon” you leave the “s” unchanged for a cost of 0. Then you add a “t” for a cost of 1. All the other letters also remained unchanged, so the total cost is 1, just as we expected.

To change “sturgeon” to “urgently” is harder. They have the same number of letters, so we could just do a replacement on each one for a distance of 8. But is that the shortest distance? What if we try to re-use that “urge” from “sturgeon”? Can we re-use the “n”? Does that help? What about the “t”? We need an algorithm that we can follow.

The Grid

What we need is a way to find the cheapest combination of operations which changes the first string to the second. That’s the Levenshtein algorithm. It works like this. Write the first string vertically from top to bottom. To the right of each letter write 1, 2, 3, etc. Write a 0 above the number 1. Then write the second string horizontally and again add the numbers 1, 2, 3 etc. to the right of the 0. It will look something like this:

    u r g e n t l y
  0 1 2 3 4 5 6 7 8
s 1
t 2
u 3
r 4
g 5
e 6
o 7
n 8

That’s the first step. The grid remains un-filled-in at this point. Can you guess what the grid positions are going to hold? They are going to hold the cost to convert the various prefixes of “sturgeon” to the various prefixes of “urgently”. The position at the intersection of row r4, and the column a2, for example, will contain the cost to convert “stur” to “ur”. We fill in these positions (left-to-right then top to bottom) with the smallest of the following three numbers:

  • The number above the current position plus one.
  • The number to the left of the current position plus one.
  • The above-left number if row letter and column letter are the same, or the above-left number plus one otherwise.

When we finish, the bottom right corner contains the cost to convert the first string to the second.

But Why? (Understanding the Algorithm)

Those are some shockingly simple rules! Let’s examine how they translate into string distance.

First, what are those numbers 0 to n that we write along the top and left? They’re not just indices. The numbers along the top represent the cost to convert the empty string to the various prefixes of “urgently”. The cost is 0 to convert “” to “”, 1 to convert “” to “u”, 2 to convert “” to “ur”, and so on. The numbers on the left are the cost to convert the prefixes of “sturgeon” to the the empty string . “” to “” is 0, “s” to “” is 1, and so forth. These costs are obvious. The only way to convert “” to “urgently” is to add eight letters. There’s nothing to delete, nothing to replace. The only way to convert “sturgeon” to “” is to delete all eight letters.

Each position, as we have established, represents the cost to convert the string of characters down the left side ending in the current row’s character into the the string of characters along the top ending at the current column’s character. Put another way, for any given position let’s call the current row’s letter A, and the current column’s letter B. If we use a colon (:) to indicate string concatenation then the beginning string, the one along the left of the grid, can be written Prefix1:A. So for our example word “sturgeon” if we look at a position in row e6 then Prefix1 is “sturg” and the final letter, which we’re calling A, is “e”.

So, speaking in terms of Prefix1, Prefix2, A, and B, we use the following inputs:

  • The cost to change Prefix1 to Prefix2:B (the number above the current position).
  • The cost to change Prefix1:A to Prefix2 (the number to the left).
  • The cost to change Prefix1 to Prefix2 (the above-left number).

If we know the costs of these three conversions we can find the cost to change Prefix1:A to Prefix2:B using this logic:

  • We know that if converting Prefix1 to Prefix2:B has a cost of X, then Prefix1:A can be converted to Prefix2:B for the cost of X plus the cost of deleting A, or X + 1.
  • We know that if converting Prefix1:A to Prefix2 has a cost of Y, then Prefix1:A can be converted to Prefix2:B for the cost of Y plus the cost of adding B, or Y + 1.
  • We know that if converting Prefix1 to Prefix2 has a cost of Z, then Prefix1:A can be converted to Prefix2:B for the cost of Z plus the cost of replacing A with B, or Z + (0 or 1 depending on whether A = B).

If you understand the above logic, then you understand this really neat algorithm. It’s a quintessential example of dynamic programming. The solution is built up by solving simpler problems. You start with the trivial case of converting to and from the empty string, and then you build up the solution for the prefixes until you have the complete solution.

In the grid’s initial configuration there is only one location for which we know all three costs and that is row s1, column u1. That’s the only space for which all three neighboring values (above, left, and above-left) are populated. After we fill in this one, there are two more spaces available to us. Those are (t2, u1) and (s1, r2). Ordinarily, though, the spaces are populated line by line.

A Simple Example

Let’s do a quick example by hand. Then we’ll take a stab at implementing it in code. What’s the string distance from “hat” to “tape”? First, our empty grid:

    t a p e
  0 1 2 3 4
h 1
a 2
t 3

The first space is row h, column t. The letters are different so our choices are 1 + 1 (above), 1 + 1 (left), or 0 + 1 (above-left). 0 + 1 is the smallest value, so the first space gets populated with 1. The next space has choices 2 + 1 (above), 1 + 1 (left), or 1 + 1 (above-left). 1 + 1 is the lowest, so we fill in the second space with 2. Once we finish the row, we have this:

    t a p e
  0 1 2 3 4
h 1 1 2 3 4
a 2
t 3

The next space is row a, column t. Our choices are 1 + 1 (above), 2 + 1 (left), or 1 + 1 (above-left). 1 + 1 is the smallest value, so the space gets populated with 2. The next space has choices 2 + 1 (above), 2 + 1 (left), or 1 + 0 (above-left). Why 1 + 0? Because the above-left value is 1 and both letters for this space are “a” so we can replace “a” with “a” for free. Go ahead and finish the whole grid. This is the result:

    t a p e
  0 1 2 3 4
h 1 1 2 3 4
a 2 2 1 2 3
t 3 2 2 2 3

The strings “hat” and “tape” have a distance of 3.

Some Code, Finally

As fascinating as Levenshtein distance is, and as much more as there is to say on the topic, the time has come to write some code. Here’s a Scala implementation that is very close to the pencil-and-paper approach that we just performed.

import scala.Math.min
import scala.Math.max

object StringDistance {
  def stringDistance(s1: String, s2: String): Int = {
    def minimum(i1: Int, i2: Int, i3: Int) = min(min(i1, i2), i3)

    val dist = new Array[Array[Int]](s1.length + 1, s2.length + 1)

    for (idx <- 0 to s1.length) dist(idx)(0) = idx
    for (jdx <- 0 to s2.length) dist(0)(jdx) = jdx

    for (idx <- 1 to s1.length; jdx <- 1 to s2.length)
      dist(idx)(jdx) = minimum (
        dist(idx-1)(jdx  ) + 1,
        dist(idx  )(jdx-1) + 1,
        dist(idx-1)(jdx-1) + (if (s1(idx-1) == s2(jdx-1)) 0 else 1)
      )

    dist(s1.length)(s2.length)
  }

Do you see what I mean when I say it’s close to the pencil-and-paper approach? We actually construct a two-dimensional Array to represent the grid we drew earlier. It’s a very literal implementation.

To explain the code briefly, we declare a singleton StringDistance having a single method called stringDistance. Within this method we declare a 3-argument minimum method. (I wonder why there’s no “def min(params: Int*): Int” defined in scala.Math.) Then we create an array called “dist” and populate the top row and leftmost column in lines 8-11. The for loop on line 13 cycles through each array position from left to right then top to bottom, and populates them according to the Levenshtein logic. Finally, once the grid is filled in we return the number in the bottom right position.

The job’s not done yet, of course. We haven’t written our unit tests. Here’s the test I wrote:

import org.scalatest._

class StringDistanceSuite extends FunSuite with PrivateMethodTester {

  test("stringDistance should work on empty strings") {
    assert( StringDistance.stringDistance(   "",    "") == 0 )
    assert( StringDistance.stringDistance(  "a",    "") == 1 )
    assert( StringDistance.stringDistance(   "",   "a") == 1 )
    assert( StringDistance.stringDistance("abc",    "") == 3 )
    assert( StringDistance.stringDistance(   "", "abc") == 3 )
  }

  test("stringDistance should work on equal strings") {
    assert( StringDistance.stringDistance(   "",    "") == 0 )
    assert( StringDistance.stringDistance(  "a",   "a") == 0 )
    assert( StringDistance.stringDistance("abc", "abc") == 0 )
  }

  test("stringDistance should work where only inserts are needed") {
    assert( StringDistance.stringDistance(   "",   "a") == 1 )
    assert( StringDistance.stringDistance(  "a",  "ab") == 1 )
    assert( StringDistance.stringDistance(  "b",  "ab") == 1 )
    assert( StringDistance.stringDistance( "ac", "abc") == 1 )
    assert( StringDistance.stringDistance("abcdefg", "xabxcdxxefxgx") == 6 )
  }

  test("stringDistance should work where only deletes are needed") {
    assert( StringDistance.stringDistance(  "a",    "") == 1 )
    assert( StringDistance.stringDistance( "ab",   "a") == 1 )
    assert( StringDistance.stringDistance( "ab",   "b") == 1 )
    assert( StringDistance.stringDistance("abc",  "ac") == 1 )
    assert( StringDistance.stringDistance("xabxcdxxefxgx", "abcdefg") == 6 )
  }

  test("stringDistance should work where only substitutions are needed") {
    assert( StringDistance.stringDistance(  "a",   "b") == 1 )
    assert( StringDistance.stringDistance( "ab",  "ac") == 1 )
    assert( StringDistance.stringDistance( "ac",  "bc") == 1 )
    assert( StringDistance.stringDistance("abc", "axc") == 1 )
    assert( StringDistance.stringDistance("xabxcdxxefxgx", "1ab2cd34ef5g6") == 6 )
  }

  test("stringDistance should work where many operations are needed") {
    assert( StringDistance.stringDistance("example", "samples") == 3 )
    assert( StringDistance.stringDistance("sturgeon", "urgently") == 6 )
    assert( StringDistance.stringDistance("levenshtein", "frankenstein") == 6 )
    assert( StringDistance.stringDistance("distance", "difference") == 5 )
    assert( StringDistance.stringDistance("java was neat", "scala is great") == 7 )
  }

}

It tests several special cases as well as the general case. All we have to do is compile our StringDistance object and the StringDistanceSuite unit test, fire up the scala interpreter and run the test:

scala> (new StringDistanceSuite).execute()
Test Starting - StringDistanceSuite: recursiveStringDistance should work on empty strings
Test Succeeded - StringDistanceSuite: recursiveStringDistance should work on empty strings
Test Starting - StringDistanceSuite: recursiveStringDistance should work on equal strings
Test Succeeded - StringDistanceSuite: recursiveStringDistance should work on equal strings
Test Starting - StringDistanceSuite: recursiveStringDistance should work where only inserts are needed
Test Succeeded - StringDistanceSuite: recursiveStringDistance should work where only inserts are needed
Test Starting - StringDistanceSuite: recursiveStringDistance should work where only deletes are needed
Test Succeeded - StringDistanceSuite: recursiveStringDistance should work where only deletes are needed
Test Starting - StringDistanceSuite: recursiveStringDistance should work where only substitutions are needed
Test Succeeded - StringDistanceSuite: recursiveStringDistance should work where only substitutions are needed
Test Starting - StringDistanceSuite: stringDistance should work where many operations are needed
Test Succeeded - StringDistanceSuite: stringDistance should work where many operations are needed

scala>

Refactoring: Reduced Memory Usage

The code passes all the tests, so let’s take things a step further. One shortcoming of our implementation is that it can require a lot of memory. That array has to be of size (n+1)*(m+1) where n and m are the lengths of the two strings we’re comparing. If you want to apply the method to strings that are a few hundred characters long (or longer) then you’re starting to talk about some serious memory requirements. How can we reduce the amount of memory required? Can you think of a way?

Once we complete one row of the grid, we use it again as an input when we compute the next row. But after that we’re done with it. Why leave it to clutter the heap? Let’s tweak our implementation slightly. Try rewriting the method using only two rows. Fill the first row with the initial 0, 1, 2, etc. Then use one to compute the other over and over. Think about how you would implement this, then have a look at my solution below.

  def stringDistance(s1: String, s2: String): Int = {
    def minimum(i1: Int, i2: Int, i3: Int) = min(min(i1, i2), i3)

    var dist = ( new Array[Int](s1.length + 1),
                 new Array[Int](s1.length + 1) )

    for (idx <- 0 to s1.length) dist._2(idx) = idx

    for (jdx <- 1 to s2.length) {
      val (newDist, oldDist) = dist
      newDist(0) = jdx
      for (idx <- 1 to s1.length) {
        newDist(idx) = minimum (
          oldDist(idx) + 1,
          newDist(idx-1) + 1,
          oldDist(idx-1) + (if (s1(idx-1) == s2(jdx-1)) 0 else 1)
        )
      }
      dist = dist.swap
    }

    dist._2(s1.length)
  }

This one uses a Pair (also called a Tuple2) containing two one-dimensional arrays, instead of a 2*(n+1) array. Pair happens to have the very handy “swap” method which we can use to trade out the rows when we’ve finished one and are ready to compute the next.

This is where our unit tests really show their worth. No need to wonder whether this new code really does work. We just recompile, run the tests again, and we can see that the code still gives us the expected results.

Refactoring: Recursion, Kind Of

What are some other ways we could write this code? Can it be improved? I thought I would try to replace the iteration in the previous implementations with recursion. Here’s what that code looks like:

  def stringDistance(s1: String, s2: String): Int = {
    def newCost(ins: Int, del: Int, subst: Int, c1: Char, c2:Char) =
      Math.min( Math.min( ins+1, del+1 ), subst + (if (c1 == c2) 0 else 1) )

    def getNewCosts(s1: List[Char], c2: Char, delVal: Int, prev: List[Int] ): List[Int] = (s1, prev) match {
      case (c1 :: _ , substVal :: insVal :: _) =>
        delVal :: getNewCosts(s1.tail, c2, newCost(insVal, delVal, substVal, c1, c2), prev.tail)
      case _ => List(delVal)
    }

    def sd(s1: List[Char], s2: List[Char], prev: List[Int]): Int = s2 match {
      case Nil => prev.last
      case _ => sd( s1, s2.tail, getNewCosts(s1, s2.head, prev.head+1, prev) )
    }

    (s1, s2) match {
      case (`s2`, `s1`) => 0
      case (_, "") | ("", _) => max(s1.length, s2.length)
      case _ => sd(s1.toList, s2.toList, (0 to s1.length).toList)
    }
  }

It’s a pretty naïve implementation, actually. It just replaces the the repetition of the two for loops with the repetition of the recursion of the two methods sd and getNewCosts. The sd method is even tail-recursive, allowing scala to optimize it. It does the same basic thing as the for loop version, though. It recurses through the characters of a row in the getNewCosts method, and it recurses through the rows of the grid in the sd method.

It looks more complicated than the previous implementations. It’s harder to read. But it passes our unit tests, so we can be pretty sure it’s correct.

Refactoring: List Methods

After the last implementation, I thought it looked a little sloppy. I wondered whether I could improve the situation by using some of the many useful methods built into scala’s List class. Here is the comparatively brief code that resulted:

  def stringDistance(s1: String, s2: String): Int = {
    def sd(s1: List[Char], s2: List[Char], costs: List[Int]): Int = s2 match {
      case Nil => costs.last
      case c2 :: tail => sd( s1, tail,
          (List(costs.head+1) /: costs.zip(costs.tail).zip(s1))((a,b) => b match {
            case ((rep,ins), chr) => Math.min( Math.min( ins+1, a.head+1 ), rep + (if (chr==c2) 0 else 1) ) :: a
          }).reverse
        )
    }
    sd(s1.toList, s2.toList, (0 to s1.length).toList)
  }

Like I say, it’s brief. Those List methods give you a lot of mileage.

Refactoring: Real Recursion

The more I looked at my previous attempt at a recursive solution, the more I realized how hare-brained it was. It wasn’t real recursion. It was just iteration using the stack. So I went back to the drawing board. If I want to know the final answer, the value in the bottom right position, how do I get it? I apply my three rules to the above, left, and above-left positions. How do I get those positions? Apply the rules again. That’s real recursion. Here’s my first stab at it:

  def stringDistance(s1: String, s2: String): Int = {
    def min(a:Int, b:Int, c:Int) = Math.min( Math.min( a, b ), c)
    def sd(s1: List[Char], s2: List[Char]): Int = (s1, s2) match {
      case (_, Nil) => s1.length
      case (Nil, _) => s2.length
      case (c1::t1, c2::t2)  => min( sd(t1,s2) + 1, sd(s1,t2) + 1,
                                     sd(t1,t2) + (if (c1==c2) 0 else 1) )
    }
    sd( s1.toList, s2.toList )
  }

Now, THAT is a nice looking recursive function. That’s more like it. See the pattern match block? If we try to convert from any string to the empty string or from the empty to a non-empty string then we just use the string length. You see? That takes care of the positions along the top and left of the grid. All the others are determined in the last case. That last case just applies our three rules. That is so short and simple. It’s a thing of beauty.

The only problem? It doesn’t work.

It’s technically correct. It will return correct answers … eventually. Or, rather, I think it will. I can’t be sure because it’s too slow to complete my unit tests! For anything but very short inputs (4 or 5 characters), the function takes a long time to return. Why? Let’s look at how many recursive calls are made for some inputs.

If we use strings “a” and “b” we pass over the “(_, Nil)” and “(Nil, _)” cases in the first call to function sd, because both our strings (Lists, actually) are non-empty. This results in three more calls to sd. Each of these three calls includes an empty List of characters, so there is no more recursion. That’s a total of four calls to sd for strings “a” and “b”.

What about “ab” and “xy”? Think about it for a moment? Step through the function in your head. How many calls to sd will there be for “ab” and “xy”?

Have you done it? I count 19. What about “abc” and “xyz”? I’ll save you the trouble. It’s 94. For length 4 strings it’s 481. For length 5 it’s 2524. Length 6 is 13,483. Then 73 thousand, then 400 thousand, then 2 million and so forth. Why so many calls? Each position in the grid is computed using all the positions to the left and all the positions above the current position. So a position in the top left will be computed and recomputed many times.

There is a way to get around this, of course. You probably already have some ideas. We’re going to do something called memoization. When you memoize a function, you make it remember results that it computed previously without having to actually recompute them. I’ll do that by caching results in a map. The map’s key is a Pair of List[Char]s, the inputs to my inner function sd, and its data is an Int, the return type of sd. I will modify sd to first check the map to see if the result for the current parameters has already been cached. If so, we simply return it. If not, we compute the value, cache it, and return it.

  def stringDistance(s1: String, s2: String): Int = {
    val memo = scala.collection.mutable.Map[(List[Char],List[Char]),Int]()
    def min(a:Int, b:Int, c:Int) = Math.min( Math.min( a, b ), c)
    def sd(s1: List[Char], s2: List[Char]): Int = {
      if (memo.contains((s1,s2)) == false)
        memo((s1,s2)) = (s1, s2) match {
          case (_, Nil) => s1.length
          case (Nil, _) => s2.length
          case (c1::t1, c2::t2)  => min( sd(t1,s2) + 1, sd(s1,t2) + 1,
                                         sd(t1,t2) + (if (c1==c2) 0 else 1) )
        }
      memo((s1,s2))
    }

    sd( s1.toList, s2.toList )
  }

There. That uglies up my function somewhat, but at least it’s usable now. And it passes my unit tests, so I’m reasonably assured that it’s right.

Memoization only works if you expect the same result for each identical function call. If your function takes input from stdin, for example, you can’t memoize that. Or if it has a random component. Or if, for any other reason, its return value is not always the same for the same inputs. You can memoize functions in different ways. There’s a post on Michid’s Weblog about a more general solution, a memoizing class which wraps existing functions to give you a memoized version.

PHEW!

Ok, I’ve tried to keep my rambling to a minimum in this post but it’s still a doozy. The things I wanted to get across are:

  • The usefulness of Levenshtein distance in solving a variety of problems.
  • How to understand the Levenshtein distance algorithm and why it works.
  • How to use unit tests to improve your code while maintaining some assurance that the new code still has the correct behavior.
  • Some of the different ways of implementing Levenshtein.

In the end, I think I like the second implementation (the one that switches out the two rows) and the last implementation the best. The second one seems to have good performance. I did some informal performance tests and it has a good mix of performance and simplicity. The last one, the memoized recursive one, appeals to me because it is in a more functional style and still has respectable performance.