fix pylint and tests

This commit is contained in:
Roman Zeyde
2016-11-03 23:29:45 +02:00
parent ac4a86d312
commit dbed773e54
6 changed files with 38 additions and 38 deletions

View File

@@ -120,7 +120,7 @@ def handle_connection_error(func):
def parse_config(fname): def parse_config(fname):
"""Parse config file into a list of Identity objects.""" """Parse config file into a list of Identity objects."""
contents = open(fname).read() contents = open(fname).read()
for identity_str, curve_name in re.findall('\<(.*?)\|(.*?)\>', contents): for identity_str, curve_name in re.findall(r'\<(.*?)\|(.*?)\>', contents):
yield device.interface.Identity(identity_str=identity_str, yield device.interface.Identity(identity_str=identity_str,
curve_name=curve_name) curve_name=curve_name)

View File

@@ -24,5 +24,4 @@ def detect():
return d return d
except interface.NotFoundError as e: except interface.NotFoundError as e:
log.debug('device not found: %s', e) log.debug('device not found: %s', e)
raise IOError('No device found: "{}" ({})'.format(identity_str, raise IOError('No device found!')
curve_name))

View File

@@ -1,7 +1,6 @@
"""KeepKey-related code (see https://www.keepkey.com/).""" """KeepKey-related code (see https://www.keepkey.com/)."""
from . import interface, trezor from . import interface, trezor
from .. import formats
class KeepKey(trezor.Trezor): class KeepKey(trezor.Trezor):

View File

