mirror of
https://github.com/romanz/amodem.git
synced 2026-04-01 08:46:49 +08:00
refactor pylab usage at receiver
This commit is contained in:
355
amodem/recv.py
355
amodem/recv.py
@@ -4,7 +4,6 @@ import itertools
|
||||
import functools
|
||||
import collections
|
||||
import time
|
||||
import os
|
||||
|
||||
import bitarray
|
||||
|
||||
@@ -21,11 +20,6 @@ from . import equalizer
|
||||
|
||||
modem = dsp.MODEM(config)
|
||||
|
||||
if os.environ.get('PYLAB') == '1':
|
||||
import pylab
|
||||
else:
|
||||
pylab = common.Dummy()
|
||||
|
||||
# Plots' size (WIDTH x HEIGHT)
|
||||
HEIGHT = np.floor(np.sqrt(len(modem.freqs)))
|
||||
WIDTH = np.ceil(len(modem.freqs) / float(HEIGHT))
|
||||
@@ -92,75 +86,176 @@ def find_start(buf, length):
|
||||
return np.argmax(correlations)
|
||||
|
||||
|
||||
def receive_prefix(sampler, freq, gain=1.0, skip=5):
|
||||
symbols = dsp.Demux(sampler, [freq])
|
||||
S = common.take(symbols, len(train.prefix)).squeeze() * gain
|
||||
sliced = np.round(S)
|
||||
pylab.figure()
|
||||
constellation(S, sliced, 'Prefix')
|
||||
class Receiver(object):
|
||||
|
||||
bits = np.array(np.abs(sliced), dtype=int)
|
||||
if any(bits != train.prefix):
|
||||
raise ValueError('Incorrect prefix')
|
||||
def __init__(self, pylab=None):
|
||||
self.stats = {}
|
||||
self.pylab = pylab or common.Dummy()
|
||||
|
||||
log.info('Prefix OK')
|
||||
def _prefix(self, sampler, freq, gain=1.0, skip=5):
|
||||
symbols = dsp.Demux(sampler, [freq])
|
||||
S = common.take(symbols, len(train.prefix)).squeeze() * gain
|
||||
sliced = np.round(S)
|
||||
self.pylab.figure()
|
||||
self._constellation(S, sliced, 'Prefix')
|
||||
|
||||
nonzeros = np.array(train.prefix, dtype=bool)
|
||||
pilot_tone = S[nonzeros]
|
||||
phase = np.unwrap(np.angle(pilot_tone)) / (2 * np.pi)
|
||||
indices = np.arange(len(phase))
|
||||
a, b = dsp.linear_regression(indices[skip:-skip], phase[skip:-skip])
|
||||
pylab.figure()
|
||||
pylab.plot(indices, phase, ':')
|
||||
pylab.plot(indices, a * indices + b)
|
||||
bits = np.array(np.abs(sliced), dtype=int)
|
||||
if any(bits != train.prefix):
|
||||
raise ValueError('Incorrect prefix')
|
||||
|
||||
freq_err = a / (config.Tsym * config.Fc)
|
||||
last_phase = a * indices[-1] + b
|
||||
log.debug('Current phase on carrier: %.3f', last_phase)
|
||||
log.info('Prefix OK')
|
||||
|
||||
log.info('Frequency error: %.2f ppm', freq_err * 1e6)
|
||||
pylab.title('Frequency drift: {:.3f} ppm'.format(freq_err * 1e6))
|
||||
return freq_err
|
||||
nonzeros = np.array(train.prefix, dtype=bool)
|
||||
pilot_tone = S[nonzeros]
|
||||
phase = np.unwrap(np.angle(pilot_tone)) / (2 * np.pi)
|
||||
indices = np.arange(len(phase))
|
||||
a, b = dsp.linear_regression(indices[skip:-skip], phase[skip:-skip])
|
||||
self.pylab.figure()
|
||||
self.pylab.plot(indices, phase, ':')
|
||||
self.pylab.plot(indices, a * indices + b)
|
||||
|
||||
freq_err = a / (config.Tsym * config.Fc)
|
||||
last_phase = a * indices[-1] + b
|
||||
log.debug('Current phase on carrier: %.3f', last_phase)
|
||||
|
||||
log.info('Frequency error: %.2f ppm', freq_err * 1e6)
|
||||
self.pylab.title('Frequency drift: {:.3f} ppm'.format(freq_err * 1e6))
|
||||
return freq_err
|
||||
|
||||
def _train(self, sampler, order, lookahead):
|
||||
train_symbols = equalizer.train_symbols(train.equalizer_length)
|
||||
prefix = postfix = train.silence_length * config.Nsym
|
||||
signal_length = (train.equalizer_length * config.Nsym) + prefix + postfix
|
||||
|
||||
signal = sampler.take(signal_length + lookahead)
|
||||
unequalized = signal[prefix:-postfix]
|
||||
|
||||
coeffs = equalizer.equalize(unequalized, train_symbols, order, lookahead)
|
||||
equalization_filter = dsp.FIR(h=coeffs)
|
||||
equalized = list(equalization_filter(signal))
|
||||
equalized = equalized[prefix+lookahead:-postfix+lookahead]
|
||||
|
||||
symbols = equalizer.demodulator(equalized, train.equalizer_length)
|
||||
sliced = np.array(symbols).round()
|
||||
errors = np.array(sliced - train_symbols, dtype=np.bool)
|
||||
error_rate = errors.sum() / errors.size
|
||||
|
||||
errors = np.array(symbols - train_symbols)
|
||||
rms = lambda x: (np.mean(np.abs(x) ** 2, axis=0) ** 0.5)
|
||||
|
||||
noise_rms = rms(errors)
|
||||
signal_rms = rms(train_symbols)
|
||||
SNRs = 20.0 * np.log10(signal_rms / noise_rms)
|
||||
|
||||
self.pylab.figure()
|
||||
for i, freq, snr in zip(range(config.Nfreq), config.frequencies, SNRs):
|
||||
log.debug('%5.1f kHz: SNR = %5.2f dB', freq / 1e3, snr)
|
||||
self.pylab.subplot(HEIGHT, WIDTH, i+1)
|
||||
self._constellation(symbols[:, i], train_symbols[:, i],
|
||||
'$F_c = {} Hz$'.format(freq))
|
||||
|
||||
assert error_rate == 0, error_rate
|
||||
|
||||
return equalization_filter
|
||||
|
||||
def _demodulate(self, sampler, freqs):
|
||||
streams = []
|
||||
symbol_list = []
|
||||
errors = {}
|
||||
|
||||
def error_handler(received, decoded, freq):
|
||||
errors.setdefault(freq, []).append(received / decoded)
|
||||
|
||||
symbols = dsp.Demux(sampler, freqs)
|
||||
generators = common.split(symbols, n=len(freqs))
|
||||
for freq, S in zip(freqs, generators):
|
||||
equalized = []
|
||||
S = common.icapture(S, result=equalized)
|
||||
symbol_list.append(equalized)
|
||||
|
||||
freq_handler = functools.partial(error_handler, freq=freq)
|
||||
bits = modem.qam.decode(S, freq_handler) # list of bit tuples
|
||||
streams.append(bits) # stream per frequency
|
||||
|
||||
self.stats['symbol_list'] = symbol_list
|
||||
self.stats['rx_bits'] = 0
|
||||
self.stats['rx_start'] = time.time()
|
||||
|
||||
log.info('Demodulation started')
|
||||
for i, block in enumerate(izip(streams)): # block per frequency
|
||||
for bits in block:
|
||||
self.stats['rx_bits'] = self.stats['rx_bits'] + len(bits)
|
||||
yield bits
|
||||
|
||||
if i and i % config.baud == 0:
|
||||
err = np.array([e for v in errors.values() for e in v])
|
||||
correction = np.mean(np.angle(err)) / (2*np.pi) if len(err) else 0.0
|
||||
duration = time.time() - self.stats['rx_start']
|
||||
log.debug('%10.1f kB, realtime: %6.2f%%, sampling error: %+.3f%%',
|
||||
self.stats['rx_bits'] / 8e3,
|
||||
duration * 100.0 / (i*config.Tsym),
|
||||
correction * 1e2)
|
||||
errors.clear()
|
||||
sampler.freq -= 0.01 * correction / config.Fc
|
||||
sampler.offset -= correction
|
||||
|
||||
def start(self, signal, freqs, gain=1.0):
|
||||
sampler = sampling.Sampler(signal, sampling.Interpolator())
|
||||
|
||||
freq_err = self._prefix(sampler, freq=freqs[0], gain=gain)
|
||||
sampler.freq -= freq_err
|
||||
|
||||
filt = self._train(sampler, order=11, lookahead=5)
|
||||
sampler.equalizer = lambda x: list(filt(x))
|
||||
|
||||
data_bits = self._demodulate(sampler, freqs)
|
||||
self.bits = itertools.chain.from_iterable(data_bits)
|
||||
|
||||
def decode(self, output):
|
||||
chunks = ecc.decode(_blocks(self.bits))
|
||||
self.size = 0
|
||||
for chunk in chunks:
|
||||
output.write(chunk)
|
||||
self.size = self.size + len(chunk)
|
||||
|
||||
def report(self):
|
||||
duration = time.time() - self.stats['rx_start']
|
||||
audio_time = self.stats['rx_bits'] / float(modem.modem_bps)
|
||||
log.debug('Demodulated %.3f kB @ %.3f seconds (%.1f%% realtime)',
|
||||
self.stats['rx_bits'] / 8e3, duration,
|
||||
100 * duration / audio_time)
|
||||
|
||||
log.info('Received %.3f kB @ %.3f seconds = %.3f kB/s',
|
||||
self.size * 1e-3, duration, self.size * 1e-3 / duration)
|
||||
|
||||
self.pylab.figure()
|
||||
symbol_list = np.array(self.stats['symbol_list'])
|
||||
for i, freq in enumerate(modem.freqs):
|
||||
self.pylab.subplot(HEIGHT, WIDTH, i+1)
|
||||
self._constellation(symbol_list[i], modem.qam.symbols,
|
||||
'$F_c = {} Hz$'.format(freq))
|
||||
|
||||
self.pylab.show()
|
||||
|
||||
def _constellation(self, y, symbols, title):
|
||||
theta = np.linspace(0, 2*np.pi, 1000)
|
||||
y = np.array(y)
|
||||
self.pylab.plot(y.real, y.imag, '.')
|
||||
self.pylab.plot(np.cos(theta), np.sin(theta), ':')
|
||||
points = np.array(symbols)
|
||||
self.pylab.plot(points.real, points.imag, '+')
|
||||
self.pylab.grid('on')
|
||||
self.pylab.axis('equal')
|
||||
self.pylab.title(title)
|
||||
|
||||
|
||||
def train_receiver(sampler, order, lookahead):
|
||||
train_symbols = equalizer.train_symbols(train.equalizer_length)
|
||||
prefix = postfix = train.silence_length * config.Nsym
|
||||
signal_length = (train.equalizer_length * config.Nsym) + prefix + postfix
|
||||
|
||||
signal = sampler.take(signal_length + lookahead)
|
||||
unequalized = signal[prefix:-postfix]
|
||||
|
||||
coeffs = equalizer.equalize(unequalized, train_symbols, order, lookahead)
|
||||
equalization_filter = dsp.FIR(h=coeffs)
|
||||
equalized = list(equalization_filter(signal))
|
||||
equalized = equalized[prefix+lookahead:-postfix+lookahead]
|
||||
|
||||
symbols = equalizer.demodulator(equalized, train.equalizer_length)
|
||||
sliced = np.array(symbols).round()
|
||||
errors = np.array(sliced - train_symbols, dtype=np.bool)
|
||||
error_rate = errors.sum() / errors.size
|
||||
|
||||
errors = np.array(symbols - train_symbols)
|
||||
rms = lambda x: (np.mean(np.abs(x) ** 2, axis=0) ** 0.5)
|
||||
|
||||
noise_rms = rms(errors)
|
||||
signal_rms = rms(train_symbols)
|
||||
SNRs = 20.0 * np.log10(signal_rms / noise_rms)
|
||||
|
||||
pylab.figure()
|
||||
for i, freq, snr in zip(range(config.Nfreq), config.frequencies, SNRs):
|
||||
log.debug('%5.1f kHz: SNR = %5.2f dB', freq / 1e3, snr)
|
||||
pylab.subplot(HEIGHT, WIDTH, i+1)
|
||||
constellation(symbols[:, i], train_symbols[:, i],
|
||||
'$F_c = {} Hz$'.format(freq))
|
||||
|
||||
assert error_rate == 0, error_rate
|
||||
|
||||
return equalization_filter
|
||||
|
||||
|
||||
stats = {}
|
||||
def _blocks(bits):
|
||||
while True:
|
||||
block = bitarray.bitarray(endian='little')
|
||||
block.extend(itertools.islice(bits, 8 * ecc.BLOCK_SIZE))
|
||||
if not block:
|
||||
break
|
||||
yield bytearray(block.tobytes())
|
||||
|
||||
|
||||
def izip(streams):
|
||||
@@ -169,128 +264,26 @@ def izip(streams):
|
||||
yield [next(i) for i in iters]
|
||||
|
||||
|
||||
def demodulate(sampler, freqs):
|
||||
streams = []
|
||||
symbol_list = []
|
||||
errors = {}
|
||||
|
||||
def error_handler(received, decoded, freq):
|
||||
errors.setdefault(freq, []).append(received / decoded)
|
||||
|
||||
symbols = dsp.Demux(sampler, freqs)
|
||||
generators = common.split(symbols, n=len(freqs))
|
||||
for freq, S in zip(freqs, generators):
|
||||
equalized = []
|
||||
S = common.icapture(S, result=equalized)
|
||||
symbol_list.append(equalized)
|
||||
|
||||
freq_handler = functools.partial(error_handler, freq=freq)
|
||||
bits = modem.qam.decode(S, freq_handler) # list of bit tuples
|
||||
streams.append(bits) # stream per frequency
|
||||
|
||||
stats['symbol_list'] = symbol_list
|
||||
stats['rx_bits'] = 0
|
||||
stats['rx_start'] = time.time()
|
||||
|
||||
log.info('Demodulation started')
|
||||
for i, block in enumerate(izip(streams)): # block per frequency
|
||||
for bits in block:
|
||||
stats['rx_bits'] = stats['rx_bits'] + len(bits)
|
||||
yield bits
|
||||
|
||||
if i and i % config.baud == 0:
|
||||
err = np.array([e for v in errors.values() for e in v])
|
||||
correction = np.mean(np.angle(err)) / (2*np.pi) if len(err) else 0.0
|
||||
duration = time.time() - stats['rx_start']
|
||||
log.debug('%10.1f kB, realtime: %6.2f%%, sampling error: %+.3f%%',
|
||||
stats['rx_bits'] / 8e3,
|
||||
duration * 100.0 / (i*config.Tsym),
|
||||
correction * 1e2)
|
||||
errors.clear()
|
||||
sampler.freq -= 0.01 * correction / config.Fc
|
||||
sampler.offset -= correction
|
||||
|
||||
|
||||
def receive(signal, freqs, gain=1.0):
|
||||
sampler = sampling.Sampler(signal, sampling.Interpolator())
|
||||
|
||||
freq_err = receive_prefix(sampler, freq=freqs[0], gain=gain)
|
||||
sampler.freq -= freq_err
|
||||
|
||||
filt = train_receiver(sampler, order=11, lookahead=5)
|
||||
sampler.equalizer = lambda x: list(filt(x))
|
||||
|
||||
data_bits = demodulate(sampler, freqs)
|
||||
return itertools.chain.from_iterable(data_bits)
|
||||
|
||||
|
||||
def decode(bits_iterator):
|
||||
def blocks():
|
||||
while True:
|
||||
bits = itertools.islice(bits_iterator, 8 * ecc.BLOCK_SIZE)
|
||||
block = bitarray.bitarray(endian='little')
|
||||
block.extend(bits)
|
||||
if not block:
|
||||
break
|
||||
yield bytearray(block.tobytes())
|
||||
|
||||
return ecc.decode(blocks())
|
||||
|
||||
|
||||
def iread(fd, skip):
|
||||
reader = stream.Reader(fd, data_type=common.loads)
|
||||
signal = itertools.chain.from_iterable(reader)
|
||||
|
||||
skipped = common.take(signal, skip)
|
||||
log.debug('Skipping %.3f seconds', len(skipped) / float(modem.baud))
|
||||
|
||||
reader.check = common.check_saturation
|
||||
return signal
|
||||
|
||||
|
||||
def main(args):
|
||||
log.info('Running MODEM @ {:.1f} kbps'.format(modem.modem_bps / 1e3))
|
||||
|
||||
signal = iread(args.input, args.skip)
|
||||
reader = stream.Reader(args.input, data_type=common.loads)
|
||||
signal = itertools.chain.from_iterable(reader)
|
||||
|
||||
skipped = common.take(signal, args.skip)
|
||||
log.debug('Skipping %.3f seconds', len(skipped) / float(modem.baud))
|
||||
|
||||
reader.check = common.check_saturation
|
||||
|
||||
size = 0
|
||||
signal, amplitude = detect(signal, config.Fc)
|
||||
bits = receive(signal, modem.freqs, gain=1.0/amplitude)
|
||||
receiver = Receiver(args.pylab)
|
||||
receiver.start(signal, modem.freqs, gain=1.0/amplitude)
|
||||
success = False
|
||||
try:
|
||||
for chunk in decode(bits):
|
||||
args.output.write(chunk)
|
||||
size = size + len(chunk)
|
||||
receiver.decode(args.output)
|
||||
success = True
|
||||
except Exception:
|
||||
log.exception('Decoding failed')
|
||||
|
||||
duration = time.time() - stats['rx_start']
|
||||
audio_time = stats['rx_bits'] / float(modem.modem_bps)
|
||||
log.debug('Demodulated %.3f kB @ %.3f seconds (%.1f%% realtime)',
|
||||
stats['rx_bits'] / 8e3, duration, 100 * duration / audio_time)
|
||||
|
||||
log.info('Received %.3f kB @ %.3f seconds = %.3f kB/s',
|
||||
size * 1e-3, duration, size * 1e-3 / duration)
|
||||
|
||||
pylab.figure()
|
||||
symbol_list = np.array(stats['symbol_list'])
|
||||
for i, freq in enumerate(modem.freqs):
|
||||
pylab.subplot(HEIGHT, WIDTH, i+1)
|
||||
constellation(symbol_list[i], modem.qam.symbols,
|
||||
'$F_c = {} Hz$'.format(freq))
|
||||
|
||||
pylab.show()
|
||||
receiver.report()
|
||||
return success
|
||||
|
||||
|
||||
def constellation(y, symbols, title):
|
||||
theta = np.linspace(0, 2*np.pi, 1000)
|
||||
y = np.array(y)
|
||||
pylab.plot(y.real, y.imag, '.')
|
||||
pylab.plot(np.cos(theta), np.sin(theta), ':')
|
||||
points = np.array(symbols)
|
||||
pylab.plot(points.real, points.imag, '+')
|
||||
pylab.grid('on')
|
||||
pylab.axis('equal')
|
||||
pylab.title(title)
|
||||
|
||||
@@ -8,6 +8,8 @@ import argparse
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument('--skip', type=int, default=128,
|
||||
help='skip initial N samples, due to spurious spikes')
|
||||
p.add_argument('--pylab', action='store_true', default=False,
|
||||
help='plot results using pylab module')
|
||||
p.add_argument('-i', '--input', type=argparse.FileType('rb'),
|
||||
default=sys.stdin)
|
||||
p.add_argument('-o', '--output', type=argparse.FileType('wb'),
|
||||
@@ -15,4 +17,7 @@ p.add_argument('-o', '--output', type=argparse.FileType('wb'),
|
||||
args = p.parse_args()
|
||||
|
||||
from amodem.recv import main
|
||||
if args.pylab:
|
||||
import pylab
|
||||
args.pylab = pylab
|
||||
main(args)
|
||||
|
||||
@@ -17,6 +17,7 @@ logging.basicConfig(level=logging.DEBUG,
|
||||
class Args(object):
|
||||
def __init__(self, **kwargs):
|
||||
self.__dict__.update(kwargs)
|
||||
self.pylab = None
|
||||
|
||||
|
||||
def run(size, chan=None, df=0, success=True):
|
||||
|
||||
@@ -26,12 +26,13 @@ def test_prefix():
|
||||
signal = np.concatenate([c * symbol for c in train.prefix])
|
||||
|
||||
sampler = sampling.Sampler(signal)
|
||||
freq_err = recv.receive_prefix(sampler, freq=config.Fc)
|
||||
r = recv.Receiver()
|
||||
freq_err = r._prefix(sampler, freq=config.Fc)
|
||||
assert abs(freq_err) < 1e-16
|
||||
|
||||
try:
|
||||
silence = 0 * signal
|
||||
recv.receive_prefix(sampling.Sampler(silence), freq=config.Fc)
|
||||
r._prefix(sampling.Sampler(silence), freq=config.Fc)
|
||||
assert False
|
||||
except ValueError:
|
||||
pass
|
||||
@@ -52,5 +53,5 @@ def test_find_start():
|
||||
assert expected == start
|
||||
|
||||
|
||||
def test_decode():
|
||||
assert list(recv.decode([])) == []
|
||||
def test_blocks():
|
||||
assert list(recv._blocks([])) == []
|
||||
|
||||
Reference in New Issue
Block a user