mirror of
https://github.com/romanz/amodem.git
synced 2026-02-24 16:18:12 +08:00
framing: handle bitstream & replace ECC by CRC-32
This commit is contained in:
@@ -8,15 +8,6 @@ log = logging.getLogger(__name__)
|
||||
scaling = 32000.0 # out of 2**15
|
||||
SATURATION_THRESHOLD = (2**15 - 1) / scaling
|
||||
|
||||
LENGTH_FORMAT = '<I'
|
||||
|
||||
|
||||
def to_bits(bytes_list):
|
||||
for val in bytes_list:
|
||||
for i in range(8):
|
||||
mask = 1 << i
|
||||
yield (1 if (val & mask) else 0)
|
||||
|
||||
|
||||
class SaturationError(ValueError):
|
||||
pass
|
||||
|
||||
@@ -1,47 +1,105 @@
|
||||
''' Reed-Solomon CODEC. '''
|
||||
from reedsolo import rs_encode_msg, rs_correct_msg
|
||||
|
||||
from . import common
|
||||
import bitarray
|
||||
|
||||
import functools
|
||||
import itertools
|
||||
import binascii
|
||||
import struct
|
||||
import logging
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_NSYM = 10
|
||||
BLOCK_SIZE = 255
|
||||
_crc32 = lambda x, mask: binascii.crc32(x) & mask
|
||||
# (so the result will be unsigned on Python 2/3)
|
||||
|
||||
|
||||
def end_of_stream(size):
|
||||
return bytearray([BLOCK_SIZE]) + b'\x00' * size
|
||||
class Checksum(object):
|
||||
fmt = '>L' # unsigned longs (32-bit)
|
||||
size = struct.calcsize(fmt)
|
||||
func = functools.partial(_crc32, mask=0xFFFFFFFF)
|
||||
|
||||
def encode(self, payload):
|
||||
checksum = self.func(payload)
|
||||
return struct.pack(self.fmt, checksum) + payload
|
||||
|
||||
def decode(self, data):
|
||||
received, = struct.unpack(self.fmt, data[:self.size])
|
||||
payload = data[self.size:]
|
||||
expected = self.func(payload)
|
||||
if received != expected:
|
||||
raise ValueError('invalid checksum')
|
||||
return payload
|
||||
|
||||
|
||||
def encode(data, nsym=DEFAULT_NSYM):
|
||||
chunk_size = BLOCK_SIZE - nsym - 1
|
||||
class Framer(object):
|
||||
block_size = 1024
|
||||
prefix_fmt = '>L'
|
||||
prefix_len = struct.calcsize(prefix_fmt)
|
||||
checksum = Checksum()
|
||||
|
||||
for _, chunk in common.iterate(data=data, size=chunk_size,
|
||||
func=bytearray, truncate=False):
|
||||
size = len(chunk)
|
||||
if size < chunk_size:
|
||||
padding = [0] * (chunk_size - size)
|
||||
chunk.extend(padding)
|
||||
EOF = b''
|
||||
|
||||
block = bytearray([size]) + chunk
|
||||
yield rs_encode_msg(block, nsym)
|
||||
def _pack(self, block):
|
||||
frame = self.checksum.encode(block)
|
||||
return struct.pack(self.prefix_fmt, len(frame)) + frame
|
||||
|
||||
yield rs_encode_msg(end_of_stream(chunk_size), nsym)
|
||||
def encode(self, data):
|
||||
for _, block in common.iterate(data=data, size=self.block_size,
|
||||
func=bytearray, truncate=False):
|
||||
yield self._pack(block=block)
|
||||
yield self._pack(block=self.EOF)
|
||||
|
||||
def decode(self, data):
|
||||
data = iter(data)
|
||||
while True:
|
||||
length, = self._take_fmt(data, self.prefix_fmt)
|
||||
frame = self._take_len(data, length)
|
||||
block = self.checksum.decode(frame)
|
||||
if block == self.EOF:
|
||||
return
|
||||
|
||||
yield block
|
||||
|
||||
def _take_fmt(self, data, fmt):
|
||||
length = struct.calcsize(fmt)
|
||||
chunk = bytearray(itertools.islice(data, length))
|
||||
if len(chunk) < length:
|
||||
raise StopIteration()
|
||||
return struct.unpack(fmt, chunk)
|
||||
|
||||
def _take_len(self, data, length):
|
||||
chunk = bytearray(itertools.islice(data, length))
|
||||
if len(chunk) < length:
|
||||
raise StopIteration()
|
||||
return chunk
|
||||
|
||||
|
||||
def decode(blocks, nsym=DEFAULT_NSYM):
|
||||
def chain_wrapper(func):
|
||||
@functools.wraps(func)
|
||||
def wrapped(*args, **kwargs):
|
||||
result = func(*args, **kwargs)
|
||||
return itertools.chain.from_iterable(result)
|
||||
return wrapped
|
||||
|
||||
last_chunk = end_of_stream(BLOCK_SIZE - nsym - 1)
|
||||
for block in blocks:
|
||||
assert len(block) == BLOCK_SIZE
|
||||
chunk = bytearray(rs_correct_msg(block, nsym))
|
||||
if chunk == last_chunk:
|
||||
log.info('EOF encountered')
|
||||
return # end of stream
|
||||
|
||||
size = chunk[0]
|
||||
payload = chunk[1:]
|
||||
assert size <= len(payload)
|
||||
@chain_wrapper
|
||||
def encode(data, framer=None):
|
||||
framer = framer or Framer()
|
||||
for frame in framer.encode(data):
|
||||
bits = bitarray.bitarray(endian='little')
|
||||
bits.frombytes(bytes(frame))
|
||||
yield bits
|
||||
|
||||
yield payload[:size]
|
||||
|
||||
@chain_wrapper
|
||||
def _to_bytes(bits, block_size=1):
|
||||
for _, chunk in common.iterate(data=bits, size=8*block_size,
|
||||
func=lambda x: x, truncate=True):
|
||||
yield bitarray.bitarray(chunk, endian='little').tobytes()
|
||||
|
||||
|
||||
@chain_wrapper
|
||||
def decode(bits, framer=None):
|
||||
framer = framer or Framer()
|
||||
for frame in framer.decode(_to_bytes(bits)):
|
||||
yield frame
|
||||
|
||||
@@ -224,15 +224,16 @@ class Receiver(object):
|
||||
filt = self._train(sampler, order=11, lookahead=5)
|
||||
sampler.equalizer = lambda x: list(filt(x))
|
||||
|
||||
data_bits = self._demodulate(sampler, freqs)
|
||||
self.bits = itertools.chain.from_iterable(data_bits)
|
||||
bitstream = self._demodulate(sampler, freqs)
|
||||
self.bitstream = itertools.chain.from_iterable(bitstream)
|
||||
|
||||
def decode(self, output):
|
||||
chunks = framing.decode(_blocks(self.bits))
|
||||
def run(self, output):
|
||||
data = framing.decode(self.bitstream)
|
||||
self.size = 0
|
||||
for chunk in chunks:
|
||||
for _, chunk in common.iterate(data=data, size=256,
|
||||
truncate=False, func=bytearray):
|
||||
output.write(chunk)
|
||||
self.size = self.size + len(chunk)
|
||||
self.size += len(chunk)
|
||||
|
||||
def report(self):
|
||||
if self.stats:
|
||||
@@ -265,15 +266,6 @@ class Receiver(object):
|
||||
self.plt.title(title)
|
||||
|
||||
|
||||
def _blocks(bits):
|
||||
while True:
|
||||
block = bitarray.bitarray(endian='little')
|
||||
block.extend(itertools.islice(bits, 8 * framing.BLOCK_SIZE))
|
||||
if not block:
|
||||
break
|
||||
yield bytearray(block.tobytes())
|
||||
|
||||
|
||||
def izip(streams):
|
||||
iters = [iter(s) for s in streams]
|
||||
while True:
|
||||
@@ -296,7 +288,7 @@ def main(args):
|
||||
try:
|
||||
signal, amplitude = detect(signal, config.Fc)
|
||||
receiver.start(signal, modem.freqs, gain=1.0/amplitude)
|
||||
receiver.decode(args.output)
|
||||
receiver.run(args.output)
|
||||
success = True
|
||||
except Exception:
|
||||
log.exception('Decoding failed')
|
||||
|
||||
@@ -47,15 +47,13 @@ class Writer(object):
|
||||
self.write(silence)
|
||||
|
||||
def modulate(self, bits):
|
||||
padding = [0] * modem.bits_per_baud
|
||||
bits = itertools.chain(bits, padding)
|
||||
symbols_iter = modem.qam.encode(bits)
|
||||
symbols_iter = itertools.chain(symbols_iter, itertools.repeat(0))
|
||||
carriers = modem.carriers / config.Nfreq
|
||||
while True:
|
||||
symbols = itertools.islice(symbols_iter, config.Nfreq)
|
||||
for _, symbols in common.iterate(symbols_iter, size=config.Nfreq):
|
||||
symbols = np.array(list(symbols))
|
||||
self.write(np.dot(symbols, carriers))
|
||||
if all(symbols == 0): # EOF marker
|
||||
break
|
||||
|
||||
|
||||
def main(args):
|
||||
@@ -73,9 +71,10 @@ def main(args):
|
||||
log.info('%.3f seconds of training audio', training_duration)
|
||||
|
||||
reader = stream.Reader(args.input, bufsize=(64 << 10), eof=True)
|
||||
data = itertools.chain.from_iterable(reader)
|
||||
encoded = itertools.chain.from_iterable(framing.encode(data))
|
||||
writer.modulate(bits=common.to_bits(encoded))
|
||||
data = list(itertools.chain.from_iterable(reader))
|
||||
bits = list(framing.encode(data))
|
||||
data_ = list(framing.decode(bits))
|
||||
writer.modulate(bits=bits)
|
||||
|
||||
data_size = writer.offset - training_size
|
||||
log.info('%.3f seconds of data audio, for %.3f kB of data',
|
||||
|
||||
@@ -1,32 +1,35 @@
|
||||
from amodem import framing
|
||||
import random
|
||||
import itertools
|
||||
import reedsolo
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def concat(chunks):
|
||||
return bytearray(itertools.chain.from_iterable(chunks))
|
||||
def concat(iterable):
|
||||
return bytearray(itertools.chain.from_iterable(iterable))
|
||||
|
||||
r = random.Random(0)
|
||||
blob = bytearray(r.randrange(0, 256) for i in range(64 * 1024))
|
||||
|
||||
|
||||
def test_random():
|
||||
r = random.Random(0)
|
||||
x = bytearray(r.randrange(0, 256) for i in range(64 * 1024))
|
||||
y = framing.encode(x)
|
||||
x_ = concat(framing.decode(y))
|
||||
assert x_ == x
|
||||
@pytest.fixture(params=[b'', b'abc', b'1234567890', blob, blob[:12345]])
|
||||
def data(request):
|
||||
return request.param
|
||||
|
||||
|
||||
def test_errors():
|
||||
data = bytearray(range(244))
|
||||
blocks = list(framing.encode(data))
|
||||
assert len(blocks) == 2
|
||||
for i in range(framing.DEFAULT_NSYM // 2):
|
||||
blocks[0][i] = blocks[0][i] ^ 0xFF
|
||||
def test_checksum(data):
|
||||
c = framing.Checksum()
|
||||
assert c.decode(c.encode(data)) == data
|
||||
|
||||
i = framing.DEFAULT_NSYM // 2
|
||||
try:
|
||||
blocks[0][i] = blocks[0][i] ^ 0xFF
|
||||
concat(framing.decode(blocks))
|
||||
assert False
|
||||
except reedsolo.ReedSolomonError as e:
|
||||
assert e.args == ('Too many errors to correct',)
|
||||
|
||||
def test_framer(data):
|
||||
f = framing.Framer()
|
||||
encoded = concat(f.encode(data))
|
||||
decoded = concat(f.decode(encoded))
|
||||
assert decoded == data
|
||||
|
||||
|
||||
def test_main(data):
|
||||
encoded = framing.encode(data)
|
||||
decoded = framing.decode(encoded)
|
||||
assert bytearray(decoded) == data
|
||||
|
||||
@@ -51,8 +51,13 @@ def run(size, chan=None, df=0, success=True):
|
||||
assert rx_data == tx_data
|
||||
|
||||
|
||||
def test_small():
|
||||
run(1024, chan=lambda x: x)
|
||||
@pytest.fixture(params=[0, 1, 3, 10, 42, 123])
|
||||
def small_size(request):
|
||||
return request.param
|
||||
|
||||
|
||||
def test_small(small_size):
|
||||
run(small_size, chan=lambda x: x)
|
||||
|
||||
|
||||
def test_error():
|
||||
|
||||
@@ -51,7 +51,3 @@ def test_find_start():
|
||||
start = recv.find_start(buf, length*config.Nsym)
|
||||
expected = offset + len(prefix)
|
||||
assert expected == start
|
||||
|
||||
|
||||
def test_blocks():
|
||||
assert list(recv._blocks([])) == []
|
||||
|
||||
Reference in New Issue
Block a user