server: stop the server via a threading.Event

It seems that Mac OS does not support calling socket.shutdown(socket.SHUT_RD)
on a listening socket (see https://github.com/romanz/trezor-agent/issues/6).
The following implementation will set the accept() timeout to 0.1s and stop
the server if a threading.Event (named "quit_event") is set by the main thread.
This commit is contained in:
Roman Zeyde
2016-01-08 20:28:38 +02:00
parent 7ea20c7009
commit fb0d0a5f61
3 changed files with 52 additions and 34 deletions

View File

@@ -1,5 +1,6 @@
import tempfile
import socket
import threading
import os
import io
import pytest
@@ -16,7 +17,7 @@ def test_socket():
assert not os.path.isfile(path)
class SocketMock(object):
class FakeSocket(object):
def __init__(self, data=b''):
self.rx = io.BytesIO(data)
@@ -34,16 +35,16 @@ class SocketMock(object):
def test_handle():
handler = protocol.Handler(keys=[], signer=None)
conn = SocketMock()
conn = FakeSocket()
server.handle_connection(conn, handler)
msg = bytearray([protocol.SSH_AGENTC_REQUEST_RSA_IDENTITIES])
conn = SocketMock(util.frame(msg))
conn = FakeSocket(util.frame(msg))
server.handle_connection(conn, handler)
assert conn.tx.getvalue() == b'\x00\x00\x00\x05\x02\x00\x00\x00\x00'
msg = bytearray([protocol.SSH2_AGENTC_REQUEST_IDENTITIES])
conn = SocketMock(util.frame(msg))
conn = FakeSocket(util.frame(msg))
server.handle_connection(conn, handler)
assert conn.tx.getvalue() == b'\x00\x00\x00\x05\x0C\x00\x00\x00\x00'
@@ -51,25 +52,24 @@ def test_handle():
server.handle_connection(conn=None, handler=None)
class ServerMock(object):
def __init__(self, connections, name):
self.connections = connections
self.name = name
def getsockname(self):
return self.name
def accept(self):
if self.connections:
return self.connections.pop(), 'address'
raise socket.error('stop')
def test_server_thread():
s = ServerMock(connections=[SocketMock()], name='mock')
h = protocol.Handler(keys=[], signer=None)
server.server_thread(s, h)
connections = [FakeSocket()]
quit_event = threading.Event()
class FakeServer(object):
def accept(self): # pylint: disable=no-self-use
if connections:
return connections.pop(), 'address'
quit_event.set()
raise socket.timeout()
def getsockname(self): # pylint: disable=no-self-use
return 'fake_server'
server.server_thread(server=FakeServer(),
handler=protocol.Handler(keys=[], signer=None),
quit_event=quit_event)
def test_spawn():
@@ -78,7 +78,7 @@ def test_spawn():
def thread(x):
obj.append(x)
with server.spawn(thread, x=1):
with server.spawn(thread, dict(x=1)):
pass
assert obj == [1]