From 5508ffc0d3ae7ed1863f97bc427b9cb675d9faf9 Mon Sep 17 00:00:00 2001 From: Roman Zeyde Date: Fri, 1 Aug 2014 10:41:47 +0300 Subject: [PATCH] Fix common.iterate() --- common.py | 3 +-- sigproc.py | 6 ++---- test_common.py | 2 +- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/common.py b/common.py index 7c70d23..da9344c 100644 --- a/common.py +++ b/common.py @@ -65,8 +65,7 @@ def iterate(data, size, func=None, truncate=True): return done = True - buf = np.array(buf) - result = func(buf) if func else buf + result = func(buf) if func else np.array(buf) yield offset, result offset += size diff --git a/sigproc.py b/sigproc.py index c777fc9..776e108 100644 --- a/sigproc.py +++ b/sigproc.py @@ -96,11 +96,9 @@ def coherence(x, freq): def extract_symbols(x, freq, offset=0): Hc = exp_iwt(-freq, Nsym) / (0.5*Nsym) - func = lambda y: np.dot(Hc, y) - iterator = common.iterate(x, Nsym, func=func) - for _, symbol in iterator: - yield symbol + for _, symbol in common.iterate(x, Nsym): + yield np.dot(Hc, symbol) def drift(S): diff --git a/test_common.py b/test_common.py index 3f985cc..bff44be 100644 --- a/test_common.py +++ b/test_common.py @@ -18,7 +18,7 @@ def test_iterate(): assert iterlist(range(N), 3) == [ (i, [i, i+1, i+2]) for i in range(0, N-2, 3)] - assert iterlist(range(N), 1, func=lambda b: -b) == [ + assert iterlist(range(N), 1, func=lambda b: -np.array(b)) == [ (i, [-i]) for i in range(N)]