Compare commits

...

88 Commits
v0.4 ... v0.6.2

Author SHA1 Message Date
Roman Zeyde
7f6bb12b24 bump version 2016-03-05 11:20:11 +02:00
Roman Zeyde
98e875562e main: add trezor-git entry point 2016-03-05 11:18:24 +02:00
Roman Zeyde
4384b93c19 main: remove unneeded use_shell parameter 2016-03-05 11:03:10 +02:00
Roman Zeyde
8a90a8cd84 main: split git from ssh 2016-03-05 10:56:30 +02:00
Roman Zeyde
1e86983782 main: split argument parser 2016-03-05 10:46:36 +02:00
Roman Zeyde
c63201c90c client: show visual challenge 2016-03-05 10:39:47 +02:00
Roman Zeyde
19b00dc427 client: add logging for challenge sizes 2016-02-27 20:09:03 +02:00
Roman Zeyde
aa35981980 README: add 'apt-get' to installation section 2016-02-27 09:49:15 +02:00
Roman Zeyde
8909b38107 main: use command-line for git interaction 2016-02-20 18:24:14 +02:00
Roman Zeyde
6d9aa9cb8a README: license badge is broken most of the time 2016-02-19 20:54:36 +02:00
Roman Zeyde
d6532311b9 fix PEP8 & docstrings 2016-02-19 20:52:59 +02:00
Roman Zeyde
41b30b42b5 main: add git identity via "origin" remote 2016-02-19 20:48:16 +02:00
Roman Zeyde
5b0e56697f travis: add pydocstyle 2016-02-19 11:41:05 +02:00
Roman Zeyde
0e6d998b4c tox: add pydocstyle 2016-02-19 11:39:12 +02:00
Roman Zeyde
2c7fabfa35 tests: add docstrings 2016-02-19 11:35:34 +02:00
Roman Zeyde
1adccdbfe6 __init__: add docstrings 2016-02-19 11:35:27 +02:00
Roman Zeyde
04f4bbf2ac main: add docstrings 2016-02-19 11:35:16 +02:00
Roman Zeyde
bbe963d0ff util: rename UTs 2016-02-19 11:34:58 +02:00
Roman Zeyde
c49514754b util: add docstrings 2016-02-19 11:34:20 +02:00
Roman Zeyde
2ebefff909 server: add docstrings 2016-02-19 11:19:01 +02:00
Roman Zeyde
21e89014c9 protocol: add docstrings and replace custom exceptions 2016-02-19 10:49:39 +02:00
Roman Zeyde
566e4310e1 formats: add docstrings 2016-02-19 10:40:39 +02:00
Roman Zeyde
e1441518d4 factory: add docstrings 2016-02-19 10:08:36 +02:00
Roman Zeyde
5cb12a43de client: add docstrings 2016-02-19 10:07:33 +02:00
Roman Zeyde
df607f3665 pylint: add 'no-member' check 2016-02-18 14:28:16 +02:00
Roman Zeyde
d712509a4e client: show current time instead of identity.path 2016-02-17 15:04:10 +02:00
Roman Zeyde
40e2d9fb2c fixup imports order
isort -rc trezor_agent
2016-02-15 20:53:14 +02:00
Roman Zeyde
cd4cc059d6 main: remove git-config parsing code 2016-02-15 20:52:44 +02:00
Roman Zeyde
2b047f0525 main: refactor shell flag 2016-02-15 20:38:34 +02:00
Roman Zeyde
64776fd294 rename client test 2016-02-15 17:22:57 +02:00
Roman Zeyde
231995bd1a remove trezor module 2016-02-15 17:22:01 +02:00
Roman Zeyde
ff76f17c02 client: elaborate SSH blob parsing 2016-02-13 20:26:23 +02:00
Roman Zeyde
963e80b49b client: move logging from parsing code 2016-02-06 18:32:51 +02:00
Roman Zeyde
dee13b75ea client: remove unneeded 'if' 2016-02-06 18:27:46 +02:00
Roman Zeyde
be86507e00 client: pass index as default argument 2016-02-06 17:52:49 +02:00
Roman Zeyde
2f2663ef94 client: set identity index explicitly 2016-02-06 17:51:57 +02:00
Roman Zeyde
cafa218e19 server: pass handler and add debug option 2016-01-26 21:14:52 +02:00
Roman Zeyde
50b627ed45 protocol: allow debugging SSH message handler 2016-01-26 21:14:27 +02:00
Roman Zeyde
7f36097c15 tests: refactor mocks and fakes 2016-01-22 12:04:24 +02:00
Roman Zeyde
a4b905cd6f bump version 2016-01-19 22:56:54 +02:00
Roman Zeyde
2eff21f96c factory: refactor for easier testing 2016-01-19 22:52:52 +02:00
Roman Zeyde
9afd07e867 server: make sure accepted UNIX sockets are blocking
It was a problem on Mac OS X, where sometimes we got EAGAIN
errors from calling socket.recv() on them.
2016-01-18 22:49:27 +02:00
Roman Zeyde
b101281a5b main: add command-line argument for setting UNIX socket timeout 2016-01-16 22:14:36 +02:00
Roman Zeyde
8c6ac43cf4 Merge Trezor and KeepKey functionality 2016-01-15 13:20:38 +02:00
Kenneth Heutmaker
5932a89dc5 Make it work with KeepKey 2016-01-14 13:28:32 -08:00
Roman Zeyde
2009160ff2 Revert "travis: test with tox"
This reverts commit 3d8072522c.
2016-01-09 17:46:07 +02:00
Roman Zeyde
3d8072522c travis: test with tox 2016-01-09 17:41:17 +02:00
Roman Zeyde
0c63aef719 sort imports using isort tool 2016-01-09 16:06:47 +02:00
Roman Zeyde
c454114c4e README: add gitter chat 2016-01-09 12:15:43 +02:00
Roman Zeyde
f9133f7e05 README: fixup license link 2016-01-09 11:19:33 +02:00
Roman Zeyde
33a6951a96 server: don't crash after single exception 2016-01-08 20:46:49 +02:00
Roman Zeyde
fb0d0a5f61 server: stop the server via a threading.Event
It seems that Mac OS does not support calling socket.shutdown(socket.SHUT_RD)
on a listening socket (see https://github.com/romanz/trezor-agent/issues/6).
The following implementation will set the accept() timeout to 0.1s and stop
the server if a threading.Event (named "quit_event") is set by the main thread.
2016-01-08 20:28:38 +02:00
Roman Zeyde
7ea20c7009 test_trezor: verify serialized signature 2016-01-08 17:30:08 +02:00
Roman Zeyde
4247558166 README: add subshell demo 2016-01-08 16:07:29 +02:00
Roman Zeyde
fe1e1d2bb9 server: log command with INFO level 2016-01-08 16:04:57 +02:00
Roman Zeyde
1a5b8118ad setup.py: support for Python 3.4 2016-01-05 20:46:55 +02:00
Roman Zeyde
3a806c6d77 beta release 2016-01-05 19:54:20 +02:00
Roman Zeyde
3b61f86c25 README: fixup license to match the repository 2016-01-05 18:49:49 +02:00
Roman Zeyde
06d84c387c bump version 2016-01-04 22:49:28 +02:00
Roman Zeyde
8347142a99 setup.py: fixup license to match the repository 2016-01-04 21:26:17 +02:00
Roman Zeyde
7dabe2c555 test_protocol: fix bytes->str 2016-01-04 21:03:46 +02:00
Roman Zeyde
d6ee3d8995 tox: add py34 2016-01-04 21:03:27 +02:00
Roman Zeyde
c3fa79e450 Fix a few pylint issues 2016-01-04 19:21:56 +02:00
Roman Zeyde
15b10c9a7e bump version 2016-01-04 19:05:43 +02:00
Roman Zeyde
e19d76398e formats: verify public key according to requested ECDSA curve 2015-12-18 16:04:20 +02:00
Roman Zeyde
535b4d50c7 Fix SSH connection arguments handling 2015-11-27 17:26:06 +02:00
Roman Zeyde
461f38d599 travis: fix up dependency 2015-10-27 19:57:53 +02:00
Roman Zeyde
60571e65dd trezor: add support for Ed25519 SSH keys 2015-10-27 19:49:30 +02:00
Roman Zeyde
34cecb276a README: fix URL 2015-09-19 14:31:40 +03:00
Roman Zeyde
903ba919b3 README: fix whitespace 2015-09-19 14:15:17 +03:00
Roman Zeyde
3184d34440 README: update badges and blog post 2015-09-19 14:14:50 +03:00
Roman Zeyde
d7099cb863 bump version 2015-09-16 22:03:15 +03:00
Roman Zeyde
e3f04f3389 Merge pull request #2 from romanz/pr
trezor: don't ask for passphrase (always use empty one)
2015-09-16 21:59:31 +03:00
Roman Zeyde
e59404737d trezor: fix PEP8 2015-09-16 21:57:48 +03:00
Pavol Rusnak
ca30707789 don't ask for passphrase (always use empty one similarly to TREZOR Connect) 2015-09-16 15:32:47 +02:00
Roman Zeyde
5449411d09 README: update trezorlib version 2015-09-06 11:50:45 +03:00
Roman Zeyde
697d22fede bump version 2015-09-06 11:48:32 +03:00
Roman Zeyde
4f94c9459c setup.py: require up-to-date ecdsa and trezor packages 2015-09-06 11:47:19 +03:00
Roman Zeyde
f5577e1c15 README: verify firmware version 2015-09-04 22:20:33 +03:00
Roman Zeyde
803e3bb738 client: require TREZOR v1.3.4 firmware for SSH NIST256P1 curve support 2015-09-04 13:07:35 +03:00
Roman Zeyde
c11245ea69 README: fixup SSH example 2015-09-02 15:16:21 +03:00
Roman Zeyde
7b5dd3a51b README: update SSH pubkey handling demo 2015-09-02 15:15:06 +03:00
Roman Zeyde
4199c79074 README: update SSH example 2015-09-02 15:12:33 +03:00
Roman Zeyde
38fd938fd4 travis: test on Python 3.4 2015-08-24 16:07:39 +03:00
Roman Zeyde
ad35e03a9f README: add travis badge 2015-08-24 15:14:46 +03:00
Roman Zeyde
dd6fded82d travis: test without trezorlib 2015-08-24 15:13:28 +03:00
Roman Zeyde
8547d00b33 README: fix naming 2015-08-24 14:49:51 +03:00
Roman Zeyde
a4a0c6a802 README: expand 2015-08-24 14:49:05 +03:00
22 changed files with 1021 additions and 468 deletions

19
.travis.yml Normal file
View File

@@ -0,0 +1,19 @@
sudo: false
language: python
python:
- "2.7"
- "3.4"
install:
- pip install ecdsa ed25519 semver # test without trezorlib for now
- pip install pylint coverage pep8 pydocstyle
script:
- pep8 trezor_agent
- pylint --reports=no --rcfile .pylintrc trezor_agent
- pydocstyle trezor_agent
- coverage run --source trezor_agent/ -m py.test -v
after_success:
- coverage report

View File

@@ -1,4 +1,76 @@
# Using Trezor as a hardware SSH agent
# Using TREZOR as a hardware SSH agent
[![Build Status](https://travis-ci.org/romanz/trezor-agent.svg?branch=master)](https://travis-ci.org/romanz/trezor-agent)
[![Python Versions](https://img.shields.io/pypi/pyversions/trezor_agent.svg)](https://pypi.python.org/pypi/trezor_agent/)
[![Package Version](https://img.shields.io/pypi/v/trezor_agent.svg)](https://pypi.python.org/pypi/trezor_agent/)
[![Development Status](https://img.shields.io/pypi/status/trezor_agent.svg)](https://pypi.python.org/pypi/trezor_agent/)
[![Downloads](https://img.shields.io/pypi/dm/trezor_agent.svg)](https://pypi.python.org/pypi/trezor_agent/)
[![Chat](https://badges.gitter.im/romanz/trezor-agent.svg)](https://gitter.im/romanz/trezor-agent)
See SatoshiLabs' blog post about this feature:
- https://medium.com/@satoshilabs/trezor-firmware-1-3-4-enables-ssh-login-86a622d7e609
## Screencast demo usage
### Simple usage (single SSH session)
[![Demo](https://asciinema.org/a/22959.png)](https://asciinema.org/a/22959)
### Advanced usage (multiple SSH sessions from a sub-shell)
[![Subshell](https://asciinema.org/a/33240.png)](https://asciinema.org/a/33240)
## Installation
First, make sure that the latest `trezorlib` Python package
is installed correctly (at least v0.6.6):
$ apt-get install python-dev libusb-1.0-0-dev libudev-dev
$ pip install Cython trezor
Then, install the latest `trezor_agent` package:
$ pip install trezor_agent
Finally, verify that you are running the latest TREZOR firmware version (at least v1.3.4):
$ trezorctl get_features
vendor: "bitcointrezor.com"
major_version: 1
minor_version: 3
patch_version: 4
...
## Public key generation
Run:
/tmp $ trezor-agent ssh.hostname.com -v > hostname.pub
2015-09-02 15:03:18,929 INFO getting "ssh://ssh.hostname.com" public key from Trezor...
2015-09-02 15:03:23,342 INFO disconnected from Trezor
/tmp $ cat hostname.pub
ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBGSevcDwmT+QaZPUEWUUjTeZRBICChxMKuJ7dRpBSF8+qt+8S1GBK5Zj8Xicc8SHG/SE/EXKUL2UU3kcUzE7ADQ= ssh://ssh.hostname.com
Append `hostname.pub` contents to `~/.ssh/authorized_keys`
configuration file at `ssh.hostname.com`, so the remote server
would allow you to login using the corresponding private key signature.
## Usage
Run:
/tmp $ trezor-agent ssh.hostname.com -v -c
2015-09-02 15:09:39,782 INFO getting "ssh://ssh.hostname.com" public key from Trezor...
2015-09-02 15:09:44,430 INFO please confirm user "roman" login to "ssh://ssh.hostname.com" using Trezor...
2015-09-02 15:09:46,152 INFO signature status: OK
Linux lmde 3.16.0-4-amd64 #1 SMP Debian 3.16.7-ckt11-1+deb8u3 (2015-08-04) x86_64
The programs included with the Debian GNU/Linux system are free software;
the exact distribution terms for each program are described in the
individual files in /usr/share/doc/*/copyright.
Debian GNU/Linux comes with ABSOLUTELY NO WARRANTY, to the extent
permitted by applicable law.
Last login: Tue Sep 1 15:57:05 2015 from localhost
~ $
Make sure to confirm SSH signature on the Trezor device when requested.

View File

@@ -3,27 +3,28 @@ from setuptools import setup
setup(
name='trezor_agent',
version='0.4',
version='0.6.2',
description='Using Trezor as hardware SSH agent',
author='Roman Zeyde',
author_email='roman.zeyde@gmail.com',
license='MIT',
url='http://github.com/romanz/trezor-agent',
packages=['trezor_agent', 'trezor_agent.trezor'],
install_requires=['ecdsa', 'trezor'],
packages=['trezor_agent'],
install_requires=['ecdsa>=0.13', 'ed25519>=1.4', 'Cython>=0.23.4', 'trezor>=0.6.6', 'keepkey>=0.7.0', 'semver>=2.2'],
platforms=['POSIX'],
classifiers=[
'Development Status :: 3 - Alpha',
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Intended Audience :: Information Technology',
'License :: OSI Approved :: MIT License',
'License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3)',
'Operating System :: POSIX',
'Programming Language :: Python :: 2.7',
'Programming Language :: Python :: 3.4',
'Topic :: Software Development :: Libraries :: Python Modules',
'Topic :: System :: Networking',
'Topic :: Communications',
],
entry_points={'console_scripts': [
'trezor-agent = trezor_agent.__main__:trezor_agent'
'trezor-agent = trezor_agent.__main__:run_agent',
'trezor-git = trezor_agent.__main__:run_git',
]},
)

11
tox.ini
View File

@@ -1,6 +1,5 @@
[tox]
envlist = py27,py34
skipsdist = True
[testenv]
deps=
pytest
@@ -8,10 +7,12 @@ deps=
pep8
coverage
pylint
six
ecdsa
semver
pydocstyle
commands=
pep8 trezor_agent
pylint --report=no --rcfile .pylintrc trezor_agent
coverage run --omit='trezor_agent/__main__.py,trezor_agent/trezor/_library.py' --source trezor_agent/ -m py.test -v
pylint --reports=no --rcfile .pylintrc trezor_agent
pydocstyle trezor_agent
coverage run --omit='trezor_agent/__main__.py' --source trezor_agent -m py.test -v trezor_agent
coverage report
coverage html

View File

@@ -0,0 +1 @@
"""SSH-agent implementation using hardware authentication devices."""

View File

@@ -1,62 +1,79 @@
import os
import re
import sys
"""SSH-agent implementation using hardware authentication devices."""
import argparse
import subprocess
from . import trezor
from . import server
import functools
import logging
import re
import os
import subprocess
import sys
import time
from . import client, formats, protocol, server
log = logging.getLogger(__name__)
def identity_from_gitconfig():
out = subprocess.check_output(args='git config --list --local'.split())
config = [line.split('=', 1) for line in out.strip().split('\n')]
config_dict = dict(item for item in config if len(item) == 2)
def ssh_args(label):
"""Create SSH command for connecting specified server."""
identity = client.string_to_identity(label, identity_type=dict)
name_regex = re.compile(r'^remote\..*\.trezor$')
names = [item[0] for item in config if name_regex.match(item[0])]
if len(names) != 1:
log.error('please add "trezor" key '
'to a single remote section at .git/config')
sys.exit(1)
key_name = names[0]
identity_label = config_dict.get(key_name)
if identity_label:
return identity_label
args = []
if 'port' in identity:
args += ['-p', identity['port']]
if 'user' in identity:
args += ['-l', identity['user']]
# extract remote name marked as TREZOR's
section_name, _ = key_name.rsplit('.', 1)
key_name = section_name + '.url'
url = config_dict[key_name]
log.debug('using "%s=%s" from git-config', key_name, url)
user, url = url.split('@', 1)
host, path = url.split(':', 1)
return 'ssh://{0}@{1}/{2}'.format(user, host, path)
return ['ssh'] + args + [identity['host']]
def create_agent_parser():
def create_parser():
"""Create argparse.ArgumentParser for this tool."""
p = argparse.ArgumentParser()
p.add_argument('-v', '--verbose', default=0, action='count')
p.add_argument('identity', type=str, default=None,
help='proto://[user@]host[:port][/path]')
curve_names = [name.decode('ascii') for name in formats.SUPPORTED_CURVES]
curve_names = ', '.join(sorted(curve_names))
p.add_argument('-e', '--ecdsa-curve-name', metavar='CURVE',
default=formats.CURVE_NIST256,
help='specify ECDSA curve name: ' + curve_names)
p.add_argument('--timeout',
default=server.UNIX_SOCKET_TIMEOUT, type=float,
help='Timeout for accepting SSH client connections')
p.add_argument('--debug', default=False, action='store_true',
help='Log SSH protocol messages for debugging.')
return p
def create_agent_parser():
"""Specific parser for SSH connection."""
p = create_parser()
g = p.add_mutually_exclusive_group()
g.add_argument('-s', '--shell', default=False, action='store_true',
help='run $SHELL as subprocess under SSH agent')
help='run ${SHELL} as subprocess under SSH agent')
g.add_argument('-c', '--connect', default=False, action='store_true',
help='connect to specified host via SSH')
p.add_argument('identity', type=str, default=None,
help='proto://[user@]host[:port][/path]')
p.add_argument('command', type=str, nargs='*', metavar='ARGUMENT',
help='command to run under the SSH agent')
return p
def create_git_parser():
"""Specific parser for git commands."""
p = create_parser()
p.add_argument('-r', '--remote', default='origin',
help='use this git remote URL to generate SSH identity')
p.add_argument('command', type=str, nargs='*', metavar='ARGUMENT',
help='Git command to run under the SSH agent')
return p
def setup_logging(verbosity):
"""Configure logging for this tool."""
fmt = ('%(asctime)s %(levelname)-12s %(message)-100s '
'[%(filename)s:%(lineno)d]')
levels = [logging.WARNING, logging.INFO, logging.DEBUG]
@@ -64,53 +81,81 @@ def setup_logging(verbosity):
logging.basicConfig(format=fmt, level=level)
def ssh_command(identity):
command = ['ssh', identity.host]
if identity.user:
command += ['-l', identity.user]
if identity.port:
command += ['-p', identity.port]
return command
def git_host(remote_name):
"""Extract git SSH host for specified remote name."""
output = subprocess.check_output('git config --local --list'.split())
pattern = r'remote\.{}\.url=(.*)'.format(remote_name)
matches = re.findall(pattern, output)
log.debug('git remote "%r": %r', remote_name, matches)
if len(matches) != 1:
raise ValueError('{:d} git remotes found: %s', matches)
url = matches[0].strip()
user, url = url.split('@', 1)
host, path = url.split(':', 1)
return 'ssh://{}@{}/{}'.format(user, host, path)
def trezor_agent():
def ssh_sign(conn, label, blob):
"""Perform SSH signature using given hardware device connection."""
now = time.strftime('%Y-%m-%d %H:%M:%S')
return conn.sign_ssh_challenge(label=label, blob=blob, visual=now)
def run_server(conn, public_key, command, debug, timeout):
"""Common code for run_agent and run_git below."""
try:
signer = functools.partial(ssh_sign, conn=conn)
public_keys = [formats.import_public_key(public_key)]
handler = protocol.Handler(keys=public_keys, signer=signer,
debug=debug)
with server.serve(handler=handler, timeout=timeout) as env:
return server.run_process(command=command, environ=env)
except KeyboardInterrupt:
log.info('server stopped')
def run_agent(client_factory=client.Client):
"""Run ssh-agent using given hardware client factory."""
args = create_agent_parser().parse_args()
setup_logging(verbosity=args.verbose)
with trezor.Client() as client:
with client_factory(curve=args.ecdsa_curve_name) as conn:
label = args.identity
command = args.command
if label == 'git':
label = identity_from_gitconfig()
log.info('using identity %r for git command %r', label, command)
if command:
command = ['git'] + command
public_key = conn.get_public_key(label=label)
identity = client.get_identity(label=label)
public_key = client.get_public_key(identity=identity)
use_shell = False
if args.connect:
command = ssh_command(identity) + args.command
command = ssh_args(label) + args.command
log.debug('SSH connect: %r', command)
if args.shell:
command, use_shell = os.environ['SHELL'], True
use_shell = bool(args.shell)
if use_shell:
command = os.environ['SHELL']
log.debug('using shell: %r', command)
if not command:
sys.stdout.write(public_key)
return
def signer(label, blob):
identity = client.get_identity(label=label)
return client.sign_ssh_challenge(identity=identity, blob=blob)
return run_server(conn=conn, public_key=public_key, command=command,
debug=args.debug, timeout=args.timeout)
try:
with server.serve(public_keys=[public_key], signer=signer) as env:
return server.run_process(command=command, environ=env,
use_shell=use_shell)
except KeyboardInterrupt:
log.info('server stopped')
def run_git(client_factory=client.Client):
"""Run git under ssh-agent using given hardware client factory."""
args = create_git_parser().parse_args()
setup_logging(verbosity=args.verbose)
with client_factory(curve=args.ecdsa_curve_name) as conn:
label = git_host(args.remote)
public_key = conn.get_public_key(label=label)
if not args.command:
sys.stdout.write(public_key)
return
return run_server(conn=conn, public_key=public_key,
command=(['git'] + args.command),
debug=args.debug, timeout=args.timeout)

151
trezor_agent/client.py Normal file
View File

@@ -0,0 +1,151 @@
"""
Connection to hardware authentication device.
It is used for getting SSH public keys and ECDSA signing of server requests.
"""
import binascii
import io
import logging
import re
import struct
from . import factory, formats, util
log = logging.getLogger(__name__)
class Client(object):
"""Client wrapper for SSH authentication device."""
def __init__(self, loader=factory.load, curve=formats.CURVE_NIST256):
"""Connect to hardware device."""
client_wrapper = loader()
self.client = client_wrapper.connection
self.identity_type = client_wrapper.identity_type
self.device_name = client_wrapper.device_name
self.curve = curve
def __enter__(self):
"""Start a session, and test connection."""
msg = 'Hello World!'
assert self.client.ping(msg) == msg
return self
def __exit__(self, *args):
"""Forget PIN, shutdown screen and disconnect."""
log.info('disconnected from %s', self.device_name)
self.client.clear_session()
self.client.close()
def get_identity(self, label, index=0):
"""Parse label string into Identity protobuf."""
identity = string_to_identity(label, self.identity_type)
identity.proto = 'ssh'
identity.index = index
return identity
def get_public_key(self, label):
"""Get SSH public key corresponding to specified by label."""
identity = self.get_identity(label=label)
label = identity_to_string(identity) # canonize key label
log.info('getting "%s" public key (%s) from %s...',
label, self.curve, self.device_name)
addr = _get_address(identity)
node = self.client.get_public_node(n=addr,
ecdsa_curve_name=self.curve)
pubkey = node.node.public_key
vk = formats.decompress_pubkey(pubkey=pubkey, curve_name=self.curve)
return formats.export_public_key(vk=vk, label=label)
def sign_ssh_challenge(self, label, blob, visual=''):
"""Sign given blob using a private key, specified by the label."""
identity = self.get_identity(label=label)
msg = _parse_ssh_blob(blob)
log.debug('%s: user %r via %r (%r)',
msg['conn'], msg['user'], msg['auth'], msg['key_type'])
log.debug('nonce: %s', binascii.hexlify(msg['nonce']))
log.debug('fingerprint: %s', msg['public_key']['fingerprint'])
log.debug('hidden challenge size: %d bytes', len(blob))
log.debug('visual challenge size: %d bytes = %r', len(visual), visual)
log.info('please confirm user "%s" login to "%s" using %s...',
msg['user'], label, self.device_name)
result = self.client.sign_identity(identity=identity,
challenge_hidden=blob,
challenge_visual=visual,
ecdsa_curve_name=self.curve)
verifying_key = formats.decompress_pubkey(pubkey=result.public_key,
curve_name=self.curve)
key_type, blob = formats.serialize_verifying_key(verifying_key)
assert blob == msg['public_key']['blob']
assert key_type == msg['key_type']
assert len(result.signature) == 65
assert result.signature[:1] == bytearray([0])
return result.signature[1:]
_identity_regexp = re.compile(''.join([
'^'
r'(?:(?P<proto>.*)://)?',
r'(?:(?P<user>.*)@)?',
r'(?P<host>.*?)',
r'(?::(?P<port>\w*))?',
r'(?P<path>/.*)?',
'$'
]))
def string_to_identity(s, identity_type):
"""Parse string into Identity protobuf."""
m = _identity_regexp.match(s)
result = m.groupdict()
log.debug('parsed identity: %s', result)
kwargs = {k: v for k, v in result.items() if v}
return identity_type(**kwargs)
def identity_to_string(identity):
"""Dump Identity protobuf into its string representation."""
result = []
if identity.proto:
result.append(identity.proto + '://')
if identity.user:
result.append(identity.user + '@')
result.append(identity.host)
if identity.port:
result.append(':' + identity.port)
if identity.path:
result.append(identity.path)
return ''.join(result)
def _get_address(identity):
index = struct.pack('<L', identity.index)
addr = index + identity_to_string(identity).encode('ascii')
log.debug('address string: %r', addr)
digest = formats.hashfunc(addr).digest()
s = io.BytesIO(bytearray(digest))
hardened = 0x80000000
address_n = [13] + list(util.recv(s, '<LLLL'))
return [(hardened | value) for value in address_n]
def _parse_ssh_blob(data):
res = {}
i = io.BytesIO(data)
res['nonce'] = util.read_frame(i)
i.read(1) # SSH2_MSG_USERAUTH_REQUEST == 50 (from ssh2.h, line 108)
res['user'] = util.read_frame(i)
res['conn'] = util.read_frame(i)
res['auth'] = util.read_frame(i)
i.read(1) # have_sig == 1 (from sshconnect2.c, line 1056)
res['key_type'] = util.read_frame(i)
public_key = util.read_frame(i)
res['public_key'] = formats.parse_pubkey(public_key)
assert not i.read()
return res

88
trezor_agent/factory.py Normal file
View File

@@ -0,0 +1,88 @@
"""Thin wrapper around trezor/keepkey libraries."""
import binascii
import collections
import logging
import semver
log = logging.getLogger(__name__)
ClientWrapper = collections.namedtuple(
'ClientWrapper',
['connection', 'identity_type', 'device_name'])
# pylint: disable=too-many-arguments
def _load_client(name, client_type, hid_transport,
passphrase_ack, identity_type, required_version):
def empty_passphrase_handler(_):
return passphrase_ack(passphrase='')
for d in hid_transport.enumerate():
connection = client_type(hid_transport(d))
connection.callback_PassphraseRequest = empty_passphrase_handler
f = connection.features
log.debug('connected to %s %s', name, f.device_id)
log.debug('label : %s', f.label)
log.debug('vendor : %s', f.vendor)
current_version = '{}.{}.{}'.format(f.major_version,
f.minor_version,
f.patch_version)
log.debug('version : %s', current_version)
log.debug('revision : %s', binascii.hexlify(f.revision))
if not semver.match(current_version, required_version):
fmt = 'Please upgrade your {} firmware to {} version (current: {})'
raise ValueError(fmt.format(name,
required_version,
current_version))
yield ClientWrapper(connection=connection,
identity_type=identity_type,
device_name=name)
def _load_trezor():
# pylint: disable=import-error
from trezorlib.client import TrezorClient
from trezorlib.transport_hid import HidTransport
from trezorlib.messages_pb2 import PassphraseAck
from trezorlib.types_pb2 import IdentityType
return _load_client(name='Trezor',
client_type=TrezorClient,
hid_transport=HidTransport,
passphrase_ack=PassphraseAck,
identity_type=IdentityType,
required_version='>=1.3.4')
def _load_keepkey():
# pylint: disable=import-error
from keepkeylib.client import KeepKeyClient
from keepkeylib.transport_hid import HidTransport
from keepkeylib.messages_pb2 import PassphraseAck
from keepkeylib.types_pb2 import IdentityType
return _load_client(name='KeepKey',
client_type=KeepKeyClient,
hid_transport=HidTransport,
passphrase_ack=PassphraseAck,
identity_type=IdentityType,
required_version='>=1.0.4')
LOADERS = [
_load_trezor,
_load_keepkey
]
def load(loaders=None):
"""Load a single device, via specified loaders' list."""
loaders = loaders if loaders is not None else LOADERS
device_list = []
for loader in loaders:
device_list.extend(loader())
if len(device_list) == 1:
return device_list[0]
msg = '{:d} devices found'.format(len(device_list))
raise IOError(msg)

View File

@@ -1,92 +1,194 @@
import io
import hashlib
"""SSH format parsing and formatting tools."""
import base64
import hashlib
import io
import logging
import ecdsa
import ed25519
from . import util
import logging
log = logging.getLogger(__name__)
DER_OCTET_STRING = b'\x04'
ECDSA_KEY_PREFIX = b'ecdsa-sha2-'
ECDSA_CURVE_NAME = b'nistp256'
# Supported ECDSA curves
CURVE_NIST256 = b'nist256p1'
CURVE_ED25519 = b'ed25519'
SUPPORTED_CURVES = {CURVE_NIST256, CURVE_ED25519}
# SSH key types
SSH_NIST256_DER_OCTET = b'\x04'
SSH_NIST256_KEY_PREFIX = b'ecdsa-sha2-'
SSH_NIST256_CURVE_NAME = b'nistp256'
SSH_NIST256_KEY_TYPE = SSH_NIST256_KEY_PREFIX + SSH_NIST256_CURVE_NAME
SSH_ED25519_KEY_TYPE = b'ssh-ed25519'
SUPPORTED_KEY_TYPES = {SSH_NIST256_KEY_TYPE, SSH_ED25519_KEY_TYPE}
hashfunc = hashlib.sha256
def fingerprint(blob):
"""
Compute SSH fingerprint for specified blob.
See https://en.wikipedia.org/wiki/Public_key_fingerprint for details.
"""
digest = hashlib.md5(blob).digest()
return ':'.join('{:02x}'.format(c) for c in bytearray(digest))
def parse_pubkey(blob, curve=ecdsa.NIST256p):
def parse_pubkey(blob):
"""
Parse SSH public key from given blob.
Cnstruct a verifier for ECDSA signatures.
The verifier returns the signatures in the required SSH format.
Currently, NIST256P1 and ED25519 elliptic curves are supported.
"""
fp = fingerprint(blob)
s = io.BytesIO(blob)
key_type = util.read_frame(s)
log.debug('key type: %s', key_type)
curve_name = util.read_frame(s)
log.debug('curve name: %s', curve_name)
point = util.read_frame(s)
assert s.read() == b''
_type, point = point[:1], point[1:]
assert _type == DER_OCTET_STRING
size = len(point) // 2
assert len(point) == 2 * size
coords = (util.bytes2num(point[:size]), util.bytes2num(point[size:]))
log.debug('coordinates: %s', coords)
fp = fingerprint(blob)
assert key_type in SUPPORTED_KEY_TYPES, key_type
result = {'blob': blob, 'type': key_type, 'fingerprint': fp}
if key_type == SSH_NIST256_KEY_TYPE:
curve_name = util.read_frame(s)
log.debug('curve name: %s', curve_name)
point = util.read_frame(s)
assert s.read() == b''
_type, point = point[:1], point[1:]
assert _type == SSH_NIST256_DER_OCTET
size = len(point) // 2
assert len(point) == 2 * size
coords = (util.bytes2num(point[:size]), util.bytes2num(point[size:]))
curve = ecdsa.NIST256p
point = ecdsa.ellipticcurve.Point(curve.curve, *coords)
def ecdsa_verifier(sig, msg):
assert len(sig) == 2 * size
sig_decode = ecdsa.util.sigdecode_string
vk = ecdsa.VerifyingKey.from_public_point(point, curve, hashfunc)
vk.verify(signature=sig, data=msg, sigdecode=sig_decode)
parts = [sig[:size], sig[size:]]
return b''.join([util.frame(b'\x00' + p) for p in parts])
result.update(point=coords, curve=CURVE_NIST256,
verifier=ecdsa_verifier)
if key_type == SSH_ED25519_KEY_TYPE:
pubkey = util.read_frame(s)
assert s.read() == b''
def ed25519_verify(sig, msg):
assert len(sig) == 64
vk = ed25519.VerifyingKey(pubkey)
vk.verify(sig, msg)
return sig
result.update(curve=CURVE_ED25519, verifier=ed25519_verify)
point = ecdsa.ellipticcurve.Point(curve.curve, *coords)
vk = ecdsa.VerifyingKey.from_public_point(point, curve, hashfunc)
result = {
'point': coords,
'curve': curve_name,
'fingerprint': fp,
'type': key_type,
'blob': blob,
'size': size,
'verifying_key': vk
}
return result
def decompress_pubkey(pub, curve=ecdsa.NIST256p):
P = curve.curve.p()
A = curve.curve.a()
B = curve.curve.b()
x = util.bytes2num(pub[1:33])
beta = pow(int(x*x*x+A*x+B), int((P+1)//4), int(P))
def _decompress_ed25519(pubkey):
"""Load public key from the serialized blob (stripping the prefix byte)."""
if pubkey[:1] == b'\x00':
# set by Trezor fsm_msgSignIdentity() and fsm_msgGetPublicKey()
return ed25519.VerifyingKey(pubkey[1:])
p0 = util.bytes2num(pub[:1])
y = (P-beta) if ((beta + p0) % 2) else beta
point = ecdsa.ellipticcurve.Point(curve.curve, x, y)
return ecdsa.VerifyingKey.from_public_point(point, curve=curve,
hashfunc=hashfunc)
def _decompress_nist256(pubkey):
"""
Load public key from the serialized blob.
The leading byte least-significant bit is used to decide how to recreate
the y-coordinate from the specified x-coordinate. See bitcoin/main.py#L198
(from https://github.com/vbuterin/pybitcointools/) for details.
"""
if pubkey[:1] in {b'\x02', b'\x03'}: # set by ecdsa_get_public_key33()
curve = ecdsa.NIST256p
P = curve.curve.p()
A = curve.curve.a()
B = curve.curve.b()
x = util.bytes2num(pubkey[1:33])
beta = pow(int(x * x * x + A * x + B), int((P + 1) // 4), int(P))
p0 = util.bytes2num(pubkey[:1])
y = (P - beta) if ((beta + p0) % 2) else beta
point = ecdsa.ellipticcurve.Point(curve.curve, x, y)
return ecdsa.VerifyingKey.from_public_point(point, curve=curve,
hashfunc=hashfunc)
def decompress_pubkey(pubkey, curve_name):
"""
Load public key from the serialized blob.
Raise ValueError on parsing error.
"""
vk = None
if len(pubkey) == 33:
decompress = {
CURVE_NIST256: _decompress_nist256,
CURVE_ED25519: _decompress_ed25519
}[curve_name]
vk = decompress(pubkey)
if not vk:
msg = 'invalid {!s} public key: {!r}'.format(curve_name, pubkey)
raise ValueError(msg)
return vk
def serialize_verifying_key(vk):
key_type = ECDSA_KEY_PREFIX + ECDSA_CURVE_NAME
curve_name = ECDSA_CURVE_NAME
key_blob = DER_OCTET_STRING + vk.to_string()
parts = [key_type, curve_name, key_blob]
return b''.join([util.frame(p) for p in parts])
"""
Serialize a public key into SSH format (for exporting to text format).
Currently, NIST256P1 and ED25519 elliptic curves are supported.
Raise TypeError on unsupported key format.
"""
if isinstance(vk, ed25519.keys.VerifyingKey):
pubkey = vk.to_bytes()
key_type = SSH_ED25519_KEY_TYPE
blob = util.frame(SSH_ED25519_KEY_TYPE) + util.frame(pubkey)
return key_type, blob
if isinstance(vk, ecdsa.keys.VerifyingKey):
curve_name = SSH_NIST256_CURVE_NAME
key_blob = SSH_NIST256_DER_OCTET + vk.to_string()
parts = [SSH_NIST256_KEY_TYPE, curve_name, key_blob]
key_type = SSH_NIST256_KEY_TYPE
blob = b''.join([util.frame(p) for p in parts])
return key_type, blob
raise TypeError('unsupported {!r}'.format(vk))
def export_public_key(pubkey, label):
blob = serialize_verifying_key(decompress_pubkey(pubkey))
def export_public_key(vk, label):
"""
Export public key to text format.
The resulting string can be written into a .pub file or
appended to the ~/.ssh/authorized_keys file.
"""
key_type, blob = serialize_verifying_key(vk)
log.debug('fingerprint: %s', fingerprint(blob))
b64 = base64.b64encode(blob).decode('ascii')
key_type = ECDSA_KEY_PREFIX + ECDSA_CURVE_NAME
return '{} {} {}\n'.format(key_type.decode('ascii'), b64, label)
def import_public_key(line):
''' Parse public key textual format, as saved at .pub file '''
"""Parse public key textual format, as saved at a .pub file."""
log.debug('loading SSH public key: %r', line)
file_type, base64blob, name = line.split()
blob = base64.b64decode(base64blob)
result = parse_pubkey(blob)
result['name'] = name.encode('ascii')
assert result['type'] == file_type.encode('ascii')
log.debug('loaded %s %s', file_type, result['fingerprint'])
log.debug('loaded %s public key: %s', file_type, result['fingerprint'])
return result

View File

@@ -1,42 +1,41 @@
"""
SSH-agent protocol implementation library.
See https://github.com/openssh/openssh-portable/blob/master/PROTOCOL.agent and
http://ptspts.blogspot.co.il/2010/06/how-to-use-ssh-agent-programmatically.html
for more details.
The server's source code can be found here:
https://github.com/openssh/openssh-portable/blob/master/authfd.c
"""
import binascii
import io
from . import util
from . import formats
import logging
from . import formats, util
log = logging.getLogger(__name__)
SSH_AGENTC_REQUEST_RSA_IDENTITIES = 1
SSH_AGENT_RSA_IDENTITIES_ANSWER = 2
SSH_AGENTC_REMOVE_ALL_RSA_IDENTITIES = 9
SSH2_AGENTC_REQUEST_IDENTITIES = 11
SSH2_AGENT_IDENTITIES_ANSWER = 12
SSH2_AGENTC_SIGN_REQUEST = 13
SSH2_AGENT_SIGN_RESPONSE = 14
SSH2_AGENTC_ADD_IDENTITY = 17
SSH2_AGENTC_REMOVE_IDENTITY = 18
SSH2_AGENTC_REMOVE_ALL_IDENTITIES = 19
class Error(Exception):
pass
class BadSignature(Error):
pass
class MissingKey(Error):
pass
class Handler(object):
"""ssh-agent protocol handler."""
def __init__(self, keys, signer):
def __init__(self, keys, signer, debug=False):
"""
Create a protocol handler with specified public keys.
Use specified signer function to sign SSH authentication requests.
"""
self.public_keys = keys
self.signer = signer
self.debug = debug
self.methods = {
SSH_AGENTC_REQUEST_RSA_IDENTITIES: Handler.legacy_pubs,
@@ -45,25 +44,28 @@ class Handler(object):
}
def handle(self, msg):
log.debug('request: %d bytes', len(msg))
"""Handle SSH message from the SSH client and return the response."""
debug_msg = ': {!r}'.format(msg) if self.debug else ''
log.debug('request: %d bytes%s', len(msg), debug_msg)
buf = io.BytesIO(msg)
code, = util.recv(buf, '>B')
method = self.methods[code]
log.debug('calling %s()', method.__name__)
reply = method(buf=buf)
log.debug('reply: %d bytes', len(reply))
debug_reply = ': {!r}'.format(reply) if self.debug else ''
log.debug('reply: %d bytes%s', len(reply), debug_reply)
return reply
@staticmethod
def legacy_pubs(buf):
''' SSH v1 public keys are not supported '''
"""SSH v1 public keys are not supported."""
assert not buf.read()
code = util.pack('B', SSH_AGENT_RSA_IDENTITIES_ANSWER)
num = util.pack('L', 0) # no SSH v1 keys
return util.frame(code, num)
def list_pubs(self, buf):
''' SSH v2 public keys are serialized and returned. '''
"""SSH v2 public keys are serialized and returned."""
assert not buf.read()
keys = self.public_keys
code = util.pack('B', SSH2_AGENT_IDENTITIES_ANSWER)
@@ -75,7 +77,12 @@ class Handler(object):
return util.frame(code, num, *pubs)
def sign_message(self, buf):
''' SSH v2 public key authentication is performed. '''
"""
SSH v2 public key authentication is performed.
If the required key is not supported, raise KeyError
If the signature is invalid, rause ValueError
"""
key = formats.parse_pubkey(util.read_frame(buf))
log.debug('looking for %s', key['fingerprint'])
blob = util.read_frame(buf)
@@ -88,26 +95,20 @@ class Handler(object):
key = k
break
else:
raise MissingKey('key not found')
raise KeyError('key not found')
log.debug('signing %d-byte blob', len(blob))
r, s = self.signer(label=key['name'], blob=blob)
signature = (r, s)
log.debug('signature: %s', signature)
label = key['name'].decode('ascii') # label should be a string
signature = self.signer(label=label, blob=blob)
log.debug('signature: %s', binascii.hexlify(signature))
try:
key['verifying_key'].verify(signature=signature, data=blob,
sigdecode=lambda sig, _: sig)
sig_bytes = key['verifier'](sig=signature, msg=blob)
log.info('signature status: OK')
except formats.ecdsa.BadSignatureError:
log.exception('signature status: ERROR')
raise BadSignature('invalid ECDSA signature')
raise ValueError('invalid ECDSA signature')
sig_bytes = io.BytesIO()
for x in signature:
x_frame = util.frame(b'\x00' + util.num2bytes(x, key['size']))
sig_bytes.write(x_frame)
sig_bytes = sig_bytes.getvalue()
log.debug('signature size: %d bytes', len(sig_bytes))
data = util.frame(util.frame(key['type']), util.frame(sig_bytes))

View File

@@ -1,19 +1,21 @@
import socket
"""UNIX-domain socket server for ssh-agent implementation."""
import contextlib
import logging
import os
import socket
import subprocess
import tempfile
import contextlib
import threading
from . import protocol
from . import formats
from . import util
import logging
log = logging.getLogger(__name__)
UNIX_SOCKET_TIMEOUT = 0.1
def remove_file(path, remove=os.remove, exists=os.path.exists):
"""Remove file, and raise OSError if still exists."""
try:
remove(path)
except OSError:
@@ -23,6 +25,11 @@ def remove_file(path, remove=os.remove, exists=os.path.exists):
@contextlib.contextmanager
def unix_domain_socket_server(sock_path):
"""
Create UNIX-domain socket on specified path.
Listen on it, and delete it after the generated context is over.
"""
log.debug('serving on SSH_AUTH_SOCK=%s', sock_path)
remove_file(sock_path)
@@ -36,6 +43,12 @@ def unix_domain_socket_server(sock_path):
def handle_connection(conn, handler):
"""
Handle a single connection using the specified protocol handler in a loop.
Exit when EOFError is raised.
All other exceptions are logged as warnings.
"""
try:
log.debug('welcome agent')
while True:
@@ -44,19 +57,41 @@ def handle_connection(conn, handler):
util.send(conn, reply)
except EOFError:
log.debug('goodbye agent')
except:
log.exception('error')
raise
except Exception as e: # pylint: disable=broad-except
log.warning('error: %s', e, exc_info=True)
def server_thread(server, handler):
log.debug('server thread started')
def retry(func, exception_type, quit_event):
"""
Run the function, retrying when the specified exception_type occurs.
Poll quit_event on each iteration, to be responsive to an external
exit request.
"""
while True:
log.debug('waiting for connection on %s', server.getsockname())
if quit_event.is_set():
raise StopIteration
try:
conn, _ = server.accept()
except socket.error as e:
log.debug('server stopped: %s', e)
return func()
except exception_type:
pass
def server_thread(sock, handler, quit_event):
"""Run a server on the specified socket."""
log.debug('server thread started')
def accept_connection():
conn, _ = sock.accept()
conn.settimeout(None)
return conn
while True:
log.debug('waiting for connection on %s', sock.getsockname())
try:
conn = retry(accept_connection, socket.timeout, quit_event)
except StopIteration:
log.debug('server stopped')
break
with contextlib.closing(conn):
handle_connection(conn, handler)
@@ -64,7 +99,8 @@ def server_thread(server, handler):
@contextlib.contextmanager
def spawn(func, **kwargs):
def spawn(func, kwargs):
"""Spawn a thread, and join it after the context is over."""
t = threading.Thread(target=func, kwargs=kwargs)
t.start()
yield
@@ -72,28 +108,40 @@ def spawn(func, **kwargs):
@contextlib.contextmanager
def serve(public_keys, signer, sock_path=None):
def serve(handler, sock_path=None, timeout=UNIX_SOCKET_TIMEOUT):
"""
Start the ssh-agent server on a UNIX-domain socket.
If no connection is made during the specified timeout,
retry until the context is over.
"""
if sock_path is None:
sock_path = tempfile.mktemp(prefix='ssh-agent-')
keys = [formats.import_public_key(k) for k in public_keys]
environ = {'SSH_AUTH_SOCK': sock_path, 'SSH_AGENT_PID': str(os.getpid())}
with unix_domain_socket_server(sock_path) as server:
handler = protocol.Handler(keys=keys, signer=signer)
with spawn(server_thread, server=server, handler=handler):
with unix_domain_socket_server(sock_path) as sock:
sock.settimeout(timeout)
quit_event = threading.Event()
kwargs = dict(sock=sock, handler=handler, quit_event=quit_event)
with spawn(server_thread, kwargs):
try:
yield environ
finally:
log.debug('closing server')
server.shutdown(socket.SHUT_RD)
quit_event.set()
def run_process(command, environ, use_shell=False):
log.debug('running %r with %r', command, environ)
def run_process(command, environ):
"""
Run the specified process and wait until it finishes.
Use environ dict for environment variables.
"""
log.info('running %r with %r', command, environ)
env = dict(os.environ)
env.update(environ)
try:
p = subprocess.Popen(args=command, env=env, shell=use_shell)
p = subprocess.Popen(args=command, env=env)
except OSError as e:
raise OSError('cannot run %r: %s' % (command, e))
log.debug('subprocess %d is running', p.pid)

View File

@@ -0,0 +1 @@
"""Unit-tests for this package."""

View File

@@ -1,8 +1,8 @@
from ..trezor import client
from .. import formats
import io
import mock
from .. import client, factory, formats, util
ADDR = [2147483661, 2810943954, 3938368396, 3454558782, 3848009040]
CURVE = 'nist256p1'
@@ -14,17 +14,9 @@ PUBKEY_TEXT = ('ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzd'
'VUfhvrGljR2Z/CMRONY6ejB+9PnpUOPuzYqi8= ssh://localhost:22\n')
class ConnectionMock(object):
class FakeConnection(object):
def __init__(self):
self.features = mock.Mock(spec=[])
self.features.device_id = '123456789'
self.features.label = 'mywallet'
self.features.vendor = 'mock'
self.features.major_version = 1
self.features.minor_version = 2
self.features.patch_version = 3
self.features.revision = b'456'
self.closed = False
def close(self):
@@ -33,10 +25,10 @@ class ConnectionMock(object):
def clear_session(self):
self.closed = True
def get_public_node(self, n, ecdsa_curve_name='secp256k1'):
def get_public_node(self, n, ecdsa_curve_name=b'secp256k1'):
assert not self.closed
assert n == ADDR
assert ecdsa_curve_name in {'secp256k1', 'nist256p1'}
assert ecdsa_curve_name in {b'secp256k1', b'nist256p1'}
result = mock.Mock(spec=[])
result.node = mock.Mock(spec=[])
result.node.public_key = PUBKEY
@@ -47,21 +39,20 @@ class ConnectionMock(object):
return msg
class FactoryMock(object):
def identity_type(**kwargs):
result = mock.Mock(spec=[])
result.index = 0
result.proto = result.user = result.host = result.port = None
result.path = None
for k, v in kwargs.items():
setattr(result, k, v)
return result
@staticmethod
def client():
return ConnectionMock()
@staticmethod
def identity_type(**kwargs):
result = mock.Mock(spec=[])
result.index = 0
result.proto = result.user = result.host = result.port = None
result.path = None
for k, v in kwargs.items():
setattr(result, k, v)
return result
def load_client():
return factory.ClientWrapper(connection=FakeConnection(),
identity_type=identity_type,
device_name='DEVICE_NAME')
BLOB = (b'\x00\x00\x00 \xce\xe0\xc9\xd5\xceu/\xe8\xc5\xf2\xbfR+x\xa1\xcf\xb0'
@@ -79,23 +70,26 @@ SIG = (b'\x00R\x19T\xf2\x84$\xef#\x0e\xee\x04X\xc6\xc3\x99T`\xd1\xd8\xf7!'
def test_ssh_agent():
c = client.Client(factory=FactoryMock)
ident = c.get_identity(label='localhost:22')
label = 'localhost:22'
c = client.Client(loader=load_client)
ident = c.get_identity(label=label)
assert ident.host == 'localhost'
assert ident.proto == 'ssh'
assert ident.port == '22'
assert ident.user is None
assert ident.path is None
assert ident.index == 0
with c:
assert c.get_public_key(ident) == PUBKEY_TEXT
assert c.get_public_key(label) == PUBKEY_TEXT
def ssh_sign_identity(identity, challenge_hidden,
challenge_visual, ecdsa_curve_name):
assert identity is ident
assert (client.identity_to_string(identity) ==
client.identity_to_string(ident))
assert challenge_hidden == BLOB
assert challenge_visual == identity.path
assert ecdsa_curve_name == 'nist256p1'
assert challenge_visual == 'VISUAL'
assert ecdsa_curve_name == b'nist256p1'
result = mock.Mock(spec=[])
result.public_key = PUBKEY
@@ -103,11 +97,19 @@ def test_ssh_agent():
return result
c.client.sign_identity = ssh_sign_identity
signature = c.sign_ssh_challenge(identity=ident, blob=BLOB)
signature = c.sign_ssh_challenge(label=label, blob=BLOB,
visual='VISUAL')
key = formats.import_public_key(PUBKEY_TEXT)
assert key['verifying_key'].verify(signature=signature, data=BLOB,
sigdecode=lambda sig, _: sig)
serialized_sig = key['verifier'](sig=signature, msg=BLOB)
stream = io.BytesIO(serialized_sig)
r = util.read_frame(stream)
s = util.read_frame(stream)
assert not stream.read()
assert r[:1] == b'\x00'
assert s[:1] == b'\x00'
assert r[1:] + s[1:] == SIG[1:]
def test_utils():

View File

@@ -0,0 +1,94 @@
import mock
import pytest
from .. import factory
def test_load():
def single():
return [0]
def nothing():
return []
def double():
return [1, 2]
assert factory.load(loaders=[single]) == 0
assert factory.load(loaders=[single, nothing]) == 0
assert factory.load(loaders=[nothing, single]) == 0
with pytest.raises(IOError):
factory.load(loaders=[])
with pytest.raises(IOError):
factory.load(loaders=[single, single])
with pytest.raises(IOError):
factory.load(loaders=[double])
def factory_load_client(**kwargs):
# pylint: disable=protected-access
return list(factory._load_client(**kwargs))
def test_load_nothing():
hid_transport = mock.Mock(spec_set=['enumerate'])
hid_transport.enumerate.return_value = []
result = factory_load_client(
name=None,
client_type=None,
hid_transport=hid_transport,
passphrase_ack=None,
identity_type=None,
required_version=None)
assert result == []
def create_client_type(version):
conn = mock.Mock(spec=[])
conn.features = mock.Mock(spec=[])
major, minor, patch = version.split('.')
conn.features.device_id = 'DEVICE_ID'
conn.features.label = 'LABEL'
conn.features.vendor = 'VENDOR'
conn.features.major_version = major
conn.features.minor_version = minor
conn.features.patch_version = patch
conn.features.revision = b'\x12\x34\x56\x78'
return mock.Mock(spec_set=[], return_value=conn)
def test_load_single():
hid_transport = mock.Mock(spec_set=['enumerate'])
hid_transport.enumerate.return_value = [0]
for version in ('1.3.4', '1.3.5', '1.4.0', '2.0.0'):
passphrase_ack = mock.Mock(spec_set=[])
client_type = create_client_type(version)
client_wrapper, = factory_load_client(
name='DEVICE_NAME',
client_type=client_type,
hid_transport=hid_transport,
passphrase_ack=passphrase_ack,
identity_type=None,
required_version='>=1.3.4')
assert client_wrapper.connection is client_type.return_value
assert client_wrapper.device_name == 'DEVICE_NAME'
client_wrapper.connection.callback_PassphraseRequest('MESSAGE')
assert passphrase_ack.mock_calls == [mock.call(passphrase='')]
def test_load_old():
hid_transport = mock.Mock(spec_set=['enumerate'])
hid_transport.enumerate.return_value = [0]
for version in ('1.3.3', '1.2.5', '1.1.0', '0.9.9'):
with pytest.raises(ValueError):
factory_load_client(
name='DEVICE_NAME',
client_type=create_client_type(version),
hid_transport=hid_transport,
passphrase_ack=None,
identity_type=None,
required_version='>=1.3.4')

View File

@@ -1,5 +1,7 @@
import binascii
import pytest
from .. import formats
@@ -27,13 +29,67 @@ def test_parse_public_key():
assert key['name'] == b'home'
assert key['point'] == _point
assert key['curve'] == b'nistp256'
assert key['curve'] == b'nist256p1'
assert key['fingerprint'] == '4b:19:bc:0f:c8:7e:dc:fa:1a:e3:c2:ff:6f:e0:80:a2' # nopep8
assert key['type'] == b'ecdsa-sha2-nistp256'
assert key['size'] == 32
def test_decompress():
blob = '036236ceabde25207e81e404586e3a3af1acda1dfed2abbbb4876c1fc5b296b575'
result = formats.export_public_key(binascii.unhexlify(blob), label='home')
assert result == _public_key
vk = formats.decompress_pubkey(binascii.unhexlify(blob),
curve_name=formats.CURVE_NIST256)
assert formats.export_public_key(vk, label='home') == _public_key
def test_parse_ed25519():
pubkey = ('ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIFBdF2tj'
'fSO8nLIi736is+f0erq28RTc7CkM11NZtTKR hello\n')
p = formats.import_public_key(pubkey)
assert p['name'] == b'hello'
assert p['curve'] == b'ed25519'
BLOB = (b'\x00\x00\x00\x0bssh-ed25519\x00\x00\x00 P]\x17kc}#'
b'\xbc\x9c\xb2"\xef~\xa2\xb3\xe7\xf4z\xba\xb6\xf1\x14'
b'\xdc\xec)\x0c\xd7SY\xb52\x91')
assert p['blob'] == BLOB
assert p['fingerprint'] == '6b:b0:77:af:e5:3a:21:6d:17:82:9b:06:19:03:a1:97' # nopep8
assert p['type'] == b'ssh-ed25519'
def test_export_ed25519():
pub = (b'\x00P]\x17kc}#\xbc\x9c\xb2"\xef~\xa2\xb3\xe7\xf4'
b'z\xba\xb6\xf1\x14\xdc\xec)\x0c\xd7SY\xb52\x91')
vk = formats.decompress_pubkey(pub, formats.CURVE_ED25519)
result = formats.serialize_verifying_key(vk)
assert result == (b'ssh-ed25519',
b'\x00\x00\x00\x0bssh-ed25519\x00\x00\x00 P]\x17kc}#\xbc'
b'\x9c\xb2"\xef~\xa2\xb3\xe7\xf4z\xba\xb6\xf1\x14\xdc'
b'\xec)\x0c\xd7SY\xb52\x91')
def test_decompress_error():
with pytest.raises(ValueError):
formats.decompress_pubkey('', formats.CURVE_NIST256)
def test_curve_mismatch():
# NIST256 public key
blob = '036236ceabde25207e81e404586e3a3af1acda1dfed2abbbb4876c1fc5b296b575'
with pytest.raises(ValueError):
formats.decompress_pubkey(binascii.unhexlify(blob),
curve_name=formats.CURVE_ED25519)
blob = '00' * 33 # Dummy public key
with pytest.raises(ValueError):
formats.decompress_pubkey(binascii.unhexlify(blob),
curve_name=formats.CURVE_NIST256)
blob = 'FF' * 33 # Unsupported prefix byte
with pytest.raises(ValueError):
formats.decompress_pubkey(binascii.unhexlify(blob),
curve_name=formats.CURVE_NIST256)
def test_serialize_error():
with pytest.raises(TypeError):
formats.serialize_verifying_key(None)

View File

@@ -1,56 +1,76 @@
from .. import protocol
from .. import formats
import pytest
from .. import formats, protocol
# pylint: disable=line-too-long
KEY = 'ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBEUksojS/qRlTKBKLQO7CBX7a7oqFkysuFn1nJ6gzlR3wNuQXEgd7qb2bjmiiBHsjNxyWvH5SxVi3+fghrqODWo= ssh://localhost' # nopep8
BLOB = b'\x00\x00\x00 !S^\xe7\xf8\x1cKN\xde\xcbo\x0c\x83\x9e\xc48\r\xac\xeb,]"\xc1\x9bA\x0eit\xc1\x81\xd4E2\x00\x00\x00\x05roman\x00\x00\x00\x0essh-connection\x00\x00\x00\tpublickey\x01\x00\x00\x00\x13ecdsa-sha2-nistp256\x00\x00\x00h\x00\x00\x00\x13ecdsa-sha2-nistp256\x00\x00\x00\x08nistp256\x00\x00\x00A\x04E$\xb2\x88\xd2\xfe\xa4eL\xa0J-\x03\xbb\x08\x15\xfbk\xba*\x16L\xac\xb8Y\xf5\x9c\x9e\xa0\xceTw\xc0\xdb\x90\\H\x1d\xee\xa6\xf6n9\xa2\x88\x11\xec\x8c\xdcrZ\xf1\xf9K\x15b\xdf\xe7\xe0\x86\xba\x8e\rj' # nopep8
SIG = (61640221631134565789126560951398335114074531708367858563384221818711312348703, 51535548700089687831159696283235534298026173963719263249292887877395159425513) # nopep8
NIST256_KEY = 'ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBEUksojS/qRlTKBKLQO7CBX7a7oqFkysuFn1nJ6gzlR3wNuQXEgd7qb2bjmiiBHsjNxyWvH5SxVi3+fghrqODWo= ssh://localhost' # nopep8
NIST256_BLOB = b'\x00\x00\x00 !S^\xe7\xf8\x1cKN\xde\xcbo\x0c\x83\x9e\xc48\r\xac\xeb,]"\xc1\x9bA\x0eit\xc1\x81\xd4E2\x00\x00\x00\x05roman\x00\x00\x00\x0essh-connection\x00\x00\x00\tpublickey\x01\x00\x00\x00\x13ecdsa-sha2-nistp256\x00\x00\x00h\x00\x00\x00\x13ecdsa-sha2-nistp256\x00\x00\x00\x08nistp256\x00\x00\x00A\x04E$\xb2\x88\xd2\xfe\xa4eL\xa0J-\x03\xbb\x08\x15\xfbk\xba*\x16L\xac\xb8Y\xf5\x9c\x9e\xa0\xceTw\xc0\xdb\x90\\H\x1d\xee\xa6\xf6n9\xa2\x88\x11\xec\x8c\xdcrZ\xf1\xf9K\x15b\xdf\xe7\xe0\x86\xba\x8e\rj' # nopep8
NIST256_SIG = b'\x88G!\x0c\n\x16:\xbeF\xbe\xb9\xd2\xa9&e\x89\xad\xc4}\x10\xf8\xbc\xdc\xef\x0e\x8d_\x8a6.\xb6\x1fq\xf0\x16>,\x9a\xde\xe7(\xd6\xd7\x93\x1f\xed\xf9\x94ddw\xfe\xbdq\x13\xbb\xfc\xa9K\xea\x9dC\xa1\xe9' # nopep8
LIST_MSG = b'\x0b'
LIST_REPLY = b'\x00\x00\x00\x84\x0c\x00\x00\x00\x01\x00\x00\x00h\x00\x00\x00\x13ecdsa-sha2-nistp256\x00\x00\x00\x08nistp256\x00\x00\x00A\x04E$\xb2\x88\xd2\xfe\xa4eL\xa0J-\x03\xbb\x08\x15\xfbk\xba*\x16L\xac\xb8Y\xf5\x9c\x9e\xa0\xceTw\xc0\xdb\x90\\H\x1d\xee\xa6\xf6n9\xa2\x88\x11\xec\x8c\xdcrZ\xf1\xf9K\x15b\xdf\xe7\xe0\x86\xba\x8e\rj\x00\x00\x00\x0fssh://localhost' # nopep8
LIST_NIST256_REPLY = b'\x00\x00\x00\x84\x0c\x00\x00\x00\x01\x00\x00\x00h\x00\x00\x00\x13ecdsa-sha2-nistp256\x00\x00\x00\x08nistp256\x00\x00\x00A\x04E$\xb2\x88\xd2\xfe\xa4eL\xa0J-\x03\xbb\x08\x15\xfbk\xba*\x16L\xac\xb8Y\xf5\x9c\x9e\xa0\xceTw\xc0\xdb\x90\\H\x1d\xee\xa6\xf6n9\xa2\x88\x11\xec\x8c\xdcrZ\xf1\xf9K\x15b\xdf\xe7\xe0\x86\xba\x8e\rj\x00\x00\x00\x0fssh://localhost' # nopep8
SIGN_MSG = b'\r\x00\x00\x00h\x00\x00\x00\x13ecdsa-sha2-nistp256\x00\x00\x00\x08nistp256\x00\x00\x00A\x04E$\xb2\x88\xd2\xfe\xa4eL\xa0J-\x03\xbb\x08\x15\xfbk\xba*\x16L\xac\xb8Y\xf5\x9c\x9e\xa0\xceTw\xc0\xdb\x90\\H\x1d\xee\xa6\xf6n9\xa2\x88\x11\xec\x8c\xdcrZ\xf1\xf9K\x15b\xdf\xe7\xe0\x86\xba\x8e\rj\x00\x00\x00\xd1\x00\x00\x00 !S^\xe7\xf8\x1cKN\xde\xcbo\x0c\x83\x9e\xc48\r\xac\xeb,]"\xc1\x9bA\x0eit\xc1\x81\xd4E2\x00\x00\x00\x05roman\x00\x00\x00\x0essh-connection\x00\x00\x00\tpublickey\x01\x00\x00\x00\x13ecdsa-sha2-nistp256\x00\x00\x00h\x00\x00\x00\x13ecdsa-sha2-nistp256\x00\x00\x00\x08nistp256\x00\x00\x00A\x04E$\xb2\x88\xd2\xfe\xa4eL\xa0J-\x03\xbb\x08\x15\xfbk\xba*\x16L\xac\xb8Y\xf5\x9c\x9e\xa0\xceTw\xc0\xdb\x90\\H\x1d\xee\xa6\xf6n9\xa2\x88\x11\xec\x8c\xdcrZ\xf1\xf9K\x15b\xdf\xe7\xe0\x86\xba\x8e\rj\x00\x00\x00\x00' # nopep8
SIGN_REPLY = b'\x00\x00\x00j\x0e\x00\x00\x00e\x00\x00\x00\x13ecdsa-sha2-nistp256\x00\x00\x00J\x00\x00\x00!\x00\x88G!\x0c\n\x16:\xbeF\xbe\xb9\xd2\xa9&e\x89\xad\xc4}\x10\xf8\xbc\xdc\xef\x0e\x8d_\x8a6.\xb6\x1f\x00\x00\x00!\x00q\xf0\x16>,\x9a\xde\xe7(\xd6\xd7\x93\x1f\xed\xf9\x94ddw\xfe\xbdq\x13\xbb\xfc\xa9K\xea\x9dC\xa1\xe9' # nopep8
NIST256_SIGN_MSG = b'\r\x00\x00\x00h\x00\x00\x00\x13ecdsa-sha2-nistp256\x00\x00\x00\x08nistp256\x00\x00\x00A\x04E$\xb2\x88\xd2\xfe\xa4eL\xa0J-\x03\xbb\x08\x15\xfbk\xba*\x16L\xac\xb8Y\xf5\x9c\x9e\xa0\xceTw\xc0\xdb\x90\\H\x1d\xee\xa6\xf6n9\xa2\x88\x11\xec\x8c\xdcrZ\xf1\xf9K\x15b\xdf\xe7\xe0\x86\xba\x8e\rj\x00\x00\x00\xd1\x00\x00\x00 !S^\xe7\xf8\x1cKN\xde\xcbo\x0c\x83\x9e\xc48\r\xac\xeb,]"\xc1\x9bA\x0eit\xc1\x81\xd4E2\x00\x00\x00\x05roman\x00\x00\x00\x0essh-connection\x00\x00\x00\tpublickey\x01\x00\x00\x00\x13ecdsa-sha2-nistp256\x00\x00\x00h\x00\x00\x00\x13ecdsa-sha2-nistp256\x00\x00\x00\x08nistp256\x00\x00\x00A\x04E$\xb2\x88\xd2\xfe\xa4eL\xa0J-\x03\xbb\x08\x15\xfbk\xba*\x16L\xac\xb8Y\xf5\x9c\x9e\xa0\xceTw\xc0\xdb\x90\\H\x1d\xee\xa6\xf6n9\xa2\x88\x11\xec\x8c\xdcrZ\xf1\xf9K\x15b\xdf\xe7\xe0\x86\xba\x8e\rj\x00\x00\x00\x00' # nopep8
NIST256_SIGN_REPLY = b'\x00\x00\x00j\x0e\x00\x00\x00e\x00\x00\x00\x13ecdsa-sha2-nistp256\x00\x00\x00J\x00\x00\x00!\x00\x88G!\x0c\n\x16:\xbeF\xbe\xb9\xd2\xa9&e\x89\xad\xc4}\x10\xf8\xbc\xdc\xef\x0e\x8d_\x8a6.\xb6\x1f\x00\x00\x00!\x00q\xf0\x16>,\x9a\xde\xe7(\xd6\xd7\x93\x1f\xed\xf9\x94ddw\xfe\xbdq\x13\xbb\xfc\xa9K\xea\x9dC\xa1\xe9' # nopep8
def test_list():
key = formats.import_public_key(KEY)
key = formats.import_public_key(NIST256_KEY)
h = protocol.Handler(keys=[key], signer=None)
reply = h.handle(LIST_MSG)
assert reply == LIST_REPLY
assert reply == LIST_NIST256_REPLY
def signer(label, blob):
assert label == b'ssh://localhost'
assert blob == BLOB
return SIG
def ecdsa_signer(label, blob):
assert label == 'ssh://localhost'
assert blob == NIST256_BLOB
return NIST256_SIG
def test_sign():
key = formats.import_public_key(KEY)
h = protocol.Handler(keys=[key], signer=signer)
reply = h.handle(SIGN_MSG)
assert reply == SIGN_REPLY
def test_ecdsa_sign():
key = formats.import_public_key(NIST256_KEY)
h = protocol.Handler(keys=[key], signer=ecdsa_signer)
reply = h.handle(NIST256_SIGN_MSG)
assert reply == NIST256_SIGN_REPLY
def test_sign_missing():
h = protocol.Handler(keys=[], signer=signer)
h = protocol.Handler(keys=[], signer=ecdsa_signer)
with pytest.raises(protocol.MissingKey):
h.handle(SIGN_MSG)
with pytest.raises(KeyError):
h.handle(NIST256_SIGN_MSG)
def test_sign_wrong():
def wrong_signature(label, blob):
assert label == b'ssh://localhost'
assert blob == BLOB
return (0, 0)
assert label == 'ssh://localhost'
assert blob == NIST256_BLOB
return b'\x00' * 64
key = formats.import_public_key(KEY)
key = formats.import_public_key(NIST256_KEY)
h = protocol.Handler(keys=[key], signer=wrong_signature)
with pytest.raises(protocol.BadSignature):
h.handle(SIGN_MSG)
with pytest.raises(ValueError):
h.handle(NIST256_SIGN_MSG)
ED25519_KEY = 'ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIFBdF2tjfSO8nLIi736is+f0erq28RTc7CkM11NZtTKR ssh://localhost' # nopep8
ED25519_SIGN_MSG = b'''\r\x00\x00\x003\x00\x00\x00\x0bssh-ed25519\x00\x00\x00 P]\x17kc}#\xbc\x9c\xb2"\xef~\xa2\xb3\xe7\xf4z\xba\xb6\xf1\x14\xdc\xec)\x0c\xd7SY\xb52\x91\x00\x00\x00\x94\x00\x00\x00 i3\xae}yk\\\xa1L\xb9\xe1\xbf\xbc\x8e\x87\r\x0e\xc0\x9f\x97\x0fTC!\x80\x07\x91\xdb^8\xc1\xd62\x00\x00\x00\x05roman\x00\x00\x00\x0essh-connection\x00\x00\x00\tpublickey\x01\x00\x00\x00\x0bssh-ed25519\x00\x00\x003\x00\x00\x00\x0bssh-ed25519\x00\x00\x00 P]\x17kc}#\xbc\x9c\xb2"\xef~\xa2\xb3\xe7\xf4z\xba\xb6\xf1\x14\xdc\xec)\x0c\xd7SY\xb52\x91\x00\x00\x00\x00''' # nopep8
ED25519_SIGN_REPLY = b'''\x00\x00\x00X\x0e\x00\x00\x00S\x00\x00\x00\x0bssh-ed25519\x00\x00\x00@\x8eb)\xa6\xe9P\x83VE\xfbq\xc6\xbf\x1dV3\xe3<O\x11\xc0\xfa\xe4\xed\xb8\x81.\x81\xc8\xa6\xba\x10RA'a\xbc\xa9\xd3\xdb\x98\x07\xf0\x1a\x9c4\x84<\xaf\x99\xb7\xe5G\xeb\xf7$\xc1\r\x86f\x16\x8e\x08\x05''' # nopep8
ED25519_BLOB = b'''\x00\x00\x00 i3\xae}yk\\\xa1L\xb9\xe1\xbf\xbc\x8e\x87\r\x0e\xc0\x9f\x97\x0fTC!\x80\x07\x91\xdb^8\xc1\xd62\x00\x00\x00\x05roman\x00\x00\x00\x0essh-connection\x00\x00\x00\tpublickey\x01\x00\x00\x00\x0bssh-ed25519\x00\x00\x003\x00\x00\x00\x0bssh-ed25519\x00\x00\x00 P]\x17kc}#\xbc\x9c\xb2"\xef~\xa2\xb3\xe7\xf4z\xba\xb6\xf1\x14\xdc\xec)\x0c\xd7SY\xb52\x91''' # nopep8
ED25519_SIG = b'''\x8eb)\xa6\xe9P\x83VE\xfbq\xc6\xbf\x1dV3\xe3<O\x11\xc0\xfa\xe4\xed\xb8\x81.\x81\xc8\xa6\xba\x10RA'a\xbc\xa9\xd3\xdb\x98\x07\xf0\x1a\x9c4\x84<\xaf\x99\xb7\xe5G\xeb\xf7$\xc1\r\x86f\x16\x8e\x08\x05''' # nopep8
def ed25519_signer(label, blob):
assert label == 'ssh://localhost'
assert blob == ED25519_BLOB
return ED25519_SIG
def test_ed25519_sign():
key = formats.import_public_key(ED25519_KEY)
h = protocol.Handler(keys=[key], signer=ed25519_signer)
reply = h.handle(ED25519_SIGN_MSG)
assert reply == ED25519_SIGN_REPLY

View File

@@ -1,12 +1,13 @@
import tempfile
import socket
import os
import io
import os
import socket
import tempfile
import threading
import mock
import pytest
from .. import server
from .. import protocol
from .. import util
from .. import protocol, server, util
def test_socket():
@@ -16,7 +17,7 @@ def test_socket():
assert not os.path.isfile(path)
class SocketMock(object):
class FakeSocket(object):
def __init__(self, data=b''):
self.rx = io.BytesIO(data)
@@ -31,45 +32,48 @@ class SocketMock(object):
def close(self):
pass
def settimeout(self, value):
pass
def test_handle():
handler = protocol.Handler(keys=[], signer=None)
conn = SocketMock()
conn = FakeSocket()
server.handle_connection(conn, handler)
msg = bytearray([protocol.SSH_AGENTC_REQUEST_RSA_IDENTITIES])
conn = SocketMock(util.frame(msg))
conn = FakeSocket(util.frame(msg))
server.handle_connection(conn, handler)
assert conn.tx.getvalue() == b'\x00\x00\x00\x05\x02\x00\x00\x00\x00'
msg = bytearray([protocol.SSH2_AGENTC_REQUEST_IDENTITIES])
conn = SocketMock(util.frame(msg))
conn = FakeSocket(util.frame(msg))
server.handle_connection(conn, handler)
assert conn.tx.getvalue() == b'\x00\x00\x00\x05\x0C\x00\x00\x00\x00'
with pytest.raises(AttributeError):
server.handle_connection(conn=None, handler=None)
class ServerMock(object):
def __init__(self, connections, name):
self.connections = connections
self.name = name
def getsockname(self):
return self.name
def accept(self):
if self.connections:
return self.connections.pop(), 'address'
raise socket.error('stop')
conn_mock = mock.Mock(spec=FakeSocket)
conn_mock.recv.side_effect = [Exception, EOFError]
server.handle_connection(conn=conn_mock, handler=None)
def test_server_thread():
s = ServerMock(connections=[SocketMock()], name='mock')
h = protocol.Handler(keys=[], signer=None)
server.server_thread(s, h)
connections = [FakeSocket()]
quit_event = threading.Event()
class FakeServer(object):
def accept(self): # pylint: disable=no-self-use
if connections:
return connections.pop(), 'address'
quit_event.set()
raise socket.timeout()
def getsockname(self): # pylint: disable=no-self-use
return 'fake_server'
server.server_thread(sock=FakeServer(),
handler=protocol.Handler(keys=[], signer=None),
quit_event=quit_event)
def test_spawn():
@@ -78,7 +82,7 @@ def test_spawn():
def thread(x):
obj.append(x)
with server.spawn(thread, x=1):
with server.spawn(thread, dict(x=1)):
pass
assert obj == [1]
@@ -87,17 +91,16 @@ def test_spawn():
def test_run():
assert server.run_process(['true'], environ={}) == 0
assert server.run_process(['false'], environ={}) == 1
assert server.run_process(
command='exit $X',
environ={'X': '42'},
use_shell=True) == 42
assert server.run_process(command=['bash', '-c', 'exit $X'],
environ={'X': '42'}) == 42
with pytest.raises(OSError):
server.run_process([''], environ={})
def test_serve_main():
with server.serve(public_keys=[], signer=None, sock_path=None):
handler = protocol.Handler(keys=[], signer=None)
with server.serve(handler=handler, sock_path=None):
pass

View File

@@ -1,4 +1,5 @@
import io
import pytest
from .. import util
@@ -23,7 +24,7 @@ def test_frames():
assert util.read_frame(io.BytesIO(f)) == b''.join(msgs)
class SocketMock(object):
class FakeSocket(object):
def __init__(self):
self.buf = io.BytesIO()
@@ -35,9 +36,9 @@ class SocketMock(object):
def test_send_recv():
s = SocketMock()
s = FakeSocket()
util.send(s, b'123')
util.send(s, data=[42], fmt='B')
util.send(s, b'*')
assert s.buf.getvalue() == b'123*'
s.buf.seek(0)

View File

@@ -1 +0,0 @@
from .client import Client

View File

@@ -1,18 +0,0 @@
''' Thin wrapper around trezorlib. '''
def client():
# pylint: disable=import-error
from trezorlib.client import TrezorClient
from trezorlib.transport_hid import HidTransport
devices = HidTransport.enumerate()
if len(devices) != 1:
msg = '{:d} Trezor devices found'.format(len(devices))
raise IOError(msg)
return TrezorClient(HidTransport(devices[0]))
def identity_type(**kwargs):
# pylint: disable=import-error
from trezorlib.types_pb2 import IdentityType
return IdentityType(**kwargs)

View File

@@ -1,145 +0,0 @@
import io
import re
import struct
import binascii
from .. import util
from .. import formats
from . import _factory as TrezorFactory
import logging
log = logging.getLogger(__name__)
class Client(object):
def __init__(self, factory=TrezorFactory):
self.factory = factory
self.client = self.factory.client()
f = self.client.features
log.debug('connected to Trezor %s', f.device_id)
log.debug('label : %s', f.label)
log.debug('vendor : %s', f.vendor)
version = [f.major_version, f.minor_version, f.patch_version]
log.debug('version : %s', '.'.join([str(v) for v in version]))
log.debug('revision : %s', binascii.hexlify(f.revision))
def __enter__(self):
msg = 'Hello World!'
assert self.client.ping(msg) == msg
return self
def __exit__(self, *args):
log.info('disconnected from Trezor')
self.client.clear_session() # forget PIN and shutdown screen
self.client.close()
def get_identity(self, label):
identity = string_to_identity(label, self.factory.identity_type)
identity.proto = 'ssh'
return identity
def get_public_key(self, identity):
assert identity.proto == 'ssh'
label = identity_to_string(identity)
log.info('getting "%s" public key from Trezor...', label)
addr = _get_address(identity)
node = self.client.get_public_node(n=addr,
ecdsa_curve_name='nist256p1')
pubkey = node.node.public_key
return formats.export_public_key(pubkey=pubkey, label=label)
def sign_ssh_challenge(self, identity, blob):
assert identity.proto == 'ssh'
label = identity_to_string(identity)
msg = _parse_ssh_blob(blob)
log.info('please confirm user "%s" login to "%s" using Trezor...',
msg['user'], label)
visual = identity.path # not signed when proto='ssh'
result = self.client.sign_identity(identity=identity,
challenge_hidden=blob,
challenge_visual=visual,
ecdsa_curve_name='nist256p1')
verifying_key = formats.decompress_pubkey(result.public_key)
public_key_blob = formats.serialize_verifying_key(verifying_key)
assert public_key_blob == msg['public_key']['blob']
assert len(result.signature) == 65
assert result.signature[:1] == bytearray([0])
return parse_signature(result.signature)
def parse_signature(blob):
sig = blob[1:]
r = util.bytes2num(sig[:32])
s = util.bytes2num(sig[32:])
return (r, s)
_identity_regexp = re.compile(''.join([
'^'
r'(?:(?P<proto>.*)://)?',
r'(?:(?P<user>.*)@)?',
r'(?P<host>.*?)',
r'(?::(?P<port>\w*))?',
r'(?P<path>/.*)?',
'$'
]))
def string_to_identity(s, identity_type):
m = _identity_regexp.match(s)
result = m.groupdict()
log.debug('parsed identity: %s', result)
kwargs = {k: v for k, v in result.items() if v}
return identity_type(**kwargs)
def identity_to_string(identity):
result = []
if identity.proto:
result.append(identity.proto + '://')
if identity.user:
result.append(identity.user + '@')
result.append(identity.host)
if identity.port:
result.append(':' + identity.port)
if identity.path:
result.append(identity.path)
return ''.join(result)
def _get_address(identity):
index = struct.pack('<L', identity.index)
addr = index + identity_to_string(identity).encode('ascii')
log.debug('address string: %r', addr)
digest = formats.hashfunc(addr).digest()
s = io.BytesIO(bytearray(digest))
hardened = 0x80000000
address_n = [13] + list(util.recv(s, '<LLLL'))
return [(hardened | value) for value in address_n]
def _parse_ssh_blob(data):
res = {}
if data:
i = io.BytesIO(data)
res['nonce'] = util.read_frame(i)
i.read(1) # TBD
res['user'] = util.read_frame(i)
res['conn'] = util.read_frame(i)
res['auth'] = util.read_frame(i)
i.read(1) # TBD
res['key_type'] = util.read_frame(i)
public_key = util.read_frame(i)
res['public_key'] = formats.parse_pubkey(public_key)
assert not i.read()
log.debug('%s: user %r via %r (%r)',
res['conn'], res['user'], res['auth'], res['key_type'])
log.debug('nonce: %s', binascii.hexlify(res['nonce']))
log.debug('fingerprint: %s', res['public_key']['fingerprint'])
return res

View File

@@ -1,14 +1,20 @@
import struct
"""Various I/O and serialization utilities."""
import io
import struct
def send(conn, data, fmt=None):
if fmt:
data = struct.pack(fmt, *data)
def send(conn, data):
"""Send data blob to connection socket."""
conn.sendall(data)
def recv(conn, size):
"""
Receive bytes from connection socket or stream.
If size is struct.calcsize()-compatible format, use it to unpack the data.
Otherwise, return the plain blob as bytes.
"""
try:
fmt = size
size = struct.calcsize(fmt)
@@ -34,11 +40,13 @@ def recv(conn, size):
def read_frame(conn):
"""Read size-prefixed frame from connection."""
size, = recv(conn, '>L')
return recv(conn, size)
def bytes2num(s):
"""Convert MSB-first bytes to an unsigned integer."""
res = 0
for i, c in enumerate(reversed(bytearray(s))):
res += c << (i * 8)
@@ -46,6 +54,7 @@ def bytes2num(s):
def num2bytes(value, size):
"""Convert an unsigned integer to MSB-first bytes with specified size."""
res = []
for _ in range(size):
res.append(value & 0xFF)
@@ -55,10 +64,12 @@ def num2bytes(value, size):
def pack(fmt, *args):
"""Serialize MSB-first message."""
return struct.pack('>' + fmt, *args)
def frame(*msgs):
"""Serialize MSB-first length-prefixed frame."""
res = io.BytesIO()
for msg in msgs:
res.write(msg)