In Java you don’t see a lot of linked lists, and if you do it’s almost always java.util.LinkedList. People never write their own lists. They don’t really need to, I suppose. The one from java.util is fine. Plenty of people are leading fulfilling software careers never having implemented their own linked list. But it’s kind of a shame. Knowing how your data structures work makes you a better programmer.

It’s even rarer for a person to implement his own linked list in Scala. Scala’s scala.List is one of the most used classes in the language, so it’s packed with functionality. It’s abstract, covariant, it has helper objects such as List and Nil and the little-known ‘::’ class, it inherits from Product, Seq, Collection, Iterable, and PartialFunction. The machinery of List pulls in Array, ListBuffer, and more. It can be hard to take it all in.

So let’s build our own linked list. We’ll start out with something very basic and un-Scala-like. Then we’ll improve it gradually until we have something a little closer to scala.List. I encourage you to fire up your Scala interpreter and follow along.

Back to Basics

First, a short review of linked lists. A linked list is a chain of nodes, each referring to exactly one other node until you get to the end of the chain. You refer to the list by its first item and you follow the chain of references to reach the other nodes.

What are the requirements for our first try? Our node should be able to hold a piece of data, and refer to the next node. It should also be able to report its length, and provide a toString method so we can visualize the list. Here we go.

class MyList(val head: Any, val tail: MyList) {
  def isEmpty = (head == null && tail == null)
  def length: Int = if (isEmpty) 0 else 1 + tail.length
  override def toString: String = if (isEmpty) "" else head + " " + tail
}

The value ‘head’ holds the data for the node, ‘tail’ refers to the next element in the chain. The ‘isEmpty’ method is true if the head and tail are both null. The length and toString methods are both defined using a similar pattern: if (isEmpty) [base result] else [data for current node + result of same method on tail].

Here’s what it looks like when we use this class:

scala> var list = new MyList(null, null)
list: MyList =

scala> list.length
res0: Int = 0

scala> list.isEmpty
res1: Boolean = true

scala> list = new MyList("ABC", list)
list: MyList = ABC

scala> list.length
res3: Int = 1

scala> list.isEmpty
res4: Boolean = false

scala> list = new MyList("XYZ", list)
list: MyList = XYZ ABC

scala> list = new MyList("123", list)
list: MyList = 123 XYZ ABC

scala> list.tail.head
res7: Any = XYZ

Not bad. It gets the job done. But it has some problems. First is the use of ‘null’. Use of null references is sloppy and increases the odds of a null pointer exception so, ideally, we don’t want to see that. It has other problems, too. It’s too verbose. It’s not typesafe. But for now let’s concentrate on getting rid of the nulls.

No Nulls Is Good Nulls

How can we do it? We’re using the null as a special value, a marker to tell us when a node is at the end of a list. So we’ll just use something else as that marker instead. What can we use? We’ll create a special object for the empty list. It will be recognized as empty just based on its identity, not on null values. So let’s try it:

class MyList(val head: Any, val tail: MyList) {
  def isEmpty = false
  def length: Int = if (isEmpty) 0 else 1 + tail.length
  override def toString: String = if (isEmpty) "" else head + " " + tail
}

object MyListNil extends MyList("arbitrary value", null) {
  override def isEmpty = true
}

That’s better. (The observant reader will note the similarity of MyListNil to scala.List’s Nil object.) We got rid of the nulls in the isEmpty method, but we still have to put something in the head and tail parameters of the MyList constructor. We put an arbitrary non-null value in head, but what do we put for tail? Either null or create a new MyList. And how can that MyList be instantiated? It also needs a tail. Vicious circle. So this solution leaves us still stuck with a null.

Earlier, the null was there to mark a special node. We factored out that usage. Now it’s there to allow us to create the MyListNil. How can we factor that out? MyListNil is required to call its parent’s constructor. What if had no parent? Then it wouldn’t be a MyList anymore. What if it had an abstract parent? Now you’re talking. Let’s see what that would look like.

abstract class MyList {
  def head: Any
  def tail: MyList
  def isEmpty: Boolean
  def length: Int
}

class MyListImpl(val head: Any, val tail: MyList) extends MyList {
  def isEmpty = false
  def length: Int = 1 + tail.length
  override def toString: String = head + " " + tail
}

object MyListNil extends MyList {
  def head: Any = throw new Exception("head of empty list")
  def tail: MyList = throw new Exception("tail of empty list")
  def isEmpty = true
  def length = 0
  override def toString =  ""
}

It’s a little more code, but much neater. There are no nulls anywhere. Here’s how it looks when we use this new MyList:

scala> var list: MyList = MyListNil
list: MyList =

scala> list = new MyListImpl("ABC", list)
list: MyList = ABC

scala> list = new MyListImpl("XYZ", list)
list: MyList = XYZ ABC

scala> list = new MyListImpl("123", list)
list: MyList = 123 XYZ ABC

scala> list.length
res3: Int = 3

scala> list.tail.head
res4: Any = XYZ

scala> list.tail.tail.tail.head
java.lang.Exception: head of empty list
        at ...

Pretty neat. The equivalent of MyListImpl in the Scala’s real List implementation is a class called ‘::’, which has that funny name, by the way, because it looks nice in pattern matching code. Sometimes ‘::’ is referred to as cons. With nulls finally eliminated, we can concentrate on other issues.

Brevity Is The Heart Of List

The thing that I notice at this point is that a lot of typing (on the keyboard) is required to use this list. We have to type out “list = new MyListImpl(…, list)” every time we add an item. We can improve this with a new method.

abstract class MyList {
  [...]
  def add(item: Any): MyList = new MyListImpl(item, this)
}

Now we have classes referring to each other. MyList creates new MyListImpls, and MyListImpl extends MyList. So you’ll need to put these classes in a .scala file and compile them instead of just typing them into the Scala interpreter. But, wow! Look how much easier it is to use MyList now:

scala> var list = MyListNil add("ABC") add("XYZ") add("123")
list: MyList = 123 XYZ ABC

scala> list.length
res1: Int = 3

So much easier! One thing I notice, though, is that the order of items in the code is different from the order produced by toString. We can change our ‘add’ method so that is right-associative instead of left-associative by using a method name that ends in ‘:’ (colon). We’ll use ‘::’ as the method name since that’s what scala.List uses.

abstract class MyList {
  [...]
  def ::(item: Any): MyList = new MyListImpl(item, this)
}

scala> var list = "ABC" :: "XYZ" :: "123" :: MyListNil
list: MyList = ABC XYZ 123

Now we’re really getting somewhere. This is starting to look more like scala.List. One other thing that the standard list implementation gives you is a shortcut for initializing lists. It looks like “List(1, 2, 3, 4)”. Notice there’s no ‘new’ keyword. This is done using the scala.List helper object and its ‘apply’ method. Below is our own MyList helper object.

object MyList {
  def apply(items: Any*): MyList = {
    var list: MyList = MyListNil
    for (idx <- 0 until items.length reverse)
      list = items(idx) :: list
    list
  }
}

scala> var list = MyList("ABC", "XYZ", "123")
list: MyList = ABC XYZ 123

scala> list = "Cool" :: list
list: MyList = Cool ABC XYZ 123

Better Type-Safe Than Sorry

Better. Our code looks much neater now when we use MyList. I’ll introduce just one more improvement to MyList. It still has a rather glaring problem. It provides no type information. It keeps all of its data using a reference to Any. If you don’t see why this is a problem, let’s see what happens when we want to get the length of some items in a MyList:

scala> var list = MyList("ABC", 12345, "WXYZ")
list: MyList = ABC 12345 WXYZ

scala> list.head.length
<console>:6: error: value length is not a member of Any
       list.head.length
                 ^

scala> list.head.asInstanceOf[String].length
res10: Int = 3

scala> list.tail.head.asInstanceOf[String].length
java.lang.ClassCastException: java.lang.Integer cannot be cast to java.lang.String
        at .<init>(<console>:6)

Ouch! First, when we try to call method ‘length’ on list.head Scala complains that list.head is a reference to Any. Any doesn’t have a length method. This isn’t a dynamically typed language like, say, Ruby. An object has to have the right type before we can start calling methods. What to do? You could implement a MyStringList where the head has type String. But then you’ll need a MyIntList, MyDoubleList, etc. What we need is a way to specify the type of data in the list when we create the MyList instance. What we need is a type parameter.

Here’s the complete MyList code using a type parameter, and a little demonstration code:

abstract class MyList[A] {
  def head: A
  def tail: MyList[A]
  def isEmpty: Boolean
  def length: Int
  def ::(item: A): MyList[A] = new MyListImpl[A](item, this)
}

class MyListImpl[A](val head: A, val tail: MyList[A]) extends MyList[A] {
  def isEmpty = false
  def length: Int = 1 + tail.length
  override def toString: String = head + " " + tail
}

object MyListNil extends MyList[Nothing] {
  def head: Nothing = throw new Exception("head of empty list")
  def tail: MyList[Nothing] = throw new Exception("tail of empty list")
  def isEmpty = true
  def length = 0
  override def toString =  ""
}

object MyList {
  def apply[A](items: A*): MyList[A] = {
    var list: MyList[A] = MyListNil.asInstanceOf[MyList[A]]
    for (idx <- 0 until items.length reverse)
      list = items(idx) :: list
    list
  }
}

scala> var list = MyList("ABC", "WXYZ", "123")
list: MyList[java.lang.String] = ABC WXYZ 123

scala> list.head.length
res0: Int = 3

scala> 3.14159 :: list
<console>:6: error: type mismatch;
 found   : Double
 required: java.lang.String
       3.14159 :: list
               ^

scala> var list = MyList("ABC", 123, 3.14159)
list: MyList[Any] = ABC 123 3.14159

Look at line 32. The “MyList(…)” returns a MyList[String]. Scala figures out from the parameters what type to use. In line 35, you can see how much easier it is to use the list contents when you know the type at compile time.

If you try to mix types, as in line 45, Scala determines the nearest common ancestor of the types (Any, in this case) and uses that. However, if the type parameter is already determined, as in line 38, it won’t change when you try to add data of a different type. To make line 38 work, we can make a small change to the ‘::’ method:

abstract class MyList[A] {
  [...]
  def ::[B >: A](item: B): MyList[B] = 
    new MyListImpl(item, this.asInstanceOf[MyList[B]])
}

scala> var list = MyList("ABC", "XYZ")
list: MyList[java.lang.String] = ABC XYZ

scala> 3.14159 :: list
res0: MyList[Any] = 3.14159 ABC XYZ

This says that ‘::’ takes a parameter of type B, which is either A or a superclass of A, and returns a MyList[B]. So if you have a MyList[String] and you call ‘::’ on it with a Double parameter, Scala figures out that although Double is not a superclass of String, String and Double are both descendants of Any, and it returns a MyList[Any].

Conclusion

That’s a good stopping point for now. Obviously you can take the MyList class a lot further and add a lot more methods, but we’ve created some code that approximates the basics provided by scala.List. In fact, you could take several of the scala.List methods (foldLeft, for example) and basically drop them right into MyList and they’d work fine.

Don’t forget to subscribe to my RSS feed, or follow this blog on Twitter.
Copyright © 2009 Matthew Jason Malone

One of my favorite functional programming tricks is folding. The fold left and fold right functions can do a lot of complicated things with a small amount of code. Today, I’d like to (1) introduce folding, (2) make note of some surprising, nay, shocking fold behavior, (3) review the folding code used in Scala’s List class, and (4) make some presumptuous suggestions on how to improve List.

Update: I’ve created a new post in which I list lots and lots of foldLeft examples in case you’d like to learn more about what folding can accomplish.

Know When to Hold ‘Em, Know When to Fold ‘Em

In case you’re not familiar with folding, I’ll describe it as briefly as I can.

Here’s the signature of the foldLeft function from List[A], a list of items of type A:

def foldLeft[B](z: B)(f: (B, A) => B): B

Firstly, foldLeft is a curried function (So is foldRight). If you don’t know about currying, that’s ok; this function just takes its two parameters (z and f) in two sets of parentheses instead of one. Currying isn’t the important part anyway.

The first parameter, z, is of type B, which is to say it can be different from the list contents type. The second parameter, f, is a function that takes a B and an A (a list item) as parameters, and it returns a value of type B. So the purpose of function f is to take a value of type B, use a list item to modify that value and return it.

The foldLeft function goes through the whole List, from head to tail, and passes each value to f. For the first list item, that first parameter, z, is used as the first parameter to f. For the second list item, the result of the first call to f is used as the B type parameter.

For example, say we had a list of Ints 1, 2, and 3. We could call foldLeft(“X”)((b,a) => b + a). For the first item, 1, the function we define would add string “X” to Int 1, returning string “X1”. For the second list item, 2, the function would add string “X1” to Int 2, returning “X12”. And for the final list item, 3, the function would add “X12” to 3 and return “X123”.

Here are a few more examples.

list.foldLeft(0)((b,a) => b+a)
list.foldLeft(1)((b,a) => b*a)
list.foldLeft(List[Int]())((b,a) => a :: b)

The first line is super simple. It’s almost like the example I described above, but the z value is the number 0 instead of string “X”. This fold combines the elements of the list by addition instead of concatenation. So the fold returns the sum of all Ints in the list. Line 2 combines them through multiplication. Do you see why the z value is 1 in this case?

Line 3 is a little more complex. Can you guess what it does? It starts out with an empty list of Ints and adds each item to the accumulator (We call the b parameter of our function the accumulator because it accumulates data from each of our list items). Because it starts with the head and adds to the beginning of the accumulator list until it gets to the last item of the original list, it returns the original list in reverse order.

The foldRight function works in much the same way as foldLeft. Can you guess the difference? You got it. It starts at the end of the list and works its way up to the head.

Folds can be used for MUCH more than I’ve shown here. With folds, you can solve lots of different problems with a standard construct. You should read up on them if you’re just starting out in functional programming.

All That Glitters Is Not Fold

Now for the moment you’ve been waiting for. Fold’s dirty little secret! The below is taken from a scala interpreter session.

scala> var shortList = 1 to 10 toList
shortList: List[Int] = List(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)

