ssh: use new device package (instead of factory)

This commit is contained in:
Roman Zeyde
2016-10-28 10:53:53 +03:00
parent 946ab633d4
commit 0f79b5ff2e
7 changed files with 79 additions and 229 deletions

View File

@@ -7,14 +7,14 @@ import re
import subprocess import subprocess
import sys import sys
from . import client, formats, protocol, server, util from . import client, device, formats, protocol, server, util
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def ssh_args(label): def ssh_args(label):
"""Create SSH command for connecting specified server.""" """Create SSH command for connecting specified server."""
identity = util.string_to_identity(label, identity_type=dict) identity = device.interface.string_to_identity(label)
args = [] args = []
if 'port' in identity: if 'port' in identity:
@@ -125,27 +125,28 @@ def run_agent(client_factory=client.Client):
args = create_agent_parser().parse_args() args = create_agent_parser().parse_args()
util.setup_logging(verbosity=args.verbose) util.setup_logging(verbosity=args.verbose)
with client_factory(curve=args.ecdsa_curve_name) as conn: d = device.detect(identity_str=args.identity,
label = args.identity curve_name=args.ecdsa_curve_name)
command = args.command conn = client_factory(device=d)
public_key = conn.get_public_key(label=label) command = args.command
public_key = conn.get_public_key()
if args.connect: if args.connect:
command = ssh_args(label) + args.command command = ssh_args(args.identity) + args.command
log.debug('SSH connect: %r', command) log.debug('SSH connect: %r', command)
use_shell = bool(args.shell) use_shell = bool(args.shell)
if use_shell: if use_shell:
command = os.environ['SHELL'] command = os.environ['SHELL']
log.debug('using shell: %r', command) log.debug('using shell: %r', command)
if not command: if not command:
sys.stdout.write(public_key) sys.stdout.write(public_key)
return return
return run_server(conn=conn, public_key=public_key, command=command, return run_server(conn=conn, public_key=public_key, command=command,
debug=args.debug, timeout=args.timeout) debug=args.debug, timeout=args.timeout)
@handle_connection_error @handle_connection_error

View File

