From 803e3bb738dcec3440a29aedb003871ff1dfa2b6 Mon Sep 17 00:00:00 2001 From: Roman Zeyde Date: Fri, 4 Sep 2015 13:07:35 +0300 Subject: [PATCH] client: require TREZOR v1.3.4 firmware for SSH NIST256P1 curve support --- trezor_agent/tests/test_trezor.py | 23 ++++++++++++++++++----- trezor_agent/trezor/client.py | 9 ++++++++- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/trezor_agent/tests/test_trezor.py b/trezor_agent/tests/test_trezor.py index 3479fa9..231b092 100644 --- a/trezor_agent/tests/test_trezor.py +++ b/trezor_agent/tests/test_trezor.py @@ -2,6 +2,7 @@ from ..trezor import client from .. import formats import mock +import pytest ADDR = [2147483661, 2810943954, 3938368396, 3454558782, 3848009040] @@ -16,14 +17,14 @@ PUBKEY_TEXT = ('ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzd' class ConnectionMock(object): - def __init__(self): + def __init__(self, version): self.features = mock.Mock(spec=[]) self.features.device_id = '123456789' self.features.label = 'mywallet' self.features.vendor = 'mock' - self.features.major_version = 1 - self.features.minor_version = 2 - self.features.patch_version = 3 + self.features.major_version = version[0] + self.features.minor_version = version[1] + self.features.patch_version = version[2] self.features.revision = b'456' self.closed = False @@ -51,7 +52,7 @@ class FactoryMock(object): @staticmethod def client(): - return ConnectionMock() + return ConnectionMock(version=(1, 3, 4)) @staticmethod def identity_type(**kwargs): @@ -120,3 +121,15 @@ def test_utils(): url = 'https://user@host:443/path' assert client.identity_to_string(identity) == url + + +def test_old_version(): + + class OldFactoryMock(FactoryMock): + + @staticmethod + def client(): + return ConnectionMock(version=(1, 2, 3)) + + with pytest.raises(ValueError): + client.Client(factory=OldFactoryMock) diff --git a/trezor_agent/trezor/client.py b/trezor_agent/trezor/client.py index ea60d9a..a256e70 100644 --- a/trezor_agent/trezor/client.py +++ b/trezor_agent/trezor/client.py @@ -13,6 +13,8 @@ log = logging.getLogger(__name__) class Client(object): + MIN_VERSION = [1, 3, 4] + def __init__(self, factory=TrezorFactory): self.factory = factory self.client = self.factory.client() @@ -21,8 +23,13 @@ class Client(object): log.debug('label : %s', f.label) log.debug('vendor : %s', f.vendor) version = [f.major_version, f.minor_version, f.patch_version] - log.debug('version : %s', '.'.join([str(v) for v in version])) + version_str = '.'.join([str(v) for v in version]) + log.debug('version : %s', version_str) log.debug('revision : %s', binascii.hexlify(f.revision)) + if version < self.MIN_VERSION: + fmt = 'Please upgrade your TREZOR to v{}+ firmware' + version_str = '.'.join([str(v) for v in self.MIN_VERSION]) + raise ValueError(fmt.format(version_str)) def __enter__(self): msg = 'Hello World!'