refactor recv

This commit is contained in:
Roman Zeyde
2014-07-04 18:26:28 +03:00
parent 2e2ebd6280
commit 50ffdfb1dd

24
recv.py
View File

@@ -9,10 +9,10 @@ log = logging.getLogger(__name__)
import sigproc
from common import *
COHERENCE_THRESHOLD = 0.9
COHERENCE_THRESHOLD = 0.95
CARRIER_DURATION = 300
CARRIER_THRESHOLD = int(0.9 * CARRIER_DURATION)
CARRIER_THRESHOLD = int(0.95 * CARRIER_DURATION)
def power(x):
return np.dot(x.conj(), x).real / len(x)
@@ -29,7 +29,8 @@ def coherence(x, freq):
def detect(x, freq):
counter = 0
for offset, coeff in iterate(x, Nsym, advance=Nsym, func=lambda x: coherence(x, Fc)):
for offset, buf in iterate(x, Nsym, advance=Nsym):
coeff = coherence(buf, Fc)
if abs(coeff) > COHERENCE_THRESHOLD:
counter += 1
else:
@@ -48,7 +49,7 @@ def find_start(x, start):
Hc = exp_iwt(Fc, len(x_))
P = np.abs(Hc.conj() * x_) ** 2
cumsumP = P.cumsum()
start = np.argmax(cumsumP[length:] - cumsumP[:-length]) + begin
start = begin + np.argmax(cumsumP[length:] - cumsumP[:-length])
log.info('Carrier starts at {:.3f} ms'.format(start * Tsym * 1e3 / Nsym))
return start
@@ -60,15 +61,14 @@ def extract_symbols(x, freq, offset=0):
def demodulate(x, freq, filt):
S = extract_symbols(x, freq)
S = np.array(list(filt.apply(S)))
#constellation(S)
S = filt(S)
for bits in sigproc.modulator.decode(S): # list of bit tuples
yield bits
def equalize(x, freqs):
def receive(x, freqs):
prefix = [1]*300 + [0]*100
symbols = list(itertools.islice(extract_symbols(x, Fc), len(prefix)))
bits = np.round(np.abs(symbols))
symbols = itertools.islice(extract_symbols(x, Fc), len(prefix))
bits = np.round(np.abs(list(symbols)))
bits = np.array(bits, dtype=int)
if all(bits[:len(prefix)] != prefix):
return None
@@ -80,10 +80,10 @@ def equalize(x, freqs):
training = ([1]*10 + [0]*10)*20 + [0]*100
S = list(itertools.islice(extract_symbols(x, freq), len(training)))
filt = sigproc.Filter.train(S, training)
filt = sigproc.train(S, training)
filters[freq] = filt
S = list(filt.apply(S))
S = list(filt(S))
y = np.array(S).real
train_result = y > 0.5
@@ -149,7 +149,7 @@ def main(fname):
if peak > SATURATION_THRESHOLD:
raise ValueError('Saturation detected: {:.3f}'.format(peak))
data_bits = equalize(x / amp, frequencies)
data_bits = receive(x / amp, frequencies)
if data_bits is None:
log.info('Cannot demodulate symbols!')
else: