Commit 8e7d4aa7 authored by Julien Muchembled's avatar Julien Muchembled

Improvements to --dynamic-master-list

- atomic write to disk to avoid corruption
- update when the address changes (not only when a node is removed/added)
parent 017f248d
...@@ -14,9 +14,8 @@ ...@@ -14,9 +14,8 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import errno, json, os
from time import time from time import time
from os.path import exists, getsize
import json
from . import attributeTracker, logging from . import attributeTracker, logging
from .handler import DelayEvent, EventQueue from .handler import DelayEvent, EventQueue
...@@ -204,34 +203,51 @@ class MasterDB(object): ...@@ -204,34 +203,51 @@ class MasterDB(object):
""" """
def __init__(self, path): def __init__(self, path):
self._path = path self._path = path
try_load = exists(path) and getsize(path)
if try_load:
db = open(path, 'r')
init_set = map(tuple, json.load(db))
else:
db = open(path, 'w+')
init_set = []
self._set = set(init_set)
db.close()
def _save(self):
try: try:
db = open(self._path, 'w') with open(path) as db:
except IOError: self._set = set(map(tuple, json.load(db)))
logging.warning('failed opening master database at %r ' except IOError, e:
'for writing, update skipped', self._path) if e.errno != errno.ENOENT:
else: raise
self._set = set()
self._save(True)
def _save(self, raise_on_error=False):
tmp = self._path + '#neo#'
try:
with open(tmp, 'w') as db:
json.dump(list(self._set), db) json.dump(list(self._set), db)
db.close() os.rename(tmp, self._path)
except EnvironmentError:
if raise_on_error:
raise
logging.exception('failed saving list of master nodes to %r',
self._path)
finally:
try:
os.remove(tmp)
except OSError:
pass
def add(self, addr): def remove(self, addr):
self._set.add(addr) if addr in self._set:
self._set.remove(addr)
self._save() self._save()
def discard(self, addr): def addremove(self, old, new):
self._set.discard(addr) assert old != new
if None is not new not in self._set:
self._set.add(new)
elif old not in self._set:
return
self._set.discard(old)
self._save() self._save()
def __repr__(self):
return '<%s@%s: %s>' % (self.__class__.__name__, self._path,
', '.join(sorted(('[%s]:%s' if ':' in x[0] else '%s:%s') % x
for x in self._set)))
def __iter__(self): def __iter__(self):
return iter(self._set) return iter(self._set)
...@@ -276,8 +292,6 @@ class NodeManager(EventQueue): ...@@ -276,8 +292,6 @@ class NodeManager(EventQueue):
self._updateUUID(node, None) self._updateUUID(node, None)
self.__updateSet(self._type_dict, None, node.getType(), node) self.__updateSet(self._type_dict, None, node.getType(), node)
self.__updateSet(self._state_dict, None, node.getState(), node) self.__updateSet(self._state_dict, None, node.getState(), node)
if node.isMaster() and self._master_db is not None:
self._master_db.add(node.getAddress())
def remove(self, node): def remove(self, node):
self._node_set.remove(node) self._node_set.remove(node)
...@@ -288,9 +302,8 @@ class NodeManager(EventQueue): ...@@ -288,9 +302,8 @@ class NodeManager(EventQueue):
self._uuid_dict.pop(node.getUUID(), None) self._uuid_dict.pop(node.getUUID(), None)
self._state_dict[node.getState()].remove(node) self._state_dict[node.getState()].remove(node)
self._type_dict[node.getType()].remove(node) self._type_dict[node.getType()].remove(node)
uuid = node.getUUID()
if node.isMaster() and self._master_db is not None: if node.isMaster() and self._master_db is not None:
self._master_db.discard(node.getAddress()) self._master_db.remove(node.getAddress())
def __update(self, index_dict, old_key, new_key, node): def __update(self, index_dict, old_key, new_key, node):
""" Update an index from old to new key """ """ Update an index from old to new key """
...@@ -305,7 +318,10 @@ class NodeManager(EventQueue): ...@@ -305,7 +318,10 @@ class NodeManager(EventQueue):
index_dict[new_key] = node index_dict[new_key] = node
def _updateAddress(self, node, old_address): def _updateAddress(self, node, old_address):
self.__update(self._address_dict, old_address, node.getAddress(), node) address = node.getAddress()
self.__update(self._address_dict, old_address, address, node)
if node.isMaster() and self._master_db is not None:
self._master_db.addremove(old_address, address)
def _updateUUID(self, node, old_uuid): def _updateUUID(self, node, old_uuid):
self.__update(self._uuid_dict, old_uuid, node.getUUID(), node) self.__update(self._uuid_dict, old_uuid, node.getUUID(), node)
......
...@@ -195,7 +195,7 @@ class MasterDBTests(NeoUnitTestBase): ...@@ -195,7 +195,7 @@ class MasterDBTests(NeoUnitTestBase):
temp_dir = getTempDirectory() temp_dir = getTempDirectory()
directory = join(temp_dir, 'read_only') directory = join(temp_dir, 'read_only')
db_file = join(directory, 'not_created') db_file = join(directory, 'not_created')
mkdir(directory, 0400) mkdir(directory, 0500)
try: try:
self.assertRaises(IOError, MasterDB, db_file) self.assertRaises(IOError, MasterDB, db_file)
finally: finally:
...@@ -212,17 +212,17 @@ class MasterDBTests(NeoUnitTestBase): ...@@ -212,17 +212,17 @@ class MasterDBTests(NeoUnitTestBase):
try: try:
db = MasterDB(db_file) db = MasterDB(db_file)
self.assertTrue(exists(db_file), db_file) self.assertTrue(exists(db_file), db_file)
chmod(db_file, 0400) chmod(directory, 0500)
address = ('example.com', 1024) address = ('example.com', 1024)
# Must not raise # Must not raise
db.add(address) db.addremove(None, address)
# Value is stored # Value is stored
self.assertTrue(address in db, [x for x in db]) self.assertIn(address, db)
# But not visible to a new db instance (write access restored so # But not visible to a new db instance (write access restored so
# it can be created) # it can be created)
chmod(db_file, 0600) chmod(directory, 0700)
db2 = MasterDB(db_file) db2 = MasterDB(db_file)
self.assertFalse(address in db2, [x for x in db2]) self.assertNotIn(address, db2)
finally: finally:
shutil.rmtree(directory) shutil.rmtree(directory)
...@@ -235,18 +235,21 @@ class MasterDBTests(NeoUnitTestBase): ...@@ -235,18 +235,21 @@ class MasterDBTests(NeoUnitTestBase):
db = MasterDB(db_file) db = MasterDB(db_file)
self.assertTrue(exists(db_file), db_file) self.assertTrue(exists(db_file), db_file)
address = ('example.com', 1024) address = ('example.com', 1024)
db.add(address) db.addremove(None, address)
address2 = ('example.org', 1024) address2 = ('example.org', 1024)
db.add(address2) db.addremove(None, address2)
# Values are visible to a new db instance # Values are visible to a new db instance
db2 = MasterDB(db_file) db2 = MasterDB(db_file)
self.assertTrue(address in db2, [x for x in db2]) self.assertIn(address, db2)
self.assertTrue(address2 in db2, [x for x in db2]) self.assertIn(address2, db2)
db.discard(address) db.addremove(address, None)
# Create yet another instance (file is not supposed to be shared) # Create yet another instance (file is not supposed to be shared)
db3 = MasterDB(db_file) db2 = MasterDB(db_file)
self.assertFalse(address in db3, [x for x in db3]) self.assertNotIn(address, db2)
self.assertTrue(address2 in db3, [x for x in db3]) self.assertIn(address2, db2)
db.remove(address2)
# and again, to test remove()
self.assertNotIn(address2, MasterDB(db_file))
finally: finally:
shutil.rmtree(directory) shutil.rmtree(directory)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment