mirror of
https://github.com/romanz/amodem.git
synced 2026-04-21 22:06:27 +08:00
Add RS ECC
This commit is contained in:
15
common.py
15
common.py
@@ -1,4 +1,6 @@
|
||||
import numpy as np
|
||||
import reedsolo as rs
|
||||
|
||||
import hashlib
|
||||
import struct
|
||||
import logging
|
||||
@@ -21,17 +23,8 @@ SATURATION_THRESHOLD = 1.0
|
||||
|
||||
LENGTH_FORMAT = '<I'
|
||||
|
||||
def pack(data):
|
||||
log.info('Sending {} bytes'.format(len(data)))
|
||||
return data
|
||||
|
||||
def unpack(data):
|
||||
log.info('Received {} bytes'.format(len(data)))
|
||||
return data
|
||||
|
||||
def to_bits(chars):
|
||||
for c in chars:
|
||||
val = ord(c)
|
||||
def to_bits(bytes_list):
|
||||
for val in bytes_list:
|
||||
for i in range(8):
|
||||
mask = 1 << i
|
||||
yield (1 if (val & mask) else 0)
|
||||
|
||||
52
ecc.py
Normal file
52
ecc.py
Normal file
@@ -0,0 +1,52 @@
|
||||
''' Reed-Solomon CODEC. '''
|
||||
from reedsolo import rs_encode_msg, rs_correct_msg, ReedSolomonError
|
||||
|
||||
import struct
|
||||
import logging
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_NSYM = 25
|
||||
BLOCK_SIZE = 255
|
||||
|
||||
LEN_FMT = '<I'
|
||||
|
||||
def encode(data, nsym=DEFAULT_NSYM):
|
||||
log.info('Encoded {} bytes'.format(len(data)))
|
||||
data = bytearray(struct.pack(LEN_FMT, len(data)) + data)
|
||||
chunk_size = BLOCK_SIZE - nsym
|
||||
enc = bytearray()
|
||||
for i in range(0, len(data), chunk_size):
|
||||
chunk = data[i:i+chunk_size]
|
||||
if len(chunk) < chunk_size:
|
||||
padding = b'\x00' * (chunk_size - len(chunk))
|
||||
chunk.extend(padding)
|
||||
enc.extend(rs_encode_msg(chunk, nsym))
|
||||
|
||||
return enc
|
||||
|
||||
def decode(data, nsym=DEFAULT_NSYM):
|
||||
data = bytearray(data)
|
||||
dec = bytearray()
|
||||
for i in range(0, len(data), BLOCK_SIZE):
|
||||
chunk = data[i:i+BLOCK_SIZE]
|
||||
try:
|
||||
dec.extend(rs_correct_msg(chunk, nsym))
|
||||
except ReedSolomonError:
|
||||
break
|
||||
|
||||
n = struct.calcsize(LEN_FMT)
|
||||
payload, length = dec[n:], dec[:n]
|
||||
length, = struct.unpack(LEN_FMT, length)
|
||||
assert length <= len(payload)
|
||||
log.info('Decoded {} bytes'.format(length))
|
||||
return payload[:length]
|
||||
|
||||
|
||||
def test_codec():
|
||||
import os
|
||||
x = bytearray(os.urandom(1024))
|
||||
y = encode(x)
|
||||
assert len(y) % BLOCK_SIZE == 0
|
||||
x_ = decode(y)
|
||||
assert x_[:len(x)] == x
|
||||
assert all(v == 0 for v in x_[len(x):])
|
||||
13
errors.py
13
errors.py
@@ -2,16 +2,15 @@ import common
|
||||
import sys
|
||||
|
||||
tx, rx = sys.argv[1:]
|
||||
tx = open(tx).read()
|
||||
rx = open(rx).read()
|
||||
tx = bytearray(open(tx).read())
|
||||
rx = bytearray(open(rx).read())
|
||||
|
||||
L = min(len(tx), len(rx))
|
||||
rx = list(common.to_bits(rx[:L]))
|
||||
tx = list(common.to_bits(tx[:L]))
|
||||
indices = [index for index, (r, t) in enumerate(zip(rx, tx)) if r != t]
|
||||
|
||||
if indices:
|
||||
total = L*8
|
||||
errors = len(indices)
|
||||
print('{}/{} bit error rate: {:.3f}%'.format(errors, total, (100.0 * errors) / total))
|
||||
sys.exit(1)
|
||||
total = L*8
|
||||
errors = len(indices)
|
||||
print('{}/{} bit error rate: {:.3f}%'.format(errors, total, (100.0 * errors) / total))
|
||||
sys.exit(int(errors > 0))
|
||||
|
||||
4
recv.py
4
recv.py
@@ -9,6 +9,7 @@ logging.basicConfig(level=0, format='%(message)s')
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
import sigproc
|
||||
import ecc
|
||||
import show
|
||||
from common import *
|
||||
|
||||
@@ -173,8 +174,7 @@ def main(t, x):
|
||||
else:
|
||||
data = iterate(data_bits, bufsize=8, advance=8, func=to_bytes)
|
||||
data = ''.join(c for _, c in data)
|
||||
log.info( 'Demodulated {} payload bytes'.format(len(data)) )
|
||||
data = unpack(data)
|
||||
data = ecc.decode(data)
|
||||
with file('data.recv', 'wb') as f:
|
||||
f.write(data)
|
||||
|
||||
|
||||
3
send.py
3
send.py
@@ -12,6 +12,7 @@ import itertools
|
||||
logging.basicConfig(level=0, format='%(message)s')
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
import ecc
|
||||
import sigproc
|
||||
from common import *
|
||||
|
||||
@@ -70,7 +71,7 @@ if __name__ == '__main__':
|
||||
for c in sym.carrier:
|
||||
train(sig, c)
|
||||
|
||||
bits = to_bits(pack(data))
|
||||
bits = to_bits(ecc.encode(data))
|
||||
modulate(sig, bits)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user