diff --git a/trezor_agent/__main__.py b/trezor_agent/__main__.py index c94ff69..92c68be 100644 --- a/trezor_agent/__main__.py +++ b/trezor_agent/__main__.py @@ -110,12 +110,10 @@ def git_host(remote_name, attributes): return '{user}@{host}'.format(**match.groupdict()) -def run_server(conn, public_keys, command, debug, timeout): +def run_server(conn, command, debug, timeout): """Common code for run_agent and run_git below.""" try: - signer = conn.sign_ssh_challenge - handler = protocol.Handler(keys=public_keys, signer=signer, - debug=debug) + handler = protocol.Handler(conn=conn, debug=debug) with server.serve(handler=handler, timeout=timeout) as env: return server.run_process(command=command, environ=env) except KeyboardInterrupt: @@ -142,13 +140,39 @@ def parse_config(fname): curve_name=curve_name) +class JustInTimeConnection(object): + """Connect to the device just before the needed operation.""" + + def __init__(self, conn_factory, identities): + """Create a JIT connection object.""" + self.conn_factory = conn_factory + self.identities = identities + + def public_keys(self): + """Return a list of SSH public keys (in textual format).""" + conn = self.conn_factory() + return [conn.get_public_key(i) for i in self.identities] + + def parse_public_keys(self): + """Parse SSH public keys into dictionaries.""" + public_keys = [formats.import_public_key(pk) + for pk in self.public_keys()] + for pk, identity in zip(public_keys, self.identities): + pk['identity'] = identity + return public_keys + + def sign(self, blob, identity): + """Sign a given blob using the specified identity on the device.""" + conn = self.conn_factory() + return conn.sign_ssh_challenge(blob=blob, identity=identity) + + @handle_connection_error def run_agent(client_factory=client.Client): """Run ssh-agent using given hardware client factory.""" args = create_agent_parser().parse_args() util.setup_logging(verbosity=args.verbose) - conn = client_factory(device=device.detect()) if args.identity.startswith('/'): identities = list(parse_config(fname=args.identity)) else: @@ -158,8 +182,6 @@ def run_agent(client_factory=client.Client): identity.identity_dict['proto'] = 'ssh' log.info('identity #%d: %s', index, identity) - public_keys = [conn.get_public_key(i) for i in identities] - if args.connect: command = ['ssh'] + ssh_args(args.identity) + args.command elif args.mosh: @@ -171,13 +193,12 @@ def run_agent(client_factory=client.Client): if use_shell: command = os.environ['SHELL'] - if not command: - for pk in public_keys: + conn = JustInTimeConnection( + conn_factory=lambda: client_factory(device.detect()), + identities=identities) + if command: + return run_server(conn=conn, command=command, debug=args.debug, + timeout=args.timeout) + else: + for pk in conn.public_keys(): sys.stdout.write(pk) - return - - public_keys = [formats.import_public_key(pk) for pk in public_keys] - for pk, identity in zip(public_keys, identities): - pk['identity'] = identity - return run_server(conn=conn, public_keys=public_keys, command=command, - debug=args.debug, timeout=args.timeout) diff --git a/trezor_agent/protocol.py b/trezor_agent/protocol.py index 987bece..14b6ae1 100644 --- a/trezor_agent/protocol.py +++ b/trezor_agent/protocol.py @@ -71,14 +71,13 @@ def _legacy_pubs(buf): class Handler(object): """ssh-agent protocol handler.""" - def __init__(self, keys, signer, debug=False): + def __init__(self, conn, debug=False): """ Create a protocol handler with specified public keys. Use specified signer function to sign SSH authentication requests. """ - self.public_keys = keys - self.signer = signer + self.conn = conn self.debug = debug self.methods = { @@ -107,7 +106,7 @@ class Handler(object): def list_pubs(self, buf): """SSH v2 public keys are serialized and returned.""" assert not buf.read() - keys = self.public_keys + keys = self.conn.parse_public_keys() code = util.pack('B', msg_code('SSH2_AGENT_IDENTITIES_ANSWER')) num = util.pack('L', len(keys)) log.debug('available keys: %s', [k['name'] for k in keys]) @@ -129,7 +128,7 @@ class Handler(object): assert util.read_frame(buf) == b'' assert not buf.read() - for k in self.public_keys: + for k in self.conn.parse_public_keys(): if (k['fingerprint']) == (key['fingerprint']): log.debug('using key %r (%s)', k['name'], k['fingerprint']) key = k @@ -140,7 +139,7 @@ class Handler(object): label = key['name'].decode('ascii') # label should be a string log.debug('signing %d-byte blob with "%s" key', len(blob), label) try: - signature = self.signer(blob=blob, identity=key['identity']) + signature = self.conn.sign(blob=blob, identity=key['identity']) except IOError: return failure() log.debug('signature: %r', signature) diff --git a/trezor_agent/tests/test_protocol.py b/trezor_agent/tests/test_protocol.py index 2d8b0dc..1afc086 100644 --- a/trezor_agent/tests/test_protocol.py +++ b/trezor_agent/tests/test_protocol.py @@ -1,3 +1,4 @@ +import mock import pytest from .. import device, formats, protocol @@ -15,16 +16,23 @@ NIST256_SIGN_MSG = b'\r\x00\x00\x00h\x00\x00\x00\x13ecdsa-sha2-nistp256\x00\x00\ NIST256_SIGN_REPLY = b'\x00\x00\x00j\x0e\x00\x00\x00e\x00\x00\x00\x13ecdsa-sha2-nistp256\x00\x00\x00J\x00\x00\x00!\x00\x88G!\x0c\n\x16:\xbeF\xbe\xb9\xd2\xa9&e\x89\xad\xc4}\x10\xf8\xbc\xdc\xef\x0e\x8d_\x8a6.\xb6\x1f\x00\x00\x00!\x00q\xf0\x16>,\x9a\xde\xe7(\xd6\xd7\x93\x1f\xed\xf9\x94ddw\xfe\xbdq\x13\xbb\xfc\xa9K\xea\x9dC\xa1\xe9' # nopep8 +def fake_connection(keys, signer): + c = mock.Mock() + c.parse_public_keys.return_value = keys + c.sign = signer + return c + + def test_list(): key = formats.import_public_key(NIST256_KEY) key['identity'] = device.interface.Identity('ssh://localhost', 'nist256p1') - h = protocol.Handler(keys=[key], signer=None) + h = protocol.Handler(fake_connection(keys=[key], signer=None)) reply = h.handle(LIST_MSG) assert reply == LIST_NIST256_REPLY def test_unsupported(): - h = protocol.Handler(keys=[], signer=None) + h = protocol.Handler(fake_connection(keys=[], signer=None)) reply = h.handle(b'\x09') assert reply == b'\x00\x00\x00\x01\x05' @@ -38,13 +46,13 @@ def ecdsa_signer(identity, blob): def test_ecdsa_sign(): key = formats.import_public_key(NIST256_KEY) key['identity'] = device.interface.Identity('ssh://localhost', 'nist256p1') - h = protocol.Handler(keys=[key], signer=ecdsa_signer) + h = protocol.Handler(fake_connection(keys=[key], signer=ecdsa_signer)) reply = h.handle(NIST256_SIGN_MSG) assert reply == NIST256_SIGN_REPLY def test_sign_missing(): - h = protocol.Handler(keys=[], signer=ecdsa_signer) + h = protocol.Handler(fake_connection(keys=[], signer=ecdsa_signer)) with pytest.raises(KeyError): h.handle(NIST256_SIGN_MSG) @@ -57,7 +65,7 @@ def test_sign_wrong(): key = formats.import_public_key(NIST256_KEY) key['identity'] = device.interface.Identity('ssh://localhost', 'nist256p1') - h = protocol.Handler(keys=[key], signer=wrong_signature) + h = protocol.Handler(fake_connection(keys=[key], signer=wrong_signature)) with pytest.raises(ValueError): h.handle(NIST256_SIGN_MSG) @@ -68,7 +76,7 @@ def test_sign_cancel(): key = formats.import_public_key(NIST256_KEY) key['identity'] = device.interface.Identity('ssh://localhost', 'nist256p1') - h = protocol.Handler(keys=[key], signer=cancel_signature) + h = protocol.Handler(fake_connection(keys=[key], signer=cancel_signature)) assert h.handle(NIST256_SIGN_MSG) == protocol.failure() @@ -89,6 +97,6 @@ def ed25519_signer(identity, blob): def test_ed25519_sign(): key = formats.import_public_key(ED25519_KEY) key['identity'] = device.interface.Identity('ssh://localhost', 'ed25519') - h = protocol.Handler(keys=[key], signer=ed25519_signer) + h = protocol.Handler(fake_connection(keys=[key], signer=ed25519_signer)) reply = h.handle(ED25519_SIGN_MSG) assert reply == ED25519_SIGN_REPLY diff --git a/trezor_agent/tests/test_server.py b/trezor_agent/tests/test_server.py index 0b4a510..c680470 100644 --- a/trezor_agent/tests/test_server.py +++ b/trezor_agent/tests/test_server.py @@ -37,10 +37,16 @@ class FakeSocket(object): pass +def empty_device(): + c = mock.Mock(spec=['parse_public_keys']) + c.parse_public_keys.return_value = [] + return c + + def test_handle(): mutex = threading.Lock() - handler = protocol.Handler(keys=[], signer=None) + handler = protocol.Handler(conn=empty_device()) conn = FakeSocket() server.handle_connection(conn, handler, mutex) @@ -67,7 +73,6 @@ def test_handle(): def test_server_thread(): - connections = [FakeSocket()] quit_event = threading.Event() @@ -81,8 +86,10 @@ def test_server_thread(): def getsockname(self): # pylint: disable=no-self-use return 'fake_server' - handler = protocol.Handler(keys=[], signer=None), - handle_conn = functools.partial(server.handle_connection, handler=handler) + handler = protocol.Handler(conn=empty_device()), + handle_conn = functools.partial(server.handle_connection, + handler=handler, + mutex=None) server.server_thread(sock=FakeServer(), handle_conn=handle_conn, quit_event=quit_event) @@ -111,7 +118,7 @@ def test_run(): def test_serve_main(): - handler = protocol.Handler(keys=[], signer=None) + handler = protocol.Handler(conn=empty_device()) with server.serve(handler=handler, sock_path=None): pass