Commit edefaca7 authored by Julien Muchembled's avatar Julien Muchembled

client: add support for reconnection to master

This implementation proper cache invalidation.

Connection to master is also made optional to load from storage nodes, as long
as partition table contains up-to-date data (which is anyway unlikely to change
when there is no master).
parent 1ea04be7
......@@ -68,7 +68,6 @@ class Application(object):
self.dispatcher = Dispatcher(self.poll_thread)
self.nm = NodeManager(dynamic_master_list)
self.cp = ConnectionPool(self)
self.pt = None
self.master_conn = None
self.primary_master_node = None
self.trying_master_node = None
......@@ -117,6 +116,16 @@ class Application(object):
self.compress = compress
registerLiveDebugger(on_log=self.log)
def __getattr__(self, attr):
if attr == 'pt':
self._getMasterConnection()
return self.__getattribute__(attr)
@property
def txn_contexts(self):
# do not iter lazily to avoid race condition
return self._txn_container.values
def getHandlerData(self):
return self._thread_container.answer
......@@ -241,13 +250,6 @@ class Application(object):
result = self.master_conn = self._connectToPrimaryNode()
return result
def getPartitionTable(self):
""" Return the partition table manager, reconnect the PMN if needed """
# this ensure the master connection is established and the partition
# table is up to date.
self._getMasterConnection()
return self.pt
def _connectToPrimaryNode(self):
"""
Lookup for the current primary master node
......@@ -660,7 +662,6 @@ class Application(object):
ttid = txn_context['ttid']
# Store data on each node
txn_stored_counter = 0
assert not txn_context['data_dict'], txn_context
packet = Packets.AskStoreTransaction(ttid, str(transaction.user),
str(transaction.description), dumps(transaction._extension),
......@@ -674,20 +675,17 @@ class Application(object):
except ConnectionClosed:
continue
add_involved_nodes(node)
txn_stored_counter += 1
# check at least one storage node accepted
if txn_stored_counter == 0:
if txn_context['involved_nodes']:
txn_context['voted'] = None
# We must not go further if connection to master was lost since
# tpc_begin, to lower the probability of failing during tpc_finish.
if 'error' in txn_context:
raise NEOStorageError(txn_context['error'])
return result
logging.error('tpc_vote failed')
raise NEOStorageError('tpc_vote failed')
# Check if master connection is still alive.
# This is just here to lower the probability of detecting a problem
# in tpc_finish, as we should do our best to detect problem before
# tpc_finish.
self._getMasterConnection()
txn_context['txn_voted'] = True
return result
def tpc_abort(self, transaction):
"""Abort current transaction."""
......@@ -718,7 +716,7 @@ class Application(object):
def tpc_finish(self, transaction, tryToResolveConflict, f=None):
"""Finish current transaction."""
txn_container = self._txn_container
if not txn_container.get(transaction)['txn_voted']:
if 'voted' not in txn_container.get(transaction):
self.tpc_vote(transaction, tryToResolveConflict)
self._load_lock_acquire()
try:
......@@ -735,15 +733,13 @@ class Application(object):
def undo(self, undone_tid, txn, tryToResolveConflict):
txn_context = self._txn_container.get(txn)
txn_info, txn_ext = self._getTransactionInformation(undone_tid)
txn_oid_list = txn_info['oids']
# Regroup objects per partition, to ask a minimum set of storage.
partition_oid_dict = {}
pt = self.getPartitionTable()
for oid in txn_oid_list:
partition = pt.getPartition(oid)
partition = self.pt.getPartition(oid)
try:
oid_list = partition_oid_dict[partition]
except KeyError:
......@@ -752,7 +748,7 @@ class Application(object):
# Ask storage the undo serial (serial at which object's previous data
# is)
getCellList = pt.getCellList
getCellList = self.pt.getCellList
getCellSortKey = self.cp.getCellSortKey
getConnForCell = self.cp.getConnForCell
queue = self._thread_container.queue
......@@ -838,11 +834,10 @@ class Application(object):
# First get a list of transactions from all storage nodes.
# Each storage node will return TIDs only for UP_TO_DATE state and
# FEEDING state cells
pt = self.getPartitionTable()
queue = self._thread_container.queue
packet = Packets.AskTIDs(first, last, INVALID_PARTITION)
tid_set = set()
for storage_node in pt.getNodeSet(True):
for storage_node in self.pt.getNodeSet(True):
conn = self.cp.getConnForNode(storage_node)
if conn is None:
continue
......
......@@ -235,6 +235,22 @@ class ClientCache(object):
else:
assert item.next_tid <= tid, (item, oid, tid)
def clear_current(self):
oid_list = []
for oid, item_list in self._oid_dict.items():
item = item_list[-1]
if item.next_tid is None:
self._remove(item)
del item_list[-1]
# We don't preserve statistics of removed items. This could be
# done easily when previous versions are cached, by copying
# counters, but it would not be fair for other oids, so it's
# probably not worth it.
if not item_list:
del self._oid_dict[oid]
oid_list.append(oid)
return oid_list
def test(self):
cache = ClientCache()
......@@ -250,7 +266,11 @@ def test(self):
data = '15', 15, None
cache.store(1, *data)
self.assertEqual(cache.load(1, None), data)
self.assertEqual(cache.clear_current(), [1])
self.assertEqual(cache.load(1, None), None)
cache.store(1, *data)
cache.invalidate(1, 20)
self.assertEqual(cache.clear_current(), [])
self.assertEqual(cache.load(1, 20), ('15', 15, 20))
cache.store(1, '10', 10, 15)
cache.store(1, '20', 20, 21)
......
......@@ -99,7 +99,6 @@ class TransactionContainer(dict):
'object_stored_counter_dict': {},
'conflict_serial_dict': {},
'resolved_conflict_serial_dict': {},
'txn_voted': False,
'involved_nodes': set(),
}
return context
......@@ -17,7 +17,7 @@
from neo.lib import logging
from neo.lib.pt import MTPartitionTable as PartitionTable
from neo.lib.protocol import NodeStates, Packets, ProtocolError
from neo.lib.util import dump
from neo.lib.util import dump, add64
from . import BaseHandler, AnswerBaseHandler
from ..exception import NEOStorageError
......@@ -96,7 +96,20 @@ class PrimaryNotificationsHandler(BaseHandler):
def packetReceived(self, conn, packet, kw={}):
if type(packet) is Packets.AnswerLastTransaction:
self.app.last_tid = packet.decode()[0]
app = self.app
ltid = packet.decode()[0]
if app.last_tid != ltid:
if app.master_conn is None:
app._cache_lock_acquire()
try:
oid_list = app._cache.clear_current()
db = app.getDB()
if db is not None:
db.invalidate(app.last_tid and
add64(app.last_tid, 1), oid_list)
finally:
app._cache_lock_release()
app.last_tid = ltid
elif type(packet) is Packets.AnswerTransactionFinished:
app = self.app
app.last_tid = tid = packet.decode()[1]
......@@ -125,8 +138,11 @@ class PrimaryNotificationsHandler(BaseHandler):
def connectionClosed(self, conn):
app = self.app
if app.master_conn is not None:
logging.critical("connection to primary master node closed")
msg = "connection to primary master node closed"
logging.critical(msg)
app.master_conn = None
for txn_context in app.txn_contexts():
txn_context['error'] = msg
app.primary_master_node = None
super(PrimaryNotificationsHandler, self).connectionClosed(conn)
......@@ -151,10 +167,6 @@ class PrimaryNotificationsHandler(BaseHandler):
finally:
app._cache_lock_release()
# For the two methods below, we must not use app._getPartitionTable()
# to avoid a dead lock. It is safe to not check the master connection
# because it's in the master handler, so the connection is already
# established.
def notifyPartitionChanges(self, conn, ptid, cell_list):
if self.app.pt.filled():
self.app.pt.update(ptid, cell_list, self.app.nm)
......
......@@ -107,7 +107,7 @@ class ConnectionPool(object):
def iterateForObject(self, object_id, readable=False):
""" Iterate over nodes managing an object """
pt = self.app.getPartitionTable()
pt = self.app.pt
if type(object_id) is str:
object_id = pt.getPartition(object_id)
cell_list = pt.getCellList(object_id, readable)
......
......@@ -44,11 +44,6 @@ def _getMasterConnection(self):
self.master_conn = Mock()
return self.master_conn
def getPartitionTable(self):
if self.pt is None:
self.master_conn = _getMasterConnection(self)
return self.pt
def _ask(self, conn, packet, handler=None, **kw):
self.setHandlerData(None)
conn.ask(packet, **kw)
......@@ -71,10 +66,8 @@ class ClientApplicationTests(NeoUnitTestBase):
# apply monkey patches
self._getMasterConnection = Application._getMasterConnection
self._ask = Application._ask
self.getPartitionTable = Application.getPartitionTable
Application._getMasterConnection = _getMasterConnection
Application._ask = _ask
Application.getPartitionTable = getPartitionTable
self._to_stop_list = []
def _tearDown(self, success):
......@@ -82,9 +75,8 @@ class ClientApplicationTests(NeoUnitTestBase):
for app in self._to_stop_list:
app.close()
# restore environnement
Application._getMasterConnection = self._getMasterConnection
Application._ask = self._ask
Application.getPartitionTable = self.getPartitionTable
Application._getMasterConnection = self._getMasterConnection
NeoUnitTestBase._tearDown(self, success)
# some helpers
......@@ -499,7 +491,7 @@ class ClientApplicationTests(NeoUnitTestBase):
'getAddress': ('127.0.0.1', 10010),
'fakeReceived': packet,
})
txn_context['txn_voted'] = True
txn_context['voted'] = None
app.tpc_finish(txn, None)
self.checkAskFinishTransaction(app.master_conn)
#self.checkDispatcherRegisterCalled(app, app.master_conn)
......
......@@ -76,8 +76,8 @@ class ConnectionPoolTests(NeoUnitTestBase):
def test_iterateForObject_noStorageAvailable(self):
# no node available
oid = self.getOID(1)
pt = Mock({'getCellList': []})
app = Mock({'getPartitionTable': pt})
app = Mock()
app.pt = Mock({'getCellList': []})
pool = ConnectionPool(app)
self.assertRaises(NEOStorageError, pool.iterateForObject(oid).next)
......@@ -87,8 +87,8 @@ class ConnectionPoolTests(NeoUnitTestBase):
node = Mock({'__repr__': 'node', 'isRunning': True})
cell = Mock({'__repr__': 'cell', 'getNode': node})
conn = Mock({'__repr__': 'conn'})
pt = Mock({'getCellList': [cell]})
app = Mock({'getPartitionTable': pt})
app = Mock()
app.pt = Mock({'getCellList': [cell]})
pool = ConnectionPool(app)
pool.getConnForNode = Mock({'__call__': ReturnValues(None, conn)})
self.assertEqual(list(pool.iterateForObject(oid)), [(node, conn)])
......@@ -99,8 +99,8 @@ class ConnectionPoolTests(NeoUnitTestBase):
node = Mock({'__repr__': 'node', 'isRunning': True})
cell = Mock({'__repr__': 'cell', 'getNode': node})
conn = Mock({'__repr__': 'conn'})
pt = Mock({'getCellList': [cell]})
app = Mock({'getPartitionTable': pt})
app = Mock()
app.pt = Mock({'getCellList': [cell]})
pool = ConnectionPool(app)
pool.getConnForNode = Mock({'__call__': conn})
self.assertEqual(list(pool.iterateForObject(oid)), [(node, conn)])
......
......@@ -29,7 +29,8 @@ class MasterHandlerTests(NeoUnitTestBase):
def setUp(self):
super(MasterHandlerTests, self).setUp()
self.db = Mock()
self.app = Mock({'getDB': self.db})
self.app = Mock({'getDB': self.db,
'txn_contexts': ()})
self.app.nm = NodeManager()
self.app.dispatcher = Mock()
self._next_port = 3000
......
......@@ -649,6 +649,44 @@ class Test(NEOThreadedTest):
finally:
cluster.stop()
def testClientReconnection(self):
cluster = NEOCluster()
try:
cluster.start()
t1, c1 = cluster.getTransaction()
c1.root()['x'] = x1 = PCounter()
c1.root()['y'] = y = PCounter()
y.value = 1
t1.commit()
x = c1._storage.load(x1._p_oid)[0]
y = c1._storage.load(y._p_oid)[0]
# close connections to master & storage
c, = cluster.master.nm.getClientList()
c.getConnection().close()
c, = cluster.storage.nm.getClientList()
c.getConnection().close()
cluster.tic()
# modify x with another client
client = ClientApplication(name=cluster.name,
master_nodes=cluster.master_nodes)
cluster.client.setPoll(0)
client.setPoll(1)
txn = transaction.Transaction()
client.tpc_begin(txn)
client.store(x1._p_oid, x1._p_serial, y, '', txn)
tid = client.tpc_finish(txn, None)
client.close()
client.setPoll(0)
cluster.client.setPoll(1)
t1.begin()
self.assertEqual(x1._p_changed ,None)
self.assertEqual(x1.value, 1)
finally:
cluster.stop()
def testInvalidTTID(self):
cluster = NEOCluster()
try:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment