detect: refactor receiver for large frequency drifts (~0.1%)

This commit is contained in:
Roman Zeyde
2015-01-15 18:14:16 +02:00
parent 5401206178
commit 1da258ebf8
5 changed files with 74 additions and 53 deletions

View File

@@ -20,13 +20,14 @@ class Detector(object):
TIMEOUT = 10.0 # [seconds]
def __init__(self, config):
def __init__(self, config, pylab):
self.freq = config.Fc
self.omega = 2 * np.pi * self.freq / config.Fs
self.Nsym = config.Nsym
self.Tsym = config.Tsym
self.maxlen = config.baud # 1 second of symbols
self.max_offset = self.TIMEOUT * config.Fs
self.plt = pylab
def _wait(self, samples):
counter = 0
@@ -71,16 +72,46 @@ class Detector(object):
bufs.append(np.array(trailing))
buf = np.concatenate(bufs)
offset = self.find_start(buf, self.CARRIER_DURATION*self.Nsym)
offset = self.find_start(buf, duration=self.CARRIER_DURATION)
start_time += (offset / self.Nsym - self.SEARCH_WINDOW) * self.Tsym
log.debug('Carrier starts at %.3f ms', start_time * 1e3)
return itertools.chain(buf[offset:], samples), amplitude
buf = buf[offset:]
def find_start(self, buf, length):
N = len(buf)
carrier = dsp.exp_iwt(self.omega, N)
z = np.cumsum(buf * carrier)
z = np.concatenate([[0], z])
correlations = np.abs(z[length:] - z[:-length])
return np.argmax(correlations)
prefix_length = self.CARRIER_DURATION * self.Nsym
amplitude, freq_err = self.estimate(buf[:prefix_length])
return itertools.chain(buf, samples), amplitude, freq_err
def find_start(self, buf, duration):
filt = dsp.FIR(dsp.exp_iwt(self.omega, self.Nsym))
p = np.abs(list(filt(buf))) ** 2
p = np.cumsum(p)[self.Nsym-1:]
p = np.concatenate([[0], p])
length = (duration - 1) * self.Nsym
correlations = np.abs(p[length:] - p[:-length])
offset = np.argmax(correlations)
return offset
def estimate(self, buf, skip=5):
filt = dsp.exp_iwt(-self.omega, self.Nsym) / (0.5 * self.Nsym)
frames = common.iterate(buf, self.Nsym)
symbols = [np.dot(filt, frame) for frame in frames]
symbols = np.array(symbols[skip:-skip])
amplitude = np.mean(np.abs(symbols))
log.debug('Carrier symbols amplitude : %.3f', amplitude)
phase = np.unwrap(np.angle(symbols)) / (2 * np.pi)
indices = np.arange(len(phase))
a, b = dsp.linear_regression(indices, phase)
self.plt.figure()
self.plt.plot(indices, phase, ':')
self.plt.plot(indices, a * indices + b)
freq_err = a / (self.Tsym * self.freq)
last_phase = a * indices[-1] + b
log.debug('Current phase on carrier: %.3f', last_phase)
log.debug('Frequency error: %.2f ppm', freq_err * 1e6)
self.plt.title('Frequency drift: {0:.3f} ppm'.format(freq_err * 1e6))
return amplitude, freq_err

View File

@@ -19,7 +19,7 @@ class Receiver(object):
def __init__(self, config, pylab=None):
self.stats = {}
self.plt = pylab or common.Dummy()
self.plt = pylab
self.modem = dsp.MODEM(config.symbols)
self.frequencies = np.array(config.frequencies)
self.omegas = 2 * np.pi * self.frequencies / config.Fs
@@ -31,7 +31,7 @@ class Receiver(object):
self.carrier_index = config.carrier_index
self.output_size = 0 # number of bytes written to output stream
def _prefix(self, symbols, gain=1.0, skip=5):
def _prefix(self, symbols, gain=1.0):
S = common.take(symbols, len(equalizer.prefix))
S = S[:, self.carrier_index] * gain
sliced = np.round(np.abs(S))
@@ -45,26 +45,8 @@ class Receiver(object):
self.plt.plot(equalizer.prefix)
if any(bits != equalizer.prefix):
raise ValueError('Incorrect prefix')
log.debug('Prefix OK')
nonzeros = np.array(equalizer.prefix, dtype=bool)
pilot_tone = S[nonzeros]
phase = np.unwrap(np.angle(pilot_tone)) / (2 * np.pi)
indices = np.arange(len(phase))
a, b = dsp.linear_regression(indices[skip:-skip], phase[skip:-skip])
self.plt.figure()
self.plt.plot(indices, phase, ':')
self.plt.plot(indices, a * indices + b)
freq_err = a / (self.Tsym * self.frequencies[self.carrier_index])
last_phase = a * indices[-1] + b
log.debug('Current phase on carrier: %.3f', last_phase)
log.debug('Frequency error: %.2f ppm', freq_err * 1e6)
self.plt.title('Frequency drift: {0:.3f} ppm'.format(freq_err * 1e6))
return freq_err
def _train(self, sampler, order, lookahead):
Nfreq = len(self.frequencies)
equalizer_length = equalizer.equalizer_length
@@ -158,13 +140,11 @@ class Receiver(object):
(1.0 - sampler.freq) * 1e6
)
def run(self, signal, gain, output):
sampler = sampling.Sampler(signal, sampling.Interpolator())
def run(self, sampler, gain, output):
symbols = dsp.Demux(sampler, omegas=self.omegas, Nsym=self.Nsym)
freq_err = self._prefix(symbols, gain=gain)
sampler.freq -= freq_err
self._prefix(symbols, gain=gain)
filt = self._train(sampler, order=11, lookahead=11)
filt = self._train(sampler, order=20, lookahead=20)
sampler.equalizer = lambda x: list(filt(x))
bitstream = self._demodulate(sampler, symbols)
@@ -235,13 +215,22 @@ def main(config, src, dst, dump_audio=None, pylab=None):
log.debug('Skipping %.3f seconds', config.skip_start)
common.take(signal, to_skip)
detector = detect.Detector(config=config)
pylab = pylab or common.Dummy()
detector = detect.Detector(config=config, pylab=pylab)
receiver = Receiver(config=config, pylab=pylab)
success = False
try:
log.info('Waiting for carrier tone: %.1f kHz', config.Fc / 1e3)
signal, amplitude = detector.run(signal)
receiver.run(signal, gain=1.0/amplitude, output=dst)
signal, amplitude, freq_error = detector.run(signal)
freq = 1 / (1.0 + freq_error) # receiver's compensated frequency
log.debug('Frequency correction: %.3f ppm', (freq - 1) * 1e6)
gain = 1.0 / amplitude
log.debug('Gain correction: %.3f', gain)
sampler = sampling.Sampler(signal, sampling.Interpolator(), freq=freq)
receiver.run(sampler, gain=1.0/amplitude, output=dst)
success = True
except Exception:
log.exception('Decoding failed')

View File

@@ -31,8 +31,8 @@ class Interpolator(object):
class Sampler(object):
def __init__(self, src, interp=None):
self.freq = 1.0 # normalized
def __init__(self, src, interp=None, freq=1.0):
self.freq = freq
self.equalizer = lambda x: x # LTI equalization filter
if interp is not None:
self.interp = interp

View File

@@ -7,6 +7,7 @@ from amodem import detect
from amodem import equalizer
from amodem import sampling
from amodem import config
from amodem import common
config = config.fastest()
@@ -15,9 +16,10 @@ def test_detect():
t = np.arange(P * config.Nsym) * config.Ts
x = np.cos(2 * np.pi * config.Fc * t)
detector = detect.Detector(config)
samples, amp = detector.run(x)
detector = detect.Detector(config, pylab=common.Dummy())
samples, amp, freq_err = detector.run(x)
assert abs(1 - amp) < 1e-12
assert abs(freq_err) < 1e-16
x = np.cos(2 * np.pi * (2*config.Fc) * t)
with pytest.raises(ValueError):
@@ -36,9 +38,8 @@ def test_prefix():
def symbols_stream(signal):
sampler = sampling.Sampler(signal)
return dsp.Demux(sampler=sampler, omegas=[omega], Nsym=config.Nsym)
r = recv.Receiver(config)
freq_err = r._prefix(symbols_stream(signal))
assert abs(freq_err) < 1e-16
r = recv.Receiver(config, pylab=common.Dummy())
r._prefix(symbols_stream(signal))
with pytest.raises(ValueError):
silence = 0 * signal
@@ -47,15 +48,14 @@ def test_prefix():
def test_find_start():
sym = np.cos(2 * np.pi * config.Fc * np.arange(config.Nsym) * config.Ts)
detector = detect.Detector(config)
detector = detect.Detector(config, pylab=common.Dummy())
length = 200
prefix = postfix = np.tile(0 * sym, 50)
carrier = np.tile(sym, length)
for offset in range(10):
prefix = [0] * offset
bufs = [prefix, prefix, carrier, postfix]
for offset in range(32):
bufs = [prefix, [0] * offset, carrier, postfix]
buf = np.concatenate(bufs)
start = detector.find_start(buf, length*config.Nsym)
start = detector.find_start(buf, length)
expected = offset + len(prefix)
assert expected == start

View File

@@ -45,7 +45,8 @@ def run(size, chan=None, df=0, success=True):
rx_data = BytesIO()
d = BytesIO()
result = recv.main(config=config, src=rx_audio, dst=rx_data, dump_audio=d)
result = recv.main(config=config, src=rx_audio, dst=rx_data,
dump_audio=d)
rx_data = rx_data.getvalue()
assert data.startswith(d.getvalue())
@@ -68,8 +69,8 @@ def test_error():
run(1024, chan=lambda x: x[:-skip], success=False)
@pytest.fixture(params=[sign * (10.0 ** exp) for sign in (+1, -1)
for exp in (-1, 0, 1, 2, 3)])
@pytest.fixture(params=[sign * mag for sign in (+1, -1)
for mag in (0.1, 1, 10, 100, 1e3, 2e3)])
def freq_err(request):
return request.param * 1e-6