diff --git a/common.py b/common.py index d0e507d..cb3e7b6 100644 --- a/common.py +++ b/common.py @@ -1,4 +1,4 @@ -import cStringIO +import functools import numpy as np import logging @@ -68,6 +68,33 @@ def iterate(data, bufsize, offset=0, advance=1, func=None): buf_index = max(0, buf_index - advance) offset += advance +class Splitter(object): + def __init__(self, iterable, n): + self.iterable = iter(iterable) + self.read = [True] * n + self.last = None + self.generators = [functools.partial(self._gen, i)() for i in range(n)] + self.n = n + + def _gen(self, index): + while True: + if all(self.read): + try: + self.last = self.iterable.next() + except StopIteration: + return + + assert len(self.last) == self.n + self.read = [False] * self.n + + if self.read[index]: + raise IndexError(index) + self.read[index] = True + yield self.last[index] + +def split(iterable, n): + return Splitter(iterable, n).generators + if __name__ == '__main__': import pylab diff --git a/test_common.py b/test_common.py index 5825ced..ea6c5b2 100644 --- a/test_common.py +++ b/test_common.py @@ -16,3 +16,16 @@ def test_iterate(): assert iterlist(range(N), 2, offset=5) == [(i, [i, i+1]) for i in range(5, N-1)] assert iterlist(range(N), 1, func=lambda b: -b) == [(i, [-i]) for i in range(N)] +def test_split(): + L = [(i*2, i*2+1) for i in range(10)] + iters = common.split(L, n=2) + assert zip(*iters) == L + + for i in [0, 1]: + iters = common.split(L, n=2) + iters[i].next() + try: + iters[i].next() + assert False + except IndexError as e: + assert e.args == (i,)