qam: use simple ML decoding

This commit is contained in:
Roman Zeyde
2014-09-28 20:29:31 +03:00
parent 09afd32f0b
commit d63a7dbe9d

View File

@@ -7,7 +7,7 @@ class QAM(object):
buf_size = 16
def __init__(self, symbols):
self._enc = {}
self.encode_map = {}
symbols = np.array(list(symbols))
bits_per_symbol = np.log2(len(symbols))
bits_per_symbol = np.round(bits_per_symbol)
@@ -17,43 +17,26 @@ class QAM(object):
for i, v in enumerate(symbols):
bits = [int(i & (1 << j) != 0) for j in range(bits_per_symbol)]
self._enc[tuple(bits)] = v
self.encode_map[tuple(bits)] = v
self._dec = {v: k for k, v in self._enc.items()}
self.symbols = symbols
self.bits_per_symbol = bits_per_symbol
reals = np.array(list(sorted(set(symbols.real))))
imags = np.array(list(sorted(set(symbols.imag))))
_mean = lambda u: float(sum(u))/len(u) if len(u) else 1.0
self.real_factor = 1.0 / _mean(np.diff(reals))
self.imag_factor = 1.0 / _mean(np.diff(imags))
self.bias = reals[0] + 1j * imags[0]
self.symbols_map = {}
for S in symbols:
s = S - self.bias
real_index = round(s.real * self.real_factor)
imag_index = round(s.imag * self.imag_factor)
self.symbols_map[real_index + 1j * imag_index] = (S, self._dec[S])
self.real_max = max(k.real for k in self.symbols_map)
self.imag_max = max(k.imag for k in self.symbols_map)
bits_map = {symbol: bits for bits, symbol in self.encode_map.items()}
self.decode_list = [(s, bits_map[s]) for s in self.symbols]
def encode(self, bits):
for bits_tuple in common.iterate(bits, self.bits_per_symbol, tuple):
yield self._enc[bits_tuple]
yield self.encode_map[bits_tuple]
def decode(self, symbols, error_handler=None):
symbols_map = self.symbols_map
symbols_vec = self.symbols
_dec = self.decode_list
for syms in common.iterate(symbols, self.buf_size, truncate=False):
s = syms - self.bias
real_index = np.clip(s.real * self.real_factor, 0, self.real_max)
imag_index = np.clip(s.imag * self.imag_factor, 0, self.imag_max)
keys = np.round(real_index + 1j * imag_index)
for key, received in zip(keys, syms):
decoded_symbol, bits = symbols_map[key]
for received in syms:
error = np.abs(symbols_vec - received)
index = np.argmin(error)
decoded, bits = _dec[index]
if error_handler:
error_handler(received=received, decoded=decoded_symbol)
error_handler(received=received, decoded=decoded)
yield bits