@@ -3,11 +3,10 @@ Connection to hardware authentication device.
It is used for getting SSH public keys and ECDSA signing of server requests. It is used for getting SSH public keys and ECDSA signing of server requests.
""" """
import binascii
import io import io
import logging import logging
from . import factory, formats, util from . import formats, util
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@@ -15,79 +14,36 @@ log = logging.getLogger(__name__)
class Client(object): class Client(object):
"""Client wrapper for SSH authentication device.""" """Client wrapper for SSH authentication device."""
def __init__(self, loader=factory.load, curve=formats.CURVE_NIST256): def __init__(self, device):
"""Connect to hardware device.""" """Connect to hardware device."""
client_wrapper = loader() device.identity_dict['proto'] = 'ssh'
self.client = client_wrapper.connection self.device = device
self.identity_type = client_wrapper.identity_type
self.device_name = client_wrapper.device_name
self.call_exception = client_wrapper.call_exception
self.curve = curve
def __enter__(self): def get_public_key(self):
"""Start a session, and test connection.""" """Get SSH public key from the device."""
msg = 'Hello World!' with self.device:
assert self.client.ping(msg) == msg pubkey = self.device.pubkey()
return self
def __exit__(self, *args): vk = formats.decompress_pubkey(pubkey=pubkey,
"""Keep the session open (doesn't forget PIN).""" curve_name=self.device.curve_name)
log.info('disconnected from %s', self.device_name) return formats.export_public_key(vk=vk,
self.client.close() label=self.device.identity_str())
def get_identity(self, label, index=0): def sign_ssh_challenge(self, blob):
"""Parse label string into Identity protobuf.""" """Sign given blob using a private key on the device."""
identity = util.string_to_identity(label, self.identity_type)
identity.proto = 'ssh'
identity.index = index
return identity
def get_public_key(self, label):
"""Get SSH public key corresponding to specified by label."""
identity = self.get_identity(label=label)
label = util.identity_to_string(identity) # canonize key label
log.info('getting "%s" public key (%s) from %s...',
label, self.curve, self.device_name)
addr = util.get_bip32_address(identity)
node = self.client.get_public_node(n=addr,
ecdsa_curve_name=self.curve)
pubkey = node.node.public_key
vk = formats.decompress_pubkey(pubkey=pubkey, curve_name=self.curve)
return formats.export_public_key(vk=vk, label=label)
def sign_ssh_challenge(self, label, blob):
"""Sign given blob using a private key, specified by the label."""
identity = self.get_identity(label=label)
msg = _parse_ssh_blob(blob) msg = _parse_ssh_blob(blob)
log.debug('%s: user %r via %r (%r)', log.debug('%s: user %r via %r (%r)',
msg['conn'], msg['user'], msg['auth'], msg['key_type']) msg['conn'], msg['user'], msg['auth'], msg['key_type'])
log.debug('nonce: %s', binascii.hexlify(msg['nonce'])) log.debug('nonce: %r', msg['nonce'])
log.debug('fingerprint: %s', msg['public_key']['fingerprint']) log.debug('fingerprint: %s', msg['public_key']['fingerprint'])
log.debug('hidden challenge size: %d bytes', len(blob)) log.debug('hidden challenge size: %d bytes', len(blob))
log.info('please confirm user "%s" login to "%s" using %s...', log.info('please confirm user "%s" login to "%s" using %s...',
msg['user'].decode('ascii'), label, self.device_name) msg['user'].decode('ascii'), self.device.identity_str(),
self.device)
try: with self.device:
result = self.client.sign_identity(identity=identity, return self.device.sign(blob=blob)
challenge_hidden=blob,
challenge_visual='',
ecdsa_curve_name=self.curve)
except self.call_exception as e:
code, msg = e.args
log.warning('%s error #%s: %s', self.device_name, code, msg)
raise IOError(msg) # close current connection, keep server open
verifying_key = formats.decompress_pubkey(pubkey=result.public_key,
curve_name=self.curve)
key_type, blob = formats.serialize_verifying_key(verifying_key)
assert blob == msg['public_key']['blob']
assert key_type == msg['key_type']
assert len(result.signature) == 65
assert result.signature[:1] == bytearray([0])
return result.signature[1:]
def _parse_ssh_blob(data): def _parse_ssh_blob(data):

View File

@@ -76,7 +76,6 @@ class Device(object):
def __init__(self, identity_str, curve_name): def __init__(self, identity_str, curve_name):
"""Configure for specific identity and elliptic curve usage.""" """Configure for specific identity and elliptic curve usage."""
self.identity_dict = string_to_identity(identity_str) self.identity_dict = string_to_identity(identity_str)
assert curve_name in formats.SUPPORTED_CURVES
self.curve_name = curve_name self.curve_name = curve_name
self.conn = None self.conn = None

View File

@@ -7,7 +7,6 @@ for more details.
The server's source code can be found here: The server's source code can be found here:
https://github.com/openssh/openssh-portable/blob/master/authfd.c https://github.com/openssh/openssh-portable/blob/master/authfd.c
""" """
import binascii
import io import io
import logging import logging
@@ -138,13 +137,13 @@ class Handler(object):
else: else:
raise KeyError('key not found') raise KeyError('key not found')
log.debug('signing %d-byte blob', len(blob))
label = key['name'].decode('ascii') # label should be a string label = key['name'].decode('ascii') # label should be a string
log.debug('signing %d-byte blob with "%s" key', len(blob), label)
try: try:
signature = self.signer(label=label, blob=blob) signature = self.signer(blob=blob)
except IOError: except IOError:
return failure() return failure()
log.debug('signature: %s', binascii.hexlify(signature)) log.debug('signature: %r', signature)
try: try:
sig_bytes = key['verifier'](sig=signature, msg=blob) sig_bytes = key['verifier'](sig=signature, msg=blob)

View File

@@ -3,7 +3,7 @@ import io
import mock import mock
import pytest import pytest
from .. import client, factory, formats, util from .. import client, device, formats, util
ADDR = [2147483661, 2810943954, 3938368396, 3454558782, 3848009040] ADDR = [2147483661, 2810943954, 3938368396, 3454558782, 3848009040]
CURVE = 'nist256p1' CURVE = 'nist256p1'
@@ -15,29 +15,23 @@ PUBKEY_TEXT = ('ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzd'
'VUfhvrGljR2Z/CMRONY6ejB+9PnpUOPuzYqi8= ssh://localhost:22\n') 'VUfhvrGljR2Z/CMRONY6ejB+9PnpUOPuzYqi8= ssh://localhost:22\n')
class FakeConnection(object): class MockDevice(device.interface.Device): # pylint: disable=abstract-method
def __init__(self): def connect(self): # pylint: disable=no-self-use
self.closed = False return mock.Mock()
def close(self): def close(self):
self.closed = True self.conn = None
def clear_session(self): def pubkey(self, ecdh=False): # pylint: disable=unused-argument
self.closed = True assert self.conn
return PUBKEY
def get_public_node(self, n, ecdsa_curve_name=b'secp256k1'): def sign(self, blob):
assert not self.closed """Sign given blob and return the signature (as bytes)."""
assert n == ADDR assert self.conn
assert ecdsa_curve_name in {'secp256k1', 'nist256p1'} assert blob == BLOB
result = mock.Mock(spec=[]) return SIG
result.node = mock.Mock(spec=[])
result.node.public_key = PUBKEY
return result
def ping(self, msg):
assert not self.closed
return msg
def identity_type(**kwargs): def identity_type(**kwargs):
@@ -50,13 +44,6 @@ def identity_type(**kwargs):
return result return result
def load_client():
return factory.ClientWrapper(connection=FakeConnection(),
identity_type=identity_type,
device_name='DEVICE_NAME',
call_exception=Exception)
BLOB = (b'\x00\x00\x00 \xce\xe0\xc9\xd5\xceu/\xe8\xc5\xf2\xbfR+x\xa1\xcf\xb0' BLOB = (b'\x00\x00\x00 \xce\xe0\xc9\xd5\xceu/\xe8\xc5\xf2\xbfR+x\xa1\xcf\xb0'
b'\x8e;R\xd3)m\x96\x1b\xb4\xd8s\xf1\x99\x16\xaa2\x00\x00\x00\x05roman' b'\x8e;R\xd3)m\x96\x1b\xb4\xd8s\xf1\x99\x16\xaa2\x00\x00\x00\x05roman'
b'\x00\x00\x00\x0essh-connection\x00\x00\x00\tpublickey' b'\x00\x00\x00\x0essh-connection\x00\x00\x00\tpublickey'
@@ -66,71 +53,33 @@ BLOB = (b'\x00\x00\x00 \xce\xe0\xc9\xd5\xceu/\xe8\xc5\xf2\xbfR+x\xa1\xcf\xb0'
b'\xdd\xbc+\xfar~\x9dAis4\xc1\x10yeT~\x1b\xeb\x1aX\xd1\xd9\x9f\xc21' b'\xdd\xbc+\xfar~\x9dAis4\xc1\x10yeT~\x1b\xeb\x1aX\xd1\xd9\x9f\xc21'
b'\x13\x8dc\xa7\xa3\x07\xefO\x9e\x95\x0e>\xec\xd8\xaa/') b'\x13\x8dc\xa7\xa3\x07\xefO\x9e\x95\x0e>\xec\xd8\xaa/')
SIG = (b'\x00R\x19T\xf2\x84$\xef#\x0e\xee\x04X\xc6\xc3\x99T`\xd1\xd8\xf7!' SIG = (b'R\x19T\xf2\x84$\xef#\x0e\xee\x04X\xc6\xc3\x99T`\xd1\xd8\xf7!'
b'\x862@cx\xb8\xb9i@1\x1b3#\x938\x86]\x97*Y\xb2\x02Xa\xdf@\xecK' b'\x862@cx\xb8\xb9i@1\x1b3#\x938\x86]\x97*Y\xb2\x02Xa\xdf@\xecK'
b'\xdc\xf0H\xab\xa8\xac\xa7? \x8f=C\x88N\xe2') b'\xdc\xf0H\xab\xa8\xac\xa7? \x8f=C\x88N\xe2')
def test_ssh_agent(): def test_ssh_agent():
label = 'localhost:22' identity_str = 'localhost:22'
c = client.Client(loader=load_client) c = client.Client(device=MockDevice(identity_str=identity_str,
ident = c.get_identity(label=label) curve_name=CURVE))
assert ident.host == 'localhost' assert c.get_public_key() == PUBKEY_TEXT
assert ident.proto == 'ssh' signature = c.sign_ssh_challenge(blob=BLOB)
assert ident.port == '22'
assert ident.user is None
assert ident.path is None
assert ident.index == 0
with c: key = formats.import_public_key(PUBKEY_TEXT)
assert c.get_public_key(label) == PUBKEY_TEXT serialized_sig = key['verifier'](sig=signature, msg=BLOB)
def ssh_sign_identity(identity, challenge_hidden, stream = io.BytesIO(serialized_sig)
challenge_visual, ecdsa_curve_name): r = util.read_frame(stream)
assert (util.identity_to_string(identity) == s = util.read_frame(stream)
util.identity_to_string(ident)) assert not stream.read()
assert challenge_hidden == BLOB assert r[:1] == b'\x00'
assert challenge_visual == '' assert s[:1] == b'\x00'
assert ecdsa_curve_name == 'nist256p1' assert r[1:] + s[1:] == SIG
result = mock.Mock(spec=[]) # pylint: disable=unused-argument
result.public_key = PUBKEY def cancel_sign(blob):
result.signature = SIG raise IOError(42, 'ERROR')
return result
c.client.sign_identity = ssh_sign_identity c.device.sign = cancel_sign
signature = c.sign_ssh_challenge(label=label, blob=BLOB) with pytest.raises(IOError):
c.sign_ssh_challenge(blob=BLOB)
key = formats.import_public_key(PUBKEY_TEXT)
serialized_sig = key['verifier'](sig=signature, msg=BLOB)
stream = io.BytesIO(serialized_sig)
r = util.read_frame(stream)
s = util.read_frame(stream)
assert not stream.read()
assert r[:1] == b'\x00'
assert s[:1] == b'\x00'
assert r[1:] + s[1:] == SIG[1:]
c.client.call_exception = ValueError
# pylint: disable=unused-argument
def cancel_sign_identity(identity, challenge_hidden,
challenge_visual, ecdsa_curve_name):
raise c.client.call_exception(42, 'ERROR')
c.client.sign_identity = cancel_sign_identity
with pytest.raises(IOError):
c.sign_ssh_challenge(label=label, blob=BLOB)
def test_utils():
identity = mock.Mock(spec=[])
identity.proto = 'https'
identity.user = 'user'
identity.host = 'host'
identity.port = '443'
identity.path = '/path'
url = 'https://user@host:443/path'
assert util.identity_to_string(identity) == url

View File

@@ -28,8 +28,7 @@ def test_unsupported():
assert reply == b'\x00\x00\x00\x01\x05' assert reply == b'\x00\x00\x00\x01\x05'
def ecdsa_signer(label, blob): def ecdsa_signer(blob):
assert label == 'ssh://localhost'
assert blob == NIST256_BLOB assert blob == NIST256_BLOB
return NIST256_SIG return NIST256_SIG
@@ -49,8 +48,7 @@ def test_sign_missing():
def test_sign_wrong(): def test_sign_wrong():
def wrong_signature(label, blob): def wrong_signature(blob):
assert label == 'ssh://localhost'
assert blob == NIST256_BLOB assert blob == NIST256_BLOB
return b'\x00' * 64 return b'\x00' * 64
@@ -62,7 +60,7 @@ def test_sign_wrong():
def test_sign_cancel(): def test_sign_cancel():
def cancel_signature(label, blob): # pylint: disable=unused-argument def cancel_signature(blob): # pylint: disable=unused-argument
raise IOError() raise IOError()
key = formats.import_public_key(NIST256_KEY) key = formats.import_public_key(NIST256_KEY)
@@ -79,8 +77,7 @@ ED25519_BLOB = b'''\x00\x00\x00 i3\xae}yk\\\xa1L\xb9\xe1\xbf\xbc\x8e\x87\r\x0e\x
ED25519_SIG = b'''\x8eb)\xa6\xe9P\x83VE\xfbq\xc6\xbf\x1dV3\xe3<O\x11\xc0\xfa\xe4\xed\xb8\x81.\x81\xc8\xa6\xba\x10RA'a\xbc\xa9\xd3\xdb\x98\x07\xf0\x1a\x9c4\x84<\xaf\x99\xb7\xe5G\xeb\xf7$\xc1\r\x86f\x16\x8e\x08\x05''' # nopep8 ED25519_SIG = b'''\x8eb)\xa6\xe9P\x83VE\xfbq\xc6\xbf\x1dV3\xe3<O\x11\xc0\xfa\xe4\xed\xb8\x81.\x81\xc8\xa6\xba\x10RA'a\xbc\xa9\xd3\xdb\x98\x07\xf0\x1a\x9c4\x84<\xaf\x99\xb7\xe5G\xeb\xf7$\xc1\r\x86f\x16\x8e\x08\x05''' # nopep8
def ed25519_signer(label, blob): def ed25519_signer(blob):
assert label == 'ssh://localhost'
assert blob == ED25519_BLOB assert blob == ED25519_BLOB
return ED25519_SIG return ED25519_SIG

