Commit 9f4a7884 authored by Sam Rushing's avatar Sam Rushing

add support for known_hosts with specified port

diffie-hellman: do not send_newkeys() in response to kexdh_reply
parent 21bcdfbf
...@@ -189,7 +189,7 @@ class Diffie_Hellman_Group1_SHA1(SSH_Key_Exchange): ...@@ -189,7 +189,7 @@ class Diffie_Hellman_Group1_SHA1(SSH_Key_Exchange):
self.transport.send_disconnect(constants.SSH_DISCONNECT_KEY_EXCHANGE_FAILED, 'Key exchange did not succeed: Signature did not match.') self.transport.send_disconnect(constants.SSH_DISCONNECT_KEY_EXCHANGE_FAILED, 'Key exchange did not succeed: Signature did not match.')
# Finished... # Finished...
self.transport.send_newkeys() #self.transport.send_newkeys()
KEXDH_REPLY_PAYLOAD = (ssh_packet.BYTE, KEXDH_REPLY_PAYLOAD = (ssh_packet.BYTE,
ssh_packet.STRING, # public host key and certificates (K_S) ssh_packet.STRING, # public host key and certificates (K_S)
......
...@@ -353,11 +353,11 @@ class OpenSSH_Key_Storage(key_storage.SSH_Key_Storage): ...@@ -353,11 +353,11 @@ class OpenSSH_Key_Storage(key_storage.SSH_Key_Storage):
get_authorized_keys_filename = staticmethod(get_authorized_keys_filename) get_authorized_keys_filename = staticmethod(get_authorized_keys_filename)
def verify(self, host_id, server_key_types, public_host_key, username=None): def verify(self, host_id, server_key_types, public_host_key, username=None, port=22):
for key in server_key_types: for key in server_key_types:
if public_host_key.name == key.name: if public_host_key.name == key.name:
# This is a supported key type. # This is a supported key type.
if self._verify_contains(host_id, public_host_key, username): if self._verify_contains(host_id, public_host_key, username, port):
return 1 return 1
return 0 return 0
...@@ -365,7 +365,7 @@ class OpenSSH_Key_Storage(key_storage.SSH_Key_Storage): ...@@ -365,7 +365,7 @@ class OpenSSH_Key_Storage(key_storage.SSH_Key_Storage):
verify = classmethod(verify) verify = classmethod(verify)
def _verify_contains(host_id, key, username): def _verify_contains(host_id, key, username, port):
"""_verify_contains(host_id, key, username) -> boolean """_verify_contains(host_id, key, username) -> boolean
Checks whether <key> is in the known_hosts file. Checks whether <key> is in the known_hosts file.
""" """
...@@ -373,7 +373,7 @@ class OpenSSH_Key_Storage(key_storage.SSH_Key_Storage): ...@@ -373,7 +373,7 @@ class OpenSSH_Key_Storage(key_storage.SSH_Key_Storage):
if not isinstance(host_id, remote_host.IPv4_Remote_Host_ID): if not isinstance(host_id, remote_host.IPv4_Remote_Host_ID):
return 0 return 0
hostfile = openssh_known_hosts.OpenSSH_Known_Hosts() hostfile = openssh_known_hosts.OpenSSH_Known_Hosts()
return hostfile.check_for_host(host_id, key, username) return hostfile.check_for_host(host_id, key, username, port)
_verify_contains = staticmethod(_verify_contains) _verify_contains = staticmethod(_verify_contains)
......
...@@ -75,7 +75,7 @@ class OpenSSH_Known_Hosts: ...@@ -75,7 +75,7 @@ class OpenSSH_Known_Hosts:
user_known_hosts_filename = os.path.join(home_dir, '.ssh', 'known_hosts') user_known_hosts_filename = os.path.join(home_dir, '.ssh', 'known_hosts')
return user_known_hosts_filename return user_known_hosts_filename
def check_for_host(self, host_id, key, username=None): def check_for_host(self, host_id, key, username=None, port=22):
"""check_for_host(self, host_id, key, username=None) -> boolean """check_for_host(self, host_id, key, username=None) -> boolean
Checks if the given key is in the known_hosts file. Checks if the given key is in the known_hosts file.
Returns true if it is, otherwise returns false. Returns true if it is, otherwise returns false.
...@@ -88,6 +88,7 @@ class OpenSSH_Known_Hosts: ...@@ -88,6 +88,7 @@ class OpenSSH_Known_Hosts:
<key> - A SSH_Public_Private_Key instance. <key> - A SSH_Public_Private_Key instance.
""" """
if not isinstance(host_id, IPv4_Remote_Host_ID): if not isinstance(host_id, IPv4_Remote_Host_ID):
raise TypeError, host_id raise TypeError, host_id
...@@ -103,7 +104,7 @@ class OpenSSH_Known_Hosts: ...@@ -103,7 +104,7 @@ class OpenSSH_Known_Hosts:
for filename in self.get_known_hosts_filenames(username): for filename in self.get_known_hosts_filenames(username):
for host in hosts: for host in hosts:
try: try:
if self._check_for_host(filename, host_id, host, key): if self._check_for_host(filename, host_id, host, port, key):
return 1 return 1
except Host_Key_Changed_Error, e: except Host_Key_Changed_Error, e:
changed = e changed = e
...@@ -113,7 +114,7 @@ class OpenSSH_Known_Hosts: ...@@ -113,7 +114,7 @@ class OpenSSH_Known_Hosts:
else: else:
raise changed raise changed
def _check_for_host(self, filename, host_id, host, key): def _check_for_host(self, filename, host_id, host, port, key):
try: try:
f = open(filename) f = open(filename)
except IOError: except IOError:
...@@ -129,7 +130,7 @@ class OpenSSH_Known_Hosts: ...@@ -129,7 +130,7 @@ class OpenSSH_Known_Hosts:
m = openssh_key_formats.ssh2_known_hosts_entry.match(line) m = openssh_key_formats.ssh2_known_hosts_entry.match(line)
if m: if m:
if key.name == m.group('keytype'): if key.name == m.group('keytype'):
if self._match_host(host, m.group('list_of_hosts')): if self._match_host(host, port, m.group('list_of_hosts')):
if self._match_key(key, m.group('base64_key')): if self._match_key(key, m.group('base64_key')):
return 1 return 1
else: else:
...@@ -145,12 +146,12 @@ class OpenSSH_Known_Hosts: ...@@ -145,12 +146,12 @@ class OpenSSH_Known_Hosts:
else: else:
raise changed raise changed
def _match_host(self, host, pattern): def _match_host(self, host, port, pattern):
patterns = pattern.split(',') patterns = pattern.split(',')
# Negated_Pattern is used to terminate the checks. # Negated_Pattern is used to terminate the checks.
try: try:
for p in patterns: for p in patterns:
if self._match_pattern(host, p): if self._match_pattern(host, port, p):
return 1 return 1
except OpenSSH_Known_Hosts.Negated_Pattern: except OpenSSH_Known_Hosts.Negated_Pattern:
return 0 return 0
...@@ -159,7 +160,9 @@ class OpenSSH_Known_Hosts: ...@@ -159,7 +160,9 @@ class OpenSSH_Known_Hosts:
class Negated_Pattern(Exception): class Negated_Pattern(Exception):
pass pass
def _match_pattern(self, host, pattern): host_with_port = re.compile ('^\\[([^\\]]+)\\]:([0-9]+)')
def _match_pattern(self, host, port, pattern):
# XXX: OpenSSH does not do any special work to check IP addresses. # XXX: OpenSSH does not do any special work to check IP addresses.
# It just assumes that it will match character-for-character. # It just assumes that it will match character-for-character.
# Thus, 001.002.003.004 != 1.2.3.4 even though those are technically # Thus, 001.002.003.004 != 1.2.3.4 even though those are technically
...@@ -174,6 +177,17 @@ class OpenSSH_Known_Hosts: ...@@ -174,6 +177,17 @@ class OpenSSH_Known_Hosts:
raise OpenSSH_Known_Hosts.Negated_Pattern raise OpenSSH_Known_Hosts.Negated_Pattern
else: else:
return 1 return 1
# check for host port
port_probe = self.host_with_port.match (pattern)
if port_probe:
# host with port
host0, port0 = port_probe.groups()
port0 = int (port0)
if host == host0 and port == port0:
if negate:
raise OpenSSH_Known_Hosts.Negated_Pattern
else:
return 1
# Check for wildcards. # Check for wildcards.
# XXX: Lazy # XXX: Lazy
# XXX: We could potentially escape other RE-special characters. # XXX: We could potentially escape other RE-special characters.
......
...@@ -137,6 +137,9 @@ class coro_socket_transport(l4_transport.Transport): ...@@ -137,6 +137,9 @@ class coro_socket_transport(l4_transport.Transport):
def get_host_id(self): def get_host_id(self):
return remote_host.IPv4_Remote_Host_ID(self.ip, self.get_hostname()) return remote_host.IPv4_Remote_Host_ID(self.ip, self.get_hostname())
def get_port(self):
return self.port
# obviously ipv4 only # obviously ipv4 only
def to_in_addr_arpa (ip): def to_in_addr_arpa (ip):
octets = ip.split ('.') octets = ip.split ('.')
......
...@@ -84,7 +84,7 @@ def doit (ip, port): ...@@ -84,7 +84,7 @@ def doit (ip, port):
if not is_ip (ip): if not is_ip (ip):
ip = coro.get_resolver().resolve_ipv4 (ip) ip = coro.get_resolver().resolve_ipv4 (ip)
debug = coro.ssh.util.debug.Debug() debug = coro.ssh.util.debug.Debug()
debug.level = coro.ssh.util.debug.DEBUG_1 debug.level = coro.ssh.util.debug.DEBUG_3
client = coro.ssh.transport.client.SSH_Client_Transport(debug=debug) client = coro.ssh.transport.client.SSH_Client_Transport(debug=debug)
transport = coro.ssh.l4_transport.coro_socket_transport.coro_socket_transport(ip, port=port) transport = coro.ssh.l4_transport.coro_socket_transport.coro_socket_transport(ip, port=port)
client.connect(transport) client.connect(transport)
......
...@@ -166,7 +166,8 @@ class SSH_Client_Transport(transport.SSH_Transport): ...@@ -166,7 +166,8 @@ class SSH_Client_Transport(transport.SSH_Transport):
Raises Invalid_Server_Public_Host_Key exception if it does not match. Raises Invalid_Server_Public_Host_Key exception if it does not match.
""" """
host_id = self.transport.get_host_id() host_id = self.transport.get_host_id()
port = self.transport.get_port()
for storage in self.supported_key_storages: for storage in self.supported_key_storages:
if storage.verify(host_id, self.c2s.supported_server_keys, public_host_key, username): if storage.verify(host_id, self.c2s.supported_server_keys, public_host_key, username, port):
return return
raise key_storage.Invalid_Server_Public_Host_Key(host_id, public_host_key) raise key_storage.Invalid_Server_Public_Host_Key(host_id, public_host_key)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment