mirror of
https://github.com/romanz/amodem.git
synced 2026-04-21 13:46:30 +08:00
fix pylint and tests
This commit is contained in:
@@ -120,7 +120,7 @@ def handle_connection_error(func):
|
|||||||
def parse_config(fname):
|
def parse_config(fname):
|
||||||
"""Parse config file into a list of Identity objects."""
|
"""Parse config file into a list of Identity objects."""
|
||||||
contents = open(fname).read()
|
contents = open(fname).read()
|
||||||
for identity_str, curve_name in re.findall('\<(.*?)\|(.*?)\>', contents):
|
for identity_str, curve_name in re.findall(r'\<(.*?)\|(.*?)\>', contents):
|
||||||
yield device.interface.Identity(identity_str=identity_str,
|
yield device.interface.Identity(identity_str=identity_str,
|
||||||
curve_name=curve_name)
|
curve_name=curve_name)
|
||||||
|
|
||||||
|
|||||||
@@ -24,5 +24,4 @@ def detect():
|
|||||||
return d
|
return d
|
||||||
except interface.NotFoundError as e:
|
except interface.NotFoundError as e:
|
||||||
log.debug('device not found: %s', e)
|
log.debug('device not found: %s', e)
|
||||||
raise IOError('No device found: "{}" ({})'.format(identity_str,
|
raise IOError('No device found!')
|
||||||
curve_name))
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""KeepKey-related code (see https://www.keepkey.com/)."""
|
"""KeepKey-related code (see https://www.keepkey.com/)."""
|
||||||
|
|
||||||
from . import interface, trezor
|
from . import interface, trezor
|
||||||
from .. import formats
|
|
||||||
|
|
||||||
|
|
||||||
class KeepKey(trezor.Trezor):
|
class KeepKey(trezor.Trezor):
|
||||||
|
|||||||
@@ -44,11 +44,10 @@ class LedgerNanoS(interface.Device):
|
|||||||
raise interface.NotFoundError(
|
raise interface.NotFoundError(
|
||||||
'{} not connected: "{}"'.format(self, e))
|
'{} not connected: "{}"'.format(self, e))
|
||||||
|
|
||||||
def pubkey(self, ecdh=False):
|
def pubkey(self, identity, ecdh=False):
|
||||||
"""Get PublicKey object for specified BIP32 address and elliptic curve."""
|
"""Get PublicKey object for specified BIP32 address and elliptic curve."""
|
||||||
curve_name = self.get_curve_name(ecdh)
|
curve_name = identity.get_curve_name(ecdh)
|
||||||
path = _expand_path(interface.get_bip32_address(self.identity_dict,
|
path = _expand_path(identity.get_bip32_address(ecdh))
|
||||||
ecdh=ecdh))
|
|
||||||
if curve_name == 'nist256p1':
|
if curve_name == 'nist256p1':
|
||||||
p2 = '01'
|
p2 = '01'
|
||||||
else:
|
else:
|
||||||
@@ -60,27 +59,26 @@ class LedgerNanoS(interface.Device):
|
|||||||
result = bytearray(self.conn.exchange(bytes(apdu)))[1:]
|
result = bytearray(self.conn.exchange(bytes(apdu)))[1:]
|
||||||
return _convert_public_key(curve_name, result)
|
return _convert_public_key(curve_name, result)
|
||||||
|
|
||||||
def sign(self, blob):
|
def sign(self, identity, blob):
|
||||||
"""Sign given blob and return the signature (as bytes)."""
|
"""Sign given blob and return the signature (as bytes)."""
|
||||||
path = _expand_path(interface.get_bip32_address(self.identity_dict,
|
path = _expand_path(identity.get_bip32_address(ecdh=False))
|
||||||
ecdh=False))
|
if identity.identity_dict['proto'] == 'ssh':
|
||||||
if self.identity_dict['proto'] == 'ssh':
|
|
||||||
ins = '04'
|
ins = '04'
|
||||||
p1 = '00'
|
p1 = '00'
|
||||||
else:
|
else:
|
||||||
ins = '08'
|
ins = '08'
|
||||||
p1 = '00'
|
p1 = '00'
|
||||||
if self.curve_name == 'nist256p1':
|
if identity.curve_name == 'nist256p1':
|
||||||
p2 = '81' if self.identity_dict['proto'] == 'ssh' else '01'
|
p2 = '81' if identity.identity_dict['proto'] == 'ssh' else '01'
|
||||||
else:
|
else:
|
||||||
p2 = '82' if self.identity_dict['proto'] == 'ssh' else '02'
|
p2 = '82' if identity.identity_dict['proto'] == 'ssh' else '02'
|
||||||
apdu = '80' + ins + p1 + p2
|
apdu = '80' + ins + p1 + p2
|
||||||
apdu = binascii.unhexlify(apdu)
|
apdu = binascii.unhexlify(apdu)
|
||||||
apdu += bytearray([len(blob) + len(path) + 1])
|
apdu += bytearray([len(blob) + len(path) + 1])
|
||||||
apdu += bytearray([len(path) // 4]) + path
|
apdu += bytearray([len(path) // 4]) + path
|
||||||
apdu += blob
|
apdu += blob
|
||||||
result = bytearray(self.conn.exchange(bytes(apdu)))
|
result = bytearray(self.conn.exchange(bytes(apdu)))
|
||||||
if self.curve_name == 'nist256p1':
|
if identity.curve_name == 'nist256p1':
|
||||||
offset = 3
|
offset = 3
|
||||||
length = result[offset]
|
length = result[offset]
|
||||||
r = result[offset+1:offset+1+length]
|
r = result[offset+1:offset+1+length]
|
||||||
@@ -96,11 +94,10 @@ class LedgerNanoS(interface.Device):
|
|||||||
else:
|
else:
|
||||||
return bytes(result[:64])
|
return bytes(result[:64])
|
||||||
|
|
||||||
def ecdh(self, pubkey):
|
def ecdh(self, identity, pubkey):
|
||||||
"""Get shared session key using Elliptic Curve Diffie-Hellman."""
|
"""Get shared session key using Elliptic Curve Diffie-Hellman."""
|
||||||
path = _expand_path(interface.get_bip32_address(self.identity_dict,
|
path = _expand_path(identity.get_bip32_address(ecdh=True))
|
||||||
ecdh=True))
|
if identity.curve_name == 'nist256p1':
|
||||||
if self.curve_name == 'nist256p1':
|
|
||||||
p2 = '01'
|
p2 = '01'
|
||||||
else:
|
else:
|
||||||
p2 = '02'
|
p2 = '02'
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ PUBKEY = (b'\x03\xd8(\xb5\xa6`\xbet0\x95\xac:[;]\xdc,\xbd\xdc?\xd7\xc0\xec'
|
|||||||
b'\xdd\xbc+\xfar~\x9dAis')
|
b'\xdd\xbc+\xfar~\x9dAis')
|
||||||
PUBKEY_TEXT = ('ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzd'
|
PUBKEY_TEXT = ('ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzd'
|
||||||
'HAyNTYAAABBBNgotaZgvnQwlaw6Wztd3Cy93D/XwOzdvCv6cn6dQWlzNMEQeW'
|
'HAyNTYAAABBBNgotaZgvnQwlaw6Wztd3Cy93D/XwOzdvCv6cn6dQWlzNMEQeW'
|
||||||
'VUfhvrGljR2Z/CMRONY6ejB+9PnpUOPuzYqi8= ssh://localhost:22\n')
|
'VUfhvrGljR2Z/CMRONY6ejB+9PnpUOPuzYqi8= <localhost:22|nist256p1>\n')
|
||||||
|
|
||||||
|
|
||||||
class MockDevice(device.interface.Device): # pylint: disable=abstract-method
|
class MockDevice(device.interface.Device): # pylint: disable=abstract-method
|
||||||
@@ -23,11 +23,11 @@ class MockDevice(device.interface.Device): # pylint: disable=abstract-method
|
|||||||
def close(self):
|
def close(self):
|
||||||
self.conn = None
|
self.conn = None
|
||||||
|
|
||||||
def pubkey(self, ecdh=False): # pylint: disable=unused-argument
|
def pubkey(self, identity, ecdh=False): # pylint: disable=unused-argument
|
||||||
assert self.conn
|
assert self.conn
|
||||||
return PUBKEY
|
return PUBKEY
|
||||||
|
|
||||||
def sign(self, blob):
|
def sign(self, identity, blob):
|
||||||
"""Sign given blob and return the signature (as bytes)."""
|
"""Sign given blob and return the signature (as bytes)."""
|
||||||
assert self.conn
|
assert self.conn
|
||||||
assert blob == BLOB
|
assert blob == BLOB
|
||||||
@@ -59,11 +59,11 @@ SIG = (b'R\x19T\xf2\x84$\xef#\x0e\xee\x04X\xc6\xc3\x99T`\xd1\xd8\xf7!'
|
|||||||
|
|
||||||
|
|
||||||
def test_ssh_agent():
|
def test_ssh_agent():
|
||||||
identity_str = 'localhost:22'
|
identity = device.interface.Identity(identity_str='localhost:22',
|
||||||
c = client.Client(device=MockDevice(identity_str=identity_str,
|
curve_name=CURVE)
|
||||||
curve_name=CURVE))
|
c = client.Client(device=MockDevice())
|
||||||
assert c.get_public_key() == PUBKEY_TEXT
|
assert c.get_public_key(identity) == PUBKEY_TEXT
|
||||||
signature = c.sign_ssh_challenge(blob=BLOB)
|
signature = c.sign_ssh_challenge(blob=BLOB, identity=identity)
|
||||||
|
|
||||||
key = formats.import_public_key(PUBKEY_TEXT)
|
key = formats.import_public_key(PUBKEY_TEXT)
|
||||||
serialized_sig = key['verifier'](sig=signature, msg=BLOB)
|
serialized_sig = key['verifier'](sig=signature, msg=BLOB)
|
||||||
@@ -77,9 +77,9 @@ def test_ssh_agent():
|
|||||||
assert r[1:] + s[1:] == SIG
|
assert r[1:] + s[1:] == SIG
|
||||||
|
|
||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
def cancel_sign(blob):
|
def cancel_sign(identity, blob):
|
||||||
raise IOError(42, 'ERROR')
|
raise IOError(42, 'ERROR')
|
||||||
|
|
||||||
c.device.sign = cancel_sign
|
c.device.sign = cancel_sign
|
||||||
with pytest.raises(IOError):
|
with pytest.raises(IOError):
|
||||||
c.sign_ssh_challenge(blob=BLOB)
|
c.sign_ssh_challenge(blob=BLOB, identity=identity)
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from .. import formats, protocol
|
from .. import device, formats, protocol
|
||||||
|
|
||||||
# pylint: disable=line-too-long
|
# pylint: disable=line-too-long
|
||||||
|
|
||||||
@@ -17,6 +17,7 @@ NIST256_SIGN_REPLY = b'\x00\x00\x00j\x0e\x00\x00\x00e\x00\x00\x00\x13ecdsa-sha2-
|
|||||||
|
|
||||||
def test_list():
|
def test_list():
|
||||||
key = formats.import_public_key(NIST256_KEY)
|
key = formats.import_public_key(NIST256_KEY)
|
||||||
|
key['identity'] = device.interface.Identity('ssh://localhost', 'nist256p1')
|
||||||
h = protocol.Handler(keys=[key], signer=None)
|
h = protocol.Handler(keys=[key], signer=None)
|
||||||
reply = h.handle(LIST_MSG)
|
reply = h.handle(LIST_MSG)
|
||||||
assert reply == LIST_NIST256_REPLY
|
assert reply == LIST_NIST256_REPLY
|
||||||
@@ -28,13 +29,15 @@ def test_unsupported():
|
|||||||
assert reply == b'\x00\x00\x00\x01\x05'
|
assert reply == b'\x00\x00\x00\x01\x05'
|
||||||
|
|
||||||
|
|
||||||
def ecdsa_signer(blob):
|
def ecdsa_signer(identity, blob):
|
||||||
|
assert str(identity) == '<ssh://localhost|nist256p1>'
|
||||||
assert blob == NIST256_BLOB
|
assert blob == NIST256_BLOB
|
||||||
return NIST256_SIG
|
return NIST256_SIG
|
||||||
|
|
||||||
|
|
||||||
def test_ecdsa_sign():
|
def test_ecdsa_sign():
|
||||||
key = formats.import_public_key(NIST256_KEY)
|
key = formats.import_public_key(NIST256_KEY)
|
||||||
|
key['identity'] = device.interface.Identity('ssh://localhost', 'nist256p1')
|
||||||
h = protocol.Handler(keys=[key], signer=ecdsa_signer)
|
h = protocol.Handler(keys=[key], signer=ecdsa_signer)
|
||||||
reply = h.handle(NIST256_SIGN_MSG)
|
reply = h.handle(NIST256_SIGN_MSG)
|
||||||
assert reply == NIST256_SIGN_REPLY
|
assert reply == NIST256_SIGN_REPLY
|
||||||
@@ -42,30 +45,30 @@ def test_ecdsa_sign():
|
|||||||
|
|
||||||
def test_sign_missing():
|
def test_sign_missing():
|
||||||
h = protocol.Handler(keys=[], signer=ecdsa_signer)
|
h = protocol.Handler(keys=[], signer=ecdsa_signer)
|
||||||
|
|
||||||
with pytest.raises(KeyError):
|
with pytest.raises(KeyError):
|
||||||
h.handle(NIST256_SIGN_MSG)
|
h.handle(NIST256_SIGN_MSG)
|
||||||
|
|
||||||
|
|
||||||
def test_sign_wrong():
|
def test_sign_wrong():
|
||||||
def wrong_signature(blob):
|
def wrong_signature(identity, blob):
|
||||||
|
assert str(identity) == '<ssh://localhost|nist256p1>'
|
||||||
assert blob == NIST256_BLOB
|
assert blob == NIST256_BLOB
|
||||||
return b'\x00' * 64
|
return b'\x00' * 64
|
||||||
|
|
||||||
key = formats.import_public_key(NIST256_KEY)
|
key = formats.import_public_key(NIST256_KEY)
|
||||||
|
key['identity'] = device.interface.Identity('ssh://localhost', 'nist256p1')
|
||||||
h = protocol.Handler(keys=[key], signer=wrong_signature)
|
h = protocol.Handler(keys=[key], signer=wrong_signature)
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
h.handle(NIST256_SIGN_MSG)
|
h.handle(NIST256_SIGN_MSG)
|
||||||
|
|
||||||
|
|
||||||
def test_sign_cancel():
|
def test_sign_cancel():
|
||||||
def cancel_signature(blob): # pylint: disable=unused-argument
|
def cancel_signature(identity, blob): # pylint: disable=unused-argument
|
||||||
raise IOError()
|
raise IOError()
|
||||||
|
|
||||||
key = formats.import_public_key(NIST256_KEY)
|
key = formats.import_public_key(NIST256_KEY)
|
||||||
|
key['identity'] = device.interface.Identity('ssh://localhost', 'nist256p1')
|
||||||
h = protocol.Handler(keys=[key], signer=cancel_signature)
|
h = protocol.Handler(keys=[key], signer=cancel_signature)
|
||||||
|
|
||||||
assert h.handle(NIST256_SIGN_MSG) == protocol.failure()
|
assert h.handle(NIST256_SIGN_MSG) == protocol.failure()
|
||||||
|
|
||||||
|
|
||||||
@@ -77,13 +80,15 @@ ED25519_BLOB = b'''\x00\x00\x00 i3\xae}yk\\\xa1L\xb9\xe1\xbf\xbc\x8e\x87\r\x0e\x
|
|||||||
ED25519_SIG = b'''\x8eb)\xa6\xe9P\x83VE\xfbq\xc6\xbf\x1dV3\xe3<O\x11\xc0\xfa\xe4\xed\xb8\x81.\x81\xc8\xa6\xba\x10RA'a\xbc\xa9\xd3\xdb\x98\x07\xf0\x1a\x9c4\x84<\xaf\x99\xb7\xe5G\xeb\xf7$\xc1\r\x86f\x16\x8e\x08\x05''' # nopep8
|
ED25519_SIG = b'''\x8eb)\xa6\xe9P\x83VE\xfbq\xc6\xbf\x1dV3\xe3<O\x11\xc0\xfa\xe4\xed\xb8\x81.\x81\xc8\xa6\xba\x10RA'a\xbc\xa9\xd3\xdb\x98\x07\xf0\x1a\x9c4\x84<\xaf\x99\xb7\xe5G\xeb\xf7$\xc1\r\x86f\x16\x8e\x08\x05''' # nopep8
|
||||||
|
|
||||||
|
|
||||||
def ed25519_signer(blob):
|
def ed25519_signer(identity, blob):
|
||||||
|
assert str(identity) == '<ssh://localhost|ed25519>'
|
||||||
assert blob == ED25519_BLOB
|
assert blob == ED25519_BLOB
|
||||||
return ED25519_SIG
|
return ED25519_SIG
|
||||||
|
|
||||||
|
|
||||||
def test_ed25519_sign():
|
def test_ed25519_sign():
|
||||||
key = formats.import_public_key(ED25519_KEY)
|
key = formats.import_public_key(ED25519_KEY)
|
||||||
|
key['identity'] = device.interface.Identity('ssh://localhost', 'ed25519')
|
||||||
h = protocol.Handler(keys=[key], signer=ed25519_signer)
|
h = protocol.Handler(keys=[key], signer=ed25519_signer)
|
||||||
reply = h.handle(ED25519_SIGN_MSG)
|
reply = h.handle(ED25519_SIGN_MSG)
|
||||||
assert reply == ED25519_SIGN_REPLY
|
assert reply == ED25519_SIGN_REPLY
|
||||||
|
|||||||
Reference in New Issue
Block a user