@@ -44,11 +44,10 @@ class LedgerNanoS(interface.Device):
raise interface.NotFoundError( raise interface.NotFoundError(
'{} not connected: "{}"'.format(self, e)) '{} not connected: "{}"'.format(self, e))
def pubkey(self, ecdh=False): def pubkey(self, identity, ecdh=False):
"""Get PublicKey object for specified BIP32 address and elliptic curve.""" """Get PublicKey object for specified BIP32 address and elliptic curve."""
curve_name = self.get_curve_name(ecdh) curve_name = identity.get_curve_name(ecdh)
path = _expand_path(interface.get_bip32_address(self.identity_dict, path = _expand_path(identity.get_bip32_address(ecdh))
ecdh=ecdh))
if curve_name == 'nist256p1': if curve_name == 'nist256p1':
p2 = '01' p2 = '01'
else: else:
@@ -60,27 +59,26 @@ class LedgerNanoS(interface.Device):
result = bytearray(self.conn.exchange(bytes(apdu)))[1:] result = bytearray(self.conn.exchange(bytes(apdu)))[1:]
return _convert_public_key(curve_name, result) return _convert_public_key(curve_name, result)
def sign(self, blob): def sign(self, identity, blob):
"""Sign given blob and return the signature (as bytes).""" """Sign given blob and return the signature (as bytes)."""
path = _expand_path(interface.get_bip32_address(self.identity_dict, path = _expand_path(identity.get_bip32_address(ecdh=False))
ecdh=False)) if identity.identity_dict['proto'] == 'ssh':
if self.identity_dict['proto'] == 'ssh':
ins = '04' ins = '04'
p1 = '00' p1 = '00'
else: else:
ins = '08' ins = '08'
p1 = '00' p1 = '00'
if self.curve_name == 'nist256p1': if identity.curve_name == 'nist256p1':
p2 = '81' if self.identity_dict['proto'] == 'ssh' else '01' p2 = '81' if identity.identity_dict['proto'] == 'ssh' else '01'
else: else:
p2 = '82' if self.identity_dict['proto'] == 'ssh' else '02' p2 = '82' if identity.identity_dict['proto'] == 'ssh' else '02'
apdu = '80' + ins + p1 + p2 apdu = '80' + ins + p1 + p2
apdu = binascii.unhexlify(apdu) apdu = binascii.unhexlify(apdu)
apdu += bytearray([len(blob) + len(path) + 1]) apdu += bytearray([len(blob) + len(path) + 1])
apdu += bytearray([len(path) // 4]) + path apdu += bytearray([len(path) // 4]) + path
apdu += blob apdu += blob
result = bytearray(self.conn.exchange(bytes(apdu))) result = bytearray(self.conn.exchange(bytes(apdu)))
if self.curve_name == 'nist256p1': if identity.curve_name == 'nist256p1':
offset = 3 offset = 3
length = result[offset] length = result[offset]
r = result[offset+1:offset+1+length] r = result[offset+1:offset+1+length]
@@ -96,11 +94,10 @@ class LedgerNanoS(interface.Device):
else: else:
return bytes(result[:64]) return bytes(result[:64])
def ecdh(self, pubkey): def ecdh(self, identity, pubkey):
"""Get shared session key using Elliptic Curve Diffie-Hellman.""" """Get shared session key using Elliptic Curve Diffie-Hellman."""
path = _expand_path(interface.get_bip32_address(self.identity_dict, path = _expand_path(identity.get_bip32_address(ecdh=True))
ecdh=True)) if identity.curve_name == 'nist256p1':
if self.curve_name == 'nist256p1':
p2 = '01' p2 = '01'
else: else:
p2 = '02' p2 = '02'

View File

@@ -12,7 +12,7 @@ PUBKEY = (b'\x03\xd8(\xb5\xa6`\xbet0\x95\xac:[;]\xdc,\xbd\xdc?\xd7\xc0\xec'
b'\xdd\xbc+\xfar~\x9dAis') b'\xdd\xbc+\xfar~\x9dAis')
PUBKEY_TEXT = ('ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzd' PUBKEY_TEXT = ('ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzd'
'HAyNTYAAABBBNgotaZgvnQwlaw6Wztd3Cy93D/XwOzdvCv6cn6dQWlzNMEQeW' 'HAyNTYAAABBBNgotaZgvnQwlaw6Wztd3Cy93D/XwOzdvCv6cn6dQWlzNMEQeW'
'VUfhvrGljR2Z/CMRONY6ejB+9PnpUOPuzYqi8= ssh://localhost:22\n') 'VUfhvrGljR2Z/CMRONY6ejB+9PnpUOPuzYqi8= <localhost:22|nist256p1>\n')
class MockDevice(device.interface.Device): # pylint: disable=abstract-method class MockDevice(device.interface.Device): # pylint: disable=abstract-method
@@ -23,11 +23,11 @@ class MockDevice(device.interface.Device): # pylint: disable=abstract-method
def close(self): def close(self):
self.conn = None self.conn = None
def pubkey(self, ecdh=False): # pylint: disable=unused-argument def pubkey(self, identity, ecdh=False): # pylint: disable=unused-argument
assert self.conn assert self.conn
return PUBKEY return PUBKEY
def sign(self, blob): def sign(self, identity, blob):
"""Sign given blob and return the signature (as bytes).""" """Sign given blob and return the signature (as bytes)."""
assert self.conn assert self.conn
assert blob == BLOB assert blob == BLOB
@@ -59,11 +59,11 @@ SIG = (b'R\x19T\xf2\x84$\xef#\x0e\xee\x04X\xc6\xc3\x99T`\xd1\xd8\xf7!'
def test_ssh_agent(): def test_ssh_agent():
identity_str = 'localhost:22' identity = device.interface.Identity(identity_str='localhost:22',
c = client.Client(device=MockDevice(identity_str=identity_str, curve_name=CURVE)
curve_name=CURVE)) c = client.Client(device=MockDevice())
assert c.get_public_key() == PUBKEY_TEXT assert c.get_public_key(identity) == PUBKEY_TEXT
signature = c.sign_ssh_challenge(blob=BLOB) signature = c.sign_ssh_challenge(blob=BLOB, identity=identity)
key = formats.import_public_key(PUBKEY_TEXT) key = formats.import_public_key(PUBKEY_TEXT)
serialized_sig = key['verifier'](sig=signature, msg=BLOB) serialized_sig = key['verifier'](sig=signature, msg=BLOB)
@@ -77,9 +77,9 @@ def test_ssh_agent():
assert r[1:] + s[1:] == SIG assert r[1:] + s[1:] == SIG
# pylint: disable=unused-argument # pylint: disable=unused-argument
def cancel_sign(blob): def cancel_sign(identity, blob):
raise IOError(42, 'ERROR') raise IOError(42, 'ERROR')
c.device.sign = cancel_sign c.device.sign = cancel_sign
with pytest.raises(IOError): with pytest.raises(IOError):
c.sign_ssh_challenge(blob=BLOB) c.sign_ssh_challenge(blob=BLOB, identity=identity)

View File

@@ -1,6 +1,6 @@
import pytest import pytest
from .. import formats, protocol from .. import device, formats, protocol
# pylint: disable=line-too-long # pylint: disable=line-too-long
@@ -17,6 +17,7 @@ NIST256_SIGN_REPLY = b'\x00\x00\x00j\x0e\x00\x00\x00e\x00\x00\x00\x13ecdsa-sha2-
def test_list(): def test_list():
key = formats.import_public_key(NIST256_KEY) key = formats.import_public_key(NIST256_KEY)
key['identity'] = device.interface.Identity('ssh://localhost', 'nist256p1')
h = protocol.Handler(keys=[key], signer=None) h = protocol.Handler(keys=[key], signer=None)
reply = h.handle(LIST_MSG) reply = h.handle(LIST_MSG)
assert reply == LIST_NIST256_REPLY assert reply == LIST_NIST256_REPLY
@@ -28,13 +29,15 @@ def test_unsupported():
assert reply == b'\x00\x00\x00\x01\x05' assert reply == b'\x00\x00\x00\x01\x05'
def ecdsa_signer(blob): def ecdsa_signer(identity, blob):
assert str(identity) == '<ssh://localhost|nist256p1>'
assert blob == NIST256_BLOB assert blob == NIST256_BLOB
return NIST256_SIG return NIST256_SIG
def test_ecdsa_sign(): def test_ecdsa_sign():
key = formats.import_public_key(NIST256_KEY) key = formats.import_public_key(NIST256_KEY)
key['identity'] = device.interface.Identity('ssh://localhost', 'nist256p1')
h = protocol.Handler(keys=[key], signer=ecdsa_signer) h = protocol.Handler(keys=[key], signer=ecdsa_signer)
reply = h.handle(NIST256_SIGN_MSG) reply = h.handle(NIST256_SIGN_MSG)
assert reply == NIST256_SIGN_REPLY assert reply == NIST256_SIGN_REPLY
@@ -42,30 +45,30 @@ def test_ecdsa_sign():
def test_sign_missing(): def test_sign_missing():
h = protocol.Handler(keys=[], signer=ecdsa_signer) h = protocol.Handler(keys=[], signer=ecdsa_signer)
with pytest.raises(KeyError): with pytest.raises(KeyError):
h.handle(NIST256_SIGN_MSG) h.handle(NIST256_SIGN_MSG)
def test_sign_wrong(): def test_sign_wrong():
def wrong_signature(blob): def wrong_signature(identity, blob):
assert str(identity) == '<ssh://localhost|nist256p1>'
assert blob == NIST256_BLOB assert blob == NIST256_BLOB
return b'\x00' * 64 return b'\x00' * 64
key = formats.import_public_key(NIST256_KEY) key = formats.import_public_key(NIST256_KEY)
key['identity'] = device.interface.Identity('ssh://localhost', 'nist256p1')
h = protocol.Handler(keys=[key], signer=wrong_signature) h = protocol.Handler(keys=[key], signer=wrong_signature)
with pytest.raises(ValueError): with pytest.raises(ValueError):
h.handle(NIST256_SIGN_MSG) h.handle(NIST256_SIGN_MSG)
def test_sign_cancel(): def test_sign_cancel():
def cancel_signature(blob): # pylint: disable=unused-argument def cancel_signature(identity, blob): # pylint: disable=unused-argument
raise IOError() raise IOError()
key = formats.import_public_key(NIST256_KEY) key = formats.import_public_key(NIST256_KEY)
key['identity'] = device.interface.Identity('ssh://localhost', 'nist256p1')
h = protocol.Handler(keys=[key], signer=cancel_signature) h = protocol.Handler(keys=[key], signer=cancel_signature)
assert h.handle(NIST256_SIGN_MSG) == protocol.failure() assert h.handle(NIST256_SIGN_MSG) == protocol.failure()
@@ -77,13 +80,15 @@ ED25519_BLOB = b'''\x00\x00\x00 i3\xae}yk\\\xa1L\xb9\xe1\xbf\xbc\x8e\x87\r\x0e\x
ED25519_SIG = b'''\x8eb)\xa6\xe9P\x83VE\xfbq\xc6\xbf\x1dV3\xe3<O\x11\xc0\xfa\xe4\xed\xb8\x81.\x81\xc8\xa6\xba\x10RA'a\xbc\xa9\xd3\xdb\x98\x07\xf0\x1a\x9c4\x84<\xaf\x99\xb7\xe5G\xeb\xf7$\xc1\r\x86f\x16\x8e\x08\x05''' # nopep8 ED25519_SIG = b'''\x8eb)\xa6\xe9P\x83VE\xfbq\xc6\xbf\x1dV3\xe3<O\x11\xc0\xfa\xe4\xed\xb8\x81.\x81\xc8\xa6\xba\x10RA'a\xbc\xa9\xd3\xdb\x98\x07\xf0\x1a\x9c4\x84<\xaf\x99\xb7\xe5G\xeb\xf7$\xc1\r\x86f\x16\x8e\x08\x05''' # nopep8
def ed25519_signer(blob): def ed25519_signer(identity, blob):
assert str(identity) == '<ssh://localhost|ed25519>'
assert blob == ED25519_BLOB assert blob == ED25519_BLOB
return ED25519_SIG return ED25519_SIG
def test_ed25519_sign(): def test_ed25519_sign():
key = formats.import_public_key(ED25519_KEY) key = formats.import_public_key(ED25519_KEY)
key['identity'] = device.interface.Identity('ssh://localhost', 'ed25519')
h = protocol.Handler(keys=[key], signer=ed25519_signer) h = protocol.Handler(keys=[key], signer=ed25519_signer)
reply = h.handle(ED25519_SIGN_MSG) reply = h.handle(ED25519_SIGN_MSG)
assert reply == ED25519_SIGN_REPLY assert reply == ED25519_SIGN_REPLY