Fix common.iterate()

This commit is contained in:
Roman Zeyde
2014-08-01 10:41:47 +03:00
parent ce00e94974
commit 5508ffc0d3
3 changed files with 4 additions and 7 deletions

View File

@@ -65,8 +65,7 @@ def iterate(data, size, func=None, truncate=True):
return
done = True
buf = np.array(buf)
result = func(buf) if func else buf
result = func(buf) if func else np.array(buf)
yield offset, result
offset += size

View File

@@ -96,11 +96,9 @@ def coherence(x, freq):
def extract_symbols(x, freq, offset=0):
Hc = exp_iwt(-freq, Nsym) / (0.5*Nsym)
func = lambda y: np.dot(Hc, y)
iterator = common.iterate(x, Nsym, func=func)
for _, symbol in iterator:
yield symbol
for _, symbol in common.iterate(x, Nsym):
yield np.dot(Hc, symbol)
def drift(S):

View File

@@ -18,7 +18,7 @@ def test_iterate():
assert iterlist(range(N), 3) == [
(i, [i, i+1, i+2]) for i in range(0, N-2, 3)]
assert iterlist(range(N), 1, func=lambda b: -b) == [
assert iterlist(range(N), 1, func=lambda b: -np.array(b)) == [
(i, [-i]) for i in range(N)]