diff --git a/amodem/equalizer.py b/amodem/equalizer.py index 3b7abd0..5189684 100644 --- a/amodem/equalizer.py +++ b/amodem/equalizer.py @@ -38,22 +38,23 @@ class Equalizer(object): omegas=self.omegas, Nsym=self.Nsym) return np.array(list(itertools.islice(symbols, size))) - def equalize_signal(self, signal, expected, order, lookahead=0): - signal = [np.zeros(order-1), signal, np.zeros(lookahead)] - signal = np.concatenate(signal) - length = len(expected) - A = [] - b = [] +def train(signal, expected, order, lookahead=0): + signal = [np.zeros(order-1), signal, np.zeros(lookahead)] + signal = np.concatenate(signal) + length = len(expected) - for i in range(length - order): - offset = order + i - row = signal[offset-order:offset+lookahead] - A.append(np.array(row, ndmin=2)) - b.append(expected[i]) + A = [] + b = [] - A = np.concatenate(A, axis=0) - b = np.array(b) - h = lstsq(A, b)[0] - h = h[::-1].real - return h + for i in range(length - order): + offset = order + i + row = signal[offset-order:offset+lookahead] + A.append(np.array(row, ndmin=2)) + b.append(expected[i]) + + A = np.concatenate(A, axis=0) + b = np.array(b) + h = lstsq(A, b)[0] + h = h[::-1].real + return h diff --git a/amodem/recv.py b/amodem/recv.py index 5269968..1a22968 100644 --- a/amodem/recv.py +++ b/amodem/recv.py @@ -75,7 +75,7 @@ class Receiver(object): signal = sampler.take(signal_length + lookahead) - coeffs = self.equalizer.equalize_signal( + coeffs = equalizer.train( signal=signal[prefix:-postfix], expected=train_signal, order=order, lookahead=lookahead diff --git a/tests/test_equalizer.py b/tests/test_equalizer.py index 75fd5f9..a18a7ff 100644 --- a/tests/test_equalizer.py +++ b/tests/test_equalizer.py @@ -51,10 +51,9 @@ def test_signal(): den = np.array([1, -0.6, 0.1]) num = np.array([0.5]) y = dsp.lfilter(x=x, b=num, a=den) - e = equalizer.Equalizer(config) lookahead = 2 - h = e.equalize_signal( + h = equalizer.train( signal=y, expected=x, order=len(den), lookahead=lookahead) assert norm(h[:lookahead]) < 1e-12