diff --git a/sshagent/server.py b/sshagent/server.py index 488173d..1e741f5 100644 --- a/sshagent/server.py +++ b/sshagent/server.py @@ -68,7 +68,23 @@ def spawn(func, **kwargs): t.join() -def run(command, environ): +@contextlib.contextmanager +def serve(key_files, signer, sock_path=None): + if sock_path is None: + sock_path = tempfile.mktemp(prefix='ssh-agent-') + + keys = [formats.parse_public_key(k) for k in key_files] + environ = {'SSH_AUTH_SOCK': sock_path, 'SSH_AGENT_PID': str(os.getpid())} + with unix_domain_socket_server(sock_path) as server: + with spawn(server_thread, server=server, keys=keys, signer=signer): + try: + yield environ + finally: + log.debug('closing server') + server.shutdown(socket.SHUT_RD) + + +def run_process(command, environ): log.debug('running %r with %r', command, environ) env = dict(os.environ) env.update(environ) @@ -80,19 +96,3 @@ def run(command, environ): ret = p.wait() log.debug('subprocess %d exited: %d', p.pid, ret) return ret - - -def serve(key_files, command, signer, sock_path=None): - if sock_path is None: - sock_path = tempfile.mktemp(prefix='ssh-agent-') - - keys = [formats.parse_public_key(k) for k in key_files] - environ = {'SSH_AUTH_SOCK': sock_path, 'SSH_AGENT_PID': str(os.getpid())} - with unix_domain_socket_server(sock_path) as server: - with spawn(server_thread, server=server, keys=keys, signer=signer): - try: - ret = run(command=command, environ=environ) - finally: - log.debug('closing server') - server.shutdown(socket.SHUT_RD) - return ret diff --git a/sshagent/trezor_agent.py b/sshagent/trezor_agent.py index 46aa82f..97d56eb 100644 --- a/sshagent/trezor_agent.py +++ b/sshagent/trezor_agent.py @@ -22,30 +22,24 @@ def main(): level = verbosity[min(args.verbose, len(verbosity) - 1)] logging.basicConfig(level=level, format=fmt) - client = trezor.Client(factory=trezor.TrezorLibrary) + with trezor.Client(factory=trezor.TrezorLibrary) as client: + key_files = [] + for label in args.labels: + pubkey = client.get_public_key(label=label) + key_file = formats.export_public_key(pubkey=pubkey, label=label) + key_files.append(key_file) - key_files = [] - for label in args.labels: - pubkey = client.get_public_key(label=label) - key_file = formats.export_public_key(pubkey=pubkey, label=label) - key_files.append(key_file) + if not args.command: + sys.stdout.write(''.join(key_files)) + return - if not args.command: - sys.stdout.write(''.join(key_files)) - return + signer = client.sign_ssh_challenge - signer = client.sign_ssh_challenge - - ret = -1 - try: - ret = server.serve( - key_files=key_files, - command=args.command, - signer=signer) - log.info('exitcode: %d', ret) - except KeyboardInterrupt: - log.info('server stopped') - sys.exit(ret) + try: + with server.serve(key_files=key_files, signer=signer) as env: + return server.run_process(command=args.command, environ=env) + except KeyboardInterrupt: + log.info('server stopped') if __name__ == '__main__': - main() + sys.exit(main())