2015-11-27

How to compute the intersection of two sorted lists (in Python)

This blog post explains how to compute the sorted intersection of two sorted lists, and it shows a fast Python implementation. The time complexity is O(min(n + m, n · log(m)), where n is the minumum of the list lengths, and m is the maximum of the list lengths.

The first idea is to take the first element of both lists, and, if they are different, discard the smaller one of the two. (The smaller one can't be in the intersection because it's smaller than all elements of the other list.) If the first elements were equal, then emit the value as part of the intersection, and discard both first elements. Continue this until one of the lists run out. If discarding is implemented by incrementing index variables, this method finishes in O(n + m) time, because each iteration discards at least one element.

We can improve on the first idea by noticing that it's possible to discard multiple elements in the beginning (i.e. when the two first elements are different): it's possible to discard all elements which are smaller than the larger one of the two first elements. Depending on the input, this can be a lot of elements, for example if all elements of the first list are smaller than all elements of the second list, then it will discard the entire first list in one step. In the general case, binary search can be used figure out how many elements to discard. However, binary search takes logarithmic time, so the total execution time is O(m · log(m)), which can be faster or slower than the O(n + m) of the previous solution. In fact, by more careful analysis of the number of runs (blocks which are discarded), it's possible to show that the execution time is just O(n · log(m)), but that can still be slower than the previous solution.

It's possible to combine the two solutions: estimate the execution time of the 2nd solution, and if the estimate is smaller than n + m, execute the 2nd solution, otherwise execute the first solution. Please note that the estimate also takes into account the constants (not only big-O). The resulting Python (≥2.4 and 3.x) code looks like:

def intersect_sorted(a1, a2):
  """Yields the intersection of sorted lists a1 and a2, without deduplication.

  Execution time is O(min(lo + hi, lo * log(hi))), where lo == min(len(a1),
  len(a2)) and hi == max(len(a1), len(a2)). It can be faster depending on
  the data.
  """
  import bisect, math
  s1, s2 = len(a1), len(a2)
  i1 = i2 = 0
  if s1 and s1 + s2 > min(s1, s2) * math.log(max(s1, s2)) * 1.4426950408889634:
    bi = bisect.bisect_left
    while i1 < s1 and i2 < s2:
      v1, v2 = a1[i1], a2[i2]
      if v1 == v2:
        yield v1
        i1 += 1
        i2 += 1
      elif v1 < v2:
        i1 = bi(a1, v2, i1)
      else:
        i2 = bi(a2, v1, i2)
  else:  # The linear solution is faster.
    while i1 < s1 and i2 < s2:
      v1, v2 = a1[i1], a2[i2]
      if v1 == v2:
        yield v1
        i1 += 1
        i2 += 1
      elif v1 < v2:
        i1 += 1
      else:
        i2 += 1

The numeric constant 1.4426950408889634 in the code above is 1/math.log(2).

The code with some tests and with support for merging multiple sequences is available on GitHub here.

Related questions on StackOverflow: this and this.

2 comments:

francis gan said...

Could you please give more comments in your code -- so that we can understand what you are thinking?

Thanks a lotg

pts said...

@francis gan: The text of the blog post contains a full explanation on how the code works.

If you think that some of that explanation can be moved to comments, please write those comments, and send a pull request on GitHub.