mirror of
https://github.com/romanz/amodem.git
synced 2026-04-21 05:36:42 +08:00
protocol: fail on unsupported commands
This commit is contained in:
@@ -15,13 +15,57 @@ from . import formats, util
|
|||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
SSH_AGENTC_REQUEST_RSA_IDENTITIES = 1
|
|
||||||
SSH_AGENT_RSA_IDENTITIES_ANSWER = 2
|
|
||||||
|
|
||||||
SSH2_AGENTC_REQUEST_IDENTITIES = 11
|
# Taken from https://github.com/openssh/openssh-portable/blob/master/authfd.h
|
||||||
SSH2_AGENT_IDENTITIES_ANSWER = 12
|
COMMANDS = dict(
|
||||||
SSH2_AGENTC_SIGN_REQUEST = 13
|
SSH_AGENTC_REQUEST_RSA_IDENTITIES=1,
|
||||||
SSH2_AGENT_SIGN_RESPONSE = 14
|
SSH_AGENT_RSA_IDENTITIES_ANSWER=2,
|
||||||
|
SSH_AGENTC_RSA_CHALLENGE=3,
|
||||||
|
SSH_AGENT_RSA_RESPONSE=4,
|
||||||
|
SSH_AGENT_FAILURE=5,
|
||||||
|
SSH_AGENT_SUCCESS=6,
|
||||||
|
SSH_AGENTC_ADD_RSA_IDENTITY=7,
|
||||||
|
SSH_AGENTC_REMOVE_RSA_IDENTITY=8,
|
||||||
|
SSH_AGENTC_REMOVE_ALL_RSA_IDENTITIES=9,
|
||||||
|
SSH2_AGENTC_REQUEST_IDENTITIES=11,
|
||||||
|
SSH2_AGENT_IDENTITIES_ANSWER=12,
|
||||||
|
SSH2_AGENTC_SIGN_REQUEST=13,
|
||||||
|
SSH2_AGENT_SIGN_RESPONSE=14,
|
||||||
|
SSH2_AGENTC_ADD_IDENTITY=17,
|
||||||
|
SSH2_AGENTC_REMOVE_IDENTITY=18,
|
||||||
|
SSH2_AGENTC_REMOVE_ALL_IDENTITIES=19,
|
||||||
|
SSH_AGENTC_ADD_SMARTCARD_KEY=20,
|
||||||
|
SSH_AGENTC_REMOVE_SMARTCARD_KEY=21,
|
||||||
|
SSH_AGENTC_LOCK=22,
|
||||||
|
SSH_AGENTC_UNLOCK=23,
|
||||||
|
SSH_AGENTC_ADD_RSA_ID_CONSTRAINED=24,
|
||||||
|
SSH2_AGENTC_ADD_ID_CONSTRAINED=25,
|
||||||
|
SSH_AGENTC_ADD_SMARTCARD_KEY_CONSTRAINED=26,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def msg_code(name):
|
||||||
|
"""Convert string name into a integer message code."""
|
||||||
|
return COMMANDS[name]
|
||||||
|
|
||||||
|
|
||||||
|
def msg_name(code):
|
||||||
|
"""Convert integer message code into a string name."""
|
||||||
|
ids = {v: k for k, v in COMMANDS.items()}
|
||||||
|
return ids[code]
|
||||||
|
|
||||||
|
|
||||||
|
def _fail():
|
||||||
|
error_msg = util.pack('B', msg_code('SSH_AGENT_FAILURE'))
|
||||||
|
return util.frame(error_msg)
|
||||||
|
|
||||||
|
|
||||||
|
def _legacy_pubs(buf):
|
||||||
|
"""SSH v1 public keys are not supported."""
|
||||||
|
assert not buf.read()
|
||||||
|
code = util.pack('B', msg_code('SSH_AGENT_RSA_IDENTITIES_ANSWER'))
|
||||||
|
num = util.pack('L', 0) # no SSH v1 keys
|
||||||
|
return util.frame(code, num)
|
||||||
|
|
||||||
|
|
||||||
class Handler(object):
|
class Handler(object):
|
||||||
@@ -38,9 +82,9 @@ class Handler(object):
|
|||||||
self.debug = debug
|
self.debug = debug
|
||||||
|
|
||||||
self.methods = {
|
self.methods = {
|
||||||
SSH_AGENTC_REQUEST_RSA_IDENTITIES: Handler.legacy_pubs,
|
msg_code('SSH_AGENTC_REQUEST_RSA_IDENTITIES'): _legacy_pubs,
|
||||||
SSH2_AGENTC_REQUEST_IDENTITIES: self.list_pubs,
|
msg_code('SSH2_AGENTC_REQUEST_IDENTITIES'): self.list_pubs,
|
||||||
SSH2_AGENTC_SIGN_REQUEST: self.sign_message,
|
msg_code('SSH2_AGENTC_SIGN_REQUEST'): self.sign_message,
|
||||||
}
|
}
|
||||||
|
|
||||||
def handle(self, msg):
|
def handle(self, msg):
|
||||||
@@ -49,6 +93,10 @@ class Handler(object):
|
|||||||
log.debug('request: %d bytes%s', len(msg), debug_msg)
|
log.debug('request: %d bytes%s', len(msg), debug_msg)
|
||||||
buf = io.BytesIO(msg)
|
buf = io.BytesIO(msg)
|
||||||
code, = util.recv(buf, '>B')
|
code, = util.recv(buf, '>B')
|
||||||
|
if code not in self.methods:
|
||||||
|
log.warning('Unsupported command: %s (%d)', msg_name(code), code)
|
||||||
|
return _fail()
|
||||||
|
|
||||||
method = self.methods[code]
|
method = self.methods[code]
|
||||||
log.debug('calling %s()', method.__name__)
|
log.debug('calling %s()', method.__name__)
|
||||||
reply = method(buf=buf)
|
reply = method(buf=buf)
|
||||||
@@ -56,19 +104,11 @@ class Handler(object):
|
|||||||
log.debug('reply: %d bytes%s', len(reply), debug_reply)
|
log.debug('reply: %d bytes%s', len(reply), debug_reply)
|
||||||
return reply
|
return reply
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def legacy_pubs(buf):
|
|
||||||
"""SSH v1 public keys are not supported."""
|
|
||||||
assert not buf.read()
|
|
||||||
code = util.pack('B', SSH_AGENT_RSA_IDENTITIES_ANSWER)
|
|
||||||
num = util.pack('L', 0) # no SSH v1 keys
|
|
||||||
return util.frame(code, num)
|
|
||||||
|
|
||||||
def list_pubs(self, buf):
|
def list_pubs(self, buf):
|
||||||
"""SSH v2 public keys are serialized and returned."""
|
"""SSH v2 public keys are serialized and returned."""
|
||||||
assert not buf.read()
|
assert not buf.read()
|
||||||
keys = self.public_keys
|
keys = self.public_keys
|
||||||
code = util.pack('B', SSH2_AGENT_IDENTITIES_ANSWER)
|
code = util.pack('B', msg_code('SSH2_AGENT_IDENTITIES_ANSWER'))
|
||||||
num = util.pack('L', len(keys))
|
num = util.pack('L', len(keys))
|
||||||
log.debug('available keys: %s', [k['name'] for k in keys])
|
log.debug('available keys: %s', [k['name'] for k in keys])
|
||||||
for i, k in enumerate(keys):
|
for i, k in enumerate(keys):
|
||||||
@@ -112,5 +152,5 @@ class Handler(object):
|
|||||||
log.debug('signature size: %d bytes', len(sig_bytes))
|
log.debug('signature size: %d bytes', len(sig_bytes))
|
||||||
|
|
||||||
data = util.frame(util.frame(key['type']), util.frame(sig_bytes))
|
data = util.frame(util.frame(key['type']), util.frame(sig_bytes))
|
||||||
code = util.pack('B', SSH2_AGENT_SIGN_RESPONSE)
|
code = util.pack('B', msg_code('SSH2_AGENT_SIGN_RESPONSE'))
|
||||||
return util.frame(code, data)
|
return util.frame(code, data)
|
||||||
|
|||||||
@@ -22,6 +22,12 @@ def test_list():
|
|||||||
assert reply == LIST_NIST256_REPLY
|
assert reply == LIST_NIST256_REPLY
|
||||||
|
|
||||||
|
|
||||||
|
def test_unsupported():
|
||||||
|
h = protocol.Handler(keys=[], signer=None)
|
||||||
|
reply = h.handle(b'\x09')
|
||||||
|
assert reply == b'\x00\x00\x00\x01\x05'
|
||||||
|
|
||||||
|
|
||||||
def ecdsa_signer(label, blob):
|
def ecdsa_signer(label, blob):
|
||||||
assert label == 'ssh://localhost'
|
assert label == 'ssh://localhost'
|
||||||
assert blob == NIST256_BLOB
|
assert blob == NIST256_BLOB
|
||||||
|
|||||||
@@ -41,16 +41,23 @@ def test_handle():
|
|||||||
conn = FakeSocket()
|
conn = FakeSocket()
|
||||||
server.handle_connection(conn, handler)
|
server.handle_connection(conn, handler)
|
||||||
|
|
||||||
msg = bytearray([protocol.SSH_AGENTC_REQUEST_RSA_IDENTITIES])
|
msg = bytearray([protocol.msg_code('SSH_AGENTC_REQUEST_RSA_IDENTITIES')])
|
||||||
conn = FakeSocket(util.frame(msg))
|
conn = FakeSocket(util.frame(msg))
|
||||||
server.handle_connection(conn, handler)
|
server.handle_connection(conn, handler)
|
||||||
assert conn.tx.getvalue() == b'\x00\x00\x00\x05\x02\x00\x00\x00\x00'
|
assert conn.tx.getvalue() == b'\x00\x00\x00\x05\x02\x00\x00\x00\x00'
|
||||||
|
|
||||||
msg = bytearray([protocol.SSH2_AGENTC_REQUEST_IDENTITIES])
|
msg = bytearray([protocol.msg_code('SSH2_AGENTC_REQUEST_IDENTITIES')])
|
||||||
conn = FakeSocket(util.frame(msg))
|
conn = FakeSocket(util.frame(msg))
|
||||||
server.handle_connection(conn, handler)
|
server.handle_connection(conn, handler)
|
||||||
assert conn.tx.getvalue() == b'\x00\x00\x00\x05\x0C\x00\x00\x00\x00'
|
assert conn.tx.getvalue() == b'\x00\x00\x00\x05\x0C\x00\x00\x00\x00'
|
||||||
|
|
||||||
|
msg = bytearray([protocol.msg_code('SSH2_AGENTC_ADD_IDENTITY')])
|
||||||
|
conn = FakeSocket(util.frame(msg))
|
||||||
|
server.handle_connection(conn, handler)
|
||||||
|
conn.tx.seek(0)
|
||||||
|
reply = util.read_frame(conn.tx)
|
||||||
|
assert reply == util.pack('B', protocol.msg_code('SSH_AGENT_FAILURE'))
|
||||||
|
|
||||||
conn_mock = mock.Mock(spec=FakeSocket)
|
conn_mock = mock.Mock(spec=FakeSocket)
|
||||||
conn_mock.recv.side_effect = [Exception, EOFError]
|
conn_mock.recv.side_effect = [Exception, EOFError]
|
||||||
server.handle_connection(conn=conn_mock, handler=None)
|
server.handle_connection(conn=conn_mock, handler=None)
|
||||||
|
|||||||
Reference in New Issue
Block a user