############################################################################## # # Copyright (c) 2002 Zope Corporation and Contributors. # All Rights Reserved. # # This software is subject to the provisions of the Zope Public License, # Version 2.0 (ZPL). A copy of the ZPL should accompany this distribution. # THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED # WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED # WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS # FOR A PARTICULAR PURPOSE # ############################################################################## """Compromising positions involving threads.""" import threading import transaction from ZODB.tests.StorageTestBase import zodb_pickle, MinPO import ZEO.ClientStorage ZERO = '\0'*8 class BasicThread(threading.Thread): def __init__(self, storage, doNextEvent, threadStartedEvent): self.storage = storage self.trans = transaction.Transaction() self.doNextEvent = doNextEvent self.threadStartedEvent = threadStartedEvent self.gotValueError = 0 self.gotDisconnected = 0 threading.Thread.__init__(self) self.setDaemon(1) def join(self): threading.Thread.join(self, 10) assert not self.isAlive() class GetsThroughVoteThread(BasicThread): # This thread gets partially through a transaction before it turns # execution over to another thread. We're trying to establish that a # tpc_finish() after a storage has been closed by another thread will get # a ClientStorageError error. # # This class gets does a tpc_begin(), store(), tpc_vote() and is waiting # to do the tpc_finish() when the other thread closes the storage. def run(self): self.storage.tpc_begin(self.trans) oid = self.storage.new_oid() self.storage.store(oid, ZERO, zodb_pickle(MinPO("c")), '', self.trans) self.storage.tpc_vote(self.trans) self.threadStartedEvent.set() self.doNextEvent.wait(10) try: self.storage.tpc_finish(self.trans) except ZEO.ClientStorage.ClientStorageError: self.gotValueError = 1 self.storage.tpc_abort(self.trans) class GetsThroughBeginThread(BasicThread): # This class is like the above except that it is intended to be run when # another thread is already in a tpc_begin(). Thus, this thread will # block in the tpc_begin until another thread closes the storage. When # that happens, this one will get disconnected too. def run(self): try: self.storage.tpc_begin(self.trans) except ZEO.ClientStorage.ClientStorageError: self.gotValueError = 1 class ThreadTests: # Thread 1 should start a transaction, but not get all the way through it. # Main thread should close the connection. Thread 1 should then get # disconnected. def checkDisconnectedOnThread2Close(self): doNextEvent = threading.Event() threadStartedEvent = threading.Event() thread1 = GetsThroughVoteThread(self._storage, doNextEvent, threadStartedEvent) thread1.start() threadStartedEvent.wait(10) self._storage.close() doNextEvent.set() thread1.join() self.assertEqual(thread1.gotValueError, 1) # Thread 1 should start a transaction, but not get all the way through # it. While thread 1 is in the middle of the transaction, a second thread # should start a transaction, and it will block in the tcp_begin() -- # because thread 1 has acquired the lock in its tpc_begin(). Now the main # thread closes the storage and both sub-threads should get disconnected. def checkSecondBeginFails(self): doNextEvent = threading.Event() threadStartedEvent = threading.Event() thread1 = GetsThroughVoteThread(self._storage, doNextEvent, threadStartedEvent) thread2 = GetsThroughBeginThread(self._storage, doNextEvent, threadStartedEvent) thread1.start() threadStartedEvent.wait(1) thread2.start() self._storage.close() doNextEvent.set() thread1.join() thread2.join() self.assertEqual(thread1.gotValueError, 1) self.assertEqual(thread2.gotValueError, 1) # Run a bunch of threads doing small and large stores in parallel def checkMTStores(self): threads = [] for i in range(5): t = threading.Thread(target=self.mtstorehelper) threads.append(t) t.start() for t in threads: t.join(30) for i in threads: self.failUnless(not t.isAlive()) # Helper for checkMTStores def mtstorehelper(self): name = threading.currentThread().getName() objs = [] for i in range(10): objs.append(MinPO("X" * 200000)) objs.append(MinPO("X")) for obj in objs: self._dostore(data=obj)