View File

@@ -1,10 +1,8 @@
"""Various I/O and serialization utilities.""" """Various I/O and serialization utilities."""
import binascii import binascii
import contextlib import contextlib
import hashlib
import io import io
import logging import logging
import re
import struct import struct
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@@ -180,55 +178,6 @@ class Reader(object):
self._captured = None self._captured = None
_identity_regexp = re.compile(''.join([
'^'
r'(?:(?P<proto>.*)://)?',
r'(?:(?P<user>.*)@)?',
r'(?P<host>.*?)',
r'(?::(?P<port>\w*))?',
r'(?P<path>/.*)?',
'$'
]))
def string_to_identity(s, identity_type):
"""Parse string into Identity protobuf."""
m = _identity_regexp.match(s)
result = m.groupdict()
log.debug('parsed identity: %s', result)
kwargs = {k: v for k, v in result.items() if v}
return identity_type(**kwargs)
def identity_to_string(identity):
"""Dump Identity protobuf into its string representation."""
result = []
if identity.proto:
result.append(identity.proto + '://')
if identity.user:
result.append(identity.user + '@')
result.append(identity.host)
if identity.port:
result.append(':' + identity.port)
if identity.path:
result.append(identity.path)
return ''.join(result)
def get_bip32_address(identity, ecdh=False):
"""Compute BIP32 derivation address according to SLIP-0013/0017."""
index = struct.pack('<L', identity.index)
addr = index + identity_to_string(identity).encode('ascii')
log.debug('address string: %r', addr)
digest = hashlib.sha256(addr).digest()
s = io.BytesIO(bytearray(digest))
hardened = 0x80000000
addr_0 = [13, 17][bool(ecdh)]
address_n = [addr_0] + list(recv(s, '<LLLL'))
return [(hardened | value) for value in address_n]
def setup_logging(verbosity, **kwargs): def setup_logging(verbosity, **kwargs):
"""Configure logging for this tool.""" """Configure logging for this tool."""
fmt = ('%(asctime)s %(levelname)-12s %(message)-100s ' fmt = ('%(asctime)s %(levelname)-12s %(message)-100s '