Merge branch 'ssh-ids'

This commit is contained in:
Roman Zeyde
2016-11-04 16:07:18 +02:00
13 changed files with 140 additions and 153 deletions

View File

@@ -31,7 +31,6 @@ setup(
],
entry_points={'console_scripts': [
'trezor-agent = trezor_agent.__main__:run_agent',
'trezor-git = trezor_agent.__main__:run_git',
'trezor-gpg-create = trezor_agent.gpg.__main__:main_create',
'trezor-gpg-agent = trezor_agent.gpg.__main__:main_agent',
'trezor-gpg-unlock = trezor_agent.gpg.__main__:auto_unlock',

View File

@@ -93,13 +93,11 @@ def git_host(remote_name, attributes):
return '{user}@{host}'.format(**match.groupdict())
def run_server(conn, public_key, command, debug, timeout):
def run_server(conn, public_keys, command, debug, timeout):
"""Common code for run_agent and run_git below."""
try:
signer = conn.sign_ssh_challenge
public_key = formats.import_public_key(public_key)
log.info('using SSH public key: %s', public_key['fingerprint'])
handler = protocol.Handler(keys=[public_key], signer=signer,
handler = protocol.Handler(keys=public_keys, signer=signer,
debug=debug)
with server.serve(handler=handler, timeout=timeout) as env:
return server.run_process(command=command, environ=env)
@@ -119,18 +117,33 @@ def handle_connection_error(func):
return wrapper
def parse_config(fname):
"""Parse config file into a list of Identity objects."""
contents = open(fname).read()
for identity_str, curve_name in re.findall(r'\<(.*?)\|(.*?)\>', contents):
yield device.interface.Identity(identity_str=identity_str,
curve_name=curve_name)
@handle_connection_error
def run_agent(client_factory=client.Client):
"""Run ssh-agent using given hardware client factory."""
args = create_agent_parser().parse_args()
util.setup_logging(verbosity=args.verbose)
d = device.detect(identity_str=args.identity,
curve_name=args.ecdsa_curve_name)
conn = client_factory(device=d)
conn = client_factory(device=device.detect())
if args.identity.startswith('/'):
identities = list(parse_config(fname=args.identity))
else:
identities = [device.interface.Identity(
identity_str=args.identity, curve_name=args.ecdsa_curve_name)]
for index, identity in enumerate(identities):
identity.identity_dict['proto'] = 'ssh'
log.info('identity #%d: %s', index, identity)
command = args.command
public_key = conn.get_public_key()
public_keys = [conn.get_public_key(i) for i in identities]
if args.connect:
command = ssh_args(args.identity) + args.command
@@ -142,36 +155,12 @@ def run_agent(client_factory=client.Client):
log.debug('using shell: %r', command)
if not command:
sys.stdout.write(public_key)
for pk in public_keys:
sys.stdout.write(pk)
return
return run_server(conn=conn, public_key=public_key, command=command,
public_keys = [formats.import_public_key(pk) for pk in public_keys]
for pk, identity in zip(public_keys, identities):
pk['identity'] = identity
return run_server(conn=conn, public_keys=public_keys, command=command,
debug=args.debug, timeout=args.timeout)
@handle_connection_error
def run_git(client_factory=client.Client):
"""Run git under ssh-agent using given hardware client factory."""
args = create_git_parser().parse_args()
util.setup_logging(verbosity=args.verbose)
with client_factory(curve=args.ecdsa_curve_name) as conn:
label = git_host(args.remote, ['pushurl', 'url'])
if not label:
log.error('Could not find "%s" SSH remote in .git/config',
args.remote)
return
public_key = conn.get_public_key(label=label)
if not args.test:
if args.command:
command = ['git'] + args.command
else:
sys.stdout.write(public_key)
return
else:
command = ['ssh', '-T', label]
return run_server(conn=conn, public_key=public_key, command=command,
debug=args.debug, timeout=args.timeout)

View File

@@ -16,34 +16,34 @@ class Client(object):
def __init__(self, device):
"""Connect to hardware device."""
device.identity_dict['proto'] = 'ssh'
self.device = device
def get_public_key(self):
def get_public_key(self, identity):
"""Get SSH public key from the device."""
with self.device:
pubkey = self.device.pubkey()
pubkey = self.device.pubkey(identity)
vk = formats.decompress_pubkey(pubkey=pubkey,
curve_name=self.device.curve_name)
curve_name=identity.curve_name)
return formats.export_public_key(vk=vk,
label=self.device.identity_str())
label=str(identity))
def sign_ssh_challenge(self, blob):
def sign_ssh_challenge(self, blob, identity):
"""Sign given blob using a private key on the device."""
msg = _parse_ssh_blob(blob)
log.debug('%s: user %r via %r (%r)',
msg['conn'], msg['user'], msg['auth'], msg['key_type'])
log.debug('nonce: %r', msg['nonce'])
log.debug('fingerprint: %s', msg['public_key']['fingerprint'])
fp = msg['public_key']['fingerprint']
log.debug('fingerprint: %s', fp)
log.debug('hidden challenge size: %d bytes', len(blob))
log.info('please confirm user "%s" login to "%s" using %s...',
msg['user'].decode('ascii'), self.device.identity_str(),
msg['user'].decode('ascii'), identity,
self.device)
with self.device:
return self.device.sign(blob=blob)
return self.device.sign(blob=blob, identity=identity)
def _parse_ssh_blob(data):

View File

@@ -16,13 +16,12 @@ DEVICE_TYPES = [
]
def detect(identity_str, curve_name):
def detect():
"""Detect the first available device and return it to the user."""
for device_type in DEVICE_TYPES:
try:
with device_type(identity_str, curve_name) as d:
with device_type() as d:
return d
except interface.NotFoundError as e:
log.debug('device not found: %s', e)
raise IOError('No device found: "{}" ({})'.format(identity_str,
curve_name))
raise IOError('No device found!')

View File

@@ -45,20 +45,6 @@ def identity_to_string(identity_dict):
return ''.join(result)
def get_bip32_address(identity_dict, ecdh=False):
"""Compute BIP32 derivation address according to SLIP-0013/0017."""
index = struct.pack('<L', identity_dict.get('index', 0))
addr = index + identity_to_string(identity_dict).encode('ascii')
log.debug('bip32 address string: %r', addr)
digest = hashlib.sha256(addr).digest()
s = io.BytesIO(bytearray(digest))
hardened = 0x80000000
addr_0 = [13, 17][bool(ecdh)]
address_n = [addr_0] + list(util.recv(s, '<LLLL'))
return [(hardened | value) for value in address_n]
class Error(Exception):
"""Device-related error."""
@@ -71,18 +57,49 @@ class DeviceError(Error):
""""Error during device operation."""
class Device(object):
"""Abstract cryptographic hardware device interface."""
class Identity(object):
"""Represent SLIP-0013 identity, together with a elliptic curve choice."""
def __init__(self, identity_str, curve_name):
"""Configure for specific identity and elliptic curve usage."""
self.identity_dict = string_to_identity(identity_str)
self.curve_name = curve_name
self.conn = None
def identity_str(self):
def items(self):
"""Return a copy of identity_dict items."""
return self.identity_dict.items()
def __str__(self):
"""Return identity serialized to string."""
return identity_to_string(self.identity_dict)
return '<{}|{}>'.format(identity_to_string(self.identity_dict), self.curve_name)
def get_bip32_address(self, ecdh=False):
"""Compute BIP32 derivation address according to SLIP-0013/0017."""
index = struct.pack('<L', self.identity_dict.get('index', 0))
addr = index + identity_to_string(self.identity_dict).encode('ascii')
log.debug('bip32 address string: %r', addr)
digest = hashlib.sha256(addr).digest()
s = io.BytesIO(bytearray(digest))
hardened = 0x80000000
addr_0 = [13, 17][bool(ecdh)]
address_n = [addr_0] + list(util.recv(s, '<LLLL'))
return [(hardened | value) for value in address_n]
def get_curve_name(self, ecdh=False):
"""Return correct curve name for device operations."""
if ecdh:
return formats.get_ecdh_curve_name(self.curve_name)
else:
return self.curve_name
class Device(object):
"""Abstract cryptographic hardware device interface."""
def __init__(self):
"""C-tor."""
self.conn = None
def connect(self):
"""Connect to device, otherwise raise NotFoundError."""
@@ -101,25 +118,18 @@ class Device(object):
log.exception('close failed: %s', e)
self.conn = None
def pubkey(self, ecdh=False):
def pubkey(self, identity, ecdh=False):
"""Get public key (as bytes)."""
raise NotImplementedError()
def sign(self, blob):
def sign(self, identity, blob):
"""Sign given blob and return the signature (as bytes)."""
raise NotImplementedError()
def ecdh(self, pubkey):
def ecdh(self, identity, pubkey):
"""Get shared session key using Elliptic Curve Diffie-Hellman."""
raise NotImplementedError()
def __str__(self):
"""Human-readable representation."""
return '{}'.format(self.__class__.__name__)
def get_curve_name(self, ecdh=False):
"""Return correct curve name for device operations."""
if ecdh:
return formats.get_ecdh_curve_name(self.curve_name)
else:
return self.curve_name

View File

@@ -1,7 +1,6 @@
"""KeepKey-related code (see https://www.keepkey.com/)."""
from . import interface, trezor
from .. import formats
class KeepKey(trezor.Trezor):
@@ -11,15 +10,7 @@ class KeepKey(trezor.Trezor):
required_version = '>=1.0.4'
def connect(self):
"""No support for other than NIST256P elliptic curves."""
if self.curve_name not in {formats.CURVE_NIST256}:
fmt = 'KeepKey does not support {} curve'
raise interface.NotFoundError(fmt.format(self.curve_name))
return trezor.Trezor.connect(self)
def ecdh(self, pubkey):
def ecdh(self, identity, pubkey):
"""No support for ECDH in KeepKey firmware."""
msg = 'KeepKey does not support ECDH'
raise interface.NotFoundError(msg)

View File

@@ -44,11 +44,10 @@ class LedgerNanoS(interface.Device):
raise interface.NotFoundError(
'{} not connected: "{}"'.format(self, e))
def pubkey(self, ecdh=False):
def pubkey(self, identity, ecdh=False):
"""Get PublicKey object for specified BIP32 address and elliptic curve."""
curve_name = self.get_curve_name(ecdh)
path = _expand_path(interface.get_bip32_address(self.identity_dict,
ecdh=ecdh))
curve_name = identity.get_curve_name(ecdh)
path = _expand_path(identity.get_bip32_address(ecdh))
if curve_name == 'nist256p1':
p2 = '01'
else:
@@ -60,27 +59,26 @@ class LedgerNanoS(interface.Device):
result = bytearray(self.conn.exchange(bytes(apdu)))[1:]
return _convert_public_key(curve_name, result)
def sign(self, blob):
def sign(self, identity, blob):
"""Sign given blob and return the signature (as bytes)."""
path = _expand_path(interface.get_bip32_address(self.identity_dict,
ecdh=False))
if self.identity_dict['proto'] == 'ssh':
path = _expand_path(identity.get_bip32_address(ecdh=False))
if identity.identity_dict['proto'] == 'ssh':
ins = '04'
p1 = '00'
else:
ins = '08'
p1 = '00'
if self.curve_name == 'nist256p1':
p2 = '81' if self.identity_dict['proto'] == 'ssh' else '01'
if identity.curve_name == 'nist256p1':
p2 = '81' if identity.identity_dict['proto'] == 'ssh' else '01'
else:
p2 = '82' if self.identity_dict['proto'] == 'ssh' else '02'
p2 = '82' if identity.identity_dict['proto'] == 'ssh' else '02'
apdu = '80' + ins + p1 + p2
apdu = binascii.unhexlify(apdu)
apdu += bytearray([len(blob) + len(path) + 1])
apdu += bytearray([len(path) // 4]) + path
apdu += blob
result = bytearray(self.conn.exchange(bytes(apdu)))
if self.curve_name == 'nist256p1':
if identity.curve_name == 'nist256p1':
offset = 3
length = result[offset]
r = result[offset+1:offset+1+length]
@@ -96,11 +94,10 @@ class LedgerNanoS(interface.Device):
else:
return bytes(result[:64])
def ecdh(self, pubkey):
def ecdh(self, identity, pubkey):
"""Get shared session key using Elliptic Curve Diffie-Hellman."""
path = _expand_path(interface.get_bip32_address(self.identity_dict,
ecdh=True))
if self.curve_name == 'nist256p1':
path = _expand_path(identity.get_bip32_address(ecdh=True))
if identity.curve_name == 'nist256p1':
p2 = '01'
else:
p2 = '02'

View File

@@ -48,33 +48,31 @@ class Trezor(interface.Device):
"""Close connection."""
self.conn.close()
def pubkey(self, ecdh=False):
def pubkey(self, identity, ecdh=False):
"""Return public key."""
curve_name = self.get_curve_name(ecdh=ecdh)
curve_name = identity.get_curve_name(ecdh=ecdh)
log.debug('"%s" getting public key (%s) from %s',
interface.identity_to_string(self.identity_dict),
curve_name, self)
addr = interface.get_bip32_address(self.identity_dict, ecdh=ecdh)
identity, curve_name, self)
addr = identity.get_bip32_address(ecdh=ecdh)
result = self.conn.get_public_node(n=addr,
ecdsa_curve_name=curve_name)
log.debug('result: %s', result)
return result.node.public_key
def _identity_proto(self):
def _identity_proto(self, identity):
result = self.defs.IdentityType()
for name, value in self.identity_dict.items():
for name, value in identity.items():
setattr(result, name, value)
return result
def sign(self, blob):
def sign(self, identity, blob):
"""Sign given blob and return the signature (as bytes)."""
curve_name = self.get_curve_name(ecdh=False)
curve_name = identity.get_curve_name(ecdh=False)
log.debug('"%s" signing %r (%s) on %s',
interface.identity_to_string(self.identity_dict), blob,
curve_name, self)
identity, blob, curve_name, self)
try:
result = self.conn.sign_identity(
identity=self._identity_proto(),
identity=self._identity_proto(identity),
challenge_hidden=blob,
challenge_visual='',
ecdsa_curve_name=curve_name)
@@ -87,15 +85,14 @@ class Trezor(interface.Device):
log.debug(msg, exc_info=True)
raise interface.DeviceError(msg)
def ecdh(self, pubkey):
def ecdh(self, identity, pubkey):
"""Get shared session key using Elliptic Curve Diffie-Hellman."""
curve_name = self.get_curve_name(ecdh=True)
curve_name = identity.get_curve_name(ecdh=True)
log.debug('"%s" shared session key (%s) for %r from %s',
interface.identity_to_string(self.identity_dict),
curve_name, pubkey, self)
identity, curve_name, pubkey, self)
try:
result = self.conn.get_ecdh_session_key(
identity=self._identity_proto(),
identity=self._identity_proto(identity),
peer_public_key=pubkey,
ecdsa_curve_name=curve_name)
log.debug('result: %s', result)

View File

@@ -123,5 +123,5 @@ def auto_unlock():
args = p.parse_args()
util.setup_logging(verbosity=args.verbose)
d = device.detect(identity_str='', curve_name='')
d = device.detect()
log.info('unlocked %s device', d)

View File

@@ -12,28 +12,28 @@ class Client(object):
def __init__(self, user_id, curve_name):
"""Connect to the device and retrieve required public key."""
self.device = device.detect(identity_str='',
curve_name=curve_name)
self.device.identity_dict['proto'] = 'gpg'
self.device.identity_dict['host'] = user_id
self.device = device.detect()
self.user_id = user_id
self.identity = device.interface.Identity(
identity_str='gpg://', curve_name=curve_name)
self.identity.identity_dict['host'] = user_id
def pubkey(self, ecdh=False):
"""Return public key as VerifyingKey object."""
with self.device:
pubkey = self.device.pubkey(ecdh=ecdh)
pubkey = self.device.pubkey(ecdh=ecdh, identity=self.identity)
return formats.decompress_pubkey(
pubkey=pubkey, curve_name=self.device.curve_name)
pubkey=pubkey, curve_name=self.identity.curve_name)
def sign(self, digest):
"""Sign the digest and return a serialized signature."""
log.info('please confirm GPG signature on %s for "%s"...',
self.device, self.user_id)
if self.device.curve_name == formats.CURVE_NIST256:
if self.identity.curve_name == formats.CURVE_NIST256:
digest = digest[:32] # sign the first 256 bits
log.debug('signing digest: %s', util.hexlify(digest))
with self.device:
sig = self.device.sign(blob=digest)
sig = self.device.sign(blob=digest, identity=self.identity)
return (util.bytes2num(sig[:32]), util.bytes2num(sig[32:]))
def ecdh(self, pubkey):
@@ -41,4 +41,4 @@ class Client(object):
log.info('please confirm GPG decryption on %s for "%s"...',
self.device, self.user_id)
with self.device:
return self.device.ecdh(pubkey=pubkey)
return self.device.ecdh(pubkey=pubkey, identity=self.identity)

View File

@@ -140,7 +140,7 @@ class Handler(object):
label = key['name'].decode('ascii') # label should be a string
log.debug('signing %d-byte blob with "%s" key', len(blob), label)
try:
signature = self.signer(blob=blob)
signature = self.signer(blob=blob, identity=key['identity'])
except IOError:
return failure()
log.debug('signature: %r', signature)

View File

@@ -12,7 +12,7 @@ PUBKEY = (b'\x03\xd8(\xb5\xa6`\xbet0\x95\xac:[;]\xdc,\xbd\xdc?\xd7\xc0\xec'
b'\xdd\xbc+\xfar~\x9dAis')
PUBKEY_TEXT = ('ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzd'
'HAyNTYAAABBBNgotaZgvnQwlaw6Wztd3Cy93D/XwOzdvCv6cn6dQWlzNMEQeW'
'VUfhvrGljR2Z/CMRONY6ejB+9PnpUOPuzYqi8= ssh://localhost:22\n')
'VUfhvrGljR2Z/CMRONY6ejB+9PnpUOPuzYqi8= <localhost:22|nist256p1>\n')
class MockDevice(device.interface.Device): # pylint: disable=abstract-method
@@ -23,11 +23,11 @@ class MockDevice(device.interface.Device): # pylint: disable=abstract-method
def close(self):
self.conn = None
def pubkey(self, ecdh=False): # pylint: disable=unused-argument
def pubkey(self, identity, ecdh=False): # pylint: disable=unused-argument
assert self.conn
return PUBKEY
def sign(self, blob):
def sign(self, identity, blob):
"""Sign given blob and return the signature (as bytes)."""
assert self.conn
assert blob == BLOB
@@ -59,11 +59,11 @@ SIG = (b'R\x19T\xf2\x84$\xef#\x0e\xee\x04X\xc6\xc3\x99T`\xd1\xd8\xf7!'
def test_ssh_agent():
identity_str = 'localhost:22'
c = client.Client(device=MockDevice(identity_str=identity_str,
curve_name=CURVE))
assert c.get_public_key() == PUBKEY_TEXT
signature = c.sign_ssh_challenge(blob=BLOB)
identity = device.interface.Identity(identity_str='localhost:22',
curve_name=CURVE)
c = client.Client(device=MockDevice())
assert c.get_public_key(identity) == PUBKEY_TEXT
signature = c.sign_ssh_challenge(blob=BLOB, identity=identity)
key = formats.import_public_key(PUBKEY_TEXT)
serialized_sig = key['verifier'](sig=signature, msg=BLOB)
@@ -77,9 +77,9 @@ def test_ssh_agent():
assert r[1:] + s[1:] == SIG
# pylint: disable=unused-argument
def cancel_sign(blob):
def cancel_sign(identity, blob):
raise IOError(42, 'ERROR')
c.device.sign = cancel_sign
with pytest.raises(IOError):
c.sign_ssh_challenge(blob=BLOB)
c.sign_ssh_challenge(blob=BLOB, identity=identity)

View File

@@ -1,6 +1,6 @@
import pytest
from .. import formats, protocol
from .. import device, formats, protocol
# pylint: disable=line-too-long
@@ -17,6 +17,7 @@ NIST256_SIGN_REPLY = b'\x00\x00\x00j\x0e\x00\x00\x00e\x00\x00\x00\x13ecdsa-sha2-
def test_list():
key = formats.import_public_key(NIST256_KEY)
key['identity'] = device.interface.Identity('ssh://localhost', 'nist256p1')
h = protocol.Handler(keys=[key], signer=None)
reply = h.handle(LIST_MSG)
assert reply == LIST_NIST256_REPLY
@@ -28,13 +29,15 @@ def test_unsupported():
assert reply == b'\x00\x00\x00\x01\x05'
def ecdsa_signer(blob):
def ecdsa_signer(identity, blob):
assert str(identity) == '<ssh://localhost|nist256p1>'
assert blob == NIST256_BLOB
return NIST256_SIG
def test_ecdsa_sign():
key = formats.import_public_key(NIST256_KEY)
key['identity'] = device.interface.Identity('ssh://localhost', 'nist256p1')
h = protocol.Handler(keys=[key], signer=ecdsa_signer)
reply = h.handle(NIST256_SIGN_MSG)
assert reply == NIST256_SIGN_REPLY
@@ -42,30 +45,30 @@ def test_ecdsa_sign():
def test_sign_missing():
h = protocol.Handler(keys=[], signer=ecdsa_signer)
with pytest.raises(KeyError):
h.handle(NIST256_SIGN_MSG)
def test_sign_wrong():
def wrong_signature(blob):
def wrong_signature(identity, blob):
assert str(identity) == '<ssh://localhost|nist256p1>'
assert blob == NIST256_BLOB
return b'\x00' * 64
key = formats.import_public_key(NIST256_KEY)
key['identity'] = device.interface.Identity('ssh://localhost', 'nist256p1')
h = protocol.Handler(keys=[key], signer=wrong_signature)
with pytest.raises(ValueError):
h.handle(NIST256_SIGN_MSG)
def test_sign_cancel():
def cancel_signature(blob): # pylint: disable=unused-argument
def cancel_signature(identity, blob): # pylint: disable=unused-argument
raise IOError()
key = formats.import_public_key(NIST256_KEY)
key['identity'] = device.interface.Identity('ssh://localhost', 'nist256p1')
h = protocol.Handler(keys=[key], signer=cancel_signature)
assert h.handle(NIST256_SIGN_MSG) == protocol.failure()
@@ -77,13 +80,15 @@ ED25519_BLOB = b'''\x00\x00\x00 i3\xae}yk\\\xa1L\xb9\xe1\xbf\xbc\x8e\x87\r\x0e\x
ED25519_SIG = b'''\x8eb)\xa6\xe9P\x83VE\xfbq\xc6\xbf\x1dV3\xe3<O\x11\xc0\xfa\xe4\xed\xb8\x81.\x81\xc8\xa6\xba\x10RA'a\xbc\xa9\xd3\xdb\x98\x07\xf0\x1a\x9c4\x84<\xaf\x99\xb7\xe5G\xeb\xf7$\xc1\r\x86f\x16\x8e\x08\x05''' # nopep8
def ed25519_signer(blob):
def ed25519_signer(identity, blob):
assert str(identity) == '<ssh://localhost|ed25519>'
assert blob == ED25519_BLOB
return ED25519_SIG
def test_ed25519_sign():
key = formats.import_public_key(ED25519_KEY)
key['identity'] = device.interface.Identity('ssh://localhost', 'ed25519')
h = protocol.Handler(keys=[key], signer=ed25519_signer)
reply = h.handle(ED25519_SIGN_MSG)
assert reply == ED25519_SIGN_REPLY