From ee2f6b75dce52ad0969f8bac8d34d837a1b97d9b Mon Sep 17 00:00:00 2001 From: Roman Zeyde Date: Sun, 23 Oct 2016 17:05:20 +0300 Subject: [PATCH 1/2] server: log SSH version for debugging --- trezor_agent/server.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/trezor_agent/server.py b/trezor_agent/server.py index d4059ac..3f47a42 100644 --- a/trezor_agent/server.py +++ b/trezor_agent/server.py @@ -116,6 +116,9 @@ def serve(handler, sock_path=None, timeout=UNIX_SOCKET_TIMEOUT): If no connection is made during the specified timeout, retry until the context is over. """ + ssh_version = subprocess.check_output(['ssh', '-V'], + stderr=subprocess.STDOUT) + log.debug('local SSH version: %r', ssh_version) if sock_path is None: sock_path = tempfile.mktemp(prefix='ssh-agent-') From 97efdf4a45d503ec776d89c427f1b4a1522f2aa4 Mon Sep 17 00:00:00 2001 From: Roman Zeyde Date: Sun, 23 Oct 2016 17:35:12 +0300 Subject: [PATCH 2/2] ssh: handle connections concurrently --- trezor_agent/server.py | 25 +++++++++++++++++-------- trezor_agent/tests/test_server.py | 12 +++++++----- 2 files changed, 24 insertions(+), 13 deletions(-) diff --git a/trezor_agent/server.py b/trezor_agent/server.py index 3f47a42..6f1ce88 100644 --- a/trezor_agent/server.py +++ b/trezor_agent/server.py @@ -43,19 +43,24 @@ def unix_domain_socket_server(sock_path): remove_file(sock_path) -def handle_connection(conn, handler): +def handle_connection(conn, handler, mutex): """ Handle a single connection using the specified protocol handler in a loop. + Since this function may be called concurrently from server_thread, + the specified mutex is used to synchronize the device handling. + Exit when EOFError is raised. All other exceptions are logged as warnings. """ try: log.debug('welcome agent') - while True: - msg = util.read_frame(conn) - reply = handler.handle(msg=msg) - util.send(conn, reply) + with contextlib.closing(conn): + while True: + msg = util.read_frame(conn) + with mutex: + reply = handler.handle(msg=msg) + util.send(conn, reply) except EOFError: log.debug('goodbye agent') except Exception as e: # pylint: disable=broad-except @@ -94,8 +99,9 @@ def server_thread(sock, handle_conn, quit_event): except StopIteration: log.debug('server stopped') break - with contextlib.closing(conn): - handle_conn(conn) + # Handle connections from SSH concurrently. + threading.Thread(target=handle_conn, + kwargs=dict(conn=conn)).start() log.debug('server thread stopped') @@ -123,10 +129,13 @@ def serve(handler, sock_path=None, timeout=UNIX_SOCKET_TIMEOUT): sock_path = tempfile.mktemp(prefix='ssh-agent-') environ = {'SSH_AUTH_SOCK': sock_path, 'SSH_AGENT_PID': str(os.getpid())} + device_mutex = threading.Lock() with unix_domain_socket_server(sock_path) as sock: sock.settimeout(timeout) quit_event = threading.Event() - handle_conn = functools.partial(handle_connection, handler=handler) + handle_conn = functools.partial(handle_connection, + handler=handler, + mutex=device_mutex) kwargs = dict(sock=sock, handle_conn=handle_conn, quit_event=quit_event) diff --git a/trezor_agent/tests/test_server.py b/trezor_agent/tests/test_server.py index af7f27d..0b4a510 100644 --- a/trezor_agent/tests/test_server.py +++ b/trezor_agent/tests/test_server.py @@ -38,30 +38,32 @@ class FakeSocket(object): def test_handle(): + mutex = threading.Lock() + handler = protocol.Handler(keys=[], signer=None) conn = FakeSocket() - server.handle_connection(conn, handler) + server.handle_connection(conn, handler, mutex) msg = bytearray([protocol.msg_code('SSH_AGENTC_REQUEST_RSA_IDENTITIES')]) conn = FakeSocket(util.frame(msg)) - server.handle_connection(conn, handler) + server.handle_connection(conn, handler, mutex) assert conn.tx.getvalue() == b'\x00\x00\x00\x05\x02\x00\x00\x00\x00' msg = bytearray([protocol.msg_code('SSH2_AGENTC_REQUEST_IDENTITIES')]) conn = FakeSocket(util.frame(msg)) - server.handle_connection(conn, handler) + server.handle_connection(conn, handler, mutex) assert conn.tx.getvalue() == b'\x00\x00\x00\x05\x0C\x00\x00\x00\x00' msg = bytearray([protocol.msg_code('SSH2_AGENTC_ADD_IDENTITY')]) conn = FakeSocket(util.frame(msg)) - server.handle_connection(conn, handler) + server.handle_connection(conn, handler, mutex) conn.tx.seek(0) reply = util.read_frame(conn.tx) assert reply == util.pack('B', protocol.msg_code('SSH_AGENT_FAILURE')) conn_mock = mock.Mock(spec=FakeSocket) conn_mock.recv.side_effect = [Exception, EOFError] - server.handle_connection(conn=conn_mock, handler=None) + server.handle_connection(conn=conn_mock, handler=None, mutex=mutex) def test_server_thread():