diff --git a/setup.py b/setup.py index 1c94e6e..97c7b9f 100644 --- a/setup.py +++ b/setup.py @@ -31,7 +31,6 @@ setup( ], entry_points={'console_scripts': [ 'trezor-agent = trezor_agent.__main__:run_agent', - 'trezor-git = trezor_agent.__main__:run_git', 'trezor-gpg-create = trezor_agent.gpg.__main__:main_create', 'trezor-gpg-agent = trezor_agent.gpg.__main__:main_agent', 'trezor-gpg-unlock = trezor_agent.gpg.__main__:auto_unlock', diff --git a/trezor_agent/__main__.py b/trezor_agent/__main__.py index 0a291c3..3fa81c3 100644 --- a/trezor_agent/__main__.py +++ b/trezor_agent/__main__.py @@ -93,13 +93,11 @@ def git_host(remote_name, attributes): return '{user}@{host}'.format(**match.groupdict()) -def run_server(conn, public_key, command, debug, timeout): +def run_server(conn, public_keys, command, debug, timeout): """Common code for run_agent and run_git below.""" try: signer = conn.sign_ssh_challenge - public_key = formats.import_public_key(public_key) - log.info('using SSH public key: %s', public_key['fingerprint']) - handler = protocol.Handler(keys=[public_key], signer=signer, + handler = protocol.Handler(keys=public_keys, signer=signer, debug=debug) with server.serve(handler=handler, timeout=timeout) as env: return server.run_process(command=command, environ=env) @@ -119,18 +117,33 @@ def handle_connection_error(func): return wrapper +def parse_config(fname): + """Parse config file into a list of Identity objects.""" + contents = open(fname).read() + for identity_str, curve_name in re.findall(r'\<(.*?)\|(.*?)\>', contents): + yield device.interface.Identity(identity_str=identity_str, + curve_name=curve_name) + + @handle_connection_error def run_agent(client_factory=client.Client): """Run ssh-agent using given hardware client factory.""" args = create_agent_parser().parse_args() util.setup_logging(verbosity=args.verbose) - d = device.detect(identity_str=args.identity, - curve_name=args.ecdsa_curve_name) - conn = client_factory(device=d) + conn = client_factory(device=device.detect()) + if args.identity.startswith('/'): + identities = list(parse_config(fname=args.identity)) + else: + identities = [device.interface.Identity( + identity_str=args.identity, curve_name=args.ecdsa_curve_name)] + for index, identity in enumerate(identities): + identity.identity_dict['proto'] = 'ssh' + log.info('identity #%d: %s', index, identity) command = args.command - public_key = conn.get_public_key() + + public_keys = [conn.get_public_key(i) for i in identities] if args.connect: command = ssh_args(args.identity) + args.command @@ -142,36 +155,12 @@ def run_agent(client_factory=client.Client): log.debug('using shell: %r', command) if not command: - sys.stdout.write(public_key) + for pk in public_keys: + sys.stdout.write(pk) return - return run_server(conn=conn, public_key=public_key, command=command, + public_keys = [formats.import_public_key(pk) for pk in public_keys] + for pk, identity in zip(public_keys, identities): + pk['identity'] = identity + return run_server(conn=conn, public_keys=public_keys, command=command, debug=args.debug, timeout=args.timeout) - - -@handle_connection_error -def run_git(client_factory=client.Client): - """Run git under ssh-agent using given hardware client factory.""" - args = create_git_parser().parse_args() - util.setup_logging(verbosity=args.verbose) - - with client_factory(curve=args.ecdsa_curve_name) as conn: - label = git_host(args.remote, ['pushurl', 'url']) - if not label: - log.error('Could not find "%s" SSH remote in .git/config', - args.remote) - return - - public_key = conn.get_public_key(label=label) - - if not args.test: - if args.command: - command = ['git'] + args.command - else: - sys.stdout.write(public_key) - return - else: - command = ['ssh', '-T', label] - - return run_server(conn=conn, public_key=public_key, command=command, - debug=args.debug, timeout=args.timeout) diff --git a/trezor_agent/client.py b/trezor_agent/client.py index 2e5b074..30dfb40 100644 --- a/trezor_agent/client.py +++ b/trezor_agent/client.py @@ -16,34 +16,34 @@ class Client(object): def __init__(self, device): """Connect to hardware device.""" - device.identity_dict['proto'] = 'ssh' self.device = device - def get_public_key(self): + def get_public_key(self, identity): """Get SSH public key from the device.""" with self.device: - pubkey = self.device.pubkey() + pubkey = self.device.pubkey(identity) vk = formats.decompress_pubkey(pubkey=pubkey, - curve_name=self.device.curve_name) + curve_name=identity.curve_name) return formats.export_public_key(vk=vk, - label=self.device.identity_str()) + label=str(identity)) - def sign_ssh_challenge(self, blob): + def sign_ssh_challenge(self, blob, identity): """Sign given blob using a private key on the device.""" msg = _parse_ssh_blob(blob) log.debug('%s: user %r via %r (%r)', msg['conn'], msg['user'], msg['auth'], msg['key_type']) log.debug('nonce: %r', msg['nonce']) - log.debug('fingerprint: %s', msg['public_key']['fingerprint']) + fp = msg['public_key']['fingerprint'] + log.debug('fingerprint: %s', fp) log.debug('hidden challenge size: %d bytes', len(blob)) log.info('please confirm user "%s" login to "%s" using %s...', - msg['user'].decode('ascii'), self.device.identity_str(), + msg['user'].decode('ascii'), identity, self.device) with self.device: - return self.device.sign(blob=blob) + return self.device.sign(blob=blob, identity=identity) def _parse_ssh_blob(data): diff --git a/trezor_agent/device/__init__.py b/trezor_agent/device/__init__.py index 65e915c..613c1cf 100644 --- a/trezor_agent/device/__init__.py +++ b/trezor_agent/device/__init__.py @@ -16,13 +16,12 @@ DEVICE_TYPES = [ ] -def detect(identity_str, curve_name): +def detect(): """Detect the first available device and return it to the user.""" for device_type in DEVICE_TYPES: try: - with device_type(identity_str, curve_name) as d: + with device_type() as d: return d except interface.NotFoundError as e: log.debug('device not found: %s', e) - raise IOError('No device found: "{}" ({})'.format(identity_str, - curve_name)) + raise IOError('No device found!') diff --git a/trezor_agent/device/interface.py b/trezor_agent/device/interface.py index 5baa21e..32bd44e 100644 --- a/trezor_agent/device/interface.py +++ b/trezor_agent/device/interface.py @@ -45,20 +45,6 @@ def identity_to_string(identity_dict): return ''.join(result) -def get_bip32_address(identity_dict, ecdh=False): - """Compute BIP32 derivation address according to SLIP-0013/0017.""" - index = struct.pack(''.format(identity_to_string(self.identity_dict), self.curve_name) + + def get_bip32_address(self, ecdh=False): + """Compute BIP32 derivation address according to SLIP-0013/0017.""" + index = struct.pack('\n') class MockDevice(device.interface.Device): # pylint: disable=abstract-method @@ -23,11 +23,11 @@ class MockDevice(device.interface.Device): # pylint: disable=abstract-method def close(self): self.conn = None - def pubkey(self, ecdh=False): # pylint: disable=unused-argument + def pubkey(self, identity, ecdh=False): # pylint: disable=unused-argument assert self.conn return PUBKEY - def sign(self, blob): + def sign(self, identity, blob): """Sign given blob and return the signature (as bytes).""" assert self.conn assert blob == BLOB @@ -59,11 +59,11 @@ SIG = (b'R\x19T\xf2\x84$\xef#\x0e\xee\x04X\xc6\xc3\x99T`\xd1\xd8\xf7!' def test_ssh_agent(): - identity_str = 'localhost:22' - c = client.Client(device=MockDevice(identity_str=identity_str, - curve_name=CURVE)) - assert c.get_public_key() == PUBKEY_TEXT - signature = c.sign_ssh_challenge(blob=BLOB) + identity = device.interface.Identity(identity_str='localhost:22', + curve_name=CURVE) + c = client.Client(device=MockDevice()) + assert c.get_public_key(identity) == PUBKEY_TEXT + signature = c.sign_ssh_challenge(blob=BLOB, identity=identity) key = formats.import_public_key(PUBKEY_TEXT) serialized_sig = key['verifier'](sig=signature, msg=BLOB) @@ -77,9 +77,9 @@ def test_ssh_agent(): assert r[1:] + s[1:] == SIG # pylint: disable=unused-argument - def cancel_sign(blob): + def cancel_sign(identity, blob): raise IOError(42, 'ERROR') c.device.sign = cancel_sign with pytest.raises(IOError): - c.sign_ssh_challenge(blob=BLOB) + c.sign_ssh_challenge(blob=BLOB, identity=identity) diff --git a/trezor_agent/tests/test_protocol.py b/trezor_agent/tests/test_protocol.py index 17a1001..2d8b0dc 100644 --- a/trezor_agent/tests/test_protocol.py +++ b/trezor_agent/tests/test_protocol.py @@ -1,6 +1,6 @@ import pytest -from .. import formats, protocol +from .. import device, formats, protocol # pylint: disable=line-too-long @@ -17,6 +17,7 @@ NIST256_SIGN_REPLY = b'\x00\x00\x00j\x0e\x00\x00\x00e\x00\x00\x00\x13ecdsa-sha2- def test_list(): key = formats.import_public_key(NIST256_KEY) + key['identity'] = device.interface.Identity('ssh://localhost', 'nist256p1') h = protocol.Handler(keys=[key], signer=None) reply = h.handle(LIST_MSG) assert reply == LIST_NIST256_REPLY @@ -28,13 +29,15 @@ def test_unsupported(): assert reply == b'\x00\x00\x00\x01\x05' -def ecdsa_signer(blob): +def ecdsa_signer(identity, blob): + assert str(identity) == '' assert blob == NIST256_BLOB return NIST256_SIG def test_ecdsa_sign(): key = formats.import_public_key(NIST256_KEY) + key['identity'] = device.interface.Identity('ssh://localhost', 'nist256p1') h = protocol.Handler(keys=[key], signer=ecdsa_signer) reply = h.handle(NIST256_SIGN_MSG) assert reply == NIST256_SIGN_REPLY @@ -42,30 +45,30 @@ def test_ecdsa_sign(): def test_sign_missing(): h = protocol.Handler(keys=[], signer=ecdsa_signer) - with pytest.raises(KeyError): h.handle(NIST256_SIGN_MSG) def test_sign_wrong(): - def wrong_signature(blob): + def wrong_signature(identity, blob): + assert str(identity) == '' assert blob == NIST256_BLOB return b'\x00' * 64 key = formats.import_public_key(NIST256_KEY) + key['identity'] = device.interface.Identity('ssh://localhost', 'nist256p1') h = protocol.Handler(keys=[key], signer=wrong_signature) - with pytest.raises(ValueError): h.handle(NIST256_SIGN_MSG) def test_sign_cancel(): - def cancel_signature(blob): # pylint: disable=unused-argument + def cancel_signature(identity, blob): # pylint: disable=unused-argument raise IOError() key = formats.import_public_key(NIST256_KEY) + key['identity'] = device.interface.Identity('ssh://localhost', 'nist256p1') h = protocol.Handler(keys=[key], signer=cancel_signature) - assert h.handle(NIST256_SIGN_MSG) == protocol.failure() @@ -77,13 +80,15 @@ ED25519_BLOB = b'''\x00\x00\x00 i3\xae}yk\\\xa1L\xb9\xe1\xbf\xbc\x8e\x87\r\x0e\x ED25519_SIG = b'''\x8eb)\xa6\xe9P\x83VE\xfbq\xc6\xbf\x1dV3\xe3' assert blob == ED25519_BLOB return ED25519_SIG def test_ed25519_sign(): key = formats.import_public_key(ED25519_KEY) + key['identity'] = device.interface.Identity('ssh://localhost', 'ed25519') h = protocol.Handler(keys=[key], signer=ed25519_signer) reply = h.handle(ED25519_SIGN_MSG) assert reply == ED25519_SIGN_REPLY