diff --git a/amodem/sigproc.py b/amodem/sigproc.py index ddd6751..17de93c 100644 --- a/amodem/sigproc.py +++ b/amodem/sigproc.py @@ -76,14 +76,16 @@ class QAM(object): imags = np.array(list(sorted(set(symbols.imag)))) self.real_factor = 1.0 / np.mean(np.diff(reals)) self.imag_factor = 1.0 / np.mean(np.diff(imags)) - self.real_offset = reals[0] - self.imag_offset = imags[0] + self.bias = reals[0] + 1j * imags[0] self.symbols_map = {} for S in symbols: - real_index = round(S.real * self.real_factor + self.real_offset) - imag_index = round(S.imag * self.imag_factor + self.imag_offset) + s = S - self.bias + real_index = round(s.real * self.real_factor) + imag_index = round(s.imag * self.imag_factor) self.symbols_map[real_index, imag_index] = (S, self._dec[S]) + self.real_max = max(k[0] for k in self.symbols_map) + self.imag_max = max(k[1] for k in self.symbols_map) def encode(self, bits): for _, bits_tuple in common.iterate(bits, self.bits_per_symbol, tuple): @@ -91,16 +93,18 @@ class QAM(object): def decode(self, symbols, error_handler=None): real_factor = self.real_factor - real_offset = self.real_offset - imag_factor = self.imag_factor - imag_offset = self.imag_offset + real_max = self.real_max + imag_max = self.imag_max + bias = self.bias symbols_map = self.symbols_map for S in symbols: - real_index = round(S.real * real_factor + real_offset) - imag_index = round(S.imag * imag_factor + imag_offset) - decoded_symbol, bits = symbols_map[real_index, imag_index] + s = S - bias + real_index = min(max(s.real * real_factor, 0), real_max) + imag_index = min(max(s.imag * imag_factor, 0), imag_max) + key = (round(real_index), round(imag_index)) + decoded_symbol, bits = symbols_map[key] if error_handler: error_handler(received=S, decoded=decoded_symbol) yield bits diff --git a/tests/test_sigproc.py b/tests/test_sigproc.py index d4608b2..bc3cd87 100644 --- a/tests/test_sigproc.py +++ b/tests/test_sigproc.py @@ -13,10 +13,29 @@ def test_qam(): m = q.bits_per_symbol bits = [tuple(r.randint(0, 1) for j in range(m)) for i in range(1024)] stream = itertools.chain(*bits) - S = q.encode(list(stream)) - decoded = list(q.decode(list(S))) + S = list(q.encode(list(stream))) + decoded = list(q.decode(S)) assert decoded == bits + noise = lambda A: A*(r.uniform(-1, 1) + 1j*r.uniform(-1, 1)) + noised_symbols = [(s + noise(1e-3)) for s in S] + decoded = list(q.decode(noised_symbols)) + assert decoded == bits + + +def quantize(q, s): + bits, = list(q.decode([s])) + r, = q.encode(bits) + index = np.argmin(np.abs(s - q.symbols)) + expected = q.symbols[index] + assert r == expected + +def test_overflow(): + q = sigproc.QAM(config.symbols) + r = np.random.RandomState(seed=0) + for i in range(10000): + s = 10*(r.normal() + 1j * r.normal()) + quantize(q, s) def test_linreg(): x = np.array([1, 3, 2, 8, 4, 6, 9, 7, 0, 5])