diff --git a/amodem/common.py b/amodem/common.py index a8382d7..ab9acf9 100644 --- a/amodem/common.py +++ b/amodem/common.py @@ -53,35 +53,13 @@ def iterate(data, size, func=None, truncate=True): offset += size -class Splitter(object): - - def __init__(self, iterable, n): - self.iterable = iter(iterable) - self.read = [True] * n - self.last = None - self.generators = [functools.partial(self._gen, i)() for i in range(n)] - self.n = n - - def _gen(self, index): - while True: - if all(self.read): - try: - self.last = next(self.iterable) - except StopIteration: - return - - assert len(self.last) == self.n - self.read = [False] * self.n - - if self.read[index]: - raise IndexError(index) - self.read[index] = True - yield self.last[index] - - def split(iterable, n): - return Splitter(iterable, n).generators + def _gen(it, index): + for item in it: + yield item[index] + iterables = itertools.tee(iterable, n) + return [_gen(it, index) for index, it in enumerate(iterables)] def icapture(iterable, result): for i in iter(iterable): diff --git a/tests/test_common.py b/tests/test_common.py index b81b3c6..8684692 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -27,15 +27,6 @@ def test_split(): iters = common.split(L, n=2) assert list(zip(*iters)) == L - for i in [0, 1]: - iters = common.split(L, n=2) - next(iters[i]) - try: - next(iters[i]) - assert False - except IndexError as e: - assert e.args == (i,) - def test_icapture(): x = range(100)