Saturday, July 05, 2008

Designing overflow-safe algorithms

I was thinking about an old interview type problem: "Given a set of N-1 distinct integers from the range 1 .. N, how will you detect the one missing integer?". Needless to say, the input contains no duplicates and the input is unordered. There are numerous solutions, with different problems/limitations:



  • Sort the input array, go through that array and locate the missing element. Problem is that this is O(NlogN)

  • Maintain an array (or a bit vector) of size N initialized to 0. As you read the input, mark the corresponding array element as 1. Then make one pass through the array and detect the missing element by looking for value "0" in the array. This algorithm is O(N) but requires investment in memory

  • A cute solution that interviewers often look for involves summing up the input and subtracting this from the sum of numbers from 1 to N ( which is 0.5 * N * (N + 1) as any high school student knows). The difference between the two is the integer missing from the input.

The last solution is quite neat except that if one were to attempt to write such a program in most commonly used programming languages one would run into overflow problems. One could use special libraries that allow one to use arbitrarily large numbers (e.g. Perl's Math::BigInt) but they come with overheads that are really unnecessary for the problem at hand. The solution I propose involves keeping a "running difference" between the sum of the input and sum(1 .. N) making sure that this difference is always within the range (-N - 1 .. N - 1). Essentially we are going to iteratively calculate the value of

missing number = sum(1..N) - Sum(input) = 1 + 2 + ... + (N-1) + N - sum(input).

The first cut implementation might be something like



i = 1; diff = 0;
while (not end of input) {
diff += i - input element;
++i;
}
return (N + diff);


The problem is that this could lead to an underflow if (say) the initial numbers in the input are all very high. If we instead start from the top like:


i = 1; diff = 0;
while (not end of input) {
diff += N - i + 1 - input element;
++i;
}
return (diff + 1);


we could have the opposite problem that we could overflow on the higher side.

The solution to this is to make sure that diff is always within bounds - if we find that it is negative, we compensate by using an unused number from the higher end of the range (1..N) whereas if we find that it is going too high, we use an unused number from the lower end of the range (1..N)

Here is my Perl implementation:



sub with_summing {# @input
my $high = @_ + 1; #Higher end of range = #elems in input + 1
my $low = 1;
my $diff = 0;
foreach my $next (@_){
if($diff < 0) {
$diff = $diff - $next + $high ; --$high;
}
else {
$diff = $diff - $next + $low; ++$low;
}
}
return $diff + $low; #At this point $high should equal $low
}


This technique of iteratively calculating a bounded value without creating intermediate values that might overflow is useful in a number of other situations. Such solutions have the additional advantage they can also give visibility into the value the input is tending to as more and more input is read. For example, we could use this technique to calculate the average. Rather than using the formula average = sum of input/cardinality of input that is susceptible to overflow, one could instead use the following iterative solution

average_so_far = 0;
num_input_elements = 0;
while (not end of input) {
average_so_far = average_so_far * (num_input_elements/(1 + num_input_elements) + new_element/(1 + num_input_elements);
++num_input_elements;
}


And we could wrap this into a nice class as follows:

class average {
public double average_so_far();
public void add_value(double new_value);
}


We could then use this class in (say) a chemical plant application where two threads share this object with one thread adding readings from a data feed while another thread could be displaying the summarized average to an operator. Notice the technique we used to update the average - we took care not to use the formula
average_so_far = (num_input_elements * average_so_far + new element)/(num_input_elements + 1);


We actually used this technique for keeping track of the results of a long running experiment we conducted as part of our work doing Performance testing at Aztecsoft, my previous company.

One might wonder how far one can go with this idea. How about calculating Standard Deviation anyone? My gut told me that this wasn't possible but it turns out that my gut was wrong! Here is how we can calculate SD without causing overflows (this may not be very readable but I am not good with representing mathematical symbols on a blog :-(:

sd = sqrt((sum(square(element - average))/(num_elems - 1)))

Let sq(i) = sum(square(element - average(i))/(i - 1)

Then it can be shown that (this is an interesting exercise that I am leaving for the reader, mainly because I will go crazy trying to format the derivation in this extremely hard to use blogging editor):
sq(i + 1) = (i-1)/(i))*sq(i) +
(delta)*(delta) +
(element(i+1) - average(i+1))*(element(i+1) - average(i+1))
where delta = (average(i+1) - average(i)
)

Notice that none of the elements of this formula are likely to cause overflow issues. This lets us write the following code to calculate the standard deviation


sub sd { # input array of numbers
my $average = 0;
my $new_average;
my $num_elems = 0;
my $old_square = 0;
my $square = 0;
foreach my $new_elem (@_){
my $old_average = $average;
$average = $old_average*($num_elems/($num_elems + 1)) +
$new_elem/($num_elems + 1) ;
my $delta = $average - $old_average;
$old_square = $square;
if($num_elems > 0){
$square = (($num_elems - 1)/$num_elems) * $old_square +
($delta ** 2) +
(($new_elem - $average) ** 2)/$num_elems;
}
++$num_elems;
}
return ($square ** 0.5, $average);
}


To sum up, many seemingly straightforward "correct" solutions can break down at high data volumes but there is often an elegant solution that can overcome the limitations of the straightforward solution. Be on the lookout for such opportunities.