Commit c589b0b9 authored by Joanne Hugé's avatar Joanne Hugé

Add firewall-test to test firewall configuration

parent 43972615
#!/usr/bin/python2
import sys
import argparse
import time
import os, signal
import socket, select
import threading
import logging
import traceback
import urllib
class color:
PURPLE = '\033[95m'
CYAN = '\033[96m'
DARKCYAN = '\033[36m'
BLUE = '\033[94m'
GREEN = '\033[92m'
YELLOW = '\033[93m'
RED = '\033[91m'
BOLD = '\033[1m'
UNDERLINE = '\033[4m'
END = '\033[0m'
localhost = {'ipv4': '127.0.0.1', 'ipv6': '::1'}
socket_family = {'ipv4': socket.AF_INET, 'ipv6': socket.AF_INET6}
socket_protocol = {'tcp': socket.SOCK_STREAM, 'udp': socket.SOCK_DGRAM}
dst_port_evt = threading.Event()
free_dst_port = 0
test_passed = False
request_stop = False
pid = os.getpid()
sigint = signal.getsignal(signal.SIGINT)
def sigint_handler(*x):
global request_stop
if os.getpid() == pid:
request_stop = True
sigint(*x)
for sig in filter(lambda x: x.startswith('SIG'), dir(signal)):
try:
signal.signal(getattr(signal, sig), sigint_handler)
except (ValueError, OSError, RuntimeError):
pass
def client_function(src_port, dst_port, family, protocol, remote):
global test_passed
data = '\0' * 16
remote_host = bool(remote)
remote = remote if remote else localhost[family]
try:
logging.info('Client: enter thread')
s = socket.socket(socket_family[family], socket_protocol[protocol])
s.settimeout(2)
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if not dst_port:
if not free_dst_port:
dst_port_evt.wait(timeout=2)
dst_port = free_dst_port
if not dst_port:
logging.debug('Client: no port available')
return
logging.debug('Client: got port: %s', dst_port)
try:
if src_port:
s.bind((localhost[family], src_port))
logging.info('Client: send 16 bytes to server')
if protocol == 'tcp':
logging.debug('Client: connect')
s.connect((remote, dst_port))
logging.debug('Client: connect_done')
if remote_host:
logging.debug('Client: remote_host')
test_passed = True
else:
s.sendall(data)
else:
s.sendto(data, (remote, dst_port))
finally:
if protocol == 'tcp':
s.shutdown(socket.SHUT_RDWR)
s.close()
logging.info('Client: exit thread')
except Exception as e:
logging.info(e + '\n' + traceback.format_exc())
def server_function(dst_port, family, protocol):
global test_passed, free_dst_port
try:
logging.info('Server: enter thread')
s = socket.socket(socket_family[family], socket_protocol[protocol])
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.settimeout(2)
try:
s.bind((localhost[family], dst_port))
if not dst_port:
free_dst_port = s.getsockname()[1]
dst_port_evt.set()
if protocol == 'tcp':
s.listen(1)
conn, _ = s.accept()
else:
conn = s
s.setblocking(0)
try:
while not request_stop:
logging.debug('Server: select')
ready = select.select([conn], [], [], 2)
if ready[0]:
data = conn.recv(16) if protocol == 'tcp' else conn.recvfrom(16)[0]
logging.debug('Server: recieved %s bytes: %s', len(data), repr(data))
test_passed = len(data) == 16
return
finally:
if protocol == 'tcp':
conn.close()
finally:
if protocol == 'tcp':
s.shutdown(socket.SHUT_RDWR)
s.close()
logging.info('Server: exit thread')
except Exception as e:
logging.info(e + '\n' + traceback.format_exc())
def print_test(success, msg):
print (color.BOLD +
('%s' % (color.GREEN + '[OK]' if success else color.RED + '[FAILED]')) +
color.END + ' ' + msg)
def test_ports(src_port, dst_port, family, protocol, remote=''):
global test_passed, request_stop
test_passed = False
if request_stop:
return False
logging.info(color.BOLD + color.BLUE + '[Testing]' + color.END +
(' %s, %s, (%s -> %s)' % (family, protocol, src_port, dst_port)))
client = threading.Thread(target=client_function, args=(
src_port, dst_port, family, protocol, remote))
if not remote:
server = threading.Thread(target=server_function, args=(dst_port, family, protocol))
server.start()
time.sleep(0.1)
client.start()
current_time = time.time()
client.join(2)
if not remote:
server.join(2 - (time.time() - current_time))
if client.is_alive() or (not remote and server.is_alive()):
request_stop = True
# Set timeout to still be interruptible
client.join(2**31)
if not remote:
server.join(2**31)
request_stop = False
msg = ('Source port: %s' % src_port) if src_port else ''
msg += ', ' if (src_port and dst_port) else ''
msg += ('Destination port: %s' % dst_port) if dst_port else ''
msg += ', ' if (src_port or dst_port) else ''
msg += '%s%s' % (protocol.upper(), 'v6' if family == 'ipv6' else '')
msg += (', Remote host: %s' % remote) if remote else ''
print_test(test_passed, msg)
return test_passed
try:
parser = argparse
parser = argparse.ArgumentParser(
description='Firewall configuration test, should be executed before running re6st')
parser.add_argument('-v', action='store_true', required=False, help='verbose')
args = parser.parse_args()
logging.basicConfig(level=(logging.DEBUG if args.v else logging.WARNING))
all_tests = True
all_tests &= test_ports(0, 1194, 'ipv4', 'udp')
all_tests &= test_ports(0, 1194, 'ipv4', 'tcp')
all_tests &= test_ports(1194, 0, 'ipv4', 'udp')
all_tests &= test_ports(1194, 0, 'ipv4', 'tcp')
all_tests &= test_ports(0, 1194, 'ipv4', 'tcp', remote='176.31.129.213')
all_tests &= test_ports(0, 326, 'ipv6', 'udp')
all_tests &= test_ports(326, 0, 'ipv6', 'udp')
all_tests &= test_ports(0, 6696, 'ipv6', 'udp')
all_tests &= test_ports(6696, 0, 'ipv6', 'udp')
all_tests &= test_ports(1900, 0, 'ipv4', 'udp')
http_test = False
try:
http_test = ( urllib.urlopen('https://www.nexedi.com').getcode() == 200)
except Exception as e:
logging.info(e + '\n' + traceback.format_exc())
print_test(http_test, 'HTTP')
all_tests &= http_test
print(color.BOLD + ('Firewall is properly configured for re6st' if all_tests else \
'Firewall is not properly configured for re6st') + color.END)
except Exception as e:
request_stop = True
logging.info(e + '\n' + traceback.format_exc())
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment