Commit d56f662d authored by Jérome Perrin's avatar Jérome Perrin Committed by Arnaud Fontaine

fixup! py3: _mysql.string_literal() returns bytes() (!1751).

Followup of 94739085.
parent 1fca2fa7
......@@ -464,11 +464,10 @@ class TestERP5Type(PropertySheetTestCase, LogInterceptor):
modified_title = getTitleFromCatalog() + '_not_reindexed'
catalog_connection = self.getSQLConnection()()
catalog_connection.query(
'UPDATE catalog SET title=%s WHERE uid=%i' % (
b'UPDATE catalog SET title=%s WHERE uid=%i' % (
catalog_connection.string_literal(modified_title),
person_object.getUid(),
),
)
))
self.commit()
# sanity check
self.assertEqual(getTitleFromCatalog(), modified_title)
......
......@@ -38,17 +38,17 @@ class TestSQLVar(ERP5TypeTestCase):
connection_id='erp5_sql_connection',
arguments_src='value',
src='<dtml-sqlvar value type="string">')
self.assertEqual(sqlmethod(value='', src__=1), "''")
self.assertEqual(sqlmethod(value=None, src__=1), 'null')
self.assertEqual(sqlmethod(value='', src__=1), b"''")
self.assertEqual(sqlmethod(value=None, src__=1), b'null')
sqlmethod.edit(src='<dtml-sqlvar value type="string" optional>')
self.assertEqual(sqlmethod(value='', src__=1), "''")
self.assertEqual(sqlmethod(value=None, src__=1), 'null')
self.assertEqual(sqlmethod(value='', src__=1), b"''")
self.assertEqual(sqlmethod(value=None, src__=1), b'null')
sqlmethod.edit(src='<dtml-sqlvar value type="nb">')
self.assertRaises(ValueError, sqlmethod, value='', src__=1)
self.assertEqual(sqlmethod(value=None, src__=1), 'null')
self.assertEqual(sqlmethod(value=None, src__=1), b'null')
sqlmethod.edit(src='<dtml-sqlvar value type="nb" optional>')
self.assertEqual(sqlmethod(value='', src__=1), 'null')
self.assertEqual(sqlmethod(value=None, src__=1), 'null')
self.assertEqual(sqlmethod(value='', src__=1), b'null')
self.assertEqual(sqlmethod(value=None, src__=1), b'null')
......@@ -29,7 +29,7 @@ from __future__ import absolute_import
from six import string_types as basestring
from six.moves import xrange
from Products.ERP5Type.Utils import ensure_list, str2bytes
from Products.ERP5Type.Utils import ensure_list, str2bytes, bytes2str
from collections import defaultdict
from contextlib import contextmanager
from itertools import product, chain
......@@ -141,10 +141,10 @@ def sqltest_dict():
if value is None: # XXX: see comment in SQLBase._getMessageList
return column + b" IS NULL"
for x in value:
return b"%s IN (%s)" % (column, str2bytes(', '.join(map(
return str2bytes("%s IN (%s)" % (column, ', '.join(map(
str if isinstance(x, _SQLTEST_NO_QUOTE_TYPE_SET) else
render_datetime if isinstance(x, DateTime) else
render_string, value))))
lambda v: bytes2str(render_string(v)), value))))
return b"0"
sqltest_dict[name] = render
_('active_process_uid')
......@@ -246,7 +246,7 @@ def getNow(db):
Note that this value is not cached, and is not transactionnal on MySQL
side.
"""
return db.query("SELECT UTC_TIMESTAMP(6)", 0)[1][0][0]
return db.query(b"SELECT UTC_TIMESTAMP(6)", 0)[1][0][0]
class SQLBase(Queue):
"""
......@@ -284,7 +284,7 @@ CREATE TABLE %s (
db = activity_tool.getSQLConnection()
create = self.createTableSQL()
if clear:
db.query("DROP TABLE IF EXISTS " + self.sql_table)
db.query(str2bytes("DROP TABLE IF EXISTS " + self.sql_table))
db.query(create)
else:
src = db.upgradeSchema(create, create_if_not_exists=1,
......@@ -398,10 +398,10 @@ CREATE TABLE %s (
for line in result]
def countMessageSQL(self, quote, **kw):
return "SELECT count(*) FROM %s WHERE processing_node > %d AND %s" % (
self.sql_table, DEPENDENCY_IGNORED_ERROR_STATE, " AND ".join(
return b"SELECT count(*) FROM %s WHERE processing_node > %d AND %s" % (
str2bytes(self.sql_table), DEPENDENCY_IGNORED_ERROR_STATE, b" AND ".join(
sqltest_dict[k](v, quote) for (k, v) in six.iteritems(kw) if v
) or "1")
) or b"1")
def hasActivitySQL(self, quote, only_valid=False, only_invalid=False, **kw):
where = [sqltest_dict[k](v, quote) for (k, v) in six.iteritems(kw) if v]
......@@ -426,7 +426,7 @@ CREATE TABLE %s (
0,
)[1]
else:
subquery = (b"("
subquery = lambda *a, **k: str2bytes(bytes2str(b"("
b"SELECT 3*priority{} AS effective_priority, date"
b" FROM %s"
b" WHERE"
......@@ -435,7 +435,7 @@ CREATE TABLE %s (
b" date <= UTC_TIMESTAMP(6)"
b" ORDER BY priority, date"
b" LIMIT 1"
b")" % self.sql_table).format
b")" % str2bytes(self.sql_table)).format(*a, **k))
result = query(
b"SELECT *"
b" FROM (%s) AS t"
......@@ -444,11 +444,11 @@ CREATE TABLE %s (
b" UNION ALL ".join(
chain(
(
subquery(b'-1', b'node = %i' % processing_node),
subquery(b'', b'node=0'),
subquery('-1', 'node = %i' % processing_node),
subquery('', 'node=0'),
),
(
subquery(b'-1', b'node = %i' % x)
subquery('-1', 'node = %i' % x)
for x in node_set
),
),
......@@ -465,7 +465,7 @@ CREATE TABLE %s (
# sorted set to filter negative node values.
# This is why this query is only executed when the previous one
# did not find anything.
result = query(subquery(b'+1', b'node>0'), 0)[1]
result = query(subquery('+1', 'node>0'), 0)[1]
if result:
return result[0]
return Queue.getPriority(self, activity_tool, processing_node, node_set)
......@@ -781,7 +781,7 @@ CREATE TABLE %s (
0,
))
else:
subquery = (b"("
subquery = lambda *a, **k: str2bytes(bytes2str(b"("
b"SELECT *, 3*priority{} AS effective_priority"
b" FROM %s"
b" WHERE"
......@@ -790,8 +790,7 @@ CREATE TABLE %s (
b" %s%s"
b" ORDER BY priority, date"
b" LIMIT %i"
b" FOR UPDATE"
b")" % args).format
b")" % args).format(*a, **k))
result = Results(query(
b"SELECT *"
b" FROM (%s) AS t"
......@@ -800,11 +799,11 @@ CREATE TABLE %s (
b" UNION ALL ".join(
chain(
(
subquery(b'-1', b'node = %i' % processing_node),
subquery(b'', b'node=0'),
subquery('-1', 'node = %i' % processing_node),
subquery('', 'node=0'),
),
(
subquery(b'-1', b'node = %i' % x)
subquery('-1', 'node = %i' % x)
for x in node_set
),
),
......@@ -822,7 +821,7 @@ CREATE TABLE %s (
# sorted set to filter negative node values.
# This is why this query is only executed when the previous one
# did not find anything.
result = Results(query(subquery(b'+1', b'node>0'), 0))
result = Results(query(subquery('+1', 'node>0'), 0))
if result:
# Reserve messages.
uid_list = [x.uid for x in result]
......@@ -835,8 +834,8 @@ CREATE TABLE %s (
"""
Put messages back in given processing_node.
"""
db.query("UPDATE %s SET processing_node=%s WHERE uid IN (%s)\0COMMIT" % (
self.sql_table, state, ','.join(map(str, uid_list))))
db.query(str2bytes("UPDATE %s SET processing_node=%s WHERE uid IN (%s)\0COMMIT" % (
self.sql_table, state, ','.join(map(str, uid_list)))))
def getProcessableMessageLoader(self, db, processing_node):
# do not merge anything
......@@ -1043,16 +1042,16 @@ CREATE TABLE %s (
return bool(message_list)
def deleteMessageList(self, db, uid_list):
db.query("DELETE FROM %s WHERE uid IN (%s)" % (
self.sql_table, ','.join(map(str, uid_list))))
db.query(str2bytes("DELETE FROM %s WHERE uid IN (%s)" % (
self.sql_table, ','.join(map(str, uid_list)))))
def reactivateMessageList(self, db, uid_list, delay, retry):
db.query("UPDATE %s SET"
db.query(str2bytes("UPDATE %s SET"
" date = DATE_ADD(UTC_TIMESTAMP(6), INTERVAL %s SECOND)"
"%s WHERE uid IN (%s)" % (
self.sql_table, delay,
", retry = retry + 1" if retry else "",
",".join(map(str, uid_list))))
",".join(map(str, uid_list)))))
def finalizeMessageExecution(self, activity_tool, message_list,
uid_to_duplicate_uid_list_dict=None):
......@@ -1209,8 +1208,8 @@ CREATE TABLE %s (
To simulate time shift, we simply substract delay from
all dates in message(_queue) table
"""
activity_tool.getSQLConnection().query("UPDATE %s SET"
activity_tool.getSQLConnection().query(str2bytes("UPDATE %s SET"
" date = DATE_SUB(date, INTERVAL %s SECOND)"
% (self.sql_table, delay)
+ ('' if processing_node is None else
"WHERE processing_node=%s" % processing_node))
"WHERE processing_node=%s" % processing_node)))
......@@ -125,7 +125,7 @@ class SQLDict(SQLBase):
b" WHERE processing_node = 0 AND (path = %s OR path LIKE %s)"
b"%s FOR UPDATE" % (
quote(path), quote(path.replace('_', r'\_') + '/%'),
str2bytes(sql_method_id),
sql_method_id,
), 0)[1]
reserve_uid_list = [x for x, in result]
uid_list += reserve_uid_list
......@@ -142,10 +142,10 @@ class SQLDict(SQLBase):
if reserve_uid_list:
self.assignMessageList(db, processing_node, reserve_uid_list)
else:
db.query("COMMIT") # XXX: useful ?
db.query(b"COMMIT") # XXX: useful ?
except:
self._log(WARNING, 'Failed to reserve duplicates')
db.query("ROLLBACK")
db.query(b"ROLLBACK")
raise
if uid_list:
self._log(TRACE, 'Reserved duplicate messages: %r' % uid_list)
......
......@@ -36,6 +36,7 @@ from .SQLBase import (
UID_SAFE_BITSIZE, UID_ALLOCATION_TRY_COUNT,
)
from Products.CMFActivity.ActivityTool import Message
from Products.ERP5Type.Utils import str2bytes
from .SQLDict import SQLDict
from six.moves import xrange
......@@ -77,10 +78,10 @@ CREATE TABLE %s (
return (tuple(m.object_path), m.method_id, m.activity_kw.get('signature'),
m.activity_kw.get('tag'), m.activity_kw.get('group_id'))
_insert_template = ("INSERT INTO %s (uid,"
" path, active_process_uid, date, method_id, processing_node,"
" priority, group_method_id, tag, signature, serialization_tag,"
" message) VALUES\n(%s)")
_insert_template = (b"INSERT INTO %s (uid,"
b" path, active_process_uid, date, method_id, processing_node,"
b" priority, group_method_id, tag, signature, serialization_tag,"
b" message) VALUES\n(%s)")
def prepareQueueMessageList(self, activity_tool, message_list):
db = activity_tool.getSQLConnection()
......@@ -92,9 +93,9 @@ CREATE TABLE %s (
if reset_uid:
reset_uid = False
# Overflow will result into IntegrityError.
db.query("SET @uid := %s" % getrandbits(UID_SAFE_BITSIZE))
db.query(b"SET @uid := %s" % str2bytes(str(getrandbits(UID_SAFE_BITSIZE))))
try:
db.query(self._insert_template % (self.sql_table, values))
db.query(self._insert_template % (str2bytes(self.sql_table), values))
except MySQLdb.IntegrityError as e:
if e.args[0] != DUP_ENTRY:
raise
......@@ -113,14 +114,14 @@ CREATE TABLE %s (
if m.is_registered:
active_process_uid = m.active_process_uid
date = m.activity_kw.get('at_date')
row = ','.join((
'@uid+%s' % i,
row = b','.join((
b'@uid+%s' % str2bytes(str(i)),
quote('/'.join(m.object_path)),
'NULL' if active_process_uid is None else str(active_process_uid),
"UTC_TIMESTAMP(6)" if date is None else quote(render_datetime(date)),
b'NULL' if active_process_uid is None else str2bytes(str(active_process_uid)),
b"UTC_TIMESTAMP(6)" if date is None else quote(render_datetime(date)),
quote(m.method_id),
'-1' if hasDependency(m) else '0',
str(m.activity_kw.get('priority', 1)),
b'-1' if hasDependency(m) else b'0',
str2bytes(str(m.activity_kw.get('priority', 1))),
quote(m.getGroupId()),
quote(m.activity_kw.get('tag', '')),
quote(m.activity_kw.get('signature', '')),
......@@ -156,9 +157,9 @@ CREATE TABLE %s (
m = Message.load(line.message, uid=uid, line=line)
try:
# Select duplicates.
result = db.query("SELECT uid FROM message_job"
" WHERE processing_node = 0 AND path = %s AND signature = %s"
" AND method_id = %s AND group_method_id = %s FOR UPDATE" % (
result = db.query(b"SELECT uid FROM message_job"
b" WHERE processing_node = 0 AND path = %s AND signature = %s"
b" AND method_id = %s AND group_method_id = %s FOR UPDATE" % (
quote(path), quote(line.signature),
quote(method_id), quote(line.group_method_id),
), 0)[1]
......@@ -166,10 +167,10 @@ CREATE TABLE %s (
if uid_list:
self.assignMessageList(db, processing_node, uid_list)
else:
db.query("COMMIT") # XXX: useful ?
db.query(b"COMMIT") # XXX: useful ?
except:
self._log(WARNING, 'Failed to reserve duplicates')
db.query("ROLLBACK")
db.query(b"ROLLBACK")
raise
if uid_list:
self._log(TRACE, 'Reserved duplicate messages: %r' % uid_list)
......
......@@ -1825,7 +1825,7 @@ class ActivityTool (BaseTool):
"""
db = self.getSQLConnection()
quote = db.string_literal
return sum(x for x, in db.query("(%s)" % ") UNION ALL (".join(
return sum(x for x, in db.query(b"(%s)" % b") UNION ALL (".join(
activity.countMessageSQL(quote, **kw)
for activity in six.itervalues(activity_dict)))[1])
......
......@@ -618,7 +618,7 @@ class TestCMFActivity(ERP5TypeTestCase, LogInterceptor):
# Monkey patch Queue to induce conflict errors artificially.
def query(self, query_string,*args, **kw):
# Not so nice, this is specific to zsql method
if "REPLACE INTO" in query_string:
if b"REPLACE INTO" in query_string:
raise OperationalError
return self.original_query(query_string,*args, **kw)
......@@ -1236,7 +1236,7 @@ class TestCMFActivity(ERP5TypeTestCase, LogInterceptor):
# Check that cmf_activity SQL connection still works
connection_da = self.portal.cmf_activity_sql_connection()
self.assertFalse(connection_da._registered)
connection_da.query('select 1')
connection_da.query(b'select 1')
self.assertTrue(connection_da._registered)
self.commit()
self.assertFalse(connection_da._registered)
......@@ -1893,7 +1893,7 @@ class TestCMFActivity(ERP5TypeTestCase, LogInterceptor):
"""
original_query = six.get_unbound_function(DB.query)
def query(self, query_string, *args, **kw):
if query_string.startswith('INSERT'):
if query_string.startswith(b'INSERT'):
insert_list.append(len(query_string))
if not n:
raise Skip
......@@ -2502,7 +2502,7 @@ class TestCMFActivity(ERP5TypeTestCase, LogInterceptor):
self.assertEqual(1, activity_tool.countMessage())
self.flushAllActivities()
sender, recipients, mail = message_list.pop()
self.assertIn('UID mismatch', mail)
self.assertIn(b'UID mismatch', mail)
m, = activity_tool.getMessageList()
self.assertEqual(m.processing_node, INVOKE_ERROR_STATE)
obj.flushActivity()
......
......@@ -102,7 +102,7 @@ class TestDeferredConnection(ERP5TypeTestCase):
Check that a basic query succeeds.
"""
connection = self.getDeferredConnection()
connection.query('REPLACE INTO `full_text` SET `uid`=0, `SearchableText`="dummy test"')
connection.query(b'REPLACE INTO `full_text` SET `uid`=0, `SearchableText`="dummy test"')
try:
self.commit()
except OperationalError:
......@@ -119,7 +119,7 @@ class TestDeferredConnection(ERP5TypeTestCase):
"""
connection = self.getDeferredConnection()
# Queue a query
connection.query('REPLACE INTO `full_text` SET `uid`=0, `SearchableText`="dummy test"')
connection.query(b'REPLACE INTO `full_text` SET `uid`=0, `SearchableText`="dummy test"')
# Replace dynamically the function used to send queries to mysql so it's
# dumber than the implemented one.
self.monkeypatchConnection(connection)
......@@ -144,7 +144,7 @@ class TestDeferredConnection(ERP5TypeTestCase):
"""
connection = self.getDeferredConnection()
# Queue a query
connection.query('REPLACE INTO `full_text` SET `uid`=0, `SearchableText`="dummy test"')
connection.query(b'REPLACE INTO `full_text` SET `uid`=0, `SearchableText`="dummy test"')
# Artificially cause a connection close.
self.monkeypatchConnection(connection)
try:
......@@ -160,10 +160,10 @@ class TestDeferredConnection(ERP5TypeTestCase):
"""
connection = self.getDeferredConnection()
# Queue a query
connection.query('REPLACE INTO `full_text` SET `uid`=0, `SearchableText`="dummy test"')
connection.query(b'REPLACE INTO `full_text` SET `uid`=0, `SearchableText`="dummy test"')
self.assertEqual(len(connection._sql_string_list), 1)
self.commit()
connection.query('REPLACE INTO `full_text` SET `uid`=0, `SearchableText`="dummy test"')
connection.query(b'REPLACE INTO `full_text` SET `uid`=0, `SearchableText`="dummy test"')
self.assertEqual(len(connection._sql_string_list), 1)
if __name__ == '__main__':
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment