One of my favorite functional programming tricks is folding. The fold left and fold right functions can do a lot of complicated things with a small amount of code. Today, I’d like to (1) introduce folding, (2) make note of some surprising, nay, shocking fold behavior, (3) review the folding code used in Scala’s List class, and (4) make some presumptuous suggestions on how to improve List.

Update: I’ve created a new post in which I list lots and lots of foldLeft examples in case you’d like to learn more about what folding can accomplish.

Know When to Hold ‘Em, Know When to Fold ‘Em

In case you’re not familiar with folding, I’ll describe it as briefly as I can.

Here’s the signature of the foldLeft function from List[A], a list of items of type A:

def foldLeft[B](z: B)(f: (B, A) => B): B

Firstly, foldLeft is a curried function (So is foldRight). If you don’t know about currying, that’s ok; this function just takes its two parameters (z and f) in two sets of parentheses instead of one. Currying isn’t the important part anyway.

The first parameter, z, is of type B, which is to say it can be different from the list contents type. The second parameter, f, is a function that takes a B and an A (a list item) as parameters, and it returns a value of type B. So the purpose of function f is to take a value of type B, use a list item to modify that value and return it.

The foldLeft function goes through the whole List, from head to tail, and passes each value to f. For the first list item, that first parameter, z, is used as the first parameter to f. For the second list item, the result of the first call to f is used as the B type parameter.

For example, say we had a list of Ints 1, 2, and 3. We could call foldLeft(“X”)((b,a) => b + a). For the first item, 1, the function we define would add string “X” to Int 1, returning string “X1”. For the second list item, 2, the function would add string “X1” to Int 2, returning “X12”. And for the final list item, 3, the function would add “X12” to 3 and return “X123”.

Here are a few more examples.

list.foldLeft(0)((b,a) => b+a)
list.foldLeft(1)((b,a) => b*a)
list.foldLeft(List[Int]())((b,a) => a :: b)

The first line is super simple. It’s almost like the example I described above, but the z value is the number 0 instead of string “X”. This fold combines the elements of the list by addition instead of concatenation. So the fold returns the sum of all Ints in the list. Line 2 combines them through multiplication. Do you see why the z value is 1 in this case?

Line 3 is a little more complex. Can you guess what it does? It starts out with an empty list of Ints and adds each item to the accumulator (We call the b parameter of our function the accumulator because it accumulates data from each of our list items). Because it starts with the head and adds to the beginning of the accumulator list until it gets to the last item of the original list, it returns the original list in reverse order.

The foldRight function works in much the same way as foldLeft. Can you guess the difference? You got it. It starts at the end of the list and works its way up to the head.

Folds can be used for MUCH more than I’ve shown here. With folds, you can solve lots of different problems with a standard construct. You should read up on them if you’re just starting out in functional programming.

All That Glitters Is Not Fold

Now for the moment you’ve been waiting for. Fold’s dirty little secret! The below is taken from a scala interpreter session.

scala> var shortList = 1 to 10 toList
shortList: List[Int] = List(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)

scala> var longList = 1 to 325000 toList
longList: List[Int] = List(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 
15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, ...

scala> shortList.foldLeft("")((x,y) => "X")
res1: java.lang.String = X

scala> shortList.foldRight("")((x,y) => "X")
res2: java.lang.String = X

scala> longList.foldLeft("")((x,y) => "X")
res3: java.lang.String = X

scala> longList.foldRight("")((x,y) => "X")
java.lang.StackOverflowError
        at scala.List.foldRight(List.scala:1079)
        at scala.List.foldRight(List.scala:1081)
        at scala.List.foldRight(List.scala:1081)
        at scala.List.foldRight(List.scala:1081)
        at scala.List.foldRight(List.scala:1081)
        at scala.List.foldRight(List.scala:1081)
        at scala.List.foldRight(List.scala:1081)
        at scala.List.foldRight(List.scala:1081)
        at scala.List.foldRig...

We created two lists: shortList with 10 items, and longList with 325,000 items. Then we perform a trivial foldLeft and foldRight on shortList. It’s trivial because the passed-in function always returns the string “X”; it doesn’t even use the list data.

Then we do a foldLeft on longList. This goes off without a hitch. Finally we try to do a foldRight, the same foldRight that succeeded on the shorter list, and it fails! The foldLeft worked. Why didn’t the foldRight work? It’s a perfectly reasonable call against a perfectly reasonable List. Something funny is going on here.

The error message says there was a stack overflow, and the stack trace shows a long list of calls at List.scala:1081. If you’ve read my post about tail-recursion, then you probably suspect that some recursive code is to blame.

Let’s look into List.scala, maybe the single most important Scala source file.

Fool’s Fold

Without further ado, here’s the code for foldLeft and foldRight from List.scala:

override def foldLeft[B](z: B)(f: (B, A) => B): B = {
  var acc = z
  var these = this
  while (!these.isEmpty) {
    acc = f(acc, these.head)
    these = these.tail
  }
  acc
}

override def foldRight[B](z: B)(f: (A, B) => B): B = this match {
  case Nil => z
  case x :: xs => f(x, xs.foldRight(z)(f))
}

Wow! Those two definitions are very different!

The foldLeft function is the one that worked for short and long lists. You can see why? It isn’t head-recursive. In fact, it isn’t recursive at all. It is implemented as a while loop. On each iteration, the next list item is passed to the function f and the accumulator (called acc) is updated. When there are no more list items, the accumulator is returned. No recursion means no stack overflows.

The foldRight function is implemented in a totally different way. If the list is empty, the z parameter is returned. Otherwise, a recursive call is made on the tail (the whole list minus the first item) of this list, and that result is passed to the function f. Study the foldRight definition. Do you understand how it works? It’s an elegant recursive solution, and the code really is quite pretty, but it’s not tail recursive so it fails for large lists.

Why didn’t Mr Odersky just write foldRight using a while loop, too? Then this problem wouldn’t exist, right? The reason is that Scala’s List is a implemented as a singly-linked list. Each list element has access to the next item in the list, but not to the previous item. You can only traverse a list in one direction! This works fine for foldLeft, which goes from head to tail, but foldRight has to start at the end of the list and work its way forward to the head. If foldRight uses recursion, it must recurse all the way to the end and then use the results of those recursive calls as the accumulator passed into function f.

See? The results of the recursive call must be used for further calculation, so the recursive call can’t be the last thing that happens, so it can’t be written as a tail-recursive function. If you don’t know what I’m talking about, read my introduction to tail-recursion.

Out With The Fold, In With The New

So is that it for foldRight? Is it hopeless? I say no!

There is a way to get the same result as foldRight, but using foldLeft. Can you guess what it is? Here’s how:

list.foldRight("X")((a,b) => a + b)
list.reverse.foldLeft("X")((b,a) => a + b)

These two lines are equivalent! They give the same result no matter what’s in list. Since foldRight processes list elements from last to first, that’s the same as processing the reversed list from first to last.

Here are three possible implementations of foldRight that could replace the current one.

def foldRight[B](z: B)(f: (A, B) => B): B = 
  reverse.foldLeft(z)((b,a) => f(a,b))

def foldRight[B](z: B)(f: (A, B) => B): B =
  if (length > 50) reverse.foldLeft(z)((b,a) => f(a,b))
  else             originalFoldRight(z)(f)

def foldRight[B](z: B)(f: (A, B) => B): B =
  try {
    originalFoldRight(z)(f)
  } catch {
    case e1: StackOverflowError => reverse.foldLeft(z)((b,a) => f(a,b))
  }

The first one simply replaces the original recursive logic with the equivalent call to reverse and foldLeft. Why wasn’t foldRight implemented this way to begin with? It may be, in part, that the authors thought the extra overhead of reversing the list was unwarranted. To me, it doesn’t seem that bad. The original foldRight and foldLeft functions are O(n), meaning they run in an amount of time roughly proportional to the number of items in the list. If you look at the source for the reverse function, you’ll see it’s also O(n). So running reverse followed by foldLeft is O(n).

The second implementation is a compromise. It uses the original recursive version of foldRight (referred to as originalFoldRight in the above code) only when the list is shorter than 50 elements. The reverse.foldLeft is used for lists of 50 elements or longer. 50 is just an arbitrary number, just a guess at a sensible limit on the number of recursive calls to allow.

The third implementation tries the original foldRight logic first and if the call stack overflows then it uses reverse.foldLeft. This solution is, of course, completely ridiculous, but even this would be better than a foldRight which sometimes crashes your program.

That’s All, Folds!

As I pointed out before, the reverse.foldLeft implementation of foldRight is O(n), same as the original recursive version. The original foldRight may work just fine when your Scala application is young and working with small data sets. Over time more customers are added, more products are created, more orders are placed, and then one day, *POOF*, a runtime error! It’s a ticking time-bomb.

As you may well guess, I would like to see the reverse.foldLeft logic used instead of the recursive version. That would prevent the stack overflow errors. But I would settle for just deprecating foldRight. It would be better to eliminate foldRight and force the coder to work around it than to leave it in its current state. In fact, I don’t think any head-recursive functions belong in the List class.

Do any readers have any insight into why foldRight is coded the way it is?

Don’t forget to subscribe to my RSS feed, or follow this blog on Twitter.
Copyright © 2009 Matthew Jason Malone

I was playing a game on my iPhone called Scramble the other day. It’s a great game. You are presented with a 4×4 grid of letters, and your job is to find words by chaining together adjacent letters. It bears a passing similarity to Boggle. I was playing online and I noticed that many other players play much better than I do. Well, a software developer doesn’t take a thing like this lying down. I decided it was time to write a Scramble (or Boggle) solver!

And since I’m trying to pick up the Scala language whenever I have an opportunity, I wrote the whole thing in Scala. I hope it will be, for the reader, an interesting study of a complete and useful (though simple) example Scala program.

First, I had to decide on a strategy. If we try every path on the game board, that’s inefficient. For example, if we try a path XQ then we can stop right there. XQ is not a word and no word starts with the letters XQ. This suggests using some sort of spell-checker-like logic. So let’s first create a dictionary data structure that allow us to look up words and eliminate dead ends. What we need is a tree. Here’s the data structure I came up with.

import scala.collection.mutable.HashMap

class LetterTree {
    private val nodes: HashMap[Char,LetterTree] = new HashMap[Char,LetterTree]
    var terminal: Boolean = false
    def addWord(word: String): Unit = addWord(word.toList)
    def addWord(word: List[Char]): Unit = word match {
        case Nil          => terminal = true
        case head :: tail => nodes.getOrElseUpdate(head, new LetterTree).addWord(tail)
    }
    def getSubTree(letter: Char): Option[LetterTree] =
        if (nodes.contains(letter)) Some(nodes(letter)) else None
}

It’s a mutable class. I toyed with some ideas for an immutable class, but it just complicated things more than I wanted to cope with. Our dictionary will be a tree in which LetterTree is the node class. You can see that a LetterTree has two member data: a hashmap of child LetterTrees indexed by Char, and a flag called terminal. So a LetterTree is a tree node whose child nodes are indexed by letter, and which can be marked (via the terminal flag) as ending a word. This is a pretty efficient way to store words.

There are two addWord methods for populating the tree. One takes a list of Chars. The other takes a String and is included for convenience. It converts its String parameter to a list of Chars and calls the other addWord method. If the addWord method is passed an empty list (Nil), then it has reached the end of a word and sets the terminal flag. Otherwise, it takes the first Char in the list and looks up (or else creates) the LetterTree mapped to that Char. The method then adds the remainder of the Char list to that mapped LetterTree.

Finally, there’s the getSubTree method. This will be useful when we start using the tree to look up potential word matches.

So imagine that we create a LetterTree and add the following words: item, its, it. We get a structure like this:

Figure 1

Figure 1

Each of the boxes represents a LetterTree instance. Each arrow, labeled with letter, represents an entry in the ‘nodes’ HashMap. The top level LetterTree does not itself represent a letter. It’s just a starting point. Its ‘nodes’ member has just one mapping: letter ‘i’ maps to a second LetterTree. That second LetterTree doesn’t have its ‘terminal’ flag set, so we haven’t made a word yet. Its ‘nodes’ map has a single entry mapping ‘t’ to a third LetterTree. This third LetterTree is a terminal node, so we know that the sequence ‘it’ forms a word. The third LetterTree’s ‘nodes’ map has entries for ‘s’ and ‘e’. The LetterTree mapped to ‘s’ is terminal, and the one mapped to ‘e’ is not. The ‘e’ node has a child node ‘m’ that is a terminal.

This is a pretty efficient way to store words. We got to use the sequence ‘it’ 3 times! Also, note that there is exactly one terminal node for each word. If we add a word, we will either append a new leaf node which will be terminal, or we will make one of the existing nodes terminal.

Now, all we have to do is create a LetterTree and call the addWord method for every word we can think of. That could be annoying. Let’s create an improved LetterTree that will read a list of words from a file.

import java.io.File
import scala.io.Source

class FileLetterTree(path: String) extends LetterTree {
    val file = new File(path)
     for (line <- Source.fromFile(file).getLines) addWord(line.trim)
}

Can you believe how easy that was?! We just extend LetterTree, take a file path as a constructor parameter, use the scala.io.Source class to get all lines, and add each line as a word. Now just find a text file containing all English words. You should be able to google it.

You might want to ensure the file contains only words 3 letters or longer (as required by the rules of the game), only lower case, and only the 26 english letters. Since Scramble has no ‘Q’ piece (neither does Boggle) but only a ‘Qu’ piece, you might want to do a global replace on your word file, replacing all instance of ‘qu’ with ‘q’.

Ok, now we have a data structure for looking up words.  Now we need to create some code that represents the game board.  We’ll assume a 4-by-4 board.

class GameBoard(lettersStr: String) {
    private val ltrStr = lettersStr.toLowerCase()
    if (!ltrStr.matches("^[a-z]{16}$"))
    throw new Exception("Exactly 16 letters a-z are required.")

    override def toString: String =
        ltrStr.substring(0,4)  + "\n" + ltrStr.substring(4,8) + "\n" +
        ltrStr.substring(8,12) + "\n" + ltrStr.substring(12,16)

    case class Letter(letter: Char) {
        var neighbors = List[Letter]()
        def addNeighbor(nbr: Letter) = { neighbors = nbr :: neighbors }
        override def toString = letter.toString
    }

    val letters = new Array[Array[Letter]](4,4)
    for (idx <- 0 until ltrStr.length)
        letters(idx/4)(idx % 4) = Letter(ltrStr(idx))

    for ( idx <- 0 to 3; jdx <- 0 to 3; iOff <- -1 to 1; jOff <- -1 to 1;
          if (iOff != 0 || jOff != 0) &&
          idx + iOff >= 0 && idx + iOff < 4 &&
          jdx + jOff >= 0 && jdx + jOff < 4 )
        letters(idx)(jdx).addNeighbor(letters(idx + iOff)(jdx + jOff))
}

That’s a lot of new code, but it is actually pretty simple if we break it down:

Lines 2-4 just ensure that we have good input.  The code converts the constructor parameter to lower case, and then confirms that it is composed of exactly 16 letters from a to z.

The next section is a toString method.  It just returns the 16-letter string in 4-letter chunks separated by newlines.

Next, there is an inner class called Letter.  It encapsulates one letter on the game board, and includes a way to keep track of neighboring Letters (think of it as a graph node).

Line 16 creates the game board, a 4-by-4 array of Letters.  The for-loop that follows initializes each of the 16 Letter objects using the corresponding letters from the 16-letter string.

All that’s left is to tell each Letter who his neighbors are.  This is done with the somewhat complex for-loop on line 20.  This structure loops through two indices, idx and jdx, as well as two offsets, iOff and jOff.  It’s like four nested loops combined into one.  It loops through idx and jdx values from 0 to 3, so it runs for each of the 16 Letters in the grid.  It also loops through iOff and jOff from -1 to 1, so it looks at each neighbor of each Letter.  See?  For each of the 16 Letters in the grid, it looks at each of the 9 Letters around that Letter (including the Letter itself).

The last section of the for-loop header (lines 21-23) defines which combinations of idx, jdx, iOff, and jOff are valid and should be processed in the loop body.  You see why, right?  First of all, there’s no need for a Letter to add itself to its list of neighbors.  It’s against the game rules to use the same grid position twice in the same word, so any iterations in which iOff and jOff are both 0 are eliminated at line 21.

Also, there are grid positions at the edges and corners.  Those positions will have fewer neighbors.   So for the Letter at position idx=0, jdx=0, offsets of iOff=-1 OR jOff=-1 make no sense.  There are no neighboring Letters in the -1 direction at that position.  These restrictions are made by lines 22 and 23.

Finally, if the indices and offsets pass the tests, the Letter at location (idx, jdx) is assigned a new neighbor, the Letter at (idx+iOff, jdx+jOff).

Of course, this isn’t the only way to write a program like this.  We could have used a simple array of characters and put the neighbor-finding logic in the solver code, but I chose to divide up the responsibilities like this.

We’ve finished setting up the game board.  We have a dictionary of legal words.  Now we’re ready to play.  Let’s add a function called findWords to the GameBoard class.

def findWords(tree: LetterTree): List[String] = {
    def findWords(tree: LetterTree, letter: Letter, sofar: List[Letter]): List[String] = {
        tree.getSubTree(letter.letter) match {
          case Some(subTree) =>
            var words: List[String] = Nil
            if (subTree.terminal) words = (letter :: sofar).foldLeft("")((c,n) => n+c) :: words
            for (nextLetter <- letter.neighbors if !sofar.contains(nextLetter))
            words = findWords(subTree, nextLetter, letter :: sofar) ::: words
            words
          case None => Nil
        }
    }
    var words: List[String] = Nil
    for (idx <- 0 to 3; jdx <- 0 to 3)
        words = words ++ findWords(tree, letters(idx)(jdx), Nil)
    words
}

This is the heart of our little program. This is the code that does the real work. The first thing that happens in this function is that we define an inner function.  We’ll look at that in a second.  First, look at code below the inner function.  We create an empty List of Strings, then for each Letter in the 4-by-4 grid we call the inner findWords function.  We pass in the dictionary tree (at the top tree level), the current Letter, and an empty list.  That’s the list of letters used so far.  At this point, no letters have been used yet, so it’s empty (Nil).

The result of the inner function call is a list of all valid words that start with the given Letter.  That list is added to the list of all valid words.  We do this for each of the 16 Letters.

Now, for that inner function.  Let’s first examine the parameter list. First is tree of type LetterTree. This is the dictionary class we built earlier. On the first call, we pass in the whole dictionary, the top level tree. As we search for words, though, we will pass sub-trees, sub-sub-trees and so forth. The second parameter is a letter of type Letter. That’s just the current letter that we’re evaluating.

The last parameter is called sofar and it has type List[Letter]. Why is it called sofar? Because it’s the ordered list of connected Letters from the game board that we know begin words in the dictionary. The sofar list might contain the letters (c, o, m, p) because this is a prefix to words like compute, computer, computing, comparison, etc. We would never be passed a sofar list containing (c, o, m, p, x) because although this sequence could show up on the game board, this is not a prefix to any word in the dictionary. This parameter will grow longer with each recursive call to the inner function. This is why on the first call to the inner function we pass the empty list, Nil, as the sofar parameter.

We first look up the sub-tree beginning with the letter value of the Letter parameter.  If we find that there is no sub-tree for this letter (the result of getSubTree matches None) then we know that there are no words beginning with the letter. It’s a dead end, so we return an empty list and we do not recurse any further.

If we do find a sub-tree, though, that means that there are words that begin with the sofar list followed by this letter. On line 6, we check to see if the node for the current letter is a terminal node in our dictionary. If it is, then we have a legal word. We append the current letter to the sofar list, convert the list to a string, and put it in the local words list.

In the next line, we loop through all of the current Letter’s neighbors, excluding any that are already in the sofar list. For each unused neighbor, we make a recursive call. This time, the parameters are the dictionary sub-tree, the neighbor letter generated by the for loop, and the sofar list with the current Letter appended. The list of words returned by each recursive call is combined with the local word list. After all the neighboring letters have been checked, the word list is returned from the function.

This is basically a graphs problem. The game board is an undirected graph. The dictionary is a tree, which is a type of graph. We are taking circuit-free paths along the game board graph and mapping them to paths on the tree, accumulating a list of matching paths for which the end is marked terminal on the tree. There are Java (and maybe Scala) libraries for dealing with graphs, and when I get a chance I’d like to see if I can get a more tidy implementation using one of these libraries.

Now, we’ll just pull all this together in a neat little package:

class PuzzleSolver(dictionaryPath: String) {
    val tree = new FileLetterTree(dictionaryPath)
    def solve(letters: String) = {
        val board = new GameBoard(letters)
        val wordSet = new HashSet[String]() ++ board.findWords(tree)
        val sortedWords = wordSet.toList.sort{ (a,b) =>
            a.length > b.length || (a.length == b.length && a > b)
        }
        println(sortedWords)
    }
}

Now we can create an instance of the PuzzleSolver class for a dictionary file that we specify. Then we can call the solve function for game board configurations. This class finds all the legal words contained in the game board, sorts them by length first and alphabetically second, and prints them out.

Here’s a sample session in the Scala interpreter:

scala> val solver = new PuzzleSolver("./words.txt")
solver: PuzzleSolver = PuzzleSolver@1876e5d

scala> solver.solve("TestingTheSolver")
List(storeen, torsel, tinsel, tensor, seeing, nestor, inseer, 
ingest, verst, verso, torse, tinge, store, soree, soget, snite, 
sneer, rotse, rotge, rosel, reest, orsel, inset, insee, ingot, 
hinge, gorse, geest, vlei, vest, vent, vein, veer, veen, tore, 
togs, ting, tine, tien, teng, stog, sore, snee, sero, sent, 
seit, sego, seer, seen, seel, rose, rest, rees, reen, reel, 
ogee, neti, nest, neer, lest, lent, lens, lehi, lees, leer, 
iten, hint, hing, hest, hent, hein, heer, gore, goes, goer, 
gest, gent, gens, gein, eros, vei, vee, tor, tog, toe, tin, 
tie, ten, teg, sot, sog, soe, set, ser, sen, seg, see, rot, 
rog, roe, rev, ree, ose, ore, oes, oer, nit, net, nei, nee, 
lev, les, len, lei, leg, lee, ing, hit, hin, hie, hen, hei, 
got, gos, gor, get, ges, gen, gel, gee, ers, ens, ego, eer, 
eel)

scala>

It works! Here’s the complete source code:

import java.io.File  
import scala.io.Source  
import scala.collection.mutable.HashMap  
import scala.collection.immutable.HashSet  
  
class LetterTree {  
    private val nodes: HashMap[Char,LetterTree] = new HashMap[Char,LetterTree]  
    var terminal: Boolean = false  
    def addWord(word: String): Unit = addWord(word.toList)  
    def addWord(word: List[Char]): Unit = word match {  
        case Nil          => terminal = true  
        case head :: tail => nodes.getOrElseUpdate(head, new LetterTree).addWord(tail)  
    }  
    def getSubTree(letter: Char): Option[LetterTree] =  
        if (nodes.contains(letter)) Some(nodes(letter)) else None  
}  

class FileLetterTree(path: String) extends LetterTree {  
    val file = new File(path)  
    for (line <- Source.fromFile(file).getLines) addWord(line.trim)  
}  


class GameBoard(lettersStr: String) {  
    private val ltrStr = lettersStr.toLowerCase()  
    if (!ltrStr.matches("^[a-z]{16}$"))  
    throw new Exception("Exactly 16 letters a-z are required.")  
  
    override def toString: String =  
        ltrStr.substring(0,4)  + "\n" + ltrStr.substring(4,8) + "\n" +  
        ltrStr.substring(8,12) + "\n" + ltrStr.substring(12,16)  
  
    case class Letter(letter: Char) {  
        var neighbors = List[Letter]()  
        def addNeighbor(nbr: Letter) = { neighbors = nbr :: neighbors }  
        override def toString = letter.toString  
    }  
  
    val letters = new Array[Array[Letter]](4,4)  
    for (idx <- 0 until ltrStr.length)  
        letters(idx/4)(idx % 4) = Letter(ltrStr(idx))  
  
    for ( idx <- 0 to 3; jdx <- 0 to 3; iOff <- -1 to 1; jOff <- -1 to 1;  
          if (iOff != 0 || jOff != 0) &&  
          idx + iOff >= 0 && idx + iOff < 4 &&  
          jdx + jOff >= 0 && jdx + jOff < 4 )  
        letters(idx)(jdx).addNeighbor(letters(idx + iOff)(jdx + jOff))  

  def findWords(tree: LetterTree): List[String] = {  
      def findWords(tree: LetterTree, letter: Letter, sofar: List[Letter]): List[String] = {  
          tree.getSubTree(letter.letter) match {  
            case Some(subTree) =>  
              var words: List[String] = Nil  
              if (subTree.terminal) words = (letter :: sofar).foldLeft("")((c,n) => n+c) :: words  
              for (nextLetter <- letter.neighbors if !sofar.contains(nextLetter))  
              words = findWords(subTree, nextLetter, letter :: sofar) ::: words  
              words  
            case None => Nil  
          }  
      }  
      var words: List[String] = Nil  
      for (idx <- 0 to 3; jdx <- 0 to 3)  
          words = words ++ findWords(tree, letters(idx)(jdx), Nil)  
      words  
  }  
}  

class PuzzleSolver(dictionaryPath: String) {  
    val tree = new FileLetterTree(dictionaryPath)  
    def solve(letters: String) = {  
        val board = new GameBoard(letters)  
        val wordSet = new HashSet[String]() ++ board.findWords(tree)  
        val sortedWords = wordSet.toList.sort{ (a,b) =>  
            a.length > b.length || (a.length == b.length && a > b)  
        }  
        println(sortedWords)  
    }  
}

One last thing: To be clear, no I don’t actually use this to cheat at online games. Just knowing that I could is satisfying enough for me.

Don’t forget to subscribe to my RSS feed, or follow this blog on Twitter.
Copyright © 2009 Matthew Jason Malone

In my last post (Tail-Recursion Basics In Scala), I went on and on about how scala, as a functional programming language, was written to accommodate recursion. When you need to do a task over and over the procedural way is to iterate, and the functional way is to recurse. That’s what I thought. But I was browsing the source for the scala List class and I noticed something weird. It’s chock full of iteration and procedural code!

Well, I don’t have to tell you I was a little disappointed. I never claimed to have unreservedly swallowed the FP hype, but I figured the authors of the scala.List must surely practice what they preach. I’m puzzled. To show you what I mean, here are a few methods written in a procedural style just as they appear in scala.List, alongside my attempt to write them in a functional, tail-recursive fashion. This isn’t an exhaustive list by any means. There are tons of methods in class scala.List like these.

Mark the following! I am not trying to to correct the authors of scala.List or point their code out as inferior. On the contrary. They know what they are doing. I do not. They implemented a general-use functional/OO programming language. I have not. I am merely wondering aloud why scala.List was implemented the way it was.

I’m using http://lampsvn.epfl.ch/trac/scala/browser/scala/tags/R_2_7_2_RC2/src/library/scala/List.scala as my source.

List.length

I used a length function as an example in my last post on recursion. The length method of List returns a count of the items in the List. First, here’s the actual code for the length method:

sealed abstract class List[+A] extends Seq[A] with Product {
...
  def length: Int = {
    var these = this
    var len = 0
    while (!these.isEmpty) {
      len += 1
      these = these.tail
    }
    len
  }
...
}

See why I called this code procedural? Those vars and that while loop. Vars are used for variables whose values change. There’s nothing wrong with it, but it’s a feature of procedural code. Iterative loops are also more procedural than functional. You could argue that the while is not a procedural loop, but a functional closure. But I know iteration when I see it. Why wasn’t it written using recursion? Here’s how it could be implemented in a more functional style.

object List {
...
  def length(list: List[_], len: Int): Int = {
    if (list.isEmpty) len
    else length(list.tail, len + 1)
  }
...
}

sealed abstract class List[+A] extends Seq[A] with Product {
...
  def length: Int = List.length(this, 0)
...
}

There is a List class as well as a List object. I put the recursive function in the List helper object because I couldn’t get the tail recursion optimization when I implemented it as a recursive member function within class List.

List.indices

The indices method returns a List of the zero-based indices of a List. If a list has 4 items, its indices method returns List(0, 1, 2, 3). Once again, we’ll first look at the actual code for the indices method:

sealed abstract class List[+A] extends Seq[A] with Product {
...
  def indices: List[Int] = {
    val b = new ListBuffer[Int]
    var i = 0
    var these = this
    while (!these.isEmpty) {
      b += i
      i += 1
      these = these.tail
    }
    b.toList
  }
...
}

As before, this code steps through the list in a while loop. Here’s how it could be done using recursion:

object List {
...
  def indices(list: List[_], idxList: List[Int], curIdx: Int): List[Int] = {
    if (list.isEmpty) idxList
    else List.indices(list.tail, idxList + curIdx, curIdx + 1)
  }
...
}

sealed abstract class List[+A] extends Seq[A] with Product {
...
  def indices: List[Int] = List.indices(this, Nil, 0)
...
}

List.last

This is a very simple method. It just returns the last item in the List. Again, here’s the actual code for the List.last method:

sealed abstract class List[+A] extends Seq[A] with Product {
...
  override def last: A =
    if (isEmpty) throw new Predef.NoSuchElementException("Nil.last")
    else {
      var cur = this
      var next = this.tail
      while (!next.isEmpty) {
        cur = next
        next = next.tail
      }
      cur.head
    }
...
}

Here’s what it might look like in a functional style.

object List {
...
  def last[A](list: List[A]): A =
    if (list.isEmpty) throw new Predef.NoSuchElementException("Nil.last")
    else {
      if (list.tail.isEmpty) head
      else last(list.tail)
    }
...
}

sealed abstract class List[+A] extends Seq[A] with Product {
...
  override def last: A = List.last(this)
...
}

List.foreach

Ok, last one. This method runs a function once for each item in the List, passing the item as a parameter each time. Here’s the real code:

sealed abstract class List[+A] extends Seq[A] with Product {
...
  final override def foreach(f: A => Unit) {
    var these = this
    while (!these.isEmpty) {
      f(these.head)
      these = these.tail
    }
  }
...
}

And, again, here’s the code rendered in a more functional style.

object List {
...
  def foreach[A](list: List[A], f: A => Unit) {
    if (!list.isEmpty) {
      f(list.head)
      foreach(list.tail, f)
    }
  }
...
}

sealed abstract class List[+A] extends Seq[A] with Product {
...
  final override def foreach(f: A => Unit) = List.foreach(this, f)
...
}

So What?

So what? What does it matter that while loops are used instead of tail recursion? Well, it doesn’t really particularly matter. Not to me, anyway. But it’s surprising that proponents, nay, authors of a functional language would write (perfectly good) procedural code like this.

I don’t know why. Maybe it’s for the benefit of the Java coders whom scala’s creators hope will become early adopters. Maybe the functions in the List object are frowned upon for some design reason. More likely it’s some technical issue that my inexperienced eye can’t discern. Is there some reason my examples would not work? If you know, please share.

Update: By the way, Debasish Ghosh wrote a related article today. It includes a comparison of the List implementation in Erlang and Scala.

Don’t forget to subscribe to my RSS feed, or follow this blog on Twitter.
Copyright © 2008 Matthew Jason Malone

Recursion 101

We all know what recursion is, right?  A function calls itself.  Or function A calls function B which calls function A.  Or A calls B, which calls C, which calls A, etc.  By far the most common situation is that in which a function calls itself.

You don’t see a lot of recursive code in Java.  There are a couple reasons for this.  First, recursion is hard.  It’s not intuitive.  With iteration (the main alternative to recursion) you see the big picture.  You can look at the whole loop and its easy to understand even for the beginner.  With recursion, you see one layer and you have to imagine what happens when those layers stack up.  Like anything else, if you practice iteration then it becomes easier, but compared to iteration, recursion is hard to learn.

Second, Java is not designed to accomodate recursion.  It’s designed to accomodate iteration.  Java gives you for loops, for-each loops, while loops, do loops, arrays, Iterators, ResultSets, etc.  These constructs are all about iteration.  Plus, recursion in Java has an Achilles’ heel: the call stack.

Generally, when you call a function in any language a new level is added to the call stack.  The call stack, we all know, is what keeps track of local variables, what function has called what, etc.  It’s no different in Java.  But the stack has a finite and limited size.  Recursion is fine if you know for certain that you’ll never go more than a few dozen levels deep.  But if the recursion goes too deep, you run out of stack space and your program goes kaput.  That doesn’t happen with iteration, so it’s safer to just avoid recursion altogether.

Recursion In Scala

Scala, however, being a functional language is very much geared toward recursion rather than iteration.  So how does it overcome the limitations of the call stack?  Let’s look at an example recursive function in scala.

def listLength1(list: List[_]): Int = {
  if (list == Nil) 0
  else 1 + listLength1(list.tail)
}

var list1 = List(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
var list2 = List(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
1 to 15 foreach( x => list2 = list2 ++ list2 )

println( listLength1( list1 ) )
println( listLength1( list2 ) )

Function listLength1 recursively counts the number of items in a list.  Try running this in the interpreter.  It works fine for list1, the short list, but the longer list exhausts the stack.  Recursion is a functional language’s bread and butter, but we see here that even scala, a functional language, is subject to call stack limitations.

Don’t give up on recursion yet, though.  Scala has a very important optimization that will allow you to recurse without limit provided you use the right kind of recursion.

Head Recursion And Tail Recursion

There are two basic kinds of recursion: head recursion and tail recursion.  In head recursion, a function makes its recursive call and then performs some more calculations, maybe using the result of the recursive call, for example.  In a tail recursive function, all calculations happen first and the recursive call is the last thing that happens.

The importance of this distinction doesn’t jump out at you, but it’s extremely important!  Imagine a tail recursive function.  It runs.  It completes all its computation.  As its very last action, it is ready to make its recursive call.  What, at this point, is the use of the stack frame?  None at all.  We don’t need our local variables anymore because we’re done with all computations.  We don’t need to know which function we’re in because we’re just going to re-enter the very same function.  Scala, in the case of tail recursion, can eliminate the creation of a new stack frame and just re-use the current stack frame.  The stack never gets any deeper, no matter how many times the recursive call is made.  That’s the voodoo that makes tail recursion special in scala.

Incidentally, some languages achieve a similar end by converting tail recursion into iteration rather than by manipulating the stack.

This won’t work with head recursion.  Do you see why?  Imagine a head recursive function.  First it does some work, then it makes its recursive call, then it does a little more work.  We can’t just re-use the current stack frame when we make that recursive call.  We’re going to NEED that stack frame info after the recursive call completes.  It has our local variables, including the result (if any) returned by the recursive call.

Here’s a question for you.  Is the example function listLength1 head recursive or tail recursive?  Well, what does it do?  (A) It checks whether its parameter is Nil.  (B) If so, it returns 0 since Nil has 0 length.  (C) If not, it returns 1 plus the result of a recursive call.  The recursive call is the last thing we typed before ending the function.  That’s tail recursion, right?  Wrong.  The recursive call is made, and THEN 1 is added to the result, and this sum is returned.  This is actually head recursion (or middle recursion, if you like) because the recursive call is not the very last thing that happens.

Tail Recursion Example

When you write a recursive function in scala, your aim is to encourage the compiler to make tail recursion optimizations.  Now let’s rewrite that function using tail recursion.

def listLength2(list: List[_]): Int = {
  def listLength2Helper(list: List[_], len: Int): Int = {
    if (list == Nil) len
    else listLength2Helper(list.tail, len + 1)
  }
  listLength2Helper(list, 0)
}

println( listLength2( list1 ) )
println( listLength2( list2 ) )

I wrote this as two functions (listLength2 and an internal helper function) to preserve the one-parameter interface used in the earlier example.  It would be great if we could specify a default value for a parameter.  We could then write this as one function, but I don’t know a way to do it.  Long story short: listLength2 just calls listLength2Helper which does the real work and is the recursive function.

Is listLength2Helper head recursive or tail recursive?  When the recursive call is made are we really finished with the stack frame, allowing scala to optimize for tail recursion?  Just like in listLength1, this function first checks for a Nil list, but this time returns the len parameter rather than 0.  If list is non-Nil then we make the recursive call.  But there’s still an addition going on here, len + 1.  Does that ruin the tail recursion?  No.  That term is evaluated first.  Only after all the parameters have been evaluated does the recursive call happen.  This function does indeed qualify for tail recursion optimization.

By The Way…

The two examples in this post are very elementary and it’s easy to predict visually that tail recursion optimization will be applied.  But it’s not always so simple.  If you’re trying to write a complex recursive function and you require that it be optimized as tail recursive then you have a problem.  How do you verify that you have succeeded?  The only ways that I know of are (A) examine the resulting object code or (B) write a unit test to verify that the optimization was made.  The problem with this is that you open yourself up to bugs in your unit tests.  You have to be able to write a test that you know will fail if the function is not optimized.  I’m new to scala, so maybe there’s a better way to do this.  If there is, please educate me.

What I would like to see is an annotation that marks a function as requiring tail recursion optimization.  Such an annotation would trigger a compiler error if the compiler were unable to make the optimization.  The annotation I propose would work like this:

@TailRecursive
def listLength1(list: List[_]): Int = {
  if (list == Nil) 0
  else 1 + listLength1(list.tail)
}

def listLength2(list: List[_]): Int = {
  @TailRecursive
  def listLength2Helper(list: List[_], len: Int): Int = {
    if (list == Nil) len
    else listLength2Helper(list.tail, len + 1)
  }
  listLength2Helper(list, 0)
}

The compiler would succeed for listLength2Helper, but would report an error when it failed to apply the requested tail recursion optimization.  This doesn’t let the developer off the hook on unit testing, nor does it relieve the developer from having to code carefully.  What it does is provide early, infallible verification that a crucial feature was implemented.  Why force the developer to examine the classfile or write a bug-vulnerable unit test when the compiler knows the answer and could just cough up the information at compile time?

Don’t forget to subscribe to my RSS feed, or follow this blog on Twitter.
Copyright © 2008 Matthew Jason Malone

As I’ve marveled before, the Scala type system is extremely expressive.  In this post, I want to look at one particular aspect of the type system.  I want to look at a few of the ways to pass behaviors into functions.  Using classes, traits, function types, and a construct that looks like an ad hoc interface (and probably other ways I haven’t learned yet) you can exercise precise control over what units of behavior your functions will accept and use.

First, let’s define a zoo of classes that we will experiment on:

trait TestTrait {
  def func(s1: String): Unit
}

class TestClass1 {
  def func(s1: String): Unit = println("TestClass1: "+s1)
}

class TestClass2 {
  def func(s1: String): Int = { println("TestClass2: "+s1); 999; }
}

class TestClass3 extends TestTrait {
  def func(s1: String) = println("TestClass3: "+s1)
}

class TestClass4 {
  def func(a1: Any) = println("TestClass4: "+a1)
}

class TestClass5 {
  def otherFunc(s1: String) = println("TestClass5: "+s1)
}

class TestClass6 {
  def wrongFunc(list: List[Int]) = println("TestClass6: "+list)
}

There.  These are very simple, but I’ll just point out the salient points of each:

  • TestTrait – Basically, an interface with one method.
  • TestClass1 – One method, func, takes a String and returns Unit (void, to Java coders).
  • TestClass2 – Same as TestClass1, but returns Int.
  • TestClass3 – Same as TestClass1, but also extends TestTrait.
  • TestClass4 – Same as TestClass1, but takes an Any parameter.
  • TestClass5 – Same as TestClass1, but its function has a different name
  • TestClass6 – One method, different name, different parameter type, does not extend TestTrait.

Each of these (except TestTrait) defines some behavior that we might want to pass into a function.  That behavior is trivial in these classes, but I’m trying to keep the examples short.  Let’s now define several functions that might be able to use some of this behavior.

def test1(fn : (String) => Unit) = { fn("test1"); }

def test2(fn : (String) => Int)  = { fn("test2"); }

def test3(obj: { def func(x: String): Unit }) = { obj.func("test3") }

def test4(obj: TestClass1) = { obj.func("test4") }

def test5(obj: TestTrait)  = { obj.func("test5") }

def test6(obj: { def func(x: Any): Unit }) = { obj.func("test6") }

Now we just need to instantiate each one class and start passing parameters!  Create an instance of each class, TestClass1 to TestClass6, and name them tc1 to tc6.  The pass each instance’s function into test1 and test2, and pass the instance itself into test3 through test6.  Let’s do this in the Scala interpreter, rather than compiling.  Like so:

val tc1 = new TestClass1
val tc2 = new TestClass2
val tc3 = new TestClass3
val tc4 = new TestClass4
val tc5 = new TestClass5
val tc6 = new TestClass6

test1(tc1.func)
test1(tc2.func)
test1(tc3.func)
test1(tc4.func)
test1(tc5.otherFunc)
test1(tc6.wrongFunc)

test2(tc1.func)
test2(tc2.func)
test2(tc3.func)
test2(tc4.func)
test2(tc5.otherFunc)
test2(tc6.wrongFunc)

test3(tc1)
test3(tc2)
test3(tc3)
test3(tc4)
test3(tc5)
test3(tc6)

test4(tc1)
test4(tc2)
test4(tc3)
test4(tc4)
test4(tc5)
test4(tc6)

test5(tc1)
test5(tc2)
test5(tc3)
test5(tc4)
test5(tc5)
test5(tc6)

test6(tc1)
test6(tc2)
test6(tc3)
test6(tc4)
test6(tc5)
test6(tc6)

Did you run that?  It didn’t work, did it?  At least not all of it.  Now we we have something to discuss.  Here’s a chart of what did and didn’t work:

test1 test2 test3 test4 test5 test6
tc1 OK FAIL OK OK FAIL FAIL
tc2 OK OK FAIL FAIL FAIL FAIL
tc3 OK FAIL OK FAIL OK FAIL
tc4 OK FAIL FAIL FAIL FAIL OK
tc5 OK FAIL FAIL FAIL FAIL FAIL
tc6 FAIL FAIL FAIL FAIL FAIL FAIL

Taking these results one test function at a time:

test1(fn : (String) => Unit)

All but one of the test objects has at least one member function that can be passed in to this function.  It’s easy to see why this works for tc1, tc3, and tc5.  They each have a function that matches the type of parameter fn exactly.  It’s also easy to see why it doesn’t work for tc6.  TestClass6.wrongFunc takes a List[Int] parameter, which is unrelated to fn’s expected String parameter.

But what about tc2 and tc4?  TestClass2.func takes a String parameter, just like the test1’s fn parameter, but it returns an Int instead of Unit.  But it still works.  Instead of requiring a Unit return type for fn, Scala allows functions that return values.  It simply throws away any returned value and treats the passed in function as though it returned Unit.  Scala’s typesafety requirements in this case are very permissive!  This can be convenient, but be aware that you have very little control of what functions can be passed in when you use this kind of parameter type.

TestClass4.func returns Unit, but it takes a parameter of type Any.  Since test1 knows parameter fn as a function with a String parameter, test1 will only ever pass Strings into fn.  A String “is-a” Any so Scala says it’s ok to pass Strings into TestClass4.func.

This all goes back to variance.  In Scala, functions are contravariant with respect to parameters and covariant with respect to return type.  If you’re not familiar with variance, this means that if function f2’s parameters are simple supertypes (simple, meaning a supertype or the same type) of function f1’s parameters, and if f2’s return type is a simple subtype of f1’s return type then f2 is a subtype of f1 and can be used anywhere f1’s type is required.  The makes sense.  If a function takes type X, then you can pass in any value of type X or a subtype of X.  If a function returns type Y, then the result can be treated as Y or any supertype of Y.  Because Unit is a subtype of Int (and of all reference types) and String is a simple supertype of String, the function type “(String) => Int” is a subtype of “(String) => Unit” and can be used wherever a “(String) => Unit” is required.

test2(fn : (String) => Int)

This one only works for tc2.  TestClass2.func matches the type of parameter fn exactly.  But none of the other functions can be used here.  This is because none of the other functions take a parameter that is a simple subtype of String and return a simple supertype of Int, none of the others will work here.

A quick note about the use of functions in this way:  If you have an instance and pass a member function of that instance as a parameter, the function still has access to that instance’s data.  Moreover, if the passed function changes the state of the object of which it is a member, then function that receives the function as a parameter can mutate the object.  I don’t think I’m being clear.  Here’s some code to demonstrate what I mean.


class TestClass(name: String) {
  var memberStr: String = ""
  def func(str: String) = { memberStr += str }
  override def toString = name + ": " + memberStr
}

val testA = new TestClass("A")
val testB = new TestClass("B")

def test(fn : (String) => Unit, num: Int): Unit = {
  if (num > 0) { fn("X"); test(fn,num-1); }
}

test(testA.func, 5)
test(testB.func, 2)

println(testA + ", " + testB)

See?  Function test just takes a function as a parameter (and an Int).  It does NOT take a TestClass as a parameter, but it’s still able to alter those two instances of TestClass by way of the mutator member function TestClass.func.  So be aware that when you declare a function parameter like this, you never know whether you’re getting a bare function, or a member function.

test3(obj: { def func(x: String): Unit })

Here’s a construct that’s much more strict than the function types used in test1 and test2.  The test3 function take an ad-hoc interface as a parameter.  It’s not a named interface (a Trait), so the parameter’s type doesn’t have to be declared as extending anything.  The parameter is accepted if it has a member function of the given name and type.

Now, in this case, remember, the parameter isn’t the function, it’s the object.  So the permissive behavior that Scala allows in test1 and test2 because of the variance behavior of function types does not apply here.  An object of type A has a function funcA, and object of type B has a function funcB whose type is a subtype of funcA, that does NOT make B a subtype of A.  The names and types of the members declared in the ad-hoc interface must match the parameter exactly.  The parameter can have additional members, but it must at least provide all the members listed in the ad-hoc interface.

In this case, TestClass1 and TestClass3 are the only classes that have functions named func that take a String parameter and return Unit.  The other classes have either the wrong function name, parameter type, or return type.

test4(obj: TestClass1)

This is an easy one.  You know what’s happening here.  Function test4 states explicitly that it requires an object of type TestClass1.  Only a TestClass1 or a subtype of TestClass1 will do.

test5(obj: TestTrait)

Function test5 expects a TestTrait parameter.  Any object that has the trait TestTrait can be passed in.  In this case, TestClass3 is the only class that extends TestTrait.  We could have declared TestCase1 as extending TestTrait, but we didn’t.  Therefor, it doesn’t matter that TestClass1 has a function with the right name, the right parameter types, and the right return type.  It isn’t explicitly defined as extending TestTrait, so it isn’t allowed.

test6(obj: { def func(x: Any): Unit })

This is really just a more extreme example of the phenomenon demonstrated in test3.  This is another ad-hoc interface.  That member function is the supertype of ALL functions that have one parameter and return a reference type.  ALL of them.  You see why?  Any is the superclass of every type and Unit is a subclass of every reference type.  Function types are contravariate with respect to their parameters and covariant with respect to their return types.  But, of course, only our instance of TestClass4 is accepted because this parameter doesn’t have a function type, but an ad-hoc interface type.  It is not the case that all types are covariant with respect to the types of their member functions.  TestClass4 is the only one that has a function with the right name, parameter type, and return type.

Don’t forget to subscribe to my RSS feed, or follow this blog on Twitter.
Copyright © 2008 Matthew Jason Malone

Variance in Java

If you’re a Java developer then you probably know a thing or two about subtyping. B is a subtype of A if B extends or implements A (I’ll use this convention throughout this post). A is the supertype, B is the subtype. But what about arrays? Or generic collections? Is an array of B a subtype of an array of A? Is List<B> a subtype of List<A>? Let’s do some experiments:

Object testObj = null;

String[] arrayB = { "a", "b", "c" };
Object[] arrayA = arrayB;
testObj = arrayA[0];

List<String> listB = new ArrayList<String>();
listB.add("123");
List<Object> listA = (List)listB;
testObj = listA.get(0);

List<Double> listC = new ArrayList<Double>();
listC.add(new Double(10.0));
listB = (List)listC;
System.out.println(listB.get(0));

I’ve left out the boilerplate for brevity. This code compiles and it runs fine, mostly. This tells us three things. First, Java treats an array of B as a subtype of an array of A. We know this because we can use a reference to an array of Objects to refer to an array of Strings. This means an array of Strings “is-a” array of Objects.

Second, we know that an ArrayList<String> is a subtype of a List<String>. This makes sense, because ArrayList implements interface List. An ArrayList<String> “is-a” List<String>. We can even finagle the List<String> into a List<Object>, but we have to make an explicit cast.

Third, we see a ClassCastException during the call to listA.get(0). We use the explicit cast again to assign a List<Double> to a reference to List<String>. The compiler allows this! This kind of sloppy typing is one of the main complaints of Java generics’ detractors (this link includes an example of an erroneous assignment that doesn’t even require an explicit cast!). Now, let’s look at some of the things Java doesn’t allow.

Object[] arrayA = { new Object() };
String[] arrayB = arrayA;

List<String> listB = new ArrayList<String>();
List<Object> listA = listB;

This causes two compiler errors. The first occurs when we try to assign an array of Object to a reference to an array of String. This is sensible. An array of Objects could contain any object: Strings, Doubles, Threads, anything. We can’t treat such an array as an array of Strings.

The second error occurs when we try to assign a list of Strings to a reference to a List of Objects. This is the same code as in the previous snippet, but without the explicit cast. The compiler doesn’t allow this. Since lists are mutable (we can add and remove items) we could add non-String members to listA, which means those non-String members would also be in listB. But if we don’t add any items to listA, the assignment is perfectly safe. Java can’t figure out in all cases when it’s safe to do an assignment and when it is not.

Why do Java generics work the way they do? I think it’s mainly due to two factors: backward compatibility, promoting adoption of the feature. When generics were introduced, there were already millions of lines of code out there that depend on regular, non-generic, mutable collections. To make the new code compatible with legacy code, the type parameters are erased during compilation, and allowing those dangerous casts lets developers work around the fact that List<B> is not a subtype of List<A>. To make it easier to use generics in new code and convert to non-generics for integration with old code, the compiler rules were made fairly permissive.

Variance Terminology

If you’re not familiar with the term “variance”, here’s what it means with respect to Java. Java arrays are covariant. That means that an array of B is a subtype of an array of A, provided that B is a subtype of A. The type-subtype relationship of the arrays follows the relationship of the contents. Lists in Java are invariant (some say nonvariant). A List of Strings has no relationship to a List of Objects. You can explicitly cast a List of Strings to a List of Objects but you can force the conversion the opposite direction, too, (Yuck.) so that doesn’t count. Now consider some hypothetical generic class X such that X<A> is a subtype of X<B> if B is a subtype of A. That’s the opposite of the way arrays work. The hypothetical generic X is contravariant.

One more detail of terminology: A type class could theoretically be covariant with respect to one type parameter, and contravariant with respect to another (not in Java, just theoretically). So, say you have a generic class X that is covariant with respect to its first type parameter and contravariant with respect to its second. So X<B,I> is a subtype of X<A,J> only if B is a subtype of A and J is a subtype of I. Weird, huh? This actually happens in Scala.

Variance in Scala

In Scala, variance is not left to chance. There are very strict rules. Variance with respect to type parameters is spelled out explicitly for each class (or trait). The same conventions are used for Arrays, Lists, or any generic class! The variance system (indeed, the whole type system) in Scala is more complicated and has a steeper learning curve than in Java, but it affords you the ability to write very expressive code that behaves in a more intuitive fashion. Here are three generic classes that use each of the variance types.


class InVar[T]     { override def toString = "InVar" }
class CoVar[+T]     { override def toString = "CoVar" }
class ContraVar[-T] { override def toString = "ContraVar" }
/************ Regular Assignment ************/
val test1: InVar[String] = new InVar[String]
val test2: CoVar[String] = new CoVar[String]
val test3: ContraVar[String] = new ContraVar[String]

The ‘+’ denotes covariance with respect to the type parameter, and ‘-‘ denotes contravariance. The class is invariant with respect to type parameters without a plus or minus. If you run this code you can see that when the type parameters are the same on both sides, the assignments work fine. Now, let’s see what happens when we test assignment for different type parameters.


scala> /************ Invariant Subtyping ************/

scala> val test1: InVar[String] = new InVar[AnyRef]
<console>:5: error: type mismatch;
 found   : InVar[AnyRef]
 required: InVar[String]
       val test1: InVar[String] = new InVar[AnyRef]
                                   ^

scala> val test2: InVar[AnyRef] = new InVar[String]
<console>:5: error: type mismatch;
 found   : InVar[String]
 required: InVar[AnyRef]
       val test2: InVar[AnyRef] = new InVar[String]
                                   ^

scala> /************ Covariant Subtyping ************/

scala> val test3: CoVar[String] = new CoVar[AnyRef]
<console>:5: error: type mismatch;
 found   : CoVar[AnyRef]
 required: CoVar[String]
       val test3: CoVar[String] = new CoVar[AnyRef]
                                  ^

scala> val test4: CoVar[AnyRef] = new CoVar[String]
test4: CoVar[AnyRef] = CoVar

scala> /************ Contravariant Subtyping ************/

scala> val test5: ContraVar[String] = new ContraVar[AnyRef]
test5: ContraVar[String] = ContraVar

scala> val test6: ContraVar[AnyRef] = new ContraVar[String]
<console>:5: error: type mismatch;
 found   : ContraVar[String]
 required: ContraVar[AnyRef]
       val test6: ContraVar[AnyRef] = new ContraVar[String]
                                      ^

Now you can see the difference in the three classes. The invariant class doesn’t allow assignment in either direction, regardless of whether their type parameters have a subtype relationship. The covariant class allows an assignment from subtype to supertype. String is a subtype of AnyRef, so CoVar[String] is a subtype of CoVar[AnyRef]. The contravariant class allows an assignment from supertype to subtype. String, again, is a subtype of AnyRef, so ContraVar[AnyRef] is a subtype of ContraVar[String].

So, it’s as simple as that, right? Sorry. There’s a little more to it. Once you’ve declared a type parameter as covariant or contravariant there are some restrictions on where this type can be used. Why? Scala is not a purely functional langage in that it allows objects to alter their internal state. It allows mutability. Mutability throws a monkey wrench into variance. Say, for example you had a Scala implementation of a linked list like so:


class LinkedList[+A] {
  private var next: LinkedList[A] = null
  def add(item: A): Unit = { ... }
  def get(index: Int): A = { ... }
}

val strList = new LinkedList[String]
strList.add("str1")
val anyList: LinkedList[Any] = strList
anyList.add(new Double(1.0))
val str: String = strList.get(1)

This code won’t compile. Do you see the problem? This code, if it worked, would allow us to create a list of Strings and then add a Double to that list! If we allow that then there’s a disaster when we get to the last line. A LinkedList of Strings returns a Double. That’s no good. That’s the same problem as we saw in Java. Scala nips this sort of code in the bud by disallowing covariant types in certain places including member function parameter types. Places where covariant types are allowed and contravariant types are forbidden are called covariant positions. And the reverse is true, too. If contravariant types are allowed and covariant types are forbidden in some position, this is called a contravariant position.

The above code causes a compiler error for the add method. You can use covariant types in most other places including constructor parameter types, member val types, and method return types. You can also use covariant types as type parameters, but only where covariant types themselves are allowed. This means, in this example, you could add a method that returns a Set[A], but you could not add a method that takes a Set[A] as a parameter, because you could use that passed-in Set[A] to alter the state.

Contravariant types can be used as constructor parameter types, member function parameters types, and as type parameters in each of those positions.

Here’s a fun exercise for the reader. Experiment with using type parameters of the different kinds of variances in different positions. Below is some example code to get you started. For each usage that fails compilation, why is it not allowed? Can you think of a way that such a usage could cause an inconsistency (such as allowing a Double in a String collection, for example)?


class InVar[T](param1: T) {
  def method1(param2: T) = { }
  def method2: T = { param1 }
  def method3: List[T] = { List[T](param1) }
  def method4[U >: T]: List[U] = { List[U](param1) }
  val val1: T = method2
  val val2: Any = param1
  var var1: T = method2
  var var2: Any = param1
}

class CoVar[+T](param1: T) {
  def method1(param2: T) = { }
  def method2: T = { param1 }
  def method3: List[T] = { List[T](param1) }
  def method4[U >: T]: List[U] = { List[U](param1) }
  val val1: T = method2
  val val2: Any = param1
  var var1: T = method2
  var var2: Any = param1
}

class ContraVar[-T](param1: T) {
  def method1(param2: T) = { }
  def method2: T = { param1 }
  def method3: List[T] = { List[T](param1) }
  def method4[U >: T]: List[U] = { List[U](param1) }
  val val1: T = method2
  val val2: Any = param1
  var var1: T = method2
  var var2: Any = param1
}

Maybe in a future post, I’ll do a more thorough analysis of all the places type parameters of the different variances can be used. If you’d be interested in reading such a thing, please do leave a comment.

Conclusion

What was the point of all that? We still don’t get the mutable, covariant collections we were hoping for. But we do get two things we don’t get in Java. We get invariant, typesafe, mutable collections that won’t wind up holding objects of the wrong type, and we get covariant, typesafe, immutable collections. If you’re new to functional programming (like I am), your first thought might be, “What good is an immutable list? Or an immutable array? If it’s immutable, I can’t add any items to it, right?”

Take the Scala List as an example. It’s declared as a “class List[+A]” and it has a method for adding new items that’s declared “def + [B >: A](x : B) : List[B]”. So List is covariant with respect to its type parameter A. That’s great! So a List[BigInt] “is-a” List[Number] and it “is-a” List[Object]. To add items to a List use the ‘+’ function. It doesn’t change the List it was called on, but it does return a new List.

Also, you can add anything to a List. What you add affects what gets returned. Look at the function definition again: “def + [B >: A](x : B) : List[B]”. It take a parameter of type B, where B is any supertype of A (or A itself), and it returns a List[B]. Let’s consider the simple case, and then something more complex.

scala> var strList = List[String]("abc")
strList: List[String] = List(abc)

scala> strList = strList + "xyz"
strList: List[String] = List(abc, xyz)

scala> var objList = strList + new Object()
objList: List[java.lang.Object] = List(abc, xyz, java.lang.Object@156ee8e)

scala> var anyList = strList + 3.1416
anyList: List[Any] = List(abc, xyz, 3.1416)

First we create a List[String] called strList containing one item. Then we add a second String and store the resulting List back in the strList variable. Then we call the ‘+’ function with an Object parameter. This is allowed because the parameter must have type B where B is a supertype of A, and Object is indeed a superclass of String. The call to ‘+’ returns a List[Object] which contains both the Strings and the Object. Simple.

Then we call the ‘+’ method on testList again. Remember, strList still just contains the two Strings because it’s immutable and we didn’t assign the last result back to strList. We couldn’t have. The result was a List[Object], not a List[String], and List is covariant, not contravariant. This time we supply a parameter of type Double. But Double isn’t even a supertype of String! That’s ok. Scala determines the nearest common ancestor, the type Any, and uses that as B. The ‘+’ function returms a List[Any] that contains the Strings and the Double. That’s handy. And covariant and typesafe.

Don’t forget to subscribe to my RSS feed, or follow this blog on Twitter.
Copyright © 2008 Matthew Jason Malone

The way Scala deals with functions is pretty interesting. If you want to use them as you would use Java functions then they’re not that complicated. You have to learn the syntax, a little about the Scala type system, and bada-bing, you’re in business. But if you start exploring the way types are implemented in Scala you find some interesting stuff.

First, we’ll briefly describe the basics.  Here are a few very simple Scala function definitions and invocations in the scala interpreter:


scala> def method1() = { println("method1") }
method1: ()Unit

scala> def method2(str: String) = { println("method2: " + str) }
method2: (String)Unit

scala> def method3(str: String): Int = {
     |   println("method3: " + str); str.length;
     | }
method3: (String)Int

scala> def method4(f: (String) => Int) = {
     |   printf("method4: " + f("method4"))
     | }
method4: ((String) => Int)Unit

scala> method1
method1

scala> method2("abc")
method2: abc

scala> method3("abcdefg")
method3: abcdefg
res13: Int = 7

scala> method4(method3)
method3: method4
method4: 7

Very basic:  method1 takes no parameters and returns nothing, method2 takes a single parameter of type String and returns nothing, method3 takes a String parameter and returns an Int, and method4 takes a parameter of type “function that takes a String parameter and returns Int” and returns nothing.

Why are we able to declare functions like this in Scala?  Didn’t I read somewhere that Scala is very object oriented?  Didn’t I read that everything is an object?  Why do we have these bare naked functions defined outside of objects?  The reason is that in Scala, everything really is an object, even functions!  That method1 we defined?  That’s an object.  When we type “def method1() = {…}” we actually declared an instance of a special class.  I’ll declare method1 again, but with the underlying object exposed:


scala> val method1 = new Function0[Unit] {
     |   def apply: Unit = { println("method1") }
     | }
method1: java.lang.Object with () => Unit = <function>

scala> method1
res1: java.lang.Object with () => Unit = <function>

scala> method1.apply
method1

scala> method1()
method1

We instantiate an instance of trait Function0[Unit] and implement its one abstract method, called apply, and assign it to a val named method1.  Now you can see method1 is actually just a plain old Scala object.  When we type in “method1” and hit enter, the interpreter just tells us the resulting value of the statement which is an Object with trait Function0.  Hmm, that didn’t work.  Next we try calling the apply method on the object.  That works!  But it’s just a regular call to a member method.  But when we type “method1()” then Scala knows that we want to use this object as a function, and that we’re not refering to the object itself.  When you declare a function using “def” Scala assumes that when you refer to the method you want the apply method invoked, and that you don’t want to return the function object.  Neat.

That Function0[Unit], by the way, defines a function that takes 0 parameters and returns Unit (which is to say nothing as in Java void (not to be confused with Nothing)).  If you want a function that takes two parameters, an Int and a String, and returns a List of Doubles, you would use Function2[Int, String, List[Double]].  So class FunctionX takes (X+1) type parameters, the first X of which define the function parameter types, and the last of which defines the return type.

So what if we go the other way?  What if we declare a method and then store it in a val?  In this case, Scala gets very picky.  Watch this:


scala> def method2 = { println("method2") }
method2: Unit

scala> val m2: () => Unit = method2
<console>:5: error: type mismatch;
 found   : Unit
 required: () => Unit
       val m2: () => Unit = method2
                            ^

scala> def method2() = { println("method2") }
method2: ()Unit

scala> val m2: () => Unit = method2
m2: () => Unit = <function>

scala> def method2 = { println("method2") }
method2: Unit

scala> val m2: () => Unit = method2 _
m2: () => Unit = <function>

Some strange stuff happens here.  First we just define a function called method2.  Nothing fancy.  Then we try to assign it to a val of type () => Unit.  It fails.  See the error message?  Found : Unit.  It parses it all wrong.  Scala thinks we’re trying to call method2 and assign the result to m2.  How can we set things straight?  Well, one way is to slightly change the way we define method2.  The only difference in the first and second definition is the addition of an empty parameter list, that empty pair parentheses.  For some reason, when we define the method in this apparently equivalent fashion, Scala rightly interprets our intentions and allows us to assign to m2.  There is another way, though.  In the third definition of method2, we’ve again removed the parentheses.  But this time we assign it successfully to val m2 by following method2 with an underscore.  The underscore just causes Scala to treat method2 as a Function0 object, rather than attempting to invoke it.

So a function is an object.  Who cares?  We’re still just calling functions.  Ah, but use your imagination.  You can do all kinds of tricks once you realize that a function is just an object.  For example:


scala> class TestClass {
     |   def f1(): Unit = { println("f1!!!"); func = f2 }
     |   def f2(): Unit = { println("f2!!!"); func = f3 }
     |   def f3(): Unit = { println("f3!!!"); func = f1 }
     |
     |   var func: () => Unit = f1
     |
     |   def test = { func() }
     | }
defined class TestClass

scala> val tc = new TestClass
tc: TestClass = TestClass@1eff71e

scala> tc.test
f1!!!

scala> tc.test
f2!!!

scala> tc.test
f3!!!

scala> tc.test
f1!!!

See what’s happening here?  We can store a reference to a function object, call the function it refers to, and re-assign it.  So the method “test” actually calls a different function each time.

Can you guess why I added the test method instead of just calling func directly?  func is declared with the var keyword, so if I entered “tc.func” instead of “tc.func()” then the interpreter would think I was refering to the function object.  Just so there’s no confusion, I wrapped the call “func()” inside a regular def-defined function called test.

Let’s see, what other neat tricks can we do?  Here’s something interesting:


scala> def printAll(str1: String, str2: String, str3: String): Unit = {
     |   println( str1 + ":" + str2 + ":" + str3 )
     | }
printAll: (String,String,String)Unit

scala> def fillInStr1(func: (String,String,String) => Unit, str1: String): (String,String) => Unit = {
     |   new Function2[String,String,Unit] {
     |     def apply(str2: String, str3: String) = {
     |       func(str1, str2, str3)
     |     }
     |   }
     | }
fillInStr1: ((String, String, String) => Unit,String)(String, String) => Unit

scala> val newPrint = fillInStr1(printAll _, "test123")
newPrint: (String, String) => Unit = <function>

scala> newPrint("abc","xyz")
test123:abc:xyz

scala> newPrint("123","456")
test123:123:456

First, we define printAll.  It’s just a function that prints out its 3 string parameters.  The next method, fillInStr1, is the interesting one.  The method signature is kind of complex.  It takes 2 parameters, func and str1.  func is a function taking 3 String parameters and returning nothing.  str1 is just a String.  fillInStr1 returns a function taking 2 String parameters and returning nothing.

Inside fillInStr, it just creates an instance of Function2, a function that takes 2 parameters.  This function object is defined so that the apply method calls the func function and passes str1 as the first parameter.  The other two parameters are the parameters of the Function2’s apply method.  Do you see what it’s doing?  It’s taking a function on 3 strings, and transforming it into a function on only 2 strings.  We can call fillInStr1 by passing in printAll (note the underscore), and a string.  What we get back is a function that behaves just like printAll, except with the first parameter already filled in.  Neat trick!

In fact, this trick is so neat that it has a name and is actually built into the language.  This little demonstration is a very simple, non-generalized application of a concept called currying.  The Code Commit blog has an excellent article on function currying in Scala if you’d like to know more about it.

This, of course, isn’t all there is to functions.  There’s a lot more!  But now you know enough to go out there and start experimenting.  See what tricks you can do, what problems you can solve with Scala’s versatile and powerful function objects.

Don’t forget to subscribe to my RSS feed, or follow this blog on Twitter.
Copyright © 2008 Matthew Jason Malone