mysqldb.py 17 KB
Newer Older
Aurel's avatar
Aurel committed
1
#
Grégory Wisniewski's avatar
Grégory Wisniewski committed
2
# Copyright (C) 2006-2010  Nexedi SA
3
#
Aurel's avatar
Aurel committed
4 5 6 7
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
8
#
Aurel's avatar
Aurel committed
9 10 11 12 13 14 15
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
16
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
Aurel's avatar
Aurel committed
17

Yoshinori Okuji's avatar
Yoshinori Okuji committed
18 19 20
import MySQLdb
from MySQLdb import OperationalError
from MySQLdb.constants.CR import SERVER_GONE_ERROR, SERVER_LOST
21
from neo import logging
22
from array import array
23
import string
Yoshinori Okuji's avatar
Yoshinori Okuji committed
24

25 26
from neo.storage.database import DatabaseManager
from neo.exception import DatabaseFailure
27
from neo.protocol import CellStates
28
from neo import util
Yoshinori Okuji's avatar
Yoshinori Okuji committed
29

30 31
LOG_QUERIES = False

32 33
class MySQLDatabaseManager(DatabaseManager):
    """This class manages a database on MySQL."""
Yoshinori Okuji's avatar
Yoshinori Okuji committed
34

35 36 37
    def __init__(self, database):
        super(MySQLDatabaseManager, self).__init__()
        self.user, self.passwd, self.db = self._parse(database)
Yoshinori Okuji's avatar
Yoshinori Okuji committed
38
        self.conn = None
39
        self._connect()
40 41 42 43 44 45 46 47 48 49 50

    def _parse(self, database):
        """ Get the database credentials (username, password, database) """
        # expected pattern : [user[:password]@]database
        username = None
        password = None
        if '@' in database:
            (username, database) = database.split('@')
            if ':' in username:
                (username, password) = username.split(':')
        return (username, password, database)
Yoshinori Okuji's avatar
Yoshinori Okuji committed
51

52 53 54
    def close(self):
        self.conn.close()

55
    def _connect(self):
Yoshinori Okuji's avatar
Yoshinori Okuji committed
56 57 58 59 60 61
        kwd = {'db' : self.db, 'user' : self.user}
        if self.passwd is not None:
            kwd['passwd'] = self.passwd
        logging.info('connecting to MySQL on the database %s with user %s',
                     self.db, self.user)
        self.conn = MySQLdb.connect(**kwd)
62 63
        self.conn.autocommit(False)

64
    def _begin(self):
65 66
        self.query("""BEGIN""")

67
    def _commit(self):
68 69
        self.conn.commit()

70
    def _rollback(self):
71
        self.conn.rollback()
Yoshinori Okuji's avatar
Yoshinori Okuji committed
72 73 74 75 76

    def query(self, query):
        """Query data from a database."""
        conn = self.conn
        try:
77 78 79 80 81 82 83 84
            if LOG_QUERIES:
                printable_char_list = []
                for c in query.split('\n', 1)[0][:70]:
                    if c not in string.printable or c in '\t\x0b\x0c\r':
                        c = '\\x%02x' % ord(c)
                    printable_char_list.append(c)
                query_part = ''.join(printable_char_list)
                logging.debug('querying %s...', query_part)
85

Yoshinori Okuji's avatar
Yoshinori Okuji committed
86 87 88
            conn.query(query)
            r = conn.store_result()
            if r is not None:
89 90 91 92 93 94 95 96 97 98
                new_r = []
                for row in r.fetch_row(r.num_rows()):
                    new_row = []
                    for d in row:
                        if isinstance(d, array):
                            d = d.tostring()
                        new_row.append(d)
                    new_r.append(tuple(new_row))
                r = tuple(new_r)

Yoshinori Okuji's avatar
Yoshinori Okuji committed
99 100 101 102 103
        except OperationalError, m:
            if m[0] in (SERVER_GONE_ERROR, SERVER_LOST):
                logging.info('the MySQL server is gone; reconnecting')
                self.connect()
                return self.query(query)
104
            raise DatabaseFailure('MySQL error %d: %s' % (m[0], m[1]))
Yoshinori Okuji's avatar
Yoshinori Okuji committed
105
        return r
106

107 108 109
    def escape(self, s):
        """Escape special characters in a string."""
        return self.conn.escape_string(s)
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126

    def setup(self, reset = 0):
        q = self.query

        if reset:
            q("""DROP TABLE IF EXISTS config, pt, trans, obj, ttrans, tobj""")

        # The table "config" stores configuration parameters which affect the
        # persistent data.
        q("""CREATE TABLE IF NOT EXISTS config (
                 name VARBINARY(16) NOT NULL PRIMARY KEY,
                 value VARBINARY(255) NOT NULL
             ) ENGINE = InnoDB""")

        # The table "pt" stores a partition table.
        q("""CREATE TABLE IF NOT EXISTS pt (
                 rid INT UNSIGNED NOT NULL,
127
                 uuid CHAR(32) NOT NULL,
128 129 130 131 132 133
                 state TINYINT UNSIGNED NOT NULL,
                 PRIMARY KEY (rid, uuid)
             ) ENGINE = InnoDB""")

        # The table "trans" stores information on committed transactions.
        q("""CREATE TABLE IF NOT EXISTS trans (
134
                 tid BIGINT UNSIGNED NOT NULL PRIMARY KEY,
135
                 packed BOOLEAN NOT NULL,
136 137
                 oids MEDIUMBLOB NOT NULL,
                 user BLOB NOT NULL,
Yoshinori Okuji's avatar
Yoshinori Okuji committed
138
                 description BLOB NOT NULL,
139 140 141 142 143
                 ext BLOB NOT NULL
             ) ENGINE = InnoDB""")

        # The table "obj" stores committed object data.
        q("""CREATE TABLE IF NOT EXISTS obj (
144 145
                 oid BIGINT UNSIGNED NOT NULL,
                 serial BIGINT UNSIGNED NOT NULL,
146
                 compression TINYINT UNSIGNED NOT NULL,
147
                 checksum INT UNSIGNED NOT NULL,
148 149 150 151 152 153
                 value MEDIUMBLOB NOT NULL,
                 PRIMARY KEY (oid, serial)
             ) ENGINE = InnoDB""")

        # The table "ttrans" stores information on uncommitted transactions.
        q("""CREATE TABLE IF NOT EXISTS ttrans (
154
                 tid BIGINT UNSIGNED NOT NULL,
155
                 packed BOOLEAN NOT NULL,
156 157
                 oids MEDIUMBLOB NOT NULL,
                 user BLOB NOT NULL,
Yoshinori Okuji's avatar
Yoshinori Okuji committed
158
                 description BLOB NOT NULL,
159 160 161 162 163
                 ext BLOB NOT NULL
             ) ENGINE = InnoDB""")

        # The table "tobj" stores uncommitted object data.
        q("""CREATE TABLE IF NOT EXISTS tobj (
164 165
                 oid BIGINT UNSIGNED NOT NULL,
                 serial BIGINT UNSIGNED NOT NULL,
166
                 compression TINYINT UNSIGNED NOT NULL,
167
                 checksum INT UNSIGNED NOT NULL,
168 169 170 171 172 173 174 175 176 177 178 179 180
                 value MEDIUMBLOB NOT NULL
             ) ENGINE = InnoDB""")

    def getConfiguration(self, key):
        q = self.query
        e = self.escape
        key = e(str(key))
        r = q("""SELECT value FROM config WHERE name = '%s'""" % key)
        try:
            return r[0][0]
        except IndexError:
            return None

181 182 183 184 185 186
    def _setConfiguration(self, key, value):
        q = self.query
        e = self.escape
        key = e(str(key))
        value = e(str(value))
        q("""REPLACE INTO config VALUES ('%s', '%s')""" % (key, value))
187 188 189

    def getPartitionTable(self):
        q = self.query
190 191 192 193 194 195
        cell_list = q("""SELECT rid, uuid, state FROM pt""")
        pt = []
        for offset, uuid, state in cell_list:
            uuid = util.bin(uuid)
            pt.append((offset, uuid, state))
        return pt
196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217

    def getLastTID(self, all = True):
        # XXX this does not consider serials in obj.
        # I am not sure if this is really harmful. For safety,
        # check for tobj only at the moment. The reason why obj is
        # not tested is that it is too slow to get the max serial
        # from obj when it has a huge number of objects, because
        # serial is the second part of the primary key, so the index
        # is not used in this case. If doing it, it is better to
        # make another index for serial, but I doubt the cost increase
        # is worth.
        q = self.query
        self.begin()
        ltid = q("""SELECT MAX(tid) FROM trans""")[0][0]
        if all:
            tmp_ltid = q("""SELECT MAX(tid) FROM ttrans""")[0][0]
            if ltid is None or (tmp_ltid is not None and ltid < tmp_ltid):
                ltid = tmp_ltid
            tmp_serial = q("""SELECT MAX(serial) FROM tobj""")[0][0]
            if ltid is None or (tmp_serial is not None and ltid < tmp_serial):
                ltid = tmp_serial
        self.commit()
218
        if ltid is not None:
219
            ltid = util.p64(ltid)
220 221 222 223 224 225 226
        return ltid

    def getUnfinishedTIDList(self):
        q = self.query
        tid_set = set()
        self.begin()
        r = q("""SELECT tid FROM ttrans""")
227
        tid_set.update((util.p64(t[0]) for t in r))
228 229
        r = q("""SELECT serial FROM tobj""")
        self.commit()
230
        tid_set.update((util.p64(t[0]) for t in r))
231 232 233 234
        return list(tid_set)

    def objectPresent(self, oid, tid, all = True):
        q = self.query
235 236
        oid = util.u64(oid)
        tid = util.u64(tid)
237
        self.begin()
238
        r = q("""SELECT oid FROM obj WHERE oid = %d AND serial = %d""" \
239 240
                % (oid, tid))
        if not r and all:
241
            r = q("""SELECT oid FROM tobj WHERE oid = %d AND serial = %d""" \
242 243 244 245 246 247
                    % (oid, tid))
        self.commit()
        if r:
            return True
        return False

248 249
    def getObject(self, oid, tid = None, before_tid = None):
        q = self.query
250
        oid = util.u64(oid)
251
        if tid is not None:
252
            tid = util.u64(tid)
Yoshinori Okuji's avatar
Yoshinori Okuji committed
253
            r = q("""SELECT serial, compression, checksum, value FROM obj
254
                        WHERE oid = %d AND serial = %d""" \
255
                    % (oid, tid))
256 257 258 259 260
            try:
                serial, compression, checksum, data = r[0]
                next_serial = None
            except IndexError:
                return None
261
        elif before_tid is not None:
262
            before_tid = util.u64(before_tid)
Yoshinori Okuji's avatar
Yoshinori Okuji committed
263
            r = q("""SELECT serial, compression, checksum, value FROM obj
264
                        WHERE oid = %d AND serial < %d
265
                        ORDER BY serial DESC LIMIT 1""" \
266
                    % (oid, before_tid))
267 268 269 270
            try:
                serial, compression, checksum, data = r[0]
            except IndexError:
                return None
271 272 273 274 275 276 277 278
            r = q("""SELECT serial FROM obj
                        WHERE oid = %d AND serial >= %d
                        ORDER BY serial LIMIT 1""" \
                    % (oid, before_tid))
            try:
                next_serial = r[0][0]
            except IndexError:
                next_serial = None
279 280 281
        else:
            # XXX I want to express "HAVING serial = MAX(serial)", but
            # MySQL does not use an index for a HAVING clause!
Yoshinori Okuji's avatar
Yoshinori Okuji committed
282
            r = q("""SELECT serial, compression, checksum, value FROM obj
283
                        WHERE oid = %d ORDER BY serial DESC LIMIT 1""" \
284
                    % oid)
285 286 287 288 289 290
            try:
                serial, compression, checksum, data = r[0]
                next_serial = None
            except IndexError:
                return None

291
        if serial is not None:
292
            serial = util.p64(serial)
293
        if next_serial is not None:
294
            next_serial = util.p64(next_serial)
295
        return serial, next_serial, compression, checksum, data
296

297 298 299 300 301 302 303 304
    def doSetPartitionTable(self, ptid, cell_list, reset):
        q = self.query
        e = self.escape
        self.begin()
        try:
            if reset:
                q("""TRUNCATE pt""")
            for offset, uuid, state in cell_list:
305
                uuid = e(util.dump(uuid))
306
                if state == CellStates.DISCARDED:
307
                    q("""DELETE FROM pt WHERE rid = %d AND uuid = '%s'""" \
308 309 310 311 312
                            % (offset, uuid))
                else:
                    q("""INSERT INTO pt VALUES (%d, '%s', %d)
                            ON DUPLICATE KEY UPDATE state = %d""" \
                                    % (offset, uuid, state, state))
313
            self.setPTID(ptid)
314 315 316 317 318 319
        except:
            self.rollback()
            raise
        self.commit()

    def changePartitionTable(self, ptid, cell_list):
320
        self.doSetPartitionTable(ptid, cell_list, False)
321 322

    def setPartitionTable(self, ptid, cell_list):
323
        self.doSetPartitionTable(ptid, cell_list, True)
324

325 326 327 328
    def dropPartition(self, num_partitions, offset):
        q = self.query
        self.begin()
        try:
329
            q("""DELETE FROM obj WHERE MOD(oid, %d) = %d""" %
330
                (num_partitions, offset))
331
            q("""DELETE FROM trans WHERE MOD(tid, %d) = %d""" %
332 333 334
                (num_partitions, offset))
        except:
            self.rollback()
335
            raise
336 337
        self.commit()

338 339 340 341 342 343 344 345 346 347 348
    def dropUnfinishedData(self):
        q = self.query
        self.begin()
        try:
            q("""TRUNCATE tobj""")
            q("""TRUNCATE ttrans""")
        except:
            self.rollback()
            raise
        self.commit()

349
    def storeTransaction(self, tid, object_list, transaction, temporary = True):
350 351
        q = self.query
        e = self.escape
352
        tid = util.u64(tid)
353 354 355 356 357 358 359 360

        if temporary:
            obj_table = 'tobj'
            trans_table = 'ttrans'
        else:
            obj_table = 'obj'
            trans_table = 'trans'

361 362 363
        self.begin()
        try:
            for oid, compression, checksum, data in object_list:
364
                oid = util.u64(oid)
365
                data = e(data)
366
                q("""REPLACE INTO %s VALUES (%d, %d, %d, %d, '%s')""" \
367
                        % (obj_table, oid, tid, compression, checksum, data))
368
            if transaction is not None:
369 370
                oid_list, user, desc, ext, packed = transaction
                packed = packed and 1 or 0
371 372 373 374
                oids = e(''.join(oid_list))
                user = e(user)
                desc = e(desc)
                ext = e(ext)
375 376
                q("""REPLACE INTO %s VALUES (%d, %i, '%s', '%s', '%s', '%s')""" \
                        % (trans_table, tid, packed, oids, user, desc, ext))
377 378 379 380 381 382 383
        except:
            self.rollback()
            raise
        self.commit()

    def finishTransaction(self, tid):
        q = self.query
384
        tid = util.u64(tid)
385 386
        self.begin()
        try:
387
            q("""INSERT INTO obj SELECT * FROM tobj WHERE tobj.serial = %d""" \
388
                    % tid)
389
            q("""DELETE FROM tobj WHERE serial = %d""" % tid)
390
            q("""INSERT INTO trans SELECT * FROM ttrans WHERE ttrans.tid = %d"""
391
                    % tid)
392
            q("""DELETE FROM ttrans WHERE tid = %d""" % tid)
393 394 395 396 397
        except:
            self.rollback()
            raise
        self.commit()

398 399
    def deleteTransaction(self, tid, all = False):
        q = self.query
400
        tid = util.u64(tid)
401 402
        self.begin()
        try:
403 404
            q("""DELETE FROM tobj WHERE serial = %d""" % tid)
            q("""DELETE FROM ttrans WHERE tid = %d""" % tid)
405 406
            if all:
                # Note that this can be very slow.
407 408
                q("""DELETE FROM obj WHERE serial = %d""" % tid)
                q("""DELETE FROM trans WHERE tid = %d""" % tid)
409 410 411 412 413 414 415
        except:
            self.rollback()
            raise
        self.commit()

    def getTransaction(self, tid, all = False):
        q = self.query
416
        tid = util.u64(tid)
417
        self.begin()
418
        r = q("""SELECT oids, user, description, ext, packed FROM trans
419
                    WHERE tid = %d""" \
420 421
                % tid)
        if not r and all:
422
            r = q("""SELECT oids, user, description, ext, packed FROM ttrans
423
                        WHERE tid = %d""" \
424 425 426
                    % tid)
        self.commit()
        if r:
Grégory Wisniewski's avatar
Grégory Wisniewski committed
427
            oids, user, desc, ext, packed = r[0]
428
            if (len(oids) % 8) != 0 or len(oids) == 0:
429
                raise DatabaseFailure('invalid oids for tid %x' % tid)
430 431 432
            oid_list = []
            for i in xrange(0, len(oids), 8):
                oid_list.append(oids[i:i+8])
433
            return oid_list, user, desc, ext, bool(packed)
434 435
        return None

436 437
    def getOIDList(self, offset, length, num_partitions, partition_list):
        q = self.query
438
        r = q("""SELECT DISTINCT oid FROM obj WHERE MOD(oid, %d) in (%s)
439
                    ORDER BY oid DESC LIMIT %d,%d""" \
440
                % (num_partitions, ','.join([str(p) for p in partition_list]),
441
                   offset, length))
442
        return [util.p64(t[0]) for t in r]
443

444
    def getObjectHistory(self, oid, offset = 0, length = 1):
445
        q = self.query
446
        oid = util.u64(oid)
447
        r = q("""SELECT serial, LENGTH(value) FROM obj WHERE oid = %d
448
                    ORDER BY serial DESC LIMIT %d, %d""" \
449
                % (oid, offset, length))
450
        if r:
451
            return [(util.p64(serial), length) for serial, length in r]
452
        return None
453 454 455

    def getTIDList(self, offset, length, num_partitions, partition_list):
        q = self.query
456
        r = q("""SELECT tid FROM trans WHERE MOD(tid, %d) in (%s)
457
                    ORDER BY tid ASC LIMIT %d,%d""" \
458 459
                % (num_partitions,
                   ','.join([str(p) for p in partition_list]),
460
                   offset, length))
461
        return [util.p64(t[0]) for t in r]
462 463 464 465

    def getTIDListPresent(self, tid_list):
        q = self.query
        r = q("""SELECT tid FROM trans WHERE tid in (%s)""" \
466 467
                % ','.join([str(util.u64(tid)) for tid in tid_list]))
        return [util.p64(t[0]) for t in r]
468 469 470

    def getSerialListPresent(self, oid, serial_list):
        q = self.query
471
        oid = util.u64(oid)
472
        r = q("""SELECT serial FROM obj WHERE oid = %d AND serial in (%s)""" \
473 474
                % (oid, ','.join([str(util.u64(serial)) for serial in serial_list])))
        return [util.p64(t[0]) for t in r]
475