mirror of
https://github.com/romanz/amodem.git
synced 2026-04-21 13:46:30 +08:00
protocol: add docstrings and replace custom exceptions
This commit is contained in:
@@ -1,3 +1,12 @@
|
|||||||
|
"""
|
||||||
|
SSH-agent protocol implementation library.
|
||||||
|
|
||||||
|
See https://github.com/openssh/openssh-portable/blob/master/PROTOCOL.agent and
|
||||||
|
http://ptspts.blogspot.co.il/2010/06/how-to-use-ssh-agent-programmatically.html
|
||||||
|
for more details.
|
||||||
|
The server's source code can be found here:
|
||||||
|
https://github.com/openssh/openssh-portable/blob/master/authfd.c
|
||||||
|
"""
|
||||||
import binascii
|
import binascii
|
||||||
import io
|
import io
|
||||||
import logging
|
import logging
|
||||||
@@ -15,21 +24,15 @@ SSH2_AGENTC_SIGN_REQUEST = 13
|
|||||||
SSH2_AGENT_SIGN_RESPONSE = 14
|
SSH2_AGENT_SIGN_RESPONSE = 14
|
||||||
|
|
||||||
|
|
||||||
class Error(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class BadSignature(Error):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class MissingKey(Error):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class Handler(object):
|
class Handler(object):
|
||||||
|
"""ssh-agent protocol handler."""
|
||||||
|
|
||||||
def __init__(self, keys, signer, debug=False):
|
def __init__(self, keys, signer, debug=False):
|
||||||
|
"""
|
||||||
|
Create a protocol handler with specified public keys.
|
||||||
|
|
||||||
|
Use specified signer function to sign SSH authentication requests.
|
||||||
|
"""
|
||||||
self.public_keys = keys
|
self.public_keys = keys
|
||||||
self.signer = signer
|
self.signer = signer
|
||||||
self.debug = debug
|
self.debug = debug
|
||||||
@@ -41,6 +44,7 @@ class Handler(object):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def handle(self, msg):
|
def handle(self, msg):
|
||||||
|
"""Handle SSH message from the SSH client and return the response."""
|
||||||
debug_msg = ': {!r}'.format(msg) if self.debug else ''
|
debug_msg = ': {!r}'.format(msg) if self.debug else ''
|
||||||
log.debug('request: %d bytes%s', len(msg), debug_msg)
|
log.debug('request: %d bytes%s', len(msg), debug_msg)
|
||||||
buf = io.BytesIO(msg)
|
buf = io.BytesIO(msg)
|
||||||
@@ -54,14 +58,14 @@ class Handler(object):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def legacy_pubs(buf):
|
def legacy_pubs(buf):
|
||||||
''' SSH v1 public keys are not supported '''
|
"""SSH v1 public keys are not supported."""
|
||||||
assert not buf.read()
|
assert not buf.read()
|
||||||
code = util.pack('B', SSH_AGENT_RSA_IDENTITIES_ANSWER)
|
code = util.pack('B', SSH_AGENT_RSA_IDENTITIES_ANSWER)
|
||||||
num = util.pack('L', 0) # no SSH v1 keys
|
num = util.pack('L', 0) # no SSH v1 keys
|
||||||
return util.frame(code, num)
|
return util.frame(code, num)
|
||||||
|
|
||||||
def list_pubs(self, buf):
|
def list_pubs(self, buf):
|
||||||
''' SSH v2 public keys are serialized and returned. '''
|
"""SSH v2 public keys are serialized and returned."""
|
||||||
assert not buf.read()
|
assert not buf.read()
|
||||||
keys = self.public_keys
|
keys = self.public_keys
|
||||||
code = util.pack('B', SSH2_AGENT_IDENTITIES_ANSWER)
|
code = util.pack('B', SSH2_AGENT_IDENTITIES_ANSWER)
|
||||||
@@ -73,7 +77,12 @@ class Handler(object):
|
|||||||
return util.frame(code, num, *pubs)
|
return util.frame(code, num, *pubs)
|
||||||
|
|
||||||
def sign_message(self, buf):
|
def sign_message(self, buf):
|
||||||
''' SSH v2 public key authentication is performed. '''
|
"""
|
||||||
|
SSH v2 public key authentication is performed.
|
||||||
|
|
||||||
|
If the required key is not supported, raise KeyError
|
||||||
|
If the signature is invalid, rause ValueError
|
||||||
|
"""
|
||||||
key = formats.parse_pubkey(util.read_frame(buf))
|
key = formats.parse_pubkey(util.read_frame(buf))
|
||||||
log.debug('looking for %s', key['fingerprint'])
|
log.debug('looking for %s', key['fingerprint'])
|
||||||
blob = util.read_frame(buf)
|
blob = util.read_frame(buf)
|
||||||
@@ -86,7 +95,7 @@ class Handler(object):
|
|||||||
key = k
|
key = k
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
raise MissingKey('key not found')
|
raise KeyError('key not found')
|
||||||
|
|
||||||
log.debug('signing %d-byte blob', len(blob))
|
log.debug('signing %d-byte blob', len(blob))
|
||||||
label = key['name'].decode('ascii') # label should be a string
|
label = key['name'].decode('ascii') # label should be a string
|
||||||
@@ -98,7 +107,7 @@ class Handler(object):
|
|||||||
log.info('signature status: OK')
|
log.info('signature status: OK')
|
||||||
except formats.ecdsa.BadSignatureError:
|
except formats.ecdsa.BadSignatureError:
|
||||||
log.exception('signature status: ERROR')
|
log.exception('signature status: ERROR')
|
||||||
raise BadSignature('invalid ECDSA signature')
|
raise ValueError('invalid ECDSA signature')
|
||||||
|
|
||||||
log.debug('signature size: %d bytes', len(sig_bytes))
|
log.debug('signature size: %d bytes', len(sig_bytes))
|
||||||
|
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ def test_ecdsa_sign():
|
|||||||
def test_sign_missing():
|
def test_sign_missing():
|
||||||
h = protocol.Handler(keys=[], signer=ecdsa_signer)
|
h = protocol.Handler(keys=[], signer=ecdsa_signer)
|
||||||
|
|
||||||
with pytest.raises(protocol.MissingKey):
|
with pytest.raises(KeyError):
|
||||||
h.handle(NIST256_SIGN_MSG)
|
h.handle(NIST256_SIGN_MSG)
|
||||||
|
|
||||||
|
|
||||||
@@ -51,7 +51,7 @@ def test_sign_wrong():
|
|||||||
key = formats.import_public_key(NIST256_KEY)
|
key = formats.import_public_key(NIST256_KEY)
|
||||||
h = protocol.Handler(keys=[key], signer=wrong_signature)
|
h = protocol.Handler(keys=[key], signer=wrong_signature)
|
||||||
|
|
||||||
with pytest.raises(protocol.BadSignature):
|
with pytest.raises(ValueError):
|
||||||
h.handle(NIST256_SIGN_MSG)
|
h.handle(NIST256_SIGN_MSG)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user