Commit 4e739de4 authored by Julien Muchembled's avatar Julien Muchembled

client: review connection locking (MTClientConnection)

This mainly changes several methods to lock automatically instead of asserting
that the caller did it. This removes any overhead for non-MT classes, and
the use of 'with' instead of lock/unlock methods also simplifies the API.
parent e438f864
...@@ -175,11 +175,8 @@ class Application(object): ...@@ -175,11 +175,8 @@ class Application(object):
handler = self.primary_handler handler = self.primary_handler
else: else:
raise ValueError, 'Unknown node type: %r' % (node.__class__, ) raise ValueError, 'Unknown node type: %r' % (node.__class__, )
conn.lock() with conn.lock:
try:
handler.dispatch(conn, packet, kw) handler.dispatch(conn, packet, kw)
finally:
conn.unlock()
def _waitAnyMessage(self, queue, block=True): def _waitAnyMessage(self, queue, block=True):
""" """
......
...@@ -26,7 +26,7 @@ class BaseHandler(EventHandler): ...@@ -26,7 +26,7 @@ class BaseHandler(EventHandler):
self.dispatcher = app.dispatcher self.dispatcher = app.dispatcher
def dispatch(self, conn, packet, kw={}): def dispatch(self, conn, packet, kw={}):
assert conn._lock._is_owned() assert conn.lock._is_owned() # XXX: see also lockCheckWrapper
super(BaseHandler, self).dispatch(conn, packet, kw) super(BaseHandler, self).dispatch(conn, packet, kw)
def packetReceived(self, conn, packet, kw={}): def packetReceived(self, conn, packet, kw={}):
......
...@@ -72,8 +72,7 @@ class ConnectionPool(object): ...@@ -72,8 +72,7 @@ class ConnectionPool(object):
"""Drop connections.""" """Drop connections."""
for conn in self.connection_dict.values(): for conn in self.connection_dict.values():
# Drop first connection which looks not used # Drop first connection which looks not used
conn.lock() with conn.lock:
try:
if not conn.pending() and \ if not conn.pending() and \
not self.app.dispatcher.registered(conn): not self.app.dispatcher.registered(conn):
del self.connection_dict[conn.getUUID()] del self.connection_dict[conn.getUUID()]
...@@ -82,8 +81,6 @@ class ConnectionPool(object): ...@@ -82,8 +81,6 @@ class ConnectionPool(object):
'storage node %s:%d closed', *conn.getAddress()) 'storage node %s:%d closed', *conn.getAddress())
if len(self.connection_dict) <= self.max_pool_size: if len(self.connection_dict) <= self.max_pool_size:
break break
finally:
conn.unlock()
def notifyFailure(self, node): def notifyFailure(self, node):
self.node_failure_dict[node.getUUID()] = time.time() + MAX_FAILURE_AGE self.node_failure_dict[node.getUUID()] = time.time() + MAX_FAILURE_AGE
......
...@@ -41,28 +41,6 @@ def not_closed(func): ...@@ -41,28 +41,6 @@ def not_closed(func):
return wraps(func)(decorator) return wraps(func)(decorator)
def lockCheckWrapper(func):
"""
This function is to be used as a wrapper around
MT(Client|Server)Connection class methods.
It uses a "_" method on RLock class, so it might stop working without
notice (sadly, RLock does not offer any "acquired" method, but that one
will do as it checks that current thread holds this lock).
It requires moniroted class to have an RLock instance in self._lock
property.
"""
def wrapper(self, *args, **kw):
if not self._lock._is_owned():
import traceback
logging.warning('%s called on %s instance without being locked.'
' Stack:\n%s', func.func_code.co_name,
self.__class__.__name__, ''.join(traceback.format_stack()))
# Call anyway
return func(self, *args, **kw)
return wraps(func)(wrapper)
class HandlerSwitcher(object): class HandlerSwitcher(object):
_next_timeout = None _next_timeout = None
_next_timeout_msg_id = None _next_timeout_msg_id = None
...@@ -250,11 +228,8 @@ class BaseConnection(object): ...@@ -250,11 +228,8 @@ class BaseConnection(object):
def checkTimeout(self, t): def checkTimeout(self, t):
pass pass
def lock(self): def lockWrapper(self, func):
return 1 return func
def unlock(self):
return None
def getConnector(self): def getConnector(self):
return self.connector return self.connector
...@@ -495,6 +470,7 @@ class Connection(BaseConnection): ...@@ -495,6 +470,7 @@ class Connection(BaseConnection):
self.analyse() self.analyse()
if self.aborted: if self.aborted:
self.em.removeReader(self) self.em.removeReader(self)
return not not self._queue
def analyse(self): def analyse(self):
"""Analyse received data.""" """Analyse received data."""
...@@ -562,8 +538,8 @@ class Connection(BaseConnection): ...@@ -562,8 +538,8 @@ class Connection(BaseConnection):
global connect_limit global connect_limit
t = time() t = time()
if t < connect_limit: if t < connect_limit:
self.checkTimeout = lambda t: t < connect_limit or \ self.checkTimeout = self.lockWrapper(lambda t:
self._delayed_closure() t < connect_limit or self._delayed_closure())
self.readable = self.writable = lambda: None self.readable = self.writable = lambda: None
else: else:
connect_limit = t + 1 connect_limit = t + 1
...@@ -707,7 +683,7 @@ class ClientConnection(Connection): ...@@ -707,7 +683,7 @@ class ClientConnection(Connection):
self.writable() self.writable()
def _connectionCompleted(self): def _connectionCompleted(self):
self.writable = super(ClientConnection, self).writable self.writable = self.lockWrapper(super(ClientConnection, self).writable)
self.connecting = False self.connecting = False
self.updateTimeout(time()) self.updateTimeout(time())
self.getHandler().connectionCompleted(self) self.getHandler().connectionCompleted(self)
...@@ -729,36 +705,53 @@ class ServerConnection(Connection): ...@@ -729,36 +705,53 @@ class ServerConnection(Connection):
self.updateTimeout(time()) self.updateTimeout(time())
class MTConnectionType(type):
def __init__(cls, *args):
if __debug__:
for name in 'analyse', 'answer':
setattr(cls, name, cls.lockCheckWrapper(name))
for name in ('close', 'checkTimeout', 'notify',
'process', 'readable', 'writable'):
setattr(cls, name, cls.__class__.lockWrapper(cls, name))
def lockCheckWrapper(cls, name):
def wrapper(self, *args, **kw):
# XXX: Unfortunately, RLock does not has any public method
# to test whether we own the lock or not.
assert self.lock._is_owned(), (self, args, kw)
return getattr(super(cls, self), name)(*args, **kw)
return wraps(getattr(cls, name).im_func)(wrapper)
def lockWrapper(cls, name):
def wrapper(self, *args, **kw):
with self.lock:
return getattr(super(cls, self), name)(*args, **kw)
return wraps(getattr(cls, name).im_func)(wrapper)
class MTClientConnection(ClientConnection): class MTClientConnection(ClientConnection):
"""A Multithread-safe version of ClientConnection.""" """A Multithread-safe version of ClientConnection."""
def __metaclass__(name, base, d): __metaclass__ = MTConnectionType
for k in ('analyse', 'answer', 'checkTimeout',
'process', 'readable', 'writable'): def lockWrapper(self, func):
d[k] = lockCheckWrapper(getattr(base[0], k).im_func) lock = self.lock
return type(name, base, d) def wrapper(*args, **kw):
with lock:
return func(*args, **kw)
return wrapper
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
# _lock is only here for lock debugging purposes. Do not use. self.lock = lock = RLock()
self._lock = lock = RLock()
self.lock = lock.acquire
self.unlock = lock.release
self.dispatcher = kwargs.pop('dispatcher') self.dispatcher = kwargs.pop('dispatcher')
self.dispatcher.needPollThread() self.dispatcher.needPollThread()
with lock: with lock:
super(MTClientConnection, self).__init__(*args, **kwargs) super(MTClientConnection, self).__init__(*args, **kwargs)
def notify(self, *args, **kw):
self.lock()
try:
return super(MTClientConnection, self).notify(*args, **kw)
finally:
self.unlock()
def ask(self, packet, timeout=CRITICAL_TIMEOUT, on_timeout=None, def ask(self, packet, timeout=CRITICAL_TIMEOUT, on_timeout=None,
queue=None, **kw): queue=None, **kw):
self.lock() with self.lock:
try:
if self.isClosed(): if self.isClosed():
raise ConnectionClosed raise ConnectionClosed
# XXX: Here, we duplicate Connection.ask because we need to call # XXX: Here, we duplicate Connection.ask because we need to call
...@@ -778,12 +771,3 @@ class MTClientConnection(ClientConnection): ...@@ -778,12 +771,3 @@ class MTClientConnection(ClientConnection):
handlers.emit(packet, timeout, on_timeout, kw) handlers.emit(packet, timeout, on_timeout, kw)
self.updateTimeout(t) self.updateTimeout(t)
return msg_id return msg_id
finally:
self.unlock()
def close(self):
self.lock()
try:
super(MTClientConnection, self).close()
finally:
self.unlock()
...@@ -92,16 +92,12 @@ class EpollEventManager(object): ...@@ -92,16 +92,12 @@ class EpollEventManager(object):
if not self._pending_processing: if not self._pending_processing:
return return
to_process = self._pending_processing.pop(0) to_process = self._pending_processing.pop(0)
to_process.lock()
try: try:
try: to_process.process()
to_process.process()
finally:
# ...and requeue if there are pending messages
if to_process.hasPendingMessages():
self._addPendingConnection(to_process)
finally: finally:
to_process.unlock() # ...and requeue if there are pending messages
if to_process.hasPendingMessages():
self._addPendingConnection(to_process)
# Non-blocking call: as we handled a packet, we should just offer # Non-blocking call: as we handled a packet, we should just offer
# poll a chance to fetch & send already-available data, but it must # poll a chance to fetch & send already-available data, but it must
# not delay us. # not delay us.
...@@ -122,12 +118,7 @@ class EpollEventManager(object): ...@@ -122,12 +118,7 @@ class EpollEventManager(object):
for fd, event in event_list: for fd, event in event_list:
if event & EPOLLIN: if event & EPOLLIN:
conn = self.connection_dict[fd] conn = self.connection_dict[fd]
conn.lock() if conn.readable():
try:
conn.readable()
finally:
conn.unlock()
if conn.hasPendingMessages():
self._addPendingConnection(conn) self._addPendingConnection(conn)
if event & EPOLLOUT: if event & EPOLLOUT:
wlist.append(fd) wlist.append(fd)
...@@ -140,11 +131,7 @@ class EpollEventManager(object): ...@@ -140,11 +131,7 @@ class EpollEventManager(object):
conn = self.connection_dict[fd] conn = self.connection_dict[fd]
except KeyError: except KeyError:
continue continue
conn.lock() conn.writable()
try:
conn.writable()
finally:
conn.unlock()
for fd in elist: for fd in elist:
# This can fail, if a connection is closed in previous calls to # This can fail, if a connection is closed in previous calls to
...@@ -153,21 +140,12 @@ class EpollEventManager(object): ...@@ -153,21 +140,12 @@ class EpollEventManager(object):
conn = self.connection_dict[fd] conn = self.connection_dict[fd]
except KeyError: except KeyError:
continue continue
conn.lock() if conn.readable():
try:
conn.readable()
finally:
conn.unlock()
if conn.hasPendingMessages():
self._addPendingConnection(conn) self._addPendingConnection(conn)
t = time() t = time()
for conn in self.connection_dict.values(): for conn in self.connection_dict.values():
conn.lock() conn.checkTimeout(t)
try:
conn.checkTimeout(t)
finally:
conn.unlock()
def addReader(self, conn): def addReader(self, conn):
connector = conn.getConnector() connector = conn.getConnector()
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import threading
import unittest import unittest
from cPickle import dumps from cPickle import dumps
from mock import Mock, ReturnValues from mock import Mock, ReturnValues
...@@ -44,6 +45,11 @@ def _getMasterConnection(self): ...@@ -44,6 +45,11 @@ def _getMasterConnection(self):
self.master_conn = Mock() self.master_conn = Mock()
return self.master_conn return self.master_conn
def getConnection(kw):
conn = Mock(kw)
conn.lock = threading.RLock()
return conn
def _ask(self, conn, packet, handler=None, **kw): def _ask(self, conn, packet, handler=None, **kw):
self.setHandlerData(None) self.setHandlerData(None)
conn.ask(packet, **kw) conn.ask(packet, **kw)
...@@ -110,7 +116,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -110,7 +116,7 @@ class ClientApplicationTests(NeoUnitTestBase):
makeTID = makeOID makeTID = makeOID
def getNodeCellConn(self, index=1, address=('127.0.0.1', 10000), uuid=None): def getNodeCellConn(self, index=1, address=('127.0.0.1', 10000), uuid=None):
conn = Mock({ conn = getConnection({
'getAddress': address, 'getAddress': address,
'__repr__': 'connection mock', '__repr__': 'connection mock',
'getUUID': uuid, 'getUUID': uuid,
...@@ -167,8 +173,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -167,8 +173,7 @@ class ClientApplicationTests(NeoUnitTestBase):
response_packet = Packets.AnswerNewOIDs(test_oid_list[:]) response_packet = Packets.AnswerNewOIDs(test_oid_list[:])
response_packet.setId(0) response_packet.setId(0)
app.master_conn = Mock({'getNextId': test_msg_id, '_addPacket': None, app.master_conn = Mock({'getNextId': test_msg_id, '_addPacket': None,
'expectMessage': None, 'lock': None, 'expectMessage': None,
'unlock': None,
# Test-specific method # Test-specific method
'fakeReceived': response_packet}) 'fakeReceived': response_packet})
new_oid = app.new_oid() new_oid = app.new_oid()
...@@ -434,12 +439,12 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -434,12 +439,12 @@ class ClientApplicationTests(NeoUnitTestBase):
packet2 = Packets.AnswerStoreObject(conflicting=1, oid=oid1, serial=tid) packet2 = Packets.AnswerStoreObject(conflicting=1, oid=oid1, serial=tid)
packet3 = Packets.AnswerStoreObject(conflicting=0, oid=oid2, serial=tid) packet3 = Packets.AnswerStoreObject(conflicting=0, oid=oid2, serial=tid)
[p.setId(i) for p, i in zip([packet1, packet2, packet3], range(3))] [p.setId(i) for p, i in zip([packet1, packet2, packet3], range(3))]
conn1 = Mock({'__repr__': 'conn1', 'getAddress': address1, conn1 = getConnection({'__repr__': 'conn1', 'getAddress': address1,
'fakeReceived': packet1, 'getUUID': uuid1}) 'fakeReceived': packet1, 'getUUID': uuid1})
conn2 = Mock({'__repr__': 'conn2', 'getAddress': address2, conn2 = getConnection({'__repr__': 'conn2', 'getAddress': address2,
'fakeReceived': packet2, 'getUUID': uuid2}) 'fakeReceived': packet2, 'getUUID': uuid2})
conn3 = Mock({'__repr__': 'conn3', 'getAddress': address3, conn3 = getConnection({'__repr__': 'conn3', 'getAddress': address3,
'fakeReceived': packet3, 'getUUID': uuid3}) 'fakeReceived': packet3, 'getUUID': uuid3})
node1 = Mock({'__repr__': 'node1', '__hash__': 1, 'getConnection': conn1}) node1 = Mock({'__repr__': 'node1', '__hash__': 1, 'getConnection': conn1})
node2 = Mock({'__repr__': 'node2', '__hash__': 2, 'getConnection': conn2}) node2 = Mock({'__repr__': 'node2', '__hash__': 2, 'getConnection': conn2})
node3 = Mock({'__repr__': 'node3', '__hash__': 3, 'getConnection': conn3}) node3 = Mock({'__repr__': 'node3', '__hash__': 3, 'getConnection': conn3})
...@@ -520,7 +525,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -520,7 +525,7 @@ class ClientApplicationTests(NeoUnitTestBase):
transaction_info = Packets.AnswerTransactionInformation(tid1, '', '', transaction_info = Packets.AnswerTransactionInformation(tid1, '', '',
'', False, (oid0, )) '', False, (oid0, ))
transaction_info.setId(1) transaction_info.setId(1)
conn = Mock({ conn = getConnection({
'getNextId': 1, 'getNextId': 1,
'fakeReceived': transaction_info, 'fakeReceived': transaction_info,
'getAddress': ('127.0.0.1', 10020), 'getAddress': ('127.0.0.1', 10020),
...@@ -706,7 +711,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -706,7 +711,7 @@ class ClientApplicationTests(NeoUnitTestBase):
}) })
asked = [] asked = []
def answerTIDs(packet): def answerTIDs(packet):
conn = Mock({'getAddress': packet}) conn = getConnection({'getAddress': packet})
app.nm.createStorage(address=conn.getAddress()) app.nm.createStorage(address=conn.getAddress())
def ask(p, queue, **kw): def ask(p, queue, **kw):
asked.append(p) asked.append(p)
......
...@@ -587,7 +587,7 @@ class ConnectionTests(NeoUnitTestBase): ...@@ -587,7 +587,7 @@ class ConnectionTests(NeoUnitTestBase):
DoNothingConnector.receive = receive DoNothingConnector.receive = receive
try: try:
bc = self._makeConnection() bc = self._makeConnection()
bc._queue = Mock() bc._queue = Mock({'__len__': 0})
self._checkReadBuf(bc, '') self._checkReadBuf(bc, '')
self.assertFalse(bc.aborted) self.assertFalse(bc.aborted)
bc.readable() bc.readable()
......
...@@ -104,13 +104,9 @@ class EventTests(NeoUnitTestBase): ...@@ -104,13 +104,9 @@ class EventTests(NeoUnitTestBase):
self.assertEqual(data, 10) self.assertEqual(data, 10)
# need to rebuild completely this test and the the packet queue # need to rebuild completely this test and the the packet queue
# check readable conn # check readable conn
#self.assertEqual(len(r_conn.mockGetNamedCalls("lock")), 1)
#self.assertEqual(len(r_conn.mockGetNamedCalls("unlock")), 1)
#self.assertEqual(len(r_conn.mockGetNamedCalls("readable")), 1) #self.assertEqual(len(r_conn.mockGetNamedCalls("readable")), 1)
#self.assertEqual(len(r_conn.mockGetNamedCalls("writable")), 0) #self.assertEqual(len(r_conn.mockGetNamedCalls("writable")), 0)
# check writable conn # check writable conn
#self.assertEqual(len(w_conn.mockGetNamedCalls("lock")), 1)
#self.assertEqual(len(w_conn.mockGetNamedCalls("unlock")), 1)
#self.assertEqual(len(w_conn.mockGetNamedCalls("readable")), 0) #self.assertEqual(len(w_conn.mockGetNamedCalls("readable")), 0)
#self.assertEqual(len(w_conn.mockGetNamedCalls("writable")), 1) #self.assertEqual(len(w_conn.mockGetNamedCalls("writable")), 1)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment