From 6a2e320808b86b5e82528d893979bace50717991 Mon Sep 17 00:00:00 2001 From: Roman Zeyde Date: Wed, 18 Feb 2015 18:14:57 +0200 Subject: [PATCH] equalizer: replace Least-Square solver by Levinson-Durbin recursion --- amodem/equalizer.py | 40 ++++++++++++++++------------------------ amodem/levinson.py | 30 ++++++++++++++++++++++++++++++ amodem/recv.py | 3 ++- tests/test_equalizer.py | 4 +++- 4 files changed, 51 insertions(+), 26 deletions(-) create mode 100644 amodem/levinson.py diff --git a/amodem/equalizer.py b/amodem/equalizer.py index de43363..ff21f54 100644 --- a/amodem/equalizer.py +++ b/amodem/equalizer.py @@ -1,9 +1,8 @@ from . import dsp from . import sampling +from . import levinson import numpy as np -from numpy.linalg import lstsq - import itertools @@ -44,28 +43,21 @@ class Equalizer(object): return np.array(list(itertools.islice(symbols, size))) -def train(signal, expected, order, lookahead=0): - signal = [np.zeros(order-1), signal, np.zeros(lookahead)] - signal = np.concatenate(signal) - length = len(expected) - - A = [] - b = [] - # construct Ah=b over-constrained equation system, - # used for least-squares estimation of the filter. - for i in range(length - order): - offset = order + i - row = signal[offset-order:offset+lookahead] - A.append(np.array(row, ndmin=2)) - b.append(expected[i]) - - A = np.concatenate(A, axis=0) - b = np.array(b) - h = lstsq(A, b)[0] - h = h[::-1].real - return h - - prefix = [1]*400 + [0]*50 equalizer_length = 500 silence_length = 100 + + +def train(signal, expected, order, lookahead=0): + padding = np.zeros(lookahead) + assert len(signal) == len(expected) + x = np.concatenate([signal, padding]) + y = np.concatenate([padding, expected]) + + N = order + lookahead # filter length + Rxx = np.zeros(N) + Rxy = np.zeros(N) + for i in range(N): + Rxx[i] = np.dot(x[i:], x[:len(x)-i]) + Rxy[i] = np.dot(y[i:], x[:len(x)-i]) + return levinson.solver(t=Rxx, y=Rxy) diff --git a/amodem/levinson.py b/amodem/levinson.py new file mode 100644 index 0000000..d2f2cd8 --- /dev/null +++ b/amodem/levinson.py @@ -0,0 +1,30 @@ +import numpy as np + + +def solver(t, y): + ''' Solve Mx = y for x, where M[i,j] = t[|i-j|], in O(N^2) steps. + See http://en.wikipedia.org/wiki/Levinson_recursion for details. + ''' + N = len(t) + assert len(y) == N + + t0 = np.array([1.0 / t[0]]) + f = [t0] # forward vectors + b = [t0] # backward vectors + for n in range(1, N): + prev_f = f[-1] + prev_b = b[-1] + ef = sum(t[n-i] * prev_f[i] for i in range(n)) + eb = sum(t[i+1] * prev_b[i] for i in range(n)) + f_ = np.concatenate([f[-1], [0]]) + b_ = np.concatenate([[0], b[-1]]) + det = 1.0 - ef * eb + f.append((f_ - ef * b_) / det) + b.append((b_ - eb * f_) / det) + + x = [] + for n in range(N): + x = np.concatenate([x, [0]]) + ef = sum(t[n-i] * x[i] for i in range(n)) + x = x + (y[n] - ef) * b[n] + return x diff --git a/amodem/recv.py b/amodem/recv.py index 350ed4e..7bfcce9 100644 --- a/amodem/recv.py +++ b/amodem/recv.py @@ -60,7 +60,7 @@ class Receiver(object): coeffs = equalizer.train( signal=signal[prefix:-postfix], - expected=train_signal, + expected=np.concatenate([train_signal, np.zeros(lookahead)]), order=order, lookahead=lookahead ) @@ -68,6 +68,7 @@ class Receiver(object): self.plt.plot(np.arange(order+lookahead), coeffs) equalization_filter = dsp.FIR(h=coeffs) + # Pre-load equalization filter with the signal (+lookahead) equalized = list(equalization_filter(signal)) equalized = equalized[prefix+lookahead:-postfix+lookahead] self._verify_training(equalized, train_symbols) diff --git a/tests/test_equalizer.py b/tests/test_equalizer.py index 3caa800..3811887 100644 --- a/tests/test_equalizer.py +++ b/tests/test_equalizer.py @@ -4,6 +4,7 @@ import numpy as np import utils from amodem import equalizer +from amodem import levinson from amodem import config config = config.fastest() @@ -46,8 +47,9 @@ def test_modem(): def test_signal(): - length = 100 + length = 120 x = np.sign(RandomState(0).normal(size=length)) + x[-20:] = 0 # make sure the signal has bounded support den = np.array([1, -0.6, 0.1]) num = np.array([0.5]) y = utils.lfilter(x=x, b=num, a=den)