Commit ed50edca authored by Julien Muchembled's avatar Julien Muchembled

Simplify API to establish connections and accept mix of IPv4/IPv6

parent c2c97752
...@@ -21,7 +21,6 @@ from neo.lib.connection import ListeningConnection ...@@ -21,7 +21,6 @@ from neo.lib.connection import ListeningConnection
from neo.lib.exception import PrimaryFailure from neo.lib.exception import PrimaryFailure
from .handler import AdminEventHandler, MasterEventHandler, \ from .handler import AdminEventHandler, MasterEventHandler, \
MasterRequestEventHandler MasterRequestEventHandler
from neo.lib.connector import getConnectorHandler
from neo.lib.bootstrap import BootstrapManager from neo.lib.bootstrap import BootstrapManager
from neo.lib.pt import PartitionTable from neo.lib.pt import PartitionTable
from neo.lib.protocol import ClusterStates, Errors, \ from neo.lib.protocol import ClusterStates, Errors, \
...@@ -39,8 +38,7 @@ class Application(object): ...@@ -39,8 +38,7 @@ class Application(object):
self.name = config.getCluster() self.name = config.getCluster()
self.server = config.getBind() self.server = config.getBind()
self.master_addresses, connector_name = config.getMasters() self.master_addresses = config.getMasters()
self.connector_handler = getConnectorHandler(connector_name)
logging.debug('IP address is %s, port is %d', *self.server) logging.debug('IP address is %s, port is %d', *self.server)
# The partition table is initialized after getting the number of # The partition table is initialized after getting the number of
...@@ -87,8 +85,7 @@ class Application(object): ...@@ -87,8 +85,7 @@ class Application(object):
# Make a listening port. # Make a listening port.
handler = AdminEventHandler(self) handler = AdminEventHandler(self)
self.listening_conn = ListeningConnection(self.em, handler, self.listening_conn = ListeningConnection(self.em, handler, self.server)
addr=self.server, connector=self.connector_handler())
while self.cluster_state != ClusterStates.STOPPING: while self.cluster_state != ClusterStates.STOPPING:
self.connectToPrimary() self.connectToPrimary()
...@@ -120,7 +117,7 @@ class Application(object): ...@@ -120,7 +117,7 @@ class Application(object):
# search, find, connect and identify to the primary master # search, find, connect and identify to the primary master
bootstrap = BootstrapManager(self, self.name, NodeTypes.ADMIN, bootstrap = BootstrapManager(self, self.name, NodeTypes.ADMIN,
self.uuid, self.server) self.uuid, self.server)
data = bootstrap.getPrimaryConnection(self.connector_handler) data = bootstrap.getPrimaryConnection()
(node, conn, uuid, num_partitions, num_replicas) = data (node, conn, uuid, num_partitions, num_replicas) = data
nm.update([(node.getType(), node.getAddress(), node.getUUID(), nm.update([(node.getType(), node.getAddress(), node.getUUID(),
NodeStates.RUNNING)]) NodeStates.RUNNING)])
......
...@@ -36,7 +36,6 @@ from neo.lib.util import makeChecksum, dump ...@@ -36,7 +36,6 @@ from neo.lib.util import makeChecksum, dump
from neo.lib.locking import Lock from neo.lib.locking import Lock
from neo.lib.connection import MTClientConnection, ConnectionClosed from neo.lib.connection import MTClientConnection, ConnectionClosed
from neo.lib.node import NodeManager from neo.lib.node import NodeManager
from neo.lib.connector import getConnectorHandler
from .exception import NEOStorageError, NEOStorageCreationUndoneError from .exception import NEOStorageError, NEOStorageCreationUndoneError
from .exception import NEOStorageNotFoundError from .exception import NEOStorageNotFoundError
from .handlers import storage, master from .handlers import storage, master
...@@ -80,8 +79,6 @@ class Application(object): ...@@ -80,8 +79,6 @@ class Application(object):
# Internal Attributes common to all thread # Internal Attributes common to all thread
self._db = None self._db = None
self.name = name self.name = name
master_addresses, connector_name = parseMasterList(master_nodes)
self.connector_handler = getConnectorHandler(connector_name)
self.dispatcher = Dispatcher(self.poll_thread) self.dispatcher = Dispatcher(self.poll_thread)
self.nm = NodeManager(dynamic_master_list) self.nm = NodeManager(dynamic_master_list)
self.cp = ConnectionPool(self) self.cp = ConnectionPool(self)
...@@ -90,7 +87,7 @@ class Application(object): ...@@ -90,7 +87,7 @@ class Application(object):
self.trying_master_node = None self.trying_master_node = None
# load master node list # load master node list
for address in master_addresses: for address in parseMasterList(master_nodes):
self.nm.createMaster(address=address) self.nm.createMaster(address=address)
# no self-assigned UUID, primary master will supply us one # no self-assigned UUID, primary master will supply us one
...@@ -290,7 +287,6 @@ class Application(object): ...@@ -290,7 +287,6 @@ class Application(object):
conn = MTClientConnection(self.em, conn = MTClientConnection(self.em,
self.notifications_handler, self.notifications_handler,
node=self.trying_master_node, node=self.trying_master_node,
connector=self.connector_handler(),
dispatcher=self.dispatcher) dispatcher=self.dispatcher)
# Query for primary master node # Query for primary master node
if conn.getConnector() is None: if conn.getConnector() is None:
......
...@@ -54,7 +54,7 @@ class ConnectionPool(object): ...@@ -54,7 +54,7 @@ class ConnectionPool(object):
app = self.app app = self.app
logging.debug('trying to connect to %s - %s', node, node.getState()) logging.debug('trying to connect to %s - %s', node, node.getState())
conn = MTClientConnection(app.em, app.storage_event_handler, node, conn = MTClientConnection(app.em, app.storage_event_handler, node,
connector=app.connector_handler(), dispatcher=app.dispatcher) dispatcher=app.dispatcher)
p = Packets.RequestIdentification(NodeTypes.CLIENT, p = Packets.RequestIdentification(NodeTypes.CLIENT,
app.uuid, None, app.name) app.uuid, None, app.name)
try: try:
......
...@@ -116,7 +116,7 @@ class BootstrapManager(EventHandler): ...@@ -116,7 +116,7 @@ class BootstrapManager(EventHandler):
logging.info('Got a new UUID: %s', uuid_str(self.uuid)) logging.info('Got a new UUID: %s', uuid_str(self.uuid))
self.accepted = True self.accepted = True
def getPrimaryConnection(self, connector_handler): def getPrimaryConnection(self):
""" """
Primary lookup/connection process. Primary lookup/connection process.
Returns when the connection is made. Returns when the connection is made.
...@@ -140,8 +140,7 @@ class BootstrapManager(EventHandler): ...@@ -140,8 +140,7 @@ class BootstrapManager(EventHandler):
sleep(1) sleep(1)
if conn is None: if conn is None:
# open the connection # open the connection
conn = ClientConnection(em, self, self.current, conn = ClientConnection(em, self, self.current)
connector_handler())
# still processing # still processing
em.poll(1) em.poll(1)
return (self.current, conn, self.uuid, self.num_partitions, return (self.current, conn, self.uuid, self.num_partitions,
......
...@@ -206,6 +206,7 @@ class BaseConnection(object): ...@@ -206,6 +206,7 @@ class BaseConnection(object):
Timeouts in HandlerSwitcher are only there to prioritize some packets. Timeouts in HandlerSwitcher are only there to prioritize some packets.
""" """
from .connector import SocketConnector as ConnectorClass
KEEP_ALIVE = 60 KEEP_ALIVE = 60
def __init__(self, event_manager, handler, connector, addr=None): def __init__(self, event_manager, handler, connector, addr=None):
...@@ -318,19 +319,18 @@ attributeTracker.track(BaseConnection) ...@@ -318,19 +319,18 @@ attributeTracker.track(BaseConnection)
class ListeningConnection(BaseConnection): class ListeningConnection(BaseConnection):
"""A listen connection.""" """A listen connection."""
def __init__(self, event_manager, handler, addr, connector, **kw): def __init__(self, event_manager, handler, addr):
logging.debug('listening to %s:%d', *addr) logging.debug('listening to %s:%d', *addr)
BaseConnection.__init__(self, event_manager, handler, connector = self.ConnectorClass(addr)
addr=addr, connector=connector) BaseConnection.__init__(self, event_manager, handler, connector, addr)
self.connector.makeListeningConnection(addr) connector.makeListeningConnection()
def readable(self): def readable(self):
try: try:
new_s, addr = self.connector.getNewConnection() connector, addr = self.connector.accept()
logging.debug('accepted a connection from %s:%d', *addr) logging.debug('accepted a connection from %s:%d', *addr)
handler = self.getHandler() handler = self.getHandler()
new_conn = ServerConnection(self.em, handler, new_conn = ServerConnection(self.em, handler, connector, addr)
connector=new_s, addr=addr)
handler.connectionAccepted(new_conn) handler.connectionAccepted(new_conn)
except ConnectorTryAgainException: except ConnectorTryAgainException:
pass pass
...@@ -668,14 +668,15 @@ class ClientConnection(Connection): ...@@ -668,14 +668,15 @@ class ClientConnection(Connection):
connecting = True connecting = True
client = True client = True
def __init__(self, event_manager, handler, node, connector): def __init__(self, event_manager, handler, node):
addr = node.getAddress() addr = node.getAddress()
connector = self.ConnectorClass(addr)
Connection.__init__(self, event_manager, handler, connector, addr) Connection.__init__(self, event_manager, handler, connector, addr)
node.setConnection(self) node.setConnection(self)
handler.connectionStarted(self) handler.connectionStarted(self)
try: try:
try: try:
self.connector.makeClientConnection(addr) connector.makeClientConnection()
except ConnectorInProgressException: except ConnectorInProgressException:
event_manager.addWriter(self) event_manager.addWriter(self)
else: else:
......
...@@ -19,52 +19,51 @@ import errno ...@@ -19,52 +19,51 @@ import errno
# Global connector registry. # Global connector registry.
# Fill by calling registerConnectorHandler. # Fill by calling registerConnectorHandler.
# Read by calling getConnectorHandler. # Read by calling SocketConnector.__new__
connector_registry = {} connector_registry = {}
DEFAULT_CONNECTOR = 'SocketConnectorIPv4'
def registerConnectorHandler(connector_handler): def registerConnectorHandler(connector_handler):
connector_registry[connector_handler.__name__] = connector_handler connector_registry[connector_handler.af_type] = connector_handler
def getConnectorHandler(connector=None):
if connector is None:
connector = DEFAULT_CONNECTOR
if isinstance(connector, basestring):
connector_handler = connector_registry.get(connector)
else:
# Allow to directly provide a handler class without requiring to
# register it first.
connector_handler = connector
return connector_handler
class SocketConnector: class SocketConnector(object):
""" This class is a wrapper for a socket """ """ This class is a wrapper for a socket """
is_listening = False is_closed = is_server = None
remote_addr = None
is_closed = None
def __init__(self, s=None, accepted_from=None): def __new__(cls, addr, s=None):
self.accepted_from = accepted_from
if accepted_from is not None:
self.remote_addr = accepted_from
self.is_listening = False
self.is_closed = False
if s is None: if s is None:
self.socket = socket.socket(self.af_type, socket.SOCK_STREAM) host, port = addr
for af_type, cls in connector_registry.iteritems():
try :
socket.inet_pton(af_type, host)
break
except socket.error:
pass
else: else:
raise ValueError("Unknown type of host", host)
self = object.__new__(cls)
self.addr = cls._normAddress(addr)
if s is None:
s = socket.socket(af_type, socket.SOCK_STREAM)
else:
self.is_server = True
self.is_closed = False
self.socket = s self.socket = s
self.socket_fd = self.socket.fileno() self.socket_fd = s.fileno()
# always use non-blocking sockets # always use non-blocking sockets
self.socket.setblocking(0) s.setblocking(0)
# disable Nagle algorithm to reduce latency # disable Nagle algorithm to reduce latency
self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
return self
def makeClientConnection(self, addr): # Threaded tests monkey-patch the following 2 operations.
self.is_closed = False _connect = lambda self, addr: self.socket.connect(addr)
self.remote_addr = addr _bind = lambda self, addr: self.socket.bind(addr)
def makeClientConnection(self):
assert self.is_closed is None
self.is_server = self.is_closed = False
try: try:
self.socket.connect(addr) self._connect(self.addr)
except socket.error, (err, errmsg): except socket.error, (err, errmsg):
if err == errno.EINPROGRESS: if err == errno.EINPROGRESS:
raise ConnectorInProgressException raise ConnectorInProgressException
...@@ -73,12 +72,12 @@ class SocketConnector: ...@@ -73,12 +72,12 @@ class SocketConnector:
raise ConnectorException, 'makeClientConnection to %s failed:' \ raise ConnectorException, 'makeClientConnection to %s failed:' \
' %s:%s' % (addr, err, errmsg) ' %s:%s' % (addr, err, errmsg)
def makeListeningConnection(self, addr): def makeListeningConnection(self):
assert self.is_closed is None
self.is_closed = False self.is_closed = False
self.is_listening = True
try: try:
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.socket.bind(addr) self._bind(self.addr)
self.socket.listen(5) self.socket.listen(5)
except socket.error, (err, errmsg): except socket.error, (err, errmsg):
self.socket.close() self.socket.close()
...@@ -94,15 +93,22 @@ class SocketConnector: ...@@ -94,15 +93,22 @@ class SocketConnector:
# in epoll # in epoll
return self.socket_fd return self.socket_fd
def getNewConnection(self): @staticmethod
def _normAddress(addr):
return addr
def getAddress(self):
return self._normAddress(self.socket.getsockname())
def accept(self):
try: try:
(new_s, addr) = self._accept() s, addr = self.socket.accept()
new_s = self.__class__(new_s, accepted_from=addr) s = self.__class__(addr, s)
return (new_s, addr) return s, s.addr
except socket.error, (err, errmsg): except socket.error, (err, errmsg):
if err == errno.EAGAIN: if err == errno.EAGAIN:
raise ConnectorTryAgainException raise ConnectorTryAgainException
raise ConnectorException, 'getNewConnection failed: %s:%s' % \ raise ConnectorException, 'accept failed: %s:%s' % \
(err, errmsg) (err, errmsg)
def receive(self): def receive(self):
...@@ -139,14 +145,14 @@ class SocketConnector: ...@@ -139,14 +145,14 @@ class SocketConnector:
state = 'closed ' state = 'closed '
else: else:
state = 'opened ' state = 'opened '
if self.is_listening: if self.is_server is None:
state += 'listening' state += 'listening'
else: else:
if self.accepted_from is None: if self.is_server:
state += 'to '
else:
state += 'from ' state += 'from '
state += str(self.remote_addr) else:
state += 'to '
state += str(self.addr)
return '<%s at 0x%x fileno %s %s, %s>' % (self.__class__.__name__, return '<%s at 0x%x fileno %s %s, %s>' % (self.__class__.__name__,
id(self), '?' if self.is_closed else self.socket_fd, id(self), '?' if self.is_closed else self.socket_fd,
self.getAddress(), state) self.getAddress(), state)
...@@ -155,22 +161,13 @@ class SocketConnectorIPv4(SocketConnector): ...@@ -155,22 +161,13 @@ class SocketConnectorIPv4(SocketConnector):
" Wrapper for IPv4 sockets" " Wrapper for IPv4 sockets"
af_type = socket.AF_INET af_type = socket.AF_INET
def _accept(self):
return self.socket.accept()
def getAddress(self):
return self.socket.getsockname()
class SocketConnectorIPv6(SocketConnector): class SocketConnectorIPv6(SocketConnector):
" Wrapper for IPv6 sockets" " Wrapper for IPv6 sockets"
af_type = socket.AF_INET6 af_type = socket.AF_INET6
def _accept(self): @staticmethod
new_s, addr = self.socket.accept() def _normAddress(addr):
return new_s, addr[:2] return addr[:2]
def getAddress(self):
return self.socket.getsockname()[:2]
registerConnectorHandler(SocketConnectorIPv4) registerConnectorHandler(SocketConnectorIPv4)
registerConnectorHandler(SocketConnectorIPv6) registerConnectorHandler(SocketConnectorIPv6)
......
...@@ -19,12 +19,8 @@ import sys ...@@ -19,12 +19,8 @@ import sys
import traceback import traceback
from cStringIO import StringIO from cStringIO import StringIO
from struct import Struct from struct import Struct
try:
from .util import getAddressType
except ImportError:
pass
PROTOCOL_VERSION = 2 PROTOCOL_VERSION = 3
# Size restrictions. # Size restrictions.
MIN_PACKET_SIZE = 10 MIN_PACKET_SIZE = 10
...@@ -449,65 +445,6 @@ class PEnum(PStructItem): ...@@ -449,65 +445,6 @@ class PEnum(PStructItem):
enum = self._enum.__class__.__name__ enum = self._enum.__class__.__name__
raise ValueError, 'Invalid code for %s enum: %r' % (enum, code) raise ValueError, 'Invalid code for %s enum: %r' % (enum, code)
class PAddressIPGeneric(PStructItem):
def __init__(self, name, format):
PStructItem.__init__(self, name, format)
def encode(self, writer, address):
host, port = address
host = socket.inet_pton(self.af_type, host)
writer(self.pack(host, port))
def decode(self, reader):
data = reader(self.size)
address = self.unpack(data)
host, port = address
host = socket.inet_ntop(self.af_type, host)
return (host, port)
class PAddressIPv4(PAddressIPGeneric):
af_type = socket.AF_INET
def __init__(self, name):
PAddressIPGeneric.__init__(self, name, '!4sH')
class PAddressIPv6(PAddressIPGeneric):
af_type = socket.AF_INET6
def __init__(self, name):
PAddressIPGeneric.__init__(self, name, '!16sH')
class PAddress(PStructItem):
"""
An host address (IPv4/IPv6)
"""
address_format_dict = {
socket.AF_INET: PAddressIPv4('ipv4'),
socket.AF_INET6: PAddressIPv6('ipv6'),
}
def __init__(self, name):
PStructItem.__init__(self, name, '!L')
def _encode(self, writer, address):
if address is None:
writer(self.pack(INVALID_ADDRESS_TYPE))
return
af_type = getAddressType(address)
writer(self.pack(af_type))
encoder = self.address_format_dict[af_type]
encoder.encode(writer, address)
def _decode(self, reader):
af_type = self.unpack(reader(self.size))[0]
if af_type == INVALID_ADDRESS_TYPE:
return None
decoder = self.address_format_dict[af_type]
host, port = decoder.decode(reader)
return (host, port)
class PString(PStructItem): class PString(PStructItem):
""" """
A variable-length string A variable-length string
...@@ -523,6 +460,29 @@ class PString(PStructItem): ...@@ -523,6 +460,29 @@ class PString(PStructItem):
length = self.unpack(reader(self.size))[0] length = self.unpack(reader(self.size))[0]
return reader(length) return reader(length)
class PAddress(PString):
"""
An host address (IPv4/IPv6)
"""
def __init__(self, name):
PString.__init__(self, name)
self._port = Struct('!H')
def _encode(self, writer, address):
if address:
host, port = address
PString._encode(self, writer, host)
writer(self._port.pack(port))
else:
PString._encode(self, writer, '')
def _decode(self, reader):
host = PString._decode(self, reader)
if host:
p = self._port
return host, p.unpack(reader(p.size))[0]
class PBoolean(PStructItem): class PBoolean(PStructItem):
""" """
A boolean value, encoded as a single byte A boolean value, encoded as a single byte
......
...@@ -23,11 +23,6 @@ from Queue import deque ...@@ -23,11 +23,6 @@ from Queue import deque
from struct import pack, unpack from struct import pack, unpack
from time import gmtime from time import gmtime
SOCKET_CONNECTORS_DICT = {
socket.AF_INET : 'SocketConnectorIPv4',
socket.AF_INET6: 'SocketConnectorIPv6',
}
TID_LOW_OVERFLOW = 2**32 TID_LOW_OVERFLOW = 2**32
TID_LOW_MAX = TID_LOW_OVERFLOW - 1 TID_LOW_MAX = TID_LOW_OVERFLOW - 1
SECOND_PER_TID_LOW = 60.0 / TID_LOW_OVERFLOW SECOND_PER_TID_LOW = 60.0 / TID_LOW_OVERFLOW
...@@ -125,25 +120,6 @@ def makeChecksum(s): ...@@ -125,25 +120,6 @@ def makeChecksum(s):
return sha1(s).digest() return sha1(s).digest()
def getAddressType(address):
"Return the type (IPv4 or IPv6) of an ip"
(host, port) = address
for af_type in SOCKET_CONNECTORS_DICT:
try :
socket.inet_pton(af_type, host)
except:
continue
else:
break
else:
raise ValueError("Unknown type of host", host)
return af_type
def getConnectorFromAddress(address):
address_type = getAddressType(address)
return SOCKET_CONNECTORS_DICT[address_type]
def parseNodeAddress(address, port_opt=None): def parseNodeAddress(address, port_opt=None):
if address[:1] == '[': if address[:1] == '[':
(host, port) = address[1:].split(']') (host, port) = address[1:].split(']')
...@@ -164,24 +140,12 @@ def parseNodeAddress(address, port_opt=None): ...@@ -164,24 +140,12 @@ def parseNodeAddress(address, port_opt=None):
def parseMasterList(masters, except_node=None): def parseMasterList(masters, except_node=None):
assert masters, 'At least one master must be defined' assert masters, 'At least one master must be defined'
# load master node list
socket_connector = None
master_node_list = [] master_node_list = []
for node in masters.split(' '): for node in masters.split():
if not node:
continue
address = parseNodeAddress(node) address = parseNodeAddress(node)
if address != except_node:
if (address != except_node):
master_node_list.append(address) master_node_list.append(address)
return master_node_list
socket_connector_temp = getConnectorFromAddress(address)
if socket_connector is None:
socket_connector = socket_connector_temp
elif socket_connector != socket_connector_temp:
raise TypeError("Wrong connector type : you're trying to use "
"ipv6 and ipv4 simultaneously")
return master_node_list, socket_connector
class ReadBuffer(object): class ReadBuffer(object):
......
...@@ -18,7 +18,6 @@ import sys, weakref ...@@ -18,7 +18,6 @@ import sys, weakref
from time import time from time import time
from neo.lib import logging from neo.lib import logging
from neo.lib.connector import getConnectorHandler
from neo.lib.debug import register as registerLiveDebugger from neo.lib.debug import register as registerLiveDebugger
from neo.lib.protocol import uuid_str, UUID_NAMESPACES, ZERO_TID from neo.lib.protocol import uuid_str, UUID_NAMESPACES, ZERO_TID
from neo.lib.protocol import ClusterStates, NodeStates, NodeTypes, Packets from neo.lib.protocol import ClusterStates, NodeStates, NodeTypes, Packets
...@@ -59,9 +58,7 @@ class Application(object): ...@@ -59,9 +58,7 @@ class Application(object):
self.autostart = config.getAutostart() self.autostart = config.getAutostart()
self.storage_readiness = set() self.storage_readiness = set()
master_addresses, connector_name = config.getMasters() for master_address in config.getMasters():
self.connector_handler = getConnectorHandler(connector_name)
for master_address in master_addresses:
self.nm.createMaster(address=master_address) self.nm.createMaster(address=master_address)
logging.debug('IP address is %s, port is %d', *self.server) logging.debug('IP address is %s, port is %d', *self.server)
...@@ -102,7 +99,7 @@ class Application(object): ...@@ -102,7 +99,7 @@ class Application(object):
raise ValueError("upstream cluster name must be" raise ValueError("upstream cluster name must be"
" different from cluster name") " different from cluster name")
self.backup_app = BackupApplication(self, upstream_cluster, self.backup_app = BackupApplication(self, upstream_cluster,
*config.getUpstreamMasters()) config.getUpstreamMasters())
self.administration_handler = administration.AdministrationHandler( self.administration_handler = administration.AdministrationHandler(
self) self)
...@@ -141,8 +138,7 @@ class Application(object): ...@@ -141,8 +138,7 @@ class Application(object):
def _run(self): def _run(self):
"""Make sure that the status is sane and start a loop.""" """Make sure that the status is sane and start a loop."""
# Make a listening port. # Make a listening port.
self.listening_conn = ListeningConnection(self.em, None, self.listening_conn = ListeningConnection(self.em, None, self.server)
addr=self.server, connector=self.connector_handler())
# Start a normal operation. # Start a normal operation.
while self.cluster_state != ClusterStates.STOPPING: while self.cluster_state != ClusterStates.STOPPING:
...@@ -196,8 +192,7 @@ class Application(object): ...@@ -196,8 +192,7 @@ class Application(object):
ClientConnection(self.em, client_handler, ClientConnection(self.em, client_handler,
# XXX: Ugly, but the whole election code will be # XXX: Ugly, but the whole election code will be
# replaced soon # replaced soon
node=getByAddress(addr), getByAddress(addr))
connector=self.connector_handler())
self.unconnected_master_node_set.clear() self.unconnected_master_node_set.clear()
self.em.poll(1) self.em.poll(1)
except ElectionFailure, m: except ElectionFailure, m:
...@@ -381,9 +376,7 @@ class Application(object): ...@@ -381,9 +376,7 @@ class Application(object):
# Reconnect to primary master node. # Reconnect to primary master node.
primary_handler = secondary.PrimaryHandler(self) primary_handler = secondary.PrimaryHandler(self)
ClientConnection(self.em, primary_handler, ClientConnection(self.em, primary_handler, self.primary_master_node)
node=self.primary_master_node,
connector=self.connector_handler())
# and another for the future incoming connections # and another for the future incoming connections
self.listening_conn.setHandler( self.listening_conn.setHandler(
......
...@@ -19,7 +19,6 @@ from bisect import bisect ...@@ -19,7 +19,6 @@ from bisect import bisect
from collections import defaultdict from collections import defaultdict
from neo.lib import logging from neo.lib import logging
from neo.lib.bootstrap import BootstrapManager from neo.lib.bootstrap import BootstrapManager
from neo.lib.connector import getConnectorHandler
from neo.lib.exception import PrimaryFailure from neo.lib.exception import PrimaryFailure
from neo.lib.handler import EventHandler from neo.lib.handler import EventHandler
from neo.lib.node import NodeManager from neo.lib.node import NodeManager
...@@ -67,11 +66,10 @@ class BackupApplication(object): ...@@ -67,11 +66,10 @@ class BackupApplication(object):
pt = None pt = None
def __init__(self, app, name, master_addresses, connector_name): def __init__(self, app, name, master_addresses):
self.app = weakref.proxy(app) self.app = weakref.proxy(app)
self.name = name self.name = name
self.nm = NodeManager() self.nm = NodeManager()
self.connector_handler = getConnectorHandler(connector_name)
for master_address in master_addresses: for master_address in master_addresses:
self.nm.createMaster(address=master_address) self.nm.createMaster(address=master_address)
...@@ -107,7 +105,7 @@ class BackupApplication(object): ...@@ -107,7 +105,7 @@ class BackupApplication(object):
break break
poll(1) poll(1)
node, conn, uuid, num_partitions, num_replicas = \ node, conn, uuid, num_partitions, num_replicas = \
bootstrap.getPrimaryConnection(self.connector_handler) bootstrap.getPrimaryConnection()
try: try:
app.changeClusterState(ClusterStates.BACKINGUP) app.changeClusterState(ClusterStates.BACKINGUP)
del bootstrap, node del bootstrap, node
......
...@@ -14,11 +14,9 @@ ...@@ -14,11 +14,9 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from neo.lib.connector import getConnectorHandler
from neo.lib.connection import ClientConnection from neo.lib.connection import ClientConnection
from neo.lib.event import EventManager from neo.lib.event import EventManager
from neo.lib.protocol import ClusterStates, NodeStates, ErrorCodes, Packets from neo.lib.protocol import ClusterStates, NodeStates, ErrorCodes, Packets
from neo.lib.util import getConnectorFromAddress
from neo.lib.node import NodeManager from neo.lib.node import NodeManager
from .handler import CommandEventHandler from .handler import CommandEventHandler
...@@ -31,8 +29,6 @@ class NeoCTL(object): ...@@ -31,8 +29,6 @@ class NeoCTL(object):
connected = False connected = False
def __init__(self, address): def __init__(self, address):
connector_name = getConnectorFromAddress(address)
self.connector_handler = getConnectorHandler(connector_name)
self.nm = nm = NodeManager() self.nm = nm = NodeManager()
self.server = nm.createAdmin(address=address) self.server = nm.createAdmin(address=address)
self.em = EventManager() self.em = EventManager()
...@@ -47,7 +43,7 @@ class NeoCTL(object): ...@@ -47,7 +43,7 @@ class NeoCTL(object):
def __getConnection(self): def __getConnection(self):
if not self.connected: if not self.connected:
self.connection = ClientConnection(self.em, self.handler, self.connection = ClientConnection(self.em, self.handler,
node=self.server, connector=self.connector_handler()) self.server)
while not self.connected: while not self.connected:
self.em.poll(1) self.em.poll(1)
if self.connection is None: if self.connection is None:
......
...@@ -24,7 +24,6 @@ from neo.lib.node import NodeManager ...@@ -24,7 +24,6 @@ from neo.lib.node import NodeManager
from neo.lib.event import EventManager from neo.lib.event import EventManager
from neo.lib.connection import ListeningConnection from neo.lib.connection import ListeningConnection
from neo.lib.exception import OperationFailure, PrimaryFailure from neo.lib.exception import OperationFailure, PrimaryFailure
from neo.lib.connector import getConnectorHandler
from neo.lib.pt import PartitionTable from neo.lib.pt import PartitionTable
from neo.lib.util import dump from neo.lib.util import dump
from neo.lib.bootstrap import BootstrapManager from neo.lib.bootstrap import BootstrapManager
...@@ -54,9 +53,7 @@ class Application(object): ...@@ -54,9 +53,7 @@ class Application(object):
) )
# load master nodes # load master nodes
master_addresses, connector_name = config.getMasters() for master_address in config.getMasters():
self.connector_handler = getConnectorHandler(connector_name)
for master_address in master_addresses :
self.nm.createMaster(address=master_address) self.nm.createMaster(address=master_address)
# set the bind address # set the bind address
...@@ -177,8 +174,7 @@ class Application(object): ...@@ -177,8 +174,7 @@ class Application(object):
# Make a listening port # Make a listening port
handler = identification.IdentificationHandler(self) handler = identification.IdentificationHandler(self)
self.listening_conn = ListeningConnection(self.em, handler, self.listening_conn = ListeningConnection(self.em, handler, self.server)
addr=self.server, connector=self.connector_handler())
self.server = self.listening_conn.getAddress() self.server = self.listening_conn.getAddress()
# Connect to a primary master node, verify data, and # Connect to a primary master node, verify data, and
...@@ -234,7 +230,7 @@ class Application(object): ...@@ -234,7 +230,7 @@ class Application(object):
# search, find, connect and identify to the primary master # search, find, connect and identify to the primary master
bootstrap = BootstrapManager(self, self.name, bootstrap = BootstrapManager(self, self.name,
NodeTypes.STORAGE, self.uuid, self.server) NodeTypes.STORAGE, self.uuid, self.server)
data = bootstrap.getPrimaryConnection(self.connector_handler) data = bootstrap.getPrimaryConnection()
(node, conn, uuid, num_partitions, num_replicas) = data (node, conn, uuid, num_partitions, num_replicas) = data
self.master_node = node self.master_node = node
self.master_conn = conn self.master_conn = conn
......
...@@ -46,7 +46,7 @@ class Checker(object): ...@@ -46,7 +46,7 @@ class Checker(object):
conn.asClient() conn.asClient()
else: else:
conn = ClientConnection(app.em, StorageOperationHandler(app), conn = ClientConnection(app.em, StorageOperationHandler(app),
node=node, connector=app.connector_handler()) node)
conn.ask(Packets.RequestIdentification( conn.ask(Packets.RequestIdentification(
NodeTypes.STORAGE, uuid, app.server, name)) NodeTypes.STORAGE, uuid, app.server, name))
self.conn_dict[conn] = node.isIdentified() self.conn_dict[conn] = node.isIdentified()
......
...@@ -254,8 +254,7 @@ class Replicator(object): ...@@ -254,8 +254,7 @@ class Replicator(object):
self.fetchTransactions() self.fetchTransactions()
else: else:
assert name or node.getUUID() != app.uuid, "loopback connection" assert name or node.getUUID() != app.uuid, "loopback connection"
conn = ClientConnection(app.em, StorageOperationHandler(app), conn = ClientConnection(app.em, StorageOperationHandler(app), node)
node=node, connector=app.connector_handler())
conn.ask(Packets.RequestIdentification(NodeTypes.STORAGE, conn.ask(Packets.RequestIdentification(NodeTypes.STORAGE,
None if name else app.uuid, app.server, name or app.name)) None if name else app.uuid, app.server, name or app.name))
if previous_node is not None and previous_node.isConnected(): if previous_node is not None and previous_node.isConnected():
......
...@@ -30,7 +30,6 @@ from functools import wraps ...@@ -30,7 +30,6 @@ from functools import wraps
from mock import Mock from mock import Mock
from neo.lib import debug, logging, protocol from neo.lib import debug, logging, protocol
from neo.lib.protocol import NodeTypes, Packets, UUID_NAMESPACES from neo.lib.protocol import NodeTypes, Packets, UUID_NAMESPACES
from neo.lib.util import getAddressType
from time import time from time import time
from struct import pack, unpack from struct import pack, unpack
from unittest.case import _ExpectedFailure, _UnexpectedSuccess from unittest.case import _ExpectedFailure, _UnexpectedSuccess
...@@ -203,8 +202,7 @@ class NeoUnitTestBase(NeoTestBase): ...@@ -203,8 +202,7 @@ class NeoUnitTestBase(NeoTestBase):
return Mock({ return Mock({
'getCluster': cluster, 'getCluster': cluster,
'getBind': masters[0], 'getBind': masters[0],
'getMasters': (masters, getAddressType(( 'getMasters': masters,
self.local_ip, 0))),
'getReplicas': replicas, 'getReplicas': replicas,
'getPartitions': partitions, 'getPartitions': partitions,
'getUUID': uuid, 'getUUID': uuid,
...@@ -226,8 +224,7 @@ class NeoUnitTestBase(NeoTestBase): ...@@ -226,8 +224,7 @@ class NeoUnitTestBase(NeoTestBase):
return Mock({ return Mock({
'getCluster': cluster, 'getCluster': cluster,
'getBind': (masters[0], 10020 + index), 'getBind': (masters[0], 10020 + index),
'getMasters': (masters, getAddressType(( 'getMasters': masters,
self.local_ip, 0))),
'getDatabase': db, 'getDatabase': db,
'getUUID': uuid, 'getUUID': uuid,
'getReset': False, 'getReset': False,
...@@ -554,29 +551,5 @@ class Patch(object): ...@@ -554,29 +551,5 @@ class Patch(object):
self.__del__() self.__del__()
connector_cpt = 0
class DoNothingConnector(Mock):
def __init__(self, s=None):
logging.info("initializing connector")
global connector_cpt
self.desc = connector_cpt
connector_cpt += 1
self.packet_cpt = 0
Mock.__init__(self)
def getAddress(self):
return self.addr
def makeClientConnection(self, addr):
self.addr = addr
def makeListeningConnection(self, addr):
self.addr = addr
def getDescriptor(self):
return self.desc
__builtin__.pdb = lambda depth=0: \ __builtin__.pdb = lambda depth=0: \
debug.getPdb().set_trace(sys._getframe(depth+1)) debug.getPdb().set_trace(sys._getframe(depth+1))
...@@ -25,7 +25,7 @@ from neo.client.cache import test as testCache ...@@ -25,7 +25,7 @@ from neo.client.cache import test as testCache
from neo.client.exception import NEOStorageError, NEOStorageNotFoundError from neo.client.exception import NEOStorageError, NEOStorageNotFoundError
from neo.lib.protocol import NodeTypes, Packets, Errors, \ from neo.lib.protocol import NodeTypes, Packets, Errors, \
INVALID_PARTITION, UUID_NAMESPACES INVALID_PARTITION, UUID_NAMESPACES
from neo.lib.util import makeChecksum, SOCKET_CONNECTORS_DICT from neo.lib.util import makeChecksum
import time import time
class Dispatcher(object): class Dispatcher(object):
...@@ -95,10 +95,9 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -95,10 +95,9 @@ class ClientApplicationTests(NeoUnitTestBase):
return txn_context return txn_context
def getApp(self, master_nodes=None, name='test', **kw): def getApp(self, master_nodes=None, name='test', **kw):
connector = SOCKET_CONNECTORS_DICT[ADDRESS_TYPE]
if master_nodes is None: if master_nodes is None:
master_nodes = '%s:10010' % buildUrlFromString(self.local_ip) master_nodes = '%s:10010' % buildUrlFromString(self.local_ip)
app = Application(master_nodes, name, connector, **kw) app = Application(master_nodes, name, **kw)
self._to_stop_list.append(app) self._to_stop_list.append(app)
app.dispatcher = Mock({ }) app.dispatcher = Mock({ })
return app return app
...@@ -750,7 +749,6 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -750,7 +749,6 @@ class ClientApplicationTests(NeoUnitTestBase):
# the third will not be ready # the third will not be ready
# after the third, the partition table will be operational # after the third, the partition table will be operational
# (as if it was connected to the primary master node) # (as if it was connected to the primary master node)
from .. import DoNothingConnector
# will raise IndexError at the third iteration # will raise IndexError at the third iteration
app = self.getApp('127.0.0.1:10010 127.0.0.1:10011') app = self.getApp('127.0.0.1:10010 127.0.0.1:10011')
# TODO: test more connection failure cases # TODO: test more connection failure cases
...@@ -797,7 +795,6 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -797,7 +795,6 @@ class ClientApplicationTests(NeoUnitTestBase):
app.nm.getByAddress(conn.getAddress())._connection = None app.nm.getByAddress(conn.getAddress())._connection = None
app._ask = _ask_base app._ask = _ask_base
# faked environnement # faked environnement
app.connector_handler = DoNothingConnector
app.em = Mock({'getConnectionList': []}) app.em = Mock({'getConnectionList': []})
app.pt = Mock({ 'operational': False}) app.pt = Mock({ 'operational': False})
app.master_conn = app._connectToPrimaryNode() app.master_conn = app._connectToPrimaryNode()
......
This diff is collapsed.
...@@ -15,35 +15,12 @@ ...@@ -15,35 +15,12 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest import unittest
import socket import socket
from . import NeoUnitTestBase, IP_VERSION_FORMAT_DICT from . import NeoUnitTestBase
from neo.lib.util import ReadBuffer, getAddressType, parseNodeAddress, \ from neo.lib.util import ReadBuffer, parseNodeAddress
getConnectorFromAddress, SOCKET_CONNECTORS_DICT
class UtilTests(NeoUnitTestBase): class UtilTests(NeoUnitTestBase):
def test_getConnectorFromAddress(self):
""" Connector name must correspond to address type """
connector = getConnectorFromAddress((
IP_VERSION_FORMAT_DICT[socket.AF_INET], 0))
self.assertEqual(connector, SOCKET_CONNECTORS_DICT[socket.AF_INET])
connector = getConnectorFromAddress((
IP_VERSION_FORMAT_DICT[socket.AF_INET6], 0))
self.assertEqual(connector, SOCKET_CONNECTORS_DICT[socket.AF_INET6])
self.assertRaises(ValueError, getConnectorFromAddress, ('', 0))
self.assertRaises(ValueError, getConnectorFromAddress, ('test', 0))
def test_getAddressType(self):
""" Get the type on an IP Address """
self.assertRaises(ValueError, getAddressType, ('', 0))
address_type = getAddressType(('::1', 0))
self.assertEqual(address_type, socket.AF_INET6)
address_type = getAddressType(('0.0.0.0', 0))
self.assertEqual(address_type, socket.AF_INET)
address_type = getAddressType(('127.0.0.1', 0))
self.assertEqual(address_type, socket.AF_INET)
def test_parseNodeAddress(self): def test_parseNodeAddress(self):
""" Parsing of addesses """ """ Parsing of addesses """
def test(parsed, *args): def test(parsed, *args):
......
...@@ -35,7 +35,7 @@ from neo.lib.connector import SocketConnector, \ ...@@ -35,7 +35,7 @@ from neo.lib.connector import SocketConnector, \
ConnectorConnectionRefusedException, ConnectorTryAgainException ConnectorConnectionRefusedException, ConnectorTryAgainException
from neo.lib.event import EventManager from neo.lib.event import EventManager
from neo.lib.protocol import CellStates, ClusterStates, NodeStates, NodeTypes from neo.lib.protocol import CellStates, ClusterStates, NodeStates, NodeTypes
from neo.lib.util import SOCKET_CONNECTORS_DICT, parseMasterList, p64 from neo.lib.util import parseMasterList, p64
from .. import NeoTestBase, Patch, getTempDirectory, setupMySQLdb, \ from .. import NeoTestBase, Patch, getTempDirectory, setupMySQLdb, \
ADDRESS_TYPE, IP_VERSION_FORMAT_DICT, DB_PREFIX, DB_USER ADDRESS_TYPE, IP_VERSION_FORMAT_DICT, DB_PREFIX, DB_USER
...@@ -166,7 +166,7 @@ class SerializedEventManager(EventManager): ...@@ -166,7 +166,7 @@ class SerializedEventManager(EventManager):
class Node(object): class Node(object):
def getConnectionList(self, *peers): def getConnectionList(self, *peers):
addr = lambda c: c and (c.accepted_from or c.getAddress()) addr = lambda c: c and (c.addr if c.is_server else c.getAddress())
addr_set = {addr(c.connector) for peer in peers addr_set = {addr(c.connector) for peer in peers
for c in peer.em.connection_dict.itervalues() for c in peer.em.connection_dict.itervalues()
if isinstance(c, Connection)} if isinstance(c, Connection)}
...@@ -467,10 +467,8 @@ class ConnectionFilter(object): ...@@ -467,10 +467,8 @@ class ConnectionFilter(object):
class NEOCluster(object): class NEOCluster(object):
BaseConnection_getTimeout = staticmethod(BaseConnection.getTimeout) BaseConnection_getTimeout = staticmethod(BaseConnection.getTimeout)
SocketConnector_makeClientConnection = staticmethod( SocketConnector_bind = staticmethod(SocketConnector._bind)
SocketConnector.makeClientConnection) SocketConnector_connect = staticmethod(SocketConnector._connect)
SocketConnector_makeListeningConnection = staticmethod(
SocketConnector.makeListeningConnection)
SocketConnector_receive = staticmethod(SocketConnector.receive) SocketConnector_receive = staticmethod(SocketConnector.receive)
SocketConnector_send = staticmethod(SocketConnector.send) SocketConnector_send = staticmethod(SocketConnector.send)
_patch_count = 0 _patch_count = 0
...@@ -489,12 +487,6 @@ class NEOCluster(object): ...@@ -489,12 +487,6 @@ class NEOCluster(object):
cls._patch_count += 1 cls._patch_count += 1
if cls._patch_count > 1: if cls._patch_count > 1:
return return
def makeClientConnection(self, addr):
real_addr = ServerNode.resolv(addr)
try:
return cls.SocketConnector_makeClientConnection(self, real_addr)
finally:
self.remote_addr = addr
def send(self, msg): def send(self, msg):
result = cls.SocketConnector_send(self, msg) result = cls.SocketConnector_send(self, msg)
if type(Serialized.pending) is not frozenset: if type(Serialized.pending) is not frozenset:
...@@ -518,9 +510,10 @@ class NEOCluster(object): ...@@ -518,9 +510,10 @@ class NEOCluster(object):
# safely started even if the cluster isn't. # safely started even if the cluster isn't.
bootstrap.sleep = lambda seconds: None bootstrap.sleep = lambda seconds: None
BaseConnection.getTimeout = lambda self: None BaseConnection.getTimeout = lambda self: None
SocketConnector.makeClientConnection = makeClientConnection SocketConnector._bind = lambda self, addr: \
SocketConnector.makeListeningConnection = lambda self, addr: \ cls.SocketConnector_bind(self, BIND)
cls.SocketConnector_makeListeningConnection(self, BIND) SocketConnector._connect = lambda self, addr: \
cls.SocketConnector_connect(self, ServerNode.resolv(addr))
SocketConnector.receive = receive SocketConnector.receive = receive
SocketConnector.send = send SocketConnector.send = send
Serialized.init() Serialized.init()
...@@ -534,10 +527,8 @@ class NEOCluster(object): ...@@ -534,10 +527,8 @@ class NEOCluster(object):
return return
bootstrap.sleep = time.sleep bootstrap.sleep = time.sleep
BaseConnection.getTimeout = cls.BaseConnection_getTimeout BaseConnection.getTimeout = cls.BaseConnection_getTimeout
SocketConnector.makeClientConnection = \ SocketConnector._bind = cls.SocketConnector_bind
cls.SocketConnector_makeClientConnection SocketConnector._connect = cls.SocketConnector_connect
SocketConnector.makeListeningConnection = \
cls.SocketConnector_makeListeningConnection
SocketConnector.receive = cls.SocketConnector_receive SocketConnector.receive = cls.SocketConnector_receive
SocketConnector.send = cls.SocketConnector_send SocketConnector.send = cls.SocketConnector_send
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment