Commit 6dffb894 authored by Vincent Pelletier's avatar Vincent Pelletier

master: Forbid truncature before database's first transaction

This is intended as a sanity check, so simple typos in neoctl truncate
command do not easily lead to the entire database being wiped.
parent f70a688c
......@@ -26,7 +26,7 @@ except ImportError:
# The protocol version must be increased whenever upgrading a node may require
# to upgrade other nodes.
PROTOCOL_VERSION = 4
PROTOCOL_VERSION = 5
# By encoding the handshake packet with msgpack, the whole NEO stream can be
# decoded with msgpack. The first byte is 0x92, which is different from TLS
# Handshake (0x16).
......
......@@ -41,7 +41,7 @@ class MasterHandler(EventHandler):
def askLastIDs(self, conn):
tm = self.app.tm
conn.answer(Packets.AnswerLastIDs(tm.getLastTID(), tm.getLastOID()))
conn.answer(Packets.AnswerLastIDs(tm.getLastTID(), tm.getLastOID(), tm.getFirstTID()))
def askLastTransaction(self, conn):
conn.answer(Packets.AnswerLastTransaction(
......
......@@ -239,6 +239,10 @@ class AdministrationHandler(MasterHandler):
app = self.app
if app.getLastTransaction() <= tid:
raise AnswerDenied("Truncating after last transaction does nothing")
first_tid = app.tm.getFirstTID()
if first_tid is None or first_tid > tid:
raise AnswerDenied("Truncating before first transaction is "
"probably not what you intended to do")
if app.pm.getApprovedRejected(add64(tid, 1))[0]:
# TODO: The protocol must be extended to support safe cases
# (e.g. no started pack whose id is after truncation tid).
......
......@@ -179,6 +179,7 @@ class TransactionManager(EventQueue):
self._ttid_dict = {}
self._last_oid = ZERO_OID
self._last_tid = ZERO_TID
self._first_tid = None
# queue filled with ttids pointing to transactions with increasing tids
self._queue = deque()
......@@ -212,6 +213,14 @@ class TransactionManager(EventQueue):
self._last_oid = oid_list[-1]
return oid_list
def setFirstTID(self, tid):
first_tid = self._first_tid
if first_tid is None or first_tid > tid:
self._first_tid = tid
def getFirstTID(self):
return self._first_tid
def setLastOID(self, oid):
if self._last_oid < oid:
self._last_oid = oid
......@@ -420,7 +429,10 @@ class TransactionManager(EventQueue):
is required is when some storages are already busy by other tasks.
"""
queue = self._queue
self._on_commit(self._ttid_dict.pop(queue.popleft()))
txn = self._ttid_dict.pop(queue.popleft())
if self._first_tid is None:
self._first_tid = txn.getTID()
self._on_commit(txn)
while queue:
ttid = queue[0]
txn = self._ttid_dict[ttid]
......
......@@ -139,11 +139,12 @@ class VerificationManager(BaseServiceHandler):
def notifyPackCompleted(self, conn, pack_id):
self.app.nm.getByUUID(conn.getUUID()).completed_pack_id = pack_id
def answerLastIDs(self, conn, ltid, loid):
def answerLastIDs(self, conn, ltid, loid, ftid):
self._uuid_set.remove(conn.getUUID())
tm = self.app.tm
tm.setLastTID(ltid)
tm.setLastOID(loid)
tm.setFirstTID(ftid)
def answerPackOrders(self, conn, pack_list):
self._uuid_set.remove(conn.getUUID())
......
......@@ -137,7 +137,7 @@ class NeoCTL(BaseApplication):
response = self.__ask(Packets.AskLastIDs())
if response[0] != Packets.AnswerLastIDs:
raise RuntimeError(response)
return response[1:]
return response[1:3]
def getLastTransaction(self):
response = self.__ask(Packets.AskLastTransaction())
......
......@@ -601,6 +601,9 @@ class ImporterDatabaseManager(DatabaseManager):
zodb = self.zodb[bisect(self.zodb_index, oid) - 1]
return zodb, oid - zodb.shift_oid
def getFirstTID(self):
return self.db.getFirstTID()
def getLastIDs(self):
tid, oid = self.db.getLastIDs()
return (max(tid, util.p64(self.zodb_ltid)),
......
......@@ -758,6 +758,22 @@ class DatabaseManager(object):
# XXX: Consider splitting getLastIDs/_getLastIDs because
# sometimes the last oid is not wanted.
def _getFirstTID(self, partition):
"""Return tid of first transaction in given 'partition'
tids are in unpacked format.
"""
@requires(_getFirstTID)
def getFirstTID(self):
"""Return tud of first transaction
tids are in unpacked format.
"""
x = self._readable_set
if x:
return min(self._getFirstTID(x) for x in x)
def _getLastTID(self, partition, max_tid=None):
"""Return tid of last transaction <= 'max_tid' in given 'partition'
......
......@@ -457,6 +457,14 @@ class MySQLDatabaseManager(MVCCDatabaseManager):
def _getPartitionTable(self):
return self.query("SELECT * FROM pt")
def _getFirstTID(self, partition):
tid_list = self.query(
"SELECT MIN(tid) as t FROM trans FORCE INDEX (PRIMARY) "
"WHERE `partition`=%s" % partition)
if tid_list:
(tid, ), = tid_list
return tid
def _getLastTID(self, partition, max_tid=None):
x = "WHERE `partition`=%s" % partition
if max_tid:
......
......@@ -343,6 +343,13 @@ class SQLiteDatabaseManager(DatabaseManager):
def _getPartitionTable(self):
return self.query("SELECT * FROM pt")
def _getFirstTID(self, partition):
try:
return self.query("SELECT MIN(tid) FROM trans WHERE partition=?",
(partition,)).fetchone()[0]
except TypeError:
pass
def _getLastTID(self, partition, max_tid=None):
x = self.query
if max_tid is None:
......
......@@ -55,7 +55,8 @@ class InitializationHandler(BaseMasterHandler):
if packed:
self.app.completed_pack_id = pack_id = min(packed.itervalues())
conn.send(Packets.NotifyPackCompleted(pack_id))
conn.answer(Packets.AnswerLastIDs(*dm.getLastIDs()))
last_tid, last_oid = dm.getLastIDs()
conn.answer(Packets.AnswerLastIDs(last_tid, last_oid, dm.getFirstTID()))
def askPartitionTable(self, conn):
pt = self.app.pt
......
......@@ -13,7 +13,7 @@ AnswerFetchObjects(?p64,?p64,{:})
AnswerFetchTransactions(?p64,[],?p64)
AnswerFinalTID(p64)
AnswerInformationLocked(p64)
AnswerLastIDs(?p64,?p64)
AnswerLastIDs(?p64,?p64,?p64)
AnswerLastTransaction(p64)
AnswerLockedTransactions({p64:?p64})
AnswerMonitorInformation([?bin],[?bin],bin)
......
......@@ -183,6 +183,7 @@ class StorageDBTests(NeoUnitTestBase):
txn1, objs1 = self.getTransaction([oid1])
txn2, objs2 = self.getTransaction([oid2])
# nothing in database
self.assertEqual(self.db.getFirstTID(), None)
self.assertEqual(self.db.getLastIDs(), (None, None))
self.assertEqual(self.db.getUnfinishedTIDDict(), {})
self.assertEqual(self.db.getObject(oid1), None)
......@@ -199,6 +200,7 @@ class StorageDBTests(NeoUnitTestBase):
([oid2], 'user', 'desc', 'ext', False, p64(2), None))
self.assertEqual(self.db.getTransaction(tid1, False), None)
self.assertEqual(self.db.getTransaction(tid2, False), None)
self.assertEqual(self.db.getFirstTID(), u64(tid1))
self.assertEqual(self.db.getTransaction(tid1, True),
([oid1], 'user', 'desc', 'ext', False, p64(1), None))
self.assertEqual(self.db.getTransaction(tid2, True),
......
......@@ -200,7 +200,7 @@ class StressApplication(AdminApplication):
if conn:
conn.ask(Packets.AskLastIDs())
def answerLastIDs(self, ltid, loid):
def answerLastIDs(self, ltid, loid, ftid):
self.loid = loid
self.ltid = ltid
self.em.setTimeout(int(time.time() + 1), self.askLastIDs)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment