From 460cf445c0c6a49ec360381813ac25e680c0d3a7 Mon Sep 17 00:00:00 2001 From: Yoshinori Okuji <yo@nexedi.com> Date: Wed, 29 Nov 2006 04:41:15 +0000 Subject: [PATCH] Rewrite step one git-svn-id: https://svn.erp5.org/repos/neo/branches/prototype3@22 71dcc9de-d417-0410-9af5-da40c76e7ee4 --- __init__.py | 0 connection.py | 398 +++++++++++++------------------------ event.py | 112 +++++++++++ handler.py | 173 ++++++++++++++++ master/__init__.py | 0 master.py => master/app.py | 255 ++---------------------- neo.conf | 30 ++- neomaster | 17 +- node.py | 35 ++-- protocol.py | 71 ++----- 10 files changed, 514 insertions(+), 577 deletions(-) create mode 100644 __init__.py create mode 100644 event.py create mode 100644 handler.py create mode 100644 master/__init__.py rename master.py => master/app.py (80%) diff --git a/__init__.py b/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/connection.py b/connection.py index 32f841da..3e5f91d6 100644 --- a/connection.py +++ b/connection.py @@ -2,76 +2,92 @@ import socket import errno import logging from select import select -from time import time from protocol import Packet, ProtocolError +from event import IdleEvent -class IdleEvent: - """This class represents an event called when a connection is waiting for - a message too long.""" - - def __init__(self, conn, msg_id, timeout, additional_timeout): - self._conn = conn - self._id = msg_id - t = time() - self._time = t + timeout - self._critical_time = t + timeout + additional_timeout - self._additional_timeout = additional_timeout - - def getId(self): - return self._id - - def getTime(self): - return self._time - - def getCriticalTime(self): - return self._critical_time - - def __call__(self, t): - conn = self._conn - if t > self._critical_time: - logging.info('timeout with %s:%d', conn.ip_address, conn.port) - self._conn.timeoutExpired(self) - return True - elif t > self._time: - if self._additional_timeout > 10: - self._additional_timeout -= 10 - conn.expectMessage(self._id, 10, self._additional_timeout) - # Start a keep-alive packet. - logging.info('sending a ping to %s:%d', conn.ip_address, conn.port) - msg_id = conn.getNextId() - conn.addPacket(Packet().ping(msg_id)) - conn.expectMessage(msg_id, 10, 0) - else: - conn.expectMessage(self._id, self._additional_timeout, 0) - return True - return False +class BaseConnection(object): + """A base connection.""" + def __init__(self, event_manager, handler, s = None, addr = None): + self.em = event_manager + self.s = s + self.addr = addr + self.handler = handler + if s is not None: + event_manager.register(self) + def getSocket(self): + return self.s -class Connection: - """A connection.""" + def setSocket(self, s): + if self.s is not None: + raise RuntimeError, 'cannot overwrite a socket in a connection' + if s is not None: + self.s = s + self.em.register(self) - connecting = False - from_self = False - aborted = False + def getAddress(self): + return self.addr - def __init__(self, connection_manager, s = None, addr = None): - self.s = s + def readable(self): + raise NotImplementedError + + def writable(self): + raise NotImplementedError + + def getHandler(self): + return self.handler + + def setHandler(self): + self.handler = handler + + def getEventManager(self): + return self.em + +class ListeningConnection(BaseConnection): + """A listen connection.""" + def __init__(self, event_manager, handler, addr = None, **kw): + logging.info('listening to %s:%d', *addr) + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + try: + s.setblocking(0) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.bind(addr) + s.listen(5) + except: + s.close() + raise + BaseConnection.__init__(self, event_manager, handler, s = s, addr = addr) + self.em.addReader(self) + + def readable(self): + try: + new_s, addr = self.s.accept() + logging.info('accepted a connection from %s:%d', *addr) + self.handler.connectionAccepted(self, new_s, addr) + except socket.error, m: + if m[0] == errno.EAGAIN: + return + raise + +class Connection(BaseConnection): + """A connection.""" + def __init__(self, event_manager, handler, s = None, addr = None): + BaseConnection.__init__(self, handler, event_manager, s = s, addr = addr) if s is not None: - connection_manager.addReader(s) - self.cm = connection_manager + event_manager.addReader(self) self.read_buf = [] self.write_buf = [] self.cur_id = 0 self.event_dict = {} - if addr is None: - self.ip_address = None - self.port = None - else: - self.ip_address, self.port = addr + self.aborted = False + self.uuid = None - def getSocket(self): - return self.s + def getUUID(self): + return self.uuid + + def setUUID(self, uuid): + self.uuid = uuid def getNextId(self): next_id = self.cur_id @@ -80,46 +96,15 @@ class Connection: self.cur_id = 0 return next_id - def connect(self, ip_address, port): - """Connect to another node.""" - if self.s is not None: - raise RuntimeError, 'already connected' - - self.ip_address = ip_address - self.port = port - self.from_self = True - - try: - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - - try: - s.setblocking(0) - s.connect((ip_address, port)) - except socket.error, m: - if m[0] == errno.EINPROGRESS: - self.connecting = True - self.cm.addWriter(s) - else: - s.close() - raise - else: - self.connectionCompleted() - self.cm.addReader(s) - except socket.error: - self.connectionFailed() - return - - self.s = s - return s - def close(self): """Close the connection.""" s = self.s + em = self.em if s is not None: - logging.debug('closing a socket for %s:%d', self.ip_address, self.port) - self.cm.removeReader(s) - self.cm.removeWriter(s) - self.cm.unregister(self) + logging.debug('closing a socket for %s:%d', *(self.addr)) + em.removeReader(self) + em.removeWriter(self) + em.unregister(self) try: # This may fail if the socket is not connected. s.shutdown(socket.SHUT_RDWR) @@ -128,34 +113,22 @@ class Connection: s.close() self.s = None for event in self.event_dict.itervalues(): - self.cm.removeIdleEvent(event) + em.removeIdleEvent(event) self.event_dict.clear() def abort(self): """Abort dealing with this connection.""" - logging.debug('aborting a socket for %s:%d', self.ip_address, self.port) + logging.debug('aborting a socket for %s:%d', *(self.addr)) self.aborted = True def writable(self): """Called when self is writable.""" - if self.connecting: - err = self.s.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) - if err: - self.connectionFailed() - self.close() - return - else: - self.connecting = False - self.connectionCompleted() - self.cm.addReader(self.s) - else: - self.send() - + self.send() if not self.pending(): if self.aborted: self.close() else: - self.cm.removeWriter(self.s) + self.em.removeWriter(self) def readable(self): """Called when self is readable.""" @@ -163,7 +136,7 @@ class Connection: self.analyse() if self.aborted: - self.cm.removeReader(self.s) + self.em.removeReader(self) def analyse(self): """Analyse received data.""" @@ -177,7 +150,7 @@ class Connection: try: packet = Packet.parse(msg) except ProtocolError, m: - self.packetMalformed(*m) + self.handler.packetMalformed(self, *m) return if packet is None: @@ -188,11 +161,11 @@ class Connection: try: event = self.event_dict[msg_id] del self.event_dict[msg_id] - self.cm.removeIdleEvent(event) + self.em.removeIdleEvent(event) except KeyError: pass - self.packetReceived(packet) + self.handler.packetReceived(self, packet) msg = msg[len(packet):] if msg: @@ -210,19 +183,17 @@ class Connection: r = s.recv(4096) if not r: logging.error('cannot read') - self.connectionClosed() + self.handler.connectionClosed(self) self.close() else: self.read_buf.append(r) except socket.error, m: if m[0] == errno.EAGAIN: pass - elif m[0] == errno.ECONNRESET: - logging.error('cannot read') - self.connectionClosed() - self.close() else: - raise + logging.error('%s', m[1]) + self.handler.connectionClosed(self) + self.close() def send(self): """Send data to a socket.""" @@ -236,7 +207,7 @@ class Connection: r = s.send(msg) if not r: logging.error('cannot write') - self.connectionClosed() + self.handler.connectionClosed(self) self.close() elif r == len(msg): del self.write_buf[:] @@ -245,21 +216,24 @@ class Connection: except socket.error, m: if m[0] == errno.EAGAIN: return - raise + else: + logging.error('%s', m[1]) + self.handler.connectionClosed(self) + self.close() def addPacket(self, packet): """Add a packet into the write buffer.""" try: - self.write_buf.append(str(packet)) + self.write_buf.append(packet.encode()) except ProtocolError, m: logging.critical('trying to send a too big message') - return self.addPacket(Packet().internalError(packet.getId(), m[1])) + return self.addPacket(packet.internalError(packet.getId(), m[1])) # If this is the first time, enable polling for writing. if len(self.write_buf) == 1: - self.cm.addWriter(self.s) + self.em.addWriter(self.s) - def expectMessage(self, msg_id = None, timeout = 10, additional_timeout = 100): + def expectMessage(self, msg_id = None, timeout = 5, additional_timeout = 30): """Expect a message for a reply to a given message ID or any message. The purpose of this method is to define how much amount of time is @@ -281,139 +255,49 @@ class Connection: the callback is executed immediately.""" event = IdleEvent(self, msg_id, timeout, additional_timeout) self.event_dict[msg_id] = event - self.cm.addIdleEvent(event) - - # Hooks. - def connectionFailed(self): - """Called when a connection fails.""" - pass - - def connectionCompleted(self): - """Called when a connection is completed.""" - pass - - def connectionAccepted(self): - """Called when a connection is accepted.""" - # A request for a node identification should arrive. - self.expectMessage(timeout = 10, additional_timeout = 0) - - def connectionClosed(self): - """Called when a connection is closed.""" - pass - - def timeoutExpired(self): - """Called when a timeout event occurs.""" - self.close() - - def peerBroken(self): - """Called when a peer is broken.""" - pass - - def packetReceived(self, packet): - """Called when a packet is received.""" - pass - - def packetMalformed(self, packet, error_message): - """Called when a packet is malformed.""" - logging.info('malformed packet: %s', error_message) - self.addPacket(Packet().protocolError(packet.getId(), error_message)) - self.abort() - self.peerBroken() - -class ConnectionManager: - """This class manages connections and sockets.""" - - def __init__(self, app = None, connection_klass = Connection): - self.listening_socket = None - self.connection_dict = {} - self.reader_set = set([]) - self.writer_set = set([]) - self.exc_list = [] - self.app = app - self.klass = connection_klass - self.event_list = [] - self.prev_time = time() - - def listen(self, ip_address, port): - logging.info('listening to %s:%d', ip_address, port) - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - s.setblocking(0) - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - s.bind((ip_address, port)) - s.listen(5) - self.listening_socket = s - self.reader_set.add(s) - - def getConnectionList(self): - return self.connection_dict.values() - - def register(self, conn): - self.connection_dict[conn.getSocket()] = conn - - def unregister(self, conn): - del self.connection_dict[conn.getSocket()] - - def connect(self, ip_address, port): - logging.info('connecting to %s:%d', ip_address, port) - conn = self.klass(self) - if conn.connect(ip_address, port) is not None: - self.register(conn) - - def poll(self, timeout = 1): - rlist, wlist, xlist = select(self.reader_set, self.writer_set, self.exc_list, - timeout) - for s in rlist: - if s == self.listening_socket: - try: - new_s, addr = s.accept() - logging.info('accepted a connection from %s:%d', addr[0], addr[1]) - conn = self.klass(self, new_s, addr) - self.register(conn) - conn.connectionAccepted() - except socket.error, m: - if m[0] == errno.EAGAIN: - continue - raise - else: - conn = self.connection_dict[s] - conn.readable() - - for s in wlist: - conn = self.connection_dict[s] - conn.writable() - - # Check idle events. Do not check them out too often, because this - # is somehow heavy. - event_list = self.event_list - if event_list: - t = time() - if t - self.prev_time >= 1: - self.prev_time = t - event_list.sort(key = lambda event: event.getTime()) - for event in tuple(event_list): - if event(t): - event_list.pop(0) - else: - break - - def addIdleEvent(self, event): - self.event_list.append(event) - - def removeIdleEvent(self, event): + self.em.addIdleEvent(event) + +class ClientConnection(Connection): + """A connection from this node to a remote node.""" + def __init__(self, event_manager, handler, addr = None, **kw): + Connection.__init__(self, event_manager, handler, addr = addr) + self.connecting = False + handler.connectionStarted(self) try: - self.event_list.remove(event) - except ValueError: - pass - - def addReader(self, s): - self.reader_set.add(s) - - def removeReader(self, s): - self.reader_set.discard(s) + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.setSocket(s) - def addWriter(self, s): - self.writer_set.add(s) + try: + s.setblocking(0) + s.connect(addr) + except socket.error, m: + if m[0] == errno.EINPROGRESS: + self.connecting = True + event_manager.addWriter(self) + else: + raise + else: + self.handler.connectionCompleted() + event_manager.addReader(self) + except: + handler.connectionFailed(self) + self.close() - def removeWriter(self, s): - self.writer_set.discard(s) + def writable(self): + """Called when self is writable.""" + if self.connecting: + err = self.s.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err: + self.connectionFailed() + self.close() + return + else: + self.connecting = False + self.handler.connectionCompleted(self) + self.cm.addReader(self.s) + else: + Connection.writable(self) +class ServerConnection(Connection): + """A connection from a remote node to this node.""" + pass diff --git a/event.py b/event.py new file mode 100644 index 00000000..c24a758a --- /dev/null +++ b/event.py @@ -0,0 +1,112 @@ +import logging +from select import select +from time import time + +class IdleEvent(object): + """This class represents an event called when a connection is waiting for + a message too long.""" + + def __init__(self, conn, msg_id, timeout, additional_timeout): + self._conn = conn + self._id = msg_id + t = time() + self._time = t + timeout + self._critical_time = t + timeout + additional_timeout + self._additional_timeout = additional_timeout + + def getId(self): + return self._id + + def getTime(self): + return self._time + + def getCriticalTime(self): + return self._critical_time + + def __call__(self, t): + conn = self._conn + if t > self._critical_time: + logging.info('timeout with %s:%d', *(conn.getAddress())) + conn.getHandler().timeoutExpired(conn) + conn.close() + return True + elif t > self._time: + if self._additional_timeout > 5: + self._additional_timeout -= 5 + conn.expectMessage(self._id, 5, self._additional_timeout) + # Start a keep-alive packet. + logging.info('sending a ping to %s:%d', *(conn.getAddress())) + msg_id = conn.getNextId() + conn.addPacket(Packet().ping(msg_id)) + conn.expectMessage(msg_id, 5, 0) + else: + conn.expectMessage(self._id, self._additional_timeout, 0) + return True + return False + +class EventManager(object): + """This class manages connections and events.""" + + def __init__(self): + self.connection_dict = {} + self.reader_set = set([]) + self.writer_set = set([]) + self.exc_list = [] + self.event_list = [] + self.prev_time = time() + + def getConnectionList(self): + return self.connection_dict.values() + + def register(self, conn): + self.connection_dict[conn.getSocket()] = conn + + def unregister(self, conn): + del self.connection_dict[conn.getSocket()] + + def poll(self, timeout = 1): + rlist, wlist, xlist = select(self.reader_set, self.writer_set, self.exc_list, + timeout) + for s in rlist: + conn = self.connection_dict[s] + conn.readable() + + for s in wlist: + conn = self.connection_dict[s] + conn.writable() + + # Check idle events. Do not check them out too often, because this + # is somehow heavy. + event_list = self.event_list + if event_list: + t = time() + if t - self.prev_time >= 1: + self.prev_time = t + event_list.sort(key = lambda event: event.getTime()) + for event in tuple(event_list): + if event(t): + event_list.pop(0) + else: + break + + def addIdleEvent(self, event): + self.event_list.append(event) + + def removeIdleEvent(self, event): + try: + self.event_list.remove(event) + except ValueError: + pass + + def addReader(self, conn): + self.reader_set.add(conn.getSocket()) + + def removeReader(self, conn): + self.reader_set.discard(conn.getSocket()) + + def addWriter(self, conn): + self.writer_set.add(conn.getSocket()) + + def removeWriter(self, conn): + self.writer_set.discard(conn.getSocket()) + diff --git a/handler.py b/handler.py new file mode 100644 index 00000000..90812deb --- /dev/null +++ b/handler.py @@ -0,0 +1,173 @@ +import logging + +from protocol import Packet, ProtocolError +from connection import ServerConnection + +from protocol import ERROR, REQUEST_NODE_IDENTIFICATION, ACCEPT_NODE_IDENTIFICATION, \ + PING, PONG, ASK_PRIMARY_MASTER, ANSWER_PRIMARY_MASTER, ANNOUNCE_PRIMARY_MASTER, \ + REELECT_PRIMARY_MASTER, NOTIFY_NODE_INFORMATION, START_OPERATION, \ + STOP_OPERATION, ASK_FINISHING_TRANSACTIONS, ANSWER_FINISHING_TRANSACTIONS, \ + FINISH_TRANSACTIONS, \ + NOT_READY_CODE, OID_NOT_FOUND_CODE, SERIAL_NOT_FOUND_CODE, TID_NOT_FOUND_CODE, \ + PROTOCOL_ERROR_CODE, TIMEOUT_ERROR_CODE, BROKEN_NODE_DISALLOWED_CODE, \ + INTERNAL_ERROR_CODE + +class EventHandler(object): + """This class handles events.""" + def __init__(self): + self.initPacketDispatchTable() + self.initErrorDispatchTable() + + def connectionStarted(self, conn): + """Called when a connection is started.""" + pass + + def connectionCompleted(self, conn): + """Called when a connection is completed.""" + pass + + def connectionFailed(self, conn): + """Called when a connection failed.""" + pass + + def connectionAccepted(self, conn, s, addr): + """Called when a connection is accepted.""" + new_conn = ServerConnection(conn.getEventManager(), conn.getHandler(), + s = s, addr = addr) + # A request for a node identification should arrive. + new_conn.expectMessage(timeout = 10, additional_timeout = 0) + + def timeoutExpired(self, conn): + """Called when a timeout event occurs.""" + pass + + def connectionClosed(self, conn): + """Called when a connection is closed by the peer.""" + pass + + def packetReceived(self, conn, packet): + """Called when a packet is received.""" + self.dispatch(conn, packet) + + def packetMalformed(self, conn, packet, error_message): + """Called when a packet is malformed.""" + logging.info('malformed packet: %s', error_message) + conn.addPacket(Packet().protocolError(packet.getId(), error_message)) + conn.abort() + self.peerBroken(conn) + + def peerBroken(self, conn): + """Called when a peer is broken.""" + logging.error('%s:%d is broken', *(conn.getAddress())) + + def dispatch(self, conn, packet): + """This is a helper method to handle various packet types.""" + t = packet.getType() + try: + method = self.packet_dispatch_table[t] + args = packet.decode() + method(conn, packet, *args) + except ValueError: + self.handleUnexpectedPacket(conn, packet) + except ProtocolError, m: + self.packetMalformed(conn, packet, m[1]) + + def handleUnexpectedPacket(self, conn, packet, message = None): + """Handle an unexpected packet.""" + if message is None: + message = 'unexpected packet type %d' % packet.getType() + else: + message = 'unexpected packet: ' + message + logging.info('%s', message) + conn.addPacket(Packet().protocolError(packet.getId(), message)) + conn.abort() + self.peerBroken(conn) + + # Packet handlers. + + def handleError(self, conn, packet, code, message): + try: + method = self.error_dispatch_table[code] + method(conn, packet, message) + except ValueError: + self.handleUnexpectedPacket(conn, packet, message) + + def handleRequestNodeIdentification(self, conn, packet, node_type, + uuid, ip_address, port, name): + self.handleUnexpectedPacket(conn, packet) + + def handleAcceptNodeIdentification(self, conn, packet, node_type, + uuid, ip_address, port): + self.handleUnexpectedPacket(conn, packet) + + def handlePing(self, conn, packet): + logging.info('got a ping packet; am I overloaded?') + conn.addPacket(Packet().pong(packet.getId())) + + def handlePong(self, conn, packet): + pass + + def handleAskPrimaryNode(self, conn, packet): + self.handleUnexpectedPacket(conn, packet) + + def handleAnswerPrimaryNode(self, conn, packet, primary_uuid, known_master_list): + self.handleUnexpectedPacket(conn, packet) + + def handleAnnouncePrimaryMaster(self, conn, packet): + self.handleUnexpectedPacket(conn, packet) + + def handleReelectPrimaryMaster(self, conn, packet): + self.handleUnexpectedPacket(conn, packet) + + def handleNotifyNodeInformation(self, conn, packet, node_list): + self.handleUnexpectedPacket(conn, packet) + + # Error packet handlers. + + handleNotReady = handleUnexpectedPacket + handleOidNotFound = handleUnexpectedPacket + handleSerialNotFound = handleUnexpectedPacket + handleTidNotFound = handleUnexpectedPacket + + def handleProtocolError(self, conn, packet, message): + raise RuntimeError, 'protocol error: %s' % (message,) + + def handleTimeoutError(self, conn, packet, message): + raise RuntimeError, 'timeout error: %s' % (message,) + + def handleBrokenNodeDisallowedError(self, conn, packet, message): + raise RuntimeError, 'broken node disallowed error: %s' % (message,) + + def handleInternalError(self, conn, packet, message): + self.peerBroken(conn) + conn.close() + + def initPacketDispatchTable(self): + d = {} + + d[ERROR] = self.handleError + d[REQUEST_NODE_IDENTIFICATION] = self.handleRequestNodeIdentification + d[ACCEPT_NODE_IDENTIFICATION] = self.handleAcceptNodeIdentification + d[PING] = self.handlePing + d[PONG] = self.handlePong + d[ASK_PRIMARY_MASTER] = self.handleAskPrimaryMaster + d[ANSWER_PRIMARY_MASTER] = self.handleAnswerPrimaryMaster + d[ANNOUNCE_PRIMARY_MASTER] = self.handleAnnouncePrimaryMaster + d[REELECT_PRIMARY_MASTER] = self.handleReelectPrimaryMaster + d[NOTIFY_NODE_INFORMATION] = self.handleNotifyNodeInformation + + self.packet_dispatch_table = d + + def initErrorDispatchTable(self): + d = {} + + d[NOT_READY_CODE] = self.handleNotReady + d[OID_NOT_FOUND_CODE] = self.handleOidNotFound + d[SERIAL_NOT_FOUND_CODE] = self.handleSerialNotFound + d[TID_NOT_FOUND_CODE] = self.handleTidNotFound + d[PROTOCOL_ERROR_CODE] = self.handleProtocolError + d[TIMEOUT_ERROR_CODE] = self.handleTimeoutError + d[BROKEN_NODE_DISALLOWED_CODE] = self.handleBrokenNodeDisallowedError + d[INTERNAL_ERROR_CODE] = self.handleInternalError + + self.error_dispatch_table = d diff --git a/master/__init__.py b/master/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/master.py b/master/app.py similarity index 80% rename from master.py rename to master/app.py index 184237ed..b8b12be2 100644 --- a/master.py +++ b/master/app.py @@ -5,204 +5,70 @@ from socket import inet_aton from time import time from connection import ConnectionManager -from connection import Connection as BaseConnection -from database import DatabaseManager from config import ConfigurationManager from protocol import Packet, ProtocolError, \ - MASTER_NODE_TYPE, STORAGE_NODE_TYPE, CLIENT_NODE_TYPE, \ - INVALID_UUID, INVALID_TID, INVALID_OID, \ - PROTOCOL_ERROR_CODE, TIMEOUT_ERROR_CODE, BROKEN_NODE_DISALLOWED_CODE, \ - INTERNAL_ERROR_CODE, \ - ERROR, REQUEST_NODE_IDENTIFICATION, ACCEPT_NODE_IDENTIFICATION, \ - PING, PONG, ASK_PRIMARY_MASTER, ANSWER_PRIMARY_MASTER, \ - ANNOUNCE_PRIMARY_MASTER, REELECT_PRIMARY_MASTER -from node import NodeManager, MasterNode, StorageNode, ClientNode, \ RUNNING_STATE, TEMPORARILY_DOWN_STATE, DOWN_STATE, BROKEN_STATE +from node import NodeManager, MasterNode, StorageNode, ClientNode +from handler import EventHandler +from event import EventManager from util import dump class NeoException(Exception): pass class ElectionFailure(NeoException): pass class PrimaryFailure(NeoException): pass -class RecoveryFailure(NeoException): pass -class Connection(BaseConnection): - """This class provides a master-specific connection.""" - - _uuid = None - - def setUUID(self, uuid): - self._uuid = uuid - - def getUUID(self): - return self._uuid - - # Feed callbacks to the master node. - def connectionFailed(self): - self.cm.app.connectionFailed(self) - BaseConnection.connectionFailed(self) - - def connectionCompleted(self): - self.cm.app.connectionCompleted(self) - BaseConnection.connectionCompleted(self) - - def connectionAccepted(self): - self.cm.app.connectionAccepted(self) - BaseConnection.connectionAccepted(self) - - def connectionClosed(self): - self.cm.app.connectionClosed(self) - BaseConnection.connectionClosed(self) - - def packetReceived(self, packet): - self.cm.app.packetReceived(self, packet) - BaseConnection.packetReceived(self, packet) - - def timeoutExpired(self): - self.cm.app.timeoutExpired(self) - BaseConnection.timeoutExpired(self) - - def peerBroken(self): - self.cm.app.peerBroken(self) - BaseConnection.peerBroken(self) - class Application(object): """The master node application.""" def __init__(self, file, section): config = ConfigurationManager(file, section) - self.database = config.getDatabase() - self.user = config.getUser() - self.password = config.getPassword() - logging.debug('database is %s, user is %s, password is %s', - self.database, self.user, self.password) - self.num_replicas = config.getReplicas() self.num_partitions = config.getPartitions() self.name = config.getName() logging.debug('the number of replicas is %d, the number of partitions is %d, the name is %s', self.num_replicas, self.num_partitions, self.name) - self.ip_address, self.port = config.getServer() - logging.debug('IP address is %s, port is %d', self.ip_address, self.port) + self.server = config.getServer() + logging.debug('IP address is %s, port is %d', *(self.server)) # Exclude itself from the list. - self.master_node_list = [n for n in config.getMasterNodeList() - if n != (self.ip_address, self.port)] + self.master_node_list = [n for n in config.getMasterNodeList() if n != self.server] logging.debug('master nodes are %s', self.master_node_list) # Internal attributes. - self.dm = DatabaseManager(self.database, self.user, self.password) - self.cm = ConnectionManager(app = self, connection_klass = Connection) + self.em = EventManager() self.nm = NodeManager() self.primary = None self.primary_master_node = None - self.ready = False - - # Co-operative threads. Simulated by generators. - self.thread_dict = {} - self.server_thread_method = None - self.event = None - - def initializeDatabase(self): - """Initialize a database by recreating all the tables. - - In master nodes, the database is used only to make - some data persistent. All operations are executed on memory. - Thus it is not necessary to make indices on the tables.""" - q = self.dm.query - e = MySQLdb.escape_string - - q("""DROP TABLE IF EXISTS loid, ltid, self, stn, part""") - - q("""CREATE TABLE loid ( - oid BINARY(8) NOT NULL - ) ENGINE = InnoDB COMMENT = 'Last Object ID'""") - q("""INSERT loid VALUES ('%s')""" % e(INVALID_OID)) - - q("""CREATE TABLE ltid ( - tid BINARY(8) NOT NULL - ) ENGINE = InnoDB COMMENT = 'Last Transaction ID'""") - q("""INSERT ltid VALUES ('%s')""" % e(INVALID_TID)) - - q("""CREATE TABLE self ( - uuid BINARY(16) NOT NULL - ) ENGINE = InnoDB COMMENT = 'UUID'""") - # XXX Generate an UUID for self. For now, just use a random string. # Avoid an invalid UUID. while 1: uuid = os.urandom(16) if uuid != INVALID_UUID: break + self.uuid = uuid - q("""INSERT self VALUES ('%s')""" % e(uuid)) - - q("""CREATE TABLE stn ( - nid INT UNSIGNED NOT NULL UNIQUE, - uuid BINARY(16) NOT NULL UNIQUE, - state CHAR(1) NOT NULL - ) ENGINE = InnoDB COMMENT = 'Storage Nodes'""") - - q("""CREATE TABLE part ( - pid INT UNSIGNED NOT NULL, - nid INT UNSIGNED NOT NULL, - state CHAR(1) NOT NULL - ) ENGINE = InnoDB COMMENT = 'Partition Table'""") - - def loadData(self): - """Load persistent data from a database.""" - logging.info('loading data from MySQL') - q = self.dm.query - result = q("""SELECT oid FROM loid""") - if len(result) != 1: - raise RuntimeError, 'the table loid has %d rows' % len(result) - self.loid = result[0][0] - logging.info('the last OID is %r' % dump(self.loid)) - - result = q("""SELECT tid FROM ltid""") - if len(result) != 1: - raise RuntimeError, 'the table ltid has %d rows' % len(result) - self.ltid = result[0][0] - logging.info('the last TID is %r' % dump(self.ltid)) - - result = q("""SELECT uuid FROM self""") - if len(result) != 1: - raise RuntimeError, 'the table self has %d rows' % len(result) - self.uuid = result[0][0] - logging.info('the UUID is %r' % dump(self.uuid)) - - # FIXME load storage and partition information here. - + self.loid = INVALID_OID + self.ltid = INVALID_TID def run(self): """Make sure that the status is sane and start a loop.""" - # Sanity checks. - logging.info('checking the database') - result = self.dm.query("""SHOW TABLES""") - table_set = set([r[0] for r in result]) - existing_table_list = [t for t in ('loid', 'ltid', 'self', 'stn', 'part') - if t in table_set] - if len(existing_table_list) == 0: - # Possibly this is the first time to launch... - self.initializeDatabase() - elif len(existing_table_list) != 5: - raise RuntimeError, 'database inconsistent' - - # XXX More tests are necessary (e.g. check the table structures, - # check the number of partitions, etc.). - - # Now ready to load persistent data from the database. - self.loadData() - - for ip_address, port in self.master_node_list: - self.nm.add(MasterNode(ip_address = ip_address, port = port)) + if self.num_replicas <= 0: + raise RuntimeError, 'replicas must be more than zero' + if self.num_partitions <= 0: + raise RuntimeError, 'partitions must be more than zero' + if len(self.name) == 0: + raise RuntimeError, 'cluster name must be non-empty' + + for server in self.master_node_list: + self.nm.add(MasterNode(server = server)) # Make a listening port. - self.cm.listen(self.ip_address, self.port) + ListeningConnection(self.em, None, addr = self.server) # Start the election of a primary master node. self.electPrimary() @@ -212,95 +78,18 @@ class Application(object): try: if self.primary: while 1: - try: - self.startRecovery() - except RecoveryFailure: - logging.critical('unable to recover the system; use full recovery') - raise + self.startRecovery() self.playPrimaryRole() else: self.playSecondaryRole() raise RuntimeError, 'should not reach here' except (ElectionFailure, PrimaryFailure): # Forget all connections. - for conn in cm.getConnectionList(): + for conn in self.em.getConnectionList(): conn.close() - self.thread_dict.clear() # Reelect a new primary master. self.electPrimary(bootstrap = False) - CONNECTION_FAILED = 'connection failed' - def connectionFailed(self, conn): - addr = (conn.ip_address, conn.port) - t = self.thread_dict[addr] - self.event = (self.CONNECTION_FAILED, conn) - try: - t.next() - except StopIteration: - del self.thread_dict[addr] - - CONNECTION_COMPLETED = 'connection completed' - def connectionCompleted(self, conn): - addr = (conn.ip_address, conn.port) - t = self.thread_dict[addr] - self.event = (self.CONNECTION_COMPLETED, conn) - try: - t.next() - except StopIteration: - del self.thread_dict[addr] - - CONNECTION_CLOSED = 'connection closed' - def connectionClosed(self, conn): - addr = (conn.ip_address, conn.port) - t = self.thread_dict[addr] - self.event = (self.CONNECTION_CLOSED, conn) - try: - t.next() - except StopIteration: - del self.thread_dict[addr] - - CONNECTION_ACCEPTED = 'connection accepted' - def connectionAccepted(self, conn): - addr = (conn.ip_address, conn.port) - logging.debug('making a server thread for %s:%d', conn.ip_address, conn.port) - t = self.server_thread_method() - self.thread_dict[addr] = t - self.event = (self.CONNECTION_ACCEPTED, conn) - try: - t.next() - except StopIteration: - del self.thread_dict[addr] - - TIMEOUT_EXPIRED = 'timeout expired' - def timeoutExpired(self, conn): - addr = (conn.ip_address, conn.port) - t = self.thread_dict[addr] - self.event = (self.TIMEOUT_EXPIRED, conn) - try: - t.next() - except StopIteration: - del self.thread_dict[addr] - - PEER_BROKEN = 'peer broken' - def peerBroken(self, conn): - addr = (conn.ip_address, conn.port) - t = self.thread_dict[addr] - self.event = (self.PEER_BROKEN, conn) - try: - t.next() - except StopIteration: - del self.thread_dict[addr] - - PACKET_RECEIVED = 'packet received' - def packetReceived(self, conn, packet): - addr = (conn.ip_address, conn.port) - t = self.thread_dict[addr] - self.event = (self.PACKET_RECEIVED, conn, packet) - try: - t.next() - except StopIteration: - del self.thread_dict[addr] - def electPrimaryClientIterator(self): """Handle events for a client connection.""" # The first event. This must be a connection failure or connection completion. diff --git a/neo.conf b/neo.conf index c4b167d0..736e71ad 100644 --- a/neo.conf +++ b/neo.conf @@ -1,21 +1,31 @@ +# Default parameters. [DEFAULT] +# The list of master nodes. master_nodes: 127.0.0.1:10010 127.0.0.1:10011 127.0.0.1:10012 -#replicas: 1 -#partitions: 1009 -#name: main +# The number of replicas. +replicas: 1 +# The number of partitions. +partitions: 1009 +# The name of this cluster. +name: main +# The user name for the database. +user: neo +# The password for the database. +password: neo +# The first master. [master1] -database: master1 -user: neo -#password: server: 127.0.0.1:10010 +# The second master. [master2] -database: master2 -user: neo server: 127.0.0.1:10011 +# The third master. [master3] -database: master3 -user: neo server: 127.0.0.1:10012 + +# The first storage. +[storage1] +database: neo1 +server: 127.0.0.1:10020 diff --git a/neomaster b/neomaster index e3360c0a..d404ebad 100755 --- a/neomaster +++ b/neomaster @@ -19,15 +19,13 @@ # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. from optparse import OptionParser -from master import Application +from master.app import Application import logging -# FIXME should be configurable -logging.basicConfig(level = logging.DEBUG) parser = OptionParser() -parser.add_option('-i', '--initialize', action = 'store_true', - help = 'initialize the database') +parser.add_option('-v', '--verbose', action = 'store_true', + help = 'print verbose messages') parser.add_option('-c', '--config', help = 'specify a configuration file') parser.add_option('-s', '--section', help = 'specify a configuration section') @@ -36,9 +34,10 @@ parser.add_option('-s', '--section', help = 'specify a configuration section') config = options.config or 'neo.conf' section = options.section or 'master' -app = Application(config, section) - -if options.initialize: - app.initializeDatabase() +if options.verbose: + logging.basicConfig(level = logging.DEBUG) +else: + logging.basicConfig(level = logging.ERROR) +app = Application(config, section) app.run() diff --git a/node.py b/node.py index 58eb37a1..ccdb3675 100644 --- a/node.py +++ b/node.py @@ -1,14 +1,14 @@ from time import time -from protocol import RUNNING_STATE, TEMPORARILY_DOWN_STATE, DOWN_STATE, BROKEN_STATE +from protocol import RUNNING_STATE, TEMPORARILY_DOWN_STATE, DOWN_STATE, BROKEN_STATE, \ + MASTER_NODE_TYPE, STORAGE_NODE_TYPE, CLIENT_NODE_TYPE class Node(object): """This class represents a node.""" - def __init__(self, ip_address = None, port = None, uuid = None): + def __init__(self, server = None, uuid = None): self.state = RUNNING_STATE - self.ip_address = ip_address - self.port = port + self.server = server self.uuid = uuid self.manager = None self.last_state_change = time() @@ -27,16 +27,15 @@ class Node(object): self.state = new_state self.last_state_change = time() - def setServer(self, ip_address, port): - if self.ip_address is not None: + def setServer(self, server): + if self.server is not None: self.manager.unregisterServer(self) - self.ip_address = ip_address - self.port = port + self.server = server self.manager.registerServer(self) def getServer(self): - return self.ip_address, self.port + return self.server def setUUID(self, uuid): if self.uuid is not None: @@ -48,17 +47,23 @@ class Node(object): def getUUID(self): return self.uuid + def getNodeType(self): + raise NotImplementedError + class MasterNode(Node): """This class represents a master node.""" - pass + def getNodeType(self): + return MASTER_NODE_TYPE class StorageNode(Node): """This class represents a storage node.""" - pass + def getNodeType(self): + return STORAGE_NODE_TYPE class ClientNode(Node): """This class represents a client node.""" - pass + def getNodeType(self): + return CLIENT_NODE_TYPE class NodeManager(object): """This class manages node status.""" @@ -71,7 +76,7 @@ class NodeManager(object): def add(self, node): node.setManager(self) self.node_list.append(node) - if node.getServer()[0] is not None: + if node.getServer() is not None: self.registerServer(node) if node.getUUID() is not None: self.registerUUID(node) @@ -113,8 +118,8 @@ class NodeManager(object): def getClientNodeList(self): return self.getNodeList(filter = lambda node: isinstance(node, ClientNode)) - def getNodeByServer(self, ip_address, port): - return self.server_dict.get((ip_address, port)) + def getNodeByServer(self, server): + return self.server_dict.get(server) def getNodeByUUID(self, uuid): return self.uuid_dict.get(uuid) diff --git a/protocol.py b/protocol.py index 3b5445e4..7bc9c9f0 100644 --- a/protocol.py +++ b/protocol.py @@ -21,14 +21,12 @@ ASK_PRIMARY_MASTER = 0x0003 ANSWER_PRIMARY_MASTER = 0x8003 ANNOUNCE_PRIMARY_MASTER = 0x0004 REELECT_PRIMARY_MASTER = 0x0005 -NOTIFY_NODE_STATE_CHANGE = 0x0006 -SEND_NODE_INFORMATION = 0x0007 -START_OPERATION = 0x0008 -STOP_OPERATION = 0x0009 -ASK_FINISHING_TRANSACTIONS = 0x000a -ANSWER_FINISHING_TRANSACTIONS = 0x800a -FINISH_TRANSACTIONS = 0x000b - +NOTIFY_NODE_INFORMATION = 0x0006 +START_OPERATION = 0x0007 +STOP_OPERATION = 0x0008 +ASK_FINISHING_TRANSACTIONS = 0x0009 +ANSWER_FINISHING_TRANSACTIONS = 0x8009 +FINISH_TRANSACTIONS = 0x000a # Error codes. NOT_READY_CODE = 1 @@ -63,7 +61,7 @@ INVALID_OID = '\0\0\0\0\0\0\0\0' class ProtocolError(Exception): pass -class Packet: +class Packet(object): """A packet.""" _id = None @@ -152,16 +150,16 @@ class Packet: self._body = pack('!H16s4sH', node_type, uuid, inet_aton(ip_address), port) return self - def askPrimaryMaster(self, msg_id, ltid, loid): + def askPrimaryMaster(self, msg_id): self._id = msg_id self._type = ASK_PRIMARY_MASTER - self._body = ltid + loid + self._body = '' return self - def answerPrimaryMaster(self, msg_id, ltid, loid, primary_uuid, known_master_list): + def answerPrimaryMaster(self, msg_id, primary_uuid, known_master_list): self._id = msg_id self._type = ANSWER_PRIMARY_MASTER - body = [ltid, loid, primary_uuid, pack('!L', len(known_master_list))] + body = [primary_uuid, pack('!L', len(known_master_list))] for master in known_master_list: body.append(pack('!4sH16s', inet_aton(master[0]), master[1], master[2])) self._body = ''.join(body) @@ -179,21 +177,9 @@ class Packet: self._body = '' return self - def notifyNodeStateChange(self, msg_id, node_type, ip_address, port, uuid, state): - self._id = msg_id - self._type = NOTIFY_NODE_STATE_CHANGE - self._body = pack('!H4sH16sH', node_type, inet_aton(ip_address), port, uuid, state) - return self - - def askNodeInformation(self, msg_id): - self._id = msg_id - self._type = ASK_NODE_INFORMATION - self._body = '' - return self - - def answerNodeInformation(self, msg_id, node_list): + def notifyNodeInformation(self, msg_id, node_list): self._id = msg_id - self._type = ANSWER_NODE_INFORMATION + self._type = NOTIFY_NODE_INFORMATION body = [pack('!L', len(node_list))] for node_type, ip_address, port, uuid, state in node_list: body.append(pack('!H4sH16sH', node_type, inet_aton(ip_address), port, @@ -261,16 +247,12 @@ class Packet: decode_table[ACCEPT_NODE_IDENTIFICATION] = _decodeAcceptNodeIdentification def _decodeAskPrimaryMaster(self): - try: - ltid, loid = unpack('!8s8s', self._body) - except: - raise ProtocolError(self, 'invalid ask primary master') - return ltid, loid + pass decode_table[ASK_PRIMARY_MASTER] = _decodeAskPrimaryMaster def _decodeAnswerPrimaryMaster(self): try: - ltid, loid, primary_uuid, n = unpack('!8s8s16sL', self._body[:36]) + primary_uuid, n = unpack('!16sL', self._body[:36]) known_master_list = [] for i in xrange(n): ip_address, port, uuid = unpack('!4sH16s', self._body[36+i*22:58+i*22]) @@ -278,7 +260,7 @@ class Packet: known_master_list.append((ip_address, port, uuid)) except: raise ProtocolError(self, 'invalid answer primary master') - return ltid, loid, primary_uuid, known_master_list + return primary_uuid, known_master_list decode_table[ANSWER_PRIMARY_MASTER] = _decodeAnswerPrimaryMaster def _decodeAnnouncePrimaryMaster(self): @@ -289,24 +271,7 @@ class Packet: pass decode_table[REELECT_PRIMARY_MASTER] = _decodeReelectPrimaryMaster - def _decodeNotifyNodeStateChange(self): - try: - node_type, ip_address, port, uuid, state = unpack('!H4sH16sH', self._body[:26]) - ip_address = inet_ntoa(ip_address) - except: - raise ProtocolError(self, 'invalid notify node state change') - if node_type not in VALID_NODE_TYPE_LIST: - raise ProtocolError(self, 'invalid node type %d' % node_type) - if state not in VALID_NODE_STATE_LIST: - raise ProtocolError(self, 'invalid node state %d' % state) - return node_type, ip_address, port, uuid, state - decode_table[NOTIFY_NODE_STATE_CHANGE] = _decodeNotifyNodeStateChange - - def _decodeAskNodeInformation(self): - pass - decode_table[ASK_NODE_INFORMATION] = _decodeAskNodeInformation - - def _decodeAnswerNodeInformation(self): + def _decodeNotifyNodeInformation(self): try: n = unpack('!L', self._body[:4]) node_list = [] @@ -324,4 +289,4 @@ class Packet: except: raise ProtocolError(self, 'invalid answer node information') return node_list - decode_table[ANSWER_NODE_INFORMATION] = _decodeAnswerNodeInformation + decode_table[NOTIFY_NODE_INFORMATION] = _decodeNotifyNodeInformation -- 2.30.9