Trezor: restructure code to support python-trezor 0.11

This commit is contained in:
matejcik
2018-12-10 16:09:55 +01:00
parent a30cab1156
commit 2cb64991c3
3 changed files with 83 additions and 83 deletions

View File

@@ -3,11 +3,9 @@
import binascii import binascii
import logging import logging
import mnemonic
import semver import semver
from . import interface from . import interface
from .. import util
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@@ -28,66 +26,8 @@ class Trezor(interface.Device):
required_version = '>=1.4.0' required_version = '>=1.4.0'
ui = None # can be overridden by device's users ui = None # can be overridden by device's users
def _override_pin_handler(self, conn):
if self.ui is None:
return
def new_handler(_):
try:
scrambled_pin = self.ui.get_pin()
result = self._defs.PinMatrixAck(pin=scrambled_pin)
if not set(scrambled_pin).issubset('123456789'):
raise self._defs.PinException(
None, 'Invalid scrambled PIN: {!r}'.format(result.pin))
return result
except: # noqa
conn.init_device()
raise
conn.callback_PinMatrixRequest = new_handler
cached_passphrase_ack = util.ExpiringCache(seconds=float('inf'))
cached_state = None cached_state = None
def _override_passphrase_handler(self, conn):
if self.ui is None:
return
def new_handler(msg):
try:
if msg.on_device is True:
return self._defs.PassphraseAck()
ack = self.__class__.cached_passphrase_ack.get()
if ack:
log.debug('re-using cached %s passphrase', self)
return ack
passphrase = self.ui.get_passphrase()
passphrase = mnemonic.Mnemonic.normalize_string(passphrase)
ack = self._defs.PassphraseAck(passphrase=passphrase)
length = len(ack.passphrase)
if length > 50:
msg = 'Too long passphrase ({} chars)'.format(length)
raise ValueError(msg)
self.__class__.cached_passphrase_ack.set(ack)
return ack
except: # noqa
conn.init_device()
raise
conn.callback_PassphraseRequest = new_handler
def _override_state_handler(self, conn):
def callback_PassphraseStateRequest(msg):
log.debug('caching state from %r', msg)
self.__class__.cached_state = msg.state
return self._defs.PassphraseStateAck()
conn.callback_PassphraseStateRequest = callback_PassphraseStateRequest
def _verify_version(self, connection): def _verify_version(self, connection):
f = connection.features f = connection.features
log.debug('connected to %s %s', self, f.device_id) log.debug('connected to %s %s', self, f.device_id)
@@ -113,10 +53,8 @@ class Trezor(interface.Device):
log.debug('using transport: %s', transport) log.debug('using transport: %s', transport)
for _ in range(5): # Retry a few times in case of PIN failures for _ in range(5): # Retry a few times in case of PIN failures
connection = self._defs.Client(transport=transport, connection = self._defs.Client(transport=transport,
ui=self.ui,
state=self.__class__.cached_state) state=self.__class__.cached_state)
self._override_pin_handler(connection)
self._override_passphrase_handler(connection)
self._override_state_handler(connection)
self._verify_version(connection) self._verify_version(connection)
try: try:
@@ -132,7 +70,8 @@ class Trezor(interface.Device):
def close(self): def close(self):
"""Close connection.""" """Close connection."""
self.conn.close() self.__class__.cached_state = self.conn.state
super().close()
def pubkey(self, identity, ecdh=False): def pubkey(self, identity, ecdh=False):
"""Return public key.""" """Return public key."""
@@ -140,8 +79,10 @@ class Trezor(interface.Device):
log.debug('"%s" getting public key (%s) from %s', log.debug('"%s" getting public key (%s) from %s',
identity.to_string(), curve_name, self) identity.to_string(), curve_name, self)
addr = identity.get_bip32_address(ecdh=ecdh) addr = identity.get_bip32_address(ecdh=ecdh)
result = self.conn.get_public_node( result = self._defs.get_public_node(
n=addr, ecdsa_curve_name=curve_name) self.conn,
n=addr,
ecdsa_curve_name=curve_name)
log.debug('result: %s', result) log.debug('result: %s', result)
return bytes(result.node.public_key) return bytes(result.node.public_key)
@@ -157,7 +98,8 @@ class Trezor(interface.Device):
log.debug('"%s" signing %r (%s) on %s', log.debug('"%s" signing %r (%s) on %s',
identity.to_string(), blob, curve_name, self) identity.to_string(), blob, curve_name, self)
try: try:
result = self.conn.sign_identity( result = self._defs.sign_identity(
self.conn,
identity=self._identity_proto(identity), identity=self._identity_proto(identity),
challenge_hidden=blob, challenge_hidden=blob,
challenge_visual='', challenge_visual='',
@@ -166,7 +108,7 @@ class Trezor(interface.Device):
assert len(result.signature) == 65 assert len(result.signature) == 65
assert result.signature[:1] == b'\x00' assert result.signature[:1] == b'\x00'
return bytes(result.signature[1:]) return bytes(result.signature[1:])
except self._defs.CallException as e: except self._defs.TrezorFailure as e:
msg = '{} error: {}'.format(self, e) msg = '{} error: {}'.format(self, e)
log.debug(msg, exc_info=True) log.debug(msg, exc_info=True)
raise interface.DeviceError(msg) raise interface.DeviceError(msg)
@@ -177,7 +119,8 @@ class Trezor(interface.Device):
log.debug('"%s" shared session key (%s) for %r from %s', log.debug('"%s" shared session key (%s) for %r from %s',
identity.to_string(), curve_name, pubkey, self) identity.to_string(), curve_name, pubkey, self)
try: try:
result = self.conn.get_ecdh_session_key( result = self._defs.get_ecdh_session_key(
self.conn,
identity=self._identity_proto(identity), identity=self._identity_proto(identity),
peer_public_key=pubkey, peer_public_key=pubkey,
ecdsa_curve_name=curve_name) ecdsa_curve_name=curve_name)
@@ -185,7 +128,7 @@ class Trezor(interface.Device):
assert len(result.session_key) in {65, 33} # NIST256 or Curve25519 assert len(result.session_key) in {65, 33} # NIST256 or Curve25519
assert result.session_key[:1] == b'\x04' assert result.session_key[:1] == b'\x04'
return bytes(result.session_key) return bytes(result.session_key)
except self._defs.CallException as e: except self._defs.TrezorFailure as e:
msg = '{} error: {}'.format(self, e) msg = '{} error: {}'.format(self, e)
log.debug(msg, exc_info=True) log.debug(msg, exc_info=True)
raise interface.DeviceError(msg) raise interface.DeviceError(msg)

View File

@@ -4,19 +4,72 @@
import os import os
import logging import logging
from trezorlib.client import CallException, PinException import mnemonic
from trezorlib.client import TrezorClient as Client import semver
from trezorlib.messages import IdentityType, PassphraseAck, PinMatrixAck, PassphraseStateAck import trezorlib
try:
from trezorlib.transport import get_transport
except ImportError:
from trezorlib.device import TrezorDevice
get_transport = TrezorDevice.find_by_path
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
if semver.match(trezorlib.__version__, ">=0.11.0"):
from trezorlib.client import TrezorClient as Client
from trezorlib.exceptions import TrezorFailure, PinException
from trezorlib.transport import get_transport
from trezorlib.messages import IdentityType
from trezorlib.btc import get_public_node
from trezorlib.misc import sign_identity, get_ecdh_session_key
else:
from trezorlib.client import (TrezorClient, CallException as TrezorFailure,
PinException)
from trezorlib.messages import IdentityType
from trezorlib import messages
from trezorlib.transport import get_transport
get_public_node = TrezorClient.get_public_node
sign_identity = TrezorClient.sign_identity
get_ecdh_session_key = TrezorClient.get_ecdh_session_key
class Client(TrezorClient):
def __init__(self, transport, ui, state=None):
super().__init__(transport, state=state)
self.ui = ui
def callback_PinMatrixRequest(self, msg):
try:
pin = self.ui.get_pin(msg.type)
if not pin.isdigit():
raise PinException(
None, 'Invalid scrambled PIN: {!r}'.format(pin))
return messages.PinMatrixAck(pin=pin)
except: # noqa
self.init_device()
raise
def callback_PassphraseRequest(self, msg):
try:
if msg.on_device is True:
return messages.PassphraseAck()
passphrase = self.ui.get_passphrase()
passphrase = mnemonic.Mnemonic.normalize_string(passphrase)
length = len(passphrase)
if length > 50:
msg = 'Too long passphrase ({} chars)'.format(length)
raise ValueError(msg)
return messages.PassphraseAck(passphrase=passphrase)
except: # noqa
self.init_device()
raise
def callback_PassphraseStateRequest(self, msg):
self.state = msg.state
return messages.PassphraseStateAck()
def find_device(): def find_device():
"""Selects a transport based on `TREZOR_PATH` environment variable. """Selects a transport based on `TREZOR_PATH` environment variable.

View File

@@ -24,7 +24,7 @@ class UI:
self.options_getter = create_default_options_getter() self.options_getter = create_default_options_getter()
self.device_name = device_type.__name__ self.device_name = device_type.__name__
def get_pin(self, name=None): def get_pin(self, _code=None):
"""Ask the user for (scrambled) PIN.""" """Ask the user for (scrambled) PIN."""
description = ( description = (
'Use the numeric keypad to describe number positions.\n' 'Use the numeric keypad to describe number positions.\n'
@@ -33,21 +33,25 @@ class UI:
' 4 5 6\n' ' 4 5 6\n'
' 1 2 3') ' 1 2 3')
return interact( return interact(
title='{} PIN'.format(name or self.device_name), title='{} PIN'.format(self.device_name),
prompt='PIN:', prompt='PIN:',
description=description, description=description,
binary=self.pin_entry_binary, binary=self.pin_entry_binary,
options=self.options_getter()) options=self.options_getter())
def get_passphrase(self, name=None): def get_passphrase(self):
"""Ask the user for passphrase.""" """Ask the user for passphrase."""
return interact( return interact(
title='{} passphrase'.format(name or self.device_name), title='{} passphrase'.format(self.device_name),
prompt='Passphrase:', prompt='Passphrase:',
description=None, description=None,
binary=self.passphrase_entry_binary, binary=self.passphrase_entry_binary,
options=self.options_getter()) options=self.options_getter())
def button_request(self, _code=None):
# XXX: show notification to the user?
pass
def create_default_options_getter(): def create_default_options_getter():
"""Return current TTY and DISPLAY settings for GnuPG pinentry.""" """Return current TTY and DISPLAY settings for GnuPG pinentry."""