diff --git a/trezor_agent/factory.py b/trezor_agent/factory.py deleted file mode 100644 index 1af7eb1..0000000 --- a/trezor_agent/factory.py +++ /dev/null @@ -1,124 +0,0 @@ -"""Thin wrapper around trezor/keepkey libraries.""" -from __future__ import absolute_import - -import binascii -import collections -import logging - -import semver - -log = logging.getLogger(__name__) - -ClientWrapper = collections.namedtuple( - 'ClientWrapper', - ['connection', 'identity_type', 'device_name', 'call_exception']) - - -# pylint: disable=too-many-arguments -def _load_client(name, client_type, hid_transport, - passphrase_ack, identity_type, - required_version, call_exception): - - def empty_passphrase_handler(_): - return passphrase_ack(passphrase='') - - for d in hid_transport.enumerate(): - connection = client_type(hid_transport(d)) - connection.callback_PassphraseRequest = empty_passphrase_handler - f = connection.features - log.debug('connected to %s %s', name, f.device_id) - log.debug('label : %s', f.label) - log.debug('vendor : %s', f.vendor) - current_version = '{}.{}.{}'.format(f.major_version, - f.minor_version, - f.patch_version) - log.debug('version : %s', current_version) - log.debug('revision : %s', binascii.hexlify(f.revision)) - if not semver.match(current_version, required_version): - fmt = 'Please upgrade your {} firmware to {} version (current: {})' - raise ValueError(fmt.format(name, - required_version, - current_version)) - yield ClientWrapper(connection=connection, - identity_type=identity_type, - device_name=name, - call_exception=call_exception) - return - - -def _load_trezor(): - try: - from trezorlib.client import TrezorClient, CallException - from trezorlib.transport_hid import HidTransport - from trezorlib.messages_pb2 import PassphraseAck - from trezorlib.types_pb2 import IdentityType - return _load_client(name='Trezor', - client_type=TrezorClient, - hid_transport=HidTransport, - passphrase_ack=PassphraseAck, - identity_type=IdentityType, - required_version='>=1.4.0', - call_exception=CallException) - except ImportError as e: - log.warning('%s: install via "pip install trezor" ' - 'if you need to support this device', e) - - -def _load_keepkey(): - try: - from keepkeylib.client import KeepKeyClient, CallException - from keepkeylib.transport_hid import HidTransport - from keepkeylib.messages_pb2 import PassphraseAck - from keepkeylib.types_pb2 import IdentityType - return _load_client(name='KeepKey', - client_type=KeepKeyClient, - hid_transport=HidTransport, - passphrase_ack=PassphraseAck, - identity_type=IdentityType, - required_version='>=1.0.4', - call_exception=CallException) - except ImportError as e: - log.warning('%s: install via "pip install keepkey" ' - 'if you need to support this device', e) - - -def _load_ledger(): - from ._ledger import LedgerClientConnection, CallException, IdentityType - try: - from ledgerblue.comm import getDongle, CommException - except ImportError as e: - log.warning('%s: install via "pip install ledgerblue" ' - 'if you need to support this device', e) - return - try: - dongle = getDongle() - except CommException: - return - - yield ClientWrapper(connection=LedgerClientConnection(dongle), - identity_type=IdentityType, - device_name="ledger", - call_exception=CallException) - - -LOADERS = [ - _load_trezor, - _load_keepkey, - _load_ledger -] - - -def load(loaders=None): - """Load a single device, via specified loaders' list.""" - loaders = loaders if loaders is not None else LOADERS - device_list = [] - for loader in loaders: - device = loader() - if device: - device_list.extend(device) - - if len(device_list) == 1: - return device_list[0] - - msg = '{:d} devices found'.format(len(device_list)) - raise IOError(msg) diff --git a/trezor_agent/tests/test_factory.py b/trezor_agent/tests/test_factory.py deleted file mode 100644 index b904666..0000000 --- a/trezor_agent/tests/test_factory.py +++ /dev/null @@ -1,97 +0,0 @@ -import mock -import pytest - -from .. import factory - - -def test_load(): - - def single(): - return [0] - - def nothing(): - return [] - - def double(): - return [1, 2] - - assert factory.load(loaders=[single]) == 0 - assert factory.load(loaders=[single, nothing]) == 0 - assert factory.load(loaders=[nothing, single]) == 0 - - with pytest.raises(IOError): - factory.load(loaders=[]) - - with pytest.raises(IOError): - factory.load(loaders=[single, single]) - - with pytest.raises(IOError): - factory.load(loaders=[double]) - - -def factory_load_client(**kwargs): - # pylint: disable=protected-access - return list(factory._load_client(**kwargs)) - - -def test_load_nothing(): - hid_transport = mock.Mock(spec_set=['enumerate']) - hid_transport.enumerate.return_value = [] - result = factory_load_client( - name=None, - client_type=None, - hid_transport=hid_transport, - passphrase_ack=None, - identity_type=None, - required_version=None, - call_exception=None) - assert result == [] - - -def create_client_type(version): - conn = mock.Mock(spec=[]) - conn.features = mock.Mock(spec=[]) - major, minor, patch = version.split('.') - conn.features.device_id = 'DEVICE_ID' - conn.features.label = 'LABEL' - conn.features.vendor = 'VENDOR' - conn.features.major_version = major - conn.features.minor_version = minor - conn.features.patch_version = patch - conn.features.revision = b'\x12\x34\x56\x78' - return mock.Mock(spec_set=[], return_value=conn) - - -def test_load_single(): - hid_transport = mock.Mock(spec_set=['enumerate']) - hid_transport.enumerate.return_value = [0] - for version in ('1.3.4', '1.3.5', '1.4.0', '2.0.0'): - passphrase_ack = mock.Mock(spec_set=[]) - client_type = create_client_type(version) - client_wrapper, = factory_load_client( - name='DEVICE_NAME', - client_type=client_type, - hid_transport=hid_transport, - passphrase_ack=passphrase_ack, - identity_type=None, - required_version='>=1.3.4', - call_exception=None) - assert client_wrapper.connection is client_type.return_value - assert client_wrapper.device_name == 'DEVICE_NAME' - client_wrapper.connection.callback_PassphraseRequest('MESSAGE') - assert passphrase_ack.mock_calls == [mock.call(passphrase='')] - - -def test_load_old(): - hid_transport = mock.Mock(spec_set=['enumerate']) - hid_transport.enumerate.return_value = [0] - for version in ('1.3.3', '1.2.5', '1.1.0', '0.9.9'): - with pytest.raises(ValueError): - factory_load_client( - name='DEVICE_NAME', - client_type=create_client_type(version), - hid_transport=hid_transport, - passphrase_ack=None, - identity_type=None, - required_version='>=1.3.4', - call_exception=None)