I noticed some strange behavior in some Scala code recently. It was rather a mystery. I looked for my error and googled for a solution for the longest time with no success. Eventually I got my answer from the Scala mailing list / Nabble forum. Here’s the class that was causing the trouble.

class ArrayWrapper[A](length: Int) {
  private val array = new Array[A](length)
  def apply(x: Int) = array(x)
  def update(x: Int, value: A) = array(x) = value
  override def toString(): String = array.toString
}

The first thing you’ll notice about this class is that it is extremely simple! There aren’t a lot of moving parts. It’s a simple wrapper that exposes 3 basic array behaviors: apply (a ‘getter’), update (a ‘putter’), and good ol’ toString. Arrays in Scala take a type parameter, and to ensure that this class could wrap an array of any type I used a type parameter, too. Have a good look at the class and make sure you understand how it works. It won’t take long.

How do you expect this class to behave? Let’s play a little fill-in-the-blanks. Here is a Scala interpreter session with some results blanked out.

scala> class ArrayWrapper[A](length: Int) {
    |   private val array = new Array[A](length)
    |   def apply(x: Int) = array(x)
    |   def update(x: Int, value: A) = array(x) = value
    |   override def toString(): String = array.toString
    | }
defined class ArrayWrapper

scala> val a = new ArrayWrapper[Int](5)
??????????????

scala> val x = a(0)
??????????????

scala> x.toString
??????????????

scala> a(0).toString
??????????????

scala> a(0) = 0

scala> a.toString
??????????????

scala> a(0).toString
??????????????

scala>

There are 6 blanks. What do you expect to see in each of those? Well, the first blank follows the creation of a new ArrayWrapper[Int] and its assignment to a val ‘a’. So, according to our overriding definition of toString, it is simply the result of the underlying Array’s toString method. I know from experience how a brand new Array of Ints looks. It looks like this:

scala> new Array[Int](5)
res1: Array[Int] = Array(0, 0, 0, 0, 0)

So that’s what I expect to see here. Anyone expect something different? Here’s what I actually saw in the first blank:

scala> val a = new ArrayWrapper[Int](5)
a: ArrayWrapper[Int] = Array(null, null, null, null, null)

Hmm. That’s not what I expected. Did you predict this? Why is this array full of nulls when a new Array[Int] is usually full of zeros? I was stumped. The array is parameterized, I reasoned, so maybe type erasure was involved. That doesn’t make sense, though. No types should be erased at this point.

Let’s look at the next few lines, 12-16. I called a(0) (the apply method) and assigned the result to x. I then called the toString method on x. What do you expect in these two lines? I would have expected a(0) to return 0 and x.toString to return “0”, but my conviction is shaken by that last result. Will a(0) return null? Will x.toString throw a NullPointerException? Decide what you predict will happen. Here’s the actual result:

scala> val x = a(0)
x: Int = 0

scala> x.toString
res0: java.lang.String = 0

Each line behaves in the “correct” way even though we saw all those nulls in the underlying array. That’s good news, I suppose. Maybe the problem is limited to Array’s toString method. It should be smooth sailing now. Let’s now look at line 18, in which we call a(0).toString. It’s just combining the operations (apply and toString) from the previous two lines without storing the intermediate result in ‘x’. I expected that to return String “0”. You can probably guess by now that what I expected is not what I got. Make your own prediction before you read the actual result below. What will happen when we call a(0).toString?

scala> a(0).toString
java.lang.NullPointerException
       at .<init>(<console>:7)
       at .<clinit>(<console>)
       at RequestResult$.<init>(<console>:3)
       at RequestResult$.<clinit>(<console>)
       at RequestResult$result(<console>)
       at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
       at
sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:39)
       at sun.reflect.DelegatingMethodAccessorImpl.i...

Ouch! NullPointerException! This is an unpleasant surprise. The call a(0) returned a zero earlier, and calling toString on that zero returned a String “0”. But now we get this disaster. I’m getting more and more confused. Do you have an explanation for this crazy behavior yet?

Moving along, in line 21 we assign 0 to a(0). Remember that a(0) returned 0 earlier. By the way, behind the scenes the line “a(0) = 0” doesn’t call the apply method, but the ‘update’ method. It succeeds. In lines 23 and 26 we call a.toString and a(0).toString. What will happen in each case? At this point, it’s anybody’s guess. The behavior has been so wacky I can’t even make a sensible prediction. Make a guess of your own, if you dare, and observe the actual result below:

scala> a(0) = 0

scala> a.toString
res3: String = Array(0, null, null, null, null)

scala> a(0).toString
res4: java.lang.String = 0

The underlying Array now appears to contain a zero in addition to the nulls. Also, the a(0).toString, which was throwing a NullPointerException earlier, is now succeeding.

As I say, I puzzled over this problem for some time. I wanted to blame the issue on type erasure in the parameterized type, but that explanation didn’t make sense. I posted a question to the Scala forum on Nabble and got a response back in short order from Daniel Sobral.

The culprit? Drumroll…

Boxing. Well, boxing, unboxing, and a peculiarity of parameterized types. To review, here is our ArrayWrapper class:

class ArrayWrapper[A](length: Int) {
  private val array = new Array[A](length)
  def apply(x: Int) = array(x)
  def update(x: Int, value: A) = array(x) = value
  override def toString(): String = array.toString
}

We declared ‘array’ to be an Array[A], which is to say an Array of who-knows-what. When the Array is defined in this way, with a type parameter of unknown type, the Array must be an array of object references! It cannot be an array of Java int primitives. That’s the peculiarity of parameterized types. That’s why the default values for the members of the array were null instead of 0. The underlying array is actually an array of java.lang.Integer objects.

When we ran ‘val x = a(0)’, Scala retrieved the value at index 0 which was null. The apply method has Int return type in our example, and null is not an legal value of an Int. Int is Scala’s version of the Java int primitive type. So the null was converted (unboxed) to Int value 0. Then it could be stored in val x, etc. Once it’s safely unboxed, it behaves like a normal Scala Int value.

So, why did a(0).toString not work? Shouldn’t the null returned from a(0) be unboxed to Int 0, then re-boxed for the toString call? Apparently it doesn’t work that way. The unboxing hasn’t happened at the time the toString call is executed, so that toString is called on the null, giving us the NullPointerException. I don’t know whether this behavior is imposed by the JVM or the Scala language. Either way, it seems to me like a violation of the Principle of Least Astonishment and an opportunity for improvement.

Once we call a(0) = 0, then the underlying array is populated with a boxed version of 0, which is to say an instance of java.lang.Integer. After it’s populated with a non-null it works normally.

Again, this only happens for Arrays with parameterized types. If we make ArrayWrapper non-parameterized and declare ‘array’ as an Array[Int] then the problem goes away.

scala> class ArrayWrapper(length: Int) {
     |   private val array = new Array[Int](length)
     |   def apply(x: Int) = array(x)
     |   def update(x: Int, value: Int) = array(x) = value
     |   override def toString(): String = array.toString
     | }
defined class ArrayWrapper

scala> val a = new ArrayWrapper(5)
a: ArrayWrapper = Array(0, 0, 0, 0, 0)

scala> a(0).toString
res0: java.lang.String = 0

There are a few lessons in all this for the Scala developer:

  • Be vigilant about Array initialization. Initialize them explicitly, especially when dealing with primitives like Int, Long, Float, Double, Byte and Char. Don’t trust the default values.
  • Beware parameterized Arrays. They are flawed. Consider specifying their type or using another collection instead, such as a List or Map which can’t contain un-initialized values.
  • Unit test all your code, even those parts that look too simple to screw up.

In my last post I reviewed the implementation of scala.List’s foldLeft and foldRight methods. That post included a couple of simple examples, but today I’d like to give you a whole lot more. The foldLeft method is extremely versatile. It can do thousands of jobs. Of course, it’s not the best tool for EVERY job, but when working on a list problem it’s a good idea to stop and think, “Should I be using foldLeft?”

Below, I’ll present a list of problem descriptions and solutions. I thought about listing all the problems first, and then the solutions, so the reader could work on his own solution and then scroll down to compare. But this would be very annoying for those who refuse, against my strenuous urging, to start up a Scala interpreter and try to write their own solution to each problem before reading my solution.

Sum

Write a function called ‘sum’ which takes a List[Int] and returns the sum of the Ints in the list. Don’t forget to use foldLeft.

def sum(list: List[Int]): Int = list.foldLeft(0)((r,c) => r+c)
def sum(list: List[Int]): Int = list.foldLeft(0)(_+_)

I’ll explain this first example in a bit more depth than the others, just to make sure we all know how foldLeft works.

These two definitions above are equivalent. Let’s examine the first one. The foldLeft method is called on the list parameter. The first parameter is 0. This is the starting value, and the value that will be returned if list is empty. The second parameter is a function literal. It takes parameters ‘r’ (for result) and ‘c’ (for current) and returns the sum of these two values. Scala is smart enough to figure out that since the first parameter (0) is an Int, the ‘r’ parameter must also be an Int. The initial value is always the same type as ‘r’. Scala can also tell that since ‘list’ is a List[Int] the ‘c’ parameter must also be an Int, so we don’t have to specify their types in the parameter list.

