refactor Filter at sigproc

This commit is contained in:
Roman Zeyde
2014-07-19 15:23:12 +03:00
parent 500f956c43
commit 6d46793770

View File

@@ -4,26 +4,26 @@ from numpy import linalg
import common
def lfilter(b, a, x):
b = np.array(b) / a[0]
a = np.array(a[1:]) / a[0]
class Filter(object):
def __init__(self, b, a):
self.b = np.array(b) / a[0]
self.a = np.array(a[1:]) / a[0]
x_ = [0] * len(b)
y_ = [0] * len(a)
for v in x:
x_ = [v] + x_[:-1]
u = np.dot(x_, b)
u = u - np.dot(y_, a)
y_ = [u] + y_[1:]
yield u
def __call__(self, x):
x_ = [0] * len(self.b)
y_ = [0] * len(self.a)
for v in x:
x_ = [v] + x_[:-1]
y = np.dot(x_, self.b) - np.dot(y_, self.a)
y_ = [y] + y_[1:]
yield y
def train(S, training):
A = np.array([S[1:], S[:-1], training[:-1]]).T
b = training[1:]
b0, b1, a1 = linalg.lstsq(A, b)[0]
return lambda x: lfilter(b=[b0, b1], a=[1, -a1], x=x)
return Filter(b=[b0, b1], a=[1, -a1])
class QAM(object):
@@ -63,7 +63,8 @@ class QAM(object):
modulator = QAM(common.symbols)
modem_bps = common.baud * modulator.bits_per_symbol * len(common.frequencies)
bits_per_baud = modulator.bits_per_symbol * len(common.frequencies)
modem_bps = common.baud * bits_per_baud
def clip(x, lims):