#
# Copyright (C) 2009-2015  Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import unittest
import MySQLdb
from mock import Mock
from neo.lib.exception import DatabaseFailure
from .testStorageDBTests import StorageDBTests
from neo.storage.database.mysqldb import MySQLDatabaseManager

NEO_SQL_DATABASE = 'test_mysqldb0'
NEO_SQL_USER = 'test'

class StorageMySQLdbTests(StorageDBTests):

    engine = None

    def getDB(self, reset=0):
        self.prepareDatabase(number=1, prefix=NEO_SQL_DATABASE[:-1])
        # db manager
        database = '%s@%s' % (NEO_SQL_USER, NEO_SQL_DATABASE)
        db = MySQLDatabaseManager(database, self.engine)
        self.assertEqual(db.db, NEO_SQL_DATABASE)
        self.assertEqual(db.user, NEO_SQL_USER)
        db.setup(reset)
        return db

    def test_query1(self):
        # fake result object
        from array import array
        result_object = Mock({
            "num_rows": 1,
            "fetch_row": ((1, 2, array('b', (1, 2, ))), ),
        })
        # expected formatted result
        expected_result = (
            (1, 2, '\x01\x02', ),
        )
        self.db.conn = Mock({ 'store_result': result_object })
        result = self.db.query('SELECT ')
        self.assertEqual(result, expected_result)
        calls = self.db.conn.mockGetNamedCalls('query')
        self.assertEqual(len(calls), 1)
        calls[0].checkArgs('SELECT ')

    def test_query2(self):
        # test the OperationalError exception
        # fake object, raise exception during the first call
        from MySQLdb import OperationalError
        from MySQLdb.constants.CR import SERVER_GONE_ERROR
        class FakeConn(object):
            def query(*args):
                raise OperationalError(SERVER_GONE_ERROR, 'this is a test')
        self.db.conn = FakeConn()
        self.connect_called = False
        def connect_hook():
            # mock object, break raise/connect loop
            self.db.conn = Mock()
            self.connect_called = True
        self.db._connect = connect_hook
        # make a query, exception will be raised then connect() will be
        # called and the second query will use the mock object
        self.db.query('INSERT')
        self.assertTrue(self.connect_called)

    def test_query3(self):
        # OperationalError > raise DatabaseFailure exception
        from MySQLdb import OperationalError
        class FakeConn(object):
            def close(self):
                pass
            def query(*args):
                raise OperationalError(-1, 'this is a test')
        self.db.conn = FakeConn()
        self.assertRaises(DatabaseFailure, self.db.query, 'QUERY')

    def test_escape(self):
        self.assertEqual(self.db.escape('a"b'), 'a\\"b')
        self.assertEqual(self.db.escape("a'b"), "a\\'b")

    def test_max_allowed_packet(self):
        EXTRA = 2
        # Check EXTRA
        x = "SELECT '%s'" % ('x' * (self.db._max_allowed_packet - 11))
        assert len(x) + EXTRA == self.db._max_allowed_packet
        self.assertRaises(DatabaseFailure, self.db.query, x + ' ')
        self.db.query(x)
        # Check MySQLDatabaseManager._max_allowed_packet
        query_list = []
        query = self.db.query
        self.db.query = lambda query: query_list.append(EXTRA + len(query))
        self.assertEqual(2, max(len(self.db.escape(chr(x)))
                                for x in xrange(256)))
        self.assertEqual(2, len(self.db.escape('\0')))
        self.db.storeData('\0' * 20, '\0' * (2**24-1), 0)
        size, = query_list
        max_allowed = self.db.__class__._max_allowed_packet
        self.assertTrue(max_allowed - 1024 < size <= max_allowed, size)

class StorageMySQLdbTokuDBTests(StorageMySQLdbTests):

    engine = "TokuDB"

del StorageDBTests

if __name__ == "__main__":
    unittest.main()