The foldLeft method takes that initial value, 0, and the function literal, and it begins to apply the function on each member of the list (parameter ‘c’), updating the result value (parameter ‘r’) each time. That result value that we call ‘r’ is sometimes called the accumulator, since it accumulates the results of the function calls.

In the first defintion, foldLeft’s second parameter (a function literal) uses explicitly named parameters. Notice that ‘r’ and ‘c’ are each referred to exactly once in the function literal, and in the same order as the parameter list. When function literal parameters are used in this way (once each, same order) you can use the shorthand demonstrated in the second definition. The first ‘_’ stands for ‘r’, and the second one stands for ‘c’.

Product

Now that you’ve got the idea, try this one. Write a function that takes a List[Int], and returns the product (of multiplication) of all the Ints in the list. It will be similar to the ‘sum’ function, but with a couple of differences.

def product(list: List[Int]): Int = list.foldLeft(1)(_*_)

Did you get it? It’s the same as ‘sum’ with two exceptions. The initial value is now 1, and the function literal’s parameters are multiplied instead of added. If the initial value were 0, as in ‘sum’, then the function would always return 0.

Count

This one’s a little different. Write a function that takes a List[Any] and returns the number of items in the list. Don’t just call list.length()! Implement it using foldLeft.

def count(list: List[Any]): Int =
  list.foldLeft(0)((sum,_) => sum + 1)

First, we pick our initial value. Remember that this is the value that will be returned for an empty list. An empty list has 0 elements, so we use 0. What function do we want to apply for every item in the list? We just want to increase the result value by one. We call that parameter ‘sum’ in this solution. We don’t care about the actual value of each list element, so we call the second parameter ‘_’, which means it should be discarded.

Average

Here’s a fun one. Write a function that takes a List[Double] and returns the average of the list’s values. There are two ways to go about this one. You could combine two of the previous solutions, using two foldLeft calls, or you could combine them into a single foldLeft. Try to find both solutions.

def average(list: List[Double]): Double =
  list.foldLeft(0.0)(_+_) / list.foldLeft(0.0)((r,c) => r+1)

def average(list: List[Double]): Double = list match {
  case head :: tail => tail.foldLeft( (head,1.0) )((r,c) =>
    ((r._1 + (c/r._2)) * r._2 / (r._2+1), r._2+1) )._1
  case Nil => NaN
}

The first solution is pretty easy and combines the ‘sum’ and ‘count’ solutions. In real life, of course, you wouldn’t use foldLeft to find the length of the list. You’d just use the length() method. Other than that, though, this is a perfectly sensible solution.

The second solution is more complex. First, the list is matched against two patterns. It is either interpreted as a head item followed by a tail, or as an empty list (Nil). If it’s empty, the function returns the same thing as the first solution, NaN (Not a Number) because you can’t divide by 0.

If the list is not empty, we use a Pair as our initial value. A Pair is just an ordered pair of values. It’s a convenient way to bundle values together. We use it when we need to keep track of more than one accumulator value. In this case, we want to keep track of the average “so far” and also the number of values that the average represents. If the function literal were just passed the average so far, it wouldn’t know how to weight the next value. Members of a Pair are accessed using special methods called ‘_1’ and ‘_2’. You can have groupings longer than 2, also. These are named Tuple3, Tuple4, and so on. In fact, Pair is just an alias of Tuple2. Notice that we didn’t use the word Pair or Tuple2 anywhere in the code. If you enclose a comma-delimited series of values in parentheses, Scala converts that series into the appropriate TupleX.

After we have built up the result, it is a Pair containing the average and the number of items in the list. We only want to return the average so we call ‘_1’ on the result of foldLeft.

Last

Whew! That one was a little tough. Here’s an easier one. Given a List[A] return the last value in the list. Again, no using List’s last() method.

def last[A](list: List[A]): A =
  list.foldLeft[A](list.head)((_, c) => c)

Easy! Mostly. You’ll notice that we’re using a type parameter, A, in this one. If you’re not familiar with type parameters, too bad. I can’t explain them here. Suffice it to say that our use of A here allows us to take a list of any type of contents, and return a result of just that type. So Scala knows that when this is called on a List[Int], it will return an Int. When it’s called on a List[String], it returns a String.

First, we pick an initial value. For the empty list the concept of a last item doesn’t make any sense, so forget that. We can use any value, so long as it’s of type A. list.head is convenient, so that’s our initial value. The function literal is the simplest we’ve seen. For each item in the list, it just returns that item itself. So when it gets to the end of the list, the accumulator holds the last item. We don’t use the accumulator value in the function literal, so it gets parameter name ‘_’.

Penultimate

Write a function called ‘penultimate’ that takes a List[A] and returns the penultimate item (i.e. the next to last item) in the list. Hint: Use a tuple.

def penultimate[A](list: List[A]): A =
  list.foldLeft( (list.head, list.tail.head) )((r, c) => (r._2, c) )._1

This one is very much like the function ‘last’, but instead of keeping just the current item it keeps a Pair containing the previous and current items. When foldLeft completes, its result is a Pair containing the next-to-last and last items. The “_1” method returns just the penultimate item.

Contains

Write a function called ‘contains’ that takes a List[A] and an item of type A, and returns true if the item is one of the members of the list, and false if it isn’t.

def contains[A](list: List[A], item: A): Boolean =
  list.foldLeft(false)(_ || _==item)

We choose an initial value of false. That is, we’ll assume the item is not in the list until we can prove otherwise. We use each of the two parameters exactly once and in the proper order, so we can use the ‘_’ shorthand in our function literal. That function literal returns the result so far (a Boolean) ORed with a comparison of the current item and the target value. If the target is ever found, the accumulator becomes true and stays true as foldLeft continues.

Get

Write a function called ‘get’ that takes a List[A] and an index Int, and returns the list value at the index position. Throw an exception if the index is out of bounds.

def get[A](list: List[A], idx: Int): A =
  list.tail.foldLeft((list.head,0)) {
    (r,c) => if (r._2 == idx) r else (c,r._2+1)
  } match {
    case (result, index) if (idx == index) => result
    case _ => throw new Exception("Bad index")
  }

This one has two parts. First there’s the foldLeft, and the result is pattern matched. The foldLeft is pretty easy to follow. The accumulator is a Pair containing the current item and the current index. The current item keeps updating and the current index keeps incrementing until the current index equals the passed in idx. Once the correct index is found the same accumulator is returned over and over. This works fine if idx parameter is in bounds. If it’s out of bounds, though, the foldLeft just returns a Pair containing the last item and the last index. That’s where the pattern match comes in. If the Pair contains the right index then we use the result item. Otherwise, we throw an exception.

MimicToString

Write a function called ‘mimicToString’ that mimics List’s own toString method. That is, it should return a String containing a comma-delimited series of string representations of the list contents with “List(” on the left and “)” on the right.

def mimicToString[A](list: List[A]): String = list match {
  case head :: tail => tail.foldLeft("List(" + head)(_ + ", " + _) + ")"
  case Nil => "List()"
}

This one also uses a pattern match, but this time the match happens first. The pattern match just treats the empty list as a special case. For the general case (a non-empty list) we use, of course, foldLeft. The accumulator starts out as “List(” + the head item. Then each remaining item (notice foldLeft is called on tail) is appended with a leading “, ” and a final “)” is added to the result of foldLeft.

Reverse

This one’s kind of fun. Make sure to try it before you look at my solution. Write a function called ‘reverse’ that takes a List and returns the same list in reverse order.

def reverse[A](list: List[A]): List[A] =
  list.foldLeft(List[A]())((r,c) => c :: r)

A very simple solution! The initial value of the accumulator is just an empty list. We don’t use Nil, but instead spell out the List type so that Scala will know what type to make ‘r’. As I say, we start with the empty list which is sensible because the reverse of an empty list is an empty list. Then, as we go through the list, we place each item at the front of the accumulator. So the item at the front of list becomes the last item in the accumulator. This goes on until we reach the end of list, and that last member of list goes onto the front of the accumulator. It’s a really neat and tidy solution.

Unique

Write a function called ‘unique’ that takes a List and returns the same List, but with duplicated items removed.

def unique[A](list: List[A]): List[A] =
  list.foldLeft(List[A]()) { (r,c) =>
    if (r.contains(c)) r else c :: r
  }.reverse

As usual, we start with an empty list. foldLeft looks at each list item and if it’s already contained in the accumulator then then it stays as it is. If it’s not in the accumulator then it’s appended. This code bears a striking similarity to the ‘reverse’ function we wrote earlier except for the “if (r.contains(c)) r” part. Because of this, the foldLeft result is actually the original list with duplicates removed, but in reverse order. To keep the output in the same order as the input, we add the call to reverse. We could also have chained on the foldLeft from the ‘reverse’ function, like so:

def unique[A](list: List[A]): List[A] =
  list.foldLeft(List[A]()) { (r,c) =>
    if (r.contains(c)) r else c :: r
  }.foldLeft(List[A]())((r,c) => c :: r)

ToSet

Write a function called ‘toSet’ that takes a List and returns a Set containing the unique elements of the list.

def toSet[A](list: List[A]): Set[A] =
  list.foldLeft(Set[A]())( (r,c) => r + c)

