Archive for September, 2008

In my last post (Tail-Recursion Basics In Scala), I went on and on about how scala, as a functional programming language, was written to accommodate recursion. When you need to do a task over and over the procedural way is to iterate, and the functional way is to recurse. That’s what I thought. But I was browsing the source for the scala List class and I noticed something weird. It’s chock full of iteration and procedural code!

Well, I don’t have to tell you I was a little disappointed. I never claimed to have unreservedly swallowed the FP hype, but I figured the authors of the scala.List must surely practice what they preach. I’m puzzled. To show you what I mean, here are a few methods written in a procedural style just as they appear in scala.List, alongside my attempt to write them in a functional, tail-recursive fashion. This isn’t an exhaustive list by any means. There are tons of methods in class scala.List like these.

Mark the following! I am not trying to to correct the authors of scala.List or point their code out as inferior. On the contrary. They know what they are doing. I do not. They implemented a general-use functional/OO programming language. I have not. I am merely wondering aloud why scala.List was implemented the way it was.

I’m using http://lampsvn.epfl.ch/trac/scala/browser/scala/tags/R_2_7_2_RC2/src/library/scala/List.scalaas my source.

List.length

I used a length function as an example in my last post on recursion. The length method of List returns a count of the items in the List. First, here’s the actual code for the length method:

sealed abstract class List[+A] extends Seq[A] with Product {
...
  def length: Int = {
    var these = this
    var len = 0
    while (!these.isEmpty) {
      len += 1
      these = these.tail
    }
    len
  }
...
}

See why I called this code procedural? Those vars and that while loop. Vars are used for variables whose values change. There’s nothing wrong with it, but it’s a feature of procedural code. Iterative loops are also more procedural than functional. You could argue that the while is not a procedural loop, but a functional closure. But I know iteration when I see it. Why wasn’t it written using recursion? Here’s how it could be implemented in a more functional style.

object List {
...
  def length(list: List[_], len: Int): Int = {
    if (list.isEmpty) len
    else length(list.tail, len + 1)
  }
...
}

sealed abstract class List[+A] extends Seq[A] with Product {
...
  def length: Int = List.length(this, 0)
...
}

There is a List class as well as a List object. I put the recursive function in the List helper object because I couldn’t get the tail recursion optimization when I implemented it as a recursive member function within class List.

List.indices

The indices method returns a List of the zero-based indices of a List. If a list has 4 items, its indices method returns List(0, 1, 2, 3). Once again, we’ll first look at the actual code for the indices method:

sealed abstract class List[+A] extends Seq[A] with Product {
...
  def indices: List[Int] = {
    val b = new ListBuffer[Int]
    var i = 0
    var these = this
    while (!these.isEmpty) {
      b += i
      i += 1
      these = these.tail
    }
    b.toList
  }
...
}

As before, this code steps through the list in a while loop. Here’s how it could be done using recursion:

object List {
...
  def indices(list: List[_], idxList: List[Int], curIdx: Int): List[Int] = {
    if (list.isEmpty) idxList
    else List.indices(list.tail, idxList + curIdx, curIdx + 1)
  }
...
}

sealed abstract class List[+A] extends Seq[A] with Product {
...
  def indices: List[Int] = List.indices(this, Nil, 0)
...
}

List.last

This is a very simple method. It just returns the last item in the List. Again, here’s the actual code for the List.last method:

sealed abstract class List[+A] extends Seq[A] with Product {
...
  override def last: A =
    if (isEmpty) throw new Predef.NoSuchElementException("Nil.last")
    else {
      var cur = this
      var next = this.tail
      while (!next.isEmpty) {
        cur = next
        next = next.tail
      }
      cur.head
    }
...
}

Here’s what it might look like in a functional style.

object List {
...
  def last[A](list: List[A]): A =
    if (list.isEmpty) throw new Predef.NoSuchElementException("Nil.last")
    else {
      if (list.tail.isEmpty) head
      else last(list.tail)
    }
...
}

sealed abstract class List[+A] extends Seq[A] with Product {
...
  override def last: A = List.last(this)
...
}

List.foreach

Ok, last one. This method runs a function once for each item in the List, passing the item as a parameter each time. Here’s the real code:

sealed abstract class List[+A] extends Seq[A] with Product {
...
  final override def foreach(f: A => Unit) {
    var these = this
    while (!these.isEmpty) {
      f(these.head)
      these = these.tail
    }
  }
...
}

And, again, here’s the code rendered in a more functional style.

object List {
...
  def foreach[A](list: List[A], f: A => Unit) {
    if (!list.isEmpty) {
      f(list.head)
      foreach(list.tail, f)
    }
  }
...
}

sealed abstract class List[+A] extends Seq[A] with Product {
...
  final override def foreach(f: A => Unit) = List.foreach(this, f)
...
}

So What?

So what? What does it matter that while loops are used instead of tail recursion? Well, it doesn’t really particularly matter. Not to me, anyway. But it’s surprising that proponents, nay, authorsof a functional language would write (perfectly good) procedural code like this.

I don’t know why. Maybe it’s for the benefit of the Java coders whom scala’s creators hope will become early adopters. Maybe the functions in the List object are frowned upon for some design reason. More likely it’s some technical issue that my inexperienced eye can’t discern. Is there some reason my examples would not work? If you know, please share.

Update: By the way, Debasish Ghosh wrote a related article today. It includes a comparison of the List implementation in Erlang and Scala.

Recursion 101

We all know what recursion is, right?  A function calls itself.  Or function A calls function B which calls function A.  Or A calls B, which calls C, which calls A, etc.  By far the most common situation is that in which a function calls itself.

You don’t see a lot of recursive code in Java.  There are a couple reasons for this.  First, recursion is hard.  It’s not intuitive.  With iteration (the main alternative to recursion) you see the big picture.  You can look at the whole loop and its easy to understand even for the beginner.  With recursion, you see one layer and you have to imagine what happens when those layers stack up.  Like anything else, if you practice iteration then it becomes easier, but compared to iteration, recursion is hard to learn.

Second, Java is not designed to accomodate recursion.  It’s designed to accomodate iteration.  Java gives you for loops, for-each loops, while loops, do loops, arrays, Iterators, ResultSets, etc.  These constructs are all about iteration.  Plus, recursion in Java has an Achilles’ heel: the call stack.

Generally, when you call a function in any language a new level is added to the call stack.  The call stack, we all know, is what keeps track of local variables, what function has called what, etc.  It’s no different in Java.  But the stack has a finite and limited size.  Recursion is fine if you know for certain that you’ll never go more than a few dozen levels deep.  But if the recursion goes too deep, you run out of stack space and your program goes kaput.  That doesn’t happen with iteration, so it’s safer to just avoid recursion altogether.

Recursion In Scala

Scala, however, being a functional language is very much geared toward recursion rather than iteration.  So how does it overcome the limitations of the call stack?  Let’s look at an example recursive function in scala.

def listLength1(list: List[_]): Int = {
  if (list == Nil) 0
  else 1 + listLength1(list.tail)
}

var list1 = List(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
var list2 = List(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
1 to 15 foreach( x => list2 = list2 ++ list2 )

println( listLength1( list1 ) )
println( listLength1( list2 ) )

Function listLength1 recursively counts the number of items in a list.  Try running this in the interpreter.  It works fine for list1, the short list, but the longer list exhausts the stack.  Recursion is a functional language’s bread and butter, but we see here that even scala, a functional language, is subject to call stack limitations.

Don’t give up on recursion yet, though.  Scala has a very important optimization that will allow you to recurse without limit provided you use the right kind of recursion.

Head Recursion And Tail Recursion

There are two basic kinds of recursion: head recursion and tail recursion.  In head recursion, a function makes its recursive call and then performs some more calculations, maybe using the result of the recursive call, for example.  In a tail recursive function, all calculations happen first and the recursive call is the last thing that happens.

The importance of this distinction doesn’t jump out at you, but it’s extremely important!  Imagine a tail recursive function.  It runs.  It completes all its computation.  As its very last action, it is ready to make its recursive call.  What, at this point, is the use of the stack frame?  None at all.  We don’t need our local variables anymore because we’re done with all computations.  We don’t need to know which function we’re in because we’re just going to re-enter the very same function.  Scala, in the case of tail recursion, can eliminate the creation of a new stack frame and just re-use the current stack frame.  The stack never gets any deeper, no matter how many times the recursive call is made.  That’s the voodoo that makes tail recursion special in scala.

Incidentally, some languages achieve a similar end by converting tail recursion into iteration rather than by manipulating the stack.

This won’t work with head recursion.  Do you see why?  Imagine a head recursive function.  First it does some work, then it makes its recursive call, then it does a little more work.  We can’t just re-use the current stack frame when we make that recursive call.  We’re going to NEED that stack frame info after the recursive call completes.  It has our local variables, including the result (if any) returned by the recursive call.

Here’s a question for you.  Is the example function listLength1 head recursive or tail recursive?  Well, what does it do?  (A) It checks whether its parameter is Nil.  (B) If so, it returns 0 since Nil has 0 length.  (C) If not, it returns 1 plus the result of a recursive call.  The recursive call is the last thing we typed before ending the function.  That’s tail recursion, right?  Wrong.  The recursive call is made, and THEN 1 is added to the result, and this sum is returned.  This is actually head recursion (or middle recursion, if you like) because the recursive call is not the very last thing that happens.

Tail Recursion Example

When you write a recursive function in scala, your aim is to encourage the compiler to make tail recursion optimizations.  Now let’s rewrite that function using tail recursion.

def listLength2(list: List[_]): Int = {
  def listLength2Helper(list: List[_], len: Int): Int = {
    if (list == Nil) len
    else listLength2Helper(list.tail, len + 1)
  }
  listLength2Helper(list, 0)
}

println( listLength2( list1 ) )
println( listLength2( list2 ) )

I wrote this as two functions (listLength2 and an internal helper function) to preserve the one-parameter interface used in the earlier example.  It would be great if we could specify a default value for a parameter.  We could then write this as one function, but I don’t know a way to do it.  Long story short: listLength2 just calls listLength2Helper which does the real work and is the recursive function.

Is listLength2Helper head recursive or tail recursive?  When the recursive call is made are we really finished with the stack frame, allowing scala to optimize for tail recursion?  Just like in listLength1, this function first checks for a Nil list, but this time returns the len parameter rather than 0.  If list is non-Nil then we make the recursive call.  But there’s still an addition going on here, len + 1.  Does that ruin the tail recursion?  No.  That term is evaluated first.  Only after all the parameters have been evaluated does the recursive call happen.  This function does indeed qualify for tail recursion optimization.

By The Way…

The two examples in this post are very elementary and it’s easy to predict visually that tail recursion optimization will be applied.  But it’s not always so simple.  If you’re trying to write a complex recursive function and you require that it be optimized as tail recursive then you have a problem.  How do you verify that you have succeeded?  The only ways that I know of are (A) examine the resulting object code or (B) write a unit test to verify that the optimization was made.  The problem with this is that you open yourself up to bugs in your unit tests.  You have to be able to write a test that you know will fail if the function is not optimized.  I’m new to scala, so maybe there’s a better way to do this.  If there is, please educate me.

What I would like to see is an annotation that marks a function as requiring tail recursion optimization.  Such an annotation would trigger a compiler error if the compiler were unable to make the optimization.  The annotation I propose would work like this:

@TailRecursive
def listLength1(list: List[_]): Int = {
  if (list == Nil) 0
  else 1 + listLength1(list.tail)
}

def listLength2(list: List[_]): Int = {
  @TailRecursive
  def listLength2Helper(list: List[_], len: Int): Int = {
    if (list == Nil) len
    else listLength2Helper(list.tail, len + 1)
  }
  listLength2Helper(list, 0)
}

The compiler would succeed for listLength2Helper, but would report an error when it failed to apply the requested tail recursion optimization.  This doesn’t let the developer off the hook on unit testing, nor does it relieve the developer from having to code carefully.  What it does is provide early, infallible verification that a crucial feature was implemented.  Why force the developer to examine the classfile or write a bug-vulnerable unit test when the compiler knowsthe answer and could just cough up the information at compile time?