diff --git a/sigproc.py b/sigproc.py index 776e108..270c1fb 100644 --- a/sigproc.py +++ b/sigproc.py @@ -101,11 +101,11 @@ def extract_symbols(x, freq, offset=0): yield np.dot(Hc, symbol) -def drift(S): - x = np.arange(len(S)) - x = x - np.mean(x) - y = np.unwrap(np.angle(S)) / (2*np.pi) - mean_y = np.mean(y) - y = y - mean_y - a = np.dot(x, y) / np.dot(x, x) - return a, mean_y +def linear_regression(x, y): + ''' Find (a,b) such that y = a*x + b. ''' + x = np.array(x) + y = np.array(y) + ones = np.ones(len(x)) + M = np.array([x, ones]).T + a, b = linalg.lstsq(M, y)[0] + return a, b diff --git a/test_sigproc.py b/test_sigproc.py index 8c4f0f1..5cbdf0d 100644 --- a/test_sigproc.py +++ b/test_sigproc.py @@ -1,12 +1,12 @@ import sigproc import itertools -import common +import config import numpy as np import random def test_qam(): - q = sigproc.QAM(common.symbols) + q = sigproc.QAM(config.symbols) r = random.Random(0) m = q.bits_per_symbol bits = [tuple(r.randint(0, 1) for j in range(m)) for i in range(1024)] @@ -16,12 +16,10 @@ def test_qam(): assert decoded == bits -def test_drift(): - fc = 10e3 - df = 1.23 - f = fc + df - x = np.cos(2 * np.pi * f / common.Fs * np.arange(common.Fs)) - S = sigproc.extract_symbols(x, fc) - S = np.array(list(S)) - df_ = sigproc.drift(S) / common.Tsym - assert abs(df - df_) < 1e-5, (df, df_) +def test_linreg(): + x = np.array([1, 3, 2, 8, 4, 6, 9, 7, 0, 5]) + a, b = 12.3, 4.56 + y = a * x + b + a_, b_ = sigproc.linear_regression(x, y) + assert abs(a - a_) < 1e-10 + assert abs(b - b_) < 1e-10