Commit 6dffb894 authored by Vincent Pelletier's avatar Vincent Pelletier

master: Forbid truncature before database's first transaction

This is intended as a sanity check, so simple typos in neoctl truncate
command do not easily lead to the entire database being wiped.
parent f70a688c
...@@ -26,7 +26,7 @@ except ImportError: ...@@ -26,7 +26,7 @@ except ImportError:
# The protocol version must be increased whenever upgrading a node may require # The protocol version must be increased whenever upgrading a node may require
# to upgrade other nodes. # to upgrade other nodes.
PROTOCOL_VERSION = 4 PROTOCOL_VERSION = 5
# By encoding the handshake packet with msgpack, the whole NEO stream can be # By encoding the handshake packet with msgpack, the whole NEO stream can be
# decoded with msgpack. The first byte is 0x92, which is different from TLS # decoded with msgpack. The first byte is 0x92, which is different from TLS
# Handshake (0x16). # Handshake (0x16).
......
...@@ -41,7 +41,7 @@ class MasterHandler(EventHandler): ...@@ -41,7 +41,7 @@ class MasterHandler(EventHandler):
def askLastIDs(self, conn): def askLastIDs(self, conn):
tm = self.app.tm tm = self.app.tm
conn.answer(Packets.AnswerLastIDs(tm.getLastTID(), tm.getLastOID())) conn.answer(Packets.AnswerLastIDs(tm.getLastTID(), tm.getLastOID(), tm.getFirstTID()))
def askLastTransaction(self, conn): def askLastTransaction(self, conn):
conn.answer(Packets.AnswerLastTransaction( conn.answer(Packets.AnswerLastTransaction(
......
...@@ -239,6 +239,10 @@ class AdministrationHandler(MasterHandler): ...@@ -239,6 +239,10 @@ class AdministrationHandler(MasterHandler):
app = self.app app = self.app
if app.getLastTransaction() <= tid: if app.getLastTransaction() <= tid:
raise AnswerDenied("Truncating after last transaction does nothing") raise AnswerDenied("Truncating after last transaction does nothing")
first_tid = app.tm.getFirstTID()
if first_tid is None or first_tid > tid:
raise AnswerDenied("Truncating before first transaction is "
"probably not what you intended to do")
if app.pm.getApprovedRejected(add64(tid, 1))[0]: if app.pm.getApprovedRejected(add64(tid, 1))[0]:
# TODO: The protocol must be extended to support safe cases # TODO: The protocol must be extended to support safe cases
# (e.g. no started pack whose id is after truncation tid). # (e.g. no started pack whose id is after truncation tid).
......
...@@ -179,6 +179,7 @@ class TransactionManager(EventQueue): ...@@ -179,6 +179,7 @@ class TransactionManager(EventQueue):
self._ttid_dict = {} self._ttid_dict = {}
self._last_oid = ZERO_OID self._last_oid = ZERO_OID
self._last_tid = ZERO_TID self._last_tid = ZERO_TID
self._first_tid = None
# queue filled with ttids pointing to transactions with increasing tids # queue filled with ttids pointing to transactions with increasing tids
self._queue = deque() self._queue = deque()
...@@ -212,6 +213,14 @@ class TransactionManager(EventQueue): ...@@ -212,6 +213,14 @@ class TransactionManager(EventQueue):
self._last_oid = oid_list[-1] self._last_oid = oid_list[-1]
return oid_list return oid_list
def setFirstTID(self, tid):
first_tid = self._first_tid
if first_tid is None or first_tid > tid:
self._first_tid = tid
def getFirstTID(self):
return self._first_tid
def setLastOID(self, oid): def setLastOID(self, oid):
if self._last_oid < oid: if self._last_oid < oid:
self._last_oid = oid self._last_oid = oid
...@@ -420,7 +429,10 @@ class TransactionManager(EventQueue): ...@@ -420,7 +429,10 @@ class TransactionManager(EventQueue):
is required is when some storages are already busy by other tasks. is required is when some storages are already busy by other tasks.
""" """
queue = self._queue queue = self._queue
self._on_commit(self._ttid_dict.pop(queue.popleft())) txn = self._ttid_dict.pop(queue.popleft())
if self._first_tid is None:
self._first_tid = txn.getTID()
self._on_commit(txn)
while queue: while queue:
ttid = queue[0] ttid = queue[0]
txn = self._ttid_dict[ttid] txn = self._ttid_dict[ttid]
......
...@@ -139,11 +139,12 @@ class VerificationManager(BaseServiceHandler): ...@@ -139,11 +139,12 @@ class VerificationManager(BaseServiceHandler):
def notifyPackCompleted(self, conn, pack_id): def notifyPackCompleted(self, conn, pack_id):
self.app.nm.getByUUID(conn.getUUID()).completed_pack_id = pack_id self.app.nm.getByUUID(conn.getUUID()).completed_pack_id = pack_id
def answerLastIDs(self, conn, ltid, loid): def answerLastIDs(self, conn, ltid, loid, ftid):
self._uuid_set.remove(conn.getUUID()) self._uuid_set.remove(conn.getUUID())
tm = self.app.tm tm = self.app.tm
tm.setLastTID(ltid) tm.setLastTID(ltid)
tm.setLastOID(loid) tm.setLastOID(loid)
tm.setFirstTID(ftid)
def answerPackOrders(self, conn, pack_list): def answerPackOrders(self, conn, pack_list):
self._uuid_set.remove(conn.getUUID()) self._uuid_set.remove(conn.getUUID())
......
...@@ -137,7 +137,7 @@ class NeoCTL(BaseApplication): ...@@ -137,7 +137,7 @@ class NeoCTL(BaseApplication):
response = self.__ask(Packets.AskLastIDs()) response = self.__ask(Packets.AskLastIDs())
if response[0] != Packets.AnswerLastIDs: if response[0] != Packets.AnswerLastIDs:
raise RuntimeError(response) raise RuntimeError(response)
return response[1:] return response[1:3]
def getLastTransaction(self): def getLastTransaction(self):
response = self.__ask(Packets.AskLastTransaction()) response = self.__ask(Packets.AskLastTransaction())
......
...@@ -601,6 +601,9 @@ class ImporterDatabaseManager(DatabaseManager): ...@@ -601,6 +601,9 @@ class ImporterDatabaseManager(DatabaseManager):
zodb = self.zodb[bisect(self.zodb_index, oid) - 1] zodb = self.zodb[bisect(self.zodb_index, oid) - 1]
return zodb, oid - zodb.shift_oid return zodb, oid - zodb.shift_oid
def getFirstTID(self):
return self.db.getFirstTID()
def getLastIDs(self): def getLastIDs(self):
tid, oid = self.db.getLastIDs() tid, oid = self.db.getLastIDs()
return (max(tid, util.p64(self.zodb_ltid)), return (max(tid, util.p64(self.zodb_ltid)),
......
...@@ -758,6 +758,22 @@ class DatabaseManager(object): ...@@ -758,6 +758,22 @@ class DatabaseManager(object):
# XXX: Consider splitting getLastIDs/_getLastIDs because # XXX: Consider splitting getLastIDs/_getLastIDs because
# sometimes the last oid is not wanted. # sometimes the last oid is not wanted.
def _getFirstTID(self, partition):
"""Return tid of first transaction in given 'partition'
tids are in unpacked format.
"""
@requires(_getFirstTID)
def getFirstTID(self):
"""Return tud of first transaction
tids are in unpacked format.
"""
x = self._readable_set
if x:
return min(self._getFirstTID(x) for x in x)
def _getLastTID(self, partition, max_tid=None): def _getLastTID(self, partition, max_tid=None):
"""Return tid of last transaction <= 'max_tid' in given 'partition' """Return tid of last transaction <= 'max_tid' in given 'partition'
......
...@@ -457,6 +457,14 @@ class MySQLDatabaseManager(MVCCDatabaseManager): ...@@ -457,6 +457,14 @@ class MySQLDatabaseManager(MVCCDatabaseManager):
def _getPartitionTable(self): def _getPartitionTable(self):
return self.query("SELECT * FROM pt") return self.query("SELECT * FROM pt")
def _getFirstTID(self, partition):
tid_list = self.query(
"SELECT MIN(tid) as t FROM trans FORCE INDEX (PRIMARY) "
"WHERE `partition`=%s" % partition)
if tid_list:
(tid, ), = tid_list
return tid
def _getLastTID(self, partition, max_tid=None): def _getLastTID(self, partition, max_tid=None):
x = "WHERE `partition`=%s" % partition x = "WHERE `partition`=%s" % partition
if max_tid: if max_tid:
......
...@@ -343,6 +343,13 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -343,6 +343,13 @@ class SQLiteDatabaseManager(DatabaseManager):
def _getPartitionTable(self): def _getPartitionTable(self):
return self.query("SELECT * FROM pt") return self.query("SELECT * FROM pt")
def _getFirstTID(self, partition):
try:
return self.query("SELECT MIN(tid) FROM trans WHERE partition=?",
(partition,)).fetchone()[0]
except TypeError:
pass
def _getLastTID(self, partition, max_tid=None): def _getLastTID(self, partition, max_tid=None):
x = self.query x = self.query
if max_tid is None: if max_tid is None:
......
...@@ -55,7 +55,8 @@ class InitializationHandler(BaseMasterHandler): ...@@ -55,7 +55,8 @@ class InitializationHandler(BaseMasterHandler):
if packed: if packed:
self.app.completed_pack_id = pack_id = min(packed.itervalues()) self.app.completed_pack_id = pack_id = min(packed.itervalues())
conn.send(Packets.NotifyPackCompleted(pack_id)) conn.send(Packets.NotifyPackCompleted(pack_id))
conn.answer(Packets.AnswerLastIDs(*dm.getLastIDs())) last_tid, last_oid = dm.getLastIDs()
conn.answer(Packets.AnswerLastIDs(last_tid, last_oid, dm.getFirstTID()))
def askPartitionTable(self, conn): def askPartitionTable(self, conn):
pt = self.app.pt pt = self.app.pt
......
...@@ -13,7 +13,7 @@ AnswerFetchObjects(?p64,?p64,{:}) ...@@ -13,7 +13,7 @@ AnswerFetchObjects(?p64,?p64,{:})
AnswerFetchTransactions(?p64,[],?p64) AnswerFetchTransactions(?p64,[],?p64)
AnswerFinalTID(p64) AnswerFinalTID(p64)
AnswerInformationLocked(p64) AnswerInformationLocked(p64)
AnswerLastIDs(?p64,?p64) AnswerLastIDs(?p64,?p64,?p64)
AnswerLastTransaction(p64) AnswerLastTransaction(p64)
AnswerLockedTransactions({p64:?p64}) AnswerLockedTransactions({p64:?p64})
AnswerMonitorInformation([?bin],[?bin],bin) AnswerMonitorInformation([?bin],[?bin],bin)
......
...@@ -183,6 +183,7 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -183,6 +183,7 @@ class StorageDBTests(NeoUnitTestBase):
txn1, objs1 = self.getTransaction([oid1]) txn1, objs1 = self.getTransaction([oid1])
txn2, objs2 = self.getTransaction([oid2]) txn2, objs2 = self.getTransaction([oid2])
# nothing in database # nothing in database
self.assertEqual(self.db.getFirstTID(), None)
self.assertEqual(self.db.getLastIDs(), (None, None)) self.assertEqual(self.db.getLastIDs(), (None, None))
self.assertEqual(self.db.getUnfinishedTIDDict(), {}) self.assertEqual(self.db.getUnfinishedTIDDict(), {})
self.assertEqual(self.db.getObject(oid1), None) self.assertEqual(self.db.getObject(oid1), None)
...@@ -199,6 +200,7 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -199,6 +200,7 @@ class StorageDBTests(NeoUnitTestBase):
([oid2], 'user', 'desc', 'ext', False, p64(2), None)) ([oid2], 'user', 'desc', 'ext', False, p64(2), None))
self.assertEqual(self.db.getTransaction(tid1, False), None) self.assertEqual(self.db.getTransaction(tid1, False), None)
self.assertEqual(self.db.getTransaction(tid2, False), None) self.assertEqual(self.db.getTransaction(tid2, False), None)
self.assertEqual(self.db.getFirstTID(), u64(tid1))
self.assertEqual(self.db.getTransaction(tid1, True), self.assertEqual(self.db.getTransaction(tid1, True),
([oid1], 'user', 'desc', 'ext', False, p64(1), None)) ([oid1], 'user', 'desc', 'ext', False, p64(1), None))
self.assertEqual(self.db.getTransaction(tid2, True), self.assertEqual(self.db.getTransaction(tid2, True),
......
...@@ -200,7 +200,7 @@ class StressApplication(AdminApplication): ...@@ -200,7 +200,7 @@ class StressApplication(AdminApplication):
if conn: if conn:
conn.ask(Packets.AskLastIDs()) conn.ask(Packets.AskLastIDs())
def answerLastIDs(self, ltid, loid): def answerLastIDs(self, ltid, loid, ftid):
self.loid = loid self.loid = loid
self.ltid = ltid self.ltid = ltid
self.em.setTimeout(int(time.time() + 1), self.askLastIDs) self.em.setTimeout(int(time.time() + 1), self.askLastIDs)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment