Recursion 101

We all know what recursion is, right?  A function calls itself.  Or function A calls function B which calls function A.  Or A calls B, which calls C, which calls A, etc.  By far the most common situation is that in which a function calls itself.

You don’t see a lot of recursive code in Java.  There are a couple reasons for this.  First, recursion is hard.  It’s not intuitive.  With iteration (the main alternative to recursion) you see the big picture.  You can look at the whole loop and its easy to understand even for the beginner.  With recursion, you see one layer and you have to imagine what happens when those layers stack up.  Like anything else, if you practice iteration then it becomes easier, but compared to iteration, recursion is hard to learn.

Second, Java is not designed to accomodate recursion.  It’s designed to accomodate iteration.  Java gives you for loops, for-each loops, while loops, do loops, arrays, Iterators, ResultSets, etc.  These constructs are all about iteration.  Plus, recursion in Java has an Achilles’ heel: the call stack.

Generally, when you call a function in any language a new level is added to the call stack.  The call stack, we all know, is what keeps track of local variables, what function has called what, etc.  It’s no different in Java.  But the stack has a finite and limited size.  Recursion is fine if you know for certain that you’ll never go more than a few dozen levels deep.  But if the recursion goes too deep, you run out of stack space and your program goes kaput.  That doesn’t happen with iteration, so it’s safer to just avoid recursion altogether.

Recursion In Scala

Scala, however, being a functional language is very much geared toward recursion rather than iteration.  So how does it overcome the limitations of the call stack?  Let’s look at an example recursive function in scala.

def listLength1(list: List[_]): Int = {
  if (list == Nil) 0
  else 1 + listLength1(list.tail)
}

var list1 = List(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
var list2 = List(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
1 to 15 foreach( x => list2 = list2 ++ list2 )

println( listLength1( list1 ) )
println( listLength1( list2 ) )

Function listLength1 recursively counts the number of items in a list.  Try running this in the interpreter.  It works fine for list1, the short list, but the longer list exhausts the stack.  Recursion is a functional language’s bread and butter, but we see here that even scala, a functional language, is subject to call stack limitations.

Don’t give up on recursion yet, though.  Scala has a very important optimization that will allow you to recurse without limit provided you use the right kind of recursion.

Head Recursion And Tail Recursion

There are two basic kinds of recursion: head recursion and tail recursion.  In head recursion, a function makes its recursive call and then performs some more calculations, maybe using the result of the recursive call, for example.  In a tail recursive function, all calculations happen first and the recursive call is the last thing that happens.

The importance of this distinction doesn’t jump out at you, but it’s extremely important!  Imagine a tail recursive function.  It runs.  It completes all its computation.  As its very last action, it is ready to make its recursive call.  What, at this point, is the use of the stack frame?  None at all.  We don’t need our local variables anymore because we’re done with all computations.  We don’t need to know which function we’re in because we’re just going to re-enter the very same function.  Scala, in the case of tail recursion, can eliminate the creation of a new stack frame and just re-use the current stack frame.  The stack never gets any deeper, no matter how many times the recursive call is made.  That’s the voodoo that makes tail recursion special in scala.

Incidentally, some languages achieve a similar end by converting tail recursion into iteration rather than by manipulating the stack.

This won’t work with head recursion.  Do you see why?  Imagine a head recursive function.  First it does some work, then it makes its recursive call, then it does a little more work.  We can’t just re-use the current stack frame when we make that recursive call.  We’re going to NEED that stack frame info after the recursive call completes.  It has our local variables, including the result (if any) returned by the recursive call.

Here’s a question for you.  Is the example function listLength1 head recursive or tail recursive?  Well, what does it do?  (A) It checks whether its parameter is Nil.  (B) If so, it returns 0 since Nil has 0 length.  (C) If not, it returns 1 plus the result of a recursive call.  The recursive call is the last thing we typed before ending the function.  That’s tail recursion, right?  Wrong.  The recursive call is made, and THEN 1 is added to the result, and this sum is returned.  This is actually head recursion (or middle recursion, if you like) because the recursive call is not the very last thing that happens.

Tail Recursion Example

When you write a recursive function in scala, your aim is to encourage the compiler to make tail recursion optimizations.  Now let’s rewrite that function using tail recursion.

def listLength2(list: List[_]): Int = {
  def listLength2Helper(list: List[_], len: Int): Int = {
    if (list == Nil) len
    else listLength2Helper(list.tail, len + 1)
  }
  listLength2Helper(list, 0)
}

println( listLength2( list1 ) )
println( listLength2( list2 ) )

I wrote this as two functions (listLength2 and an internal helper function) to preserve the one-parameter interface used in the earlier example.  It would be great if we could specify a default value for a parameter.  We could then write this as one function, but I don’t know a way to do it.  Long story short: listLength2 just calls listLength2Helper which does the real work and is the recursive function.

Is listLength2Helper head recursive or tail recursive?  When the recursive call is made are we really finished with the stack frame, allowing scala to optimize for tail recursion?  Just like in listLength1, this function first checks for a Nil list, but this time returns the len parameter rather than 0.  If list is non-Nil then we make the recursive call.  But there’s still an addition going on here, len + 1.  Does that ruin the tail recursion?  No.  That term is evaluated first.  Only after all the parameters have been evaluated does the recursive call happen.  This function does indeed qualify for tail recursion optimization.

By The Way…

The two examples in this post are very elementary and it’s easy to predict visually that tail recursion optimization will be applied.  But it’s not always so simple.  If you’re trying to write a complex recursive function and you require that it be optimized as tail recursive then you have a problem.  How do you verify that you have succeeded?  The only ways that I know of are (A) examine the resulting object code or (B) write a unit test to verify that the optimization was made.  The problem with this is that you open yourself up to bugs in your unit tests.  You have to be able to write a test that you know will fail if the function is not optimized.  I’m new to scala, so maybe there’s a better way to do this.  If there is, please educate me.

What I would like to see is an annotation that marks a function as requiring tail recursion optimization.  Such an annotation would trigger a compiler error if the compiler were unable to make the optimization.  The annotation I propose would work like this:

@TailRecursive
def listLength1(list: List[_]): Int = {
  if (list == Nil) 0
  else 1 + listLength1(list.tail)
}

def listLength2(list: List[_]): Int = {
  @TailRecursive
  def listLength2Helper(list: List[_], len: Int): Int = {
    if (list == Nil) len
    else listLength2Helper(list.tail, len + 1)
  }
  listLength2Helper(list, 0)
}

The compiler would succeed for listLength2Helper, but would report an error when it failed to apply the requested tail recursion optimization.  This doesn’t let the developer off the hook on unit testing, nor does it relieve the developer from having to code carefully.  What it does is provide early, infallible verification that a crucial feature was implemented.  Why force the developer to examine the classfile or write a bug-vulnerable unit test when the compiler knowsthe answer and could just cough up the information at compile time?