Commit 521d36d3 authored by Grégory Wisniewski's avatar Grégory Wisniewski

Fix load/store of None configuration values with database manager.

- Store None values as NULL column values
- Raise KeyError if the configuration entry is not found

git-svn-id: https://svn.erp5.org/repos/neo/trunk@1993 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent 7a736ff7
No related merge requests found
...@@ -87,34 +87,46 @@ class Application(object): ...@@ -87,34 +87,46 @@ class Application(object):
"""Load persistent configuration data from the database. """Load persistent configuration data from the database.
If data is not present, generate it.""" If data is not present, generate it."""
def NoneOnKeyError(getter):
try:
return getter()
except KeyError:
return None
dm = self.dm dm = self.dm
self.uuid = dm.getUUID() # check cluster name
num_partitions = dm.getNumPartitions() try:
num_replicas = dm.getNumReplicas() if dm.getName() != self.name:
raise RuntimeError('name does not match with the database')
except KeyError:
dm.setName(self.name)
# load configuration
self.uuid = NoneOnKeyError(dm.getUUID)
num_partitions = NoneOnKeyError(dm.getNumPartitions)
num_replicas = NoneOnKeyError(dm.getNumReplicas)
ptid = NoneOnKeyError(dm.getPTID)
# check partition table configuration
if num_partitions is not None and num_replicas is not None: if num_partitions is not None and num_replicas is not None:
if num_partitions <= 0: if num_partitions <= 0:
raise RuntimeError, 'partitions must be more than zero' raise RuntimeError, 'partitions must be more than zero'
# create a partition table # create a partition table
self.pt = PartitionTable(num_partitions, num_replicas) self.pt = PartitionTable(num_partitions, num_replicas)
name = dm.getName()
if name is None:
dm.setName(self.name)
elif name != self.name:
raise RuntimeError('name does not match with the database')
ptid = dm.getPTID()
logging.info('Configuration loaded:') logging.info('Configuration loaded:')
logging.info('UUID : %s', dump(self.uuid)) logging.info('UUID : %s', dump(self.uuid))
logging.info('PTID : %s', dump(ptid)) logging.info('PTID : %s', dump(ptid))
logging.info('Name : %s', name) logging.info('Name : %s', self.name)
logging.info('Partitions: %s', num_partitions) logging.info('Partitions: %s', num_partitions)
logging.info('Replicas : %s', num_replicas) logging.info('Replicas : %s', num_replicas)
def loadPartitionTable(self): def loadPartitionTable(self):
"""Load a partition table from the database.""" """Load a partition table from the database."""
ptid = self.dm.getPTID() try:
ptid = self.dm.getPTID()
except KeyError:
ptid = None
cell_list = self.dm.getPartitionTable() cell_list = self.dm.getPartitionTable()
new_cell_list = [] new_cell_list = []
for offset, uuid, state in cell_list: for offset, uuid, state in cell_list:
......
...@@ -135,7 +135,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -135,7 +135,7 @@ class MySQLDatabaseManager(DatabaseManager):
# persistent data. # persistent data.
q("""CREATE TABLE IF NOT EXISTS config ( q("""CREATE TABLE IF NOT EXISTS config (
name VARBINARY(16) NOT NULL PRIMARY KEY, name VARBINARY(16) NOT NULL PRIMARY KEY,
value VARBINARY(255) NOT NULL value VARBINARY(255) NULL
) ENGINE = InnoDB""") ) ENGINE = InnoDB""")
# The table "pt" stores a partition table. # The table "pt" stores a partition table.
...@@ -191,18 +191,20 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -191,18 +191,20 @@ class MySQLDatabaseManager(DatabaseManager):
q = self.query q = self.query
e = self.escape e = self.escape
key = e(str(key)) key = e(str(key))
r = q("""SELECT value FROM config WHERE name = '%s'""" % key)
try: try:
return r[0][0] return q("SELECT value FROM config WHERE name = '%s'" % key)[0][0]
except IndexError: except IndexError:
return None raise KeyError, key
def _setConfiguration(self, key, value): def _setConfiguration(self, key, value):
q = self.query q = self.query
e = self.escape e = self.escape
key = e(str(key)) key = e(str(key))
value = e(str(value)) if value is None:
q("""REPLACE INTO config VALUES ('%s', '%s')""" % (key, value)) value = 'NULL'
else:
value = "'%s'" % (e(str(value)), )
q("""REPLACE INTO config VALUES ('%s', %s)""" % (key, value))
def getPartitionTable(self): def getPartitionTable(self):
q = self.query q = self.query
......
...@@ -27,8 +27,14 @@ class VerificationHandler(BaseMasterHandler): ...@@ -27,8 +27,14 @@ class VerificationHandler(BaseMasterHandler):
def askLastIDs(self, conn): def askLastIDs(self, conn):
app = self.app app = self.app
oid = app.dm.getLastOID() try:
tid = app.dm.getLastTID() oid = app.dm.getLastOID()
except KeyError:
oid = None
try:
tid = app.dm.getLastTID()
except KeyError:
tid = None
conn.answer(Packets.AnswerLastIDs(oid, tid, app.pt.getID())) conn.answer(Packets.AnswerLastIDs(oid, tid, app.pt.getID()))
def askPartitionTable(self, conn, offset_list): def askPartitionTable(self, conn, offset_list):
......
...@@ -151,9 +151,8 @@ class StorageMySQSLdbTests(NeoTestBase): ...@@ -151,9 +151,8 @@ class StorageMySQSLdbTests(NeoTestBase):
def test_10_getConfiguration(self): def test_10_getConfiguration(self):
# check if a configuration entry is well read # check if a configuration entry is well read
self.db.setup() self.db.setup()
result = self.db.getConfiguration('a') # doesn't exists, raise
# doesn't exists, None expected self.assertRaises(KeyError, self.db.getConfiguration, 'a')
self.assertEquals(result, None)
self.db.query("insert into config values ('a', 'b');") self.db.query("insert into config values ('a', 'b');")
result = self.db.getConfiguration('a') result = self.db.getConfiguration('a')
# exists, check result # exists, check result
...@@ -169,7 +168,7 @@ class StorageMySQSLdbTests(NeoTestBase): ...@@ -169,7 +168,7 @@ class StorageMySQSLdbTests(NeoTestBase):
def checkConfigEntry(self, get_call, set_call, value): def checkConfigEntry(self, get_call, set_call, value):
# generic test for all configuration entries accessors # generic test for all configuration entries accessors
self.db.setup() self.db.setup()
self.assertEquals(get_call(), None) self.assertRaises(KeyError, get_call)
set_call(value) set_call(value)
self.assertEquals(get_call(), value) self.assertEquals(get_call(), value)
set_call(value * 2) set_call(value * 2)
...@@ -194,13 +193,10 @@ class StorageMySQSLdbTests(NeoTestBase): ...@@ -194,13 +193,10 @@ class StorageMySQSLdbTests(NeoTestBase):
value='TEST_NAME') value='TEST_NAME')
def test_15_PTID(self): def test_15_PTID(self):
test = '\x01' * 8 self.checkConfigEntry(
self.db.setup() get_call=self.db.getPTID,
self.assertEquals(self.db.getPTID(), None) set_call=self.db.setPTID,
self.db.setPTID(test) value=self.getPTID(1))
self.assertEquals(self.db.getPTID(), test)
self.db.setPTID(test * 2)
self.assertEquals(self.db.getPTID(), test * 2)
def test_16_getPartitionTable(self): def test_16_getPartitionTable(self):
# insert an entry and check it # insert an entry and check it
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment