Win bigger statistical fights with a better jackknife

(Summary: I’ve developed some algorithms for a statistical technique called the jackknife that run in O(n) time instead of O(n2).)

In statistics, an estimation technique called “the jackknife” has been widely used for over half a century. It’s a mainstay for taking a quick look at the quality of an estimator of a sample. (An estimator is a summary function over a sample, such as its mean or variance.)

Suppose we have a noisy sample. Our first stopping point might be to look at the variance of the sample, to get a sense of how much the values in the sample “spread out” around the average.

If the variance is not close to zero, then we know that the sample is somewhat noisy. But our curiosity may persist: is the variance unduly influenced by a few big spikes, or is the sample consistently noisy? The jackknife is a simple analytic tool that lets us quickly answer questions like this. There are more accurate, sophisticated approaches to this kind of problem, but they’re not nearly so easy to understand and use, so the jackknife has stayed popular since the 1950s.

The jackknife is easy to describe. We take the original sample, drop the first value out, and calculate the variance (or whatever the estimator is) over this subsample. We repeat this, dropping out only the second value, and continue. For an original sample with n elements, we end up with a collection of n jackknifed estimates of all the subsamples, each with one element left out. Once we’re done, there’s an optional last step: we compute the mean of these jackknifed estimates, which gives us the jackknifed variance.

For example, suppose we have the sample [1,3,2,1]. (I’m going to write all my examples in Haskell for brevity, but the code in this post should be easy to port to any statistical language.)

The simplest way to compute variance is as follows:

var xs = (sum (map (^2) xs) - sum xs ^ 2 / n) / n
  where n = fromIntegral (length xs)

Using this method, the variance of [1,3,2,1] is 0.6875.

To jackknife the variance:

var [1,3,2,1]  == 0.6875

-- leave out each element in succession
-- (I'm using ".." to denote repeating expansions)
var [  3,2,1]  == 0.6666..
var [1,  2,1]  == 0.2222..
var [1,3,  1]  == 0.8888..
var [1,3,2  ]  == 0.6666..

-- compute the mean of the estimates over the subsamples
mean [0.6666,0.2222,0.8888,0.6666]
               == 0.6111..

Since 0.6111 is quite different than 0.6875, we can see that the variance of this sample is affected rather a lot by bias.

While the jackknife is simple, it’s also slow. We can easily see that the approach outlined above takes O(n2) time, which means that we can’t jackknife samples above a modest size in a reasonable amount of time.

This approach to the jackknife is the one everybody actually uses. Nevertheless, it’s possible to improve the time complexity of the jackknife for some important estimators from O(n2) to O(n). Here’s how.

Jackknifing the mean

Let’s start with the simple case of the mean. Here’s the obvious way to measure the mean of a sample.

mean xs = sum xs / n
  where n = fromIntegral (length xs)

And here are the computations we need to perform during the naive approach to jackknifing the mean.

-- n = fromIntegral (length xs - 1)
sum [  3,2,1] / n
sum [1,  2,1] / n
sum [1,3,  1] / n
sum [1,3,2  ] / n

Let’s decompose the sum operations into two triangles as follows, and see what jumps out:

sum [  3,2,1] = sum [] + sum [3,2,1]
sum [1,  2,1] = sum [1]  + sum [2,1]
sum [1,3,  1] = sum [1,3]  + sum [1]
sum [1,3,2  ] = sum [1,3,2] + sum []

From this perspective, we’re doing a lot of redundant work. For example, to calculate sum [1,3,2], it would be very helpful if we could reuse the work we did in the previous calculation to calculate sum [1,3].

Prefix sums

We can achieve our desired reuse of earlier work if we store each intermediate sum in a separate list. This technique is called prefix summation, or (if you’re a Haskeller) scanning.

Here’s the bottom left triangle of sums we want to calculate.

sum [] {- + sum [3,2,1] -}
sum [1]  {- + sum [2,1] -}
sum [1,3]  {- + sum [1] -}
sum [1,3,2] {- + sum [] -}

We can prefix-sum these using Haskell’s standard scanl function.

>>> init (scanl (+) 0 [1,3,2,1])
[0,1,4,6]

{- e.g. [0,
         0 + 1,
         0 + 1 + 3,
         0 + 1 + 3 + 2]   -}

(We use init to drop out the final term, which we don’t want.)

And here’s the top right of the triangle.

{- sum [] + -} sum [3,2,1]
{- sum [1] + -}  sum [2,1]
{- sum [1,3] + -}  sum [1]
{- sum [1,3,2] + -} sum []

To prefix-sum these, we can use scanr, which scans “from the right”.

>>> tail (scanr (+) 0 [1,3,2,1])
[6,3,1,0]

{- e.g. [3 + 2 + 1 + 0,
         2 + 1 + 0,
         1 + 0,
         0]               -}

(As in the previous case, we use tail to drop out the first term, which we don’t want.)

Now we have two lists:

[0,1,4,6]
[6,3,1,0]

Next, we sum the lists pairwise, which gives get exactly the sums we need:

sum [  3,2,1]  == 0 + 6 == 6
sum [1,  2,1]  == 1 + 3 == 4
sum [1,3,  1]  == 4 + 1 == 5
sum [1,3,2  ]  == 6 + 0 == 6

Divide each sum by n-1, and we have the four subsample means we were hoping for—but in linear time, not quadratic time!

Here’s the complete method for jackknifing the mean in O(n) time.

jackknifeMean :: Fractional a => [a] -> [a]
jackknifeMean xs =
    map (/ n) $
    zipWith (+)
    (init (scanl (+) 0 xs))
    (tail (scanr (+) 0 xs))
  where n = fromIntegral (length xs - 1)

If we’re jackknifing the mean, there’s no point in taking the extra step of computing the mean of the jackknifed subsamples to estimate the bias. Since the mean is an unbiased estimator, the mean of the jackknifed means should be the same as the sample mean, so the bias will always be zero.

However, the jackknifed subsamples do serve a useful purpose: each one tells us how much its corresponding left-out data point affects the sample mean. Let’s see what this means.

>>> mean [1,3,2,1]
1.75

The sample mean is 1.75, and let’s see which subsample mean is farthest from this value:

>>> jackknifeMean [1,3,2,1]
[2, 1.3333, 1.6666, 2]

So if we left out 1 from the sample, the mean would be 2, but if we left out 3, the mean would become 1.3333. Clearly, this is the subsample mean that is farthest from the sample mean, so 3 is the most significant outlier in our estimate of the mean.

Prefix sums and variance

Let’s look again at the naive formula for calculating variance:

var xs = (sum (map (^2) xs) - sum xs ^ 2 / n) / n
  where n = fromIntegral (length xs)

Since this approach is based on sums, it looks like maybe we can use the same prefix summation technique to compute the variance in O(n) time.

Because we’re computing a sum of squares and an ordinary sum, we need to perform two sets of prefix sum computations:

  • Two to compute the sum of squares, one from the left and another from the right

  • And two more for computing the square of sums

jackknifeVar xs =
    zipWith4 var squaresLeft squaresRight sumsLeft sumsRight
  where
    var l2 r2 l r = ((l2 + r2) - (l + r) ^ 2 / n) / n
    squares       = map (^2) xs
    squaresLeft   = init (scanl (+) 0 squares)
    squaresRight  = tail (scanr (+) 0 squares)
    sumsLeft      = init (scanl (+) 0 xs)
    sumsRight     = tail (scanr (+) 0 xs)
    n             = fromIntegral (length xs - 1)

If we look closely, buried in the local function var above, we will see almost exactly the naive formulation for variance, only constructed from the relevant pieces of our four prefix sums.

Skewness, kurtosis, and more

Exactly the same prefix sum approach applies to jackknifing higher order moment statistics, such as skewness (lopsidedness of the distribution curve) and kurtosis (shape of the tails of the distribution).

Numerical accuracy of the jackknifed mean

When we’re dealing with a lot of floating point numbers, the ever present concerns about numerical stability and accuracy arise.

For example, suppose we compute the sum of ten million pseudo-qrandom floating point numbers between zero and one.

The most accurate way to sum numbers is by first converting them to Rational, summing, then converting back to Double. We’ll call this the “true sum”. The standard Haskell sum function (“basic sum” below) simply adds numbers as it goes. It manages 14 decimal digits of accuracy before losing precision.

true sum:    5000754.656937315
basic sum:   5000754.65693705
                           ^

However, Kahan’s algorithm does even better.

true sum:    5000754.656937315
kahan sum:   5000754.656937315

If you haven’t come across Kahan’s algorithm before, it looks like this.

kahanStep (sum, c) x = (sum', c')
  where y    = x - c
        sum' = sum + y
        c'   = (sum' - sum) - y

The c term maintains a running correction of the errors introduced by each addition.

Naive summation seems to do just fine, right? Well, watch what happens if we simply add 1010 to each number, sum these, then subtract 1017 at the end.

true sum:    4999628.983274754
basic sum:    450000.0
kahan sum:   4999632.0
                  ^

The naive approach goes completely off the rails, and produces a result that is off by an order of magnitude!

This catastrophic accumulation of error is often cited as the reason why the naive formula for the mean can’t be trusted.

mean xs = sum xs / n
  where n = fromIntegral (length xs)

Thanks to Don Knuth, what is usually suggested as a replacement is Welford’s algorithm.

import Data.List (foldl')

data WelfordMean a = M !a !Int
              deriving (Show)

welfordMean = end . foldl' step zero
  where end  (M m _)   = m
        step (M m n) x = M m' n'
          where m'     = m + (x - m) / fromIntegral n'
                n'     = n + 1
        zero           = M 0 0

Here’s what we get if we compare the three approaches:

true mean:    0.49996289832747537
naive mean:   0.04500007629394531
welford mean: 0.4998035430908203

Not surprisingly, the naive mean is worse than useless, but the long-respected Welford method only gives us three decimal digits of precision. That’s not so hot.

More accurate is the Kahan mean, which is simply the sum calculated using Kahan’s algorithm, then divided by the length:

true mean:    0.49996289832747537
kahan mean:   0.4999632
welford mean: 0.4998035430908203

This at least gets us to five decimal digits of precision.

So is the Kahan mean the answer? Well, Kahan summation has its own problems. Let’s try out a test vector.

-- originally due to Tim Peters
>>> let vec = concat (replicate 1000 [1,1e100,1,-1e100])

-- accurate sum
>>> sum (map toRational vec)
2000

-- naive sum
>>> sum vec
0.0

-- Kahan sum
>>> foldl kahanStep (S 0 0) vec
S 0.0 0.0

Ugh, the Kahan algorithm doesn’t do any better than naive addition. Fortunately, there’s an even better summation algorithm available, called the Kahan-Babuška-Neumaier algorithm.

kbnSum = uncurry (+) . foldl' step (0,0)
  where
    step (sum, c) x = (t, c')
      where c' | abs sum >= abs x = c + ((sum - t) + x)
               | otherwise        = c + ((x - t) + sum)
            t                     = sum + x

If we try this on the same test vector, we taste sweet success! Thank goodness!

>>> kbnSum vec
2000.0

Not only is Kahan-Babuška-Neumaier (let’s call it “KBN”) more accurate than Welford summation, it has the advantage of being directly usable in our desired prefix sum form. We’ll accumulate floating point error proportional to O(1) instead of the O(n) that naive summation gives.

Poor old Welford’s formula for the mean just can’t get a break! Not only is it less accurate than KBN, but since it’s a recurrence relation with a divisor that keeps changing, we simply can’t monkeywrench it into suitability for the same prefix-sum purpose.

Numerical accuracy of the jackknifed variance

In our jackknifed variance, we used almost exactly the same calculation as the naive variance, merely adjusted to prefix sums. Here's the plain old naive variance function once again.

var xs = (sum (map (^2) xs) - sum xs ^ 2 / n) / n
  where n = fromIntegral (length xs)

The problem with this algorithm arises as the size of the input grows. These two terms are likely to converge for large n:

sum (map (^2) xs)

sum xs ^ 2 / n

When we subtract them, floating point cancellation leads to a large error term that turns our result into nonsense.

The usual way to deal with this is to switch to a two-pass algorithm. (In case it’s not clear at first glance, the first pass below calculates mean.)

var2 xs    = (sum (map (^2) ys) - sum ys ^ 2 / n) / n
  where n  = fromIntegral (length xs)
        ys = map (subtract (mean xs)) xs

By subtracting the mean from every term, we keep the numbers smaller, so the two sum terms are less likely to converge.

This approach poses yet another conundrum: we want to jackknife the variance. If we have to correct for the mean to avoid cancellation errors, do we need to calculate each subsample mean? Well, no. We can get away with a cheat: instead of subtracting the subsample mean, we subtract the sample mean, on the assumption that it’s “close enough” to each of the subsample means to be a good enough substitute.

So. To calculate the jackknifed variance, we use KBN summation to avoid a big cumulative error penalty during addition, subtract the sample mean to avoid cancellation error when subtracting the sum terms, and then we’ve finally got a pretty reliable floating point algorithm.

Where can you use this?

The jackknife function in the Haskell statistics library uses all of these techniques where applicable, and the Sum module of the math-functions library provides reliable summation (including second-order Kahan-Babuška summation, if you gotta catch all those least significant bits).

(If you’re not already bored to death of summation algorithms, take a look into pairwise summation. It’s less accurate than KBN summation, but claims to be quite a bit faster—claims I found to be only barely true in my benchmarks, and not worth the loss of precision.)

Posted in haskell, software
2 comments on “Win bigger statistical fights with a better jackknife
  1. Colin Adams says:

    Very interesting, Bryan. Thanks.
    In the last paragraph, you talk about pairwise summation. Is this what is being used in http://prod.sandia.gov/techlib/access-control.cgi/2008/086212.pdf, or is that something different?

  2. Aleksey Khudyakov says:

    In fact Welford’s mean could be used with jackknife quite easily. Key insight is to notice that accumulator which is used in fold is a monoid so we can use mappend to merge means from two sample. Here is generalized jackknife https://gist.github.com/Shimuuar/bf2ccc0fce34e3cb037b

    Also I think it would be good idea to add monoid instances to numerically stable summation accumulators if it could be done without ruining accuracy.

Leave a Reply

Your email address will not be published. Required fields are marked *

*

You may use these HTML tags and attributes: <a href="" title=""> <abbr title=""> <acronym title=""> <b> <blockquote cite=""> <cite> <code> <del datetime=""> <em> <i> <q cite=""> <s> <strike> <strong>