Super easy one. You just start out with an empty Set, which would be the right answer for an empty List. Then you just add each list item to the accumulator. Since the accumulator is a Set, it takes care of eliminating duplicates for you.

Double

Write a function called ‘double’ that takes a List and a new List in which each item appears twice in a row. For example double(List(1, 2, 3)) should return List(1, 1, 2, 2, 3, 3).

def double[A](list: List[A]): List[A] =
  list.foldLeft(List[A]())((r,c) => c :: c :: r).reverse

Again, pretty easy. Are you starting to see a pattern. When you use foldLeft to transform one list into another, you usually end up with the reverse of what you really want.

Alternately, you could have used the foldRight method instead. This does the same thing as foldLeft, except it accumulates its result from back to front instead of front to back. I can’t recommend using it, though, due to problems I point out in my other post on foldLeft and foldRight. But here’s what it would look like:

def double[A](list: List[A]): List[A] =
  list.foldRight(List[A]())((c,r) => c :: c :: r)

InsertionSort

This one takes some thinking. Write a function called ‘insertionSort’ that uses foldLeft to sort the input List using the insertion sort algorithm. Try it on your own before you look at the solution.

Need a hint? Use List’s ‘span’ method.

Did you find a solution? Here’s mine:

def insertionSort[A <% Ordered[A]](list: List[A]): List[A] =
  list.foldLeft(List[A]()) { (r,c) =>
    val (front, back) = r.span(_ < c)
    front ::: c :: back
  }

First, the type parameter ensures that we have elements that can be arranged in order. We start, predictably, with an empty list as our initial accumulator. Then, for each item we assume the accumulator is in order (which it always will be), and use span to split it into two sub-lists: all already-sorted items less than the current item, and all already-sorted items greater than or equal to the current item. We put the current item in between these two and the accumulator remains sorted. This is, of course, not the fastest way to sort a list. But it’s a neat foldLeft trick.

Pivot

Speaking of sorting, you can implement part of quicksort with foldLeft, the pivot. Write a function called ‘pivot’ that takes a List, and returns a Tuple3 containing: (1) a list of all elements less than the original list’s first element, (2) the first element, and (3) a List of all elements greater than or equal to the first element.

def pivot[A <% Ordered[A]](list: List[A]): (List[A],A,List[A]) =
  list.tail.foldLeft[(List[A],A,List[A])]( (Nil, list.head, Nil) ) {
    (result, item) =>
    val (r1, pivot, r2) = result
    if (item < pivot) (item :: r1, pivot, r2) else (r1, pivot, item :: r2)
  }

We’re using the first element, head, as the pivot value, so we skip the head and call foldLeft on list.tail. We initialize the accumulator to a Tuple3 containing the head element with an empty list on either side. Then for each item in the list we just pick which of the two lists to add to based on a comparison with the pivot value.

If you take the additional step of turning this into a recursive call, you can implement a quicksort algorithm. It probably won’t be a very efficient one because it will involve a lot of building and rebuilding lists. Give it a try if you like, and then look at my solution:

def quicksort[A <% Ordered[A]](list: List[A]): List[A] = list match {
  case head :: _ :: _ =>
    println(list)
    list.foldLeft[(List[A],List[A],List[A])]( (Nil, Nil, Nil) ) {
      (result, item) =>
      val (r1, r2, r3) = result
      if      (item < head) (item :: r1, r2, r3)
      else if (item > head) (r1, r2, item :: r3)
      else                  (r1, item :: r2, r3)
    } match {
      case (list1, list2, list3) =>
        quicksort(list1) ::: list2  ::: quicksort(list3)
    }
  case _ => list
}

Basically, for all lists that have more than 1 element the function chooses the head element as the pivot value, uses foldLeft to divide the list into three (less than, equal to, and greater than the pivot), recursively sorts the less-than and greater-than lists, and knits the three together.

Encode

Ok, we got a little into the weeds with that last one. Here’s a simpler one. Write a function called ‘encode’ that takes a List and returns a list of Pairs containing the original values and the number of times they are repeated. So passing List(1, 2, 2, 2, 2, 2, 3, 2, 2) to encode will return List((1, 1), (2, 5), (3, 1), (2, 2)).

def encode[A](list: List[A]): List[(A,Int)] =
list.foldLeft(List[(A,Int)]()){ (r,c) =>
    r match {
      case (value, count) :: tail =>
        if (value == c) (c, count+1) :: tail
        else            (c, 1) :: r
      case Nil =>
        (c, 1) :: r
    }
}.reverse

Decode

You knew this was coming. Write a function called ‘decode’ that does the opposite of encode. Calling ‘decode(encode(list))’ should return the original list.

def decode[A](list: List[(A,Int)]): List[A] =
list.foldLeft(List[A]()){ (r,c) =>
    var result = r
    for (_ <- 1 to c._2) result = c._1 :: result
    result
}.reverse

Encode and decode could both have been written by using foldRight and dropping the call to reverse.

Group

One last example. Write a function called ‘group’ that takes a List and an Int size that groups elements into sublists of the specified sizes. So calling “group( List(1, 2, 3, 4, 5, 6, 7), 3)” should return List(List(1, 2, 3), List(4, 5, 6), List(7)). Don’t forget to make sure list items are in the right order. Try it yourself before you look at the solution below.

def group[A](list: List[A], size: Int): List[List[A]] =
  list.foldLeft( (List[List[A]](),0) ) { (r,c) => r match {
    case (head :: tail, num) =>
      if (num < size)  ( (c :: head) :: tail , num + 1 )
      else             ( List(c) :: head :: tail , 1 )
    case (Nil, num) => (List(List(c)), 1)
    }
  }._1.foldLeft(List[List[A]]())( (r,c) => c.reverse :: r)

This code uses the first foldLeft to group the items in a way that’s convenient to list operations, and that last foldLeft to fix the order, which would otherwise be wrong in both the outer and inner lists.

The End!

That’s all for now. If you know of any neat foldLeft tricks, please do leave a comment. I’d be interested to hear about it.

In Java you don’t see a lot of linked lists, and if you do it’s almost always java.util.LinkedList. People never write their own lists. They don’t really need to, I suppose. The one from java.util is fine. Plenty of people are leading fulfilling software careers never having implemented their own linked list. But it’s kind of a shame. Knowing how your data structures work makes you a better programmer.

It’s even rarer for a person to implement his own linked list in Scala. Scala’s scala.List is one of the most used classes in the language, so it’s packed with functionality. It’s abstract, covariant, it has helper objects such as List and Nil and the little-known ‘::’ class, it inherits from Product, Seq, Collection, Iterable, and PartialFunction. The machinery of List pulls in Array, ListBuffer, and more. It can be hard to take it all in.

So let’s build our own linked list. We’ll start out with something very basic and un-Scala-like. Then we’ll improve it gradually until we have something a little closer to scala.List. I encourage you to fire up your Scala interpreter and follow along.

Back to Basics

First, a short review of linked lists. A linked list is a chain of nodes, each referring to exactly one other node until you get to the end of the chain. You refer to the list by its first item and you follow the chain of references to reach the other nodes.

What are the requirements for our first try? Our node should be able to hold a piece of data, and refer to the next node. It should also be able to report its length, and provide a toString method so we can visualize the list. Here we go.

class MyList(val head: Any, val tail: MyList) {
  def isEmpty = (head == null && tail == null)
  def length: Int = if (isEmpty) 0 else 1 + tail.length
  override def toString: String = if (isEmpty) "" else head + " " + tail
}

The value ‘head’ holds the data for the node, ‘tail’ refers to the next element in the chain. The ‘isEmpty’ method is true if the head and tail are both null. The length and toString methods are both defined using a similar pattern: if (isEmpty) [base result] else [data for current node + result of same method on tail].

Here’s what it looks like when we use this class:

scala> var list = new MyList(null, null)
list: MyList =

scala> list.length
res0: Int = 0

scala> list.isEmpty
res1: Boolean = true

scala> list = new MyList("ABC", list)
list: MyList = ABC

scala> list.length
res3: Int = 1

scala> list.isEmpty
res4: Boolean = false

scala> list = new MyList("XYZ", list)
list: MyList = XYZ ABC

scala> list = new MyList("123", list)
list: MyList = 123 XYZ ABC

scala> list.tail.head
res7: Any = XYZ

Not bad. It gets the job done. But it has some problems. First is the use of ‘null’. Use of null references is sloppy and increases the odds of a null pointer exception so, ideally, we don’t want to see that. It has other problems, too. It’s too verbose. It’s not typesafe. But for now let’s concentrate on getting rid of the nulls.

No Nulls Is Good Nulls

How can we do it? We’re using the null as a special value, a marker to tell us when a node is at the end of a list. So we’ll just use something else as that marker instead. What can we use? We’ll create a special object for the empty list. It will be recognized as empty just based on its identity, not on null values. So let’s try it:

class MyList(val head: Any, val tail: MyList) {
  def isEmpty = false
  def length: Int = if (isEmpty) 0 else 1 + tail.length
  override def toString: String = if (isEmpty) "" else head + " " + tail
}

object MyListNil extends MyList("arbitrary value", null) {
  override def isEmpty = true
}

That’s better. (The observant reader will note the similarity of MyListNil to scala.List’s Nil object.) We got rid of the nulls in the isEmpty method, but we still have to put something in the head and tail parameters of the MyList constructor. We put an arbitrary non-null value in head, but what do we put for tail? Either null or create a new MyList. And how can that MyList be instantiated? It also needs a tail. Vicious circle. So this solution leaves us still stuck with a null.

Earlier, the null was there to mark a special node. We factored out that usage. Now it’s there to allow us to create the MyListNil. How can we factor that out? MyListNil is required to call its parent’s constructor. What if had no parent? Then it wouldn’t be a MyList anymore. What if it had an abstract parent? Now you’re talking. Let’s see what that would look like.

abstract class MyList {
  def head: Any
  def tail: MyList
  def isEmpty: Boolean
  def length: Int
}

class MyListImpl(val head: Any, val tail: MyList) extends MyList {
  def isEmpty = false
  def length: Int = 1 + tail.length
  override def toString: String = head + " " + tail
}

object MyListNil extends MyList {
  def head: Any = throw new Exception("head of empty list")
  def tail: MyList = throw new Exception("tail of empty list")
  def isEmpty = true
  def length = 0
  override def toString =  ""
}

It’s a little more code, but much neater. There are no nulls anywhere. Here’s how it looks when we use this new MyList:

scala> var list: MyList = MyListNil
list: MyList =

scala> list = new MyListImpl("ABC", list)
list: MyList = ABC

scala> list = new MyListImpl("XYZ", list)
list: MyList = XYZ ABC

scala> list = new MyListImpl("123", list)
list: MyList = 123 XYZ ABC

scala> list.length
res3: Int = 3

scala> list.tail.head
res4: Any = XYZ

scala> list.tail.tail.tail.head
java.lang.Exception: head of empty list
        at ...

Pretty neat. The equivalent of MyListImpl in the Scala’s real List implementation is a class called ‘::’, which has that funny name, by the way, because it looks nice in pattern matching code. Sometimes ‘::’ is referred to as cons. With nulls finally eliminated, we can concentrate on other issues.

Brevity Is The Heart Of List

The thing that I notice at this point is that a lot of typing (on the keyboard) is required to use this list. We have to type out “list = new MyListImpl(…, list)” every time we add an item. We can improve this with a new method.

abstract class MyList {
  [...]
  def add(item: Any): MyList = new MyListImpl(item, this)
}

Now we have classes referring to each other. MyList creates new MyListImpls, and MyListImpl extends MyList. So you’ll need to put these classes in a .scala file and compile them instead of just typing them into the Scala interpreter. But, wow! Look how much easier it is to use MyList now:

scala> var list = MyListNil add("ABC") add("XYZ") add("123")
list: MyList = 123 XYZ ABC

scala> list.length
res1: Int = 3

So much easier! One thing I notice, though, is that the order of items in the code is different from the order produced by toString. We can change our ‘add’ method so that is right-associative instead of left-associative by using a method name that ends in ‘:’ (colon). We’ll use ‘::’ as the method name since that’s what scala.List uses.

abstract class MyList {
  [...]
  def ::(item: Any): MyList = new MyListImpl(item, this)
}

scala> var list = "ABC" :: "XYZ" :: "123" :: MyListNil
list: MyList = ABC XYZ 123

Now we’re really getting somewhere. This is starting to look more like scala.List. One other thing that the standard list implementation gives you is a shortcut for initializing lists. It looks like “List(1, 2, 3, 4)”. Notice there’s no ‘new’ keyword. This is done using the scala.List helper object and its ‘apply’ method. Below is our own MyList helper object.

object MyList {
  def apply(items: Any*): MyList = {
    var list: MyList = MyListNil
    for (idx <- 0 until items.length reverse)
      list = items(idx) :: list
    list
  }
}

scala> var list = MyList("ABC", "XYZ", "123")
list: MyList = ABC XYZ 123

scala> list = "Cool" :: list
list: MyList = Cool ABC XYZ 123

Better Type-Safe Than Sorry

Better. Our code looks much neater now when we use MyList. I’ll introduce just one more improvement to MyList. It still has a rather glaring problem. It provides no type information. It keeps all of its data using a reference to Any. If you don’t see why this is a problem, let’s see what happens when we want to get the length of some items in a MyList:

scala> var list = MyList("ABC", 12345, "WXYZ")
list: MyList = ABC 12345 WXYZ

scala> list.head.length
<console>:6: error: value length is not a member of Any
       list.head.length
                 ^

scala> list.head.asInstanceOf[String].length
res10: Int = 3

scala> list.tail.head.asInstanceOf[String].length
java.lang.ClassCastException: java.lang.Integer cannot be cast to java.lang.String
        at .<init>(<console>:6)

Ouch! First, when we try to call method ‘length’ on list.head Scala complains that list.head is a reference to Any. Any doesn’t have a length method. This isn’t a dynamically typed language like, say, Ruby. An object has to have the right type before we can start calling methods. What to do? You could implement a MyStringList where the head has type String. But then you’ll need a MyIntList, MyDoubleList, etc. What we need is a way to specify the type of data in the list when we create the MyList instance. What we need is a type parameter.

Here’s the complete MyList code using a type parameter, and a little demonstration code:

abstract class MyList[A] {
  def head: A
  def tail: MyList[A]
  def isEmpty: Boolean
  def length: Int
  def ::(item: A): MyList[A] = new MyListImpl[A](item, this)
}

class MyListImpl[A](val head: A, val tail: MyList[A]) extends MyList[A] {
  def isEmpty = false
  def length: Int = 1 + tail.length
  override def toString: String = head + " " + tail
}

object MyListNil extends MyList[Nothing] {
  def head: Nothing = throw new Exception("head of empty list")
  def tail: MyList[Nothing] = throw new Exception("tail of empty list")
  def isEmpty = true
  def length = 0
  override def toString =  ""
}

object MyList {
  def apply[A](items: A*): MyList[A] = {
    var list: MyList[A] = MyListNil.asInstanceOf[MyList[A]]
    for (idx <- 0 until items.length reverse)
      list = items(idx) :: list
    list
  }
}

scala> var list = MyList("ABC", "WXYZ", "123")
list: MyList[java.lang.String] = ABC WXYZ 123

scala> list.head.length
res0: Int = 3

scala> 3.14159 :: list
<console>:6: error: type mismatch;
 found   : Double
 required: java.lang.String
       3.14159 :: list
               ^

scala> var list = MyList("ABC", 123, 3.14159)
list: MyList[Any] = ABC 123 3.14159

Look at line 32. The “MyList(…)” returns a MyList[String]. Scala figures out from the parameters what type to use. In line 35, you can see how much easier it is to use the list contents when you know the type at compile time.

If you try to mix types, as in line 45, Scala determines the nearest common ancestor of the types (Any, in this case) and uses that. However, if the type parameter is already determined, as in line 38, it won’t change when you try to add data of a different type. To make line 38 work, we can make a small change to the ‘::’ method:

abstract class MyList[A] {
  [...]
  def ::[B >: A](item: B): MyList[B] = 
    new MyListImpl(item, this.asInstanceOf[MyList[B]])
}

scala> var list = MyList("ABC", "XYZ")
list: MyList[java.lang.String] = ABC XYZ

scala> 3.14159 :: list
res0: MyList[Any] = 3.14159 ABC XYZ

This says that ‘::’ takes a parameter of type B, which is either A or a superclass of A, and returns a MyList[B]. So if you have a MyList[String] and you call ‘::’ on it with a Double parameter, Scala figures out that although Double is not a superclass of String, String and Double are both descendants of Any, and it returns a MyList[Any].

Conclusion

That’s a good stopping point for now. Obviously you can take the MyList class a lot further and add a lot more methods, but we’ve created some code that approximates the basics provided by scala.List. In fact, you could take several of the scala.List methods (foldLeft, for example) and basically drop them right into MyList and they’d work fine.

One of my favorite functional programming tricks is folding. The fold left and fold right functions can do a lot of complicated things with a small amount of code. Today, I’d like to (1) introduce folding, (2) make note of some surprising, nay, shocking fold behavior, (3) review the folding code used in Scala’s List class, and (4) make some presumptuous suggestions on how to improve List.

Update: I’ve created a new post in which I list lots and lots of foldLeft examples in case you’d like to learn more about what folding can accomplish.

Know When to Hold ‘Em, Know When to Fold ‘Em

In case you’re not familiar with folding, I’ll describe it as briefly as I can.

Here’s the signature of the foldLeft function from List[A], a list of items of type A:

def foldLeft[B](z: B)(f: (B, A) => B): B

Firstly, foldLeft is a curried function (So is foldRight). If you don’t know about currying, that’s ok; this function just takes its two parameters (z and f) in two sets of parentheses instead of one. Currying isn’t the important part anyway.

The first parameter, z, is of type B, which is to say it can be different from the list contents type. The second parameter, f, is a function that takes a B and an A (a list item) as parameters, and it returns a value of type B. So the purpose of function f is to take a value of type B, use a list item to modify that value and return it.

The foldLeft function goes through the whole List, from head to tail, and passes each value to f. For the first list item, that first parameter, z, is used as the first parameter to f. For the second list item, the result of the first call to f is used as the B type parameter.

For example, say we had a list of Ints 1, 2, and 3. We could call foldLeft(“X”)((b,a) => b + a). For the first item, 1, the function we define would add string “X” to Int 1, returning string “X1″. For the second list item, 2, the function would add string “X1″ to Int 2, returning “X12″. And for the final list item, 3, the function would add “X12″ to 3 and return “X123″.

Here are a few more examples.

list.foldLeft(0)((b,a) => b+a)
list.foldLeft(1)((b,a) => b*a)
list.foldLeft(List[Int]())((b,a) => a :: b)

The first line is super simple. It’s almost like the example I described above, but the z value is the number 0 instead of string “X”. This fold combines the elements of the list by addition instead of concatenation. So the fold returns the sum of all Ints in the list. Line 2 combines them through multiplication. Do you see why the z value is 1 in this case?

Line 3 is a little more complex. Can you guess what it does? It starts out with an empty list of Ints and adds each item to the accumulator (We call the b parameter of our function the accumulator because it accumulates data from each of our list items). Because it starts with the head and adds to the beginning of the accumulator list until it gets to the last item of the original list, it returns the original list in reverse order.

The foldRight function works in much the same way as foldLeft. Can you guess the difference? You got it. It starts at the end of the list and works its way up to the head.

Folds can be used for MUCH more than I’ve shown here. With folds, you can solve lots of different problems with a standard construct. You should read up on them if you’re just starting out in functional programming.

All That Glitters Is Not Fold

Now for the moment you’ve been waiting for. Fold’s dirty little secret! The below is taken from a scala interpreter session.

scala> var shortList = 1 to 10 toList
shortList: List[Int] = List(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)

scala> var longList = 1 to 325000 toList
longList: List[Int] = List(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 
15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, ...

scala> shortList.foldLeft("")((x,y) => "X")
res1: java.lang.String = X

scala> shortList.foldRight("")((x,y) => "X")
res2: java.lang.String = X

scala> longList.foldLeft("")((x,y) => "X")
res3: java.lang.String = X

scala> longList.foldRight("")((x,y) => "X")
java.lang.StackOverflowError
        at scala.List.foldRight(List.scala:1079)
        at scala.List.foldRight(List.scala:1081)
        at scala.List.foldRight(List.scala:1081)
        at scala.List.foldRight(List.scala:1081)
        at scala.List.foldRight(List.scala:1081)
        at scala.List.foldRight(List.scala:1081)
        at scala.List.foldRight(List.scala:1081)
        at scala.List.foldRight(List.scala:1081)
        at scala.List.foldRig...

We created two lists: shortList with 10 items, and longList with 325,000 items. Then we perform a trivial foldLeft and foldRight on shortList. It’s trivial because the passed-in function always returns the string “X”; it doesn’t even use the list data.

Then we do a foldLeft on longList. This goes off without a hitch. Finally we try to do a foldRight, the same foldRight that succeeded on the shorter list, and it fails! The foldLeft worked. Why didn’t the foldRight work? It’s a perfectly reasonable call against a perfectly reasonable List. Something funny is going on here.

The error message says there was a stack overflow, and the stack trace shows a long list of calls at List.scala:1081. If you’ve read my post about tail-recursion, then you probably suspect that some recursive code is to blame.

Let’s look into List.scala, maybe the single most important Scala source file.

Fool’s Fold

Without further ado, here’s the code for foldLeft and foldRight from List.scala:

override def foldLeft[B](z: B)(f: (B, A) => B): B = {
  var acc = z
  var these = this
  while (!these.isEmpty) {
    acc = f(acc, these.head)
    these = these.tail
  }
  acc
}

override def foldRight[B](z: B)(f: (A, B) => B): B = this match {
  case Nil => z
  case x :: xs => f(x, xs.foldRight(z)(f))
}

Wow! Those two definitions are very different!

The foldLeft function is the one that worked for short and long lists. You can see why? It isn’t head-recursive. In fact, it isn’t recursive at all. It is implemented as a while loop. On each iteration, the next list item is passed to the function f and the accumulator (called acc) is updated. When there are no more list items, the accumulator is returned. No recursion means no stack overflows.

The foldRight function is implemented in a totally different way. If the list is empty, the z parameter is returned. Otherwise, a recursive call is made on the tail (the whole list minus the first item) of this list, and that result is passed to the function f. Study the foldRight definition. Do you understand how it works? It’s an elegant recursive solution, and the code really is quite pretty, but it’s not tail recursive so it fails for large lists.

Why didn’t Mr Odersky just write foldRight using a while loop, too? Then this problem wouldn’t exist, right? The reason is that Scala’s List is a implemented as a singly-linked list. Each list element has access to the next item in the list, but not to the previous item. You can only traverse a list in one direction! This works fine for foldLeft, which goes from head to tail, but foldRight has to start at the end of the list and work its way forward to the head. If foldRight uses recursion, it must recurse all the way to the end and then use the results of those recursive calls as the accumulator passed into function f.

See? The results of the recursive call must be used for further calculation, so the recursive call can’t be the last thing that happens, so it can’t be written as a tail-recursive function. If you don’t know what I’m talking about, read my introduction to tail-recursion.

Out With The Fold, In With The New

So is that it for foldRight? Is it hopeless? I say no!

There is a way to get the same result as foldRight, but using foldLeft. Can you guess what it is? Here’s how:

list.foldRight("X")((a,b) => a + b)
list.reverse.foldLeft("X")((b,a) => a + b)

These two lines are equivalent! They give the same result no matter what’s in list. Since foldRight processes list elements from last to first, that’s the same as processing the reversed list from first to last.

Here are three possible implementations of foldRight that could replace the current one.

def foldRight[B](z: B)(f: (A, B) => B): B = 
  reverse.foldLeft(z)((b,a) => f(a,b))

def foldRight[B](z: B)(f: (A, B) => B): B =
  if (length > 50) reverse.foldLeft(z)((b,a) => f(a,b))
  else             originalFoldRight(z)(f)

def foldRight[B](z: B)(f: (A, B) => B): B =
  try {
    originalFoldRight(z)(f)
  } catch {
    case e1: StackOverflowError => reverse.foldLeft(z)((b,a) => f(a,b))
  }

The first one simply replaces the original recursive logic with the equivalent call to reverse and foldLeft. Why wasn’t foldRight implemented this way to begin with? It may be, in part, that the authors thought the extra overhead of reversing the list was unwarranted. To me, it doesn’t seem that bad. The original foldRight and foldLeft functions are O(n), meaning they run in an amount of time roughly proportional to the number of items in the list. If you look at the source for the reverse function, you’ll see it’s also O(n). So running reverse followed by foldLeft is O(n).

The second implementation is a compromise. It uses the original recursive version of foldRight (referred to as originalFoldRight in the above code) only when the list is shorter than 50 elements. The reverse.foldLeft is used for lists of 50 elements or longer. 50 is just an arbitrary number, just a guess at a sensible limit on the number of recursive calls to allow.

The third implementation tries the original foldRight logic first and if the call stack overflows then it uses reverse.foldLeft. This solution is, of course, completely ridiculous, but even this would be better than a foldRight which sometimes crashes your program.

That’s All, Folds!

As I pointed out before, the reverse.foldLeft implementation of foldRight is O(n), same as the original recursive version. The original foldRight may work just fine when your Scala application is young and working with small data sets. Over time more customers are added, more products are created, more orders are placed, and then one day, *POOF*, a runtime error! It’s a ticking time-bomb.

As you may well guess, I would like to see the reverse.foldLeft logic used instead of the recursive version. That would prevent the stack overflow errors. But I would settle for just deprecating foldRight. It would be better to eliminate foldRight and force the coder to work around it than to leave it in its current state. In fact, I don’t think any head-recursive functions belong in the List class.

Do any readers have any insight into why foldRight is coded the way it is?

I was playing a game on my iPhone called Scramble the other day. It’s a great game. You are presented with a 4×4 grid of letters, and your job is to find words by chaining together adjacent letters. It bears a passing similarity to Boggle. I was playing online and I noticed that many other players play much better than I do. Well, a software developer doesn’t take a thing like this lying down. I decided it was time to write a Scramble (or Boggle) solver!

And since I’m trying to pick up the Scala language whenever I have an opportunity, I wrote the whole thing in Scala. I hope it will be, for the reader, an interesting study of a complete and useful (though simple) example Scala program.

First, I had to decide on a strategy. If we try every path on the game board, that’s inefficient. For example, if we try a path XQ then we can stop right there. XQ is not a word and no word starts with the letters XQ. This suggests using some sort of spell-checker-like logic. So let’s first create a dictionary data structure that allow us to look up words and eliminate dead ends. What we need is a tree. Here’s the data structure I came up with.

import scala.collection.mutable.HashMap

class LetterTree {
    private val nodes: HashMap[Char,LetterTree] = new HashMap[Char,LetterTree]
    var terminal: Boolean = false
    def addWord(word: String): Unit = addWord(word.toList)
    def addWord(word: List[Char]): Unit = word match {
        case Nil          => terminal = true
        case head :: tail => nodes.getOrElseUpdate(head, new LetterTree).addWord(tail)
    }
    def getSubTree(letter: Char): Option[LetterTree] =
        if (nodes.contains(letter)) Some(nodes(letter)) else None
}

It’s a mutable class. I toyed with some ideas for an immutable class, but it just complicated things more than I wanted to cope with. Our dictionary will be a tree in which LetterTree is the node class. You can see that a LetterTree has two member data: a hashmap of child LetterTrees indexed by Char, and a flag called terminal. So a LetterTree is a tree node whose child nodes are indexed by letter, and which can be marked (via the terminal flag) as ending a word. This is a pretty efficient way to store words.

There are two addWord methods for populating the tree. One takes a list of Chars. The other takes a String and is included for convenience. It converts its String parameter to a list of Chars and calls the other addWord method. If the addWord method is passed an empty list (Nil), then it has reached the end of a word and sets the terminal flag. Otherwise, it takes the first Char in the list and looks up (or else creates) the LetterTree mapped to that Char. The method then adds the remainder of the Char list to that mapped LetterTree.

Finally, there’s the getSubTree method. This will be useful when we start using the tree to look up potential word matches.

So imagine that we create a LetterTree and add the following words: item, its, it. We get a structure like this:

Figure 1

Figure 1

Each of the boxes represents a LetterTree instance. Each arrow, labeled with letter, represents an entry in the ‘nodes’ HashMap. The top level LetterTree does not itself represent a letter. It’s just a starting point. Its ‘nodes’ member has just one mapping: letter ‘i’ maps to a second LetterTree. That second LetterTree doesn’t have its ‘terminal’ flag set, so we haven’t made a word yet. Its ‘nodes’ map has a single entry mapping ‘t’ to a third LetterTree. This third LetterTree is a terminal node, so we know that the sequence ‘it’ forms a word. The third LetterTree’s ‘nodes’ map has entries for ‘s’ and ‘e’. The LetterTree mapped to ‘s’ is terminal, and the one mapped to ‘e’ is not. The ‘e’ node has a child node ‘m’ that is a terminal.

This is a pretty efficient way to store words. We got to use the sequence ‘it’ 3 times! Also, note that there is exactly one terminal node for each word. If we add a word, we will either append a new leaf node which will be terminal, or we will make one of the existing nodes terminal.

Now, all we have to do is create a LetterTree and call the addWord method for every word we can think of. That could be annoying. Let’s create an improved LetterTree that will read a list of words from a file.

import java.io.File
import scala.io.Source

class FileLetterTree(path: String) extends LetterTree {
    val file = new File(path)
     for (line <- Source.fromFile(file).getLines) addWord(line.trim)
}

Can you believe how easy that was?! We just extend LetterTree, take a file path as a constructor parameter, use the scala.io.Source class to get all lines, and add each line as a word. Now just find a text file containing all English words. You should be able to google it.

You might want to ensure the file contains only words 3 letters or longer (as required by the rules of the game), only lower case, and only the 26 english letters. Since Scramble has no ‘Q’ piece (neither does Boggle) but only a ‘Qu’ piece, you might want to do a global replace on your word file, replacing all instance of ‘qu’ with ‘q’.

Ok, now we have a data structure for looking up words.  Now we need to create some code that represents the game board.  We’ll assume a 4-by-4 board.

class GameBoard(lettersStr: String) {
    private val ltrStr = lettersStr.toLowerCase()
    if (!ltrStr.matches("^[a-z]{16}$"))
    throw new Exception("Exactly 16 letters a-z are required.")

    override def toString: String =
        ltrStr.substring(0,4)  + "\n" + ltrStr.substring(4,8) + "\n" +
        ltrStr.substring(8,12) + "\n" + ltrStr.substring(12,16)

    case class Letter(letter: Char) {
        var neighbors = List[Letter]()
        def addNeighbor(nbr: Letter) = { neighbors = nbr :: neighbors }
        override def toString = letter.toString
    }

    val letters = new Array[Array[Letter]](4,4)
    for (idx <- 0 until ltrStr.length)
        letters(idx/4)(idx % 4) = Letter(ltrStr(idx))

    for ( idx <- 0 to 3; jdx <- 0 to 3; iOff <- -1 to 1; jOff <- -1 to 1;
          if (iOff != 0 || jOff != 0) &&
          idx + iOff >= 0 && idx + iOff < 4 &&
          jdx + jOff >= 0 && jdx + jOff < 4 )
        letters(idx)(jdx).addNeighbor(letters(idx + iOff)(jdx + jOff))
}

That’s a lot of new code, but it is actually pretty simple if we break it down:

Lines 2-4 just ensure that we have good input.  The code converts the constructor parameter to lower case, and then confirms that it is composed of exactly 16 letters from a to z.

The next section is a toString method.  It just returns the 16-letter string in 4-letter chunks separated by newlines.

Next, there is an inner class called Letter.  It encapsulates one letter on the game board, and includes a way to keep track of neighboring Letters (think of it as a graph node).

Line 16 creates the game board, a 4-by-4 array of Letters.  The for-loop that follows initializes each of the 16 Letter objects using the corresponding letters from the 16-letter string.

All that’s left is to tell each Letter who his neighbors are.  This is done with the somewhat complex for-loop on line 20.  This structure loops through two indices, idx and jdx, as well as two offsets, iOff and jOff.  It’s like four nested loops combined into one.  It loops through idx and jdx values from 0 to 3, so it runs for each of the 16 Letters in the grid.  It also loops through iOff and jOff from -1 to 1, so it looks at each neighbor of each Letter.  See?  For each of the 16 Letters in the grid, it looks at each of the 9 Letters around that Letter (including the Letter itself).

The last section of the for-loop header (lines 21-23) defines which combinations of idx, jdx, iOff, and jOff are valid and should be processed in the loop body.  You see why, right?  First of all, there’s no need for a Letter to add itself to its list of neighbors.  It’s against the game rules to use the same grid position twice in the same word, so any iterations in which iOff and jOff are both 0 are eliminated at line 21.

Also, there are grid positions at the edges and corners.  Those positions will have fewer neighbors.   So for the Letter at position idx=0, jdx=0, offsets of iOff=-1 OR jOff=-1 make no sense.  There are no neighboring Letters in the -1 direction at that position.  These restrictions are made by lines 22 and 23.

Finally, if the indices and offsets pass the tests, the Letter at location (idx, jdx) is assigned a new neighbor, the Letter at (idx+iOff, jdx+jOff).

Of course, this isn’t the only way to write a program like this.  We could have used a simple array of characters and put the neighbor-finding logic in the solver code, but I chose to divide up the responsibilities like this.

We’ve finished setting up the game board.  We have a dictionary of legal words.  Now we’re ready to play.  Let’s add a function called findWords to the GameBoard class.

def findWords(tree: LetterTree): List[String] = {
    def findWords(tree: LetterTree, letter: Letter, sofar: List[Letter]): List[String] = {
        tree.getSubTree(letter.letter) match {
          case Some(subTree) =>
            var words: List[String] = Nil
            if (subTree.terminal) words = (letter :: sofar).foldLeft("")((c,n) => n+c) :: words
            for (nextLetter <- letter.neighbors if !sofar.contains(nextLetter))
            words = findWords(subTree, nextLetter, letter :: sofar) ::: words
            words
          case None => Nil
        }
    }
    var words: List[String] = Nil
    for (idx <- 0 to 3; jdx <- 0 to 3)
        words = words ++ findWords(tree, letters(idx)(jdx), Nil)
    words
}

This is the heart of our little program. This is the code that does the real work. The first thing that happens in this function is that we define an inner function.  We’ll look at that in a second.  First, look at code below the inner function.  We create an empty List of Strings, then for each Letter in the 4-by-4 grid we call the inner findWords function.  We pass in the dictionary tree (at the top tree level), the current Letter, and an empty list.  That’s the list of letters used so far.  At this point, no letters have been used yet, so it’s empty (Nil).

The result of the inner function call is a list of all valid words that start with the given Letter.  That list is added to the list of all valid words.  We do this for each of the 16 Letters.

Now, for that inner function.  Let’s first examine the parameter list. First is tree of type LetterTree. This is the dictionary class we built earlier. On the first call, we pass in the whole dictionary, the top level tree. As we search for words, though, we will pass sub-trees, sub-sub-trees and so forth. The second parameter is a letter of type Letter. That’s just the current letter that we’re evaluating.

The last parameter is called sofar and it has type List[Letter]. Why is it called sofar? Because it’s the ordered list of connected Letters from the game board that we know begin words in the dictionary. The sofar list might contain the letters (c, o, m, p) because this is a prefix to words like compute, computer, computing, comparison, etc. We would never be passed a sofar list containing (c, o, m, p, x) because although this sequence could show up on the game board, this is not a prefix to any word in the dictionary. This parameter will grow longer with each recursive call to the inner function. This is why on the first call to the inner function we pass the empty list, Nil, as the sofar parameter.

We first look up the sub-tree beginning with the letter value of the Letter parameter.  If we find that there is no sub-tree for this letter (the result of getSubTree matches None) then we know that there are no words beginning with the letter. It’s a dead end, so we return an empty list and we do not recurse any further.

If we do find a sub-tree, though, that means that there are words that begin with the sofar list followed by this letter. On line 6, we check to see if the node for the current letter is a terminal node in our dictionary. If it is, then we have a legal word. We append the current letter to the sofar list, convert the list to a string, and put it in the local words list.

In the next line, we loop through all of the current Letter’s neighbors, excluding any that are already in the sofar list. For each unused neighbor, we make a recursive call. This time, the parameters are the dictionary sub-tree, the neighbor letter generated by the for loop, and the sofar list with the current Letter appended. The list of words returned by each recursive call is combined with the local word list. After all the neighboring letters have been checked, the word list is returned from the function.

This is basically a graphs problem. The game board is an undirected graph. The dictionary is a tree, which is a type of graph. We are taking circuit-free paths along the game board graph and mapping them to paths on the tree, accumulating a list of matching paths for which the end is marked terminal on the tree. There are Java (and maybe Scala) libraries for dealing with graphs, and when I get a chance I’d like to see if I can get a more tidy implementation using one of these libraries.

Now, we’ll just pull all this together in a neat little package:

class PuzzleSolver(dictionaryPath: String) {
    val tree = new FileLetterTree(dictionaryPath)
    def solve(letters: String) = {
        val board = new GameBoard(letters)
        val wordSet = new HashSet[String]() ++ board.findWords(tree)
        val sortedWords = wordSet.toList.sort{ (a,b) =>
            a.length > b.length || (a.length == b.length && a > b)
        }
        println(sortedWords)
    }
}

Now we can create an instance of the PuzzleSolver class for a dictionary file that we specify. Then we can call the solve function for game board configurations. This class finds all the legal words contained in the game board, sorts them by length first and alphabetically second, and prints them out.

Here’s a sample session in the Scala interpreter:

scala> val solver = new PuzzleSolver("./words.txt")
solver: PuzzleSolver = PuzzleSolver@1876e5d

scala> solver.solve("TestingTheSolver")
List(storeen, torsel, tinsel, tensor, seeing, nestor, inseer, 
ingest, verst, verso, torse, tinge, store, soree, soget, snite, 
sneer, rotse, rotge, rosel, reest, orsel, inset, insee, ingot, 
hinge, gorse, geest, vlei, vest, vent, vein, veer, veen, tore, 
togs, ting, tine, tien, teng, stog, sore, snee, sero, sent, 
seit, sego, seer, seen, seel, rose, rest, rees, reen, reel, 
ogee, neti, nest, neer, lest, lent, lens, lehi, lees, leer, 
iten, hint, hing, hest, hent, hein, heer, gore, goes, goer, 
gest, gent, gens, gein, eros, vei, vee, tor, tog, toe, tin, 
tie, ten, teg, sot, sog, soe, set, ser, sen, seg, see, rot, 
rog, roe, rev, ree, ose, ore, oes, oer, nit, net, nei, nee, 
lev, les, len, lei, leg, lee, ing, hit, hin, hie, hen, hei, 
got, gos, gor, get, ges, gen, gel, gee, ers, ens, ego, eer, 
eel)

scala>

It works! Here’s the complete source code:

import java.io.File  
import scala.io.Source  
import scala.collection.mutable.HashMap  
import scala.collection.immutable.HashSet  
  
class LetterTree {  
    private val nodes: HashMap[Char,LetterTree] = new HashMap[Char,LetterTree]  
    var terminal: Boolean = false  
    def addWord(word: String): Unit = addWord(word.toList)  
    def addWord(word: List[Char]): Unit = word match {  
        case Nil          => terminal = true  
        case head :: tail => nodes.getOrElseUpdate(head, new LetterTree).addWord(tail)  
    }  
    def getSubTree(letter: Char): Option[LetterTree] =  
        if (nodes.contains(letter)) Some(nodes(letter)) else None  
}  

class FileLetterTree(path: String) extends LetterTree {  
    val file = new File(path)  
    for (line <- Source.fromFile(file).getLines) addWord(line.trim)  
}  


class GameBoard(lettersStr: String) {  
    private val ltrStr = lettersStr.toLowerCase()  
    if (!ltrStr.matches("^[a-z]{16}$"))  
    throw new Exception("Exactly 16 letters a-z are required.")  
  
    override def toString: String =  
        ltrStr.substring(0,4)  + "\n" + ltrStr.substring(4,8) + "\n" +  
        ltrStr.substring(8,12) + "\n" + ltrStr.substring(12,16)  
  
    case class Letter(letter: Char) {  
        var neighbors = List[Letter]()  
        def addNeighbor(nbr: Letter) = { neighbors = nbr :: neighbors }  
        override def toString = letter.toString  
    }  
  
    val letters = new Array[Array[Letter]](4,4)  
    for (idx <- 0 until ltrStr.length)  
        letters(idx/4)(idx % 4) = Letter(ltrStr(idx))  
  
    for ( idx <- 0 to 3; jdx <- 0 to 3; iOff <- -1 to 1; jOff <- -1 to 1;  
          if (iOff != 0 || jOff != 0) &&  
          idx + iOff >= 0 && idx + iOff < 4 &&  
          jdx + jOff >= 0 && jdx + jOff < 4 )  
        letters(idx)(jdx).addNeighbor(letters(idx + iOff)(jdx + jOff))  

  def findWords(tree: LetterTree): List[String] = {  
      def findWords(tree: LetterTree, letter: Letter, sofar: List[Letter]): List[String] = {  
          tree.getSubTree(letter.letter) match {  
            case Some(subTree) =>  
              var words: List[String] = Nil  
              if (subTree.terminal) words = (letter :: sofar).foldLeft("")((c,n) => n+c) :: words  
              for (nextLetter <- letter.neighbors if !sofar.contains(nextLetter))  
              words = findWords(subTree, nextLetter, letter :: sofar) ::: words  
              words  
            case None => Nil  
          }  
      }  
      var words: List[String] = Nil  
      for (idx <- 0 to 3; jdx <- 0 to 3)  
          words = words ++ findWords(tree, letters(idx)(jdx), Nil)  
      words  
  }  
}  

class PuzzleSolver(dictionaryPath: String) {  
    val tree = new FileLetterTree(dictionaryPath)  
    def solve(letters: String) = {  
        val board = new GameBoard(letters)  
        val wordSet = new HashSet[String]() ++ board.findWords(tree)  
        val sortedWords = wordSet.toList.sort{ (a,b) =>  
            a.length > b.length || (a.length == b.length && a > b)  
        }  
        println(sortedWords)  
    }  
}

One last thing: To be clear, no I don’t actually use this to cheat at online games. Just knowing that I could is satisfying enough for me.

In my last post (Tail-Recursion Basics In Scala), I went on and on about how scala, as a functional programming language, was written to accommodate recursion. When you need to do a task over and over the procedural way is to iterate, and the functional way is to recurse. That’s what I thought. But I was browsing the source for the scala List class and I noticed something weird. It’s chock full of iteration and procedural code!

Well, I don’t have to tell you I was a little disappointed. I never claimed to have unreservedly swallowed the FP hype, but I figured the authors of the scala.List must surely practice what they preach. I’m puzzled. To show you what I mean, here are a few methods written in a procedural style just as they appear in scala.List, alongside my attempt to write them in a functional, tail-recursive fashion. This isn’t an exhaustive list by any means. There are tons of methods in class scala.List like these.

Mark the following! I am not trying to to correct the authors of scala.List or point their code out as inferior. On the contrary. They know what they are doing. I do not. They implemented a general-use functional/OO programming language. I have not. I am merely wondering aloud why scala.List was implemented the way it was.

I’m using http://lampsvn.epfl.ch/trac/scala/browser/scala/tags/R_2_7_2_RC2/src/library/scala/List.scalaas my source.

List.length

I used a length function as an example in my last post on recursion. The length method of List returns a count of the items in the List. First, here’s the actual code for the length method:

sealed abstract class List[+A] extends Seq[A] with Product {
...
  def length: Int = {
    var these = this
    var len = 0
    while (!these.isEmpty) {
      len += 1
      these = these.tail
    }
    len
  }
...
}

See why I called this code procedural? Those vars and that while loop. Vars are used for variables whose values change. There’s nothing wrong with it, but it’s a feature of procedural code. Iterative loops are also more procedural than functional. You could argue that the while is not a procedural loop, but a functional closure. But I know iteration when I see it. Why wasn’t it written using recursion? Here’s how it could be implemented in a more functional style.

object List {
...
  def length(list: List[_], len: Int): Int = {
    if (list.isEmpty) len
    else length(list.tail, len + 1)
  }
...
}

sealed abstract class List[+A] extends Seq[A] with Product {
...
  def length: Int = List.length(this, 0)
...
}

There is a List class as well as a List object. I put the recursive function in the List helper object because I couldn’t get the tail recursion optimization when I implemented it as a recursive member function within class List.

List.indices

The indices method returns a List of the zero-based indices of a List. If a list has 4 items, its indices method returns List(0, 1, 2, 3). Once again, we’ll first look at the actual code for the indices method:

sealed abstract class List[+A] extends Seq[A] with Product {
...
  def indices: List[Int] = {
    val b = new ListBuffer[Int]
    var i = 0
    var these = this
    while (!these.isEmpty) {
      b += i
      i += 1
      these = these.tail
    }
    b.toList
  }
...
}

As before, this code steps through the list in a while loop. Here’s how it could be done using recursion:

object List {
...
  def indices(list: List[_], idxList: List[Int], curIdx: Int): List[Int] = {
    if (list.isEmpty) idxList
    else List.indices(list.tail, idxList + curIdx, curIdx + 1)
  }
...
}

sealed abstract class List[+A] extends Seq[A] with Product {
...
  def indices: List[Int] = List.indices(this, Nil, 0)
...
}

List.last

This is a very simple method. It just returns the last item in the List. Again, here’s the actual code for the List.last method:

sealed abstract class List[+A] extends Seq[A] with Product {
...
  override def last: A =
    if (isEmpty) throw new Predef.NoSuchElementException("Nil.last")
    else {
      var cur = this
      var next = this.tail
      while (!next.isEmpty) {
        cur = next
        next = next.tail
      }
      cur.head
    }
...
}

Here’s what it might look like in a functional style.

object List {
...
  def last[A](list: List[A]): A =
    if (list.isEmpty) throw new Predef.NoSuchElementException("Nil.last")
    else {
      if (list.tail.isEmpty) head
      else last(list.tail)
    }
...
}

sealed abstract class List[+A] extends Seq[A] with Product {
...
  override def last: A = List.last(this)
...
}

List.foreach

Ok, last one. This method runs a function once for each item in the List, passing the item as a parameter each time. Here’s the real code:

sealed abstract class List[+A] extends Seq[A] with Product {
...
  final override def foreach(f: A => Unit) {
    var these = this
    while (!these.isEmpty) {
      f(these.head)
      these = these.tail
    }
  }
...
}

And, again, here’s the code rendered in a more functional style.

object List {
...
  def foreach[A](list: List[A], f: A => Unit) {
    if (!list.isEmpty) {
      f(list.head)
      foreach(list.tail, f)
    }
  }
...
}

sealed abstract class List[+A] extends Seq[A] with Product {
...
  final override def foreach(f: A => Unit) = List.foreach(this, f)
...
}

So What?

So what? What does it matter that while loops are used instead of tail recursion? Well, it doesn’t really particularly matter. Not to me, anyway. But it’s surprising that proponents, nay, authorsof a functional language would write (perfectly good) procedural code like this.

I don’t know why. Maybe it’s for the benefit of the Java coders whom scala’s creators hope will become early adopters. Maybe the functions in the List object are frowned upon for some design reason. More likely it’s some technical issue that my inexperienced eye can’t discern. Is there some reason my examples would not work? If you know, please share.

Update: By the way, Debasish Ghosh wrote a related article today. It includes a comparison of the List implementation in Erlang and Scala.

Recursion 101

We all know what recursion is, right?  A function calls itself.  Or function A calls function B which calls function A.  Or A calls B, which calls C, which calls A, etc.  By far the most common situation is that in which a function calls itself.

You don’t see a lot of recursive code in Java.  There are a couple reasons for this.  First, recursion is hard.  It’s not intuitive.  With iteration (the main alternative to recursion) you see the big picture.  You can look at the whole loop and its easy to understand even for the beginner.  With recursion, you see one layer and you have to imagine what happens when those layers stack up.  Like anything else, if you practice iteration then it becomes easier, but compared to iteration, recursion is hard to learn.

Second, Java is not designed to accomodate recursion.  It’s designed to accomodate iteration.  Java gives you for loops, for-each loops, while loops, do loops, arrays, Iterators, ResultSets, etc.  These constructs are all about iteration.  Plus, recursion in Java has an Achilles’ heel: the call stack.

Generally, when you call a function in any language a new level is added to the call stack.  The call stack, we all know, is what keeps track of local variables, what function has called what, etc.  It’s no different in Java.  But the stack has a finite and limited size.  Recursion is fine if you know for certain that you’ll never go more than a few dozen levels deep.  But if the recursion goes too deep, you run out of stack space and your program goes kaput.  That doesn’t happen with iteration, so it’s safer to just avoid recursion altogether.

Recursion In Scala

Scala, however, being a functional language is very much geared toward recursion rather than iteration.  So how does it overcome the limitations of the call stack?  Let’s look at an example recursive function in scala.

def listLength1(list: List[_]): Int = {
  if (list == Nil) 0
  else 1 + listLength1(list.tail)
}

var list1 = List(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
var list2 = List(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
1 to 15 foreach( x => list2 = list2 ++ list2 )

println( listLength1( list1 ) )
println( listLength1( list2 ) )

Function listLength1 recursively counts the number of items in a list.  Try running this in the interpreter.  It works fine for list1, the short list, but the longer list exhausts the stack.  Recursion is a functional language’s bread and butter, but we see here that even scala, a functional language, is subject to call stack limitations.

Don’t give up on recursion yet, though.  Scala has a very important optimization that will allow you to recurse without limit provided you use the right kind of recursion.

Head Recursion And Tail Recursion

There are two basic kinds of recursion: head recursion and tail recursion.  In head recursion, a function makes its recursive call and then performs some more calculations, maybe using the result of the recursive call, for example.  In a tail recursive function, all calculations happen first and the recursive call is the last thing that happens.

The importance of this distinction doesn’t jump out at you, but it’s extremely important!  Imagine a tail recursive function.  It runs.  It completes all its computation.  As its very last action, it is ready to make its recursive call.  What, at this point, is the use of the stack frame?  None at all.  We don’t need our local variables anymore because we’re done with all computations.  We don’t need to know which function we’re in because we’re just going to re-enter the very same function.  Scala, in the case of tail recursion, can eliminate the creation of a new stack frame and just re-use the current stack frame.  The stack never gets any deeper, no matter how many times the recursive call is made.  That’s the voodoo that makes tail recursion special in scala.

Incidentally, some languages achieve a similar end by converting tail recursion into iteration rather than by manipulating the stack.

This won’t work with head recursion.  Do you see why?  Imagine a head recursive function.  First it does some work, then it makes its recursive call, then it does a little more work.  We can’t just re-use the current stack frame when we make that recursive call.  We’re going to NEED that stack frame info after the recursive call completes.  It has our local variables, including the result (if any) returned by the recursive call.

Here’s a question for you.  Is the example function listLength1 head recursive or tail recursive?  Well, what does it do?  (A) It checks whether its parameter is Nil.  (B) If so, it returns 0 since Nil has 0 length.  (C) If not, it returns 1 plus the result of a recursive call.  The recursive call is the last thing we typed before ending the function.  That’s tail recursion, right?  Wrong.  The recursive call is made, and THEN 1 is added to the result, and this sum is returned.  This is actually head recursion (or middle recursion, if you like) because the recursive call is not the very last thing that happens.

Tail Recursion Example

When you write a recursive function in scala, your aim is to encourage the compiler to make tail recursion optimizations.  Now let’s rewrite that function using tail recursion.

def listLength2(list: List[_]): Int = {
  def listLength2Helper(list: List[_], len: Int): Int = {
    if (list == Nil) len
    else listLength2Helper(list.tail, len + 1)
  }
  listLength2Helper(list, 0)
}

println( listLength2( list1 ) )
println( listLength2( list2 ) )

I wrote this as two functions (listLength2 and an internal helper function) to preserve the one-parameter interface used in the earlier example.  It would be great if we could specify a default value for a parameter.  We could then write this as one function, but I don’t know a way to do it.  Long story short: listLength2 just calls listLength2Helper which does the real work and is the recursive function.

Is listLength2Helper head recursive or tail recursive?  When the recursive call is made are we really finished with the stack frame, allowing scala to optimize for tail recursion?  Just like in listLength1, this function first checks for a Nil list, but this time returns the len parameter rather than 0.  If list is non-Nil then we make the recursive call.  But there’s still an addition going on here, len + 1.  Does that ruin the tail recursion?  No.  That term is evaluated first.  Only after all the parameters have been evaluated does the recursive call happen.  This function does indeed qualify for tail recursion optimization.

By The Way…

The two examples in this post are very elementary and it’s easy to predict visually that tail recursion optimization will be applied.  But it’s not always so simple.  If you’re trying to write a complex recursive function and you require that it be optimized as tail recursive then you have a problem.  How do you verify that you have succeeded?  The only ways that I know of are (A) examine the resulting object code or (B) write a unit test to verify that the optimization was made.  The problem with this is that you open yourself up to bugs in your unit tests.  You have to be able to write a test that you know will fail if the function is not optimized.  I’m new to scala, so maybe there’s a better way to do this.  If there is, please educate me.

What I would like to see is an annotation that marks a function as requiring tail recursion optimization.  Such an annotation would trigger a compiler error if the compiler were unable to make the optimization.  The annotation I propose would work like this:

@TailRecursive
def listLength1(list: List[_]): Int = {
  if (list == Nil) 0
  else 1 + listLength1(list.tail)
}

def listLength2(list: List[_]): Int = {
  @TailRecursive
  def listLength2Helper(list: List[_], len: Int): Int = {
    if (list == Nil) len
    else listLength2Helper(list.tail, len + 1)
  }
  listLength2Helper(list, 0)
}

The compiler would succeed for listLength2Helper, but would report an error when it failed to apply the requested tail recursion optimization.  This doesn’t let the developer off the hook on unit testing, nor does it relieve the developer from having to code carefully.  What it does is provide early, infallible verification that a crucial feature was implemented.  Why force the developer to examine the classfile or write a bug-vulnerable unit test when the compiler knowsthe answer and could just cough up the information at compile time?

« Previous PageNext Page »