Trezor: restructure code to support python-trezor 0.11

This commit is contained in:
matejcik
2018-12-10 16:09:55 +01:00
parent a30cab1156
commit 2cb64991c3
3 changed files with 83 additions and 83 deletions

View File

@@ -3,11 +3,9 @@
import binascii
import logging
import mnemonic
import semver
from . import interface
from .. import util
log = logging.getLogger(__name__)
@@ -28,66 +26,8 @@ class Trezor(interface.Device):
required_version = '>=1.4.0'
ui = None # can be overridden by device's users
def _override_pin_handler(self, conn):
if self.ui is None:
return
def new_handler(_):
try:
scrambled_pin = self.ui.get_pin()
result = self._defs.PinMatrixAck(pin=scrambled_pin)
if not set(scrambled_pin).issubset('123456789'):
raise self._defs.PinException(
None, 'Invalid scrambled PIN: {!r}'.format(result.pin))
return result
except: # noqa
conn.init_device()
raise
conn.callback_PinMatrixRequest = new_handler
cached_passphrase_ack = util.ExpiringCache(seconds=float('inf'))
cached_state = None
def _override_passphrase_handler(self, conn):
if self.ui is None:
return
def new_handler(msg):
try:
if msg.on_device is True:
return self._defs.PassphraseAck()
ack = self.__class__.cached_passphrase_ack.get()
if ack:
log.debug('re-using cached %s passphrase', self)
return ack
passphrase = self.ui.get_passphrase()
passphrase = mnemonic.Mnemonic.normalize_string(passphrase)
ack = self._defs.PassphraseAck(passphrase=passphrase)
length = len(ack.passphrase)
if length > 50:
msg = 'Too long passphrase ({} chars)'.format(length)
raise ValueError(msg)
self.__class__.cached_passphrase_ack.set(ack)
return ack
except: # noqa
conn.init_device()
raise
conn.callback_PassphraseRequest = new_handler
def _override_state_handler(self, conn):
def callback_PassphraseStateRequest(msg):
log.debug('caching state from %r', msg)
self.__class__.cached_state = msg.state
return self._defs.PassphraseStateAck()
conn.callback_PassphraseStateRequest = callback_PassphraseStateRequest
def _verify_version(self, connection):
f = connection.features
log.debug('connected to %s %s', self, f.device_id)
@@ -113,10 +53,8 @@ class Trezor(interface.Device):
log.debug('using transport: %s', transport)
for _ in range(5): # Retry a few times in case of PIN failures
connection = self._defs.Client(transport=transport,
ui=self.ui,
state=self.__class__.cached_state)
self._override_pin_handler(connection)
self._override_passphrase_handler(connection)
self._override_state_handler(connection)
self._verify_version(connection)
try:
@@ -132,7 +70,8 @@ class Trezor(interface.Device):
def close(self):
"""Close connection."""
self.conn.close()
self.__class__.cached_state = self.conn.state
super().close()
def pubkey(self, identity, ecdh=False):
"""Return public key."""
@@ -140,8 +79,10 @@ class Trezor(interface.Device):
log.debug('"%s" getting public key (%s) from %s',
identity.to_string(), curve_name, self)
addr = identity.get_bip32_address(ecdh=ecdh)
result = self.conn.get_public_node(
n=addr, ecdsa_curve_name=curve_name)
result = self._defs.get_public_node(
self.conn,
n=addr,
ecdsa_curve_name=curve_name)
log.debug('result: %s', result)
return bytes(result.node.public_key)
@@ -157,7 +98,8 @@ class Trezor(interface.Device):
log.debug('"%s" signing %r (%s) on %s',
identity.to_string(), blob, curve_name, self)
try:
result = self.conn.sign_identity(
result = self._defs.sign_identity(
self.conn,
identity=self._identity_proto(identity),
challenge_hidden=blob,
challenge_visual='',
@@ -166,7 +108,7 @@ class Trezor(interface.Device):
assert len(result.signature) == 65
assert result.signature[:1] == b'\x00'
return bytes(result.signature[1:])
except self._defs.CallException as e:
except self._defs.TrezorFailure as e:
msg = '{} error: {}'.format(self, e)
log.debug(msg, exc_info=True)
raise interface.DeviceError(msg)
@@ -177,7 +119,8 @@ class Trezor(interface.Device):
log.debug('"%s" shared session key (%s) for %r from %s',
identity.to_string(), curve_name, pubkey, self)
try:
result = self.conn.get_ecdh_session_key(
result = self._defs.get_ecdh_session_key(
self.conn,
identity=self._identity_proto(identity),
peer_public_key=pubkey,
ecdsa_curve_name=curve_name)
@@ -185,7 +128,7 @@ class Trezor(interface.Device):
assert len(result.session_key) in {65, 33} # NIST256 or Curve25519
assert result.session_key[:1] == b'\x04'
return bytes(result.session_key)
except self._defs.CallException as e:
except self._defs.TrezorFailure as e:
msg = '{} error: {}'.format(self, e)
log.debug(msg, exc_info=True)
raise interface.DeviceError(msg)