Commit 8b43f516 authored by Jérome Perrin's avatar Jérome Perrin

upgradeSchema bug fixes

Add some tests and fix bugs:
* https://nexedi.erp5.net/bug_module/20170426-A3962E
* another bug that columns names were not escaped ( in a project we have a custom table with a column named `use` and this breaks `upgradeSchema`  )

/reviewed-on !854
parents dacf12bb 0016855f
...@@ -42,6 +42,7 @@ from Acquisition import aq_base, aq_inner, aq_parent, ImplicitAcquisitionWrapper ...@@ -42,6 +42,7 @@ from Acquisition import aq_base, aq_inner, aq_parent, ImplicitAcquisitionWrapper
from Products.CMFActivity.ActiveObject import ActiveObject from Products.CMFActivity.ActiveObject import ActiveObject
from Products.CMFActivity.ActivityTool import GroupedMessage from Products.CMFActivity.ActivityTool import GroupedMessage
from Products.ERP5Type.TransactionalVariable import getTransactionalVariable from Products.ERP5Type.TransactionalVariable import getTransactionalVariable
from Products.ZMySQLDA.DA import DeferredConnection
from AccessControl.PermissionRole import rolesForPermissionOn from AccessControl.PermissionRole import rolesForPermissionOn
...@@ -1361,20 +1362,45 @@ class CatalogTool (UniqueObject, ZCatalog, CMFCoreCatalogTool, ActiveObject): ...@@ -1361,20 +1362,45 @@ class CatalogTool (UniqueObject, ZCatalog, CMFCoreCatalogTool, ActiveObject):
security.declareProtected(Permissions.ManagePortal, 'upgradeSchema') security.declareProtected(Permissions.ManagePortal, 'upgradeSchema')
def upgradeSchema(self, sql_catalog_id=None, src__=0): def upgradeSchema(self, sql_catalog_id=None, src__=0):
"""Upgrade all catalog tables, with ALTER or CREATE queries""" """Upgrade all catalog tables, with ALTER or CREATE queries"""
portal = self.getPortalObject()
catalog = self.getSQLCatalog(sql_catalog_id) catalog = self.getSQLCatalog(sql_catalog_id)
connection_id = catalog.z_create_catalog.connection_id
src = [] # group methods by connection
db = self.getPortalObject()[connection_id]() method_list_by_connection_id = defaultdict(list)
with db.lock(): for method_id in catalog.sql_clear_catalog:
for clear_method in catalog.sql_clear_catalog: method = catalog[method_id]
r = catalog[clear_method]._upgradeSchema( method_list_by_connection_id[method.connection_id].append(method)
connection_id, create_if_not_exists=1, src__=1)
if r: # Because we cannot select on deferred connections, _upgradeSchema
src.append(r) # cannot be used on SQL methods using a deferred connection.
if not src__: # We try to find a "non deferred" connection using the same connection
for r in src: # string and we'll use it instead.
db.query(r) connection_by_connection_id = {}
return src for connection_id in method_list_by_connection_id:
connection = portal[connection_id]
connection_string = connection.connection_string
connection_by_connection_id[connection_id] = connection
if isinstance(connection, DeferredConnection):
for other_connection in portal.objectValues(
spec=('Z MySQL Database Connection',)):
if connection_string == other_connection.connection_string:
connection_by_connection_id[connection_id] = other_connection
break
queries_by_connection_id = defaultdict(list)
for connection_id, method_list in method_list_by_connection_id.items():
connection = connection_by_connection_id[connection_id]
db = connection()
with db.lock():
for method in method_list:
query = method._upgradeSchema(connection.getId(), create_if_not_exists=1, src__=1)
if query:
queries_by_connection_id[connection_id].append(query)
if not src__:
for query in queries_by_connection_id[connection_id]:
db.query(query)
return sum(queries_by_connection_id.values(), [])
security.declarePublic('getDocumentValueList') security.declarePublic('getDocumentValueList')
def getDocumentValueList(self, sql_catalog_id=None, def getDocumentValueList(self, sql_catalog_id=None,
......
...@@ -34,6 +34,7 @@ import httplib ...@@ -34,6 +34,7 @@ import httplib
from AccessControl import getSecurityManager from AccessControl import getSecurityManager
from AccessControl.SecurityManagement import newSecurityManager from AccessControl.SecurityManagement import newSecurityManager
from DateTime import DateTime from DateTime import DateTime
from _mysql_exceptions import ProgrammingError
from OFS.ObjectManager import ObjectManager from OFS.ObjectManager import ObjectManager
from Products.ERP5Type.tests.ERP5TypeTestCase import ERP5TypeTestCase from Products.ERP5Type.tests.ERP5TypeTestCase import ERP5TypeTestCase
from Products.ERP5Type.tests.utils import LogInterceptor, createZODBPythonScript, todo_erp5, getExtraSqlConnectionStringList from Products.ERP5Type.tests.utils import LogInterceptor, createZODBPythonScript, todo_erp5, getExtraSqlConnectionStringList
...@@ -3830,7 +3831,135 @@ VALUES ...@@ -3830,7 +3831,135 @@ VALUES
# but a proper page # but a proper page
self.assertIn('<title>Catalog Tool - portal_catalog', ret.getBody()) self.assertIn('<title>Catalog Tool - portal_catalog', ret.getBody())
def test_suite():
suite = unittest.TestSuite() class CatalogToolUpgradeSchemaTestCase(ERP5TypeTestCase):
suite.addTest(unittest.makeSuite(TestERP5Catalog)) """Tests for "upgrade schema" feature of ERP5 Catalog.
return suite """
def getBusinessTemplateList(self):
return ("erp5_full_text_mroonga_catalog",)
def afterSetUp(self):
# Add two connections
db1, db2 = getExtraSqlConnectionStringList()[:2]
addConnection = self.portal.manage_addProduct[
"ZMySQLDA"].manage_addZMySQLConnection
addConnection("erp5_test_connection_1", "", db1)
addConnection("erp5_test_connection_2", "", db2)
addConnection("erp5_test_connection_deferred_2", "", db2, deferred=True)
self.catalog_tool = self.portal.portal_catalog
self.catalog = self.catalog_tool.newContent(portal_type="Catalog")
self.catalog.newContent(
portal_type="SQL Method",
connection_id="erp5_test_connection_1",
id="z_create_catalog",
src="CREATE TABLE dummy_catalog (uid int)")
# These will be cleaned up at tear down
self._db1_table_list = ["dummy_catalog"]
self._db2_table_list = []
def beforeTearDown(self):
for table in self._db1_table_list:
self.query_connection_1("DROP TABLE IF EXISTS `%s`" % table)
for table in self._db2_table_list:
self.query_connection_2("DROP TABLE IF EXISTS `%s`" % table)
self.portal.manage_delObjects([
"erp5_test_connection_1",
"erp5_test_connection_2",
"erp5_test_connection_deferred_2"])
self.commit()
def query_connection_1(self, q):
return self.portal.erp5_test_connection_1().query(q)
def query_connection_2(self, q):
return self.portal.erp5_test_connection_2().query(q)
def upgradeSchema(self):
self.assertTrue(
self.catalog_tool.upgradeSchema(
sql_catalog_id=self.catalog.getId(), src__=True))
self.catalog_tool.upgradeSchema(sql_catalog_id=self.catalog.getId())
self.assertFalse(
self.catalog_tool.upgradeSchema(
sql_catalog_id=self.catalog.getId(), src__=True))
def test_upgradeSchema_add_table(self):
self._db1_table_list.append("add_table")
method = self.catalog.newContent(
portal_type="SQL Method",
connection_id="erp5_test_connection_1",
id=self.id(),
src="CREATE TABLE add_table (a int)")
self.catalog.setSqlClearCatalogList([method.getId()])
self.commit()
self.upgradeSchema()
self.commit()
self.query_connection_1("SELECT a from add_table")
def test_upgradeSchema_alter_table(self):
self._db1_table_list.append("altered_table")
self.query_connection_1("CREATE TABLE altered_table (a int)")
self.commit()
method = self.catalog.newContent(
portal_type="SQL Method",
connection_id="erp5_test_connection_1",
id=self.id(),
src="CREATE TABLE altered_table (a int, b int)")
self.catalog.setSqlClearCatalogList([method.getId()])
self.commit()
self.upgradeSchema()
self.commit()
self.query_connection_1("SELECT b from altered_table")
def test_upgradeSchema_multi_connections(self):
# Check that we can upgrade tables on more than one connection,
# like when using an external datawarehouse. This is a reproduction
# for https://nexedi.erp5.net/bug_module/20170426-A3962E
# In this test we use both "normal" and deferred connections,
# which is what happens in default erp5 catalog.
self._db1_table_list.append("table1")
self.query_connection_1("CREATE TABLE table1 (a int)")
self._db2_table_list.extend(("table2", "table_deferred2"))
self.query_connection_2("CREATE TABLE table2 (a int)")
self.query_connection_2("CREATE TABLE table_deferred2 (a int)")
self.commit()
method1 = self.catalog.newContent(
portal_type="SQL Method",
connection_id="erp5_test_connection_1",
src="CREATE TABLE table1 (a int, b int)")
method2 = self.catalog.newContent(
portal_type="SQL Method",
connection_id="erp5_test_connection_2",
src="CREATE TABLE table2 (a int, b int)")
method_deferred2 = self.catalog.newContent(
portal_type="SQL Method",
connection_id="erp5_test_connection_deferred_2",
src="CREATE TABLE table_deferred2 (a int, b int)")
self.catalog.setSqlClearCatalogList(
[method1.getId(),
method2.getId(),
method_deferred2.getId()])
self.commit()
self.upgradeSchema()
self.commit()
self.query_connection_1("SELECT b from table1")
self.query_connection_2("SELECT b from table2")
self.query_connection_2("SELECT b from table_deferred2")
with self.assertRaisesRegexp(ProgrammingError,
r"Table '.*\.table2' doesn't exist"):
self.query_connection_1("SELECT b from table2")
with self.assertRaisesRegexp(ProgrammingError,
r"Table '.*\.table_deferred2' doesn't exist"):
self.query_connection_1("SELECT b from table_deferred2")
with self.assertRaisesRegexp(ProgrammingError,
r"Table '.*\.table1' doesn't exist"):
self.query_connection_2("SELECT b from table1")
...@@ -530,7 +530,7 @@ class DB(TM): ...@@ -530,7 +530,7 @@ class DB(TM):
# already done it (in case that it plans to execute the returned query). # already done it (in case that it plans to execute the returned query).
with (nested if src__ else self.lock)(): with (nested if src__ else self.lock)():
try: try:
old_list, old_set, old_default = self._getTableSchema(name) old_list, old_set, old_default = self._getTableSchema("`%s`" % name)
except ProgrammingError, e: except ProgrammingError, e:
if e[0] != ER.NO_SUCH_TABLE or not create_if_not_exists: if e[0] != ER.NO_SUCH_TABLE or not create_if_not_exists:
raise raise
...@@ -538,7 +538,7 @@ class DB(TM): ...@@ -538,7 +538,7 @@ class DB(TM):
self.query(create_sql) self.query(create_sql)
return create_sql return create_sql
name_new = '_%s_new' % name name_new = '`_%s_new`' % name
self.query('CREATE TEMPORARY TABLE %s %s' self.query('CREATE TEMPORARY TABLE %s %s'
% (name_new, create_sql[m.end():])) % (name_new, create_sql[m.end():]))
try: try:
...@@ -559,7 +559,7 @@ class DB(TM): ...@@ -559,7 +559,7 @@ class DB(TM):
old_dict[column] = pos, spec old_dict[column] = pos, spec
pos += 1 pos += 1
else: else:
q("DROP COLUMN " + column) q("DROP COLUMN `%s`" % column)
for key in old_set - new_set: for key in old_set - new_set:
if "PRIMARY" in key: if "PRIMARY" in key:
...@@ -574,26 +574,26 @@ class DB(TM): ...@@ -574,26 +574,26 @@ class DB(TM):
try: try:
old = old_dict[column] old = old_dict[column]
except KeyError: except KeyError:
q("ADD COLUMN %s %s %s" % (column, spec, where)) q("ADD COLUMN `%s` %s %s" % (column, spec, where))
column_list.append(column) column_list.append(column)
else: else:
if old != (pos, spec): if old != (pos, spec):
q("MODIFY COLUMN %s %s %s" % (column, spec, where)) q("MODIFY COLUMN `%s` %s %s" % (column, spec, where))
if old[1] != spec: if old[1] != spec:
column_list.append(column) column_list.append(column)
pos += 1 pos += 1
where = "AFTER " + column where = "AFTER `%s`" % column
for key in new_set - old_set: for key in new_set - old_set:
q("ADD " + key) q("ADD " + key)
if src: if src:
src = "ALTER TABLE %s%s" % (name, ','.join("\n " + q src = "ALTER TABLE `%s`%s" % (name, ','.join("\n " + q
for q in src)) for q in src))
if not src__: if not src__:
self.query(src) self.query(src)
if column_list and initialize and self.query( if column_list and initialize and self.query(
"SELECT 1 FROM " + name, 1)[1]: "SELECT 1 FROM `%s`" % name, 1)[1]:
initialize(self, column_list) initialize(self, column_list)
return src return src
......
##############################################################################
# coding: utf-8
# Copyright (c) 2019 Nexedi SA and Contributors. All Rights Reserved.
# Jérome Perrin <jerome@nexedi.com>
#
# WARNING: This program as such is intended to be used by professional
# programmers who take the whole responsability of assessing all potential
# consequences resulting from its eventual inadequacies and bugs
# End users who are looking for a ready-to-use solution with commercial
# garantees and support are strongly adviced to contract a Free Software
# Service Company
#
# This program is Free Software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
#
##############################################################################
from textwrap import dedent
from _mysql_exceptions import OperationalError
from Shared.DC.ZRDB.DA import DA
from Products.ERP5Type.tests.ERP5TypeTestCase import ERP5TypeTestCase
class TestTableStructureMigrationTestCase(ERP5TypeTestCase):
def getBusinessTemplateList(self):
return 'erp5_full_text_mroonga_catalog',
def beforeTearDown(self):
self.portal.erp5_sql_connection().query('DROP table if exists X')
self.portal.erp5_sql_connection().query('DROP table if exists `table`')
self.commit()
def query(self, q):
return self.portal.erp5_sql_connection().query(q)
def check_upgrade_schema(self, previous_schema, new_schema, table_name='X'):
self.query(previous_schema)
da = DA(
id=self.id(),
title=self.id(),
connection_id=self.portal.erp5_sql_connection.getId(),
arguments=(),
template=new_schema).__of__(self.portal)
self.assertTrue(da._upgradeSchema(src__=True))
da._upgradeSchema()
self.assertFalse(da._upgradeSchema(src__=True))
self.assertEqual(
new_schema,
self.query('SHOW CREATE TABLE `%s`' % table_name)[1][0][1])
def test_add_column(self):
self.check_upgrade_schema(
dedent(
"""\
CREATE TABLE `X` (
`a` int(11) DEFAULT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci"""),
dedent(
"""\
CREATE TABLE `X` (
`a` int(11) DEFAULT NULL,
`b` int(11) DEFAULT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci"""))
self.query("SELECT a, b FROM X")
def test_remove_column(self):
self.check_upgrade_schema(
dedent(
"""\
CREATE TABLE `X` (
`a` int(11) DEFAULT NULL,
`b` int(11) DEFAULT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci"""),
dedent(
"""\
CREATE TABLE `X` (
`b` int(11) DEFAULT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci"""))
self.query("SELECT b FROM X")
with self.assertRaisesRegexp(OperationalError,
"Unknown column 'a' in 'field list'"):
self.query("SELECT a FROM X")
def test_rename_column(self):
self.check_upgrade_schema(
dedent(
"""\
CREATE TABLE `X` (
`a` int(11) DEFAULT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci"""),
dedent(
"""\
CREATE TABLE `X` (
`b` int(11) DEFAULT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci"""))
self.query("SELECT b FROM X")
with self.assertRaisesRegexp(OperationalError,
"Unknown column 'a' in 'field list'"):
self.query("SELECT a FROM X")
def test_change_column_type(self):
self.check_upgrade_schema(
dedent(
"""\
CREATE TABLE `X` (
`a` int(11) DEFAULT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci"""),
dedent(
"""\
CREATE TABLE `X` (
`a` varchar(10) COLLATE utf8_unicode_ci DEFAULT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci"""))
# insterting 1 will be casted as string
self.query("INSERT INTO X VALUES (1)")
self.assertEqual(('1',), self.query("SELECT a FROM X")[1][0])
def test_change_column_default(self):
self.check_upgrade_schema(
dedent(
"""\
CREATE TABLE `X` (
`a` int(11) DEFAULT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci"""),
dedent(
"""\
CREATE TABLE `X` (
`a` int(11) DEFAULT 123
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci"""))
self.query("INSERT INTO X VALUES ()")
self.assertEqual((123,), self.query("SELECT a FROM X")[1][0])
def test_add_index(self):
self.check_upgrade_schema(
dedent(
"""\
CREATE TABLE `X` (
`a` int(11) DEFAULT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci"""),
dedent(
"""\
CREATE TABLE `X` (
`a` int(11) DEFAULT NULL,
KEY `idx_a` (`a`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci"""))
self.query("SELECT * FROM X USE INDEX (`idx_a`)")
def test_remove_index(self):
self.check_upgrade_schema(
dedent(
"""\
CREATE TABLE `X` (
`a` int(11) DEFAULT NULL,
KEY `idx_a` (`a`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci"""),
dedent(
"""\
CREATE TABLE `X` (
`a` int(11) DEFAULT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci"""))
with self.assertRaisesRegexp(OperationalError,
"Key 'idx_a' doesn't exist in table 'X'"):
self.query("SELECT * FROM X USE INDEX (`idx_a`)")
def test_escape(self):
self.check_upgrade_schema(
dedent(
"""\
CREATE TABLE `table` (
`drop` int(11) DEFAULT NULL,
`alter` int(11) DEFAULT NULL,
KEY `CASE` (`drop`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci"""),
dedent(
"""\
CREATE TABLE `table` (
`and` int(11) DEFAULT NULL,
`alter` varchar(255) COLLATE utf8_unicode_ci DEFAULT 'BETWEEN',
KEY `use` (`alter`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci"""),
table_name='table')
self.query(
"SELECT `alter`, `and` FROM `table` USE INDEX (`use`)")
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment