diff --git a/product/ZMySQLDA/DA.py b/product/ZMySQLDA/DA.py index b183642899295673becbf007ac818b0dfd9fd8cb..3f2117ae3c0ad716d20f650897d9f0b4db7d8ac3 100644 --- a/product/ZMySQLDA/DA.py +++ b/product/ZMySQLDA/DA.py @@ -132,7 +132,7 @@ class Connection(DABase.Connection): self._v_database_connection = connection else: if connection is not None: - connection.close() + connection.closeConnection() DB = self.factory() database_connection_pool[pool_key] = DB(s) self._v_database_connection = database_connection_pool[pool_key] diff --git a/product/ZMySQLDA/db.py b/product/ZMySQLDA/db.py index df10470d92bb86189262da62a0c3453c7d1aba15..67510dd6318b7c1d4c7ae184afb8615f87a828e1 100644 --- a/product/ZMySQLDA/db.py +++ b/product/ZMySQLDA/db.py @@ -106,6 +106,7 @@ from zLOG import LOG, ERROR, INFO import string, sys from string import strip, split, find, upper, rfind from time import time +from thread import get_ident hosed_connection = ( CR.SERVER_GONE_ERROR, @@ -155,7 +156,7 @@ def int_or_long(s): try: return int(s) except: return long(s) -FINISH_OR_ABORT_CALLED_ID = '_v_finish_or_abort_called' +FINISH_OR_ABORT_CALLED_ID = '_finish_or_abort_called' class DB(TM): @@ -185,9 +186,10 @@ class DB(TM): def __init__(self,connection): self.connection=connection - self.kwargs = kwargs = self._parse_connection_string(connection) - self.db=apply(self.Database_Connection, (), kwargs) - transactional = self.db.server_capabilities & CLIENT.TRANSACTIONS + self.kwargs = self._parse_connection_string(connection) + self.db = {} + db = self.getConnection() + transactional = db.server_capabilities & CLIENT.TRANSACTIONS if self._try_transactions == '-': transactional = 0 elif not transactional and self._try_transactions == '+': @@ -196,10 +198,24 @@ class DB(TM): if self._mysql_lock: self._use_TM = 1 - def close(self): - if self.db is not None: - self.db.close() - self.db = None + def forceReconnection(self): + db = apply(self.Database_Connection, (), self.kwargs) + self.db[get_ident()] = db + return db + + def getConnection(self): + ident = get_ident() + db = self.db.get(ident) + if db is None: + db = self.forceReconnection() + return db + + def closeConnection(self): + ident = get_ident() + db = self.db.get(ident) + if db is not None: + db.close() + del self.db[ident] def _parse_connection_string(self, connection): kwargs = {'conv': self.conv} @@ -242,11 +258,11 @@ class DB(TM): _care=('TABLE', 'VIEW')): r=[] a=r.append - if 1: - self.db.query("SHOW TABLES") - result = self.db.store_result() + db = self.getConnection() + db.query("SHOW TABLES") + result = db.store_result() row = result.fetch_row(1) - while row: + while row: a({'TABLE_NAME': row[0][0], 'TABLE_TYPE': 'TABLE'}) row = result.fetch_row(1) return r @@ -254,10 +270,9 @@ class DB(TM): def columns(self, table_name): from string import join try: - # Field, Type, Null, Key, Default, Extra - if 1: - self.db.query('SHOW COLUMNS FROM %s' % table_name) - c=self.db.store_result() + db = self.getConnection() + db.query('SHOW COLUMNS FROM %s' % table_name) + c = db.store_result() except: return () r=[] @@ -301,7 +316,7 @@ class DB(TM): self._use_TM and self._register() desc=None result=() - db=self.db + db = self.getConnection() try: if 1: for qs in filter(None, map(strip,split(query_string, '\0'))): @@ -321,11 +336,10 @@ class DB(TM): result=c.fetch_row(max_rows) else: desc=None - except OperationalError, m: if m[0] not in hosed_connection: raise # Hm. maybe the db is hosed. Let's restart it. - db=self.db=apply(self.Database_Connection, (), self.kwargs) + self.forceReconnection() return self.query(query_string, max_rows) if desc is None: return (),() @@ -342,50 +356,54 @@ class DB(TM): func(item) return items, result - def string_literal(self, s): return self.db.string_literal(s) + def string_literal(self, s): + return self.getConnection().string_literal(s) def _begin(self, *ignored): try: + db = self.getConnection() if self._transactions: - self.db.query("BEGIN") - self.db.store_result() + db.query("BEGIN") + db.store_result() if self._mysql_lock: - self.db.query("SELECT GET_LOCK('%s',0)" % self._mysql_lock) - self.db.store_result() + db.query("SELECT GET_LOCK('%s',0)" % self._mysql_lock) + db.store_result() except: LOG('ZMySQLDA', ERROR, "exception during _begin", error=sys.exc_info()) raise + setattr(self, FINISH_OR_ABORT_CALLED_ID, False) def _finish(self, *ignored): if getattr(self, FINISH_OR_ABORT_CALLED_ID, False): return try: try: + db = self.getConnection() if self._mysql_lock: - self.db.query("SELECT RELEASE_LOCK('%s')" % self._mysql_lock) - self.db.store_result() + db.query("SELECT RELEASE_LOCK('%s')" % self._mysql_lock) + db.store_result() if self._transactions: - self.db.query("COMMIT") - self.db.store_result() + db.query("COMMIT") + db.store_result() except: LOG('ZMySQLDA', ERROR, "exception during _finish", error=sys.exc_info()) raise finally: - self._v_database_connection = None setattr(self, FINISH_OR_ABORT_CALLED_ID, True) def _abort(self, *ignored): if getattr(self, FINISH_OR_ABORT_CALLED_ID, False): return try: + db = self.getConnection() if self._mysql_lock: - self.db.query("SELECT RELEASE_LOCK('%s')" % self._mysql_lock) - self.db.store_result() + db.query("SELECT RELEASE_LOCK('%s')" % self._mysql_lock) + db.store_result() if self._transactions: - self.db.query("ROLLBACK") - self.db.store_result() + db.query("ROLLBACK") + db.store_result() else: LOG('ZMySQLDA', ERROR, "aborting when non-transactional") finally: