framing: refactor a bit

This commit is contained in:
Roman Zeyde
2015-01-08 09:35:38 +02:00
parent 318a0644de
commit 3b1d193b0b

View File

@@ -7,23 +7,22 @@ import struct
import logging import logging
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
_crc32 = lambda x, mask: binascii.crc32(bytes(x)) & mask _checksum_func = lambda x: binascii.crc32(bytes(x)) & 0xFFFFFFFF
# (so the result will be unsigned on Python 2/3) # (so the result will be unsigned on Python 2/3)
class Checksum(object): class Checksum(object):
fmt = '>L' # unsigned longs (32-bit) fmt = '>L' # unsigned longs (32-bit)
size = struct.calcsize(fmt) size = struct.calcsize(fmt)
func = functools.partial(_crc32, mask=0xFFFFFFFF)
def encode(self, payload): def encode(self, payload):
checksum = self.func(payload) checksum = _checksum_func(payload)
return struct.pack(self.fmt, checksum) + payload return struct.pack(self.fmt, checksum) + payload
def decode(self, data): def decode(self, data):
received, = struct.unpack(self.fmt, bytes(data[:self.size])) received, = struct.unpack(self.fmt, bytes(data[:self.size]))
payload = data[self.size:] payload = data[self.size:]
expected = self.func(payload) expected = _checksum_func(payload)
if received != expected: if received != expected:
log.warning('Invalid checksum: %04x != %04x', received, expected) log.warning('Invalid checksum: %04x != %04x', received, expected)
raise ValueError('Invalid checksum') raise ValueError('Invalid checksum')
@@ -51,8 +50,8 @@ class Framer(object):
def decode(self, data): def decode(self, data):
data = iter(data) data = iter(data)
while True: while True:
length, = self._take_fmt(data, self.prefix_fmt) length, = _take_fmt(data, self.prefix_fmt)
frame = self._take_len(data, length) frame = _take_len(data, length)
block = self.checksum.decode(frame) block = self.checksum.decode(frame)
if block == self.EOF: if block == self.EOF:
log.debug('EOF frame detected') log.debug('EOF frame detected')
@@ -60,14 +59,16 @@ class Framer(object):
yield block yield block
def _take_fmt(self, data, fmt):
def _take_fmt(data, fmt):
length = struct.calcsize(fmt) length = struct.calcsize(fmt)
chunk = bytearray(itertools.islice(data, length)) chunk = bytearray(itertools.islice(data, length))
if len(chunk) < length: if len(chunk) < length:
raise ValueError('missing prefix data') raise ValueError('missing prefix data')
return struct.unpack(fmt, bytes(chunk)) return struct.unpack(fmt, bytes(chunk))
def _take_len(self, data, length):
def _take_len(data, length):
chunk = bytearray(itertools.islice(data, length)) chunk = bytearray(itertools.islice(data, length))
if len(chunk) < length: if len(chunk) < length:
raise ValueError('missing payload data') raise ValueError('missing payload data')