mirror of
https://github.com/romanz/amodem.git
synced 2026-03-23 10:41:01 +08:00
Trezor: restructure code to support python-trezor 0.11
This commit is contained in:
@@ -3,11 +3,9 @@
|
||||
import binascii
|
||||
import logging
|
||||
|
||||
import mnemonic
|
||||
import semver
|
||||
|
||||
from . import interface
|
||||
from .. import util
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -28,66 +26,8 @@ class Trezor(interface.Device):
|
||||
required_version = '>=1.4.0'
|
||||
|
||||
ui = None # can be overridden by device's users
|
||||
|
||||
def _override_pin_handler(self, conn):
|
||||
if self.ui is None:
|
||||
return
|
||||
|
||||
def new_handler(_):
|
||||
try:
|
||||
scrambled_pin = self.ui.get_pin()
|
||||
result = self._defs.PinMatrixAck(pin=scrambled_pin)
|
||||
if not set(scrambled_pin).issubset('123456789'):
|
||||
raise self._defs.PinException(
|
||||
None, 'Invalid scrambled PIN: {!r}'.format(result.pin))
|
||||
return result
|
||||
except: # noqa
|
||||
conn.init_device()
|
||||
raise
|
||||
|
||||
conn.callback_PinMatrixRequest = new_handler
|
||||
|
||||
cached_passphrase_ack = util.ExpiringCache(seconds=float('inf'))
|
||||
cached_state = None
|
||||
|
||||
def _override_passphrase_handler(self, conn):
|
||||
if self.ui is None:
|
||||
return
|
||||
|
||||
def new_handler(msg):
|
||||
try:
|
||||
if msg.on_device is True:
|
||||
return self._defs.PassphraseAck()
|
||||
ack = self.__class__.cached_passphrase_ack.get()
|
||||
if ack:
|
||||
log.debug('re-using cached %s passphrase', self)
|
||||
return ack
|
||||
|
||||
passphrase = self.ui.get_passphrase()
|
||||
passphrase = mnemonic.Mnemonic.normalize_string(passphrase)
|
||||
ack = self._defs.PassphraseAck(passphrase=passphrase)
|
||||
|
||||
length = len(ack.passphrase)
|
||||
if length > 50:
|
||||
msg = 'Too long passphrase ({} chars)'.format(length)
|
||||
raise ValueError(msg)
|
||||
|
||||
self.__class__.cached_passphrase_ack.set(ack)
|
||||
return ack
|
||||
except: # noqa
|
||||
conn.init_device()
|
||||
raise
|
||||
|
||||
conn.callback_PassphraseRequest = new_handler
|
||||
|
||||
def _override_state_handler(self, conn):
|
||||
def callback_PassphraseStateRequest(msg):
|
||||
log.debug('caching state from %r', msg)
|
||||
self.__class__.cached_state = msg.state
|
||||
return self._defs.PassphraseStateAck()
|
||||
|
||||
conn.callback_PassphraseStateRequest = callback_PassphraseStateRequest
|
||||
|
||||
def _verify_version(self, connection):
|
||||
f = connection.features
|
||||
log.debug('connected to %s %s', self, f.device_id)
|
||||
@@ -113,10 +53,8 @@ class Trezor(interface.Device):
|
||||
log.debug('using transport: %s', transport)
|
||||
for _ in range(5): # Retry a few times in case of PIN failures
|
||||
connection = self._defs.Client(transport=transport,
|
||||
ui=self.ui,
|
||||
state=self.__class__.cached_state)
|
||||
self._override_pin_handler(connection)
|
||||
self._override_passphrase_handler(connection)
|
||||
self._override_state_handler(connection)
|
||||
self._verify_version(connection)
|
||||
|
||||
try:
|
||||
@@ -132,7 +70,8 @@ class Trezor(interface.Device):
|
||||
|
||||
def close(self):
|
||||
"""Close connection."""
|
||||
self.conn.close()
|
||||
self.__class__.cached_state = self.conn.state
|
||||
super().close()
|
||||
|
||||
def pubkey(self, identity, ecdh=False):
|
||||
"""Return public key."""
|
||||
@@ -140,8 +79,10 @@ class Trezor(interface.Device):
|
||||
log.debug('"%s" getting public key (%s) from %s',
|
||||
identity.to_string(), curve_name, self)
|
||||
addr = identity.get_bip32_address(ecdh=ecdh)
|
||||
result = self.conn.get_public_node(
|
||||
n=addr, ecdsa_curve_name=curve_name)
|
||||
result = self._defs.get_public_node(
|
||||
self.conn,
|
||||
n=addr,
|
||||
ecdsa_curve_name=curve_name)
|
||||
log.debug('result: %s', result)
|
||||
return bytes(result.node.public_key)
|
||||
|
||||
@@ -157,7 +98,8 @@ class Trezor(interface.Device):
|
||||
log.debug('"%s" signing %r (%s) on %s',
|
||||
identity.to_string(), blob, curve_name, self)
|
||||
try:
|
||||
result = self.conn.sign_identity(
|
||||
result = self._defs.sign_identity(
|
||||
self.conn,
|
||||
identity=self._identity_proto(identity),
|
||||
challenge_hidden=blob,
|
||||
challenge_visual='',
|
||||
@@ -166,7 +108,7 @@ class Trezor(interface.Device):
|
||||
assert len(result.signature) == 65
|
||||
assert result.signature[:1] == b'\x00'
|
||||
return bytes(result.signature[1:])
|
||||
except self._defs.CallException as e:
|
||||
except self._defs.TrezorFailure as e:
|
||||
msg = '{} error: {}'.format(self, e)
|
||||
log.debug(msg, exc_info=True)
|
||||
raise interface.DeviceError(msg)
|
||||
@@ -177,7 +119,8 @@ class Trezor(interface.Device):
|
||||
log.debug('"%s" shared session key (%s) for %r from %s',
|
||||
identity.to_string(), curve_name, pubkey, self)
|
||||
try:
|
||||
result = self.conn.get_ecdh_session_key(
|
||||
result = self._defs.get_ecdh_session_key(
|
||||
self.conn,
|
||||
identity=self._identity_proto(identity),
|
||||
peer_public_key=pubkey,
|
||||
ecdsa_curve_name=curve_name)
|
||||
@@ -185,7 +128,7 @@ class Trezor(interface.Device):
|
||||
assert len(result.session_key) in {65, 33} # NIST256 or Curve25519
|
||||
assert result.session_key[:1] == b'\x04'
|
||||
return bytes(result.session_key)
|
||||
except self._defs.CallException as e:
|
||||
except self._defs.TrezorFailure as e:
|
||||
msg = '{} error: {}'.format(self, e)
|
||||
log.debug(msg, exc_info=True)
|
||||
raise interface.DeviceError(msg)
|
||||
|
||||
@@ -4,19 +4,72 @@
|
||||
import os
|
||||
import logging
|
||||
|
||||
from trezorlib.client import CallException, PinException
|
||||
from trezorlib.client import TrezorClient as Client
|
||||
from trezorlib.messages import IdentityType, PassphraseAck, PinMatrixAck, PassphraseStateAck
|
||||
|
||||
try:
|
||||
from trezorlib.transport import get_transport
|
||||
except ImportError:
|
||||
from trezorlib.device import TrezorDevice
|
||||
get_transport = TrezorDevice.find_by_path
|
||||
import mnemonic
|
||||
import semver
|
||||
import trezorlib
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
if semver.match(trezorlib.__version__, ">=0.11.0"):
|
||||
from trezorlib.client import TrezorClient as Client
|
||||
from trezorlib.exceptions import TrezorFailure, PinException
|
||||
from trezorlib.transport import get_transport
|
||||
from trezorlib.messages import IdentityType
|
||||
|
||||
from trezorlib.btc import get_public_node
|
||||
from trezorlib.misc import sign_identity, get_ecdh_session_key
|
||||
|
||||
else:
|
||||
from trezorlib.client import (TrezorClient, CallException as TrezorFailure,
|
||||
PinException)
|
||||
from trezorlib.messages import IdentityType
|
||||
from trezorlib import messages
|
||||
from trezorlib.transport import get_transport
|
||||
|
||||
get_public_node = TrezorClient.get_public_node
|
||||
sign_identity = TrezorClient.sign_identity
|
||||
get_ecdh_session_key = TrezorClient.get_ecdh_session_key
|
||||
|
||||
class Client(TrezorClient):
|
||||
def __init__(self, transport, ui, state=None):
|
||||
super().__init__(transport, state=state)
|
||||
self.ui = ui
|
||||
|
||||
def callback_PinMatrixRequest(self, msg):
|
||||
try:
|
||||
pin = self.ui.get_pin(msg.type)
|
||||
if not pin.isdigit():
|
||||
raise PinException(
|
||||
None, 'Invalid scrambled PIN: {!r}'.format(pin))
|
||||
return messages.PinMatrixAck(pin=pin)
|
||||
except: # noqa
|
||||
self.init_device()
|
||||
raise
|
||||
|
||||
def callback_PassphraseRequest(self, msg):
|
||||
try:
|
||||
if msg.on_device is True:
|
||||
return messages.PassphraseAck()
|
||||
|
||||
passphrase = self.ui.get_passphrase()
|
||||
passphrase = mnemonic.Mnemonic.normalize_string(passphrase)
|
||||
|
||||
length = len(passphrase)
|
||||
if length > 50:
|
||||
msg = 'Too long passphrase ({} chars)'.format(length)
|
||||
raise ValueError(msg)
|
||||
|
||||
return messages.PassphraseAck(passphrase=passphrase)
|
||||
except: # noqa
|
||||
self.init_device()
|
||||
raise
|
||||
|
||||
def callback_PassphraseStateRequest(self, msg):
|
||||
self.state = msg.state
|
||||
return messages.PassphraseStateAck()
|
||||
|
||||
|
||||
def find_device():
|
||||
"""Selects a transport based on `TREZOR_PATH` environment variable.
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ class UI:
|
||||
self.options_getter = create_default_options_getter()
|
||||
self.device_name = device_type.__name__
|
||||
|
||||
def get_pin(self, name=None):
|
||||
def get_pin(self, _code=None):
|
||||
"""Ask the user for (scrambled) PIN."""
|
||||
description = (
|
||||
'Use the numeric keypad to describe number positions.\n'
|
||||
@@ -33,21 +33,25 @@ class UI:
|
||||
' 4 5 6\n'
|
||||
' 1 2 3')
|
||||
return interact(
|
||||
title='{} PIN'.format(name or self.device_name),
|
||||
title='{} PIN'.format(self.device_name),
|
||||
prompt='PIN:',
|
||||
description=description,
|
||||
binary=self.pin_entry_binary,
|
||||
options=self.options_getter())
|
||||
|
||||
def get_passphrase(self, name=None):
|
||||
def get_passphrase(self):
|
||||
"""Ask the user for passphrase."""
|
||||
return interact(
|
||||
title='{} passphrase'.format(name or self.device_name),
|
||||
title='{} passphrase'.format(self.device_name),
|
||||
prompt='Passphrase:',
|
||||
description=None,
|
||||
binary=self.passphrase_entry_binary,
|
||||
options=self.options_getter())
|
||||
|
||||
def button_request(self, _code=None):
|
||||
# XXX: show notification to the user?
|
||||
pass
|
||||
|
||||
|
||||
def create_default_options_getter():
|
||||
"""Return current TTY and DISPLAY settings for GnuPG pinentry."""
|
||||
|
||||
Reference in New Issue
Block a user