detect: refactor receiver for large frequency drifts (~0.1%)

This commit is contained in:
Roman Zeyde
2015-01-15 18:14:16 +02:00
parent 5401206178
commit 1da258ebf8
5 changed files with 74 additions and 53 deletions

View File

@@ -20,13 +20,14 @@ class Detector(object):
TIMEOUT = 10.0 # [seconds] TIMEOUT = 10.0 # [seconds]
def __init__(self, config): def __init__(self, config, pylab):
self.freq = config.Fc self.freq = config.Fc
self.omega = 2 * np.pi * self.freq / config.Fs self.omega = 2 * np.pi * self.freq / config.Fs
self.Nsym = config.Nsym self.Nsym = config.Nsym
self.Tsym = config.Tsym self.Tsym = config.Tsym
self.maxlen = config.baud # 1 second of symbols self.maxlen = config.baud # 1 second of symbols
self.max_offset = self.TIMEOUT * config.Fs self.max_offset = self.TIMEOUT * config.Fs
self.plt = pylab
def _wait(self, samples): def _wait(self, samples):
counter = 0 counter = 0
@@ -71,16 +72,46 @@ class Detector(object):
bufs.append(np.array(trailing)) bufs.append(np.array(trailing))
buf = np.concatenate(bufs) buf = np.concatenate(bufs)
offset = self.find_start(buf, self.CARRIER_DURATION*self.Nsym) offset = self.find_start(buf, duration=self.CARRIER_DURATION)
start_time += (offset / self.Nsym - self.SEARCH_WINDOW) * self.Tsym start_time += (offset / self.Nsym - self.SEARCH_WINDOW) * self.Tsym
log.debug('Carrier starts at %.3f ms', start_time * 1e3) log.debug('Carrier starts at %.3f ms', start_time * 1e3)
return itertools.chain(buf[offset:], samples), amplitude buf = buf[offset:]
def find_start(self, buf, length): prefix_length = self.CARRIER_DURATION * self.Nsym
N = len(buf) amplitude, freq_err = self.estimate(buf[:prefix_length])
carrier = dsp.exp_iwt(self.omega, N) return itertools.chain(buf, samples), amplitude, freq_err
z = np.cumsum(buf * carrier)
z = np.concatenate([[0], z]) def find_start(self, buf, duration):
correlations = np.abs(z[length:] - z[:-length]) filt = dsp.FIR(dsp.exp_iwt(self.omega, self.Nsym))
return np.argmax(correlations) p = np.abs(list(filt(buf))) ** 2
p = np.cumsum(p)[self.Nsym-1:]
p = np.concatenate([[0], p])
length = (duration - 1) * self.Nsym
correlations = np.abs(p[length:] - p[:-length])
offset = np.argmax(correlations)
return offset
def estimate(self, buf, skip=5):
filt = dsp.exp_iwt(-self.omega, self.Nsym) / (0.5 * self.Nsym)
frames = common.iterate(buf, self.Nsym)
symbols = [np.dot(filt, frame) for frame in frames]
symbols = np.array(symbols[skip:-skip])
amplitude = np.mean(np.abs(symbols))
log.debug('Carrier symbols amplitude : %.3f', amplitude)
phase = np.unwrap(np.angle(symbols)) / (2 * np.pi)
indices = np.arange(len(phase))
a, b = dsp.linear_regression(indices, phase)
self.plt.figure()
self.plt.plot(indices, phase, ':')
self.plt.plot(indices, a * indices + b)
freq_err = a / (self.Tsym * self.freq)
last_phase = a * indices[-1] + b
log.debug('Current phase on carrier: %.3f', last_phase)
log.debug('Frequency error: %.2f ppm', freq_err * 1e6)
self.plt.title('Frequency drift: {0:.3f} ppm'.format(freq_err * 1e6))
return amplitude, freq_err

View File

@@ -19,7 +19,7 @@ class Receiver(object):
def __init__(self, config, pylab=None): def __init__(self, config, pylab=None):
self.stats = {} self.stats = {}
self.plt = pylab or common.Dummy() self.plt = pylab
self.modem = dsp.MODEM(config.symbols) self.modem = dsp.MODEM(config.symbols)
self.frequencies = np.array(config.frequencies) self.frequencies = np.array(config.frequencies)
self.omegas = 2 * np.pi * self.frequencies / config.Fs self.omegas = 2 * np.pi * self.frequencies / config.Fs
@@ -31,7 +31,7 @@ class Receiver(object):
self.carrier_index = config.carrier_index self.carrier_index = config.carrier_index
self.output_size = 0 # number of bytes written to output stream self.output_size = 0 # number of bytes written to output stream
def _prefix(self, symbols, gain=1.0, skip=5): def _prefix(self, symbols, gain=1.0):
S = common.take(symbols, len(equalizer.prefix)) S = common.take(symbols, len(equalizer.prefix))
S = S[:, self.carrier_index] * gain S = S[:, self.carrier_index] * gain
sliced = np.round(np.abs(S)) sliced = np.round(np.abs(S))
@@ -45,26 +45,8 @@ class Receiver(object):
self.plt.plot(equalizer.prefix) self.plt.plot(equalizer.prefix)
if any(bits != equalizer.prefix): if any(bits != equalizer.prefix):
raise ValueError('Incorrect prefix') raise ValueError('Incorrect prefix')
log.debug('Prefix OK') log.debug('Prefix OK')
nonzeros = np.array(equalizer.prefix, dtype=bool)
pilot_tone = S[nonzeros]
phase = np.unwrap(np.angle(pilot_tone)) / (2 * np.pi)
indices = np.arange(len(phase))
a, b = dsp.linear_regression(indices[skip:-skip], phase[skip:-skip])
self.plt.figure()
self.plt.plot(indices, phase, ':')
self.plt.plot(indices, a * indices + b)
freq_err = a / (self.Tsym * self.frequencies[self.carrier_index])
last_phase = a * indices[-1] + b
log.debug('Current phase on carrier: %.3f', last_phase)
log.debug('Frequency error: %.2f ppm', freq_err * 1e6)
self.plt.title('Frequency drift: {0:.3f} ppm'.format(freq_err * 1e6))
return freq_err
def _train(self, sampler, order, lookahead): def _train(self, sampler, order, lookahead):
Nfreq = len(self.frequencies) Nfreq = len(self.frequencies)
equalizer_length = equalizer.equalizer_length equalizer_length = equalizer.equalizer_length
@@ -158,13 +140,11 @@ class Receiver(object):
(1.0 - sampler.freq) * 1e6 (1.0 - sampler.freq) * 1e6
) )
def run(self, signal, gain, output): def run(self, sampler, gain, output):
sampler = sampling.Sampler(signal, sampling.Interpolator())
symbols = dsp.Demux(sampler, omegas=self.omegas, Nsym=self.Nsym) symbols = dsp.Demux(sampler, omegas=self.omegas, Nsym=self.Nsym)
freq_err = self._prefix(symbols, gain=gain) self._prefix(symbols, gain=gain)
sampler.freq -= freq_err
filt = self._train(sampler, order=11, lookahead=11) filt = self._train(sampler, order=20, lookahead=20)
sampler.equalizer = lambda x: list(filt(x)) sampler.equalizer = lambda x: list(filt(x))
bitstream = self._demodulate(sampler, symbols) bitstream = self._demodulate(sampler, symbols)
@@ -235,13 +215,22 @@ def main(config, src, dst, dump_audio=None, pylab=None):
log.debug('Skipping %.3f seconds', config.skip_start) log.debug('Skipping %.3f seconds', config.skip_start)
common.take(signal, to_skip) common.take(signal, to_skip)
detector = detect.Detector(config=config) pylab = pylab or common.Dummy()
detector = detect.Detector(config=config, pylab=pylab)
receiver = Receiver(config=config, pylab=pylab) receiver = Receiver(config=config, pylab=pylab)
success = False success = False
try: try:
log.info('Waiting for carrier tone: %.1f kHz', config.Fc / 1e3) log.info('Waiting for carrier tone: %.1f kHz', config.Fc / 1e3)
signal, amplitude = detector.run(signal) signal, amplitude, freq_error = detector.run(signal)
receiver.run(signal, gain=1.0/amplitude, output=dst)
freq = 1 / (1.0 + freq_error) # receiver's compensated frequency
log.debug('Frequency correction: %.3f ppm', (freq - 1) * 1e6)
gain = 1.0 / amplitude
log.debug('Gain correction: %.3f', gain)
sampler = sampling.Sampler(signal, sampling.Interpolator(), freq=freq)
receiver.run(sampler, gain=1.0/amplitude, output=dst)
success = True success = True
except Exception: except Exception:
log.exception('Decoding failed') log.exception('Decoding failed')

View File

@@ -31,8 +31,8 @@ class Interpolator(object):
class Sampler(object): class Sampler(object):
def __init__(self, src, interp=None): def __init__(self, src, interp=None, freq=1.0):
self.freq = 1.0 # normalized self.freq = freq
self.equalizer = lambda x: x # LTI equalization filter self.equalizer = lambda x: x # LTI equalization filter
if interp is not None: if interp is not None:
self.interp = interp self.interp = interp

View File

@@ -7,6 +7,7 @@ from amodem import detect
from amodem import equalizer from amodem import equalizer
from amodem import sampling from amodem import sampling
from amodem import config from amodem import config
from amodem import common
config = config.fastest() config = config.fastest()
@@ -15,9 +16,10 @@ def test_detect():
t = np.arange(P * config.Nsym) * config.Ts t = np.arange(P * config.Nsym) * config.Ts
x = np.cos(2 * np.pi * config.Fc * t) x = np.cos(2 * np.pi * config.Fc * t)
detector = detect.Detector(config) detector = detect.Detector(config, pylab=common.Dummy())
samples, amp = detector.run(x) samples, amp, freq_err = detector.run(x)
assert abs(1 - amp) < 1e-12 assert abs(1 - amp) < 1e-12
assert abs(freq_err) < 1e-16
x = np.cos(2 * np.pi * (2*config.Fc) * t) x = np.cos(2 * np.pi * (2*config.Fc) * t)
with pytest.raises(ValueError): with pytest.raises(ValueError):
@@ -36,9 +38,8 @@ def test_prefix():
def symbols_stream(signal): def symbols_stream(signal):
sampler = sampling.Sampler(signal) sampler = sampling.Sampler(signal)
return dsp.Demux(sampler=sampler, omegas=[omega], Nsym=config.Nsym) return dsp.Demux(sampler=sampler, omegas=[omega], Nsym=config.Nsym)
r = recv.Receiver(config) r = recv.Receiver(config, pylab=common.Dummy())
freq_err = r._prefix(symbols_stream(signal)) r._prefix(symbols_stream(signal))
assert abs(freq_err) < 1e-16
with pytest.raises(ValueError): with pytest.raises(ValueError):
silence = 0 * signal silence = 0 * signal
@@ -47,15 +48,14 @@ def test_prefix():
def test_find_start(): def test_find_start():
sym = np.cos(2 * np.pi * config.Fc * np.arange(config.Nsym) * config.Ts) sym = np.cos(2 * np.pi * config.Fc * np.arange(config.Nsym) * config.Ts)
detector = detect.Detector(config) detector = detect.Detector(config, pylab=common.Dummy())
length = 200 length = 200
prefix = postfix = np.tile(0 * sym, 50) prefix = postfix = np.tile(0 * sym, 50)
carrier = np.tile(sym, length) carrier = np.tile(sym, length)
for offset in range(10): for offset in range(32):
prefix = [0] * offset bufs = [prefix, [0] * offset, carrier, postfix]
bufs = [prefix, prefix, carrier, postfix]
buf = np.concatenate(bufs) buf = np.concatenate(bufs)
start = detector.find_start(buf, length*config.Nsym) start = detector.find_start(buf, length)
expected = offset + len(prefix) expected = offset + len(prefix)
assert expected == start assert expected == start

View File

@@ -45,7 +45,8 @@ def run(size, chan=None, df=0, success=True):
rx_data = BytesIO() rx_data = BytesIO()
d = BytesIO() d = BytesIO()
result = recv.main(config=config, src=rx_audio, dst=rx_data, dump_audio=d) result = recv.main(config=config, src=rx_audio, dst=rx_data,
dump_audio=d)
rx_data = rx_data.getvalue() rx_data = rx_data.getvalue()
assert data.startswith(d.getvalue()) assert data.startswith(d.getvalue())
@@ -68,8 +69,8 @@ def test_error():
run(1024, chan=lambda x: x[:-skip], success=False) run(1024, chan=lambda x: x[:-skip], success=False)
@pytest.fixture(params=[sign * (10.0 ** exp) for sign in (+1, -1) @pytest.fixture(params=[sign * mag for sign in (+1, -1)
for exp in (-1, 0, 1, 2, 3)]) for mag in (0.1, 1, 10, 100, 1e3, 2e3)])
def freq_err(request): def freq_err(request):
return request.param * 1e-6 return request.param * 1e-6