Commit 9f4a7884 authored by Sam Rushing's avatar Sam Rushing

add support for known_hosts with specified port

diffie-hellman: do not send_newkeys() in response to kexdh_reply
parent 21bcdfbf
......@@ -189,7 +189,7 @@ class Diffie_Hellman_Group1_SHA1(SSH_Key_Exchange):
self.transport.send_disconnect(constants.SSH_DISCONNECT_KEY_EXCHANGE_FAILED, 'Key exchange did not succeed: Signature did not match.')
# Finished...
self.transport.send_newkeys()
#self.transport.send_newkeys()
KEXDH_REPLY_PAYLOAD = (ssh_packet.BYTE,
ssh_packet.STRING, # public host key and certificates (K_S)
......
......@@ -353,11 +353,11 @@ class OpenSSH_Key_Storage(key_storage.SSH_Key_Storage):
get_authorized_keys_filename = staticmethod(get_authorized_keys_filename)
def verify(self, host_id, server_key_types, public_host_key, username=None):
def verify(self, host_id, server_key_types, public_host_key, username=None, port=22):
for key in server_key_types:
if public_host_key.name == key.name:
# This is a supported key type.
if self._verify_contains(host_id, public_host_key, username):
if self._verify_contains(host_id, public_host_key, username, port):
return 1
return 0
......@@ -365,7 +365,7 @@ class OpenSSH_Key_Storage(key_storage.SSH_Key_Storage):
verify = classmethod(verify)
def _verify_contains(host_id, key, username):
def _verify_contains(host_id, key, username, port):
"""_verify_contains(host_id, key, username) -> boolean
Checks whether <key> is in the known_hosts file.
"""
......@@ -373,7 +373,7 @@ class OpenSSH_Key_Storage(key_storage.SSH_Key_Storage):
if not isinstance(host_id, remote_host.IPv4_Remote_Host_ID):
return 0
hostfile = openssh_known_hosts.OpenSSH_Known_Hosts()
return hostfile.check_for_host(host_id, key, username)
return hostfile.check_for_host(host_id, key, username, port)
_verify_contains = staticmethod(_verify_contains)
......
......@@ -75,7 +75,7 @@ class OpenSSH_Known_Hosts:
user_known_hosts_filename = os.path.join(home_dir, '.ssh', 'known_hosts')
return user_known_hosts_filename
def check_for_host(self, host_id, key, username=None):
def check_for_host(self, host_id, key, username=None, port=22):
"""check_for_host(self, host_id, key, username=None) -> boolean
Checks if the given key is in the known_hosts file.
Returns true if it is, otherwise returns false.
......@@ -88,6 +88,7 @@ class OpenSSH_Known_Hosts:
<key> - A SSH_Public_Private_Key instance.
"""
if not isinstance(host_id, IPv4_Remote_Host_ID):
raise TypeError, host_id
......@@ -103,7 +104,7 @@ class OpenSSH_Known_Hosts:
for filename in self.get_known_hosts_filenames(username):
for host in hosts:
try:
if self._check_for_host(filename, host_id, host, key):
if self._check_for_host(filename, host_id, host, port, key):
return 1
except Host_Key_Changed_Error, e:
changed = e
......@@ -113,7 +114,7 @@ class OpenSSH_Known_Hosts:
else:
raise changed
def _check_for_host(self, filename, host_id, host, key):
def _check_for_host(self, filename, host_id, host, port, key):
try:
f = open(filename)
except IOError:
......@@ -129,7 +130,7 @@ class OpenSSH_Known_Hosts:
m = openssh_key_formats.ssh2_known_hosts_entry.match(line)
if m:
if key.name == m.group('keytype'):
if self._match_host(host, m.group('list_of_hosts')):
if self._match_host(host, port, m.group('list_of_hosts')):
if self._match_key(key, m.group('base64_key')):
return 1
else:
......@@ -145,12 +146,12 @@ class OpenSSH_Known_Hosts:
else:
raise changed
def _match_host(self, host, pattern):
def _match_host(self, host, port, pattern):
patterns = pattern.split(',')
# Negated_Pattern is used to terminate the checks.
try:
for p in patterns:
if self._match_pattern(host, p):
if self._match_pattern(host, port, p):
return 1
except OpenSSH_Known_Hosts.Negated_Pattern:
return 0
......@@ -159,7 +160,9 @@ class OpenSSH_Known_Hosts:
class Negated_Pattern(Exception):
pass
def _match_pattern(self, host, pattern):
host_with_port = re.compile ('^\\[([^\\]]+)\\]:([0-9]+)')
def _match_pattern(self, host, port, pattern):
# XXX: OpenSSH does not do any special work to check IP addresses.
# It just assumes that it will match character-for-character.
# Thus, 001.002.003.004 != 1.2.3.4 even though those are technically
......@@ -174,6 +177,17 @@ class OpenSSH_Known_Hosts:
raise OpenSSH_Known_Hosts.Negated_Pattern
else:
return 1
# check for host port
port_probe = self.host_with_port.match (pattern)
if port_probe:
# host with port
host0, port0 = port_probe.groups()
port0 = int (port0)
if host == host0 and port == port0:
if negate:
raise OpenSSH_Known_Hosts.Negated_Pattern
else:
return 1
# Check for wildcards.
# XXX: Lazy
# XXX: We could potentially escape other RE-special characters.
......
......@@ -137,6 +137,9 @@ class coro_socket_transport(l4_transport.Transport):
def get_host_id(self):
return remote_host.IPv4_Remote_Host_ID(self.ip, self.get_hostname())
def get_port(self):
return self.port
# obviously ipv4 only
def to_in_addr_arpa (ip):
octets = ip.split ('.')
......
......@@ -84,7 +84,7 @@ def doit (ip, port):
if not is_ip (ip):
ip = coro.get_resolver().resolve_ipv4 (ip)
debug = coro.ssh.util.debug.Debug()
debug.level = coro.ssh.util.debug.DEBUG_1
debug.level = coro.ssh.util.debug.DEBUG_3
client = coro.ssh.transport.client.SSH_Client_Transport(debug=debug)
transport = coro.ssh.l4_transport.coro_socket_transport.coro_socket_transport(ip, port=port)
client.connect(transport)
......
......@@ -166,7 +166,8 @@ class SSH_Client_Transport(transport.SSH_Transport):
Raises Invalid_Server_Public_Host_Key exception if it does not match.
"""
host_id = self.transport.get_host_id()
port = self.transport.get_port()
for storage in self.supported_key_storages:
if storage.verify(host_id, self.c2s.supported_server_keys, public_host_key, username):
if storage.verify(host_id, self.c2s.supported_server_keys, public_host_key, username, port):
return
raise key_storage.Invalid_Server_Public_Host_Key(host_id, public_host_key)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment