parallel iterator split

This commit is contained in:
Roman Zeyde
2014-07-07 16:42:00 +03:00
committed by Roman Zeyde
parent 000e96b40d
commit 9981f280f4
2 changed files with 41 additions and 1 deletions

View File

@@ -1,4 +1,4 @@
import cStringIO
import functools
import numpy as np
import logging
@@ -68,6 +68,33 @@ def iterate(data, bufsize, offset=0, advance=1, func=None):
buf_index = max(0, buf_index - advance)
offset += advance
class Splitter(object):
def __init__(self, iterable, n):
self.iterable = iter(iterable)
self.read = [True] * n
self.last = None
self.generators = [functools.partial(self._gen, i)() for i in range(n)]
self.n = n
def _gen(self, index):
while True:
if all(self.read):
try:
self.last = self.iterable.next()
except StopIteration:
return
assert len(self.last) == self.n
self.read = [False] * self.n
if self.read[index]:
raise IndexError(index)
self.read[index] = True
yield self.last[index]
def split(iterable, n):
return Splitter(iterable, n).generators
if __name__ == '__main__':
import pylab

View File

@@ -16,3 +16,16 @@ def test_iterate():
assert iterlist(range(N), 2, offset=5) == [(i, [i, i+1]) for i in range(5, N-1)]
assert iterlist(range(N), 1, func=lambda b: -b) == [(i, [-i]) for i in range(N)]
def test_split():
L = [(i*2, i*2+1) for i in range(10)]
iters = common.split(L, n=2)
assert zip(*iters) == L
for i in [0, 1]:
iters = common.split(L, n=2)
iters[i].next()
try:
iters[i].next()
assert False
except IndexError as e:
assert e.args == (i,)