connection.py 9.64 KB
Newer Older
Yoshinori Okuji's avatar
Yoshinori Okuji committed
1 2 3 4 5 6
import socket
import errno
import logging
from select import select

from protocol import Packet, ProtocolError
Yoshinori Okuji's avatar
Yoshinori Okuji committed
7
from event import IdleEvent
Yoshinori Okuji's avatar
Yoshinori Okuji committed
8

Yoshinori Okuji's avatar
Yoshinori Okuji committed
9 10 11 12 13 14 15 16 17
class BaseConnection(object):
    """A base connection."""
    def __init__(self, event_manager, handler, s = None, addr = None):
        self.em = event_manager
        self.s = s
        self.addr = addr
        self.handler = handler
        if s is not None:
            event_manager.register(self)
Yoshinori Okuji's avatar
Yoshinori Okuji committed
18

Yoshinori Okuji's avatar
Yoshinori Okuji committed
19 20
    def getSocket(self):
        return self.s
Yoshinori Okuji's avatar
Yoshinori Okuji committed
21

Yoshinori Okuji's avatar
Yoshinori Okuji committed
22 23 24 25 26 27
    def setSocket(self, s):
        if self.s is not None:
            raise RuntimeError, 'cannot overwrite a socket in a connection'
        if s is not None:
            self.s = s
            self.em.register(self)
Yoshinori Okuji's avatar
Yoshinori Okuji committed
28

Yoshinori Okuji's avatar
Yoshinori Okuji committed
29 30
    def getAddress(self):
        return self.addr
Yoshinori Okuji's avatar
Yoshinori Okuji committed
31

Yoshinori Okuji's avatar
Yoshinori Okuji committed
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
    def readable(self):
        raise NotImplementedError

    def writable(self):
        raise NotImplementedError

    def getHandler(self):
        return self.handler

    def setHandler(self):
        self.handler = handler

    def getEventManager(self):
        return self.em

class ListeningConnection(BaseConnection):
    """A listen connection."""
    def __init__(self, event_manager, handler, addr = None, **kw):
        logging.info('listening to %s:%d', *addr)
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        try:
            s.setblocking(0)
            s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            s.bind(addr)
            s.listen(5)
        except:
            s.close()
            raise
        BaseConnection.__init__(self, event_manager, handler, s = s, addr = addr)
        self.em.addReader(self)

    def readable(self):
        try:
            new_s, addr = self.s.accept()
            logging.info('accepted a connection from %s:%d', *addr)
            self.handler.connectionAccepted(self, new_s, addr)
        except socket.error, m:
            if m[0] == errno.EAGAIN:
                return
            raise

class Connection(BaseConnection):
    """A connection."""
    def __init__(self, event_manager, handler, s = None, addr = None):
        BaseConnection.__init__(self, handler, event_manager, s = s, addr = addr)
Yoshinori Okuji's avatar
Yoshinori Okuji committed
77
        if s is not None:
Yoshinori Okuji's avatar
Yoshinori Okuji committed
78
            event_manager.addReader(self)
Yoshinori Okuji's avatar
Yoshinori Okuji committed
79 80 81 82
        self.read_buf = []
        self.write_buf = []
        self.cur_id = 0
        self.event_dict = {}
Yoshinori Okuji's avatar
Yoshinori Okuji committed
83 84
        self.aborted = False
        self.uuid = None
Yoshinori Okuji's avatar
Yoshinori Okuji committed
85

Yoshinori Okuji's avatar
Yoshinori Okuji committed
86 87 88 89 90
    def getUUID(self):
        return self.uuid

    def setUUID(self, uuid):
        self.uuid = uuid
Yoshinori Okuji's avatar
Yoshinori Okuji committed
91 92 93 94 95 96 97 98 99 100 101

    def getNextId(self):
        next_id = self.cur_id
        self.cur_id += 1
        if self.cur_id > 0xffff:
            self.cur_id = 0
        return next_id

    def close(self):
        """Close the connection."""
        s = self.s
Yoshinori Okuji's avatar
Yoshinori Okuji committed
102
        em = self.em
Yoshinori Okuji's avatar
Yoshinori Okuji committed
103
        if s is not None:
Yoshinori Okuji's avatar
Yoshinori Okuji committed
104 105 106 107
            logging.debug('closing a socket for %s:%d', *(self.addr))
            em.removeReader(self)
            em.removeWriter(self)
            em.unregister(self)
Yoshinori Okuji's avatar
Yoshinori Okuji committed
108 109 110 111 112 113 114 115
            try:
                # This may fail if the socket is not connected.
                s.shutdown(socket.SHUT_RDWR)
            except socket.error:
                pass
            s.close()
            self.s = None
            for event in self.event_dict.itervalues():
Yoshinori Okuji's avatar
Yoshinori Okuji committed
116
                em.removeIdleEvent(event)
Yoshinori Okuji's avatar
Yoshinori Okuji committed
117 118 119 120
            self.event_dict.clear()

    def abort(self):
        """Abort dealing with this connection."""
Yoshinori Okuji's avatar
Yoshinori Okuji committed
121
        logging.debug('aborting a socket for %s:%d', *(self.addr))
Yoshinori Okuji's avatar
Yoshinori Okuji committed
122 123 124 125
        self.aborted = True

    def writable(self):
        """Called when self is writable."""
Yoshinori Okuji's avatar
Yoshinori Okuji committed
126
        self.send()
Yoshinori Okuji's avatar
Yoshinori Okuji committed
127 128 129 130
        if not self.pending():
            if self.aborted:
                self.close()
            else:
Yoshinori Okuji's avatar
Yoshinori Okuji committed
131
                self.em.removeWriter(self)
Yoshinori Okuji's avatar
Yoshinori Okuji committed
132 133 134 135 136 137 138

    def readable(self):
        """Called when self is readable."""
        self.recv()
        self.analyse()

        if self.aborted:
Yoshinori Okuji's avatar
Yoshinori Okuji committed
139
            self.em.removeReader(self)
Yoshinori Okuji's avatar
Yoshinori Okuji committed
140 141 142 143 144 145 146 147 148 149 150 151 152

    def analyse(self):
        """Analyse received data."""
        if self.read_buf:
            if len(self.read_buf) == 1:
                msg = self.read_buf[0]
            else:
                msg = ''.join(self.read_buf)

            while 1:
                try:
                    packet = Packet.parse(msg)
                except ProtocolError, m:
Yoshinori Okuji's avatar
Yoshinori Okuji committed
153
                    self.handler.packetMalformed(self, *m)
Yoshinori Okuji's avatar
Yoshinori Okuji committed
154 155 156 157 158 159 160 161 162 163
                    return

                if packet is None:
                    break

                # Remove idle events, if appropriate packets were received.
                for msg_id in (None, packet.getId()):
                    try:
                        event = self.event_dict[msg_id]
                        del self.event_dict[msg_id]
Yoshinori Okuji's avatar
Yoshinori Okuji committed
164
                        self.em.removeIdleEvent(event)
Yoshinori Okuji's avatar
Yoshinori Okuji committed
165 166 167
                    except KeyError:
                        pass

Yoshinori Okuji's avatar
Yoshinori Okuji committed
168
                self.handler.packetReceived(self, packet)
169
                msg = msg[len(packet):]
Yoshinori Okuji's avatar
Yoshinori Okuji committed
170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185

            if msg:
                self.read_buf = [msg]
            else:
                del self.read_buf[:]

    def pending(self):
        return self.s is not None and len(self.write_buf) != 0

    def recv(self):
        """Receive data from a socket."""
        s = self.s
        try:
            r = s.recv(4096)
            if not r:
                logging.error('cannot read')
Yoshinori Okuji's avatar
Yoshinori Okuji committed
186
                self.handler.connectionClosed(self)
Yoshinori Okuji's avatar
Yoshinori Okuji committed
187 188 189 190 191
                self.close()
            else:
                self.read_buf.append(r)
        except socket.error, m:
            if m[0] == errno.EAGAIN:
192 193
                pass
            else:
Yoshinori Okuji's avatar
Yoshinori Okuji committed
194 195 196
                logging.error('%s', m[1])
                self.handler.connectionClosed(self)
                self.close()
Yoshinori Okuji's avatar
Yoshinori Okuji committed
197 198 199 200 201 202 203 204 205 206 207 208 209

    def send(self):
        """Send data to a socket."""
        s = self.s
        if self.write_buf:
            if len(self.write_buf) == 1:
                msg = self.write_buf[0]
            else:
                msg = ''.join(self.write_buf)
            try:
                r = s.send(msg)
                if not r:
                    logging.error('cannot write')
Yoshinori Okuji's avatar
Yoshinori Okuji committed
210
                    self.handler.connectionClosed(self)
Yoshinori Okuji's avatar
Yoshinori Okuji committed
211 212 213 214 215 216 217 218
                    self.close()
                elif r == len(msg):
                    del self.write_buf[:]
                else:
                    self.write_buf = [msg[:r]]
            except socket.error, m:
                if m[0] == errno.EAGAIN:
                    return
Yoshinori Okuji's avatar
Yoshinori Okuji committed
219 220 221 222
                else:
                    logging.error('%s', m[1])
                    self.handler.connectionClosed(self)
                    self.close()
Yoshinori Okuji's avatar
Yoshinori Okuji committed
223 224 225 226

    def addPacket(self, packet):
        """Add a packet into the write buffer."""
        try:
Yoshinori Okuji's avatar
Yoshinori Okuji committed
227
            self.write_buf.append(packet.encode())
Yoshinori Okuji's avatar
Yoshinori Okuji committed
228 229
        except ProtocolError, m:
            logging.critical('trying to send a too big message')
Yoshinori Okuji's avatar
Yoshinori Okuji committed
230
            return self.addPacket(packet.internalError(packet.getId(), m[1]))
Yoshinori Okuji's avatar
Yoshinori Okuji committed
231 232 233

        # If this is the first time, enable polling for writing.
        if len(self.write_buf) == 1:
Yoshinori Okuji's avatar
Yoshinori Okuji committed
234
            self.em.addWriter(self.s)
Yoshinori Okuji's avatar
Yoshinori Okuji committed
235

Yoshinori Okuji's avatar
Yoshinori Okuji committed
236
    def expectMessage(self, msg_id = None, timeout = 5, additional_timeout = 30):
Yoshinori Okuji's avatar
Yoshinori Okuji committed
237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257
        """Expect a message for a reply to a given message ID or any message.
        
        The purpose of this method is to define how much amount of time is
        acceptable to wait for a message, thus to detect a down or broken
        peer. This is important, because one error may halt a whole cluster
        otherwise. Although TCP defines a keep-alive feature, the timeout
        is too long generally, and it does not detect a certain type of reply,
        thus it is better to probe problems at the application level.

        The message ID specifies what ID is expected. Usually, this should
        be identical with an ID for a request message. If it is None, any
        message is acceptable, so it can be used to check idle time.
        
        The timeout is the amount of time to wait until keep-alive messages start.
        Once the timeout is expired, the connection starts to ping the peer.
        
        The additional timeout defines the amount of time after the timeout
        to invoke a timeoutExpired callback. If it is zero, no ping is sent, and
        the callback is executed immediately."""
        event = IdleEvent(self, msg_id, timeout, additional_timeout)
        self.event_dict[msg_id] = event
Yoshinori Okuji's avatar
Yoshinori Okuji committed
258 259 260 261 262 263 264 265
        self.em.addIdleEvent(event)

class ClientConnection(Connection):
    """A connection from this node to a remote node."""
    def __init__(self, event_manager, handler, addr = None, **kw):
        Connection.__init__(self, event_manager, handler, addr = addr)
        self.connecting = False
        handler.connectionStarted(self)
Yoshinori Okuji's avatar
Yoshinori Okuji committed
266
        try:
Yoshinori Okuji's avatar
Yoshinori Okuji committed
267 268
            s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            self.setSocket(s)
Yoshinori Okuji's avatar
Yoshinori Okuji committed
269

Yoshinori Okuji's avatar
Yoshinori Okuji committed
270 271 272 273 274 275 276 277 278 279 280 281 282 283 284
            try:
                s.setblocking(0)
                s.connect(addr)
            except socket.error, m:
                if m[0] == errno.EINPROGRESS:
                    self.connecting = True
                    event_manager.addWriter(self)
                else:
                    raise
            else:
                self.handler.connectionCompleted()
                event_manager.addReader(self)
        except:
            handler.connectionFailed(self)
            self.close()
Yoshinori Okuji's avatar
Yoshinori Okuji committed
285

Yoshinori Okuji's avatar
Yoshinori Okuji committed
286 287 288 289 290 291 292 293 294 295 296 297 298 299
    def writable(self):
        """Called when self is writable."""
        if self.connecting:
            err = self.s.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
            if err:
                self.connectionFailed()
                self.close()
                return
            else:
                self.connecting = False
                self.handler.connectionCompleted(self)
                self.cm.addReader(self.s)
        else:
            Connection.writable(self)
Yoshinori Okuji's avatar
Yoshinori Okuji committed
300

Yoshinori Okuji's avatar
Yoshinori Okuji committed
301 302 303
class ServerConnection(Connection):
    """A connection from a remote node to this node."""
    pass