mirror of
https://github.com/romanz/amodem.git
synced 2026-04-01 17:26:49 +08:00
ecc.decode() should generate chunks
This commit is contained in:
60
ecc.py
60
ecc.py
@@ -1,57 +1,49 @@
|
||||
''' Reed-Solomon CODEC. '''
|
||||
from reedsolo import rs_encode_msg, rs_correct_msg, ReedSolomonError
|
||||
from reedsolo import rs_encode_msg, rs_correct_msg
|
||||
|
||||
import struct
|
||||
import logging
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
import common
|
||||
|
||||
DEFAULT_NSYM = 10
|
||||
BLOCK_SIZE = 255
|
||||
|
||||
LEN_FMT = '<I'
|
||||
|
||||
def end_of_stream(size):
|
||||
return bytearray([BLOCK_SIZE]) + b'\x00' * size
|
||||
|
||||
|
||||
def encode(data, nsym=DEFAULT_NSYM):
|
||||
log.debug('Encoded {} bytes'.format(len(data)))
|
||||
data = bytearray(struct.pack(LEN_FMT, len(data)) + data)
|
||||
chunk_size = BLOCK_SIZE - nsym
|
||||
chunk_size = BLOCK_SIZE - nsym - 1
|
||||
|
||||
enc = bytearray()
|
||||
for i in range(0, len(data), chunk_size):
|
||||
chunk = data[i:i+chunk_size]
|
||||
if len(chunk) < chunk_size:
|
||||
padding = b'\x00' * (chunk_size - len(chunk))
|
||||
chunk = bytearray(data[i:i+chunk_size])
|
||||
|
||||
size = len(chunk)
|
||||
if size < chunk_size:
|
||||
padding = b'\x00' * (chunk_size - size)
|
||||
chunk.extend(padding)
|
||||
|
||||
chunk = bytearray([size]) + chunk
|
||||
enc.extend(rs_encode_msg(chunk, nsym))
|
||||
|
||||
enc.extend(rs_encode_msg(end_of_stream(chunk_size), nsym))
|
||||
return enc
|
||||
|
||||
|
||||
def decode(data, nsym=DEFAULT_NSYM):
|
||||
data = bytearray(data)
|
||||
dec = bytearray()
|
||||
for i in range(0, len(data), BLOCK_SIZE):
|
||||
chunk = data[i:i+BLOCK_SIZE]
|
||||
try:
|
||||
dec.extend(rs_correct_msg(chunk, nsym))
|
||||
log.debug('Decoded %d blocks = %d bytes',
|
||||
(i+1) / BLOCK_SIZE, len(dec))
|
||||
except ReedSolomonError as e:
|
||||
log.debug('Decoding stopped: %s', e)
|
||||
break
|
||||
|
||||
if not dec:
|
||||
return None
|
||||
last_chunk = end_of_stream(BLOCK_SIZE - nsym - 1)
|
||||
for _, chunk in common.iterate(data, BLOCK_SIZE):
|
||||
chunk = bytearray(rs_correct_msg(chunk, nsym))
|
||||
if chunk == last_chunk:
|
||||
return # end of stream
|
||||
|
||||
overhead = (i - len(dec)) / float(i)
|
||||
blocks = i / BLOCK_SIZE
|
||||
log.debug('Decoded %d blocks = %d bytes (ECC overhead %.1f%%)',
|
||||
blocks, len(dec), overhead * 100)
|
||||
size = chunk[0]
|
||||
chunk = chunk[1:]
|
||||
if size > len(chunk):
|
||||
raise ValueError('Invalid chunk', size, len(chunk), chunk)
|
||||
|
||||
n = struct.calcsize(LEN_FMT)
|
||||
payload, length = dec[n:], dec[:n]
|
||||
length, = struct.unpack(LEN_FMT, length)
|
||||
if length > len(payload):
|
||||
log.warning('%d bytes are missing!', length - len(payload))
|
||||
return None
|
||||
|
||||
return payload[:length]
|
||||
yield chunk[:size]
|
||||
|
||||
12
test_ecc.py
12
test_ecc.py
@@ -1,15 +1,23 @@
|
||||
import ecc
|
||||
import random
|
||||
import itertools
|
||||
|
||||
|
||||
def concat(chunks):
|
||||
return bytearray(itertools.chain.from_iterable(chunks))
|
||||
|
||||
|
||||
def test_random():
|
||||
r = random.Random(0)
|
||||
x = bytearray(r.randrange(0, 256) for i in range(16 * 1024))
|
||||
y = ecc.encode(x)
|
||||
assert len(y) % ecc.BLOCK_SIZE == 0
|
||||
x_ = ecc.decode(y)
|
||||
x_ = concat(ecc.decode(y))
|
||||
assert x_[:len(x)] == x
|
||||
assert all(v == 0 for v in x_[len(x):])
|
||||
|
||||
|
||||
def test_file():
|
||||
data = open('data.send').read()
|
||||
assert ecc.decode(ecc.encode(data)) == data
|
||||
enc = ecc.encode(data)
|
||||
assert concat(ecc.decode(enc)) == data
|
||||
|
||||
Reference in New Issue
Block a user