Commit 18cd0bf9 authored by Julien Muchembled's avatar Julien Muchembled Committed by Cédric Le Ninivin

CMFActivity: limit insertion by size in bytes instead of number of rows

This fixes the issue that a transaction with many big messages failed to
commit. By dynamically find the maximum allowed size of a query, it also
speeds up insertion by minimizing the number of queries.
parent 94a4c950
...@@ -46,8 +46,6 @@ from Products.CMFActivity.Errors import ActivityFlushError ...@@ -46,8 +46,6 @@ from Products.CMFActivity.Errors import ActivityFlushError
MAX_VALIDATED_LIMIT = 1000 MAX_VALIDATED_LIMIT = 1000
# Read this many messages to validate. # Read this many messages to validate.
READ_MESSAGE_LIMIT = 1000 READ_MESSAGE_LIMIT = 1000
# TODO: Limit by size in bytes instead of number of rows.
MAX_MESSAGE_LIST_SIZE = 100
INVOKE_ERROR_STATE = -2 INVOKE_ERROR_STATE = -2
# Activity uids are stored as 64 bits unsigned integers. # Activity uids are stored as 64 bits unsigned integers.
# No need to depend on a database that supports unsigned integers. # No need to depend on a database that supports unsigned integers.
...@@ -163,17 +161,26 @@ CREATE TABLE %s ( ...@@ -163,17 +161,26 @@ CREATE TABLE %s (
if src: if src:
LOG('CMFActivity', INFO, "%r table upgraded\n%s" LOG('CMFActivity', INFO, "%r table upgraded\n%s"
% (self.sql_table, src)) % (self.sql_table, src))
self._insert_max_payload = (db.getMaxAllowedPacket()
+ len(self._insert_separator)
- len(self._insert_template % (self.sql_table, '')))
def _initialize(self, db, column_list): def _initialize(self, db, column_list):
LOG('CMFActivity', ERROR, "Non-empty %r table upgraded." LOG('CMFActivity', ERROR, "Non-empty %r table upgraded."
" The following added columns could not be initialized: %s" " The following added columns could not be initialized: %s"
% (self.sql_table, ", ".join(column_list))) % (self.sql_table, ", ".join(column_list)))
_insert_template = ("INSERT INTO %s (uid,"
" path, active_process_uid, date, method_id, processing_node,"
" priority, group_method_id, tag, serialization_tag,"
" message) VALUES\n(%s)")
_insert_separator = "),\n("
def prepareQueueMessageList(self, activity_tool, message_list): def prepareQueueMessageList(self, activity_tool, message_list):
db = activity_tool.getSQLConnection() db = activity_tool.getSQLConnection()
quote = db.string_literal quote = db.string_literal
def insert(reset_uid): def insert(reset_uid):
values = "),\n(".join(values_list) values = self._insert_separator.join(values_list)
del values_list[:] del values_list[:]
for _ in xrange(UID_ALLOCATION_TRY_COUNT): for _ in xrange(UID_ALLOCATION_TRY_COUNT):
if reset_uid: if reset_uid:
...@@ -181,10 +188,7 @@ CREATE TABLE %s ( ...@@ -181,10 +188,7 @@ CREATE TABLE %s (
# Overflow will result into IntegrityError. # Overflow will result into IntegrityError.
db.query("SET @uid := %s" % getrandbits(UID_SAFE_BITSIZE)) db.query("SET @uid := %s" % getrandbits(UID_SAFE_BITSIZE))
try: try:
db.query("INSERT INTO %s (uid," db.query(self._insert_template % (self.sql_table, values))
" path, active_process_uid, date, method_id, processing_node,"
" priority, group_method_id, tag, serialization_tag,"
" message) VALUES\n(%s)" % (self.sql_table, values))
except MySQLdb.IntegrityError, (code, _): except MySQLdb.IntegrityError, (code, _):
if code != DUP_ENTRY: if code != DUP_ENTRY:
raise raise
...@@ -196,13 +200,15 @@ CREATE TABLE %s ( ...@@ -196,13 +200,15 @@ CREATE TABLE %s (
i = 0 i = 0
reset_uid = True reset_uid = True
values_list = [] values_list = []
max_payload = self._insert_max_payload
sep_len = len(self._insert_separator)
for m in message_list: for m in message_list:
if m.is_registered: if m.is_registered:
active_process_uid = m.active_process_uid active_process_uid = m.active_process_uid
order_validation_text = m.order_validation_text = \ order_validation_text = m.order_validation_text = \
self.getOrderValidationText(m) self.getOrderValidationText(m)
date = m.activity_kw.get('at_date') date = m.activity_kw.get('at_date')
values_list.append(','.join(( row = ','.join((
'@uid+%s' % i, '@uid+%s' % i,
quote('/'.join(m.object_path)), quote('/'.join(m.object_path)),
'NULL' if active_process_uid is None else str(active_process_uid), 'NULL' if active_process_uid is None else str(active_process_uid),
...@@ -213,11 +219,18 @@ CREATE TABLE %s ( ...@@ -213,11 +219,18 @@ CREATE TABLE %s (
quote(m.getGroupId()), quote(m.getGroupId()),
quote(m.activity_kw.get('tag', '')), quote(m.activity_kw.get('tag', '')),
quote(m.activity_kw.get('serialization_tag', '')), quote(m.activity_kw.get('serialization_tag', '')),
quote(Message.dump(m))))) quote(Message.dump(m))))
i += 1 i += 1
if not i % MAX_MESSAGE_LIST_SIZE: n = sep_len + len(row)
max_payload -= n
if max_payload < 0:
if values_list:
insert(reset_uid) insert(reset_uid)
reset_uid = False reset_uid = False
max_payload = self._insert_max_payload - n
else:
raise ValueError("max_allowed_packet too small to insert message")
values_list.append(row)
if values_list: if values_list:
insert(reset_uid) insert(reset_uid)
......
...@@ -31,7 +31,7 @@ from zLOG import LOG, TRACE, INFO, WARNING, ERROR, PANIC ...@@ -31,7 +31,7 @@ from zLOG import LOG, TRACE, INFO, WARNING, ERROR, PANIC
import MySQLdb import MySQLdb
from MySQLdb.constants.ER import DUP_ENTRY from MySQLdb.constants.ER import DUP_ENTRY
from SQLBase import ( from SQLBase import (
SQLBase, sort_message_key, MAX_MESSAGE_LIST_SIZE, SQLBase, sort_message_key,
UID_SAFE_BITSIZE, UID_ALLOCATION_TRY_COUNT, UID_SAFE_BITSIZE, UID_ALLOCATION_TRY_COUNT,
) )
from Products.CMFActivity.ActivityTool import Message from Products.CMFActivity.ActivityTool import Message
...@@ -75,11 +75,16 @@ CREATE TABLE %s ( ...@@ -75,11 +75,16 @@ CREATE TABLE %s (
return (tuple(m.object_path), m.method_id, m.activity_kw.get('signature'), return (tuple(m.object_path), m.method_id, m.activity_kw.get('signature'),
m.activity_kw.get('tag'), m.activity_kw.get('group_id')) m.activity_kw.get('tag'), m.activity_kw.get('group_id'))
_insert_template = ("INSERT INTO %s (uid,"
" path, active_process_uid, date, method_id, processing_node,"
" priority, group_method_id, tag, signature, serialization_tag,"
" message) VALUES\n(%s)")
def prepareQueueMessageList(self, activity_tool, message_list): def prepareQueueMessageList(self, activity_tool, message_list):
db = activity_tool.getSQLConnection() db = activity_tool.getSQLConnection()
quote = db.string_literal quote = db.string_literal
def insert(reset_uid): def insert(reset_uid):
values = "),\n(".join(values_list) values = self._insert_separator.join(values_list)
del values_list[:] del values_list[:]
for _ in xrange(UID_ALLOCATION_TRY_COUNT): for _ in xrange(UID_ALLOCATION_TRY_COUNT):
if reset_uid: if reset_uid:
...@@ -87,10 +92,7 @@ CREATE TABLE %s ( ...@@ -87,10 +92,7 @@ CREATE TABLE %s (
# Overflow will result into IntegrityError. # Overflow will result into IntegrityError.
db.query("SET @uid := %s" % getrandbits(UID_SAFE_BITSIZE)) db.query("SET @uid := %s" % getrandbits(UID_SAFE_BITSIZE))
try: try:
db.query("INSERT INTO %s (uid," db.query(self._insert_template % (self.sql_table, values))
" path, active_process_uid, date, method_id, processing_node,"
" priority, group_method_id, tag, signature, serialization_tag,"
" message) VALUES\n(%s)" % (self.sql_table, values))
except MySQLdb.IntegrityError, (code, _): except MySQLdb.IntegrityError, (code, _):
if code != DUP_ENTRY: if code != DUP_ENTRY:
raise raise
...@@ -102,17 +104,19 @@ CREATE TABLE %s ( ...@@ -102,17 +104,19 @@ CREATE TABLE %s (
i = 0 i = 0
reset_uid = True reset_uid = True
values_list = [] values_list = []
max_payload = self._insert_max_payload
sep_len = len(self._insert_separator)
for m in message_list: for m in message_list:
if m.is_registered: if m.is_registered:
active_process_uid = m.active_process_uid active_process_uid = m.active_process_uid
order_validation_text = m.order_validation_text = \ order_validation_text = m.order_validation_text = \
self.getOrderValidationText(m) self.getOrderValidationText(m)
date = m.activity_kw.get('at_date') date = m.activity_kw.get('at_date')
values_list.append(','.join(( row = ','.join((
'@uid+%s' % i, '@uid+%s' % i,
quote('/'.join(m.object_path)), quote('/'.join(m.object_path)),
'NULL' if active_process_uid is None else str(active_process_uid), 'NULL' if active_process_uid is None else str(active_process_uid),
"UTC_TIMESTAMP(6)" if date is None else render_datetime(date), "UTC_TIMESTAMP(6)" if date is None else quote(render_datetime(date)),
quote(m.method_id), quote(m.method_id),
'0' if order_validation_text == 'none' else '-1', '0' if order_validation_text == 'none' else '-1',
str(m.activity_kw.get('priority', 1)), str(m.activity_kw.get('priority', 1)),
...@@ -120,11 +124,18 @@ CREATE TABLE %s ( ...@@ -120,11 +124,18 @@ CREATE TABLE %s (
quote(m.activity_kw.get('tag', '')), quote(m.activity_kw.get('tag', '')),
quote(m.activity_kw.get('signature', '')), quote(m.activity_kw.get('signature', '')),
quote(m.activity_kw.get('serialization_tag', '')), quote(m.activity_kw.get('serialization_tag', '')),
quote(Message.dump(m))))) quote(Message.dump(m))))
i += 1 i += 1
if not i % MAX_MESSAGE_LIST_SIZE: n = sep_len + len(row)
max_payload -= n
if max_payload < 0:
if values_list:
insert(reset_uid) insert(reset_uid)
reset_uid = False reset_uid = False
max_payload = self._insert_max_payload - n
else:
raise ValueError("max_allowed_packet too small to insert message")
values_list.append(row)
if values_list: if values_list:
insert(reset_uid) insert(reset_uid)
......
...@@ -2048,29 +2048,61 @@ class TestCMFActivity(ERP5TypeTestCase, LogInterceptor): ...@@ -2048,29 +2048,61 @@ class TestCMFActivity(ERP5TypeTestCase, LogInterceptor):
DB.query = DB.original_query DB.query = DB.original_query
del DB.original_query del DB.original_query
def test_MAX_MESSAGE_LIST_SIZE(self): def test_insert_max_payload(self):
from Products.CMFActivity.Activity import SQLBase activity_tool = self.portal.portal_activities
MAX_MESSAGE_LIST_SIZE = SQLBase.MAX_MESSAGE_LIST_SIZE max_allowed_packet = activity_tool.getSQLConnection().getMaxAllowedPacket()
try: insert_list = []
SQLBase.MAX_MESSAGE_LIST_SIZE = 3 invoke_list = []
def dummy_counter(o): N = 100
self.__call_count += 1 class Skip(Exception):
o = self.portal.organisation_module.newContent(portal_type='Organisation') """
Speed up test by not interrupting the first transaction
for activity in "SQLDict", "SQLQueue", "SQLJoblib": as soon as we have the information we want.
self.__call_count = 0 """
original_query = DB.query.__func__
def query(self, query_string, *args, **kw):
if query_string.startswith('INSERT'):
insert_list.append(len(query_string))
if not n:
raise Skip
return original_query(self, query_string, *args, **kw)
def check():
for i in xrange(1, N):
activity_tool.activate(activity=activity, group_id=str(i)
).doSomething(arg)
activity_tool.activate(activity=activity, group_id='~'
).doSomething(' ' * n)
self.tic()
self.assertEqual(len(invoke_list), N)
invoke_list.remove(n)
self.assertEqual(set(invoke_list), {len(arg)})
del invoke_list[:]
activity_tool.__class__.doSomething = \
lambda self, arg: invoke_list.append(len(arg))
try: try:
for i in xrange(10): DB.query = query
method_name = 'dummy_counter_%s' % i for activity in ActivityTool.activity_dict:
getattr(o.activate(activity=activity), method_name)() arg = ' ' * (max_allowed_packet // N)
setattr(Organisation, method_name, dummy_counter) # Find the size of the last message argument, such that all messages
self.flushAllActivities() # are inserted in a single query whose size is to the maximum allowed.
finally: n = 0
for i in xrange(10): self.assertRaises(Skip, check)
delattr(Organisation, 'dummy_counter_%s' % i) self.abort()
self.assertEqual(self.__call_count, 10) n = max_allowed_packet - insert_list.pop()
self.assertFalse(insert_list)
# Now check with the biggest insert query possible.
check()
self.assertEqual(max_allowed_packet, insert_list.pop())
self.assertFalse(insert_list)
# And check that the insert query is split
# in order not to exceed max_allowed_packet.
n += 1
check()
self.assertEqual(len(insert_list), 2)
del insert_list[:]
finally: finally:
SQLBase.MAX_MESSAGE_LIST_SIZE = MAX_MESSAGE_LIST_SIZE del activity_tool.__class__.doSomething
DB.query = original_query
def test_115_TestSerializationTagSQLDictPreventsParallelExecution(self): def test_115_TestSerializationTagSQLDictPreventsParallelExecution(self):
""" """
...@@ -2341,38 +2373,6 @@ class TestCMFActivity(ERP5TypeTestCase, LogInterceptor): ...@@ -2341,38 +2373,6 @@ class TestCMFActivity(ERP5TypeTestCase, LogInterceptor):
def test_126_userNotificationSavedOnEventLogWhenSiteErrorLoggerRaisesWithSQLQueue(self): def test_126_userNotificationSavedOnEventLogWhenSiteErrorLoggerRaisesWithSQLQueue(self):
self.TryNotificationSavedOnEventLogWhenSiteErrorLoggerRaises('SQLQueue') self.TryNotificationSavedOnEventLogWhenSiteErrorLoggerRaises('SQLQueue')
def test_127_checkConflictErrorAndNoRemainingActivities(self):
"""
When an activity creates several activities, make sure that all newly
created activities are not commited if there is ZODB Conflict error
"""
from Products.CMFActivity.Activity import SQLBase
MAX_MESSAGE_LIST_SIZE = SQLBase.MAX_MESSAGE_LIST_SIZE
try:
SQLBase.MAX_MESSAGE_LIST_SIZE = 1
activity_tool = self.portal.portal_activities
def doSomething(self):
self.serialize()
self.activate(activity='SQLQueue').getId()
self.activate(activity='SQLQueue').getTitle()
conn = self._p_jar
tid = self._p_serial
oid = self._p_oid
try:
conn.db().invalidate({oid: tid})
except TypeError:
conn.db().invalidate(tid, {oid: tid})
activity_tool.__class__.doSomething = doSomething
activity_tool.activate(activity='SQLQueue').doSomething()
self.commit()
activity_tool.tic()
message_list = activity_tool.getMessageList()
self.assertEqual(['doSomething'],[x.method_id for x in message_list])
activity_tool.manageClearActivities()
finally:
SQLBase.MAX_MESSAGE_LIST_SIZE = MAX_MESSAGE_LIST_SIZE
def test_128_CheckDistributeWithSerializationTagAndGroupMethodId(self): def test_128_CheckDistributeWithSerializationTagAndGroupMethodId(self):
activity_tool = self.portal.portal_activities activity_tool = self.portal.portal_activities
obj1 = activity_tool.newActiveProcess() obj1 = activity_tool.newActiveProcess()
......
...@@ -482,6 +482,10 @@ class DB(TM): ...@@ -482,6 +482,10 @@ class DB(TM):
if m[0] not in hosed_connection: if m[0] not in hosed_connection:
raise raise
def getMaxAllowedPacket(self):
# minus 2-bytes overhead from mysql library
return self._query("SELECT @@max_allowed_packet-2").fetch_row()[0][0]
@contextmanager @contextmanager
def lock(self): def lock(self):
"""Lock for the connected DB""" """Lock for the connected DB"""
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment