diff --git a/amodem/equalizer.py b/amodem/equalizer.py index eb66684..72ee21c 100644 --- a/amodem/equalizer.py +++ b/amodem/equalizer.py @@ -34,7 +34,7 @@ def demodulator(signal, size): return np.array(list(itertools.islice(symbols, size))) -def equalize(signal, symbols, order, lookahead=0): +def equalize_symbols(signal, symbols, order, lookahead=0): Nsym = config.Nsym Nfreq = config.Nfreq carriers = modem.carriers @@ -50,14 +50,12 @@ def equalize(signal, symbols, order, lookahead=0): A = [] b = [] - index = 0 for j in range(Nfreq): for i in range(length): offset = (i+1)*Nsym row = y[offset-order:offset+lookahead, j] A.append(row) b.append(symbols[i, j]) - index += 1 A = np.array(A) b = np.array(b) @@ -65,3 +63,23 @@ def equalize(signal, symbols, order, lookahead=0): h = h[::-1].real return h + + +def equalize_signal(signal, expected, order, lookahead=0): + signal = np.concatenate([np.zeros(order-1), signal, np.zeros(lookahead)]) + length = len(expected) + + A = [] + b = [] + + for i in range(length - order): + offset = order + i + row = signal[offset-order:offset+lookahead] + A.append(np.array(row, ndmin=2)) + b.append(expected[i]) + + A = np.concatenate(A, axis=0) + b = np.array(b) + h, residuals, rank, sv = lstsq(A, b) + h = h[::-1].real + return h diff --git a/amodem/recv.py b/amodem/recv.py index e9caabc..a94343d 100644 --- a/amodem/recv.py +++ b/amodem/recv.py @@ -133,7 +133,7 @@ class Receiver(object): signal = sampler.take(signal_length + lookahead) - coeffs = equalizer.equalize( + coeffs = equalizer.equalize_symbols( signal=signal[prefix:-postfix], symbols=train_symbols, order=order, lookahead=lookahead diff --git a/tests/test_equalizer.py b/tests/test_equalizer.py index b26fb1f..713263f 100644 --- a/tests/test_equalizer.py +++ b/tests/test_equalizer.py @@ -1,4 +1,5 @@ from numpy.linalg import norm +from numpy.random import RandomState import numpy as np from amodem import dsp @@ -41,7 +42,7 @@ def test_modem(): assert_approx(sent, received) -def test_isi(): +def test_symbols(): length = 100 gain = float(config.Nfreq) @@ -54,10 +55,31 @@ def test_isi(): y = dsp.lfilter(x=x, b=num, a=den) lookahead = 2 - h = equalizer.equalize(y, symbols, order=len(den), lookahead=lookahead) + h = equalizer.equalize_symbols( + signal=y, symbols=symbols, order=len(den), lookahead=lookahead + ) assert norm(h[:lookahead]) < 1e-12 assert_approx(h[lookahead:], den / num) y = dsp.lfilter(x=y, b=h[lookahead:], a=[1]) z = equalizer.demodulator(y, size=length) assert_approx(z, symbols) + + +def test_signal(): + length = 100 + x = np.sign(RandomState(0).normal(size=length)) + den = np.array([1, -0.6, 0.1]) + num = np.array([0.5]) + y = dsp.lfilter(x=x, b=num, a=den) + + lookahead = 2 + h = equalizer.equalize_signal( + signal=y, expected=x, order=len(den), lookahead=lookahead) + assert norm(h[:lookahead]) < 1e-12 + + h = h[lookahead:] + assert_approx(h, den / num) + + x_ = dsp.lfilter(x=y, b=h, a=[1]) + assert_approx(x_, x)