Commit c74309cf authored by Julien Muchembled's avatar Julien Muchembled

ZMySQLDA: attach connections to ZODB connections instead of threads

This makes code simpler, faster and easier to understand.
It is easy to forget that ZODB connections can be reused by different threads,
which led to bug such as the one fixed by commit 2c11b76a.

ZODB already maintains a pool of connections to reuse
so we don't need anymore to have one.
parent 9cb5bdd3
...@@ -26,7 +26,7 @@ ...@@ -26,7 +26,7 @@
# #
############################################################################## ##############################################################################
from Products.ZMySQLDA.DA import Connection, ThreadedDB from Products.ZMySQLDA.DA import Connection, DB
from Products.ERP5Type.Globals import InitializeClass from Products.ERP5Type.Globals import InitializeClass
from App.special_dtml import HTMLFile from App.special_dtml import HTMLFile
from Acquisition import aq_parent from Acquisition import aq_parent
...@@ -59,11 +59,11 @@ class ActivityConnection(Connection): ...@@ -59,11 +59,11 @@ class ActivityConnection(Connection):
permission_type = 'Add Z MySQL Database Connections' permission_type = 'Add Z MySQL Database Connections'
def factory(self): def factory(self):
return ActivityThreadedDB return ActivityDB
InitializeClass(ActivityConnection) InitializeClass(ActivityConnection)
class ActivityThreadedDB(ThreadedDB): class ActivityDB(DB):
_sort_key = (0,) _sort_key = (0,)
...@@ -2226,11 +2226,9 @@ class TestCMFActivity(ERP5TypeTestCase, LogInterceptor): ...@@ -2226,11 +2226,9 @@ class TestCMFActivity(ERP5TypeTestCase, LogInterceptor):
self.flushAllActivities(silent=1, loop_size=100) self.flushAllActivities(silent=1, loop_size=100)
self.commit() self.commit()
# Check that cmf_activity SQL connection still works # Check that cmf_activity SQL connection still works
connection_da_pool = self.getPortalObject().cmf_activity_sql_connection() connection_da = self.getPortalObject().cmf_activity_sql_connection()
import thread
connection_da = connection_da_pool._db_pool[thread.get_ident()]
self.assertFalse(connection_da._registered) self.assertFalse(connection_da._registered)
connection_da_pool.query('select 1') connection_da.query('select 1')
self.assertTrue(connection_da._registered) self.assertTrue(connection_da._registered)
self.commit() self.commit()
self.assertFalse(connection_da._registered) self.assertFalse(connection_da._registered)
...@@ -3513,7 +3511,7 @@ class TestCMFActivity(ERP5TypeTestCase, LogInterceptor): ...@@ -3513,7 +3511,7 @@ class TestCMFActivity(ERP5TypeTestCase, LogInterceptor):
app = ZopeTestCase.app() app = ZopeTestCase.app()
try: try:
c = app[self.getPortalName()].cmf_activity_sql_connection() c = app[self.getPortalName()].cmf_activity_sql_connection()
return app._p_jar, c._access_db('sortKey', (), {}) return app._p_jar, c.sortKey()
finally: finally:
ZopeTestCase.close(app) ZopeTestCase.close(app)
jar, sort_key = sortKey() jar, sort_key = sortKey()
......
...@@ -61,13 +61,10 @@ class TestInvalidationBug(ERP5TypeTestCase): ...@@ -61,13 +61,10 @@ class TestInvalidationBug(ERP5TypeTestCase):
test_list = [] test_list = []
for connection_id, table in (('erp5_sql_connection', 'catalog'), for connection_id, table in (('erp5_sql_connection', 'catalog'),
('cmf_activity_sql_connection', 'message')): ('cmf_activity_sql_connection', 'message')):
conn_class = self.portal[connection_id].__class__ connection = self.portal[connection_id]
conn_string = self.portal[connection_id].connection_string query = connection.factory()('-' + connection.connection_string).query
connection = conn_class('_' + connection_id, '', sql = "rollback\0select * from %s where path='%s'" % (table, path)
'-' + conn_string).__of__(self.portal) test_list.append(lambda query=query, sql=sql: len(query(sql)[1]))
query = "rollback\0select * from %s where path='%s'" % (table, path)
test_list.append(lambda manage_test=connection.manage_test, query=query:
len(manage_test(query)))
result_list = [map(apply, test_list)] result_list = [map(apply, test_list)]
Transaction_commitResources = transaction.Transaction._commitResources Transaction_commitResources = transaction.Transaction._commitResources
connection = module._p_jar connection = module._p_jar
......
...@@ -89,7 +89,10 @@ $Id: DA.py,v 1.4 2001/08/09 20:16:36 adustman Exp $''' % database_type ...@@ -89,7 +89,10 @@ $Id: DA.py,v 1.4 2001/08/09 20:16:36 adustman Exp $''' % database_type
__version__='$Revision: 1.4 $'[11:-2] __version__='$Revision: 1.4 $'[11:-2]
import os import os
from db import ThreadedDB from collections import defaultdict
from weakref import WeakKeyDictionary
from db import DB
import transaction
import Shared.DC.ZRDB import Shared.DC.ZRDB
import DABase import DABase
from App.Dialogs import MessageDialog from App.Dialogs import MessageDialog
...@@ -97,8 +100,6 @@ from App.special_dtml import HTMLFile ...@@ -97,8 +100,6 @@ from App.special_dtml import HTMLFile
from App.ImageFile import ImageFile from App.ImageFile import ImageFile
from ExtensionClass import Base from ExtensionClass import Base
from DateTime import DateTime from DateTime import DateTime
from thread import allocate_lock
from Acquisition import aq_parent
SHARED_DC_ZRDB_LOCATION = os.path.dirname(Shared.DC.ZRDB.__file__) SHARED_DC_ZRDB_LOCATION = os.path.dirname(Shared.DC.ZRDB.__file__)
...@@ -108,12 +109,15 @@ def manage_addZMySQLConnection(self, id, title, ...@@ -108,12 +109,15 @@ def manage_addZMySQLConnection(self, id, title,
connection_string, connection_string,
check=None, REQUEST=None): check=None, REQUEST=None):
"""Add a DB connection to a folder""" """Add a DB connection to a folder"""
self._setObject(id, Connection(id, title, connection_string, check)) connection = Connection(id, title, connection_string)
if REQUEST is not None: return self.manage_main(self,REQUEST) self._setObject(id, connection)
if check:
connection.connect(connection_string)
if REQUEST is not None:
return self.manage_main(self, REQUEST)
# Connection Pool for connections to MySQL. # Connection Pool for connections to MySQL.
database_connection_pool_lock = allocate_lock() database_connection_pool = defaultdict(WeakKeyDictionary)
database_connection_pool = {}
class Connection(DABase.Connection): class Connection(DABase.Connection):
" " " "
...@@ -124,34 +128,27 @@ class Connection(DABase.Connection): ...@@ -124,34 +128,27 @@ class Connection(DABase.Connection):
manage_properties=HTMLFile('connectionEdit', globals()) manage_properties=HTMLFile('connectionEdit', globals())
def factory(self): return ThreadedDB connect_on_load = False
def factory(self): return DB
def manage_beforeDelete(self, item, container):
database_connection_pool.get(self._p_oid, {}).pop(self._p_jar, None)
def connect(self, s): def connect(self, s):
# if acquisition wrappers are not there, do not connect in order to prevent
# having 2 distinct connections for the same connector. Without this
# two following lines, there is in the pool for the same connector two connections,
# one for (connection_id,) and another one for (some, path, connection_id,)
if aq_parent(self) is None:
return self
try:
database_connection_pool_lock.acquire()
self._v_connected = '' self._v_connected = ''
pool_key = self.getPhysicalPath() if not self._p_oid:
connection = database_connection_pool.get(pool_key) transaction.savepoint(optimistic=True)
if connection is not None and connection._connection == s: pool = database_connection_pool[self._p_oid]
connection = pool.get(self._p_jar)
DB = self.factory()
if connection.__class__ is not DB or connection._connection != s:
connection = pool[self._p_jar] = DB(s)
self._v_database_connection = connection self._v_database_connection = connection
else:
if connection is not None:
connection.closeConnection()
ThreadedDB = self.factory()
database_connection_pool[pool_key] = ThreadedDB(s)
self._v_database_connection = database_connection_pool[pool_key]
# XXX If date is used as such, it can be wrong because an existing # XXX If date is used as such, it can be wrong because an existing
# connection may be reused. But this is suposedly only used as a # connection may be reused. But this is suposedly only used as a
# marker to know if connection was successfull. # marker to know if connection was successfull.
self._v_connected = DateTime() self._v_connected = DateTime()
finally:
database_connection_pool_lock.release()
return self return self
def sql_quote__(self, v, escapes={}): def sql_quote__(self, v, escapes={}):
......
...@@ -106,7 +106,6 @@ from ZODB.POSException import ConflictError ...@@ -106,7 +106,6 @@ from ZODB.POSException import ConflictError
import sys import sys
from string import strip, split, upper, rfind from string import strip, split, upper, rfind
from thread import get_ident, allocate_lock
hosed_connection = ( hosed_connection = (
CR.SERVER_GONE_ERROR, CR.SERVER_GONE_ERROR,
...@@ -169,12 +168,7 @@ def ord_or_None(s): ...@@ -169,12 +168,7 @@ def ord_or_None(s):
if s is not None: if s is not None:
return ord(s) return ord(s)
class ThreadedDB: class DB(TM):
"""
This class is an interface to DB.
Its characteristic is that an instance of this class interfaces multiple
instances of DB class, each one being bound to a specific thread.
"""
conv=conversions.copy() conv=conversions.copy()
conv[FIELD_TYPE.LONG] = int_or_long conv[FIELD_TYPE.LONG] = int_or_long
...@@ -194,18 +188,14 @@ class ThreadedDB: ...@@ -194,18 +188,14 @@ class ThreadedDB:
""" """
self._connection = connection self._connection = connection
self._kw_args = self._parse_connection_string(connection) self._kw_args = self._parse_connection_string(connection)
self._db_pool = {} self._forceReconnection()
self._db_lock = allocate_lock() transactional = self.db.server_capabilities & CLIENT.TRANSACTIONS
connection = MySQLdb.connect(**self._kw_args)
transactional = connection.server_capabilities & CLIENT.TRANSACTIONS
connection.close()
if self._try_transactions == '-': if self._try_transactions == '-':
transactional = 0 transactional = 0
elif not transactional and self._try_transactions == '+': elif not transactional and self._try_transactions == '+':
raise NotSupportedError, "transactions not supported by this server" raise NotSupportedError, "transactions not supported by this server"
self._use_TM = self._transactions = transactional self._transactions = transactional
if self._mysql_lock: self._use_TM = transactional or self._mysql_lock
self._use_TM = 1
def _parse_connection_string(self, connection): def _parse_connection_string(self, connection):
kwargs = {'conv': self.conv} kwargs = {'conv': self.conv}
...@@ -252,65 +242,6 @@ class ThreadedDB: ...@@ -252,65 +242,6 @@ class ThreadedDB:
kwargs['unix_socket'], items = items[0], items[1:] kwargs['unix_socket'], items = items[0], items[1:]
return kwargs return kwargs
def _pool_set(self, key, value):
self._db_lock.acquire()
try:
self._db_pool[key] = value
finally:
self._db_lock.release()
def _pool_get(self, key):
self._db_lock.acquire()
try:
return self._db_pool.get(key)
finally:
self._db_lock.release()
def _pool_del(self, key):
self._db_lock.acquire()
try:
del self._db_pool[key]
finally:
self._db_lock.release()
def closeConnection(self):
ident = get_ident()
try:
self._pool_del(ident)
except KeyError:
pass
def _access_db(self, method_id, args, kw):
"""
Generic method to call pooled objects' methods.
When the current thread had never issued any call, create a DB
instance.
"""
ident = get_ident()
db = self._pool_get(ident)
if db is None:
db = DB(kw_args=self._kw_args, use_TM=self._use_TM,
mysql_lock=self._mysql_lock,
transactions=self._transactions)
db.setSortKey(self._sort_key)
self._pool_set(ident, db)
return getattr(db, method_id)(*args, **kw)
def tables(self, *args, **kw):
return self._access_db(method_id='tables', args=args, kw=kw)
def columns(self, *args, **kw):
return self._access_db(method_id='columns', args=args, kw=kw)
def query(self, *args, **kw):
return self._access_db(method_id='query', args=args, kw=kw)
def string_literal(self, *args, **kw):
return self._access_db(method_id='string_literal', args=args, kw=kw)
class DB(TM):
defs={ defs={
FIELD_TYPE.CHAR: "i", FIELD_TYPE.DATE: "d", FIELD_TYPE.CHAR: "i", FIELD_TYPE.DATE: "d",
FIELD_TYPE.DATETIME: "d", FIELD_TYPE.DECIMAL: "n", FIELD_TYPE.DATETIME: "d", FIELD_TYPE.DECIMAL: "n",
...@@ -322,19 +253,11 @@ class DB(TM): ...@@ -322,19 +253,11 @@ class DB(TM):
_p_oid=_p_changed=_registered=None _p_oid=_p_changed=_registered=None
def __init__(self, kw_args, use_TM, mysql_lock, transactions):
self._kw_args = kw_args
self._mysql_lock = mysql_lock
self._use_TM = use_TM
self._transactions = transactions
self._forceReconnection()
def __del__(self): def __del__(self):
self.db.close() self.db.close()
def _forceReconnection(self): def _forceReconnection(self):
db = MySQLdb.connect(**self._kw_args) self.db = MySQLdb.connect(**self._kw_args)
self.db = db
def tables(self, rdb=0, def tables(self, rdb=0,
_care=('TABLE', 'VIEW')): _care=('TABLE', 'VIEW')):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment