mirror of
https://github.com/romanz/amodem.git
synced 2026-04-19 12:46:00 +08:00
trezor: use identities instead of labels
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
import io
|
import io
|
||||||
|
import struct
|
||||||
import binascii
|
import binascii
|
||||||
|
|
||||||
from . import util
|
from . import util
|
||||||
@@ -21,10 +22,10 @@ class TrezorLibrary(object):
|
|||||||
return TrezorClient(HidTransport(devices[0]))
|
return TrezorClient(HidTransport(devices[0]))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def identity(label, proto='ssh'):
|
def parse_identity(s):
|
||||||
# pylint: disable=import-error
|
# pylint: disable=import-error
|
||||||
from trezorlib.types_pb2 import IdentityType
|
from trezorlib.types_pb2 import IdentityType
|
||||||
return IdentityType(host=label, proto=proto)
|
return IdentityType(**_string_to_identity(s))
|
||||||
|
|
||||||
|
|
||||||
class Client(object):
|
class Client(object):
|
||||||
@@ -52,19 +53,20 @@ class Client(object):
|
|||||||
self.client.close()
|
self.client.close()
|
||||||
|
|
||||||
def get_public_key(self, label):
|
def get_public_key(self, label):
|
||||||
addr = _get_address(self.factory.identity(label))
|
log.info('getting %r public key from Trezor...', label)
|
||||||
log.info('getting %r SSH public key from Trezor...', label)
|
identity = self.factory.parse_identity(label)
|
||||||
|
addr = _get_address(identity)
|
||||||
node = self.client.get_public_node(addr, self.curve_name)
|
node = self.client.get_public_node(addr, self.curve_name)
|
||||||
return node.node.public_key
|
return node.node.public_key
|
||||||
|
|
||||||
def sign_ssh_challenge(self, label, blob):
|
def sign_ssh_challenge(self, label, blob):
|
||||||
ident = self.factory.identity(label)
|
identity = self.factory.parse_identity(label)
|
||||||
msg = _parse_ssh_blob(blob)
|
msg = _parse_ssh_blob(blob)
|
||||||
request = 'user: "{user}"'.format(**msg)
|
request = 'user: "{user}"'.format(**msg)
|
||||||
|
|
||||||
log.info('confirm %s connection to %r using Trezor...',
|
log.info('confirm %s connection to %r using Trezor...',
|
||||||
request, label)
|
request, label)
|
||||||
s = self.client.sign_identity(identity=ident,
|
s = self.client.sign_identity(identity=identity,
|
||||||
challenge_hidden=blob,
|
challenge_hidden=blob,
|
||||||
challenge_visual=request,
|
challenge_visual=request,
|
||||||
ecdsa_curve_name=self.curve_name)
|
ecdsa_curve_name=self.curve_name)
|
||||||
@@ -77,9 +79,53 @@ class Client(object):
|
|||||||
return (r, s)
|
return (r, s)
|
||||||
|
|
||||||
|
|
||||||
def _get_address(ident):
|
def _lsplit(s, sep):
|
||||||
index = '\x00' * 4
|
p = None
|
||||||
addr = index + '{}://{}'.format(ident.proto, ident.host)
|
if sep in s:
|
||||||
|
p, s = s.split(sep, 1)
|
||||||
|
return (p, s)
|
||||||
|
|
||||||
|
|
||||||
|
def _rsplit(s, sep):
|
||||||
|
p = None
|
||||||
|
if sep in s:
|
||||||
|
s, p = s.rsplit(sep, 1)
|
||||||
|
return (s, p)
|
||||||
|
|
||||||
|
|
||||||
|
def _string_to_identity(s):
|
||||||
|
proto, s = _lsplit(s, '://')
|
||||||
|
user, s = _lsplit(s, '@')
|
||||||
|
s, path = _rsplit(s, '/')
|
||||||
|
host, port = _rsplit(s, ':')
|
||||||
|
|
||||||
|
if not proto:
|
||||||
|
proto = 'ssh'
|
||||||
|
|
||||||
|
result = [
|
||||||
|
('proto', proto), ('user', user), ('host', host),
|
||||||
|
('port', port), ('path', path)
|
||||||
|
]
|
||||||
|
return {k: v for k, v in result if v}
|
||||||
|
|
||||||
|
|
||||||
|
def _identity_to_string(identity):
|
||||||
|
result = []
|
||||||
|
if identity.proto:
|
||||||
|
result.append(identity.proto + '://')
|
||||||
|
if identity.user:
|
||||||
|
result.append(identity.user + '@')
|
||||||
|
result.append(identity.host)
|
||||||
|
if identity.port:
|
||||||
|
result.append(':' + identity.port)
|
||||||
|
if identity.path:
|
||||||
|
result.append('/' + identity.path)
|
||||||
|
return ''.join(result)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_address(identity):
|
||||||
|
index = struct.pack('<L', identity.index)
|
||||||
|
addr = index + _identity_to_string(identity)
|
||||||
digest = formats.hashfunc(addr).digest()
|
digest = formats.hashfunc(addr).digest()
|
||||||
s = io.BytesIO(bytearray(digest))
|
s = io.BytesIO(bytearray(digest))
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
import sys
|
import sys
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
@@ -10,12 +11,14 @@ log = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
fmt = '%(asctime)s %(levelname)-12s %(message)-100s [%(filename)s]'
|
fmt = '%(asctime)s %(levelname)-12s %(message)-100s [%(filename)s:%(lineno)d]'
|
||||||
p = argparse.ArgumentParser()
|
p = argparse.ArgumentParser()
|
||||||
p.add_argument('-k', '--key-label',
|
p.add_argument('-v', '--verbose', action='count', default=0,
|
||||||
metavar='LABEL', dest='labels', action='append', default=[])
|
help='increase the the logging verbosity')
|
||||||
p.add_argument('-v', '--verbose', action='count', default=0)
|
p.add_argument('-c', dest='command', type=str, default=None,
|
||||||
p.add_argument('command', type=str, nargs='*')
|
help='command to run under the SSH agent')
|
||||||
|
p.add_argument('identity', type=str, nargs='*',
|
||||||
|
help='proto://[user@]host[:port][/path]')
|
||||||
args = p.parse_args()
|
args = p.parse_args()
|
||||||
|
|
||||||
verbosity = [logging.WARNING, logging.INFO, logging.DEBUG]
|
verbosity = [logging.WARNING, logging.INFO, logging.DEBUG]
|
||||||
@@ -24,20 +27,21 @@ def main():
|
|||||||
|
|
||||||
with trezor.Client(factory=trezor.TrezorLibrary) as client:
|
with trezor.Client(factory=trezor.TrezorLibrary) as client:
|
||||||
key_files = []
|
key_files = []
|
||||||
for label in args.labels:
|
for label in args.identity:
|
||||||
pubkey = client.get_public_key(label=label)
|
pubkey = client.get_public_key(label)
|
||||||
key_file = formats.export_public_key(pubkey=pubkey, label=label)
|
key_file = formats.export_public_key(pubkey=pubkey, label=label)
|
||||||
key_files.append(key_file)
|
key_files.append(key_file)
|
||||||
|
|
||||||
if not args.command:
|
command = args.command
|
||||||
sys.stdout.write(''.join(key_files))
|
if not command:
|
||||||
return
|
command = os.environ['SHELL']
|
||||||
|
log.info('using %r shell', command)
|
||||||
|
|
||||||
signer = client.sign_ssh_challenge
|
signer = client.sign_ssh_challenge
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with server.serve(key_files=key_files, signer=signer) as env:
|
with server.serve(key_files=key_files, signer=signer) as env:
|
||||||
return server.run_process(command=args.command, environ=env)
|
return server.run_process(command=command, environ=env)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
log.info('server stopped')
|
log.info('server stopped')
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user