framing: handle bitstream & replace ECC by CRC-32

This commit is contained in:
Roman Zeyde
2014-09-06 14:27:18 +03:00
parent 9cdabd938a
commit 3602831a29
7 changed files with 134 additions and 90 deletions

View File

@@ -8,15 +8,6 @@ log = logging.getLogger(__name__)
scaling = 32000.0 # out of 2**15
SATURATION_THRESHOLD = (2**15 - 1) / scaling
LENGTH_FORMAT = '<I'
def to_bits(bytes_list):
for val in bytes_list:
for i in range(8):
mask = 1 << i
yield (1 if (val & mask) else 0)
class SaturationError(ValueError):
pass

View File

@@ -1,47 +1,105 @@
''' Reed-Solomon CODEC. '''
from reedsolo import rs_encode_msg, rs_correct_msg
from . import common
import bitarray
import functools
import itertools
import binascii
import struct
import logging
log = logging.getLogger(__name__)
DEFAULT_NSYM = 10
BLOCK_SIZE = 255
_crc32 = lambda x, mask: binascii.crc32(x) & mask
# (so the result will be unsigned on Python 2/3)
def end_of_stream(size):
return bytearray([BLOCK_SIZE]) + b'\x00' * size
class Checksum(object):
fmt = '>L' # unsigned longs (32-bit)
size = struct.calcsize(fmt)
func = functools.partial(_crc32, mask=0xFFFFFFFF)
def encode(self, payload):
checksum = self.func(payload)
return struct.pack(self.fmt, checksum) + payload
def decode(self, data):
received, = struct.unpack(self.fmt, data[:self.size])
payload = data[self.size:]
expected = self.func(payload)
if received != expected:
raise ValueError('invalid checksum')
return payload
def encode(data, nsym=DEFAULT_NSYM):
chunk_size = BLOCK_SIZE - nsym - 1
class Framer(object):
block_size = 1024
prefix_fmt = '>L'
prefix_len = struct.calcsize(prefix_fmt)
checksum = Checksum()
for _, chunk in common.iterate(data=data, size=chunk_size,
func=bytearray, truncate=False):
size = len(chunk)
if size < chunk_size:
padding = [0] * (chunk_size - size)
chunk.extend(padding)
EOF = b''
block = bytearray([size]) + chunk
yield rs_encode_msg(block, nsym)
def _pack(self, block):
frame = self.checksum.encode(block)
return struct.pack(self.prefix_fmt, len(frame)) + frame
yield rs_encode_msg(end_of_stream(chunk_size), nsym)
def encode(self, data):
for _, block in common.iterate(data=data, size=self.block_size,
func=bytearray, truncate=False):
yield self._pack(block=block)
yield self._pack(block=self.EOF)
def decode(self, data):
data = iter(data)
while True:
length, = self._take_fmt(data, self.prefix_fmt)
frame = self._take_len(data, length)
block = self.checksum.decode(frame)
if block == self.EOF:
return
yield block
def _take_fmt(self, data, fmt):
length = struct.calcsize(fmt)
chunk = bytearray(itertools.islice(data, length))
if len(chunk) < length:
raise StopIteration()
return struct.unpack(fmt, chunk)
def _take_len(self, data, length):
chunk = bytearray(itertools.islice(data, length))
if len(chunk) < length:
raise StopIteration()
return chunk
def decode(blocks, nsym=DEFAULT_NSYM):
def chain_wrapper(func):
@functools.wraps(func)
def wrapped(*args, **kwargs):
result = func(*args, **kwargs)
return itertools.chain.from_iterable(result)
return wrapped
last_chunk = end_of_stream(BLOCK_SIZE - nsym - 1)
for block in blocks:
assert len(block) == BLOCK_SIZE
chunk = bytearray(rs_correct_msg(block, nsym))
if chunk == last_chunk:
log.info('EOF encountered')
return # end of stream
size = chunk[0]
payload = chunk[1:]
assert size <= len(payload)
@chain_wrapper
def encode(data, framer=None):
framer = framer or Framer()
for frame in framer.encode(data):
bits = bitarray.bitarray(endian='little')
bits.frombytes(bytes(frame))
yield bits
yield payload[:size]
@chain_wrapper
def _to_bytes(bits, block_size=1):
for _, chunk in common.iterate(data=bits, size=8*block_size,
func=lambda x: x, truncate=True):
yield bitarray.bitarray(chunk, endian='little').tobytes()
@chain_wrapper
def decode(bits, framer=None):
framer = framer or Framer()
for frame in framer.decode(_to_bytes(bits)):
yield frame

View File

@@ -224,15 +224,16 @@ class Receiver(object):
filt = self._train(sampler, order=11, lookahead=5)
sampler.equalizer = lambda x: list(filt(x))
data_bits = self._demodulate(sampler, freqs)
self.bits = itertools.chain.from_iterable(data_bits)
bitstream = self._demodulate(sampler, freqs)
self.bitstream = itertools.chain.from_iterable(bitstream)
def decode(self, output):
chunks = framing.decode(_blocks(self.bits))
def run(self, output):
data = framing.decode(self.bitstream)
self.size = 0
for chunk in chunks:
for _, chunk in common.iterate(data=data, size=256,
truncate=False, func=bytearray):
output.write(chunk)
self.size = self.size + len(chunk)
self.size += len(chunk)
def report(self):
if self.stats:
@@ -265,15 +266,6 @@ class Receiver(object):
self.plt.title(title)
def _blocks(bits):
while True:
block = bitarray.bitarray(endian='little')
block.extend(itertools.islice(bits, 8 * framing.BLOCK_SIZE))
if not block:
break
yield bytearray(block.tobytes())
def izip(streams):
iters = [iter(s) for s in streams]
while True:
@@ -296,7 +288,7 @@ def main(args):
try:
signal, amplitude = detect(signal, config.Fc)
receiver.start(signal, modem.freqs, gain=1.0/amplitude)
receiver.decode(args.output)
receiver.run(args.output)
success = True
except Exception:
log.exception('Decoding failed')

View File

@@ -47,15 +47,13 @@ class Writer(object):
self.write(silence)
def modulate(self, bits):
padding = [0] * modem.bits_per_baud
bits = itertools.chain(bits, padding)
symbols_iter = modem.qam.encode(bits)
symbols_iter = itertools.chain(symbols_iter, itertools.repeat(0))
carriers = modem.carriers / config.Nfreq
while True:
symbols = itertools.islice(symbols_iter, config.Nfreq)
for _, symbols in common.iterate(symbols_iter, size=config.Nfreq):
symbols = np.array(list(symbols))
self.write(np.dot(symbols, carriers))
if all(symbols == 0): # EOF marker
break
def main(args):
@@ -73,9 +71,10 @@ def main(args):
log.info('%.3f seconds of training audio', training_duration)
reader = stream.Reader(args.input, bufsize=(64 << 10), eof=True)
data = itertools.chain.from_iterable(reader)
encoded = itertools.chain.from_iterable(framing.encode(data))
writer.modulate(bits=common.to_bits(encoded))
data = list(itertools.chain.from_iterable(reader))
bits = list(framing.encode(data))
data_ = list(framing.decode(bits))
writer.modulate(bits=bits)
data_size = writer.offset - training_size
log.info('%.3f seconds of data audio, for %.3f kB of data',

View File

@@ -1,32 +1,35 @@
from amodem import framing
import random
import itertools
import reedsolo
import pytest
def concat(chunks):
return bytearray(itertools.chain.from_iterable(chunks))
def concat(iterable):
return bytearray(itertools.chain.from_iterable(iterable))
r = random.Random(0)
blob = bytearray(r.randrange(0, 256) for i in range(64 * 1024))
def test_random():
r = random.Random(0)
x = bytearray(r.randrange(0, 256) for i in range(64 * 1024))
y = framing.encode(x)
x_ = concat(framing.decode(y))
assert x_ == x
@pytest.fixture(params=[b'', b'abc', b'1234567890', blob, blob[:12345]])
def data(request):
return request.param
def test_errors():
data = bytearray(range(244))
blocks = list(framing.encode(data))
assert len(blocks) == 2
for i in range(framing.DEFAULT_NSYM // 2):
blocks[0][i] = blocks[0][i] ^ 0xFF
def test_checksum(data):
c = framing.Checksum()
assert c.decode(c.encode(data)) == data
i = framing.DEFAULT_NSYM // 2
try:
blocks[0][i] = blocks[0][i] ^ 0xFF
concat(framing.decode(blocks))
assert False
except reedsolo.ReedSolomonError as e:
assert e.args == ('Too many errors to correct',)
def test_framer(data):
f = framing.Framer()
encoded = concat(f.encode(data))
decoded = concat(f.decode(encoded))
assert decoded == data
def test_main(data):
encoded = framing.encode(data)
decoded = framing.decode(encoded)
assert bytearray(decoded) == data

View File

@@ -51,8 +51,13 @@ def run(size, chan=None, df=0, success=True):
assert rx_data == tx_data
def test_small():
run(1024, chan=lambda x: x)
@pytest.fixture(params=[0, 1, 3, 10, 42, 123])
def small_size(request):
return request.param
def test_small(small_size):
run(small_size, chan=lambda x: x)
def test_error():

View File

@@ -51,7 +51,3 @@ def test_find_start():
start = recv.find_start(buf, length*config.Nsym)
expected = offset + len(prefix)
assert expected == start
def test_blocks():
assert list(recv._blocks([])) == []