From 7f36097c15a91ed1bbaf6da6a53a783b2287e698 Mon Sep 17 00:00:00 2001 From: Roman Zeyde Date: Fri, 22 Jan 2016 12:04:24 +0200 Subject: [PATCH] tests: refactor mocks and fakes --- trezor_agent/tests/test_factory.py | 32 ++++++++++++++++-------------- trezor_agent/tests/test_trezor.py | 4 ++-- trezor_agent/tests/test_utils.py | 4 ++-- 3 files changed, 21 insertions(+), 19 deletions(-) diff --git a/trezor_agent/tests/test_factory.py b/trezor_agent/tests/test_factory.py index 5371d41..9460826 100644 --- a/trezor_agent/tests/test_factory.py +++ b/trezor_agent/tests/test_factory.py @@ -29,11 +29,13 @@ def test_load(): factory.load(loaders=[double]) -factory_load_client = factory._load_client # pylint: disable=protected-access +def factory_load_client(**kwargs): + # pylint: disable=protected-access + return list(factory._load_client(**kwargs)) def test_load_nothing(): - hid_transport = mock.Mock() + hid_transport = mock.Mock(spec_set=['enumerate']) hid_transport.enumerate.return_value = [] result = factory_load_client( name=None, @@ -42,36 +44,36 @@ def test_load_nothing(): passphrase_ack=None, identity_type=None, required_version=None) - assert list(result) == [] + assert result == [] def create_client_type(version): - conn = mock.Mock() - conn.features = mock.Mock() + conn = mock.Mock(spec=[]) + conn.features = mock.Mock(spec=[]) major, minor, patch = version.split('.') + conn.features.device_id = 'DEVICE_ID' + conn.features.label = 'LABEL' + conn.features.vendor = 'VENDOR' conn.features.major_version = major conn.features.minor_version = minor conn.features.patch_version = patch conn.features.revision = b'\x12\x34\x56\x78' - client_type = mock.Mock() - client_type.return_value = conn - return client_type + return mock.Mock(spec_set=[], return_value=conn) def test_load_single(): - hid_transport = mock.Mock() + hid_transport = mock.Mock(spec_set=['enumerate']) hid_transport.enumerate.return_value = [0] for version in ('1.3.4', '1.3.5', '1.4.0', '2.0.0'): - passphrase_ack = mock.Mock() + passphrase_ack = mock.Mock(spec_set=[]) client_type = create_client_type(version) - result = factory_load_client( + client_wrapper, = factory_load_client( name='DEVICE_NAME', client_type=client_type, hid_transport=hid_transport, passphrase_ack=passphrase_ack, identity_type=None, required_version='>=1.3.4') - client_wrapper, = result assert client_wrapper.connection is client_type.return_value assert client_wrapper.device_name == 'DEVICE_NAME' client_wrapper.connection.callback_PassphraseRequest('MESSAGE') @@ -79,14 +81,14 @@ def test_load_single(): def test_load_old(): - hid_transport = mock.Mock() + hid_transport = mock.Mock(spec_set=['enumerate']) hid_transport.enumerate.return_value = [0] for version in ('1.3.3', '1.2.5', '1.1.0', '0.9.9'): with pytest.raises(ValueError): - next(factory_load_client( + factory_load_client( name='DEVICE_NAME', client_type=create_client_type(version), hid_transport=hid_transport, passphrase_ack=None, identity_type=None, - required_version='>=1.3.4')) + required_version='>=1.3.4') diff --git a/trezor_agent/tests/test_trezor.py b/trezor_agent/tests/test_trezor.py index 0af55ca..f14e3f0 100644 --- a/trezor_agent/tests/test_trezor.py +++ b/trezor_agent/tests/test_trezor.py @@ -15,7 +15,7 @@ PUBKEY_TEXT = ('ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzd' 'VUfhvrGljR2Z/CMRONY6ejB+9PnpUOPuzYqi8= ssh://localhost:22\n') -class ConnectionMock(object): +class FakeConnection(object): def __init__(self): self.closed = False @@ -51,7 +51,7 @@ def identity_type(**kwargs): def load_client(): - return factory.ClientWrapper(connection=ConnectionMock(), + return factory.ClientWrapper(connection=FakeConnection(), identity_type=identity_type, device_name='DEVICE_NAME') diff --git a/trezor_agent/tests/test_utils.py b/trezor_agent/tests/test_utils.py index 2c9e9bd..26f66a9 100644 --- a/trezor_agent/tests/test_utils.py +++ b/trezor_agent/tests/test_utils.py @@ -24,7 +24,7 @@ def test_frames(): assert util.read_frame(io.BytesIO(f)) == b''.join(msgs) -class SocketMock(object): +class FakeSocket(object): def __init__(self): self.buf = io.BytesIO() @@ -36,7 +36,7 @@ class SocketMock(object): def test_send_recv(): - s = SocketMock() + s = FakeSocket() util.send(s, b'123') util.send(s, data=[42], fmt='B') assert s.buf.getvalue() == b'123*'