mirror of
https://github.com/romanz/amodem.git
synced 2026-02-06 00:36:20 +08:00
refactor QAM object into MODEM
This commit is contained in:
@@ -33,3 +33,10 @@ symbols = symbols / np.max(np.abs(symbols))
|
||||
|
||||
Nsym = int(Tsym / Ts)
|
||||
baud = int(1/Tsym)
|
||||
|
||||
bits_per_symbol = np.log2(Npoints)
|
||||
bits_per_baud = bits_per_symbol * Nfreq
|
||||
modem_bps = baud * bits_per_baud
|
||||
carriers = np.array([
|
||||
np.exp(2j * np.pi * f * np.arange(0, Nsym) * Ts) for f in frequencies
|
||||
])
|
||||
|
||||
@@ -4,8 +4,8 @@ import logging
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
from .config import Ts, Nsym
|
||||
from .qam import QAM
|
||||
from . import config
|
||||
from . import common
|
||||
|
||||
|
||||
class IIR(object):
|
||||
@@ -65,16 +65,17 @@ def estimate(x, y, order, lookahead=0):
|
||||
|
||||
class Demux(object):
|
||||
def __init__(self, sampler, freqs):
|
||||
self.sampler = sampler
|
||||
Nsym = config.Nsym
|
||||
self.filters = [exp_iwt(-f, Nsym) / (0.5*Nsym) for f in freqs]
|
||||
self.filters = np.array(self.filters)
|
||||
self.sampler = sampler
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def next(self):
|
||||
frame = self.sampler.take(size=Nsym)
|
||||
if len(frame) == Nsym:
|
||||
frame = self.sampler.take(size=config.Nsym)
|
||||
if len(frame) == config.Nsym:
|
||||
return np.dot(self.filters, frame)
|
||||
else:
|
||||
raise StopIteration
|
||||
@@ -82,27 +83,8 @@ class Demux(object):
|
||||
__next__ = next
|
||||
|
||||
|
||||
class MODEM(object):
|
||||
def __init__(self, config):
|
||||
self.qam = QAM(config.symbols)
|
||||
self.baud = config.baud
|
||||
self.freqs = config.frequencies
|
||||
self.bits_per_baud = self.qam.bits_per_symbol * len(self.freqs)
|
||||
self.modem_bps = self.baud * self.bits_per_baud
|
||||
self.carriers = np.array([
|
||||
np.exp(2j * np.pi * freq * np.arange(0, Nsym) * Ts)
|
||||
for freq in self.freqs
|
||||
])
|
||||
|
||||
def __repr__(self):
|
||||
return '<{:.3f} kbps, {:d}-QAM, {:d} carriers>'.format(
|
||||
self.modem_bps / 1e3, len(self.qam.symbols), len(self.carriers))
|
||||
|
||||
__str__ = __repr__
|
||||
|
||||
|
||||
def exp_iwt(freq, n):
|
||||
iwt = 2j * np.pi * freq * np.arange(n) * Ts
|
||||
iwt = 2j * np.pi * freq * np.arange(n) * config.Ts
|
||||
return np.exp(iwt)
|
||||
|
||||
|
||||
@@ -128,3 +110,50 @@ def linear_regression(x, y):
|
||||
M = np.array([x, ones]).T
|
||||
a, b = linalg.lstsq(M, y)[0]
|
||||
return a, b
|
||||
|
||||
|
||||
class MODEM(object):
|
||||
|
||||
buf_size = 16
|
||||
|
||||
def __init__(self, symbols):
|
||||
self.encode_map = {}
|
||||
symbols = np.array(list(symbols))
|
||||
bits_per_symbol = np.log2(len(symbols))
|
||||
bits_per_symbol = np.round(bits_per_symbol)
|
||||
N = (2 ** bits_per_symbol)
|
||||
assert N == len(symbols)
|
||||
bits_per_symbol = int(bits_per_symbol)
|
||||
|
||||
for i, v in enumerate(symbols):
|
||||
bits = [int(i & (1 << j) != 0) for j in range(bits_per_symbol)]
|
||||
self.encode_map[tuple(bits)] = v
|
||||
|
||||
self.symbols = symbols
|
||||
self.bits_per_symbol = bits_per_symbol
|
||||
|
||||
bits_map = {symbol: bits for bits, symbol in self.encode_map.items()}
|
||||
self.decode_list = [(s, bits_map[s]) for s in self.symbols]
|
||||
|
||||
def encode(self, bits):
|
||||
for bits_tuple in common.iterate(bits, self.bits_per_symbol, tuple):
|
||||
yield self.encode_map[bits_tuple]
|
||||
|
||||
def decode(self, symbols, error_handler=None):
|
||||
''' Maximum-likelihood decoding, using naive nearest-neighbour. '''
|
||||
symbols_vec = self.symbols
|
||||
_dec = self.decode_list
|
||||
for syms in common.iterate(symbols, self.buf_size, truncate=False):
|
||||
for received in syms:
|
||||
error = np.abs(symbols_vec - received)
|
||||
index = np.argmin(error)
|
||||
decoded, bits = _dec[index]
|
||||
if error_handler:
|
||||
error_handler(received=received, decoded=decoded)
|
||||
yield bits
|
||||
|
||||
def __repr__(self):
|
||||
return '<{:.3f} kbps, {:d}-QAM, {:d} carriers>'.format(
|
||||
config.modem_bps / 1e3, len(self.symbols), len(config.carriers))
|
||||
|
||||
__str__ = __repr__
|
||||
|
||||
@@ -9,7 +9,6 @@ import itertools
|
||||
import random
|
||||
|
||||
_constellation = [1, 1j, -1, -1j]
|
||||
modem = dsp.MODEM(config)
|
||||
|
||||
|
||||
def train_symbols(length, seed=0, Nfreq=config.Nfreq):
|
||||
@@ -18,7 +17,8 @@ def train_symbols(length, seed=0, Nfreq=config.Nfreq):
|
||||
return np.array([choose() for i in range(length)])
|
||||
|
||||
|
||||
def modulator(symbols, carriers=modem.carriers):
|
||||
def modulator(symbols):
|
||||
carriers = config.carriers
|
||||
gain = 1.0 / len(carriers)
|
||||
result = []
|
||||
for s in symbols:
|
||||
@@ -37,7 +37,7 @@ def demodulator(signal, size):
|
||||
def equalize_symbols(signal, symbols, order, lookahead=0):
|
||||
Nsym = config.Nsym
|
||||
Nfreq = config.Nfreq
|
||||
carriers = modem.carriers
|
||||
carriers = config.carriers
|
||||
|
||||
assert symbols.shape[1] == Nfreq
|
||||
length = symbols.shape[0]
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
import numpy as np
|
||||
from . import common
|
||||
|
||||
|
||||
class QAM(object):
|
||||
|
||||
buf_size = 16
|
||||
|
||||
def __init__(self, symbols):
|
||||
self.encode_map = {}
|
||||
symbols = np.array(list(symbols))
|
||||
bits_per_symbol = np.log2(len(symbols))
|
||||
bits_per_symbol = np.round(bits_per_symbol)
|
||||
N = (2 ** bits_per_symbol)
|
||||
assert N == len(symbols)
|
||||
bits_per_symbol = int(bits_per_symbol)
|
||||
|
||||
for i, v in enumerate(symbols):
|
||||
bits = [int(i & (1 << j) != 0) for j in range(bits_per_symbol)]
|
||||
self.encode_map[tuple(bits)] = v
|
||||
|
||||
self.symbols = symbols
|
||||
self.bits_per_symbol = bits_per_symbol
|
||||
|
||||
bits_map = {symbol: bits for bits, symbol in self.encode_map.items()}
|
||||
self.decode_list = [(s, bits_map[s]) for s in self.symbols]
|
||||
|
||||
def encode(self, bits):
|
||||
for bits_tuple in common.iterate(bits, self.bits_per_symbol, tuple):
|
||||
yield self.encode_map[bits_tuple]
|
||||
|
||||
def decode(self, symbols, error_handler=None):
|
||||
symbols_vec = self.symbols
|
||||
_dec = self.decode_list
|
||||
for syms in common.iterate(symbols, self.buf_size, truncate=False):
|
||||
for received in syms:
|
||||
error = np.abs(symbols_vec - received)
|
||||
index = np.argmin(error)
|
||||
decoded, bits = _dec[index]
|
||||
if error_handler:
|
||||
error_handler(received=received, decoded=decoded)
|
||||
yield bits
|
||||
@@ -16,11 +16,11 @@ from . import config
|
||||
from . import framing
|
||||
from . import equalizer
|
||||
|
||||
modem = dsp.MODEM(config)
|
||||
modem = dsp.MODEM(config.symbols)
|
||||
|
||||
# Plots' size (WIDTH x HEIGHT)
|
||||
HEIGHT = np.floor(np.sqrt(len(modem.freqs)))
|
||||
WIDTH = np.ceil(len(modem.freqs) / float(HEIGHT))
|
||||
HEIGHT = np.floor(np.sqrt(config.Nfreq))
|
||||
WIDTH = np.ceil(config.Nfreq / float(HEIGHT))
|
||||
|
||||
COHERENCE_THRESHOLD = 0.99
|
||||
|
||||
@@ -187,7 +187,7 @@ class Receiver(object):
|
||||
symbol_list.append(equalized)
|
||||
|
||||
freq_handler = functools.partial(error_handler, freq=freq)
|
||||
bits = modem.qam.decode(S, freq_handler) # list of bit tuples
|
||||
bits = modem.decode(S, freq_handler) # list of bit tuples
|
||||
streams.append(bits) # stream per frequency
|
||||
|
||||
self.stats['symbol_list'] = symbol_list
|
||||
@@ -238,7 +238,7 @@ class Receiver(object):
|
||||
def report(self):
|
||||
if self.stats:
|
||||
duration = time.time() - self.stats['rx_start']
|
||||
audio_time = self.stats['rx_bits'] / float(modem.modem_bps)
|
||||
audio_time = self.stats['rx_bits'] / float(config.modem_bps)
|
||||
log.debug('Demodulated %.3f kB @ %.3f seconds (%.1f%% realtime)',
|
||||
self.stats['rx_bits'] / 8e3, duration,
|
||||
100 * duration / audio_time if audio_time else 0)
|
||||
@@ -248,9 +248,9 @@ class Receiver(object):
|
||||
|
||||
self.plt.figure()
|
||||
symbol_list = np.array(self.stats['symbol_list'])
|
||||
for i, freq in enumerate(modem.freqs):
|
||||
for i, freq in enumerate(config.frequencies):
|
||||
self.plt.subplot(HEIGHT, WIDTH, i+1)
|
||||
self._constellation(symbol_list[i], modem.qam.symbols,
|
||||
self._constellation(symbol_list[i], config.symbols,
|
||||
'$F_c = {} Hz$'.format(freq))
|
||||
self.plt.show()
|
||||
|
||||
@@ -263,6 +263,7 @@ class Receiver(object):
|
||||
self.plt.plot(points.real, points.imag, '+')
|
||||
self.plt.grid('on')
|
||||
self.plt.axis('equal')
|
||||
self.plt.axis(np.array([-1, 1, -1, 1])*1.1)
|
||||
self.plt.title(title)
|
||||
|
||||
|
||||
@@ -271,7 +272,7 @@ def main(args):
|
||||
signal = itertools.chain.from_iterable(reader)
|
||||
|
||||
skipped = common.take(signal, args.skip)
|
||||
log.debug('Skipping %.3f seconds', len(skipped) / float(modem.baud))
|
||||
log.debug('Skipping %.3f seconds', len(skipped) / float(config.baud))
|
||||
|
||||
reader.check = common.check_saturation
|
||||
|
||||
@@ -280,7 +281,7 @@ def main(args):
|
||||
try:
|
||||
log.info('Waiting for carrier tone: %.1f kHz', config.Fc / 1e3)
|
||||
signal, amplitude = detect(signal, config.Fc)
|
||||
receiver.start(signal, modem.freqs, gain=1.0/amplitude)
|
||||
receiver.start(signal, config.frequencies, gain=1.0/amplitude)
|
||||
receiver.run(args.output)
|
||||
success = True
|
||||
except Exception:
|
||||
|
||||
@@ -9,12 +9,12 @@ from . import wave
|
||||
|
||||
from . import common
|
||||
from . import config
|
||||
from . import dsp
|
||||
from . import stream
|
||||
from . import framing
|
||||
from . import equalizer
|
||||
from . import dsp
|
||||
|
||||
modem = dsp.MODEM(config)
|
||||
modem = dsp.MODEM(config.symbols)
|
||||
|
||||
|
||||
class Writer(object):
|
||||
@@ -29,7 +29,7 @@ class Writer(object):
|
||||
self.offset += len(data)
|
||||
|
||||
def start(self):
|
||||
carrier = modem.carriers[config.carrier_index]
|
||||
carrier = config.carriers[config.carrier_index]
|
||||
for value in train.prefix:
|
||||
self.write(carrier * value)
|
||||
|
||||
@@ -41,10 +41,10 @@ class Writer(object):
|
||||
self.write(silence)
|
||||
|
||||
def modulate(self, bits):
|
||||
padding = [0] * modem.bits_per_baud
|
||||
padding = [0] * config.bits_per_baud
|
||||
bits = itertools.chain(bits, padding)
|
||||
symbols_iter = modem.qam.encode(bits)
|
||||
carriers = modem.carriers / config.Nfreq
|
||||
symbols_iter = modem.encode(bits)
|
||||
carriers = config.carriers / config.Nfreq
|
||||
for i, symbols in common.iterate(symbols_iter,
|
||||
size=config.Nfreq, enumerate=True):
|
||||
symbols = np.array(list(symbols))
|
||||
@@ -52,7 +52,7 @@ class Writer(object):
|
||||
|
||||
data_duration = (i / config.Nfreq + 1) * config.Tsym
|
||||
if data_duration % 1 == 0:
|
||||
bits_size = data_duration * modem.modem_bps
|
||||
bits_size = data_duration * config.modem_bps
|
||||
log.debug('Sent %8.1f kB', bits_size / 8e3)
|
||||
|
||||
|
||||
|
||||
@@ -17,10 +17,7 @@ from amodem import recv
|
||||
from amodem import send
|
||||
from amodem import wave
|
||||
from amodem import calib
|
||||
from amodem import dsp
|
||||
|
||||
|
||||
modem = dsp.MODEM(config)
|
||||
null = open('/dev/null', 'wb')
|
||||
|
||||
|
||||
@@ -50,7 +47,7 @@ def FileType(mode, process=None):
|
||||
def main():
|
||||
fmt = ('Audio OFDM MODEM: {:.1f} kb/s ({:d}-QAM x {:d} carriers) '
|
||||
'Fs={:.1f} kHz')
|
||||
description = fmt.format(modem.modem_bps / 1e3, len(config.symbols),
|
||||
description = fmt.format(config.modem_bps / 1e3, len(config.symbols),
|
||||
config.Nfreq, config.Fs / 1e3)
|
||||
p = argparse.ArgumentParser(description=description)
|
||||
g = p.add_mutually_exclusive_group()
|
||||
|
||||
@@ -5,6 +5,9 @@ from amodem import dsp
|
||||
from amodem import config
|
||||
from amodem import sampling
|
||||
|
||||
import random
|
||||
import itertools
|
||||
|
||||
|
||||
def test_linreg():
|
||||
x = np.array([1, 3, 2, 8, 4, 6, 9, 7, 0, 5])
|
||||
@@ -63,3 +66,35 @@ def test_demux():
|
||||
res = dsp.Demux(sampling.Sampler(sig.real), freqs)
|
||||
res = np.array(list(res))
|
||||
assert np.max(np.abs(res - syms)) < 1e-12
|
||||
|
||||
|
||||
def test_qam():
|
||||
q = dsp.MODEM(config.symbols)
|
||||
r = random.Random(0)
|
||||
m = q.bits_per_symbol
|
||||
bits = [tuple(r.randint(0, 1) for j in range(m)) for i in range(1024)]
|
||||
stream = itertools.chain(*bits)
|
||||
S = list(q.encode(list(stream)))
|
||||
decoded = list(q.decode(S))
|
||||
assert decoded == bits
|
||||
|
||||
noise = lambda A: A*(r.uniform(-1, 1) + 1j*r.uniform(-1, 1))
|
||||
noised_symbols = [(s + noise(1e-3)) for s in S]
|
||||
decoded = list(q.decode(noised_symbols))
|
||||
assert decoded == bits
|
||||
|
||||
|
||||
def quantize(q, s):
|
||||
bits, = list(q.decode([s]))
|
||||
r, = q.encode(bits)
|
||||
index = np.argmin(np.abs(s - q.symbols))
|
||||
expected = q.symbols[index]
|
||||
assert r == expected
|
||||
|
||||
|
||||
def test_overflow():
|
||||
q = dsp.MODEM(config.symbols)
|
||||
r = np.random.RandomState(seed=0)
|
||||
for i in range(10000):
|
||||
s = 10*(r.normal() + 1j * r.normal())
|
||||
quantize(q, s)
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
import random
|
||||
import itertools
|
||||
|
||||
import numpy as np
|
||||
|
||||
from amodem import qam
|
||||
from amodem import config
|
||||
|
||||
|
||||
def test_qam():
|
||||
q = qam.QAM(config.symbols)
|
||||
r = random.Random(0)
|
||||
m = q.bits_per_symbol
|
||||
bits = [tuple(r.randint(0, 1) for j in range(m)) for i in range(1024)]
|
||||
stream = itertools.chain(*bits)
|
||||
S = list(q.encode(list(stream)))
|
||||
decoded = list(q.decode(S))
|
||||
assert decoded == bits
|
||||
|
||||
noise = lambda A: A*(r.uniform(-1, 1) + 1j*r.uniform(-1, 1))
|
||||
noised_symbols = [(s + noise(1e-3)) for s in S]
|
||||
decoded = list(q.decode(noised_symbols))
|
||||
assert decoded == bits
|
||||
|
||||
|
||||
def quantize(q, s):
|
||||
bits, = list(q.decode([s]))
|
||||
r, = q.encode(bits)
|
||||
index = np.argmin(np.abs(s - q.symbols))
|
||||
expected = q.symbols[index]
|
||||
assert r == expected
|
||||
|
||||
|
||||
def test_overflow():
|
||||
q = qam.QAM(config.symbols)
|
||||
r = np.random.RandomState(seed=0)
|
||||
for i in range(10000):
|
||||
s = 10*(r.normal() + 1j * r.normal())
|
||||
quantize(q, s)
|
||||
Reference in New Issue
Block a user