refactor QAM object into MODEM

This commit is contained in:
Roman Zeyde
2014-09-30 14:13:07 +03:00
parent e94ccdd8de
commit 07e530cee2
9 changed files with 117 additions and 129 deletions

View File

@@ -33,3 +33,10 @@ symbols = symbols / np.max(np.abs(symbols))
Nsym = int(Tsym / Ts)
baud = int(1/Tsym)
bits_per_symbol = np.log2(Npoints)
bits_per_baud = bits_per_symbol * Nfreq
modem_bps = baud * bits_per_baud
carriers = np.array([
np.exp(2j * np.pi * f * np.arange(0, Nsym) * Ts) for f in frequencies
])

View File

@@ -4,8 +4,8 @@ import logging
log = logging.getLogger(__name__)
from .config import Ts, Nsym
from .qam import QAM
from . import config
from . import common
class IIR(object):
@@ -65,16 +65,17 @@ def estimate(x, y, order, lookahead=0):
class Demux(object):
def __init__(self, sampler, freqs):
self.sampler = sampler
Nsym = config.Nsym
self.filters = [exp_iwt(-f, Nsym) / (0.5*Nsym) for f in freqs]
self.filters = np.array(self.filters)
self.sampler = sampler
def __iter__(self):
return self
def next(self):
frame = self.sampler.take(size=Nsym)
if len(frame) == Nsym:
frame = self.sampler.take(size=config.Nsym)
if len(frame) == config.Nsym:
return np.dot(self.filters, frame)
else:
raise StopIteration
@@ -82,27 +83,8 @@ class Demux(object):
__next__ = next
class MODEM(object):
def __init__(self, config):
self.qam = QAM(config.symbols)
self.baud = config.baud
self.freqs = config.frequencies
self.bits_per_baud = self.qam.bits_per_symbol * len(self.freqs)
self.modem_bps = self.baud * self.bits_per_baud
self.carriers = np.array([
np.exp(2j * np.pi * freq * np.arange(0, Nsym) * Ts)
for freq in self.freqs
])
def __repr__(self):
return '<{:.3f} kbps, {:d}-QAM, {:d} carriers>'.format(
self.modem_bps / 1e3, len(self.qam.symbols), len(self.carriers))
__str__ = __repr__
def exp_iwt(freq, n):
iwt = 2j * np.pi * freq * np.arange(n) * Ts
iwt = 2j * np.pi * freq * np.arange(n) * config.Ts
return np.exp(iwt)
@@ -128,3 +110,50 @@ def linear_regression(x, y):
M = np.array([x, ones]).T
a, b = linalg.lstsq(M, y)[0]
return a, b
class MODEM(object):
buf_size = 16
def __init__(self, symbols):
self.encode_map = {}
symbols = np.array(list(symbols))
bits_per_symbol = np.log2(len(symbols))
bits_per_symbol = np.round(bits_per_symbol)
N = (2 ** bits_per_symbol)
assert N == len(symbols)
bits_per_symbol = int(bits_per_symbol)
for i, v in enumerate(symbols):
bits = [int(i & (1 << j) != 0) for j in range(bits_per_symbol)]
self.encode_map[tuple(bits)] = v
self.symbols = symbols
self.bits_per_symbol = bits_per_symbol
bits_map = {symbol: bits for bits, symbol in self.encode_map.items()}
self.decode_list = [(s, bits_map[s]) for s in self.symbols]
def encode(self, bits):
for bits_tuple in common.iterate(bits, self.bits_per_symbol, tuple):
yield self.encode_map[bits_tuple]
def decode(self, symbols, error_handler=None):
''' Maximum-likelihood decoding, using naive nearest-neighbour. '''
symbols_vec = self.symbols
_dec = self.decode_list
for syms in common.iterate(symbols, self.buf_size, truncate=False):
for received in syms:
error = np.abs(symbols_vec - received)
index = np.argmin(error)
decoded, bits = _dec[index]
if error_handler:
error_handler(received=received, decoded=decoded)
yield bits
def __repr__(self):
return '<{:.3f} kbps, {:d}-QAM, {:d} carriers>'.format(
config.modem_bps / 1e3, len(self.symbols), len(config.carriers))
__str__ = __repr__

View File

@@ -9,7 +9,6 @@ import itertools
import random
_constellation = [1, 1j, -1, -1j]
modem = dsp.MODEM(config)
def train_symbols(length, seed=0, Nfreq=config.Nfreq):
@@ -18,7 +17,8 @@ def train_symbols(length, seed=0, Nfreq=config.Nfreq):
return np.array([choose() for i in range(length)])
def modulator(symbols, carriers=modem.carriers):
def modulator(symbols):
carriers = config.carriers
gain = 1.0 / len(carriers)
result = []
for s in symbols:
@@ -37,7 +37,7 @@ def demodulator(signal, size):
def equalize_symbols(signal, symbols, order, lookahead=0):
Nsym = config.Nsym
Nfreq = config.Nfreq
carriers = modem.carriers
carriers = config.carriers
assert symbols.shape[1] == Nfreq
length = symbols.shape[0]

View File

@@ -1,42 +0,0 @@
import numpy as np
from . import common
class QAM(object):
buf_size = 16
def __init__(self, symbols):
self.encode_map = {}
symbols = np.array(list(symbols))
bits_per_symbol = np.log2(len(symbols))
bits_per_symbol = np.round(bits_per_symbol)
N = (2 ** bits_per_symbol)
assert N == len(symbols)
bits_per_symbol = int(bits_per_symbol)
for i, v in enumerate(symbols):
bits = [int(i & (1 << j) != 0) for j in range(bits_per_symbol)]
self.encode_map[tuple(bits)] = v
self.symbols = symbols
self.bits_per_symbol = bits_per_symbol
bits_map = {symbol: bits for bits, symbol in self.encode_map.items()}
self.decode_list = [(s, bits_map[s]) for s in self.symbols]
def encode(self, bits):
for bits_tuple in common.iterate(bits, self.bits_per_symbol, tuple):
yield self.encode_map[bits_tuple]
def decode(self, symbols, error_handler=None):
symbols_vec = self.symbols
_dec = self.decode_list
for syms in common.iterate(symbols, self.buf_size, truncate=False):
for received in syms:
error = np.abs(symbols_vec - received)
index = np.argmin(error)
decoded, bits = _dec[index]
if error_handler:
error_handler(received=received, decoded=decoded)
yield bits

View File

@@ -16,11 +16,11 @@ from . import config
from . import framing
from . import equalizer
modem = dsp.MODEM(config)
modem = dsp.MODEM(config.symbols)
# Plots' size (WIDTH x HEIGHT)
HEIGHT = np.floor(np.sqrt(len(modem.freqs)))
WIDTH = np.ceil(len(modem.freqs) / float(HEIGHT))
HEIGHT = np.floor(np.sqrt(config.Nfreq))
WIDTH = np.ceil(config.Nfreq / float(HEIGHT))
COHERENCE_THRESHOLD = 0.99
@@ -187,7 +187,7 @@ class Receiver(object):
symbol_list.append(equalized)
freq_handler = functools.partial(error_handler, freq=freq)
bits = modem.qam.decode(S, freq_handler) # list of bit tuples
bits = modem.decode(S, freq_handler) # list of bit tuples
streams.append(bits) # stream per frequency
self.stats['symbol_list'] = symbol_list
@@ -238,7 +238,7 @@ class Receiver(object):
def report(self):
if self.stats:
duration = time.time() - self.stats['rx_start']
audio_time = self.stats['rx_bits'] / float(modem.modem_bps)
audio_time = self.stats['rx_bits'] / float(config.modem_bps)
log.debug('Demodulated %.3f kB @ %.3f seconds (%.1f%% realtime)',
self.stats['rx_bits'] / 8e3, duration,
100 * duration / audio_time if audio_time else 0)
@@ -248,9 +248,9 @@ class Receiver(object):
self.plt.figure()
symbol_list = np.array(self.stats['symbol_list'])
for i, freq in enumerate(modem.freqs):
for i, freq in enumerate(config.frequencies):
self.plt.subplot(HEIGHT, WIDTH, i+1)
self._constellation(symbol_list[i], modem.qam.symbols,
self._constellation(symbol_list[i], config.symbols,
'$F_c = {} Hz$'.format(freq))
self.plt.show()
@@ -263,6 +263,7 @@ class Receiver(object):
self.plt.plot(points.real, points.imag, '+')
self.plt.grid('on')
self.plt.axis('equal')
self.plt.axis(np.array([-1, 1, -1, 1])*1.1)
self.plt.title(title)
@@ -271,7 +272,7 @@ def main(args):
signal = itertools.chain.from_iterable(reader)
skipped = common.take(signal, args.skip)
log.debug('Skipping %.3f seconds', len(skipped) / float(modem.baud))
log.debug('Skipping %.3f seconds', len(skipped) / float(config.baud))
reader.check = common.check_saturation
@@ -280,7 +281,7 @@ def main(args):
try:
log.info('Waiting for carrier tone: %.1f kHz', config.Fc / 1e3)
signal, amplitude = detect(signal, config.Fc)
receiver.start(signal, modem.freqs, gain=1.0/amplitude)
receiver.start(signal, config.frequencies, gain=1.0/amplitude)
receiver.run(args.output)
success = True
except Exception:

View File

@@ -9,12 +9,12 @@ from . import wave
from . import common
from . import config
from . import dsp
from . import stream
from . import framing
from . import equalizer
from . import dsp
modem = dsp.MODEM(config)
modem = dsp.MODEM(config.symbols)
class Writer(object):
@@ -29,7 +29,7 @@ class Writer(object):
self.offset += len(data)
def start(self):
carrier = modem.carriers[config.carrier_index]
carrier = config.carriers[config.carrier_index]
for value in train.prefix:
self.write(carrier * value)
@@ -41,10 +41,10 @@ class Writer(object):
self.write(silence)
def modulate(self, bits):
padding = [0] * modem.bits_per_baud
padding = [0] * config.bits_per_baud
bits = itertools.chain(bits, padding)
symbols_iter = modem.qam.encode(bits)
carriers = modem.carriers / config.Nfreq
symbols_iter = modem.encode(bits)
carriers = config.carriers / config.Nfreq
for i, symbols in common.iterate(symbols_iter,
size=config.Nfreq, enumerate=True):
symbols = np.array(list(symbols))
@@ -52,7 +52,7 @@ class Writer(object):
data_duration = (i / config.Nfreq + 1) * config.Tsym
if data_duration % 1 == 0:
bits_size = data_duration * modem.modem_bps
bits_size = data_duration * config.modem_bps
log.debug('Sent %8.1f kB', bits_size / 8e3)

View File

@@ -17,10 +17,7 @@ from amodem import recv
from amodem import send
from amodem import wave
from amodem import calib
from amodem import dsp
modem = dsp.MODEM(config)
null = open('/dev/null', 'wb')
@@ -50,7 +47,7 @@ def FileType(mode, process=None):
def main():
fmt = ('Audio OFDM MODEM: {:.1f} kb/s ({:d}-QAM x {:d} carriers) '
'Fs={:.1f} kHz')
description = fmt.format(modem.modem_bps / 1e3, len(config.symbols),
description = fmt.format(config.modem_bps / 1e3, len(config.symbols),
config.Nfreq, config.Fs / 1e3)
p = argparse.ArgumentParser(description=description)
g = p.add_mutually_exclusive_group()

View File

@@ -5,6 +5,9 @@ from amodem import dsp
from amodem import config
from amodem import sampling
import random
import itertools
def test_linreg():
x = np.array([1, 3, 2, 8, 4, 6, 9, 7, 0, 5])
@@ -63,3 +66,35 @@ def test_demux():
res = dsp.Demux(sampling.Sampler(sig.real), freqs)
res = np.array(list(res))
assert np.max(np.abs(res - syms)) < 1e-12
def test_qam():
q = dsp.MODEM(config.symbols)
r = random.Random(0)
m = q.bits_per_symbol
bits = [tuple(r.randint(0, 1) for j in range(m)) for i in range(1024)]
stream = itertools.chain(*bits)
S = list(q.encode(list(stream)))
decoded = list(q.decode(S))
assert decoded == bits
noise = lambda A: A*(r.uniform(-1, 1) + 1j*r.uniform(-1, 1))
noised_symbols = [(s + noise(1e-3)) for s in S]
decoded = list(q.decode(noised_symbols))
assert decoded == bits
def quantize(q, s):
bits, = list(q.decode([s]))
r, = q.encode(bits)
index = np.argmin(np.abs(s - q.symbols))
expected = q.symbols[index]
assert r == expected
def test_overflow():
q = dsp.MODEM(config.symbols)
r = np.random.RandomState(seed=0)
for i in range(10000):
s = 10*(r.normal() + 1j * r.normal())
quantize(q, s)

View File

@@ -1,39 +0,0 @@
import random
import itertools
import numpy as np
from amodem import qam
from amodem import config
def test_qam():
q = qam.QAM(config.symbols)
r = random.Random(0)
m = q.bits_per_symbol
bits = [tuple(r.randint(0, 1) for j in range(m)) for i in range(1024)]
stream = itertools.chain(*bits)
S = list(q.encode(list(stream)))
decoded = list(q.decode(S))
assert decoded == bits
noise = lambda A: A*(r.uniform(-1, 1) + 1j*r.uniform(-1, 1))
noised_symbols = [(s + noise(1e-3)) for s in S]
decoded = list(q.decode(noised_symbols))
assert decoded == bits
def quantize(q, s):
bits, = list(q.decode([s]))
r, = q.encode(bits)
index = np.argmin(np.abs(s - q.symbols))
expected = q.symbols[index]
assert r == expected
def test_overflow():
q = qam.QAM(config.symbols)
r = np.random.RandomState(seed=0)
for i in range(10000):
s = 10*(r.normal() + 1j * r.normal())
quantize(q, s)