Designing recursive functions around Python's stack limits
Some functions can be defined clearly and succinctly using a recursive formula. There are two common examples of this.
The factorial function has the following recursive definition:
The recursive rule for computing a Fibonacci number, Fn, has the following definition:
Each of these involves a case that has a simple defined value and a case that involves computing the function's value, based on other values of the same function.
The problem we have is that Python imposes a limitation on the upper limit for these kinds of recursive function definitions. While Python's integers can easily represent 1000!, the stack limit prevents us from doing this casually.
Computing Fn Fibonacci numbers involves an additional problem. If we're not careful, we'll compute a lot of values more than once:
And so on.
To compute F5, we'll compute F3 twice, and F2 three times. This can become extremely costly as computing one Fibonacci number involves also computing a cascading torrent of other numbers.
Pragmatically, the filesystem is an example of a recursive data structure. Each directory contains subdirectories. The essential design for a simple numeric recursion also applies to the analysis of the directory tree. Similarly, a document serialized in JSON notation is a recursive collection of objects; often, a dictionary of dictionaries. Understanding such simple cases for recursion make it easier to work with more complex recursive data structures.
In all of these cases, we're seeking to eliminate the recursion and replace it with iteration. In addition to recursion elimination, we'd like to preserve as much of the original mathematical clarity as we can.
Getting ready
Many recursive function definitions follow the pattern set by the factorial function. This is sometimes called tail recursion because the recursive case can be written at the tail of the function body:
def fact(n: int) -> int:
if n == 0:
return 1
return n*fact(n-1)
The last expression in the function refers to the function with a different argument value.
We can restate this, avoiding the recursion limits in Python.
How to do it...
A tail recursion can also be described as a reduction. We're going to start with a collection of values, and then reduce them to a single value:
- Expand the rule to show all of the details:. This helps ensure we understand the recursive rule.
- Write a loop or generator to create all the values: . In Python, this can be as simple as range(1, n+1). In some cases, though, we might have to apply some transformation function to the base values:. Applying a transformation often looks like this in Python:
N = (f(i) for i in range(1, n+1))
- Incorporate the reduction function. In this case, we're computing a large product, using multiplication. We can summarize this using notation. For this example, we're computing a product of values in a range:
Here's the implementation in Python:
def prod(int_iter: Iterable[int]) -> int:
p = 1
for x in int_iter:
p *= x
return p
We can refactor the fact() function to use the prod() function like this:
def fact(n: int):
return prod(range(1, n+1))
This works nicely. We've optimized a recursive solution to combine the prod() and fact() functions into an iterative function. This revision avoids the potential stack overflow problems the recursive version suffers from.
Note that the Python 3 range object is lazy: it doesn't create a big list object. The range object returns values as they are requested by the prod() function. This makes the overall computation relatively efficient.
How it works...
A tail recursion definition is handy because it's short and easy to remember. Mathematicians like this because it can help clarify what a function means.
Many static, compiled languages are optimized in a manner similar to the technique we've shown here. There are two parts to this optimization:
- Use relatively simple algebraic rules to reorder the statements so that the recursive clause is actually last. The if clauses can be reorganized into a different physical order so that return fact(n-1) * n is last. This rearrangement is necessary for code organized like this:
def ugly_fact(n: int) -> int: if n > 0: return fact(n-1) * n elif n == 0: return 1 else: raise ValueError(f"Unexpected {n=}")
- Inject a special instruction into the virtual machine's byte code—or the actual machine code—that re-evaluates the function without creating a new stack frame. Python doesn't have this feature. In effect, this special instruction transforms the recursion into a kind of while statement:
p = n while n != 1: n = n-1 p *= n
This purely mechanical transformation leads to rather ugly code. In Python, it may also be remarkably slow. In other languages, the presence of the special byte code instruction will lead to code that runs quickly.
We prefer not to do this kind of mechanical optimization. First, it leads to ugly code. More importantly – in Python – it tends to create code that's actually slower than the alternative we developed here.
There's more...
The Fibonacci problem involves two recursions. If we write it naively as a recursion, it might look like this:
def fibo(n: int) -> int:
if n <= 1:
return 1
else:
return fibo(n-1)+fibo(n-2)
It's difficult to do a simple mechanical transformation to turn something into a tail recursion. A problem with multiple recursions like this requires some more careful design.
We have two ways to reduce the computation complexity of this:
- Use memoization
- Restate the problem
The memoization technique is easy to apply in Python. We can use functools.lru_cache() as a decorator. This function will cache previously computed values. This means that we'll only compute a value once; every other time, lru_cache will return the previously computed value.
It looks like this:
from functools import lru_cache
@lru_cache(128)
def fibo_r(n: int) -> int:
if n < 2:
return 1
else:
return fibo_r(n - 1) + fibo_r(n - 2)
Adding a decorator is a simple way to optimize a more complex multi-way recursion.
Restating the problem means looking at it from a new perspective. In this case, we can think of computing all Fibonacci numbers up to, and including, Fn. We only want the last value in this sequence. We compute all the intermediates because it's more efficient to do it that way. Here's a generator function that does this:
def fibo_iter() -> Iterator[int]:
a = 1
b = 1
yield a
while True:
yield b
a, b = b, a + b
This function is an infinite iteration of Fibonacci numbers. It uses Python's yield so that it emits values in a lazy fashion. When a client function uses this iterator, the next number in the sequence is computed as each number is consumed.
Here's a function that consumes the values and also imposes an upper limit on the otherwise infinite iterator:
def fibo_i(n: int) -> int:
for i, f_i in enumerate(fibo_iter()):
if i == n:
break
return f_i
This function consumes each value from the fibo_iter() iterator. When the desired number has been reached, the break statement ends the for statement.
When we looked at the Avoiding a potential problem with break statements recipe in Chapter 2, Statements and Syntax, we noted that a while statement with a break may have multiple reasons for terminating. In this example, there is only one way to end the for statement.
We can always assert that i == n at the end of the loop. This simplifies the design of the function. We've also optimized the recursive solution and turned it into an iteration that avoids the potential for stack overflow.
See also
- See the Avoiding a potential problem with break statements recipe in Chapter 2, Statements and Syntax.