mirror of
https://github.com/romanz/amodem.git
synced 2026-03-18 07:36:02 +08:00
Split the package into a shared library and separate per-device packages
This commit is contained in:
166
libagent/server.py
Normal file
166
libagent/server.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""UNIX-domain socket server for ssh-agent implementation."""
|
||||
import contextlib
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
import subprocess
|
||||
import tempfile
|
||||
import threading
|
||||
|
||||
from . import util
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
UNIX_SOCKET_TIMEOUT = 0.1
|
||||
|
||||
|
||||
def remove_file(path, remove=os.remove, exists=os.path.exists):
|
||||
"""Remove file, and raise OSError if still exists."""
|
||||
try:
|
||||
remove(path)
|
||||
except OSError:
|
||||
if exists(path):
|
||||
raise
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def unix_domain_socket_server(sock_path):
|
||||
"""
|
||||
Create UNIX-domain socket on specified path.
|
||||
|
||||
Listen on it, and delete it after the generated context is over.
|
||||
"""
|
||||
log.debug('serving on %s', sock_path)
|
||||
remove_file(sock_path)
|
||||
|
||||
server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
server.bind(sock_path)
|
||||
server.listen(1)
|
||||
try:
|
||||
yield server
|
||||
finally:
|
||||
remove_file(sock_path)
|
||||
|
||||
|
||||
def handle_connection(conn, handler, mutex):
|
||||
"""
|
||||
Handle a single connection using the specified protocol handler in a loop.
|
||||
|
||||
Since this function may be called concurrently from server_thread,
|
||||
the specified mutex is used to synchronize the device handling.
|
||||
|
||||
Exit when EOFError is raised.
|
||||
All other exceptions are logged as warnings.
|
||||
"""
|
||||
try:
|
||||
log.debug('welcome agent')
|
||||
with contextlib.closing(conn):
|
||||
while True:
|
||||
msg = util.read_frame(conn)
|
||||
with mutex:
|
||||
reply = handler.handle(msg=msg)
|
||||
util.send(conn, reply)
|
||||
except EOFError:
|
||||
log.debug('goodbye agent')
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
log.warning('error: %s', e, exc_info=True)
|
||||
|
||||
|
||||
def retry(func, exception_type, quit_event):
|
||||
"""
|
||||
Run the function, retrying when the specified exception_type occurs.
|
||||
|
||||
Poll quit_event on each iteration, to be responsive to an external
|
||||
exit request.
|
||||
"""
|
||||
while True:
|
||||
if quit_event.is_set():
|
||||
raise StopIteration
|
||||
try:
|
||||
return func()
|
||||
except exception_type:
|
||||
pass
|
||||
|
||||
|
||||
def server_thread(sock, handle_conn, quit_event):
|
||||
"""Run a server on the specified socket."""
|
||||
log.debug('server thread started')
|
||||
|
||||
def accept_connection():
|
||||
conn, _ = sock.accept()
|
||||
conn.settimeout(None)
|
||||
return conn
|
||||
|
||||
while True:
|
||||
log.debug('waiting for connection on %s', sock.getsockname())
|
||||
try:
|
||||
conn = retry(accept_connection, socket.timeout, quit_event)
|
||||
except StopIteration:
|
||||
log.debug('server stopped')
|
||||
break
|
||||
# Handle connections from SSH concurrently.
|
||||
threading.Thread(target=handle_conn,
|
||||
kwargs=dict(conn=conn)).start()
|
||||
log.debug('server thread stopped')
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def spawn(func, kwargs):
|
||||
"""Spawn a thread, and join it after the context is over."""
|
||||
t = threading.Thread(target=func, kwargs=kwargs)
|
||||
t.start()
|
||||
yield
|
||||
t.join()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def serve(handler, sock_path=None, timeout=UNIX_SOCKET_TIMEOUT):
|
||||
"""
|
||||
Start the ssh-agent server on a UNIX-domain socket.
|
||||
|
||||
If no connection is made during the specified timeout,
|
||||
retry until the context is over.
|
||||
"""
|
||||
ssh_version = subprocess.check_output(['ssh', '-V'],
|
||||
stderr=subprocess.STDOUT)
|
||||
log.debug('local SSH version: %r', ssh_version)
|
||||
if sock_path is None:
|
||||
sock_path = tempfile.mktemp(prefix='trezor-ssh-agent-')
|
||||
|
||||
environ = {'SSH_AUTH_SOCK': sock_path, 'SSH_AGENT_PID': str(os.getpid())}
|
||||
device_mutex = threading.Lock()
|
||||
with unix_domain_socket_server(sock_path) as sock:
|
||||
sock.settimeout(timeout)
|
||||
quit_event = threading.Event()
|
||||
handle_conn = functools.partial(handle_connection,
|
||||
handler=handler,
|
||||
mutex=device_mutex)
|
||||
kwargs = dict(sock=sock,
|
||||
handle_conn=handle_conn,
|
||||
quit_event=quit_event)
|
||||
with spawn(server_thread, kwargs):
|
||||
try:
|
||||
yield environ
|
||||
finally:
|
||||
log.debug('closing server')
|
||||
quit_event.set()
|
||||
|
||||
|
||||
def run_process(command, environ):
|
||||
"""
|
||||
Run the specified process and wait until it finishes.
|
||||
|
||||
Use environ dict for environment variables.
|
||||
"""
|
||||
log.info('running %r with %r', command, environ)
|
||||
env = dict(os.environ)
|
||||
env.update(environ)
|
||||
try:
|
||||
p = subprocess.Popen(args=command, env=env)
|
||||
except OSError as e:
|
||||
raise OSError('cannot run %r: %s' % (command, e))
|
||||
log.debug('subprocess %d is running', p.pid)
|
||||
ret = p.wait()
|
||||
log.debug('subprocess %d exited: %d', p.pid, ret)
|
||||
return ret
|
||||
Reference in New Issue
Block a user