From 838df004f07ce38112537f7c8e5e81dd09ea6930 Mon Sep 17 00:00:00 2001 From: Roman Zeyde Date: Thu, 13 Aug 2015 18:31:37 +0300 Subject: [PATCH] trezor: fix protocol defaults --- sshagent/__main__.py | 4 ++-- sshagent/trezor.py | 15 +++++++++------ 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/sshagent/__main__.py b/sshagent/__main__.py index 77f3baf..f756056 100644 --- a/sshagent/__main__.py +++ b/sshagent/__main__.py @@ -79,7 +79,7 @@ def trezor_agent(): if command: command = ['git'] + command - identity = client.get_identity(label=label) + identity = client.get_identity(label=label, protocol='ssh') public_key = client.get_public_key(identity=identity) use_shell = False @@ -104,7 +104,7 @@ def trezor_agent(): return def signer(label, blob): - identity = client.get_identity(label=label) + identity = client.get_identity(label=label, protocol='ssh') return client.sign_ssh_challenge(identity=identity, blob=blob) try: diff --git a/sshagent/trezor.py b/sshagent/trezor.py index e33972d..51015e1 100644 --- a/sshagent/trezor.py +++ b/sshagent/trezor.py @@ -54,8 +54,12 @@ class Client(object): log.info('disconnected from Trezor') self.client.close() - def get_identity(self, label): - return _string_to_identity(label, self.factory.identity_type) + def get_identity(self, label, protocol=None): + identity = _string_to_identity(label, self.factory.identity_type) + if protocol is not None: + identity.proto = protocol + + return identity def get_public_key(self, identity): assert identity.proto == 'ssh' @@ -152,16 +156,15 @@ _identity_regexp = re.compile(''.join([ def _string_to_identity(s, identity_type): m = _identity_regexp.match(s) result = m.groupdict() - if not result.get('proto'): - result['proto'] = 'ssh' # otherwise, Trezor will use SECP256K1 curve - log.debug('parsed identity: %s', result) kwargs = {k: v for k, v in result.items() if v} return identity_type(**kwargs) def _identity_to_string(identity): - result = [identity.proto + '://'] + result = [] + if identity.proto: + result.append(identity.proto + '://') if identity.user: result.append(identity.user + '@') result.append(identity.host)