scala> var longList = 1 to 325000 toList
longList: List[Int] = List(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 
15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, ...

scala> shortList.foldLeft("")((x,y) => "X")
res1: java.lang.String = X

scala> shortList.foldRight("")((x,y) => "X")
res2: java.lang.String = X

scala> longList.foldLeft("")((x,y) => "X")
res3: java.lang.String = X

scala> longList.foldRight("")((x,y) => "X")
java.lang.StackOverflowError
        at scala.List.foldRight(List.scala:1079)
        at scala.List.foldRight(List.scala:1081)
        at scala.List.foldRight(List.scala:1081)
        at scala.List.foldRight(List.scala:1081)
        at scala.List.foldRight(List.scala:1081)
        at scala.List.foldRight(List.scala:1081)
        at scala.List.foldRight(List.scala:1081)
        at scala.List.foldRight(List.scala:1081)
        at scala.List.foldRig...

We created two lists: shortList with 10 items, and longList with 325,000 items. Then we perform a trivial foldLeft and foldRight on shortList. It’s trivial because the passed-in function always returns the string “X”; it doesn’t even use the list data.

Then we do a foldLeft on longList. This goes off without a hitch. Finally we try to do a foldRight, the same foldRight that succeeded on the shorter list, and it fails! The foldLeft worked. Why didn’t the foldRight work? It’s a perfectly reasonable call against a perfectly reasonable List. Something funny is going on here.

The error message says there was a stack overflow, and the stack trace shows a long list of calls at List.scala:1081. If you’ve read my post about tail-recursion, then you probably suspect that some recursive code is to blame.

Let’s look into List.scala, maybe the single most important Scala source file.

Fool’s Fold

Without further ado, here’s the code for foldLeft and foldRight from List.scala:

override def foldLeft[B](z: B)(f: (B, A) => B): B = {
  var acc = z
  var these = this
  while (!these.isEmpty) {
    acc = f(acc, these.head)
    these = these.tail
  }
  acc
}

override def foldRight[B](z: B)(f: (A, B) => B): B = this match {
  case Nil => z
  case x :: xs => f(x, xs.foldRight(z)(f))
}

Wow! Those two definitions are very different!

The foldLeft function is the one that worked for short and long lists. You can see why? It isn’t head-recursive. In fact, it isn’t recursive at all. It is implemented as a while loop. On each iteration, the next list item is passed to the function f and the accumulator (called acc) is updated. When there are no more list items, the accumulator is returned. No recursion means no stack overflows.

The foldRight function is implemented in a totally different way. If the list is empty, the z parameter is returned. Otherwise, a recursive call is made on the tail (the whole list minus the first item) of this list, and that result is passed to the function f. Study the foldRight definition. Do you understand how it works? It’s an elegant recursive solution, and the code really is quite pretty, but it’s not tail recursive so it fails for large lists.

Why didn’t Mr Odersky just write foldRight using a while loop, too? Then this problem wouldn’t exist, right? The reason is that Scala’s List is a implemented as a singly-linked list. Each list element has access to the next item in the list, but not to the previous item. You can only traverse a list in one direction! This works fine for foldLeft, which goes from head to tail, but foldRight has to start at the end of the list and work its way forward to the head. If foldRight uses recursion, it must recurse all the way to the end and then use the results of those recursive calls as the accumulator passed into function f.

See? The results of the recursive call must be used for further calculation, so the recursive call can’t be the last thing that happens, so it can’t be written as a tail-recursive function. If you don’t know what I’m talking about, read my introduction to tail-recursion.

Out With The Fold, In With The New

So is that it for foldRight? Is it hopeless? I say no!

There is a way to get the same result as foldRight, but using foldLeft. Can you guess what it is? Here’s how:

list.foldRight("X")((a,b) => a + b)
list.reverse.foldLeft("X")((b,a) => a + b)

These two lines are equivalent! They give the same result no matter what’s in list. Since foldRight processes list elements from last to first, that’s the same as processing the reversed list from first to last.

Here are three possible implementations of foldRight that could replace the current one.

def foldRight[B](z: B)(f: (A, B) => B): B = 
  reverse.foldLeft(z)((b,a) => f(a,b))

def foldRight[B](z: B)(f: (A, B) => B): B =
  if (length > 50) reverse.foldLeft(z)((b,a) => f(a,b))
  else             originalFoldRight(z)(f)

def foldRight[B](z: B)(f: (A, B) => B): B =
  try {
    originalFoldRight(z)(f)
  } catch {
    case e1: StackOverflowError => reverse.foldLeft(z)((b,a) => f(a,b))
  }

The first one simply replaces the original recursive logic with the equivalent call to reverse and foldLeft. Why wasn’t foldRight implemented this way to begin with? It may be, in part, that the authors thought the extra overhead of reversing the list was unwarranted. To me, it doesn’t seem that bad. The original foldRight and foldLeft functions are O(n), meaning they run in an amount of time roughly proportional to the number of items in the list. If you look at the source for the reverse function, you’ll see it’s also O(n). So running reverse followed by foldLeft is O(n).

The second implementation is a compromise. It uses the original recursive version of foldRight (referred to as originalFoldRight in the above code) only when the list is shorter than 50 elements. The reverse.foldLeft is used for lists of 50 elements or longer. 50 is just an arbitrary number, just a guess at a sensible limit on the number of recursive calls to allow.

The third implementation tries the original foldRight logic first and if the call stack overflows then it uses reverse.foldLeft. This solution is, of course, completely ridiculous, but even this would be better than a foldRight which sometimes crashes your program.

That’s All, Folds!

As I pointed out before, the reverse.foldLeft implementation of foldRight is O(n), same as the original recursive version. The original foldRight may work just fine when your Scala application is young and working with small data sets. Over time more customers are added, more products are created, more orders are placed, and then one day, *POOF*, a runtime error! It’s a ticking time-bomb.

As you may well guess, I would like to see the reverse.foldLeft logic used instead of the recursive version. That would prevent the stack overflow errors. But I would settle for just deprecating foldRight. It would be better to eliminate foldRight and force the coder to work around it than to leave it in its current state. In fact, I don’t think any head-recursive functions belong in the List class.

Do any readers have any insight into why foldRight is coded the way it is?

Don’t forget to subscribe to my RSS feed, or follow this blog on Twitter.
Copyright © 2009 Matthew Jason Malone