Commit 12c71939 authored by Arnaud Fontaine's avatar Arnaud Fontaine

py3: _mysql.string_literal() returns bytes().

And _mysql/mysqldb API (_mysql.connection.query()) converts the query string to
bytes() (additionally, cursor.execute(QUERY, ARGS) calls query() after
converting everything to bytes() too).
parent 1d64f8ed
...@@ -107,14 +107,14 @@ def SQLLock(db, lock_name, timeout): ...@@ -107,14 +107,14 @@ def SQLLock(db, lock_name, timeout):
""" """
lock_name = db.string_literal(lock_name) lock_name = db.string_literal(lock_name)
query = db.query query = db.query
(_, ((acquired, ), )) = query('SELECT GET_LOCK(%s, %f)' % (lock_name, timeout)) (_, ((acquired, ), )) = query(b'SELECT GET_LOCK(%s, %f)' % (lock_name, timeout))
if acquired is None: if acquired is None:
raise ValueError('Error acquiring lock') raise ValueError('Error acquiring lock')
try: try:
yield acquired yield acquired
finally: finally:
if acquired: if acquired:
query('SELECT RELEASE_LOCK(%s)' % (lock_name, )) query(b'SELECT RELEASE_LOCK(%s)' % (lock_name, ))
# sqltest_dict ({'condition_name': <render_function>}) defines how to render # sqltest_dict ({'condition_name': <render_function>}) defines how to render
# condition statements in the SQL query used by SQLBase.getMessageList # condition statements in the SQL query used by SQLBase.getMessageList
def sqltest_dict(): def sqltest_dict():
...@@ -122,23 +122,23 @@ def sqltest_dict(): ...@@ -122,23 +122,23 @@ def sqltest_dict():
def _(name, column=None, op="="): def _(name, column=None, op="="):
if column is None: if column is None:
column = name column = name
column_op = "%s %s " % (column, op) column_op = ("%s %s " % (column, op)).encode()
def render(value, render_string): def render(value, render_string):
if isinstance(value, _SQLTEST_NO_QUOTE_TYPE_SET): if isinstance(value, _SQLTEST_NO_QUOTE_TYPE_SET):
return column_op + str(value) return column_op + str(value).encode()
if isinstance(value, DateTime): if isinstance(value, DateTime):
value = render_datetime(value) value = render_datetime(value)
if isinstance(value, basestring): if isinstance(value, basestring):
return column_op + render_string(value) return column_op + render_string(value)
assert op == "=", value assert op == "=", value
if value is None: # XXX: see comment in SQLBase._getMessageList if value is None: # XXX: see comment in SQLBase._getMessageList
return column + " IS NULL" return column + b" IS NULL"
for x in value: for x in value:
return "%s IN (%s)" % (column, ', '.join(map( return b"%s IN (%s)" % (column, ', '.join(map(
str if isinstance(x, _SQLTEST_NO_QUOTE_TYPE_SET) else str if isinstance(x, _SQLTEST_NO_QUOTE_TYPE_SET) else
render_datetime if isinstance(x, DateTime) else render_datetime if isinstance(x, DateTime) else
render_string, value))) render_string, value)).encode())
return "0" return b"0"
sqltest_dict[name] = render sqltest_dict[name] = render
_('active_process_uid') _('active_process_uid')
_('group_method_id') _('group_method_id')
...@@ -158,13 +158,13 @@ def sqltest_dict(): ...@@ -158,13 +158,13 @@ def sqltest_dict():
assert isinstance(priority, _SQLTEST_NO_QUOTE_TYPE_SET) assert isinstance(priority, _SQLTEST_NO_QUOTE_TYPE_SET)
assert isinstance(uid, _SQLTEST_NO_QUOTE_TYPE_SET) assert isinstance(uid, _SQLTEST_NO_QUOTE_TYPE_SET)
return ( return (
'(priority>%(priority)s OR (priority=%(priority)s AND ' b'(priority>%(priority)d OR (priority=%(priority)d AND '
'(date>%(date)s OR (date=%(date)s AND uid>%(uid)s))' b'(date>%(date)s OR (date=%(date)s AND uid>%(uid)d))'
'))' % { b'))' % {
'priority': priority, b'priority': priority,
# render_datetime raises if "date" lacks date API, so no need to check # render_datetime raises if "date" lacks date API, so no need to check
'date': render_string(render_datetime(date)), b'date': render_string(render_datetime(date)),
'uid': uid, b'uid': uid,
} }
) )
sqltest_dict['above_priority_date_uid'] = renderAbovePriorityDateUid sqltest_dict['above_priority_date_uid'] = renderAbovePriorityDateUid
...@@ -175,7 +175,7 @@ def _validate_after_path_and_method_id(value, render_string): ...@@ -175,7 +175,7 @@ def _validate_after_path_and_method_id(value, render_string):
path, method_id = value path, method_id = value
return ( return (
sqltest_dict['method_id'](method_id, render_string) + sqltest_dict['method_id'](method_id, render_string) +
' AND ' + b' AND ' +
sqltest_dict['path'](path, render_string) sqltest_dict['path'](path, render_string)
) )
...@@ -183,7 +183,7 @@ def _validate_after_tag_and_method_id(value, render_string): ...@@ -183,7 +183,7 @@ def _validate_after_tag_and_method_id(value, render_string):
tag, method_id = value tag, method_id = value
return ( return (
sqltest_dict['method_id'](method_id, render_string) + sqltest_dict['method_id'](method_id, render_string) +
' AND ' + b' AND ' +
sqltest_dict['tag'](tag, render_string) sqltest_dict['tag'](tag, render_string)
) )
...@@ -287,18 +287,18 @@ CREATE TABLE %s ( ...@@ -287,18 +287,18 @@ CREATE TABLE %s (
% (self.sql_table, src)) % (self.sql_table, src))
self._insert_max_payload = (db.getMaxAllowedPacket() self._insert_max_payload = (db.getMaxAllowedPacket()
+ len(self._insert_separator) + len(self._insert_separator)
- len(self._insert_template % (self.sql_table, ''))) - len(self._insert_template % (self.sql_table.encode(), b'')))
def _initialize(self, db, column_list): def _initialize(self, db, column_list):
LOG('CMFActivity', ERROR, "Non-empty %r table upgraded." LOG('CMFActivity', ERROR, "Non-empty %r table upgraded."
" The following added columns could not be initialized: %s" " The following added columns could not be initialized: %s"
% (self.sql_table, ", ".join(column_list))) % (self.sql_table, ", ".join(column_list)))
_insert_template = ("INSERT INTO %s (uid," _insert_template = (b"INSERT INTO %s (uid,"
" path, active_process_uid, date, method_id, processing_node," b" path, active_process_uid, date, method_id, processing_node,"
" priority, node, group_method_id, tag, serialization_tag," b" priority, node, group_method_id, tag, serialization_tag,"
" message) VALUES\n(%s)") b" message) VALUES\n(%s)")
_insert_separator = "),\n(" _insert_separator = b"),\n("
def _hasDependency(self, message): def _hasDependency(self, message):
get = message.activity_kw.get get = message.activity_kw.get
...@@ -317,10 +317,11 @@ CREATE TABLE %s ( ...@@ -317,10 +317,11 @@ CREATE TABLE %s (
if reset_uid: if reset_uid:
reset_uid = False reset_uid = False
# Overflow will result into IntegrityError. # Overflow will result into IntegrityError.
db.query("SET @uid := %s" % getrandbits(UID_SAFE_BITSIZE)) db.query(b"SET @uid := %d" % getrandbits(UID_SAFE_BITSIZE))
try: try:
db.query(self._insert_template % (self.sql_table, values)) db.query(self._insert_template % (self.sql_table.encode(), values))
except MySQLdb.IntegrityError, (code, _): except MySQLdb.IntegrityError as xxx_todo_changeme:
(code, _) = xxx_todo_changeme.args
if code != DUP_ENTRY: if code != DUP_ENTRY:
raise raise
reset_uid = True reset_uid = True
...@@ -338,18 +339,18 @@ CREATE TABLE %s ( ...@@ -338,18 +339,18 @@ CREATE TABLE %s (
if m.is_registered: if m.is_registered:
active_process_uid = m.active_process_uid active_process_uid = m.active_process_uid
date = m.activity_kw.get('at_date') date = m.activity_kw.get('at_date')
row = ','.join(( row = b','.join((
'@uid+%s' % i, b'@uid+%d' % i,
quote('/'.join(m.object_path)), quote('/'.join(m.object_path)),
'NULL' if active_process_uid is None else str(active_process_uid), b'NULL' if active_process_uid is None else str(active_process_uid).encode(),
"UTC_TIMESTAMP(6)" if date is None else quote(render_datetime(date)), b"UTC_TIMESTAMP(6)" if date is None else quote(render_datetime(date)),
quote(m.method_id), quote(m.method_id),
'-1' if hasDependency(m) else '0', b'-1' if hasDependency(m) else b'0',
str(m.activity_kw.get('priority', 1)), str(m.activity_kw.get('priority', 1)).encode(),
str(m.activity_kw.get('node', 0)), str(m.activity_kw.get('node', 0)).encode(),
quote(m.getGroupId()), quote(m.getGroupId()),
quote(m.activity_kw.get('tag', '')), quote(m.activity_kw.get('tag', b'')),
quote(m.activity_kw.get('serialization_tag', '')), quote(m.activity_kw.get('serialization_tag', b'')),
quote(Message.dump(m)))) quote(Message.dump(m))))
i += 1 i += 1
n = sep_len + len(row) n = sep_len + len(row)
...@@ -370,11 +371,11 @@ CREATE TABLE %s ( ...@@ -370,11 +371,11 @@ CREATE TABLE %s (
# value should be ignored, instead of trying to render them # value should be ignored, instead of trying to render them
# (with comparisons with NULL). # (with comparisons with NULL).
q = db.string_literal q = db.string_literal
sql = '\n AND '.join(sqltest_dict[k](v, q) for k, v in six.iteritems(kw)) sql = b'\n AND '.join(sqltest_dict[k](v, q) for k, v in six.iteritems(kw))
sql = "SELECT * FROM %s%s\nORDER BY priority, date, uid%s" % ( sql = b"SELECT * FROM %s%s\nORDER BY priority, date, uid%s" % (
self.sql_table, self.sql_table.encode(),
sql and '\nWHERE ' + sql, sql and b'\nWHERE ' + sql,
'' if count is None else '\nLIMIT %d' % count, b'' if count is None else b'\nLIMIT %d' % count,
) )
return sql if src__ else Results(db.query(sql, max_rows=0)) return sql if src__ else Results(db.query(sql, max_rows=0))
...@@ -399,17 +400,17 @@ CREATE TABLE %s ( ...@@ -399,17 +400,17 @@ CREATE TABLE %s (
def hasActivitySQL(self, quote, only_valid=False, only_invalid=False, **kw): def hasActivitySQL(self, quote, only_valid=False, only_invalid=False, **kw):
where = [sqltest_dict[k](v, quote) for (k, v) in six.iteritems(kw) if v] where = [sqltest_dict[k](v, quote) for (k, v) in six.iteritems(kw) if v]
if only_valid: if only_valid:
where.append('processing_node > %d' % INVOKE_ERROR_STATE) where.append(b'processing_node > %d' % INVOKE_ERROR_STATE)
if only_invalid: if only_invalid:
where.append('processing_node <= %d' % INVOKE_ERROR_STATE) where.append(b'processing_node <= %d' % INVOKE_ERROR_STATE)
return "SELECT 1 FROM %s WHERE %s LIMIT 1" % ( return b"SELECT 1 FROM %s WHERE %s LIMIT 1" % (
self.sql_table, " AND ".join(where) or "1") self.sql_table.encode(), b" AND ".join(where) or b"1")
def getPriority(self, activity_tool, processing_node, node_set=None): def getPriority(self, activity_tool, processing_node, node_set=None):
if node_set is None: if node_set is None:
q = ("SELECT 3*priority, date FROM %s" q = (b"SELECT 3*priority, date FROM %s"
" WHERE processing_node=0 AND date <= UTC_TIMESTAMP(6)" b" WHERE processing_node=0 AND date <= UTC_TIMESTAMP(6)"
" ORDER BY priority, date LIMIT 1" % self.sql_table) b" ORDER BY priority, date LIMIT 1" % self.sql_table.encode())
else: else:
subquery = ("(SELECT 3*priority{} as effective_priority, date FROM %s" subquery = ("(SELECT 3*priority{} as effective_priority, date FROM %s"
" WHERE {} AND processing_node=0 AND date <= UTC_TIMESTAMP(6)" " WHERE {} AND processing_node=0 AND date <= UTC_TIMESTAMP(6)"
...@@ -417,12 +418,12 @@ CREATE TABLE %s ( ...@@ -417,12 +418,12 @@ CREATE TABLE %s (
node = 'node=%s' % processing_node node = 'node=%s' % processing_node
# "ALL" on all but one, to incur deduplication cost only once. # "ALL" on all but one, to incur deduplication cost only once.
# "UNION ALL" between the two naturally distinct sets. # "UNION ALL" between the two naturally distinct sets.
q = ("SELECT * FROM (%s UNION ALL %s UNION %s%s) as t" q = (b"SELECT * FROM (%s UNION ALL %s UNION %s%s) as t"
" ORDER BY effective_priority, date LIMIT 1" % ( b" ORDER BY effective_priority, date LIMIT 1" % (
subquery(-1, node), subquery(-1, node).encode(),
subquery('', 'node=0'), subquery('', 'node=0').encode(),
subquery('+IF(node, IF(%s, -1, 1), 0)' % node, 'node>=0'), subquery('+IF(node, IF(%s, -1, 1), 0)' % node, 'node>=0').encode(),
' UNION ALL ' + subquery(-1, 'node IN (%s)' % ','.join(map(str, node_set))) if node_set else '', b' UNION ALL ' + subquery(-1, 'node IN (%s)' % ','.join(map(str, node_set))).encode() if node_set else b'',
)) ))
result = activity_tool.getSQLConnection().query(q, 0)[1] result = activity_tool.getSQLConnection().query(q, 0)[1]
if result: if result:
...@@ -595,18 +596,18 @@ CREATE TABLE %s ( ...@@ -595,18 +596,18 @@ CREATE TABLE %s (
if len(column_list) == 1 else if len(column_list) == 1 else
_IDENTITY _IDENTITY
) )
base_sql_suffix = ' WHERE processing_node > %i AND (%%s) LIMIT 1)' % ( base_sql_suffix = b' WHERE processing_node > %i AND (%%s) LIMIT 1)' % (
min_processing_node, min_processing_node,
) )
sql_suffix_list = [ sql_suffix_list = [
base_sql_suffix % to_sql(dependency_value, quote) base_sql_suffix % to_sql(dependency_value, quote)
for dependency_value in dependency_value_dict for dependency_value in dependency_value_dict
] ]
base_sql_prefix = '(SELECT %s FROM ' % ( base_sql_prefix = b'(SELECT %s FROM ' % (
','.join(column_list), b','.join([ c.encode() for c in column_list ]),
) )
subquery_list = [ subquery_list = [
base_sql_prefix + table_name + sql_suffix base_sql_prefix + table_name.encode() + sql_suffix
for table_name in table_name_list for table_name in table_name_list
for sql_suffix in sql_suffix_list for sql_suffix in sql_suffix_list
] ]
...@@ -617,7 +618,7 @@ CREATE TABLE %s ( ...@@ -617,7 +618,7 @@ CREATE TABLE %s (
# by the number of activty tables: it is also proportional to the # by the number of activty tables: it is also proportional to the
# number of distinct values being looked for in the current column. # number of distinct values being looked for in the current column.
for row in db.query( for row in db.query(
' UNION '.join(subquery_list[_MAX_DEPENDENCY_UNION_SUBQUERY_COUNT:]), b' UNION '.join(subquery_list[_MAX_DEPENDENCY_UNION_SUBQUERY_COUNT:]),
max_rows=0, max_rows=0,
)[1]: )[1]:
# Each row is a value which blocks some activities. # Each row is a value which blocks some activities.
...@@ -691,9 +692,9 @@ CREATE TABLE %s ( ...@@ -691,9 +692,9 @@ CREATE TABLE %s (
assert limit assert limit
quote = db.string_literal quote = db.string_literal
query = db.query query = db.query
args = (self.sql_table, sqltest_dict['to_date'](date, quote), args = (self.sql_table.encode(), sqltest_dict['to_date'](date, quote),
' AND group_method_id=' + quote(group_method_id) b' AND group_method_id=' + quote(group_method_id)
if group_method_id else '' , limit) if group_method_id else b'' , limit)
# Note: Not all write accesses to our table are protected by this lock. # Note: Not all write accesses to our table are protected by this lock.
# This lock is not here for data consistency reasons, but to avoid wasting # This lock is not here for data consistency reasons, but to avoid wasting
...@@ -722,25 +723,25 @@ CREATE TABLE %s ( ...@@ -722,25 +723,25 @@ CREATE TABLE %s (
# time). # time).
if node_set is None: if node_set is None:
result = Results(query( result = Results(query(
"SELECT * FROM %s WHERE processing_node=0 AND %s%s" b"SELECT * FROM %s WHERE processing_node=0 AND %s%s"
" ORDER BY priority, date LIMIT %s FOR UPDATE" % args, 0)) b" ORDER BY priority, date LIMIT %d FOR UPDATE" % args, 0))
else: else:
# We'd like to write # We'd like to write
# ORDER BY priority, IF(node, IF(node={node}, -1, 1), 0), date # ORDER BY priority, IF(node, IF(node={node}, -1, 1), 0), date
# but this makes indices inefficient. # but this makes indices inefficient.
subquery = ("(SELECT *, 3*priority{} as effective_priority FROM %s" subquery = (b"(SELECT *, 3*priority%%s as effective_priority FROM %s"
" WHERE {} AND processing_node=0 AND %s%s" b" WHERE %%s AND processing_node=0 AND %s%s"
" ORDER BY priority, date LIMIT %s FOR UPDATE)" % args).format b" ORDER BY priority, date LIMIT %d FOR UPDATE)" % args)
node = 'node=%s' % processing_node node = b'node=%d' % processing_node
result = Results(query( result = Results(query(
# "ALL" on all but one, to incur deduplication cost only once. # "ALL" on all but one, to incur deduplication cost only once.
# "UNION ALL" between the two naturally distinct sets. # "UNION ALL" between the two naturally distinct sets.
"SELECT * FROM (%s UNION ALL %s UNION %s%s) as t" b"SELECT * FROM (%s UNION ALL %s UNION %s%s) as t"
" ORDER BY effective_priority, date LIMIT %s"% ( b" ORDER BY effective_priority, date LIMIT %d"% (
subquery(-1, node), subquery % (b'-1', node),
subquery('', 'node=0'), subquery % (b'', b'node=0'),
subquery('+IF(node, IF(%s, -1, 1), 0)' % node, 'node>=0'), subquery % (b'+IF(node, IF(%s, -1, 1), 0)' % node, b'node>=0'),
' UNION ALL ' + subquery(-1, 'node IN (%s)' % ','.join(map(str, node_set))) if node_set else '', b' UNION ALL ' + subquery % (str(-1), b'node IN (%s)' % b','.join(map(str, node_set)).encode()) if node_set else b'',
limit), 0)) limit), 0))
if result: if result:
# Reserve messages. # Reserve messages.
...@@ -803,9 +804,9 @@ CREATE TABLE %s ( ...@@ -803,9 +804,9 @@ CREATE TABLE %s (
# To minimize the probability of deadlocks, we also COMMIT so that a # To minimize the probability of deadlocks, we also COMMIT so that a
# new transaction starts on the first 'FOR UPDATE' query, which is all # new transaction starts on the first 'FOR UPDATE' query, which is all
# the more important as the current on started with getPriority(). # the more important as the current on started with getPriority().
result = db.query("SELECT * FROM %s WHERE processing_node=%s" result = db.query(b"SELECT * FROM %s WHERE processing_node=%d"
" ORDER BY priority, date LIMIT 1\0COMMIT" % ( b" ORDER BY priority, date LIMIT 1\0COMMIT" % (
self.sql_table, processing_node), 0) self.sql_table.encode(), processing_node), 0)
already_assigned = result[1] already_assigned = result[1]
if already_assigned: if already_assigned:
result = Results(result) result = Results(result)
...@@ -834,10 +835,10 @@ CREATE TABLE %s ( ...@@ -834,10 +835,10 @@ CREATE TABLE %s (
cost *= count cost *= count
# Retrieve objects which have the same group method. # Retrieve objects which have the same group method.
result = iter(already_assigned result = iter(already_assigned
and Results(db.query("SELECT * FROM %s" and Results(db.query(b"SELECT * FROM %s"
" WHERE processing_node=%s AND group_method_id=%s" b" WHERE processing_node=%d AND group_method_id=%s"
" ORDER BY priority, date LIMIT %s" % ( b" ORDER BY priority, date LIMIT %d" % (
self.sql_table, processing_node, self.sql_table.encode(), processing_node,
db.string_literal(group_method_id), limit), 0)) db.string_literal(group_method_id), limit), 0))
# Do not optimize rare case: keep the code simple by not # Do not optimize rare case: keep the code simple by not
# adding more results from getReservedMessageList if the # adding more results from getReservedMessageList if the
......
...@@ -85,7 +85,7 @@ class SQLDict(SQLBase): ...@@ -85,7 +85,7 @@ class SQLDict(SQLBase):
uid = line.uid uid = line.uid
original_uid = path_and_method_id_dict.get(key) original_uid = path_and_method_id_dict.get(key)
if original_uid is None: if original_uid is None:
sql_method_id = " AND method_id = %s AND group_method_id = %s" % ( sql_method_id = b" AND method_id = %s AND group_method_id = %s" % (
quote(method_id), quote(line.group_method_id)) quote(method_id), quote(line.group_method_id))
m = Message.load(line.message, uid=uid, line=line) m = Message.load(line.message, uid=uid, line=line)
merge_parent = m.activity_kw.get('merge_parent') merge_parent = m.activity_kw.get('merge_parent')
...@@ -102,11 +102,11 @@ class SQLDict(SQLBase): ...@@ -102,11 +102,11 @@ class SQLDict(SQLBase):
uid_list = [] uid_list = []
if path_list: if path_list:
# Select parent messages. # Select parent messages.
result = Results(db.query("SELECT * FROM message" result = Results(db.query(b"SELECT * FROM message"
" WHERE processing_node IN (0, %s) AND path IN (%s)%s" b" WHERE processing_node IN (0, %d) AND path IN (%s)%s"
" ORDER BY path LIMIT 1 FOR UPDATE" % ( b" ORDER BY path LIMIT 1 FOR UPDATE" % (
processing_node, processing_node,
','.join(map(quote, path_list)), b','.join(map(quote, path_list)),
sql_method_id, sql_method_id,
), 0)) ), 0))
if result: # found a parent if result: # found a parent
...@@ -119,11 +119,11 @@ class SQLDict(SQLBase): ...@@ -119,11 +119,11 @@ class SQLDict(SQLBase):
m = Message.load(line.message, uid=uid, line=line) m = Message.load(line.message, uid=uid, line=line)
# return unreserved similar children # return unreserved similar children
path = line.path path = line.path
result = db.query("SELECT uid FROM message" result = db.query(b"SELECT uid FROM message"
" WHERE processing_node = 0 AND (path = %s OR path LIKE %s)" b" WHERE processing_node = 0 AND (path = %s OR path LIKE %s)"
"%s FOR UPDATE" % ( b"%s FOR UPDATE" % (
quote(path), quote(path.replace('_', r'\_') + '/%'), quote(path), quote(path.replace('_', r'\_') + '/%'),
sql_method_id, sql_method_id.encode(),
), 0)[1] ), 0)[1]
reserve_uid_list = [x for x, in result] reserve_uid_list = [x for x, in result]
uid_list += reserve_uid_list uid_list += reserve_uid_list
...@@ -132,8 +132,8 @@ class SQLDict(SQLBase): ...@@ -132,8 +132,8 @@ class SQLDict(SQLBase):
reserve_uid_list.append(uid) reserve_uid_list.append(uid)
else: else:
# Select duplicates. # Select duplicates.
result = db.query("SELECT uid FROM message" result = db.query(b"SELECT uid FROM message"
" WHERE processing_node = 0 AND path = %s%s FOR UPDATE" % ( b" WHERE processing_node = 0 AND path = %s%s FOR UPDATE" % (
quote(path), sql_method_id, quote(path), sql_method_id,
), 0)[1] ), 0)[1]
reserve_uid_list = uid_list = [x for x, in result] reserve_uid_list = uid_list = [x for x, in result]
......
...@@ -1413,7 +1413,7 @@ class ActivityTool (BaseTool): ...@@ -1413,7 +1413,7 @@ class ActivityTool (BaseTool):
path = None if obj is None else '/'.join(obj.getPhysicalPath()) path = None if obj is None else '/'.join(obj.getPhysicalPath())
db = self.getSQLConnection() db = self.getSQLConnection()
quote = db.string_literal quote = db.string_literal
return bool(db.query("(%s)" % ") UNION ALL (".join( return bool(db.query(b"(%s)" % b") UNION ALL (".join(
activity.hasActivitySQL(quote, path=path, **kw) activity.hasActivitySQL(quote, path=path, **kw)
for activity in six.itervalues(activity_dict)))[1]) for activity in six.itervalues(activity_dict)))[1])
......
...@@ -111,6 +111,7 @@ from Shared.DC.ZRDB.TM import TM ...@@ -111,6 +111,7 @@ from Shared.DC.ZRDB.TM import TM
from DateTime import DateTime from DateTime import DateTime
from zLOG import LOG, ERROR, WARNING from zLOG import LOG, ERROR, WARNING
from ZODB.POSException import ConflictError from ZODB.POSException import ConflictError
from Products.ERP5Type.Utils import str2bytes
hosed_connection = ( hosed_connection = (
CR.SERVER_GONE_ERROR, CR.SERVER_GONE_ERROR,
...@@ -203,7 +204,7 @@ def ord_or_None(s): ...@@ -203,7 +204,7 @@ def ord_or_None(s):
return ord(s) return ord(s)
match_select = re.compile( match_select = re.compile(
r'(?:SET\s+STATEMENT\s+(.+?)\s+FOR\s+)?SELECT\s+(.+)', rb'(?:SET\s+STATEMENT\s+(.+?)\s+FOR\s+)?SELECT\s+(.+)',
re.IGNORECASE | re.DOTALL, re.IGNORECASE | re.DOTALL,
).match ).match
...@@ -418,12 +419,14 @@ class DB(TM): ...@@ -418,12 +419,14 @@ class DB(TM):
"""Execute 'query_string' and return at most 'max_rows'.""" """Execute 'query_string' and return at most 'max_rows'."""
self._use_TM and self._register() self._use_TM and self._register()
desc = None desc = None
if not isinstance(query_string, bytes):
query_string = str2bytes(query_string)
# XXX deal with a typical mistake that the user appends # XXX deal with a typical mistake that the user appends
# an unnecessary and rather harmful semicolon at the end. # an unnecessary and rather harmful semicolon at the end.
# Unfortunately, MySQLdb does not want to be graceful. # Unfortunately, MySQLdb does not want to be graceful.
if query_string[-1:] == ';': if query_string[-1:] == b';':
query_string = query_string[:-1] query_string = query_string[:-1]
for qs in query_string.split('\0'): for qs in query_string.split(b'\0'):
qs = qs.strip() qs = qs.strip()
if qs: if qs:
select_match = match_select(qs) select_match = match_select(qs)
...@@ -432,12 +435,12 @@ class DB(TM): ...@@ -432,12 +435,12 @@ class DB(TM):
if query_timeout is not None: if query_timeout is not None:
statement, select = select_match.groups() statement, select = select_match.groups()
if statement: if statement:
statement += ", max_statement_time=%f" % query_timeout statement += b", max_statement_time=%f" % query_timeout
else: else:
statement = "max_statement_time=%f" % query_timeout statement = b"max_statement_time=%f" % query_timeout
qs = "SET STATEMENT %s FOR SELECT %s" % (statement, select) qs = b"SET STATEMENT %s FOR SELECT %s" % (statement, select)
if max_rows: if max_rows:
qs = "%s LIMIT %d" % (qs, max_rows) qs = b"%s LIMIT %d" % (qs, max_rows)
c = self._query(qs) c = self._query(qs)
if c: if c:
if desc is not None is not c.describe(): if desc is not None is not c.describe():
...@@ -640,7 +643,7 @@ class DeferredDB(DB): ...@@ -640,7 +643,7 @@ class DeferredDB(DB):
def query(self, query_string, max_rows=1000): def query(self, query_string, max_rows=1000):
self._register() self._register()
for qs in query_string.split('\0'): for qs in query_string.split(b'\0'):
qs = qs.strip() qs = qs.strip()
if qs: if qs:
if match_select(qs): if match_select(qs):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment