Compare commits

..

84 Commits

Author SHA1 Message Date
Roman Zeyde
b3a6c76631 bump version 2016-03-12 21:08:52 +02:00
Roman Zeyde
f056f1fac5 fixup lint errors 2016-03-12 21:07:10 +02:00
Roman Zeyde
716dc82312 bump version 2016-03-12 20:58:38 +02:00
Roman Zeyde
0e2a19f7ce client: fixup UT 2016-03-12 20:57:16 +02:00
Roman Zeyde
2cdbc89d28 protocol: fixup UT 2016-03-12 20:57:09 +02:00
Roman Zeyde
1022e54d6a protocol: fail gracefully on cancellation 2016-03-12 20:42:14 +02:00
Roman Zeyde
ea88f425f5 protocol: fail on unsupported commands 2016-03-12 20:40:09 +02:00
Roman Zeyde
000860feaf main: add --test flag for verifying SSH configuration
https://help.github.com/articles/testing-your-ssh-connection/
2016-03-12 15:32:29 +02:00
Roman Zeyde
2a5196003e tests: update for CallException handling 2016-03-06 22:06:45 +02:00
Roman Zeyde
e10b42bbb5 client: catch CallException for cancellation handling 2016-03-06 21:59:17 +02:00
Roman Zeyde
b07d7e6535 server: handle IOError gracefully 2016-03-06 21:58:39 +02:00
Roman Zeyde
4838030be5 factory: add CallException type 2016-03-06 21:58:11 +02:00
Roman Zeyde
c9f341a42b main: handle 'pushurl' and 'url' remote settings 2016-03-06 21:21:25 +02:00
Roman Zeyde
bdd2568b2c main: log pubkey fingerprint on INFO level 2016-03-05 20:49:14 +02:00
Roman Zeyde
ae20ae4a04 bump version 2016-03-05 19:54:51 +02:00
Roman Zeyde
f15c2c7236 README: add trezor-git screencast 2016-03-05 15:12:30 +02:00
Roman Zeyde
e6ccc324a0 main: ignore path from git remote URL
It's much easier to use single keypair per user@host
2016-03-05 14:56:58 +02:00
Roman Zeyde
7c102e435e setup: add more classifiers 2016-03-05 11:29:05 +02:00
Roman Zeyde
7f6bb12b24 bump version 2016-03-05 11:20:11 +02:00
Roman Zeyde
98e875562e main: add trezor-git entry point 2016-03-05 11:18:24 +02:00
Roman Zeyde
4384b93c19 main: remove unneeded use_shell parameter 2016-03-05 11:03:10 +02:00
Roman Zeyde
8a90a8cd84 main: split git from ssh 2016-03-05 10:56:30 +02:00
Roman Zeyde
1e86983782 main: split argument parser 2016-03-05 10:46:36 +02:00
Roman Zeyde
c63201c90c client: show visual challenge 2016-03-05 10:39:47 +02:00
Roman Zeyde
19b00dc427 client: add logging for challenge sizes 2016-02-27 20:09:03 +02:00
Roman Zeyde
aa35981980 README: add 'apt-get' to installation section 2016-02-27 09:49:15 +02:00
Roman Zeyde
8909b38107 main: use command-line for git interaction 2016-02-20 18:24:14 +02:00
Roman Zeyde
6d9aa9cb8a README: license badge is broken most of the time 2016-02-19 20:54:36 +02:00
Roman Zeyde
d6532311b9 fix PEP8 & docstrings 2016-02-19 20:52:59 +02:00
Roman Zeyde
41b30b42b5 main: add git identity via "origin" remote 2016-02-19 20:48:16 +02:00
Roman Zeyde
5b0e56697f travis: add pydocstyle 2016-02-19 11:41:05 +02:00
Roman Zeyde
0e6d998b4c tox: add pydocstyle 2016-02-19 11:39:12 +02:00
Roman Zeyde
2c7fabfa35 tests: add docstrings 2016-02-19 11:35:34 +02:00
Roman Zeyde
1adccdbfe6 __init__: add docstrings 2016-02-19 11:35:27 +02:00
Roman Zeyde
04f4bbf2ac main: add docstrings 2016-02-19 11:35:16 +02:00
Roman Zeyde
bbe963d0ff util: rename UTs 2016-02-19 11:34:58 +02:00
Roman Zeyde
c49514754b util: add docstrings 2016-02-19 11:34:20 +02:00
Roman Zeyde
2ebefff909 server: add docstrings 2016-02-19 11:19:01 +02:00
Roman Zeyde
21e89014c9 protocol: add docstrings and replace custom exceptions 2016-02-19 10:49:39 +02:00
Roman Zeyde
566e4310e1 formats: add docstrings 2016-02-19 10:40:39 +02:00
Roman Zeyde
e1441518d4 factory: add docstrings 2016-02-19 10:08:36 +02:00
Roman Zeyde
5cb12a43de client: add docstrings 2016-02-19 10:07:33 +02:00
Roman Zeyde
df607f3665 pylint: add 'no-member' check 2016-02-18 14:28:16 +02:00
Roman Zeyde
d712509a4e client: show current time instead of identity.path 2016-02-17 15:04:10 +02:00
Roman Zeyde
40e2d9fb2c fixup imports order
isort -rc trezor_agent
2016-02-15 20:53:14 +02:00
Roman Zeyde
cd4cc059d6 main: remove git-config parsing code 2016-02-15 20:52:44 +02:00
Roman Zeyde
2b047f0525 main: refactor shell flag 2016-02-15 20:38:34 +02:00
Roman Zeyde
64776fd294 rename client test 2016-02-15 17:22:57 +02:00
Roman Zeyde
231995bd1a remove trezor module 2016-02-15 17:22:01 +02:00
Roman Zeyde
ff76f17c02 client: elaborate SSH blob parsing 2016-02-13 20:26:23 +02:00
Roman Zeyde
963e80b49b client: move logging from parsing code 2016-02-06 18:32:51 +02:00
Roman Zeyde
dee13b75ea client: remove unneeded 'if' 2016-02-06 18:27:46 +02:00
Roman Zeyde
be86507e00 client: pass index as default argument 2016-02-06 17:52:49 +02:00
Roman Zeyde
2f2663ef94 client: set identity index explicitly 2016-02-06 17:51:57 +02:00
Roman Zeyde
cafa218e19 server: pass handler and add debug option 2016-01-26 21:14:52 +02:00
Roman Zeyde
50b627ed45 protocol: allow debugging SSH message handler 2016-01-26 21:14:27 +02:00
Roman Zeyde
7f36097c15 tests: refactor mocks and fakes 2016-01-22 12:04:24 +02:00
Roman Zeyde
a4b905cd6f bump version 2016-01-19 22:56:54 +02:00
Roman Zeyde
2eff21f96c factory: refactor for easier testing 2016-01-19 22:52:52 +02:00
Roman Zeyde
9afd07e867 server: make sure accepted UNIX sockets are blocking
It was a problem on Mac OS X, where sometimes we got EAGAIN
errors from calling socket.recv() on them.
2016-01-18 22:49:27 +02:00
Roman Zeyde
b101281a5b main: add command-line argument for setting UNIX socket timeout 2016-01-16 22:14:36 +02:00
Roman Zeyde
8c6ac43cf4 Merge Trezor and KeepKey functionality 2016-01-15 13:20:38 +02:00
Kenneth Heutmaker
5932a89dc5 Make it work with KeepKey 2016-01-14 13:28:32 -08:00
Roman Zeyde
2009160ff2 Revert "travis: test with tox"
This reverts commit 3d8072522c.
2016-01-09 17:46:07 +02:00
Roman Zeyde
3d8072522c travis: test with tox 2016-01-09 17:41:17 +02:00
Roman Zeyde
0c63aef719 sort imports using isort tool 2016-01-09 16:06:47 +02:00
Roman Zeyde
c454114c4e README: add gitter chat 2016-01-09 12:15:43 +02:00
Roman Zeyde
f9133f7e05 README: fixup license link 2016-01-09 11:19:33 +02:00
Roman Zeyde
33a6951a96 server: don't crash after single exception 2016-01-08 20:46:49 +02:00
Roman Zeyde
fb0d0a5f61 server: stop the server via a threading.Event
It seems that Mac OS does not support calling socket.shutdown(socket.SHUT_RD)
on a listening socket (see https://github.com/romanz/trezor-agent/issues/6).
The following implementation will set the accept() timeout to 0.1s and stop
the server if a threading.Event (named "quit_event") is set by the main thread.
2016-01-08 20:28:38 +02:00
Roman Zeyde
7ea20c7009 test_trezor: verify serialized signature 2016-01-08 17:30:08 +02:00
Roman Zeyde
4247558166 README: add subshell demo 2016-01-08 16:07:29 +02:00
Roman Zeyde
fe1e1d2bb9 server: log command with INFO level 2016-01-08 16:04:57 +02:00
Roman Zeyde
1a5b8118ad setup.py: support for Python 3.4 2016-01-05 20:46:55 +02:00
Roman Zeyde
3a806c6d77 beta release 2016-01-05 19:54:20 +02:00
Roman Zeyde
3b61f86c25 README: fixup license to match the repository 2016-01-05 18:49:49 +02:00
Roman Zeyde
06d84c387c bump version 2016-01-04 22:49:28 +02:00
Roman Zeyde
8347142a99 setup.py: fixup license to match the repository 2016-01-04 21:26:17 +02:00
Roman Zeyde
7dabe2c555 test_protocol: fix bytes->str 2016-01-04 21:03:46 +02:00
Roman Zeyde
d6ee3d8995 tox: add py34 2016-01-04 21:03:27 +02:00
Roman Zeyde
c3fa79e450 Fix a few pylint issues 2016-01-04 19:21:56 +02:00
Roman Zeyde
15b10c9a7e bump version 2016-01-04 19:05:43 +02:00
Roman Zeyde
e19d76398e formats: verify public key according to requested ECDSA curve 2015-12-18 16:04:20 +02:00
Roman Zeyde
535b4d50c7 Fix SSH connection arguments handling 2015-11-27 17:26:06 +02:00
23 changed files with 898 additions and 426 deletions

View File

@@ -1,2 +1,2 @@
[MESSAGES CONTROL]
disable=invalid-name, missing-docstring, locally-disabled,no-member
disable=invalid-name, missing-docstring, locally-disabled

View File

@@ -5,12 +5,13 @@ python:
- "3.4"
install:
- pip install ecdsa ed25519 # test without trezorlib for now
- pip install pylint coverage pep8
- pip install ecdsa ed25519 semver # test without trezorlib for now
- pip install pylint coverage pep8 pydocstyle
script:
- pep8 trezor_agent
- pylint --report=no --rcfile .pylintrc trezor_agent
- pylint --reports=no --rcfile .pylintrc trezor_agent
- pydocstyle trezor_agent
- coverage run --source trezor_agent/ -m py.test -v
after_success:

View File

@@ -2,10 +2,10 @@
[![Build Status](https://travis-ci.org/romanz/trezor-agent.svg?branch=master)](https://travis-ci.org/romanz/trezor-agent)
[![Python Versions](https://img.shields.io/pypi/pyversions/trezor_agent.svg)](https://pypi.python.org/pypi/trezor_agent/)
[![License](https://img.shields.io/pypi/l/trezor_agent.svg)](https://pypi.python.org/pypi/trezor_agent/)
[![Package Version](https://img.shields.io/pypi/v/trezor_agent.svg)](https://pypi.python.org/pypi/trezor_agent/)
[![Development Status](https://img.shields.io/pypi/status/trezor_agent.svg)](https://pypi.python.org/pypi/trezor_agent/)
[![Downloads](https://img.shields.io/pypi/dm/trezor_agent.svg)](https://pypi.python.org/pypi/trezor_agent/)
[![Chat](https://badges.gitter.im/romanz/trezor-agent.svg)](https://gitter.im/romanz/trezor-agent)
See SatoshiLabs' blog post about this feature:
@@ -13,13 +13,21 @@ See SatoshiLabs' blog post about this feature:
## Screencast demo usage
### Simple usage (single SSH session)
[![Demo](https://asciinema.org/a/22959.png)](https://asciinema.org/a/22959)
### Advanced usage (multiple SSH sessions from a sub-shell)
[![Subshell](https://asciinema.org/a/33240.png)](https://asciinema.org/a/33240)
### Using for GitHub SSH authentication (via `trezor-git` utility)
[![GitHub](https://asciinema.org/a/38337.png)](https://asciinema.org/a/38337)
## Installation
First, make sure that the latest `trezorlib` Python package
is installed correctly (at least v0.6.6):
$ apt-get install python-dev libusb-1.0-0-dev libudev-dev
$ pip install Cython trezor
Then, install the latest `trezor_agent` package:

View File

@@ -3,27 +3,32 @@ from setuptools import setup
setup(
name='trezor_agent',
version='0.5.0',
version='0.6.5',
description='Using Trezor as hardware SSH agent',
author='Roman Zeyde',
author_email='roman.zeyde@gmail.com',
license='MIT',
url='http://github.com/romanz/trezor-agent',
packages=['trezor_agent', 'trezor_agent.trezor'],
install_requires=['ecdsa>=0.13', 'ed25519>=1.4', 'Cython>=0.23.4', 'trezor>=0.6.6'],
packages=['trezor_agent'],
install_requires=['ecdsa>=0.13', 'ed25519>=1.4', 'Cython>=0.23.4', 'trezor>=0.6.6', 'keepkey>=0.7.0', 'semver>=2.2'],
platforms=['POSIX'],
classifiers=[
'Development Status :: 3 - Alpha',
'Environment :: Console',
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Intended Audience :: Information Technology',
'License :: OSI Approved :: MIT License',
'Intended Audience :: System Administrators',
'License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3)',
'Operating System :: POSIX',
'Programming Language :: Python :: 2.7',
'Programming Language :: Python :: 3.4',
'Topic :: Software Development :: Libraries :: Python Modules',
'Topic :: System :: Networking',
'Topic :: Communications',
'Topic :: Security',
'Topic :: Utilities',
],
entry_points={'console_scripts': [
'trezor-agent = trezor_agent.__main__:trezor_agent'
'trezor-agent = trezor_agent.__main__:run_agent',
'trezor-git = trezor_agent.__main__:run_git',
]},
)

View File

@@ -1,5 +1,5 @@
[tox]
envlist = py27
envlist = py27,py34
[testenv]
deps=
pytest
@@ -7,9 +7,12 @@ deps=
pep8
coverage
pylint
six
semver
pydocstyle
commands=
pep8 trezor_agent
pylint --report=no --rcfile .pylintrc trezor_agent
pylint --reports=no --rcfile .pylintrc trezor_agent
pydocstyle trezor_agent
coverage run --omit='trezor_agent/__main__.py' --source trezor_agent -m py.test -v trezor_agent
coverage report
coverage html

View File

@@ -0,0 +1 @@
"""SSH-agent implementation using hardware authentication devices."""

View File

@@ -1,68 +1,81 @@
import os
import re
import sys
"""SSH-agent implementation using hardware authentication devices."""
import argparse
import subprocess
import functools
from . import trezor
from . import server
from . import formats
import logging
import re
import os
import subprocess
import sys
import time
from . import client, formats, protocol, server
log = logging.getLogger(__name__)
def identity_from_gitconfig():
out = subprocess.check_output(args='git config --list --local'.split())
config = [line.split('=', 1) for line in out.strip().split('\n')]
config_dict = dict(item for item in config if len(item) == 2)
def ssh_args(label):
"""Create SSH command for connecting specified server."""
identity = client.string_to_identity(label, identity_type=dict)
name_regex = re.compile(r'^remote\..*\.trezor$')
names = [item[0] for item in config if name_regex.match(item[0])]
if len(names) != 1:
log.error('please add "trezor" key '
'to a single remote section at .git/config')
sys.exit(1)
key_name = names[0]
identity_label = config_dict.get(key_name)
if identity_label:
return identity_label
args = []
if 'port' in identity:
args += ['-p', identity['port']]
if 'user' in identity:
args += ['-l', identity['user']]
# extract remote name marked as TREZOR's
section_name, _ = key_name.rsplit('.', 1)
key_name = section_name + '.url'
url = config_dict[key_name]
log.debug('using "%s=%s" from git-config', key_name, url)
user, url = url.split('@', 1)
host, path = url.split(':', 1)
return 'ssh://{0}@{1}/{2}'.format(user, host, path)
return ['ssh'] + args + [identity['host']]
def create_agent_parser():
def create_parser():
"""Create argparse.ArgumentParser for this tool."""
p = argparse.ArgumentParser()
p.add_argument('-v', '--verbose', default=0, action='count')
p.add_argument('identity', type=str, default=None,
help='proto://[user@]host[:port][/path]')
curve_names = [name.decode('ascii') for name in formats.SUPPORTED_CURVES]
curve_names = ', '.join(sorted(curve_names))
p.add_argument('-e', '--ecdsa-curve-name', metavar='CURVE',
default=formats.CURVE_NIST256,
help='specify ECDSA curve name: ' + curve_names)
p.add_argument('--timeout',
default=server.UNIX_SOCKET_TIMEOUT, type=float,
help='Timeout for accepting SSH client connections')
p.add_argument('--debug', default=False, action='store_true',
help='Log SSH protocol messages for debugging.')
return p
def create_agent_parser():
"""Specific parser for SSH connection."""
p = create_parser()
g = p.add_mutually_exclusive_group()
g.add_argument('-s', '--shell', default=False, action='store_true',
help='run ${SHELL} as subprocess under SSH agent')
g.add_argument('-c', '--connect', default=False, action='store_true',
help='connect to specified host via SSH')
curves = ', '.join(sorted(formats.SUPPORTED_CURVES))
p.add_argument('-e', '--ecdsa-curve-name', metavar='CURVE',
default=formats.CURVE_NIST256,
help='specify ECDSA curve name: ' + curves)
p.add_argument('identity', type=str, default=None,
help='proto://[user@]host[:port][/path]')
p.add_argument('command', type=str, nargs='*', metavar='ARGUMENT',
help='command to run under the SSH agent')
return p
def create_git_parser():
"""Specific parser for git commands."""
p = create_parser()
p.add_argument('-r', '--remote', default='origin',
help='use this git remote URL to generate SSH identity')
p.add_argument('-t', '--test', action='store_true',
help='test connection using `ssh -T user@host` command')
p.add_argument('command', type=str, nargs='*', metavar='ARGUMENT',
help='Git command to run under the SSH agent')
return p
def setup_logging(verbosity):
"""Configure logging for this tool."""
fmt = ('%(asctime)s %(levelname)-12s %(message)-100s '
'[%(filename)s:%(lineno)d]')
levels = [logging.WARNING, logging.INFO, logging.DEBUG]
@@ -70,47 +83,95 @@ def setup_logging(verbosity):
logging.basicConfig(format=fmt, level=level)
def ssh_sign(client, label, blob):
return client.sign_ssh_challenge(label=label, blob=blob)
def git_host(remote_name, attributes):
"""Extract git SSH host for specified remote name."""
try:
output = subprocess.check_output('git config --local --list'.split())
except subprocess.CalledProcessError:
return
for attribute in attributes:
name = r'remote.{0}.{1}'.format(remote_name, attribute)
matches = re.findall(re.escape(name) + '=(.*)', output)
log.debug('%r: %r', name, matches)
if not matches:
continue
url = matches[0].strip()
user, url = url.split('@', 1)
host, _ = url.split(':', 1) # skip unused path (1 key per user@host)
return '{}@{}'.format(user, host)
def run_agent(client_factory):
def ssh_sign(conn, label, blob):
"""Perform SSH signature using given hardware device connection."""
now = time.strftime('%Y-%m-%d %H:%M:%S')
return conn.sign_ssh_challenge(label=label, blob=blob, visual=now)
def run_server(conn, public_key, command, debug, timeout):
"""Common code for run_agent and run_git below."""
try:
signer = functools.partial(ssh_sign, conn=conn)
public_key = formats.import_public_key(public_key)
log.info('using SSH public key: %s', public_key['fingerprint'])
handler = protocol.Handler(keys=[public_key], signer=signer,
debug=debug)
with server.serve(handler=handler, timeout=timeout) as env:
return server.run_process(command=command, environ=env)
except KeyboardInterrupt:
log.info('server stopped')
def run_agent(client_factory=client.Client):
"""Run ssh-agent using given hardware client factory."""
args = create_agent_parser().parse_args()
setup_logging(verbosity=args.verbose)
with client_factory(curve=args.ecdsa_curve_name) as client:
with client_factory(curve=args.ecdsa_curve_name) as conn:
label = args.identity
command = args.command
if label == 'git':
label = identity_from_gitconfig()
log.info('using identity %r for git command %r', label, command)
if command:
command = ['git'] + command
public_key = conn.get_public_key(label=label)
public_key = client.get_public_key(label=label)
use_shell = False
if args.connect:
command = ['ssh', label] + args.command
command = ssh_args(label) + args.command
log.debug('SSH connect: %r', command)
if args.shell:
command, use_shell = os.environ['SHELL'], True
use_shell = bool(args.shell)
if use_shell:
command = os.environ['SHELL']
log.debug('using shell: %r', command)
if not command:
sys.stdout.write(public_key)
return
try:
signer = functools.partial(ssh_sign, client=client)
with server.serve(public_keys=[public_key], signer=signer) as env:
return server.run_process(command=command, environ=env,
use_shell=use_shell)
except KeyboardInterrupt:
log.info('server stopped')
return run_server(conn=conn, public_key=public_key, command=command,
debug=args.debug, timeout=args.timeout)
def trezor_agent():
run_agent(trezor.Client)
def run_git(client_factory=client.Client):
"""Run git under ssh-agent using given hardware client factory."""
args = create_git_parser().parse_args()
setup_logging(verbosity=args.verbose)
with client_factory(curve=args.ecdsa_curve_name) as conn:
label = git_host(args.remote, ['pushurl', 'url'])
if not label:
log.error('Could not find "%s" remote in .git/config', args.remote)
return
public_key = conn.get_public_key(label=label)
if not args.test:
if args.command:
command = ['git'] + args.command
else:
sys.stdout.write(public_key)
return
else:
command = ['ssh', '-T', label]
return run_server(conn=conn, public_key=public_key, command=command,
debug=args.debug, timeout=args.timeout)

157
trezor_agent/client.py Normal file
View File

@@ -0,0 +1,157 @@
"""
Connection to hardware authentication device.
It is used for getting SSH public keys and ECDSA signing of server requests.
"""
import binascii
import io
import logging
import re
import struct
from . import factory, formats, util
log = logging.getLogger(__name__)
class Client(object):
"""Client wrapper for SSH authentication device."""
def __init__(self, loader=factory.load, curve=formats.CURVE_NIST256):
"""Connect to hardware device."""
client_wrapper = loader()
self.client = client_wrapper.connection
self.identity_type = client_wrapper.identity_type
self.device_name = client_wrapper.device_name
self.call_exception = client_wrapper.call_exception
self.curve = curve
def __enter__(self):
"""Start a session, and test connection."""
msg = 'Hello World!'
assert self.client.ping(msg) == msg
return self
def __exit__(self, *args):
"""Forget PIN, shutdown screen and disconnect."""
log.info('disconnected from %s', self.device_name)
self.client.clear_session()
self.client.close()
def get_identity(self, label, index=0):
"""Parse label string into Identity protobuf."""
identity = string_to_identity(label, self.identity_type)
identity.proto = 'ssh'
identity.index = index
return identity
def get_public_key(self, label):
"""Get SSH public key corresponding to specified by label."""
identity = self.get_identity(label=label)
label = identity_to_string(identity) # canonize key label
log.info('getting "%s" public key (%s) from %s...',
label, self.curve, self.device_name)
addr = _get_address(identity)
node = self.client.get_public_node(n=addr,
ecdsa_curve_name=self.curve)
pubkey = node.node.public_key
vk = formats.decompress_pubkey(pubkey=pubkey, curve_name=self.curve)
return formats.export_public_key(vk=vk, label=label)
def sign_ssh_challenge(self, label, blob, visual=''):
"""Sign given blob using a private key, specified by the label."""
identity = self.get_identity(label=label)
msg = _parse_ssh_blob(blob)
log.debug('%s: user %r via %r (%r)',
msg['conn'], msg['user'], msg['auth'], msg['key_type'])
log.debug('nonce: %s', binascii.hexlify(msg['nonce']))
log.debug('fingerprint: %s', msg['public_key']['fingerprint'])
log.debug('hidden challenge size: %d bytes', len(blob))
log.debug('visual challenge size: %d bytes = %r', len(visual), visual)
log.info('please confirm user "%s" login to "%s" using %s...',
msg['user'], label, self.device_name)
try:
result = self.client.sign_identity(identity=identity,
challenge_hidden=blob,
challenge_visual=visual,
ecdsa_curve_name=self.curve)
except self.call_exception as e:
code, msg = e.args
log.warning('%s error #%s: %s', self.device_name, code, msg)
raise IOError(msg) # close current connection, keep server open
verifying_key = formats.decompress_pubkey(pubkey=result.public_key,
curve_name=self.curve)
key_type, blob = formats.serialize_verifying_key(verifying_key)
assert blob == msg['public_key']['blob']
assert key_type == msg['key_type']
assert len(result.signature) == 65
assert result.signature[:1] == bytearray([0])
return result.signature[1:]
_identity_regexp = re.compile(''.join([
'^'
r'(?:(?P<proto>.*)://)?',
r'(?:(?P<user>.*)@)?',
r'(?P<host>.*?)',
r'(?::(?P<port>\w*))?',
r'(?P<path>/.*)?',
'$'
]))
def string_to_identity(s, identity_type):
"""Parse string into Identity protobuf."""
m = _identity_regexp.match(s)
result = m.groupdict()
log.debug('parsed identity: %s', result)
kwargs = {k: v for k, v in result.items() if v}
return identity_type(**kwargs)
def identity_to_string(identity):
"""Dump Identity protobuf into its string representation."""
result = []
if identity.proto:
result.append(identity.proto + '://')
if identity.user:
result.append(identity.user + '@')
result.append(identity.host)
if identity.port:
result.append(':' + identity.port)
if identity.path:
result.append(identity.path)
return ''.join(result)
def _get_address(identity):
index = struct.pack('<L', identity.index)
addr = index + identity_to_string(identity).encode('ascii')
log.debug('address string: %r', addr)
digest = formats.hashfunc(addr).digest()
s = io.BytesIO(bytearray(digest))
hardened = 0x80000000
address_n = [13] + list(util.recv(s, '<LLLL'))
return [(hardened | value) for value in address_n]
def _parse_ssh_blob(data):
res = {}
i = io.BytesIO(data)
res['nonce'] = util.read_frame(i)
i.read(1) # SSH2_MSG_USERAUTH_REQUEST == 50 (from ssh2.h, line 108)
res['user'] = util.read_frame(i)
res['conn'] = util.read_frame(i)
res['auth'] = util.read_frame(i)
i.read(1) # have_sig == 1 (from sshconnect2.c, line 1056)
res['key_type'] = util.read_frame(i)
public_key = util.read_frame(i)
res['public_key'] = formats.parse_pubkey(public_key)
assert not i.read()
return res

92
trezor_agent/factory.py Normal file
View File

@@ -0,0 +1,92 @@
"""Thin wrapper around trezor/keepkey libraries."""
import binascii
import collections
import logging
import semver
log = logging.getLogger(__name__)
ClientWrapper = collections.namedtuple(
'ClientWrapper',
['connection', 'identity_type', 'device_name', 'call_exception'])
# pylint: disable=too-many-arguments
def _load_client(name, client_type, hid_transport,
passphrase_ack, identity_type,
required_version, call_exception):
def empty_passphrase_handler(_):
return passphrase_ack(passphrase='')
for d in hid_transport.enumerate():
connection = client_type(hid_transport(d))
connection.callback_PassphraseRequest = empty_passphrase_handler
f = connection.features
log.debug('connected to %s %s', name, f.device_id)
log.debug('label : %s', f.label)
log.debug('vendor : %s', f.vendor)
current_version = '{}.{}.{}'.format(f.major_version,
f.minor_version,
f.patch_version)
log.debug('version : %s', current_version)
log.debug('revision : %s', binascii.hexlify(f.revision))
if not semver.match(current_version, required_version):
fmt = 'Please upgrade your {} firmware to {} version (current: {})'
raise ValueError(fmt.format(name,
required_version,
current_version))
yield ClientWrapper(connection=connection,
identity_type=identity_type,
device_name=name,
call_exception=call_exception)
def _load_trezor():
# pylint: disable=import-error
from trezorlib.client import TrezorClient, CallException
from trezorlib.transport_hid import HidTransport
from trezorlib.messages_pb2 import PassphraseAck
from trezorlib.types_pb2 import IdentityType
return _load_client(name='Trezor',
client_type=TrezorClient,
hid_transport=HidTransport,
passphrase_ack=PassphraseAck,
identity_type=IdentityType,
required_version='>=1.3.4',
call_exception=CallException)
def _load_keepkey():
# pylint: disable=import-error
from keepkeylib.client import KeepKeyClient, CallException
from keepkeylib.transport_hid import HidTransport
from keepkeylib.messages_pb2 import PassphraseAck
from keepkeylib.types_pb2 import IdentityType
return _load_client(name='KeepKey',
client_type=KeepKeyClient,
hid_transport=HidTransport,
passphrase_ack=PassphraseAck,
identity_type=IdentityType,
required_version='>=1.0.4',
call_exception=CallException)
LOADERS = [
_load_trezor,
_load_keepkey
]
def load(loaders=None):
"""Load a single device, via specified loaders' list."""
loaders = loaders if loaders is not None else LOADERS
device_list = []
for loader in loaders:
device_list.extend(loader())
if len(device_list) == 1:
return device_list[0]
msg = '{:d} devices found'.format(len(device_list))
raise IOError(msg)

View File

@@ -1,12 +1,14 @@
import io
import hashlib
"""SSH format parsing and formatting tools."""
import base64
import hashlib
import io
import logging
import ecdsa
import ed25519
from . import util
import logging
log = logging.getLogger(__name__)
# Supported ECDSA curves
@@ -26,11 +28,23 @@ hashfunc = hashlib.sha256
def fingerprint(blob):
"""
Compute SSH fingerprint for specified blob.
See https://en.wikipedia.org/wiki/Public_key_fingerprint for details.
"""
digest = hashlib.md5(blob).digest()
return ':'.join('{:02x}'.format(c) for c in bytearray(digest))
def parse_pubkey(blob):
"""
Parse SSH public key from given blob.
Cnstruct a verifier for ECDSA signatures.
The verifier returns the signatures in the required SSH format.
Currently, NIST256P1 and ED25519 elliptic curves are supported.
"""
fp = fingerprint(blob)
s = io.BytesIO(blob)
key_type = util.read_frame(s)
@@ -52,11 +66,11 @@ def parse_pubkey(blob):
curve = ecdsa.NIST256p
point = ecdsa.ellipticcurve.Point(curve.curve, *coords)
vk = ecdsa.VerifyingKey.from_public_point(point, curve, hashfunc)
def ecdsa_verifier(sig, msg):
assert len(sig) == 2 * size
sig_decode = ecdsa.util.sigdecode_string
vk = ecdsa.VerifyingKey.from_public_point(point, curve, hashfunc)
vk.verify(signature=sig, data=msg, sigdecode=sig_decode)
parts = [sig[:size], sig[size:]]
return b''.join([util.frame(b'\x00' + p) for p in parts])
@@ -67,10 +81,10 @@ def parse_pubkey(blob):
if key_type == SSH_ED25519_KEY_TYPE:
pubkey = util.read_frame(s)
assert s.read() == b''
vk = ed25519.VerifyingKey(pubkey)
def ed25519_verify(sig, msg):
assert len(sig) == 64
vk = ed25519.VerifyingKey(pubkey)
vk.verify(sig, msg)
return sig
@@ -79,29 +93,65 @@ def parse_pubkey(blob):
return result
def decompress_pubkey(pub):
if pub[:1] == b'\x00':
def _decompress_ed25519(pubkey):
"""Load public key from the serialized blob (stripping the prefix byte)."""
if pubkey[:1] == b'\x00':
# set by Trezor fsm_msgSignIdentity() and fsm_msgGetPublicKey()
return ed25519.VerifyingKey(pub[1:])
return ed25519.VerifyingKey(pubkey[1:])
if pub[:1] in {b'\x02', b'\x03'}: # set by ecdsa_get_public_key33()
def _decompress_nist256(pubkey):
"""
Load public key from the serialized blob.
The leading byte least-significant bit is used to decide how to recreate
the y-coordinate from the specified x-coordinate. See bitcoin/main.py#L198
(from https://github.com/vbuterin/pybitcointools/) for details.
"""
if pubkey[:1] in {b'\x02', b'\x03'}: # set by ecdsa_get_public_key33()
curve = ecdsa.NIST256p
P = curve.curve.p()
A = curve.curve.a()
B = curve.curve.b()
x = util.bytes2num(pub[1:33])
x = util.bytes2num(pubkey[1:33])
beta = pow(int(x * x * x + A * x + B), int((P + 1) // 4), int(P))
p0 = util.bytes2num(pub[:1])
p0 = util.bytes2num(pubkey[:1])
y = (P - beta) if ((beta + p0) % 2) else beta
point = ecdsa.ellipticcurve.Point(curve.curve, x, y)
return ecdsa.VerifyingKey.from_public_point(point, curve=curve,
hashfunc=hashfunc)
raise ValueError('invalid {!r}', pub)
def decompress_pubkey(pubkey, curve_name):
"""
Load public key from the serialized blob.
Raise ValueError on parsing error.
"""
vk = None
if len(pubkey) == 33:
decompress = {
CURVE_NIST256: _decompress_nist256,
CURVE_ED25519: _decompress_ed25519
}[curve_name]
vk = decompress(pubkey)
if not vk:
msg = 'invalid {!s} public key: {!r}'.format(curve_name, pubkey)
raise ValueError(msg)
return vk
def serialize_verifying_key(vk):
"""
Serialize a public key into SSH format (for exporting to text format).
Currently, NIST256P1 and ED25519 elliptic curves are supported.
Raise TypeError on unsupported key format.
"""
if isinstance(vk, ed25519.keys.VerifyingKey):
pubkey = vk.to_bytes()
key_type = SSH_ED25519_KEY_TYPE
@@ -119,17 +169,21 @@ def serialize_verifying_key(vk):
raise TypeError('unsupported {!r}'.format(vk))
def export_public_key(pubkey, label):
assert len(pubkey) == 33
key_type, blob = serialize_verifying_key(decompress_pubkey(pubkey))
def export_public_key(vk, label):
"""
Export public key to text format.
The resulting string can be written into a .pub file or
appended to the ~/.ssh/authorized_keys file.
"""
key_type, blob = serialize_verifying_key(vk)
log.debug('fingerprint: %s', fingerprint(blob))
b64 = base64.b64encode(blob).decode('ascii')
return '{} {} {}\n'.format(key_type.decode('ascii'), b64, label)
def import_public_key(line):
''' Parse public key textual format, as saved at .pub file '''
"""Parse public key textual format, as saved at a .pub file."""
log.debug('loading SSH public key: %r', line)
file_type, base64blob, name = line.split()
blob = base64.b64decode(base64blob)

View File

@@ -1,68 +1,115 @@
import io
"""
SSH-agent protocol implementation library.
See https://github.com/openssh/openssh-portable/blob/master/PROTOCOL.agent and
http://ptspts.blogspot.co.il/2010/06/how-to-use-ssh-agent-programmatically.html
for more details.
The server's source code can be found here:
https://github.com/openssh/openssh-portable/blob/master/authfd.c
"""
import binascii
from . import util
from . import formats
import io
import logging
from . import formats, util
log = logging.getLogger(__name__)
SSH_AGENTC_REQUEST_RSA_IDENTITIES = 1
SSH_AGENT_RSA_IDENTITIES_ANSWER = 2
SSH2_AGENTC_REQUEST_IDENTITIES = 11
SSH2_AGENT_IDENTITIES_ANSWER = 12
SSH2_AGENTC_SIGN_REQUEST = 13
SSH2_AGENT_SIGN_RESPONSE = 14
# Taken from https://github.com/openssh/openssh-portable/blob/master/authfd.h
COMMANDS = dict(
SSH_AGENTC_REQUEST_RSA_IDENTITIES=1,
SSH_AGENT_RSA_IDENTITIES_ANSWER=2,
SSH_AGENTC_RSA_CHALLENGE=3,
SSH_AGENT_RSA_RESPONSE=4,
SSH_AGENT_FAILURE=5,
SSH_AGENT_SUCCESS=6,
SSH_AGENTC_ADD_RSA_IDENTITY=7,
SSH_AGENTC_REMOVE_RSA_IDENTITY=8,
SSH_AGENTC_REMOVE_ALL_RSA_IDENTITIES=9,
SSH2_AGENTC_REQUEST_IDENTITIES=11,
SSH2_AGENT_IDENTITIES_ANSWER=12,
SSH2_AGENTC_SIGN_REQUEST=13,
SSH2_AGENT_SIGN_RESPONSE=14,
SSH2_AGENTC_ADD_IDENTITY=17,
SSH2_AGENTC_REMOVE_IDENTITY=18,
SSH2_AGENTC_REMOVE_ALL_IDENTITIES=19,
SSH_AGENTC_ADD_SMARTCARD_KEY=20,
SSH_AGENTC_REMOVE_SMARTCARD_KEY=21,
SSH_AGENTC_LOCK=22,
SSH_AGENTC_UNLOCK=23,
SSH_AGENTC_ADD_RSA_ID_CONSTRAINED=24,
SSH2_AGENTC_ADD_ID_CONSTRAINED=25,
SSH_AGENTC_ADD_SMARTCARD_KEY_CONSTRAINED=26,
)
class Error(Exception):
pass
def msg_code(name):
"""Convert string name into a integer message code."""
return COMMANDS[name]
class BadSignature(Error):
pass
def msg_name(code):
"""Convert integer message code into a string name."""
ids = {v: k for k, v in COMMANDS.items()}
return ids[code]
class MissingKey(Error):
pass
def failure():
"""Return error code to SSH binary."""
error_msg = util.pack('B', msg_code('SSH_AGENT_FAILURE'))
return util.frame(error_msg)
def _legacy_pubs(buf):
"""SSH v1 public keys are not supported."""
assert not buf.read()
code = util.pack('B', msg_code('SSH_AGENT_RSA_IDENTITIES_ANSWER'))
num = util.pack('L', 0) # no SSH v1 keys
return util.frame(code, num)
class Handler(object):
"""ssh-agent protocol handler."""
def __init__(self, keys, signer):
def __init__(self, keys, signer, debug=False):
"""
Create a protocol handler with specified public keys.
Use specified signer function to sign SSH authentication requests.
"""
self.public_keys = keys
self.signer = signer
self.debug = debug
self.methods = {
SSH_AGENTC_REQUEST_RSA_IDENTITIES: Handler.legacy_pubs,
SSH2_AGENTC_REQUEST_IDENTITIES: self.list_pubs,
SSH2_AGENTC_SIGN_REQUEST: self.sign_message,
msg_code('SSH_AGENTC_REQUEST_RSA_IDENTITIES'): _legacy_pubs,
msg_code('SSH2_AGENTC_REQUEST_IDENTITIES'): self.list_pubs,
msg_code('SSH2_AGENTC_SIGN_REQUEST'): self.sign_message,
}
def handle(self, msg):
log.debug('request: %d bytes', len(msg))
"""Handle SSH message from the SSH client and return the response."""
debug_msg = ': {!r}'.format(msg) if self.debug else ''
log.debug('request: %d bytes%s', len(msg), debug_msg)
buf = io.BytesIO(msg)
code, = util.recv(buf, '>B')
if code not in self.methods:
log.warning('Unsupported command: %s (%d)', msg_name(code), code)
return failure()
method = self.methods[code]
log.debug('calling %s()', method.__name__)
reply = method(buf=buf)
log.debug('reply: %d bytes', len(reply))
debug_reply = ': {!r}'.format(reply) if self.debug else ''
log.debug('reply: %d bytes%s', len(reply), debug_reply)
return reply
@staticmethod
def legacy_pubs(buf):
''' SSH v1 public keys are not supported '''
assert not buf.read()
code = util.pack('B', SSH_AGENT_RSA_IDENTITIES_ANSWER)
num = util.pack('L', 0) # no SSH v1 keys
return util.frame(code, num)
def list_pubs(self, buf):
''' SSH v2 public keys are serialized and returned. '''
"""SSH v2 public keys are serialized and returned."""
assert not buf.read()
keys = self.public_keys
code = util.pack('B', SSH2_AGENT_IDENTITIES_ANSWER)
code = util.pack('B', msg_code('SSH2_AGENT_IDENTITIES_ANSWER'))
num = util.pack('L', len(keys))
log.debug('available keys: %s', [k['name'] for k in keys])
for i, k in enumerate(keys):
@@ -71,7 +118,12 @@ class Handler(object):
return util.frame(code, num, *pubs)
def sign_message(self, buf):
''' SSH v2 public key authentication is performed. '''
"""
SSH v2 public key authentication is performed.
If the required key is not supported, raise KeyError
If the signature is invalid, rause ValueError
"""
key = formats.parse_pubkey(util.read_frame(buf))
log.debug('looking for %s', key['fingerprint'])
blob = util.read_frame(buf)
@@ -84,10 +136,14 @@ class Handler(object):
key = k
break
else:
raise MissingKey('key not found')
raise KeyError('key not found')
log.debug('signing %d-byte blob', len(blob))
signature = self.signer(label=key['name'], blob=blob)
label = key['name'].decode('ascii') # label should be a string
try:
signature = self.signer(label=label, blob=blob)
except IOError:
return failure()
log.debug('signature: %s', binascii.hexlify(signature))
try:
@@ -95,10 +151,10 @@ class Handler(object):
log.info('signature status: OK')
except formats.ecdsa.BadSignatureError:
log.exception('signature status: ERROR')
raise BadSignature('invalid ECDSA signature')
raise ValueError('invalid ECDSA signature')
log.debug('signature size: %d bytes', len(sig_bytes))
data = util.frame(util.frame(key['type']), util.frame(sig_bytes))
code = util.pack('B', SSH2_AGENT_SIGN_RESPONSE)
code = util.pack('B', msg_code('SSH2_AGENT_SIGN_RESPONSE'))
return util.frame(code, data)

View File

@@ -1,19 +1,21 @@
import socket
"""UNIX-domain socket server for ssh-agent implementation."""
import contextlib
import logging
import os
import socket
import subprocess
import tempfile
import contextlib
import threading
from . import protocol
from . import formats
from . import util
import logging
log = logging.getLogger(__name__)
UNIX_SOCKET_TIMEOUT = 0.1
def remove_file(path, remove=os.remove, exists=os.path.exists):
"""Remove file, and raise OSError if still exists."""
try:
remove(path)
except OSError:
@@ -23,6 +25,11 @@ def remove_file(path, remove=os.remove, exists=os.path.exists):
@contextlib.contextmanager
def unix_domain_socket_server(sock_path):
"""
Create UNIX-domain socket on specified path.
Listen on it, and delete it after the generated context is over.
"""
log.debug('serving on SSH_AUTH_SOCK=%s', sock_path)
remove_file(sock_path)
@@ -36,6 +43,12 @@ def unix_domain_socket_server(sock_path):
def handle_connection(conn, handler):
"""
Handle a single connection using the specified protocol handler in a loop.
Exit when EOFError is raised.
All other exceptions are logged as warnings.
"""
try:
log.debug('welcome agent')
while True:
@@ -44,19 +57,41 @@ def handle_connection(conn, handler):
util.send(conn, reply)
except EOFError:
log.debug('goodbye agent')
except:
log.exception('error')
raise
except Exception as e: # pylint: disable=broad-except
log.warning('error: %s', e, exc_info=True)
def server_thread(server, handler):
log.debug('server thread started')
def retry(func, exception_type, quit_event):
"""
Run the function, retrying when the specified exception_type occurs.
Poll quit_event on each iteration, to be responsive to an external
exit request.
"""
while True:
log.debug('waiting for connection on %s', server.getsockname())
if quit_event.is_set():
raise StopIteration
try:
conn, _ = server.accept()
except socket.error as e:
log.debug('server stopped: %s', e)
return func()
except exception_type:
pass
def server_thread(sock, handler, quit_event):
"""Run a server on the specified socket."""
log.debug('server thread started')
def accept_connection():
conn, _ = sock.accept()
conn.settimeout(None)
return conn
while True:
log.debug('waiting for connection on %s', sock.getsockname())
try:
conn = retry(accept_connection, socket.timeout, quit_event)
except StopIteration:
log.debug('server stopped')
break
with contextlib.closing(conn):
handle_connection(conn, handler)
@@ -64,7 +99,8 @@ def server_thread(server, handler):
@contextlib.contextmanager
def spawn(func, **kwargs):
def spawn(func, kwargs):
"""Spawn a thread, and join it after the context is over."""
t = threading.Thread(target=func, kwargs=kwargs)
t.start()
yield
@@ -72,28 +108,40 @@ def spawn(func, **kwargs):
@contextlib.contextmanager
def serve(public_keys, signer, sock_path=None):
def serve(handler, sock_path=None, timeout=UNIX_SOCKET_TIMEOUT):
"""
Start the ssh-agent server on a UNIX-domain socket.
If no connection is made during the specified timeout,
retry until the context is over.
"""
if sock_path is None:
sock_path = tempfile.mktemp(prefix='ssh-agent-')
keys = [formats.import_public_key(k) for k in public_keys]
environ = {'SSH_AUTH_SOCK': sock_path, 'SSH_AGENT_PID': str(os.getpid())}
with unix_domain_socket_server(sock_path) as server:
handler = protocol.Handler(keys=keys, signer=signer)
with spawn(server_thread, server=server, handler=handler):
with unix_domain_socket_server(sock_path) as sock:
sock.settimeout(timeout)
quit_event = threading.Event()
kwargs = dict(sock=sock, handler=handler, quit_event=quit_event)
with spawn(server_thread, kwargs):
try:
yield environ
finally:
log.debug('closing server')
server.shutdown(socket.SHUT_RD)
quit_event.set()
def run_process(command, environ, use_shell=False):
log.debug('running %r with %r', command, environ)
def run_process(command, environ):
"""
Run the specified process and wait until it finishes.
Use environ dict for environment variables.
"""
log.info('running %r with %r', command, environ)
env = dict(os.environ)
env.update(environ)
try:
p = subprocess.Popen(args=command, env=env, shell=use_shell)
p = subprocess.Popen(args=command, env=env)
except OSError as e:
raise OSError('cannot run %r: %s' % (command, e))
log.debug('subprocess %d is running', p.pid)

View File

@@ -0,0 +1 @@
"""Unit-tests for this package."""

View File

@@ -1,9 +1,9 @@
from ..trezor import client
from .. import formats
import io
import mock
import pytest
from .. import client, factory, formats, util
ADDR = [2147483661, 2810943954, 3938368396, 3454558782, 3848009040]
CURVE = 'nist256p1'
@@ -15,17 +15,9 @@ PUBKEY_TEXT = ('ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzd'
'VUfhvrGljR2Z/CMRONY6ejB+9PnpUOPuzYqi8= ssh://localhost:22\n')
class ConnectionMock(object):
class FakeConnection(object):
def __init__(self, version):
self.features = mock.Mock(spec=[])
self.features.device_id = '123456789'
self.features.label = 'mywallet'
self.features.vendor = 'mock'
self.features.major_version = version[0]
self.features.minor_version = version[1]
self.features.patch_version = version[2]
self.features.revision = b'456'
def __init__(self):
self.closed = False
def close(self):
@@ -48,21 +40,21 @@ class ConnectionMock(object):
return msg
class FactoryMock(object):
def identity_type(**kwargs):
result = mock.Mock(spec=[])
result.index = 0
result.proto = result.user = result.host = result.port = None
result.path = None
for k, v in kwargs.items():
setattr(result, k, v)
return result
@staticmethod
def client():
return ConnectionMock(version=(1, 3, 4))
@staticmethod
def identity_type(**kwargs):
result = mock.Mock(spec=[])
result.index = 0
result.proto = result.user = result.host = result.port = None
result.path = None
for k, v in kwargs.items():
setattr(result, k, v)
return result
def load_client():
return factory.ClientWrapper(connection=FakeConnection(),
identity_type=identity_type,
device_name='DEVICE_NAME',
call_exception=Exception)
BLOB = (b'\x00\x00\x00 \xce\xe0\xc9\xd5\xceu/\xe8\xc5\xf2\xbfR+x\xa1\xcf\xb0'
@@ -81,21 +73,24 @@ SIG = (b'\x00R\x19T\xf2\x84$\xef#\x0e\xee\x04X\xc6\xc3\x99T`\xd1\xd8\xf7!'
def test_ssh_agent():
label = 'localhost:22'
c = client.Client(factory=FactoryMock)
c = client.Client(loader=load_client)
ident = c.get_identity(label=label)
assert ident.host == 'localhost'
assert ident.proto == 'ssh'
assert ident.port == '22'
assert ident.user is None
assert ident.path is None
assert ident.index == 0
with c:
assert c.get_public_key(label) == PUBKEY_TEXT
def ssh_sign_identity(identity, challenge_hidden,
challenge_visual, ecdsa_curve_name):
assert (client.identity_to_string(identity) ==
client.identity_to_string(ident))
assert challenge_hidden == BLOB
assert challenge_visual == identity.path
assert challenge_visual == 'VISUAL'
assert ecdsa_curve_name == b'nist256p1'
result = mock.Mock(spec=[])
@@ -104,10 +99,30 @@ def test_ssh_agent():
return result
c.client.sign_identity = ssh_sign_identity
signature = c.sign_ssh_challenge(label=label, blob=BLOB)
signature = c.sign_ssh_challenge(label=label, blob=BLOB,
visual='VISUAL')
key = formats.import_public_key(PUBKEY_TEXT)
assert key['verifier'](sig=signature, msg=BLOB)
serialized_sig = key['verifier'](sig=signature, msg=BLOB)
stream = io.BytesIO(serialized_sig)
r = util.read_frame(stream)
s = util.read_frame(stream)
assert not stream.read()
assert r[:1] == b'\x00'
assert s[:1] == b'\x00'
assert r[1:] + s[1:] == SIG[1:]
c.client.call_exception = ValueError
# pylint: disable=unused-argument
def cancel_sign_identity(identity, challenge_hidden,
challenge_visual, ecdsa_curve_name):
raise c.client.call_exception(42, 'ERROR')
c.client.sign_identity = cancel_sign_identity
with pytest.raises(IOError):
c.sign_ssh_challenge(label=label, blob=BLOB, visual='VISUAL')
def test_utils():
@@ -120,15 +135,3 @@ def test_utils():
url = 'https://user@host:443/path'
assert client.identity_to_string(identity) == url
def test_old_version():
class OldFactoryMock(FactoryMock):
@staticmethod
def client():
return ConnectionMock(version=(1, 2, 3))
with pytest.raises(ValueError):
client.Client(factory=OldFactoryMock)

View File

@@ -0,0 +1,97 @@
import mock
import pytest
from .. import factory
def test_load():
def single():
return [0]
def nothing():
return []
def double():
return [1, 2]
assert factory.load(loaders=[single]) == 0
assert factory.load(loaders=[single, nothing]) == 0
assert factory.load(loaders=[nothing, single]) == 0
with pytest.raises(IOError):
factory.load(loaders=[])
with pytest.raises(IOError):
factory.load(loaders=[single, single])
with pytest.raises(IOError):
factory.load(loaders=[double])
def factory_load_client(**kwargs):
# pylint: disable=protected-access
return list(factory._load_client(**kwargs))
def test_load_nothing():
hid_transport = mock.Mock(spec_set=['enumerate'])
hid_transport.enumerate.return_value = []
result = factory_load_client(
name=None,
client_type=None,
hid_transport=hid_transport,
passphrase_ack=None,
identity_type=None,
required_version=None,
call_exception=None)
assert result == []
def create_client_type(version):
conn = mock.Mock(spec=[])
conn.features = mock.Mock(spec=[])
major, minor, patch = version.split('.')
conn.features.device_id = 'DEVICE_ID'
conn.features.label = 'LABEL'
conn.features.vendor = 'VENDOR'
conn.features.major_version = major
conn.features.minor_version = minor
conn.features.patch_version = patch
conn.features.revision = b'\x12\x34\x56\x78'
return mock.Mock(spec_set=[], return_value=conn)
def test_load_single():
hid_transport = mock.Mock(spec_set=['enumerate'])
hid_transport.enumerate.return_value = [0]
for version in ('1.3.4', '1.3.5', '1.4.0', '2.0.0'):
passphrase_ack = mock.Mock(spec_set=[])
client_type = create_client_type(version)
client_wrapper, = factory_load_client(
name='DEVICE_NAME',
client_type=client_type,
hid_transport=hid_transport,
passphrase_ack=passphrase_ack,
identity_type=None,
required_version='>=1.3.4',
call_exception=None)
assert client_wrapper.connection is client_type.return_value
assert client_wrapper.device_name == 'DEVICE_NAME'
client_wrapper.connection.callback_PassphraseRequest('MESSAGE')
assert passphrase_ack.mock_calls == [mock.call(passphrase='')]
def test_load_old():
hid_transport = mock.Mock(spec_set=['enumerate'])
hid_transport.enumerate.return_value = [0]
for version in ('1.3.3', '1.2.5', '1.1.0', '0.9.9'):
with pytest.raises(ValueError):
factory_load_client(
name='DEVICE_NAME',
client_type=create_client_type(version),
hid_transport=hid_transport,
passphrase_ack=None,
identity_type=None,
required_version='>=1.3.4',
call_exception=None)

View File

@@ -1,4 +1,5 @@
import binascii
import pytest
from .. import formats
@@ -35,8 +36,9 @@ def test_parse_public_key():
def test_decompress():
blob = '036236ceabde25207e81e404586e3a3af1acda1dfed2abbbb4876c1fc5b296b575'
result = formats.export_public_key(binascii.unhexlify(blob), label='home')
assert result == _public_key
vk = formats.decompress_pubkey(binascii.unhexlify(blob),
curve_name=formats.CURVE_NIST256)
assert formats.export_public_key(vk, label='home') == _public_key
def test_parse_ed25519():
@@ -57,7 +59,7 @@ def test_parse_ed25519():
def test_export_ed25519():
pub = (b'\x00P]\x17kc}#\xbc\x9c\xb2"\xef~\xa2\xb3\xe7\xf4'
b'z\xba\xb6\xf1\x14\xdc\xec)\x0c\xd7SY\xb52\x91')
vk = formats.decompress_pubkey(pub)
vk = formats.decompress_pubkey(pub, formats.CURVE_ED25519)
result = formats.serialize_verifying_key(vk)
assert result == (b'ssh-ed25519',
b'\x00\x00\x00\x0bssh-ed25519\x00\x00\x00 P]\x17kc}#\xbc'
@@ -67,7 +69,25 @@ def test_export_ed25519():
def test_decompress_error():
with pytest.raises(ValueError):
formats.decompress_pubkey('')
formats.decompress_pubkey('', formats.CURVE_NIST256)
def test_curve_mismatch():
# NIST256 public key
blob = '036236ceabde25207e81e404586e3a3af1acda1dfed2abbbb4876c1fc5b296b575'
with pytest.raises(ValueError):
formats.decompress_pubkey(binascii.unhexlify(blob),
curve_name=formats.CURVE_ED25519)
blob = '00' * 33 # Dummy public key
with pytest.raises(ValueError):
formats.decompress_pubkey(binascii.unhexlify(blob),
curve_name=formats.CURVE_NIST256)
blob = 'FF' * 33 # Unsupported prefix byte
with pytest.raises(ValueError):
formats.decompress_pubkey(binascii.unhexlify(blob),
curve_name=formats.CURVE_NIST256)
def test_serialize_error():

View File

@@ -1,8 +1,7 @@
from .. import protocol
from .. import formats
import pytest
from .. import formats, protocol
# pylint: disable=line-too-long
NIST256_KEY = 'ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBEUksojS/qRlTKBKLQO7CBX7a7oqFkysuFn1nJ6gzlR3wNuQXEgd7qb2bjmiiBHsjNxyWvH5SxVi3+fghrqODWo= ssh://localhost' # nopep8
@@ -23,8 +22,14 @@ def test_list():
assert reply == LIST_NIST256_REPLY
def test_unsupported():
h = protocol.Handler(keys=[], signer=None)
reply = h.handle(b'\x09')
assert reply == b'\x00\x00\x00\x01\x05'
def ecdsa_signer(label, blob):
assert label == b'ssh://localhost'
assert label == 'ssh://localhost'
assert blob == NIST256_BLOB
return NIST256_SIG
@@ -39,23 +44,33 @@ def test_ecdsa_sign():
def test_sign_missing():
h = protocol.Handler(keys=[], signer=ecdsa_signer)
with pytest.raises(protocol.MissingKey):
with pytest.raises(KeyError):
h.handle(NIST256_SIGN_MSG)
def test_sign_wrong():
def wrong_signature(label, blob):
assert label == b'ssh://localhost'
assert label == 'ssh://localhost'
assert blob == NIST256_BLOB
return b'\x00' * 64
key = formats.import_public_key(NIST256_KEY)
h = protocol.Handler(keys=[key], signer=wrong_signature)
with pytest.raises(protocol.BadSignature):
with pytest.raises(ValueError):
h.handle(NIST256_SIGN_MSG)
def test_sign_cancel():
def cancel_signature(label, blob): # pylint: disable=unused-argument
raise IOError()
key = formats.import_public_key(NIST256_KEY)
h = protocol.Handler(keys=[key], signer=cancel_signature)
assert h.handle(NIST256_SIGN_MSG) == protocol.failure()
ED25519_KEY = 'ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIFBdF2tjfSO8nLIi736is+f0erq28RTc7CkM11NZtTKR ssh://localhost' # nopep8
ED25519_SIGN_MSG = b'''\r\x00\x00\x003\x00\x00\x00\x0bssh-ed25519\x00\x00\x00 P]\x17kc}#\xbc\x9c\xb2"\xef~\xa2\xb3\xe7\xf4z\xba\xb6\xf1\x14\xdc\xec)\x0c\xd7SY\xb52\x91\x00\x00\x00\x94\x00\x00\x00 i3\xae}yk\\\xa1L\xb9\xe1\xbf\xbc\x8e\x87\r\x0e\xc0\x9f\x97\x0fTC!\x80\x07\x91\xdb^8\xc1\xd62\x00\x00\x00\x05roman\x00\x00\x00\x0essh-connection\x00\x00\x00\tpublickey\x01\x00\x00\x00\x0bssh-ed25519\x00\x00\x003\x00\x00\x00\x0bssh-ed25519\x00\x00\x00 P]\x17kc}#\xbc\x9c\xb2"\xef~\xa2\xb3\xe7\xf4z\xba\xb6\xf1\x14\xdc\xec)\x0c\xd7SY\xb52\x91\x00\x00\x00\x00''' # nopep8
ED25519_SIGN_REPLY = b'''\x00\x00\x00X\x0e\x00\x00\x00S\x00\x00\x00\x0bssh-ed25519\x00\x00\x00@\x8eb)\xa6\xe9P\x83VE\xfbq\xc6\xbf\x1dV3\xe3<O\x11\xc0\xfa\xe4\xed\xb8\x81.\x81\xc8\xa6\xba\x10RA'a\xbc\xa9\xd3\xdb\x98\x07\xf0\x1a\x9c4\x84<\xaf\x99\xb7\xe5G\xeb\xf7$\xc1\r\x86f\x16\x8e\x08\x05''' # nopep8
@@ -65,7 +80,7 @@ ED25519_SIG = b'''\x8eb)\xa6\xe9P\x83VE\xfbq\xc6\xbf\x1dV3\xe3<O\x11\xc0\xfa\xe4
def ed25519_signer(label, blob):
assert label == b'ssh://localhost'
assert label == 'ssh://localhost'
assert blob == ED25519_BLOB
return ED25519_SIG

View File

@@ -1,12 +1,13 @@
import tempfile
import socket
import os
import io
import os
import socket
import tempfile
import threading
import mock
import pytest
from .. import server
from .. import protocol
from .. import util
from .. import protocol, server, util
def test_socket():
@@ -16,7 +17,7 @@ def test_socket():
assert not os.path.isfile(path)
class SocketMock(object):
class FakeSocket(object):
def __init__(self, data=b''):
self.rx = io.BytesIO(data)
@@ -31,45 +32,55 @@ class SocketMock(object):
def close(self):
pass
def settimeout(self, value):
pass
def test_handle():
handler = protocol.Handler(keys=[], signer=None)
conn = SocketMock()
conn = FakeSocket()
server.handle_connection(conn, handler)
msg = bytearray([protocol.SSH_AGENTC_REQUEST_RSA_IDENTITIES])
conn = SocketMock(util.frame(msg))
msg = bytearray([protocol.msg_code('SSH_AGENTC_REQUEST_RSA_IDENTITIES')])
conn = FakeSocket(util.frame(msg))
server.handle_connection(conn, handler)
assert conn.tx.getvalue() == b'\x00\x00\x00\x05\x02\x00\x00\x00\x00'
msg = bytearray([protocol.SSH2_AGENTC_REQUEST_IDENTITIES])
conn = SocketMock(util.frame(msg))
msg = bytearray([protocol.msg_code('SSH2_AGENTC_REQUEST_IDENTITIES')])
conn = FakeSocket(util.frame(msg))
server.handle_connection(conn, handler)
assert conn.tx.getvalue() == b'\x00\x00\x00\x05\x0C\x00\x00\x00\x00'
with pytest.raises(AttributeError):
server.handle_connection(conn=None, handler=None)
msg = bytearray([protocol.msg_code('SSH2_AGENTC_ADD_IDENTITY')])
conn = FakeSocket(util.frame(msg))
server.handle_connection(conn, handler)
conn.tx.seek(0)
reply = util.read_frame(conn.tx)
assert reply == util.pack('B', protocol.msg_code('SSH_AGENT_FAILURE'))
class ServerMock(object):
def __init__(self, connections, name):
self.connections = connections
self.name = name
def getsockname(self):
return self.name
def accept(self):
if self.connections:
return self.connections.pop(), 'address'
raise socket.error('stop')
conn_mock = mock.Mock(spec=FakeSocket)
conn_mock.recv.side_effect = [Exception, EOFError]
server.handle_connection(conn=conn_mock, handler=None)
def test_server_thread():
s = ServerMock(connections=[SocketMock()], name='mock')
h = protocol.Handler(keys=[], signer=None)
server.server_thread(s, h)
connections = [FakeSocket()]
quit_event = threading.Event()
class FakeServer(object):
def accept(self): # pylint: disable=no-self-use
if connections:
return connections.pop(), 'address'
quit_event.set()
raise socket.timeout()
def getsockname(self): # pylint: disable=no-self-use
return 'fake_server'
server.server_thread(sock=FakeServer(),
handler=protocol.Handler(keys=[], signer=None),
quit_event=quit_event)
def test_spawn():
@@ -78,7 +89,7 @@ def test_spawn():
def thread(x):
obj.append(x)
with server.spawn(thread, x=1):
with server.spawn(thread, dict(x=1)):
pass
assert obj == [1]
@@ -87,17 +98,16 @@ def test_spawn():
def test_run():
assert server.run_process(['true'], environ={}) == 0
assert server.run_process(['false'], environ={}) == 1
assert server.run_process(
command='exit $X',
environ={'X': '42'},
use_shell=True) == 42
assert server.run_process(command=['bash', '-c', 'exit $X'],
environ={'X': '42'}) == 42
with pytest.raises(OSError):
server.run_process([''], environ={})
def test_serve_main():
with server.serve(public_keys=[], signer=None, sock_path=None):
handler = protocol.Handler(keys=[], signer=None)
with server.serve(handler=handler, sock_path=None):
pass

View File

@@ -1,4 +1,5 @@
import io
import pytest
from .. import util
@@ -23,7 +24,7 @@ def test_frames():
assert util.read_frame(io.BytesIO(f)) == b''.join(msgs)
class SocketMock(object):
class FakeSocket(object):
def __init__(self):
self.buf = io.BytesIO()
@@ -35,9 +36,9 @@ class SocketMock(object):
def test_send_recv():
s = SocketMock()
s = FakeSocket()
util.send(s, b'123')
util.send(s, data=[42], fmt='B')
util.send(s, b'*')
assert s.buf.getvalue() == b'123*'
s.buf.seek(0)

View File

@@ -1 +0,0 @@
from .client import Client

View File

@@ -1,23 +0,0 @@
''' Thin wrapper around trezorlib. '''
def client():
# pylint: disable=import-error
from trezorlib.client import TrezorClient
from trezorlib.transport_hid import HidTransport
from trezorlib.messages_pb2 import PassphraseAck
devices = HidTransport.enumerate()
if len(devices) != 1:
msg = '{:d} Trezor devices found'.format(len(devices))
raise IOError(msg)
t = TrezorClient(HidTransport(devices[0]))
t.callback_PassphraseRequest = lambda msg: PassphraseAck(passphrase='')
return t
def identity_type(**kwargs):
# pylint: disable=import-error
from trezorlib.types_pb2 import IdentityType
return IdentityType(**kwargs)

View File

@@ -1,148 +0,0 @@
import io
import re
import struct
import binascii
from .. import util
from .. import formats
from . import _factory as TrezorFactory
import logging
log = logging.getLogger(__name__)
class Client(object):
MIN_VERSION = [1, 3, 4]
def __init__(self, factory=TrezorFactory, curve=formats.CURVE_NIST256):
self.curve = curve
self.factory = factory
self.client = self.factory.client()
f = self.client.features
log.debug('connected to Trezor %s', f.device_id)
log.debug('label : %s', f.label)
log.debug('vendor : %s', f.vendor)
version = [f.major_version, f.minor_version, f.patch_version]
version_str = '.'.join([str(v) for v in version])
log.debug('version : %s', version_str)
log.debug('revision : %s', binascii.hexlify(f.revision))
if version < self.MIN_VERSION:
fmt = 'Please upgrade your TREZOR to v{}+ firmware'
version_str = '.'.join([str(v) for v in self.MIN_VERSION])
raise ValueError(fmt.format(version_str))
def __enter__(self):
msg = 'Hello World!'
assert self.client.ping(msg) == msg
return self
def __exit__(self, *args):
log.info('disconnected from Trezor')
self.client.clear_session() # forget PIN and shutdown screen
self.client.close()
def get_identity(self, label):
identity = string_to_identity(label, self.factory.identity_type)
identity.proto = 'ssh'
return identity
def get_public_key(self, label):
identity = self.get_identity(label=label)
label = identity_to_string(identity) # canonize key label
log.info('getting "%s" public key (%s) from Trezor...',
label, self.curve)
addr = _get_address(identity)
node = self.client.get_public_node(n=addr,
ecdsa_curve_name=self.curve)
pubkey = node.node.public_key
return formats.export_public_key(pubkey=pubkey, label=label)
def sign_ssh_challenge(self, label, blob):
identity = self.get_identity(label=label)
msg = _parse_ssh_blob(blob)
log.info('please confirm user "%s" login to "%s" using Trezor...',
msg['user'], label)
visual = identity.path # not signed when proto='ssh'
result = self.client.sign_identity(identity=identity,
challenge_hidden=blob,
challenge_visual=visual,
ecdsa_curve_name=self.curve)
verifying_key = formats.decompress_pubkey(result.public_key)
key_type, blob = formats.serialize_verifying_key(verifying_key)
assert blob == msg['public_key']['blob']
assert key_type == msg['key_type']
assert len(result.signature) == 65
assert result.signature[:1] == bytearray([0])
return result.signature[1:]
_identity_regexp = re.compile(''.join([
'^'
r'(?:(?P<proto>.*)://)?',
r'(?:(?P<user>.*)@)?',
r'(?P<host>.*?)',
r'(?::(?P<port>\w*))?',
r'(?P<path>/.*)?',
'$'
]))
def string_to_identity(s, identity_type):
m = _identity_regexp.match(s)
result = m.groupdict()
log.debug('parsed identity: %s', result)
kwargs = {k: v for k, v in result.items() if v}
return identity_type(**kwargs)
def identity_to_string(identity):
result = []
if identity.proto:
result.append(identity.proto + '://')
if identity.user:
result.append(identity.user + '@')
result.append(identity.host)
if identity.port:
result.append(':' + identity.port)
if identity.path:
result.append(identity.path)
return ''.join(result)
def _get_address(identity):
index = struct.pack('<L', identity.index)
addr = index + identity_to_string(identity).encode('ascii')
log.debug('address string: %r', addr)
digest = formats.hashfunc(addr).digest()
s = io.BytesIO(bytearray(digest))
hardened = 0x80000000
address_n = [13] + list(util.recv(s, '<LLLL'))
return [(hardened | value) for value in address_n]
def _parse_ssh_blob(data):
res = {}
if data:
i = io.BytesIO(data)
res['nonce'] = util.read_frame(i)
i.read(1) # TBD
res['user'] = util.read_frame(i)
res['conn'] = util.read_frame(i)
res['auth'] = util.read_frame(i)
i.read(1) # TBD
res['key_type'] = util.read_frame(i)
public_key = util.read_frame(i)
res['public_key'] = formats.parse_pubkey(public_key)
assert not i.read()
log.debug('%s: user %r via %r (%r)',
res['conn'], res['user'], res['auth'], res['key_type'])
log.debug('nonce: %s', binascii.hexlify(res['nonce']))
log.debug('fingerprint: %s', res['public_key']['fingerprint'])
return res

View File

@@ -1,14 +1,20 @@
import struct
"""Various I/O and serialization utilities."""
import io
import struct
def send(conn, data, fmt=None):
if fmt:
data = struct.pack(fmt, *data)
def send(conn, data):
"""Send data blob to connection socket."""
conn.sendall(data)
def recv(conn, size):
"""
Receive bytes from connection socket or stream.
If size is struct.calcsize()-compatible format, use it to unpack the data.
Otherwise, return the plain blob as bytes.
"""
try:
fmt = size
size = struct.calcsize(fmt)
@@ -34,11 +40,13 @@ def recv(conn, size):
def read_frame(conn):
"""Read size-prefixed frame from connection."""
size, = recv(conn, '>L')
return recv(conn, size)
def bytes2num(s):
"""Convert MSB-first bytes to an unsigned integer."""
res = 0
for i, c in enumerate(reversed(bytearray(s))):
res += c << (i * 8)
@@ -46,6 +54,7 @@ def bytes2num(s):
def num2bytes(value, size):
"""Convert an unsigned integer to MSB-first bytes with specified size."""
res = []
for _ in range(size):
res.append(value & 0xFF)
@@ -55,10 +64,12 @@ def num2bytes(value, size):
def pack(fmt, *args):
"""Serialize MSB-first message."""
return struct.pack('>' + fmt, *args)
def frame(*msgs):
"""Serialize MSB-first length-prefixed frame."""
res = io.BytesIO()
for msg in msgs:
res.write(msg)