m (Changed some formatting.)
(A new section on scanl)
Revision as of 03:12, 11 October 2007
1 FoldsFirst, read TailRecursive. If you are not writing your code tail-recursively, then that is why you are getting stack overflows. However, as the bottom of that page suggests, making code tail-recursive in a lazy language is not quite the same as in a eager language. This page is more geared to the latter case using foldr/l as the prime culprit/example. As such WhatIsaFold may be helpful, but isn't too critical. Also knowing what
The definitions of the three folds we'll be looking at are as follows,
foldr f z  = z foldr f z (x:xs) = f x (foldr f z xs) foldl f z  = z foldl f z (x:xs) = foldl f (f z x) xs foldl' f z  = z foldl' f z (x:xs) = (foldl' f $! f z x) xs foldl' (found in e.g. Data.List) is just a stricter version of foldl.
The one-line summary for folds: if the binary operation is strict use foldl' otherwise use foldr.
Common newbie stack overflowing code:
mysum :: [Integer] -> Integer mysum = foldr (+) 0 main = print (mysum [1..1000000])
If you've read TailRecursive, you should immediately see the problem from the definition of foldr above. Quite simply, foldr isn't tail-recursive! But,
concat xss = foldr (++) [] xss
This is from the Haskell Report. Surely they know what they are doing! And sure enough,
main = print (length (concat [[x] | x <- [1..1000000]]))
Common less newbie stack overflowing code:
mysum :: [Integer] -> Integer mysum = foldl (+) 0 main = print (mysum [1..1000000])
So what's going on here. Looking at the code for foldl, it looks tail-recursive. Well, much like you can see the problem with a non-tail-recursive factorial by unfolding a few iterations, let's do the same for our foldl definition of sum, but making sure to use a call-by-name/need evaluation order. Here is the unfolding,
mysum [1..10] -> foldl (+) 0 (1:[2..10]) -> foldl (+) (0+1) (2:[3..10]) -> foldl (+) (0+1+2) (3:[4..10]) -> foldl (+) (0+1+2+3) (4:[5..10]) -> ...
I think you get the idea. The problem is that we are building up a chain of thunks that will evaluate the sum instead of just maintaining a running sum. What we need to do is to force the addition before recursing. This is exactly what foldl' does.
Just to check,
mysum :: [Integer] -> Integer mysum = foldl' (+) 0 main = print (mysum [1..1000000])
Now let's go back to the foldr sum and concat. What's the difference between sum and concat that makes the sum definition wrong, but the concat definition right. Again, let's evaluate each by hand.
mysum (+) 0 [1..10] -> foldr (+) 0 (1:[2..10]) -> 1+foldr (+) 0 (2:[3..10]) -> 1+(2+foldr (+) 0 (3:[4..10])) -> ...
Okay, no surprise there.
concat [,,,...] -> foldr (++) [] (:[,,...]) -> (1:)++foldr (++) [] [,,...] -> 1:(++foldr (++) [] [,,...])
Notice that there is no '-> ...' at the end. That was the complete evaluation. There is no reason to do anything more unless we look at the more of the result. We may well GC the 1 before we look at the tail, and GC the first cons cell before we look at the second. So, concat runs in a constant amount of stack and further can handle infinite lists (as a note, it's immediately obvious foldl(') can never work on infinite lists because we'll always be in the (:) case and that always immediately recurses). The differentiator between mysum and concat is that (++) is not strict* in its second argument; we don't have to evaluate the rest of the foldr to know the beginning of concat. In mysum, since (+) is strict in its second argument we need the results of the whole foldr before we can compute the final result.
So, we arrive at the one-line summary: A function strict in its second argument will always* require linear stack space with foldr, so foldl' should be used instead in that case. If the function is lazy/non-strict in its second argument we should use foldr to 1) support infinite lists and 2) to allow a streaming use of the input list where only part of it needs to be in memory at a time.Okay, both here and in the one-line summary, there is no mention of foldl. When should foldl be used? The pragmatic answer is: by and far it shouldn't be used. A case where it makes a difference is if the function is conditionally strict in its first argument depending on its second, where I use conditionally strict to mean a function that is strict or not in one argument depending on another argument(s). For an example, consider a definition of
- A strict function is a function, f such that f ⊥ = ⊥. Typically, we think of a function "being strict" in an argument as a function that "forces" its argument, but the above definition of strict should immediately suggest another function that is strict and doesn't "force" it's argument in the intuitive sense, namely id. (++) = id and therefore is a strict function. Sure enough, if you were to evaluate (concat (repeat )) it would not terminate. As such (++) is a conditionally strict function. This also makes the "always" slightly imprecise, a function that is strict because it just returns it's argument will not use up stack space (but is, as I mentioned, still an issue for infinitely long lists).
A subtle stack-overflow surprise comes when
print (scanl (+) 0 [1..1000000])
completes successfully but
print (last (scanl (+) 0 [1..1000000]))
causes a stack overflow.
The latter stack overflow is explained exactly as before, namely,
last (scanl (+) 0 [1..5]) -> ... several steps ... -> ((((0+1)+2)+3)+4)+5
Most puzzling is why the former succeeds without a stack overflow. This is caused by a combination of two factors:
- thunks in the list produced by enjoy sharing: late thunks build upon early thunksscanl
- printing a list of numbers evaluates early thunks and then late thunks
To exemplify, here is an abridged progression. I use this pseudo format to depict sharing of thunks
expr where var=expr, var=expr
although in reality it is more like a pointer graph.
print (scanl (+) 0 [1..1000000]) -> print (a : case [1..1000000] of ... x:xs -> scanl (+) (a+x) xs) where a=0 -> ... evaluate a to 0 for printing, I/O, some more steps ... -> print (scanl (+) (a+1) [2..1000000]) where a=0 -> print (b : case [2..1000000] of ... x:xs -> scanl (+) (b+x) xs) where a=0, b=a+1 -> ... evaluate b to 1 for printing, I/O, some more steps ... -> print (scanl (+) (b+2) [3..1000000]) where b=1 -> print (c : case [3..1000000] of ... x:xs -> scanl (+) (c+x) xs) where b=1, c=b+2 -> ... evaluate c to 3 for printing, I/O, some more steps ... -> print (scanl (+) (c+3) [4..1000000]) where c=3 -> print (d : case [4..1000000] of ... x:xs -> scanl (+) (d+x) xs) where c=3, d=c+3 -> ... evaluate d to 6 for printing, I/O, some more steps ... -> print (scanl (+) (d+4) [5..1000000]) where d=6 -> etc.