diff --git a/amodem/sigproc.py b/amodem/sigproc.py index b3be754..dfe818a 100644 --- a/amodem/sigproc.py +++ b/amodem/sigproc.py @@ -27,6 +27,12 @@ class Filter(object): yield y +def lfilter(b, a, x): + f = Filter(b=b, a=a) + y = list(f(x)) + return np.array(y) + + def train(S, training): A = np.array([S[1:], S[:-1], training[:-1]]).T b = training[1:] diff --git a/tests/test_full.py b/tests/test_full.py index 9d5202d..2ec6aba 100644 --- a/tests/test_full.py +++ b/tests/test_full.py @@ -36,18 +36,13 @@ def run(size, chan): assert rx_data == tx_data -def apply_filter(b, a, x): - f = sigproc.Filter(b=b, a=a) - y = list(f(list(x))) - return np.array(y) - def test_lowpass(): - run(1024, lambda x: apply_filter(b=[0.9], a=[1.0, -0.1], x=x)) + run(1024, lambda x: sigproc.lfilter(b=[0.9], a=[1.0, -0.1], x=x)) def test_highpass(): - run(1024, lambda x: apply_filter(b=[0.9], a=[1.0, 0.1], x=x)) + run(1024, lambda x: sigproc.lfilter(b=[0.9], a=[1.0, 0.1], x=x)) def test_small(): diff --git a/tests/test_sigproc.py b/tests/test_sigproc.py index 6ee85e3..cbad696 100644 --- a/tests/test_sigproc.py +++ b/tests/test_sigproc.py @@ -26,12 +26,10 @@ def test_linreg(): def test_filter(): - f = sigproc.Filter(b=[1], a=[1]) x = range(10) - y = list(f(x)) - assert [float(i) for i in x] == y + y = sigproc.lfilter(b=[1], a=[1], x=x) + assert (np.array(x) == y).all() - f = sigproc.Filter(b=[0.5], a=[1, -0.5]) x = [1] + [0] * 10 - y = list(f(x)) - assert y == [0.5 ** (i+1) for i in range(len(x))] + y = sigproc.lfilter(b=[0.5], a=[1, -0.5], x=x) + assert list(y) == [0.5 ** (i+1) for i in range(len(x))]