mirror of
https://github.com/romanz/amodem.git
synced 2026-03-06 14:55:56 +08:00
fix equalization PoC
This commit is contained in:
@@ -3,6 +3,7 @@ from numpy.linalg import norm, lstsq
|
||||
import numpy as np
|
||||
import itertools
|
||||
|
||||
|
||||
def test_fir():
|
||||
a = [1, 0.8, -0.1, 0, 0]
|
||||
tx = train.equalizer
|
||||
@@ -29,11 +30,13 @@ import random
|
||||
|
||||
_constellation = [1, 1j, -1, -1j]
|
||||
|
||||
def train_symbols(length, seed=0):
|
||||
|
||||
def train_symbols(length, seed=0, Nfreq=config.Nfreq):
|
||||
r = random.Random(seed)
|
||||
choose = lambda: [r.choice(_constellation) for j in range(config.Nfreq)]
|
||||
choose = lambda: [r.choice(_constellation) for j in range(Nfreq)]
|
||||
return np.array([choose() for i in range(length)])
|
||||
|
||||
|
||||
def modulator(length):
|
||||
symbols = train_symbols(length)
|
||||
carriers = send.sym.carrier
|
||||
@@ -45,16 +48,19 @@ def modulator(length):
|
||||
assert np.max(np.abs(result)) <= 1
|
||||
return result
|
||||
|
||||
|
||||
def demodulator(signal):
|
||||
signal = itertools.chain(signal, itertools.repeat(0))
|
||||
return dsp.Demux(signal, config.frequencies)
|
||||
|
||||
|
||||
def test_training():
|
||||
L = 1000
|
||||
t1 = train_symbols(L)
|
||||
t2 = train_symbols(L)
|
||||
assert (t1 == t2).all()
|
||||
|
||||
|
||||
def test_commutation():
|
||||
x = np.random.RandomState(seed=0).normal(size=1000)
|
||||
b = [1, 1j, -1, -1j]
|
||||
@@ -69,66 +75,48 @@ def test_commutation():
|
||||
z_ = dsp.lfilter(x=x, b=b, a=[1])
|
||||
assert norm(z - z_) < 1e-10
|
||||
|
||||
def equalize(signal, carriers, symbols, order):
|
||||
''' symbols[k] = (signal * h) * filters[k] '''
|
||||
signal = np.array(signal)
|
||||
scaling = (2.0/config.Nsym)
|
||||
carriers = np.array(carriers).conj() * scaling
|
||||
symbol_stream = []
|
||||
for i in range(len(signal) - config.Nsym + 1):
|
||||
frame = signal[i:i+config.Nsym]
|
||||
symbol_stream.append(np.dot(carriers, frame))
|
||||
symbol_stream = np.array(symbol_stream)
|
||||
assert symbol_stream.shape[1] == config.Nfreq
|
||||
LHS = []
|
||||
RHS = []
|
||||
offsets = range(0, len(symbol_stream) - order + 1, config.Nsym)
|
||||
for j in range(config.Nfreq):
|
||||
for i, offset in enumerate(offsets):
|
||||
row = list(symbol_stream[offset:offset+order, j])
|
||||
LHS.append(row)
|
||||
RHS.append(symbols[i, j])
|
||||
|
||||
LHS = np.array(LHS)
|
||||
RHS = np.array(RHS)
|
||||
return lstsq(LHS, RHS)[0]
|
||||
|
||||
def test_modem():
|
||||
L = 1000
|
||||
sent = train_symbols(L)
|
||||
gain = len(send.sym.carrier)
|
||||
x = modulator(L) * gain
|
||||
y = dsp.lfilter(x=x, b=[0, 4], a=[1])
|
||||
h_ = equalize(signal=y, carriers=send.sym.carrier, symbols=sent, order=2)
|
||||
assert norm(h_ - [0, 0.25]) < 1e-10
|
||||
|
||||
s = demodulator(x)
|
||||
received = np.array(list(itertools.islice(s, L)))
|
||||
err = sent - received
|
||||
assert norm(err) < 1e-10
|
||||
|
||||
def test_concept():
|
||||
|
||||
def test_equalizer():
|
||||
N = 32
|
||||
s = [1] * 10 + [-1] * 10
|
||||
s = train_symbols(length=100, Nfreq=1).real.squeeze()
|
||||
x = [v for v in s for i in range(N)]
|
||||
matched = [1.0 / N] * N
|
||||
den = np.array([1, -0.1])
|
||||
num = np.array([1])
|
||||
y = dsp.lfilter(x=x, b=num, a=den)
|
||||
y1 = dsp.lfilter(x=y, b=matched, a=[1])
|
||||
#y2 = dsp.lfilter(x=y1, b=den/num, a=[1])
|
||||
z = dsp.lfilter(x=x, b=matched, a=[1])
|
||||
assert norm(z[N-1::N] - s) < 1e-12
|
||||
|
||||
den = np.array([1, 0.125])
|
||||
num = np.array([1])
|
||||
y = dsp.lfilter(x=x, b=num, a=den)
|
||||
y = dsp.lfilter(x=y, b=matched, a=[1])
|
||||
|
||||
A = []
|
||||
b = []
|
||||
|
||||
r = 2
|
||||
for i in range(len(s)):
|
||||
offset = (i+1)*N
|
||||
row = y1[offset-r:offset]
|
||||
row = y[offset-r:offset]
|
||||
A.append(row)
|
||||
b.append(s[i])
|
||||
A = np.array(A)
|
||||
b = np.array(b)
|
||||
h = lstsq(A, b)[0][::-1]
|
||||
assert norm(h - den) < 1e-12
|
||||
h, residuals, rank, sv = lstsq(A, b)
|
||||
h = h[::-1]
|
||||
print(h)
|
||||
|
||||
y1 = dsp.lfilter(x=x, b=num, a=den)
|
||||
y2 = dsp.lfilter(x=y1, b=h, a=[1])
|
||||
y3 = dsp.lfilter(x=y2, b=matched, a=[1])
|
||||
z = y3[N-1::N]
|
||||
assert norm(z - s) < 1e-12
|
||||
|
||||
Reference in New Issue
Block a user