Commit 3e207f4d authored by Julien Muchembled's avatar Julien Muchembled

Review API between the main loop and the various select-able objects

parent a30aec39
#!/usr/bin/python #!/usr/bin/python
import errno, httplib, logging, select, socket, time import httplib, logging, socket
from BaseHTTPServer import BaseHTTPRequestHandler from BaseHTTPServer import BaseHTTPRequestHandler
from SocketServer import ThreadingTCPServer from SocketServer import ThreadingTCPServer
from urlparse import parse_qsl from urlparse import parse_qsl
...@@ -93,33 +93,18 @@ def main(): ...@@ -93,33 +93,18 @@ def main():
def requestHandler(request, client_address, _): def requestHandler(request, client_address, _):
RequestHandler(request, client_address, server) RequestHandler(request, client_address, server)
server_list = [] server_dict = {}
if config.bind4: if config.bind4:
server_list.append(HTTPServer4((config.bind4, config.port), r = HTTPServer4((config.bind4, config.port), requestHandler)
requestHandler)) server_dict[r.fileno()] = r._handle_request_noblock
if config.bind6: if config.bind6:
server_list.append(HTTPServer6((config.bind6, config.port), r = HTTPServer6((config.bind6, config.port), requestHandler)
requestHandler)) server_dict[r.fileno()] = r._handle_request_noblock
if server_list: if server_dict:
empty_list = []
while True: while True:
while True: args = server_dict.copy(), []
next = server.timeout server.select(*args)
if next is None: utils.select(*args)
break
next -= time.time()
if next > 0:
break
server.onTimeout()
try:
r = select.select(server_list[:], empty_list, empty_list,
next)[0]
except select.error as e:
if e.args[0] != errno.EINTR:
raise
else:
for r in r:
r._handle_request_noblock()
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -90,6 +90,10 @@ class RegistryServer(object): ...@@ -90,6 +90,10 @@ class RegistryServer(object):
self.email = self.ca.get_subject().emailAddress self.email = self.ca.get_subject().emailAddress
self.onTimeout() self.onTimeout()
def select(self, r, t):
if self.timeout:
t.append((self.timeout, self.onTimeout))
def onTimeout(self): def onTimeout(self):
# XXX: Because we use threads to process requests, the statements # XXX: Because we use threads to process requests, the statements
# 'self.timeout = 1' below have no effect as long as the # 'self.timeout = 1' below have no effect as long as the
......
import logging, random, socket, subprocess, time import logging, os, random, socket, subprocess, time
from collections import defaultdict, deque from collections import defaultdict, deque
from . import plib, utils, version from . import plib, utils, version
...@@ -104,7 +104,7 @@ class Connection(object): ...@@ -104,7 +104,7 @@ class Connection(object):
class TunnelManager(object): class TunnelManager(object):
def __init__(self, write_pipe, peer_db, openvpn_args, timeout, def __init__(self, peer_db, openvpn_args, timeout,
refresh, client_count, iface_list, network, prefix, refresh, client_count, iface_list, network, prefix,
address, ip_changed, encrypt, remote_gateway, disable_proto, address, ip_changed, encrypt, remote_gateway, disable_proto,
neighbour_list=()): neighbour_list=()):
...@@ -112,7 +112,9 @@ class TunnelManager(object): ...@@ -112,7 +112,9 @@ class TunnelManager(object):
self.ovpn_args = openvpn_args self.ovpn_args = openvpn_args
self.peer_db = peer_db self.peer_db = peer_db
self.timeout = timeout self.timeout = timeout
self.write_pipe = write_pipe # Create and open read_only pipe to get server events
r, self.write_pipe = os.pipe()
self._read_pipe = os.fdopen(r)
self._connecting = set() self._connecting = set()
self._connection_dict = {} self._connection_dict = {}
self._disconnected = None self._disconnected = None
...@@ -140,7 +142,7 @@ class TunnelManager(object): ...@@ -140,7 +142,7 @@ class TunnelManager(object):
# about binding and anycast. # about binding and anycast.
self.sock.bind(('::', PORT)) self.sock.bind(('::', PORT))
self.next_refresh = time.time() self._next_refresh = time.time()
self.resetTunnelRefresh() self.resetTunnelRefresh()
self._client_count = client_count self._client_count = client_count
...@@ -184,6 +186,11 @@ class TunnelManager(object): ...@@ -184,6 +186,11 @@ class TunnelManager(object):
self._free_iface_list.append(iface) self._free_iface_list.append(iface)
del self._iface_to_prefix[iface] del self._iface_to_prefix[iface]
def select(self, r, t):
r[self._read_pipe] = self.handleTunnelEvent
r[self.sock] = self.handlePeerEvent
t.append((self._next_refresh, self.refresh))
def refresh(self): def refresh(self):
logging.debug('Checking tunnels...') logging.debug('Checking tunnels...')
self._cleanDeads() self._cleanDeads()
...@@ -205,7 +212,7 @@ class TunnelManager(object): ...@@ -205,7 +212,7 @@ class TunnelManager(object):
# to see each other. # to see each other.
#if remove and self._free_iface_list: #if remove and self._free_iface_list:
# self._tuntap(self._free_iface_list.pop()) # self._tuntap(self._free_iface_list.pop())
self.next_refresh = time.time() + 5 self._next_refresh = time.time() + 5
def _cleanDeads(self): def _cleanDeads(self):
for prefix in self._connection_dict.keys(): for prefix in self._connection_dict.keys():
...@@ -379,9 +386,9 @@ class TunnelManager(object): ...@@ -379,9 +386,9 @@ class TunnelManager(object):
for prefix in self._connection_dict.keys(): for prefix in self._connection_dict.keys():
self._kill(prefix) self._kill(prefix)
def handleTunnelEvent(self, msg): def handleTunnelEvent(self):
try: try:
msg = msg.rstrip() msg = self._read_pipe.readline().rstrip()
args = msg.split() args = msg.split()
m = getattr(self, '_ovpn_' + args.pop(0).replace('-', '_')) m = getattr(self, '_ovpn_' + args.pop(0).replace('-', '_'))
except (AttributeError, ValueError): except (AttributeError, ValueError):
......
...@@ -28,6 +28,9 @@ class Forwarder(object): ...@@ -28,6 +28,9 @@ class Forwarder(object):
raise UPnPException(str(e)) raise UPnPException(str(e))
return wraps(wrapped)(wrapper) return wraps(wrapped)(wrapper)
def select(self, r, t):
t.append((self.next_refresh, self.refresh))
def checkExternalIp(self, ip=None): def checkExternalIp(self, ip=None):
if ip: if ip:
try: try:
......
import argparse, calendar, errno, logging, os, shlex, signal, socket import argparse, calendar, errno, logging, os, select as _select, shlex, signal
import struct, subprocess, sys, textwrap, threading, time, traceback import socket, struct, subprocess, sys, textwrap, threading, time, traceback
try: try:
subprocess.CalledProcessError(0, '', '') subprocess.CalledProcessError(0, '', '')
except TypeError: # BBB: Python < 2.7 except TypeError: # BBB: Python < 2.7
...@@ -165,6 +165,21 @@ class Popen(subprocess.Popen): ...@@ -165,6 +165,21 @@ class Popen(subprocess.Popen):
return r return r
def select(R, T):
try:
r, w, _ = _select.select(R, (), (),
max(0, min(T)[0] - time.time()) if T else None)
except _select.error as e:
if e.args[0] != errno.EINTR:
raise
return
for r in r:
R[r]()
t = time.time()
for next_refresh, refresh in T:
if next_refresh <= t:
refresh()
def makedirs(path): def makedirs(path):
try: try:
os.makedirs(path) os.makedirs(path)
......
#!/usr/bin/python #!/usr/bin/python
import atexit, errno, logging, os, select, signal, socket import atexit, errno, logging, os, signal, socket
import sqlite3, subprocess, sys, time, threading import sqlite3, subprocess, sys, time, threading
from collections import deque from collections import deque
from OpenSSL import crypto from OpenSSL import crypto
...@@ -301,18 +301,16 @@ def main(): ...@@ -301,18 +301,16 @@ def main():
cleanup = [] cleanup = []
if config.client_count and not config.client: if config.client_count and not config.client:
required('registry') required('registry')
# Create and open read_only pipe to get server events
r_pipe, write_pipe = os.pipe()
read_pipe = os.fdopen(r_pipe)
peer_db = db.PeerDB(db_path, registry, config.key, network, prefix) peer_db = db.PeerDB(db_path, registry, config.key, network, prefix)
cleanup.append(lambda: peer_db.cacheMinimize(config.client_count)) cleanup.append(lambda: peer_db.cacheMinimize(config.client_count))
tunnel_manager = tunnel.TunnelManager(write_pipe, peer_db, tunnel_manager = tunnel.TunnelManager(peer_db,
config.openvpn_args, timeout, config.tunnel_refresh, config.openvpn_args, timeout, config.tunnel_refresh,
config.client_count, config.iface_list, network, prefix, config.client_count, config.iface_list, network, prefix,
address, ip_changed, config.encrypt, remote_gateway, address, ip_changed, config.encrypt, remote_gateway,
config.disable_proto, config.neighbour) config.disable_proto, config.neighbour)
cleanup.append(tunnel_manager.sock.close) cleanup.append(tunnel_manager.sock.close)
tunnel_interfaces += tunnel_manager.new_iface_list tunnel_interfaces += tunnel_manager.new_iface_list
write_pipe = tunnel_manager.write_pipe
else: else:
tunnel_manager = write_pipe = None tunnel_manager = write_pipe = None
...@@ -418,34 +416,18 @@ def main(): ...@@ -418,34 +416,18 @@ def main():
cleanup.append(utils.Popen(cmd, shell=True).stop) cleanup.append(utils.Popen(cmd, shell=True).stop)
# main loop # main loop
if tunnel_manager is None: select_list = [forwarder.select] if forwarder else []
exit.release() if tunnel_manager:
time.sleep(max(0, next_renew - time.time())) select_list.append(tunnel_manager.select)
raise ReexecException("Restart to renew certificate")
cleanup += tunnel_manager.delInterfaces, tunnel_manager.killAll cleanup += tunnel_manager.delInterfaces, tunnel_manager.killAll
exit.release() exit.release()
while True: def renew():
next = tunnel_manager.next_refresh
if forwarder:
next = min(next, forwarder.next_refresh)
r = [read_pipe, tunnel_manager.sock]
try:
r = select.select(r, [], [], max(0, next - time.time()))[0]
except select.error as e:
if e.args[0] != errno.EINTR:
raise
continue
if read_pipe in r:
tunnel_manager.handleTunnelEvent(read_pipe.readline())
if tunnel_manager.sock in r:
tunnel_manager.handlePeerEvent()
t = time.time()
if t >= tunnel_manager.next_refresh:
tunnel_manager.refresh()
if t >= next_renew:
raise ReexecException("Restart to renew certificate") raise ReexecException("Restart to renew certificate")
if forwarder and t >= forwarder.next_refresh: select_list.append(utils.select)
forwarder.refresh() while True:
args = {}, [(next_renew, renew)]
for s in select_list:
s(*args)
finally: finally:
# XXX: We have a possible race condition if a signal is handled at # XXX: We have a possible race condition if a signal is handled at
# the beginning of this clause, just before the following line. # the beginning of this clause, just before the following line.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment