In Java you don’t see a lot of linked lists, and if you do it’s almost always java.util.LinkedList. People never write their own lists. They don’t really need to, I suppose. The one from java.util is fine. Plenty of people are leading fulfilling software careers never having implemented their own linked list. But it’s kind of a shame. Knowing how your data structures work makes you a better programmer.

It’s even rarer for a person to implement his own linked list in Scala. Scala’s scala.List is one of the most used classes in the language, so it’s packed with functionality. It’s abstract, covariant, it has helper objects such as List and Nil and the little-known ‘::’ class, it inherits from Product, Seq, Collection, Iterable, and PartialFunction. The machinery of List pulls in Array, ListBuffer, and more. It can be hard to take it all in.

So let’s build our own linked list. We’ll start out with something very basic and un-Scala-like. Then we’ll improve it gradually until we have something a little closer to scala.List. I encourage you to fire up your Scala interpreter and follow along.

Back to Basics

First, a short review of linked lists. A linked list is a chain of nodes, each referring to exactly one other node until you get to the end of the chain. You refer to the list by its first item and you follow the chain of references to reach the other nodes.

What are the requirements for our first try? Our node should be able to hold a piece of data, and refer to the next node. It should also be able to report its length, and provide a toString method so we can visualize the list. Here we go.

class MyList(val head: Any, val tail: MyList) {
  def isEmpty = (head == null && tail == null)
  def length: Int = if (isEmpty) 0 else 1 + tail.length
  override def toString: String = if (isEmpty) "" else head + " " + tail
}

The value ‘head’ holds the data for the node, ‘tail’ refers to the next element in the chain. The ‘isEmpty’ method is true if the head and tail are both null. The length and toString methods are both defined using a similar pattern: if (isEmpty) [base result] else [data for current node + result of same method on tail].

Here’s what it looks like when we use this class:

scala> var list = new MyList(null, null)
list: MyList =

scala> list.length
res0: Int = 0

scala> list.isEmpty
res1: Boolean = true

scala> list = new MyList("ABC", list)
list: MyList = ABC

scala> list.length
res3: Int = 1

scala> list.isEmpty
res4: Boolean = false

scala> list = new MyList("XYZ", list)
list: MyList = XYZ ABC

scala> list = new MyList("123", list)
list: MyList = 123 XYZ ABC

scala> list.tail.head
res7: Any = XYZ

Not bad. It gets the job done. But it has some problems. First is the use of ‘null’. Use of null references is sloppy and increases the odds of a null pointer exception so, ideally, we don’t want to see that. It has other problems, too. It’s too verbose. It’s not typesafe. But for now let’s concentrate on getting rid of the nulls.

No Nulls Is Good Nulls

How can we do it? We’re using the null as a special value, a marker to tell us when a node is at the end of a list. So we’ll just use something else as that marker instead. What can we use? We’ll create a special object for the empty list. It will be recognized as empty just based on its identity, not on null values. So let’s try it:

class MyList(val head: Any, val tail: MyList) {
  def isEmpty = false
  def length: Int = if (isEmpty) 0 else 1 + tail.length
  override def toString: String = if (isEmpty) "" else head + " " + tail
}

object MyListNil extends MyList("arbitrary value", null) {
  override def isEmpty = true
}

That’s better. (The observant reader will note the similarity of MyListNil to scala.List’s Nil object.) We got rid of the nulls in the isEmpty method, but we still have to put something in the head and tail parameters of the MyList constructor. We put an arbitrary non-null value in head, but what do we put for tail? Either null or create a new MyList. And how can that MyList be instantiated? It also needs a tail. Vicious circle. So this solution leaves us still stuck with a null.

Earlier, the null was there to mark a special node. We factored out that usage. Now it’s there to allow us to create the MyListNil. How can we factor that out? MyListNil is required to call its parent’s constructor. What if had no parent? Then it wouldn’t be a MyList anymore. What if it had an abstract parent? Now you’re talking. Let’s see what that would look like.

abstract class MyList {
  def head: Any
  def tail: MyList
  def isEmpty: Boolean
  def length: Int
}

class MyListImpl(val head: Any, val tail: MyList) extends MyList {
  def isEmpty = false
  def length: Int = 1 + tail.length
  override def toString: String = head + " " + tail
}

object MyListNil extends MyList {
  def head: Any = throw new Exception("head of empty list")
  def tail: MyList = throw new Exception("tail of empty list")
  def isEmpty = true
  def length = 0
  override def toString =  ""
}

It’s a little more code, but much neater. There are no nulls anywhere. Here’s how it looks when we use this new MyList:

scala> var list: MyList = MyListNil
list: MyList =

scala> list = new MyListImpl("ABC", list)
list: MyList = ABC

scala> list = new MyListImpl("XYZ", list)
list: MyList = XYZ ABC

scala> list = new MyListImpl("123", list)
list: MyList = 123 XYZ ABC

scala> list.length
res3: Int = 3

scala> list.tail.head
res4: Any = XYZ

scala> list.tail.tail.tail.head
java.lang.Exception: head of empty list
        at ...

Pretty neat. The equivalent of MyListImpl in the Scala’s real List implementation is a class called ‘::’, which has that funny name, by the way, because it looks nice in pattern matching code. Sometimes ‘::’ is referred to as cons. With nulls finally eliminated, we can concentrate on other issues.

Brevity Is The Heart Of List

The thing that I notice at this point is that a lot of typing (on the keyboard) is required to use this list. We have to type out “list = new MyListImpl(…, list)” every time we add an item. We can improve this with a new method.

abstract class MyList {
  [...]
  def add(item: Any): MyList = new MyListImpl(item, this)
}

Now we have classes referring to each other. MyList creates new MyListImpls, and MyListImpl extends MyList. So you’ll need to put these classes in a .scala file and compile them instead of just typing them into the Scala interpreter. But, wow! Look how much easier it is to use MyList now:

scala> var list = MyListNil add("ABC") add("XYZ") add("123")
list: MyList = 123 XYZ ABC

scala> list.length
res1: Int = 3

So much easier! One thing I notice, though, is that the order of items in the code is different from the order produced by toString. We can change our ‘add’ method so that is right-associative instead of left-associative by using a method name that ends in ‘:’ (colon). We’ll use ‘::’ as the method name since that’s what scala.List uses.

abstract class MyList {
  [...]
  def ::(item: Any): MyList = new MyListImpl(item, this)
}

scala> var list = "ABC" :: "XYZ" :: "123" :: MyListNil
list: MyList = ABC XYZ 123

Now we’re really getting somewhere. This is starting to look more like scala.List. One other thing that the standard list implementation gives you is a shortcut for initializing lists. It looks like “List(1, 2, 3, 4)”. Notice there’s no ‘new’ keyword. This is done using the scala.List helper object and its ‘apply’ method. Below is our own MyList helper object.

object MyList {
  def apply(items: Any*): MyList = {
    var list: MyList = MyListNil
    for (idx <- 0 until items.length reverse)
      list = items(idx) :: list
    list
  }
}

scala> var list = MyList("ABC", "XYZ", "123")
list: MyList = ABC XYZ 123

scala> list = "Cool" :: list
list: MyList = Cool ABC XYZ 123

Better Type-Safe Than Sorry

Better. Our code looks much neater now when we use MyList. I’ll introduce just one more improvement to MyList. It still has a rather glaring problem. It provides no type information. It keeps all of its data using a reference to Any. If you don’t see why this is a problem, let’s see what happens when we want to get the length of some items in a MyList:

scala> var list = MyList("ABC", 12345, "WXYZ")
list: MyList = ABC 12345 WXYZ

scala> list.head.length
<console>:6: error: value length is not a member of Any
       list.head.length
                 ^

scala> list.head.asInstanceOf[String].length
res10: Int = 3

scala> list.tail.head.asInstanceOf[String].length
java.lang.ClassCastException: java.lang.Integer cannot be cast to java.lang.String
        at .<init>(<console>:6)

Ouch! First, when we try to call method ‘length’ on list.head Scala complains that list.head is a reference to Any. Any doesn’t have a length method. This isn’t a dynamically typed language like, say, Ruby. An object has to have the right type before we can start calling methods. What to do? You could implement a MyStringList where the head has type String. But then you’ll need a MyIntList, MyDoubleList, etc. What we need is a way to specify the type of data in the list when we create the MyList instance. What we need is a type parameter.

Here’s the complete MyList code using a type parameter, and a little demonstration code:

abstract class MyList[A] {
  def head: A
  def tail: MyList[A]
  def isEmpty: Boolean
  def length: Int
  def ::(item: A): MyList[A] = new MyListImpl[A](item, this)
}

class MyListImpl[A](val head: A, val tail: MyList[A]) extends MyList[A] {
  def isEmpty = false
  def length: Int = 1 + tail.length
  override def toString: String = head + " " + tail
}

object MyListNil extends MyList[Nothing] {
  def head: Nothing = throw new Exception("head of empty list")
  def tail: MyList[Nothing] = throw new Exception("tail of empty list")
  def isEmpty = true
  def length = 0
  override def toString =  ""
}

object MyList {
  def apply[A](items: A*): MyList[A] = {
    var list: MyList[A] = MyListNil.asInstanceOf[MyList[A]]
    for (idx <- 0 until items.length reverse)
      list = items(idx) :: list
    list
  }
}

scala> var list = MyList("ABC", "WXYZ", "123")
list: MyList[java.lang.String] = ABC WXYZ 123

scala> list.head.length
res0: Int = 3

scala> 3.14159 :: list
<console>:6: error: type mismatch;
 found   : Double
 required: java.lang.String
       3.14159 :: list
               ^

scala> var list = MyList("ABC", 123, 3.14159)
list: MyList[Any] = ABC 123 3.14159

Look at line 32. The “MyList(…)” returns a MyList[String]. Scala figures out from the parameters what type to use. In line 35, you can see how much easier it is to use the list contents when you know the type at compile time.

If you try to mix types, as in line 45, Scala determines the nearest common ancestor of the types (Any, in this case) and uses that. However, if the type parameter is already determined, as in line 38, it won’t change when you try to add data of a different type. To make line 38 work, we can make a small change to the ‘::’ method:

abstract class MyList[A] {
  [...]
  def ::[B >: A](item: B): MyList[B] = 
    new MyListImpl(item, this.asInstanceOf[MyList[B]])
}

scala> var list = MyList("ABC", "XYZ")
list: MyList[java.lang.String] = ABC XYZ

scala> 3.14159 :: list
res0: MyList[Any] = 3.14159 ABC XYZ

This says that ‘::’ takes a parameter of type B, which is either A or a superclass of A, and returns a MyList[B]. So if you have a MyList[String] and you call ‘::’ on it with a Double parameter, Scala figures out that although Double is not a superclass of String, String and Double are both descendants of Any, and it returns a MyList[Any].

Conclusion

That’s a good stopping point for now. Obviously you can take the MyList class a lot further and add a lot more methods, but we’ve created some code that approximates the basics provided by scala.List. In fact, you could take several of the scala.List methods (foldLeft, for example) and basically drop them right into MyList and they’d work fine.

One of my favorite functional programming tricks is folding. The fold left and fold right functions can do a lot of complicated things with a small amount of code. Today, I’d like to (1) introduce folding, (2) make note of some surprising, nay, shocking fold behavior, (3) review the folding code used in Scala’s List class, and (4) make some presumptuous suggestions on how to improve List.

Update: I’ve created a new post in which I list lots and lots of foldLeft examples in case you’d like to learn more about what folding can accomplish.

Know When to Hold ‘Em, Know When to Fold ‘Em

In case you’re not familiar with folding, I’ll describe it as briefly as I can.

Here’s the signature of the foldLeft function from List[A], a list of items of type A:

def foldLeft[B](z: B)(f: (B, A) => B): B

Firstly, foldLeft is a curried function (So is foldRight). If you don’t know about currying, that’s ok; this function just takes its two parameters (z and f) in two sets of parentheses instead of one. Currying isn’t the important part anyway.

The first parameter, z, is of type B, which is to say it can be different from the list contents type. The second parameter, f, is a function that takes a B and an A (a list item) as parameters, and it returns a value of type B. So the purpose of function f is to take a value of type B, use a list item to modify that value and return it.

The foldLeft function goes through the whole List, from head to tail, and passes each value to f. For the first list item, that first parameter, z, is used as the first parameter to f. For the second list item, the result of the first call to f is used as the B type parameter.

For example, say we had a list of Ints 1, 2, and 3. We could call foldLeft(“X”)((b,a) => b + a). For the first item, 1, the function we define would add string “X” to Int 1, returning string “X1″. For the second list item, 2, the function would add string “X1″ to Int 2, returning “X12″. And for the final list item, 3, the function would add “X12″ to 3 and return “X123″.

Here are a few more examples.

list.foldLeft(0)((b,a) => b+a)
list.foldLeft(1)((b,a) => b*a)
list.foldLeft(List[Int]())((b,a) => a :: b)

The first line is super simple. It’s almost like the example I described above, but the z value is the number 0 instead of string “X”. This fold combines the elements of the list by addition instead of concatenation. So the fold returns the sum of all Ints in the list. Line 2 combines them through multiplication. Do you see why the z value is 1 in this case?

Line 3 is a little more complex. Can you guess what it does? It starts out with an empty list of Ints and adds each item to the accumulator (We call the b parameter of our function the accumulator because it accumulates data from each of our list items). Because it starts with the head and adds to the beginning of the accumulator list until it gets to the last item of the original list, it returns the original list in reverse order.

The foldRight function works in much the same way as foldLeft. Can you guess the difference? You got it. It starts at the end of the list and works its way up to the head.

Folds can be used for MUCH more than I’ve shown here. With folds, you can solve lots of different problems with a standard construct. You should read up on them if you’re just starting out in functional programming.

All That Glitters Is Not Fold

Now for the moment you’ve been waiting for. Fold’s dirty little secret! The below is taken from a scala interpreter session.

scala> var shortList = 1 to 10 toList
shortList: List[Int] = List(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)

scala> var longList = 1 to 325000 toList
longList: List[Int] = List(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 
15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, ...

scala> shortList.foldLeft("")((x,y) => "X")
res1: java.lang.String = X

scala> shortList.foldRight("")((x,y) => "X")
res2: java.lang.String = X

scala> longList.foldLeft("")((x,y) => "X")
res3: java.lang.String = X

scala> longList.foldRight("")((x,y) => "X")
java.lang.StackOverflowError
        at scala.List.foldRight(List.scala:1079)
        at scala.List.foldRight(List.scala:1081)
        at scala.List.foldRight(List.scala:1081)
        at scala.List.foldRight(List.scala:1081)
        at scala.List.foldRight(List.scala:1081)
        at scala.List.foldRight(List.scala:1081)
        at scala.List.foldRight(List.scala:1081)
        at scala.List.foldRight(List.scala:1081)
        at scala.List.foldRig...

We created two lists: shortList with 10 items, and longList with 325,000 items. Then we perform a trivial foldLeft and foldRight on shortList. It’s trivial because the passed-in function always returns the string “X”; it doesn’t even use the list data.

Then we do a foldLeft on longList. This goes off without a hitch. Finally we try to do a foldRight, the same foldRight that succeeded on the shorter list, and it fails! The foldLeft worked. Why didn’t the foldRight work? It’s a perfectly reasonable call against a perfectly reasonable List. Something funny is going on here.

The error message says there was a stack overflow, and the stack trace shows a long list of calls at List.scala:1081. If you’ve read my post about tail-recursion, then you probably suspect that some recursive code is to blame.

Let’s look into List.scala, maybe the single most important Scala source file.

Fool’s Fold

Without further ado, here’s the code for foldLeft and foldRight from List.scala:

override def foldLeft[B](z: B)(f: (B, A) => B): B = {
  var acc = z
  var these = this
  while (!these.isEmpty) {
    acc = f(acc, these.head)
    these = these.tail
  }
  acc
}

override def foldRight[B](z: B)(f: (A, B) => B): B = this match {
  case Nil => z
  case x :: xs => f(x, xs.foldRight(z)(f))
}

Wow! Those two definitions are very different!

The foldLeft function is the one that worked for short and long lists. You can see why? It isn’t head-recursive. In fact, it isn’t recursive at all. It is implemented as a while loop. On each iteration, the next list item is passed to the function f and the accumulator (called acc) is updated. When there are no more list items, the accumulator is returned. No recursion means no stack overflows.

The foldRight function is implemented in a totally different way. If the list is empty, the z parameter is returned. Otherwise, a recursive call is made on the tail (the whole list minus the first item) of this list, and that result is passed to the function f. Study the foldRight definition. Do you understand how it works? It’s an elegant recursive solution, and the code really is quite pretty, but it’s not tail recursive so it fails for large lists.

Why didn’t Mr Odersky just write foldRight using a while loop, too? Then this problem wouldn’t exist, right? The reason is that Scala’s List is a implemented as a singly-linked list. Each list element has access to the next item in the list, but not to the previous item. You can only traverse a list in one direction! This works fine for foldLeft, which goes from head to tail, but foldRight has to start at the end of the list and work its way forward to the head. If foldRight uses recursion, it must recurse all the way to the end and then use the results of those recursive calls as the accumulator passed into function f.

See? The results of the recursive call must be used for further calculation, so the recursive call can’t be the last thing that happens, so it can’t be written as a tail-recursive function. If you don’t know what I’m talking about, read my introduction to tail-recursion.

Out With The Fold, In With The New

So is that it for foldRight? Is it hopeless? I say no!

There is a way to get the same result as foldRight, but using foldLeft. Can you guess what it is? Here’s how:

list.foldRight("X")((a,b) => a + b)
list.reverse.foldLeft("X")((b,a) => a + b)

These two lines are equivalent! They give the same result no matter what’s in list. Since foldRight processes list elements from last to first, that’s the same as processing the reversed list from first to last.

Here are three possible implementations of foldRight that could replace the current one.

def foldRight[B](z: B)(f: (A, B) => B): B = 
  reverse.foldLeft(z)((b,a) => f(a,b))

def foldRight[B](z: B)(f: (A, B) => B): B =
  if (length > 50) reverse.foldLeft(z)((b,a) => f(a,b))
  else             originalFoldRight(z)(f)

def foldRight[B](z: B)(f: (A, B) => B): B =
  try {
    originalFoldRight(z)(f)
  } catch {
    case e1: StackOverflowError => reverse.foldLeft(z)((b,a) => f(a,b))
  }

The first one simply replaces the original recursive logic with the equivalent call to reverse and foldLeft. Why wasn’t foldRight implemented this way to begin with? It may be, in part, that the authors thought the extra overhead of reversing the list was unwarranted. To me, it doesn’t seem that bad. The original foldRight and foldLeft functions are O(n), meaning they run in an amount of time roughly proportional to the number of items in the list. If you look at the source for the reverse function, you’ll see it’s also O(n). So running reverse followed by foldLeft is O(n).

The second implementation is a compromise. It uses the original recursive version of foldRight (referred to as originalFoldRight in the above code) only when the list is shorter than 50 elements. The reverse.foldLeft is used for lists of 50 elements or longer. 50 is just an arbitrary number, just a guess at a sensible limit on the number of recursive calls to allow.

The third implementation tries the original foldRight logic first and if the call stack overflows then it uses reverse.foldLeft. This solution is, of course, completely ridiculous, but even this would be better than a foldRight which sometimes crashes your program.

That’s All, Folds!

As I pointed out before, the reverse.foldLeft implementation of foldRight is O(n), same as the original recursive version. The original foldRight may work just fine when your Scala application is young and working with small data sets. Over time more customers are added, more products are created, more orders are placed, and then one day, *POOF*, a runtime error! It’s a ticking time-bomb.

As you may well guess, I would like to see the reverse.foldLeft logic used instead of the recursive version. That would prevent the stack overflow errors. But I would settle for just deprecating foldRight. It would be better to eliminate foldRight and force the coder to work around it than to leave it in its current state. In fact, I don’t think any head-recursive functions belong in the List class.

Do any readers have any insight into why foldRight is coded the way it is?

I was playing a game on my iPhone called Scramble the other day. It’s a great game. You are presented with a 4×4 grid of letters, and your job is to find words by chaining together adjacent letters. It bears a passing similarity to Boggle. I was playing online and I noticed that many other players play much better than I do. Well, a software developer doesn’t take a thing like this lying down. I decided it was time to write a Scramble (or Boggle) solver!

And since I’m trying to pick up the Scala language whenever I have an opportunity, I wrote the whole thing in Scala. I hope it will be, for the reader, an interesting study of a complete and useful (though simple) example Scala program.

First, I had to decide on a strategy. If we try every path on the game board, that’s inefficient. For example, if we try a path XQ then we can stop right there. XQ is not a word and no word starts with the letters XQ. This suggests using some sort of spell-checker-like logic. So let’s first create a dictionary data structure that allow us to look up words and eliminate dead ends. What we need is a tree. Here’s the data structure I came up with.

import scala.collection.mutable.HashMap

class LetterTree {
    private val nodes: HashMap[Char,LetterTree] = new HashMap[Char,LetterTree]
    var terminal: Boolean = false
    def addWord(word: String): Unit = addWord(word.toList)
    def addWord(word: List[Char]): Unit = word match {
        case Nil          => terminal = true
        case head :: tail => nodes.getOrElseUpdate(head, new LetterTree).addWord(tail)
    }
    def getSubTree(letter: Char): Option[LetterTree] =
        if (nodes.contains(letter)) Some(nodes(letter)) else None
}

It’s a mutable class. I toyed with some ideas for an immutable class, but it just complicated things more than I wanted to cope with. Our dictionary will be a tree in which LetterTree is the node class. You can see that a LetterTree has two member data: a hashmap of child LetterTrees indexed by Char, and a flag called terminal. So a LetterTree is a tree node whose child nodes are indexed by letter, and which can be marked (via the terminal flag) as ending a word. This is a pretty efficient way to store words.

There are two addWord methods for populating the tree. One takes a list of Chars. The other takes a String and is included for convenience. It converts its String parameter to a list of Chars and calls the other addWord method. If the addWord method is passed an empty list (Nil), then it has reached the end of a word and sets the terminal flag. Otherwise, it takes the first Char in the list and looks up (or else creates) the LetterTree mapped to that Char. The method then adds the remainder of the Char list to that mapped LetterTree.

Finally, there’s the getSubTree method. This will be useful when we start using the tree to look up potential word matches.

So imagine that we create a LetterTree and add the following words: item, its, it. We get a structure like this:

Figure 1

Figure 1

Each of the boxes represents a LetterTree instance. Each arrow, labeled with letter, represents an entry in the ‘nodes’ HashMap. The top level LetterTree does not itself represent a letter. It’s just a starting point. Its ‘nodes’ member has just one mapping: letter ‘i’ maps to a second LetterTree. That second LetterTree doesn’t have its ‘terminal’ flag set, so we haven’t made a word yet. Its ‘nodes’ map has a single entry mapping ‘t’ to a third LetterTree. This third LetterTree is a terminal node, so we know that the sequence ‘it’ forms a word. The third LetterTree’s ‘nodes’ map has entries for ‘s’ and ‘e’. The LetterTree mapped to ‘s’ is terminal, and the one mapped to ‘e’ is not. The ‘e’ node has a child node ‘m’ that is a terminal.

This is a pretty efficient way to store words. We got to use the sequence ‘it’ 3 times! Also, note that there is exactly one terminal node for each word. If we add a word, we will either append a new leaf node which will be terminal, or we will make one of the existing nodes terminal.

Now, all we have to do is create a LetterTree and call the addWord method for every word we can think of. That could be annoying. Let’s create an improved LetterTree that will read a list of words from a file.

import java.io.File
import scala.io.Source

class FileLetterTree(path: String) extends LetterTree {
    val file = new File(path)
     for (line <- Source.fromFile(file).getLines) addWord(line.trim)
}

Can you believe how easy that was?! We just extend LetterTree, take a file path as a constructor parameter, use the scala.io.Source class to get all lines, and add each line as a word. Now just find a text file containing all English words. You should be able to google it.

You might want to ensure the file contains only words 3 letters or longer (as required by the rules of the game), only lower case, and only the 26 english letters. Since Scramble has no ‘Q’ piece (neither does Boggle) but only a ‘Qu’ piece, you might want to do a global replace on your word file, replacing all instance of ‘qu’ with ‘q’.

Ok, now we have a data structure for looking up words.  Now we need to create some code that represents the game board.  We’ll assume a 4-by-4 board.

class GameBoard(lettersStr: String) {
    private val ltrStr = lettersStr.toLowerCase()
    if (!ltrStr.matches("^[a-z]{16}$"))
    throw new Exception("Exactly 16 letters a-z are required.")

    override def toString: String =
        ltrStr.substring(0,4)  + "\n" + ltrStr.substring(4,8) + "\n" +
        ltrStr.substring(8,12) + "\n" + ltrStr.substring(12,16)

    case class Letter(letter: Char) {
        var neighbors = List[Letter]()
        def addNeighbor(nbr: Letter) = { neighbors = nbr :: neighbors }
        override def toString = letter.toString
    }

    val letters = new Array[Array[Letter]](4,4)
    for (idx <- 0 until ltrStr.length)
        letters(idx/4)(idx % 4) = Letter(ltrStr(idx))

    for ( idx <- 0 to 3; jdx <- 0 to 3; iOff <- -1 to 1; jOff <- -1 to 1;
          if (iOff != 0 || jOff != 0) &&
          idx + iOff >= 0 && idx + iOff < 4 &&
          jdx + jOff >= 0 && jdx + jOff < 4 )
        letters(idx)(jdx).addNeighbor(letters(idx + iOff)(jdx + jOff))
}

That’s a lot of new code, but it is actually pretty simple if we break it down:

Lines 2-4 just ensure that we have good input.  The code converts the constructor parameter to lower case, and then confirms that it is composed of exactly 16 letters from a to z.

The next section is a toString method.  It just returns the 16-letter string in 4-letter chunks separated by newlines.

Next, there is an inner class called Letter.  It encapsulates one letter on the game board, and includes a way to keep track of neighboring Letters (think of it as a graph node).

Line 16 creates the game board, a 4-by-4 array of Letters.  The for-loop that follows initializes each of the 16 Letter objects using the corresponding letters from the 16-letter string.

All that’s left is to tell each Letter who his neighbors are.  This is done with the somewhat complex for-loop on line 20.  This structure loops through two indices, idx and jdx, as well as two offsets, iOff and jOff.  It’s like four nested loops combined into one.  It loops through idx and jdx values from 0 to 3, so it runs for each of the 16 Letters in the grid.  It also loops through iOff and jOff from -1 to 1, so it looks at each neighbor of each Letter.  See?  For each of the 16 Letters in the grid, it looks at each of the 9 Letters around that Letter (including the Letter itself).

The last section of the for-loop header (lines 21-23) defines which combinations of idx, jdx, iOff, and jOff are valid and should be processed in the loop body.  You see why, right?  First of all, there’s no need for a Letter to add itself to its list of neighbors.  It’s against the game rules to use the same grid position twice in the same word, so any iterations in which iOff and jOff are both 0 are eliminated at line 21.

Also, there are grid positions at the edges and corners.  Those positions will have fewer neighbors.   So for the Letter at position idx=0, jdx=0, offsets of iOff=-1 OR jOff=-1 make no sense.  There are no neighboring Letters in the -1 direction at that position.  These restrictions are made by lines 22 and 23.

Finally, if the indices and offsets pass the tests, the Letter at location (idx, jdx) is assigned a new neighbor, the Letter at (idx+iOff, jdx+jOff).

Of course, this isn’t the only way to write a program like this.  We could have used a simple array of characters and put the neighbor-finding logic in the solver code, but I chose to divide up the responsibilities like this.

We’ve finished setting up the game board.  We have a dictionary of legal words.  Now we’re ready to play.  Let’s add a function called findWords to the GameBoard class.

def findWords(tree: LetterTree): List[String] = {
    def findWords(tree: LetterTree, letter: Letter, sofar: List[Letter]): List[String] = {
        tree.getSubTree(letter.letter) match {
          case Some(subTree) =>
            var words: List[String] = Nil
            if (subTree.terminal) words = (letter :: sofar).foldLeft("")((c,n) => n+c) :: words
            for (nextLetter <- letter.neighbors if !sofar.contains(nextLetter))
            words = findWords(subTree, nextLetter, letter :: sofar) ::: words
            words
          case None => Nil
        }
    }
    var words: List[String] = Nil
    for (idx <- 0 to 3; jdx <- 0 to 3)
        words = words ++ findWords(tree, letters(idx)(jdx), Nil)
    words
}

This is the heart of our little program. This is the code that does the real work. The first thing that happens in this function is that we define an inner function.  We’ll look at that in a second.  First, look at code below the inner function.  We create an empty List of Strings, then for each Letter in the 4-by-4 grid we call the inner findWords function.  We pass in the dictionary tree (at the top tree level), the current Letter, and an empty list.  That’s the list of letters used so far.  At this point, no letters have been used yet, so it’s empty (Nil).

The result of the inner function call is a list of all valid words that start with the given Letter.  That list is added to the list of all valid words.  We do this for each of the 16 Letters.

Now, for that inner function.  Let’s first examine the parameter list. First is tree of type LetterTree. This is the dictionary class we built earlier. On the first call, we pass in the whole dictionary, the top level tree. As we search for words, though, we will pass sub-trees, sub-sub-trees and so forth. The second parameter is a letter of type Letter. That’s just the current letter that we’re evaluating.

The last parameter is called sofar and it has type List[Letter]. Why is it called sofar? Because it’s the ordered list of connected Letters from the game board that we know begin words in the dictionary. The sofar list might contain the letters (c, o, m, p) because this is a prefix to words like compute, computer, computing, comparison, etc. We would never be passed a sofar list containing (c, o, m, p, x) because although this sequence could show up on the game board, this is not a prefix to any word in the dictionary. This parameter will grow longer with each recursive call to the inner function. This is why on the first call to the inner function we pass the empty list, Nil, as the sofar parameter.

We first look up the sub-tree beginning with the letter value of the Letter parameter.  If we find that there is no sub-tree for this letter (the result of getSubTree matches None) then we know that there are no words beginning with the letter. It’s a dead end, so we return an empty list and we do not recurse any further.

If we do find a sub-tree, though, that means that there are words that begin with the sofar list followed by this letter. On line 6, we check to see if the node for the current letter is a terminal node in our dictionary. If it is, then we have a legal word. We append the current letter to the sofar list, convert the list to a string, and put it in the local words list.

In the next line, we loop through all of the current Letter’s neighbors, excluding any that are already in the sofar list. For each unused neighbor, we make a recursive call. This time, the parameters are the dictionary sub-tree, the neighbor letter generated by the for loop, and the sofar list with the current Letter appended. The list of words returned by each recursive call is combined with the local word list. After all the neighboring letters have been checked, the word list is returned from the function.

This is basically a graphs problem. The game board is an undirected graph. The dictionary is a tree, which is a type of graph. We are taking circuit-free paths along the game board graph and mapping them to paths on the tree, accumulating a list of matching paths for which the end is marked terminal on the tree. There are Java (and maybe Scala) libraries for dealing with graphs, and when I get a chance I’d like to see if I can get a more tidy implementation using one of these libraries.

Now, we’ll just pull all this together in a neat little package:

class PuzzleSolver(dictionaryPath: String) {
    val tree = new FileLetterTree(dictionaryPath)
    def solve(letters: String) = {
        val board = new GameBoard(letters)
        val wordSet = new HashSet[String]() ++ board.findWords(tree)
        val sortedWords = wordSet.toList.sort{ (a,b) =>
            a.length > b.length || (a.length == b.length && a > b)
        }
        println(sortedWords)
    }
}

Now we can create an instance of the PuzzleSolver class for a dictionary file that we specify. Then we can call the solve function for game board configurations. This class finds all the legal words contained in the game board, sorts them by length first and alphabetically second, and prints them out.

Here’s a sample session in the Scala interpreter:

scala> val solver = new PuzzleSolver("./words.txt")
solver: PuzzleSolver = PuzzleSolver@1876e5d

scala> solver.solve("TestingTheSolver")
List(storeen, torsel, tinsel, tensor, seeing, nestor, inseer, 
ingest, verst, verso, torse, tinge, store, soree, soget, snite, 
sneer, rotse, rotge, rosel, reest, orsel, inset, insee, ingot, 
hinge, gorse, geest, vlei, vest, vent, vein, veer, veen, tore, 
togs, ting, tine, tien, teng, stog, sore, snee, sero, sent, 
seit, sego, seer, seen, seel, rose, rest, rees, reen, reel, 
ogee, neti, nest, neer, lest, lent, lens, lehi, lees, leer, 
iten, hint, hing, hest, hent, hein, heer, gore, goes, goer, 
gest, gent, gens, gein, eros, vei, vee, tor, tog, toe, tin, 
tie, ten, teg, sot, sog, soe, set, ser, sen, seg, see, rot, 
rog, roe, rev, ree, ose, ore, oes, oer, nit, net, nei, nee, 
lev, les, len, lei, leg, lee, ing, hit, hin, hie, hen, hei, 
got, gos, gor, get, ges, gen, gel, gee, ers, ens, ego, eer, 
eel)

scala>

It works! Here’s the complete source code:

import java.io.File  
import scala.io.Source  
import scala.collection.mutable.HashMap  
import scala.collection.immutable.HashSet  
  
class LetterTree {  
    private val nodes: HashMap[Char,LetterTree] = new HashMap[Char,LetterTree]  
    var terminal: Boolean = false  
    def addWord(word: String): Unit = addWord(word.toList)  
    def addWord(word: List[Char]): Unit = word match {  
        case Nil          => terminal = true  
        case head :: tail => nodes.getOrElseUpdate(head, new LetterTree).addWord(tail)  
    }  
    def getSubTree(letter: Char): Option[LetterTree] =  
        if (nodes.contains(letter)) Some(nodes(letter)) else None  
}  

class FileLetterTree(path: String) extends LetterTree {  
    val file = new File(path)  
    for (line <- Source.fromFile(file).getLines) addWord(line.trim)  
}  


class GameBoard(lettersStr: String) {  
    private val ltrStr = lettersStr.toLowerCase()  
    if (!ltrStr.matches("^[a-z]{16}$"))  
    throw new Exception("Exactly 16 letters a-z are required.")  
  
    override def toString: String =  
        ltrStr.substring(0,4)  + "\n" + ltrStr.substring(4,8) + "\n" +  
        ltrStr.substring(8,12) + "\n" + ltrStr.substring(12,16)  
  
    case class Letter(letter: Char) {  
        var neighbors = List[Letter]()  
        def addNeighbor(nbr: Letter) = { neighbors = nbr :: neighbors }  
        override def toString = letter.toString  
    }  
  
    val letters = new Array[Array[Letter]](4,4)  
    for (idx <- 0 until ltrStr.length)  
        letters(idx/4)(idx % 4) = Letter(ltrStr(idx))  
  
    for ( idx <- 0 to 3; jdx <- 0 to 3; iOff <- -1 to 1; jOff <- -1 to 1;  
          if (iOff != 0 || jOff != 0) &&  
          idx + iOff >= 0 && idx + iOff < 4 &&  
          jdx + jOff >= 0 && jdx + jOff < 4 )  
        letters(idx)(jdx).addNeighbor(letters(idx + iOff)(jdx + jOff))  

  def findWords(tree: LetterTree): List[String] = {  
      def findWords(tree: LetterTree, letter: Letter, sofar: List[Letter]): List[String] = {  
          tree.getSubTree(letter.letter) match {  
            case Some(subTree) =>  
              var words: List[String] = Nil  
              if (subTree.terminal) words = (letter :: sofar).foldLeft("")((c,n) => n+c) :: words  
              for (nextLetter <- letter.neighbors if !sofar.contains(nextLetter))  
              words = findWords(subTree, nextLetter, letter :: sofar) ::: words  
              words  
            case None => Nil  
          }  
      }  
      var words: List[String] = Nil  
      for (idx <- 0 to 3; jdx <- 0 to 3)  
          words = words ++ findWords(tree, letters(idx)(jdx), Nil)  
      words  
  }  
}  

class PuzzleSolver(dictionaryPath: String) {  
    val tree = new FileLetterTree(dictionaryPath)  
    def solve(letters: String) = {  
        val board = new GameBoard(letters)  
        val wordSet = new HashSet[String]() ++ board.findWords(tree)  
        val sortedWords = wordSet.toList.sort{ (a,b) =>  
            a.length > b.length || (a.length == b.length && a > b)  
        }  
        println(sortedWords)  
    }  
}

One last thing: To be clear, no I don’t actually use this to cheat at online games. Just knowing that I could is satisfying enough for me.

In my last post (Tail-Recursion Basics In Scala), I went on and on about how scala, as a functional programming language, was written to accommodate recursion. When you need to do a task over and over the procedural way is to iterate, and the functional way is to recurse. That’s what I thought. But I was browsing the source for the scala List class and I noticed something weird. It’s chock full of iteration and procedural code!

Well, I don’t have to tell you I was a little disappointed. I never claimed to have unreservedly swallowed the FP hype, but I figured the authors of the scala.List must surely practice what they preach. I’m puzzled. To show you what I mean, here are a few methods written in a procedural style just as they appear in scala.List, alongside my attempt to write them in a functional, tail-recursive fashion. This isn’t an exhaustive list by any means. There are tons of methods in class scala.List like these.

Mark the following! I am not trying to to correct the authors of scala.List or point their code out as inferior. On the contrary. They know what they are doing. I do not. They implemented a general-use functional/OO programming language. I have not. I am merely wondering aloud why scala.List was implemented the way it was.

I’m using http://lampsvn.epfl.ch/trac/scala/browser/scala/tags/R_2_7_2_RC2/src/library/scala/List.scalaas my source.

List.length

I used a length function as an example in my last post on recursion. The length method of List returns a count of the items in the List. First, here’s the actual code for the length method:

sealed abstract class List[+A] extends Seq[A] with Product {
...
  def length: Int = {
    var these = this
    var len = 0
    while (!these.isEmpty) {
      len += 1
      these = these.tail
    }
    len
  }
...
}

See why I called this code procedural? Those vars and that while loop. Vars are used for variables whose values change. There’s nothing wrong with it, but it’s a feature of procedural code. Iterative loops are also more procedural than functional. You could argue that the while is not a procedural loop, but a functional closure. But I know iteration when I see it. Why wasn’t it written using recursion? Here’s how it could be implemented in a more functional style.

object List {
...
  def length(list: List[_], len: Int): Int = {
    if (list.isEmpty) len
    else length(list.tail, len + 1)
  }
...
}

sealed abstract class List[+A] extends Seq[A] with Product {
...
  def length: Int = List.length(this, 0)
...
}

There is a List class as well as a List object. I put the recursive function in the List helper object because I couldn’t get the tail recursion optimization when I implemented it as a recursive member function within class List.

List.indices

The indices method returns a List of the zero-based indices of a List. If a list has 4 items, its indices method returns List(0, 1, 2, 3). Once again, we’ll first look at the actual code for the indices method:

sealed abstract class List[+A] extends Seq[A] with Product {
...
  def indices: List[Int] = {
    val b = new ListBuffer[Int]
    var i = 0
    var these = this
    while (!these.isEmpty) {
      b += i
      i += 1
      these = these.tail
    }
    b.toList
  }
...
}

As before, this code steps through the list in a while loop. Here’s how it could be done using recursion:

object List {
...
  def indices(list: List[_], idxList: List[Int], curIdx: Int): List[Int] = {
    if (list.isEmpty) idxList
    else List.indices(list.tail, idxList + curIdx, curIdx + 1)
  }
...
}

sealed abstract class List[+A] extends Seq[A] with Product {
...
  def indices: List[Int] = List.indices(this, Nil, 0)
...
}

List.last

This is a very simple method. It just returns the last item in the List. Again, here’s the actual code for the List.last method:

sealed abstract class List[+A] extends Seq[A] with Product {
...
  override def last: A =
    if (isEmpty) throw new Predef.NoSuchElementException("Nil.last")
    else {
      var cur = this
      var next = this.tail
      while (!next.isEmpty) {
        cur = next
        next = next.tail
      }
      cur.head
    }
...
}

Here’s what it might look like in a functional style.

object List {
...
  def last[A](list: List[A]): A =
    if (list.isEmpty) throw new Predef.NoSuchElementException("Nil.last")
    else {
      if (list.tail.isEmpty) head
      else last(list.tail)
    }
...
}

sealed abstract class List[+A] extends Seq[A] with Product {
...
  override def last: A = List.last(this)
...
}

List.foreach

Ok, last one. This method runs a function once for each item in the List, passing the item as a parameter each time. Here’s the real code:

sealed abstract class List[+A] extends Seq[A] with Product {
...
  final override def foreach(f: A => Unit) {
    var these = this
    while (!these.isEmpty) {
      f(these.head)
      these = these.tail
    }
  }
...
}

And, again, here’s the code rendered in a more functional style.

object List {
...
  def foreach[A](list: List[A], f: A => Unit) {
    if (!list.isEmpty) {
      f(list.head)
      foreach(list.tail, f)
    }
  }
...
}

sealed abstract class List[+A] extends Seq[A] with Product {
...
  final override def foreach(f: A => Unit) = List.foreach(this, f)
...
}

So What?

So what? What does it matter that while loops are used instead of tail recursion? Well, it doesn’t really particularly matter. Not to me, anyway. But it’s surprising that proponents, nay, authorsof a functional language would write (perfectly good) procedural code like this.

I don’t know why. Maybe it’s for the benefit of the Java coders whom scala’s creators hope will become early adopters. Maybe the functions in the List object are frowned upon for some design reason. More likely it’s some technical issue that my inexperienced eye can’t discern. Is there some reason my examples would not work? If you know, please share.

Update: By the way, Debasish Ghosh wrote a related article today. It includes a comparison of the List implementation in Erlang and Scala.

Recursion 101

We all know what recursion is, right?  A function calls itself.  Or function A calls function B which calls function A.  Or A calls B, which calls C, which calls A, etc.  By far the most common situation is that in which a function calls itself.

You don’t see a lot of recursive code in Java.  There are a couple reasons for this.  First, recursion is hard.  It’s not intuitive.  With iteration (the main alternative to recursion) you see the big picture.  You can look at the whole loop and its easy to understand even for the beginner.  With recursion, you see one layer and you have to imagine what happens when those layers stack up.  Like anything else, if you practice iteration then it becomes easier, but compared to iteration, recursion is hard to learn.

Second, Java is not designed to accomodate recursion.  It’s designed to accomodate iteration.  Java gives you for loops, for-each loops, while loops, do loops, arrays, Iterators, ResultSets, etc.  These constructs are all about iteration.  Plus, recursion in Java has an Achilles’ heel: the call stack.

Generally, when you call a function in any language a new level is added to the call stack.  The call stack, we all know, is what keeps track of local variables, what function has called what, etc.  It’s no different in Java.  But the stack has a finite and limited size.  Recursion is fine if you know for certain that you’ll never go more than a few dozen levels deep.  But if the recursion goes too deep, you run out of stack space and your program goes kaput.  That doesn’t happen with iteration, so it’s safer to just avoid recursion altogether.

Recursion In Scala

Scala, however, being a functional language is very much geared toward recursion rather than iteration.  So how does it overcome the limitations of the call stack?  Let’s look at an example recursive function in scala.

def listLength1(list: List[_]): Int = {
  if (list == Nil) 0
  else 1 + listLength1(list.tail)
}

var list1 = List(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
var list2 = List(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
1 to 15 foreach( x => list2 = list2 ++ list2 )

println( listLength1( list1 ) )
println( listLength1( list2 ) )

Function listLength1 recursively counts the number of items in a list.  Try running this in the interpreter.  It works fine for list1, the short list, but the longer list exhausts the stack.  Recursion is a functional language’s bread and butter, but we see here that even scala, a functional language, is subject to call stack limitations.

Don’t give up on recursion yet, though.  Scala has a very important optimization that will allow you to recurse without limit provided you use the right kind of recursion.

Head Recursion And Tail Recursion

There are two basic kinds of recursion: head recursion and tail recursion.  In head recursion, a function makes its recursive call and then performs some more calculations, maybe using the result of the recursive call, for example.  In a tail recursive function, all calculations happen first and the recursive call is the last thing that happens.

The importance of this distinction doesn’t jump out at you, but it’s extremely important!  Imagine a tail recursive function.  It runs.  It completes all its computation.  As its very last action, it is ready to make its recursive call.  What, at this point, is the use of the stack frame?  None at all.  We don’t need our local variables anymore because we’re done with all computations.  We don’t need to know which function we’re in because we’re just going to re-enter the very same function.  Scala, in the case of tail recursion, can eliminate the creation of a new stack frame and just re-use the current stack frame.  The stack never gets any deeper, no matter how many times the recursive call is made.  That’s the voodoo that makes tail recursion special in scala.

Incidentally, some languages achieve a similar end by converting tail recursion into iteration rather than by manipulating the stack.

This won’t work with head recursion.  Do you see why?  Imagine a head recursive function.  First it does some work, then it makes its recursive call, then it does a little more work.  We can’t just re-use the current stack frame when we make that recursive call.  We’re going to NEED that stack frame info after the recursive call completes.  It has our local variables, including the result (if any) returned by the recursive call.

Here’s a question for you.  Is the example function listLength1 head recursive or tail recursive?  Well, what does it do?  (A) It checks whether its parameter is Nil.  (B) If so, it returns 0 since Nil has 0 length.  (C) If not, it returns 1 plus the result of a recursive call.  The recursive call is the last thing we typed before ending the function.  That’s tail recursion, right?  Wrong.  The recursive call is made, and THEN 1 is added to the result, and this sum is returned.  This is actually head recursion (or middle recursion, if you like) because the recursive call is not the very last thing that happens.

Tail Recursion Example

When you write a recursive function in scala, your aim is to encourage the compiler to make tail recursion optimizations.  Now let’s rewrite that function using tail recursion.

def listLength2(list: List[_]): Int = {
  def listLength2Helper(list: List[_], len: Int): Int = {
    if (list == Nil) len
    else listLength2Helper(list.tail, len + 1)
  }
  listLength2Helper(list, 0)
}

println( listLength2( list1 ) )
println( listLength2( list2 ) )

I wrote this as two functions (listLength2 and an internal helper function) to preserve the one-parameter interface used in the earlier example.  It would be great if we could specify a default value for a parameter.  We could then write this as one function, but I don’t know a way to do it.  Long story short: listLength2 just calls listLength2Helper which does the real work and is the recursive function.

Is listLength2Helper head recursive or tail recursive?  When the recursive call is made are we really finished with the stack frame, allowing scala to optimize for tail recursion?  Just like in listLength1, this function first checks for a Nil list, but this time returns the len parameter rather than 0.  If list is non-Nil then we make the recursive call.  But there’s still an addition going on here, len + 1.  Does that ruin the tail recursion?  No.  That term is evaluated first.  Only after all the parameters have been evaluated does the recursive call happen.  This function does indeed qualify for tail recursion optimization.

By The Way…

The two examples in this post are very elementary and it’s easy to predict visually that tail recursion optimization will be applied.  But it’s not always so simple.  If you’re trying to write a complex recursive function and you require that it be optimized as tail recursive then you have a problem.  How do you verify that you have succeeded?  The only ways that I know of are (A) examine the resulting object code or (B) write a unit test to verify that the optimization was made.  The problem with this is that you open yourself up to bugs in your unit tests.  You have to be able to write a test that you know will fail if the function is not optimized.  I’m new to scala, so maybe there’s a better way to do this.  If there is, please educate me.

What I would like to see is an annotation that marks a function as requiring tail recursion optimization.  Such an annotation would trigger a compiler error if the compiler were unable to make the optimization.  The annotation I propose would work like this:

@TailRecursive
def listLength1(list: List[_]): Int = {
  if (list == Nil) 0
  else 1 + listLength1(list.tail)
}

def listLength2(list: List[_]): Int = {
  @TailRecursive
  def listLength2Helper(list: List[_], len: Int): Int = {
    if (list == Nil) len
    else listLength2Helper(list.tail, len + 1)
  }
  listLength2Helper(list, 0)
}

The compiler would succeed for listLength2Helper, but would report an error when it failed to apply the requested tail recursion optimization.  This doesn’t let the developer off the hook on unit testing, nor does it relieve the developer from having to code carefully.  What it does is provide early, infallible verification that a crucial feature was implemented.  Why force the developer to examine the classfile or write a bug-vulnerable unit test when the compiler knowsthe answer and could just cough up the information at compile time?

As I’ve marveled before, the Scala type system is extremely expressive.  In this post, I want to look at one particular aspect of the type system.  I want to look at a few of the ways to pass behaviors into functions.  Using classes, traits, function types, and a construct that looks like an ad hoc interface (and probably other ways I haven’t learned yet) you can exercise precise control over what units of behavior your functions will accept and use.

First, let’s define a zoo of classes that we will experiment on:

trait TestTrait {
  def func(s1: String): Unit
}

class TestClass1 {
  def func(s1: String): Unit = println("TestClass1: "+s1)
}

class TestClass2 {
  def func(s1: String): Int = { println("TestClass2: "+s1); 999; }
}

class TestClass3 extends TestTrait {
  def func(s1: String) = println("TestClass3: "+s1)
}

class TestClass4 {
  def func(a1: Any) = println("TestClass4: "+a1)
}

class TestClass5 {
  def otherFunc(s1: String) = println("TestClass5: "+s1)
}

class TestClass6 {
  def wrongFunc(list: List[Int]) = println("TestClass6: "+list)
}

There.  These are very simple, but I’ll just point out the salient points of each:

  • TestTrait – Basically, an interface with one method.
  • TestClass1 – One method, func, takes a String and returns Unit (void, to Java coders).
  • TestClass2 – Same as TestClass1, but returns Int.
  • TestClass3 – Same as TestClass1, but also extends TestTrait.
  • TestClass4 – Same as TestClass1, but takes an Any parameter.
  • TestClass5 – Same as TestClass1, but its function has a different name
  • TestClass6 – One method, different name, different parameter type, does notextend TestTrait.

Each of these (except TestTrait) defines some behavior that we might want to pass into a function.  That behavior is trivial in these classes, but I’m trying to keep the examples short.  Let’s now define several functions that might be able to use some of this behavior.

def test1(fn : (String) => Unit) = { fn("test1"); }

def test2(fn : (String) => Int)  = { fn("test2"); }

def test3(obj: { def func(x: String): Unit }) = { obj.func("test3") }

def test4(obj: TestClass1) = { obj.func("test4") }

def test5(obj: TestTrait)  = { obj.func("test5") }

def test6(obj: { def func(x: Any): Unit }) = { obj.func("test6") }

Now we just need to instantiate each one class and start passing parameters!  Create an instance of each class, TestClass1 to TestClass6, and name them tc1 to tc6.  The pass each instance’s function into test1 and test2, and pass the instance itself into test3 through test6.  Let’s do this in the Scala interpreter, rather than compiling.  Like so:

val tc1 = new TestClass1
val tc2 = new TestClass2
val tc3 = new TestClass3
val tc4 = new TestClass4
val tc5 = new TestClass5
val tc6 = new TestClass6

test1(tc1.func)
test1(tc2.func)
test1(tc3.func)
test1(tc4.func)
test1(tc5.otherFunc)
test1(tc6.wrongFunc)

test2(tc1.func)
test2(tc2.func)
test2(tc3.func)
test2(tc4.func)
test2(tc5.otherFunc)
test2(tc6.wrongFunc)

test3(tc1)
test3(tc2)
test3(tc3)
test3(tc4)
test3(tc5)
test3(tc6)

test4(tc1)
test4(tc2)
test4(tc3)
test4(tc4)
test4(tc5)
test4(tc6)

test5(tc1)
test5(tc2)
test5(tc3)
test5(tc4)
test5(tc5)
test5(tc6)

test6(tc1)
test6(tc2)
test6(tc3)
test6(tc4)
test6(tc5)
test6(tc6)

Did you run that?  It didn’t work, did it?  At least not all of it.  Now we we have something to discuss.  Here’s a chart of what did and didn’t work:

test1 test2 test3 test4 test5 test6
tc1 OK FAIL OK OK FAIL FAIL
tc2 OK OK FAIL FAIL FAIL FAIL
tc3 OK FAIL OK FAIL OK FAIL
tc4 OK FAIL FAIL FAIL FAIL OK
tc5 OK FAIL FAIL FAIL FAIL FAIL
tc6 FAIL FAIL FAIL FAIL FAIL FAIL

Taking these results one test function at a time:

test1(fn : (String) => Unit)

All but one of the test objects has at least one member function that can be passed in to this function.  It’s easy to see why this works for tc1, tc3, and tc5.  They each have a function that matches the type of parameter fn exactly.  It’s also easy to see why it doesn’t work for tc6.  TestClass6.wrongFunc takes a List[Int] parameter, which is unrelated to fn’s expected String parameter.

But what about tc2 and tc4?  TestClass2.func takes a String parameter, just like the test1′s fn parameter, but it returns an Int instead of Unit.  But it still works.  Instead of requiring a Unit return type for fn, Scala allows functions that return values.  It simply throws away any returned value and treats the passed in function as though it returned Unit.  Scala’s typesafety requirements in this case are very permissive!  This can be convenient, but be aware that you have very little control of what functions can be passed in when you use this kind of parameter type.

TestClass4.func returns Unit, but it takes a parameter of type Any.  Since test1 knows parameter fn as a function with a String parameter, test1 will only ever pass Strings into fn.  A String “is-a” Any so Scala says it’s ok to pass Strings into TestClass4.func.

This all goes back to variance.  In Scala, functions are contravariant with respect to parameters and covariant with respect to return type.  If you’re not familiar with variance, this means that if function f2′s parameters are simple supertypes (simple, meaning a supertype or the same type) of function f1′s parameters, and if f2′s return type is a simple subtype of f1′s return type then f2 is a subtype of f1 and can be used anywhere f1′s type is required.  The makes sense.  If a function takes type X, then you can pass in any value of type X or a subtype of X.  If a function returns type Y, then the result can be treated as Y or any supertype of Y.  Because Unit is a subtype of Int (and of all reference types) and String is a simple supertype of String, the function type “(String) => Int” is a subtype of “(String) => Unit” and can be used wherever a “(String) => Unit” is required.

test2(fn : (String) => Int)

This one only works for tc2.  TestClass2.func matches the type of parameter fn exactly.  But none of the other functions can be used here.  This is because none of the other functions take a parameter that is a simple subtype of String and return a simple supertype of Int, none of the others will work here.

A quick note about the use of functions in this way:  If you have an instance and pass a member function of that instance as a parameter, the function still has access to that instance’s data.  Moreover, if the passed function changes the state of the object of which it is a member, then function that receives the function as a parameter can mutate the object.  I don’t think I’m being clear.  Here’s some code to demonstrate what I mean.

class TestClass(name: String) {
  var memberStr: String = ""
  def func(str: String) = { memberStr += str }
  override def toString = name + ": " + memberStr
}

val testA = new TestClass("A")
val testB = new TestClass("B")

def test(fn : (String) => Unit, num: Int): Unit = {
  if (num > 0) { fn("X"); test(fn,num-1); }
}

test(testA.func, 5)
test(testB.func, 2)

println(testA + ", " + testB)

See?  Function test just takes a function as a parameter (and an Int).  It does NOT take a TestClass as a parameter, but it’s still able to alter those two instances of TestClass by way of the mutator member function TestClass.func.  So be aware that when you declare a function parameter like this, you never know whether you’re getting a bare function, or a member function.

test3(obj: { def func(x: String): Unit })

Here’s a construct that’s much more strict than the function types used in test1 and test2.  The test3 function take an ad-hoc interface as a parameter.  It’s not a named interface (a Trait), so the parameter’s type doesn’t have to be declared as extending anything.  The parameter is accepted if it has a member function of the given name and type.

Now, in this case, remember, the parameter isn’t the function, it’s the object.  So the permissive behavior that Scala allows in test1 and test2 because of the variance behavior of function types does not apply here.  An object of type A has a function funcA, and object of type B has a function funcB whose type is a subtype of funcA, that does NOT make B a subtype of A.  The names and types of the members declared in the ad-hoc interface must match the parameter exactly.  The parameter can have additional members, but it must at least provide all the members listed in the ad-hoc interface.

In this case, TestClass1 and TestClass3 are the only classes that have functions named func that take a String parameter and return Unit.  The other classes have either the wrong function name, parameter type, or return type.

test4(obj: TestClass1)

This is an easy one.  You know what’s happening here.  Function test4 states explicitly that it requires an object of type TestClass1.  Only a TestClass1 or a subtype of TestClass1 will do.

test5(obj: TestTrait)

Function test5 expects a TestTrait parameter.  Any object that has the trait TestTrait can be passed in.  In this case, TestClass3 is the only class that extends TestTrait.  We could have declared TestCase1 as extending TestTrait, but we didn’t.  Therefor, it doesn’t matter that TestClass1 has a function with the right name, the right parameter types, and the right return type.  It isn’t explicitly defined as extending TestTrait, so it isn’t allowed.

test6(obj: { def func(x: Any): Unit })

This is really just a more extreme example of the phenomenon demonstrated in test3.  This is another ad-hoc interface.  That member function is the supertype of ALL functions that have one parameter and return a reference type.  ALL of them.  You see why?  Any is the superclass of every type and Unit is a subclass of every reference type.  Function types are contravariate with respect to their parameters and covariant with respect to their return types.  But, of course, only our instance of TestClass4 is accepted because this parameter doesn’t have a function type, but an ad-hoc interface type.  It is not the case that all types are covariant with respect to the types of their member functions.  TestClass4 is the only one that has a function with the right name, parameter type, and return type.

Variance in Java

If you’re a Java developer then you probably know a thing or two about subtyping. B is a subtype of A if B extends or implements A (I’ll use this convention throughout this post). A is the supertype, B is the subtype. But what about arrays? Or generic collections? Is an array of B a subtype of an array of A? Is List<B> a subtype of List<A>? Let’s do some experiments:

Object testObj = null;

String[] arrayB = { "a", "b", "c" };
Object[] arrayA = arrayB;
testObj = arrayA[0];

List<String> listB = new ArrayList<String>();
listB.add("123");
List<Object> listA = (List)listB;
testObj = listA.get(0);

List<Double> listC = new ArrayList<Double>();
listC.add(new Double(10.0));
listB = (List)listC;
System.out.println(listB.get(0));

I’ve left out the boilerplate for brevity. This code compiles and it runs fine, mostly. This tells us three things. First, Java treats an array of B as a subtype of an array of A. We know this because we can use a reference to an array of Objects to refer to an array of Strings. This means an array of Strings “is-a” array of Objects.

Second, we know that an ArrayList<String> is a subtype of a List<String>. This makes sense, because ArrayList implements interface List. An ArrayList<String> “is-a” List<String>. We can even finagle the List<String> into a List<Object>, but we have to make an explicit cast.

Third, we see a ClassCastException during the call to listA.get(0). We use the explicit cast again to assign a List<Double> to a reference to List<String>. The compiler allows this! This kind of sloppy typing is one of the main complaints of Java generics’ detractors (this link includes an example of an erroneous assignment that doesn’t even require an explicit cast!). Now, let’s look at some of the things Java doesn’t allow.

Object[] arrayA = { new Object() };
String[] arrayB = arrayA;

List<String> listB = new ArrayList<String>();
List<Object> listA = listB;

This causes two compiler errors. The first occurs when we try to assign an array of Object to a reference to an array of String. This is sensible. An array of Objects could contain any object: Strings, Doubles, Threads, anything. We can’t treat such an array as an array of Strings.

The second error occurs when we try to assign a list of Strings to a reference to a List of Objects. This is the same code as in the previous snippet, but without the explicit cast. The compiler doesn’t allow this. Since lists are mutable (we can add and remove items) we could add non-String members to listA, which means those non-String members would also be in listB. But if we don’t add any items to listA, the assignment is perfectly safe. Java can’t figure out in all cases when it’s safe to do an assignment and when it is not.

Why do Java generics work the way they do? I think it’s mainly due to two factors: backward compatibility, promoting adoption of the feature. When generics were introduced, there were already millions of lines of code out there that depend on regular, non-generic, mutable collections. To make the new code compatible with legacy code, the type parameters are erased during compilation, and allowing those dangerous casts lets developers work around the fact that List<B> is not a subtype of List<A>. To make it easier to use generics in new code and convert to non-generics for integration with old code, the compiler rules were made fairly permissive.

Variance Terminology

If you’re not familiar with the term “variance”, here’s what it means with respect to Java. Java arrays are covariant. That means that an array of B is a subtype of an array of A, provided that B is a subtype of A. The type-subtype relationship of the arrays follows the relationship of the contents. Lists in Java are invariant (some say nonvariant). A List of Strings has no relationship to a List of Objects. You can explicitly cast a List of Strings to a List of Objects but you can force the conversion the opposite direction, too, (Yuck.) so that doesn’t count. Now consider some hypothetical generic class X such that X<A> is a subtype of X<B> if B is a subtype of A. That’s the opposite of the way arrays work. The hypothetical generic X is contravariant.

One more detail of terminology: A type class could theoretically be covariant with respect to one type parameter, and contravariant with respect to another (not in Java, just theoretically). So, say you have a generic class X that is covariant with respect to its first type parameter and contravariant with respect to its second. So X<B,I> is a subtype of X<A,J> only if B is a subtype of A and J is a subtype of I. Weird, huh? This actually happens in Scala.

Variance in Scala

In Scala, variance is not left to chance. There are very strict rules. Variance with respect to type parameters is spelled out explicitly for each class (or trait). The same conventions are used for Arrays, Lists, or any generic class! The variance system (indeed, the whole type system) in Scala is more complicated and has a steeper learning curve than in Java, but it affords you the ability to write very expressive code that behaves in a more intuitive fashion. Here are three generic classes that use each of the variance types.

class InVar[T]     { override def toString = "InVar" }
class CoVar[+T]     { override def toString = "CoVar" }
class ContraVar[-T] { override def toString = "ContraVar" }
/************ Regular Assignment ************/
val test1: InVar[String] = new InVar[String]
val test2: CoVar[String] = new CoVar[String]
val test3: ContraVar[String] = new ContraVar[String]

The ‘+’ denotes covariance with respect to the type parameter, and ‘-’ denotes contravariance. The class is invariant with respect to type parameters without a plus or minus. If you run this code you can see that when the type parameters are the same on both sides, the assignments work fine. Now, let’s see what happens when we test assignment for different type parameters.

scala> /************ Invariant Subtyping ************/

scala> val test1: InVar[String] = new InVar[AnyRef]
<console>:5: error: type mismatch;
 found   : InVar[AnyRef]
 required: InVar[String]
       val test1: InVar[String] = new InVar[AnyRef]
                                   ^

scala> val test2: InVar[AnyRef] = new InVar[String]
<console>:5: error: type mismatch;
 found   : InVar[String]
 required: InVar[AnyRef]
       val test2: InVar[AnyRef] = new InVar[String]
                                   ^

scala> /************ Covariant Subtyping ************/

scala> val test3: CoVar[String] = new CoVar[AnyRef]
<console>:5: error: type mismatch;
 found   : CoVar[AnyRef]
 required: CoVar[String]
       val test3: CoVar[String] = new CoVar[AnyRef]
                                  ^

scala> val test4: CoVar[AnyRef] = new CoVar[String]
test4: CoVar[AnyRef] = CoVar

scala> /************ Contravariant Subtyping ************/

scala> val test5: ContraVar[String] = new ContraVar[AnyRef]
test5: ContraVar[String] = ContraVar

scala> val test6: ContraVar[AnyRef] = new ContraVar[String]
<console>:5: error: type mismatch;
 found   : ContraVar[String]
 required: ContraVar[AnyRef]
       val test6: ContraVar[AnyRef] = new ContraVar[String]
                                      ^

Now you can see the difference in the three classes. The invariant class doesn’t allow assignment in either direction, regardless of whether their type parameters have a subtype relationship. The covariant class allows an assignment from subtype to supertype. String is a subtype of AnyRef, so CoVar[String] is a subtype of CoVar[AnyRef]. The contravariant class allows an assignment from supertype to subtype. String, again, is a subtype of AnyRef, so ContraVar[AnyRef] is a subtype of ContraVar[String].

So, it’s as simple as that, right? Sorry. There’s a little more to it. Once you’ve declared a type parameter as covariant or contravariant there are some restrictions on where this type can be used. Why? Scala is not a purely functional langage in that it allows objects to alter their internal state. It allows mutability. Mutability throws a monkey wrench into variance. Say, for example you had a Scala implementation of a linked list like so:

class LinkedList[+A] {
  private var next: LinkedList[A] = null
  def add(item: A): Unit = { ... }
  def get(index: Int): A = { ... }
}

val strList = new LinkedList[String]
strList.add("str1")
val anyList: LinkedList[Any] = strList
anyList.add(new Double(1.0))
val str: String = strList.get(1)

This code won’t compile. Do you see the problem? This code, if it worked, would allow us to create a list of Strings and then add a Double to that list! If we allow that then there’s a disaster when we get to the last line. A LinkedList of Strings returns a Double. That’s no good. That’s the same problem as we saw in Java. Scala nips this sort of code in the bud by disallowing covariant types in certain places including member function parameter types. Places where covariant types are allowed and contravariant types are forbidden are called covariant positions. And the reverse is true, too. If contravariant types are allowed and covariant types are forbidden in some position, this is called a contravariant position.

The above code causes a compiler error for the add method. You can use covariant types in most other places including constructor parameter types, member val types, and method return types. You can also use covariant types as type parameters, but only where covariant types themselves are allowed. This means, in this example, you could add a method that returns a Set[A], but you could not add a method that takes a Set[A] as a parameter, because you could use that passed-in Set[A] to alter the state.

Contravariant types can be used as constructor parameter types, member function parameters types, and as type parameters in each of those positions.

Here’s a fun exercise for the reader. Experiment with using type parameters of the different kinds of variances in different positions. Below is some example code to get you started. For each usage that fails compilation, why is it not allowed? Can you think of a way that such a usage could cause an inconsistency (such as allowing a Double in a String collection, for example)?

class InVar[T](param1: T) {
  def method1(param2: T) = { }
  def method2: T = { param1 }
  def method3: List[T] = { List[T](param1) }
  def method4[U >: T]: List[U] = { List[U](param1) }
  val val1: T = method2
  val val2: Any = param1
  var var1: T = method2
  var var2: Any = param1
}

class CoVar[+T](param1: T) {
  def method1(param2: T) = { }
  def method2: T = { param1 }
  def method3: List[T] = { List[T](param1) }
  def method4[U >: T]: List[U] = { List[U](param1) }
  val val1: T = method2
  val val2: Any = param1
  var var1: T = method2
  var var2: Any = param1
}

class ContraVar[-T](param1: T) {
  def method1(param2: T) = { }
  def method2: T = { param1 }
  def method3: List[T] = { List[T](param1) }
  def method4[U >: T]: List[U] = { List[U](param1) }
  val val1: T = method2
  val val2: Any = param1
  var var1: T = method2
  var var2: Any = param1
}

Maybe in a future post, I’ll do a more thorough analysis of all the places type parameters of the different variances can be used. If you’d be interested in reading such a thing, please do leave a comment.

Conclusion

What was the point of all that? We still don’t get the mutable, covariant collections we were hoping for. But we do get two things we don’t get in Java. We get invariant, typesafe, mutable collections that won’t wind up holding objects of the wrong type, and we get covariant, typesafe, immutable collections. If you’re new to functional programming (like I am), your first thought might be, “What good is an immutable list? Or an immutable array? If it’s immutable, I can’t add any items to it, right?”

Take the Scala List as an example. It’s declared as a “class List[+A]” and it has a method for adding new items that’s declared “def + [B >: A](x : B) : List[B]“. So List is covariant with respect to its type parameter A. That’s great! So a List[BigInt] “is-a” List[Number] and it “is-a” List[Object]. To add items to a List use the ‘+’ function. It doesn’t change the List it was called on, but it does return a new List.

Also, you can add anything to a List. What you add affects what gets returned. Look at the function definition again: “def + [B >: A](x : B) : List[B]“. It take a parameter of type B, where B is any supertype of A (or A itself), and it returns a List[B]. Let’s consider the simple case, and then something more complex.

scala> var strList = List[String]("abc")
strList: List[String] = List(abc)

scala> strList = strList + "xyz"
strList: List[String] = List(abc, xyz)

scala> var objList = strList + new Object()
objList: List[java.lang.Object] = List(abc, xyz, java.lang.Object@156ee8e)

scala> var anyList = strList + 3.1416
anyList: List[Any] = List(abc, xyz, 3.1416)

First we create a List[String] called strList containing one item. Then we add a second String and store the resulting List back in the strList variable. Then we call the ‘+’ function with an Object parameter. This is allowed because the parameter must have type B where B is a supertype of A, and Object is indeed a superclass of String. The call to ‘+’ returns a List[Object] which contains both the Strings and the Object. Simple.

Then we call the ‘+’ method on testList again. Remember, strList still just contains the two Strings because it’s immutable and we didn’t assign the last result back to strList. We couldn’t have. The result was a List[Object], not a List[String], and List is covariant, not contravariant. This time we supply a parameter of type Double. But Double isn’t even a supertype of String! That’s ok. Scala determines the nearest common ancestor, the type Any, and uses that as B. The ‘+’ function returms a List[Any] that contains the Strings and the Double. That’s handy. And covariant and typesafe.

« Previous PageNext Page »