protocol: add docstrings and replace custom exceptions

This commit is contained in:
Roman Zeyde
2016-02-19 10:48:36 +02:00
parent 566e4310e1
commit 21e89014c9
2 changed files with 28 additions and 19 deletions

View File

@@ -1,3 +1,12 @@
"""
SSH-agent protocol implementation library.
See https://github.com/openssh/openssh-portable/blob/master/PROTOCOL.agent and
http://ptspts.blogspot.co.il/2010/06/how-to-use-ssh-agent-programmatically.html
for more details.
The server's source code can be found here:
https://github.com/openssh/openssh-portable/blob/master/authfd.c
"""
import binascii import binascii
import io import io
import logging import logging
@@ -15,21 +24,15 @@ SSH2_AGENTC_SIGN_REQUEST = 13
SSH2_AGENT_SIGN_RESPONSE = 14 SSH2_AGENT_SIGN_RESPONSE = 14
class Error(Exception):
pass
class BadSignature(Error):
pass
class MissingKey(Error):
pass
class Handler(object): class Handler(object):
"""ssh-agent protocol handler."""
def __init__(self, keys, signer, debug=False): def __init__(self, keys, signer, debug=False):
"""
Create a protocol handler with specified public keys.
Use specified signer function to sign SSH authentication requests.
"""
self.public_keys = keys self.public_keys = keys
self.signer = signer self.signer = signer
self.debug = debug self.debug = debug
@@ -41,6 +44,7 @@ class Handler(object):
} }
def handle(self, msg): def handle(self, msg):
"""Handle SSH message from the SSH client and return the response."""
debug_msg = ': {!r}'.format(msg) if self.debug else '' debug_msg = ': {!r}'.format(msg) if self.debug else ''
log.debug('request: %d bytes%s', len(msg), debug_msg) log.debug('request: %d bytes%s', len(msg), debug_msg)
buf = io.BytesIO(msg) buf = io.BytesIO(msg)
@@ -54,14 +58,14 @@ class Handler(object):
@staticmethod @staticmethod
def legacy_pubs(buf): def legacy_pubs(buf):
''' SSH v1 public keys are not supported ''' """SSH v1 public keys are not supported."""
assert not buf.read() assert not buf.read()
code = util.pack('B', SSH_AGENT_RSA_IDENTITIES_ANSWER) code = util.pack('B', SSH_AGENT_RSA_IDENTITIES_ANSWER)
num = util.pack('L', 0) # no SSH v1 keys num = util.pack('L', 0) # no SSH v1 keys
return util.frame(code, num) return util.frame(code, num)
def list_pubs(self, buf): def list_pubs(self, buf):
''' SSH v2 public keys are serialized and returned. ''' """SSH v2 public keys are serialized and returned."""
assert not buf.read() assert not buf.read()
keys = self.public_keys keys = self.public_keys
code = util.pack('B', SSH2_AGENT_IDENTITIES_ANSWER) code = util.pack('B', SSH2_AGENT_IDENTITIES_ANSWER)
@@ -73,7 +77,12 @@ class Handler(object):
return util.frame(code, num, *pubs) return util.frame(code, num, *pubs)
def sign_message(self, buf): def sign_message(self, buf):
''' SSH v2 public key authentication is performed. ''' """
SSH v2 public key authentication is performed.
If the required key is not supported, raise KeyError
If the signature is invalid, rause ValueError
"""
key = formats.parse_pubkey(util.read_frame(buf)) key = formats.parse_pubkey(util.read_frame(buf))
log.debug('looking for %s', key['fingerprint']) log.debug('looking for %s', key['fingerprint'])
blob = util.read_frame(buf) blob = util.read_frame(buf)
@@ -86,7 +95,7 @@ class Handler(object):
key = k key = k
break break
else: else:
raise MissingKey('key not found') raise KeyError('key not found')
log.debug('signing %d-byte blob', len(blob)) log.debug('signing %d-byte blob', len(blob))
label = key['name'].decode('ascii') # label should be a string label = key['name'].decode('ascii') # label should be a string
@@ -98,7 +107,7 @@ class Handler(object):
log.info('signature status: OK') log.info('signature status: OK')
except formats.ecdsa.BadSignatureError: except formats.ecdsa.BadSignatureError:
log.exception('signature status: ERROR') log.exception('signature status: ERROR')
raise BadSignature('invalid ECDSA signature') raise ValueError('invalid ECDSA signature')
log.debug('signature size: %d bytes', len(sig_bytes)) log.debug('signature size: %d bytes', len(sig_bytes))

View File

@@ -38,7 +38,7 @@ def test_ecdsa_sign():
def test_sign_missing(): def test_sign_missing():
h = protocol.Handler(keys=[], signer=ecdsa_signer) h = protocol.Handler(keys=[], signer=ecdsa_signer)
with pytest.raises(protocol.MissingKey): with pytest.raises(KeyError):
h.handle(NIST256_SIGN_MSG) h.handle(NIST256_SIGN_MSG)
@@ -51,7 +51,7 @@ def test_sign_wrong():
key = formats.import_public_key(NIST256_KEY) key = formats.import_public_key(NIST256_KEY)
h = protocol.Handler(keys=[key], signer=wrong_signature) h = protocol.Handler(keys=[key], signer=wrong_signature)
with pytest.raises(protocol.BadSignature): with pytest.raises(ValueError):
h.handle(NIST256_SIGN_MSG) h.handle(NIST256_SIGN_MSG)