Commit 1e4733c8 authored by Jeremy Hylton's avatar Jeremy Hylton

Improve thread test infrastructure to avoid assertion errors in

threads that lead to unreported test failures.
parent 39c022b8
##############################################################################
#
# Copyright (c) 2001, 2002 Zope Corporation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.0 (ZPL). A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE.
#
##############################################################################
import random import random
import sys
import threading import threading
import time import time
...@@ -31,19 +19,50 @@ def sort(l): ...@@ -31,19 +19,50 @@ def sort(l):
l.sort() l.sort()
return l return l
class ZODBClientThread(threading.Thread): class TestThread(threading.Thread):
"""Base class for defining threads that run from unittest.
__super_init = threading.Thread.__init__ If the thread exits with an uncaught exception, catch it and
re-raise it when the thread is joined. The re-raise will cause
the test to fail.
The subclass should define a runtest() method instead of a run()
method.
"""
def __init__(self, test):
threading.Thread.__init__(self)
self.test = test
self._fail = None
self._exc_info = None
def run(self):
try:
self.runtest()
except:
self._exc_info = sys.exc_info()
def fail(self, msg=""):
self._test.fail(msg)
def join(self, timeout=None):
threading.Thread.join(self, timeout)
if self._exc_info:
raise self._exc_info[0], self._exc_info[1], self._exc_info[2]
class ZODBClientThread(TestThread):
__super_init = TestThread.__init__
def __init__(self, db, test, commits=10, delay=SHORT_DELAY): def __init__(self, db, test, commits=10, delay=SHORT_DELAY):
self.__super_init() self.__super_init(test)
self.setDaemon(1) self.setDaemon(1)
self.db = db self.db = db
self.test = test self.test = test
self.commits = commits self.commits = commits
self.delay = delay self.delay = delay
def run(self): def runtest(self):
conn = self.db.open() conn = self.db.open()
root = conn.root() root = conn.root()
d = self.get_thread_dict(root) d = self.get_thread_dict(root)
...@@ -69,27 +88,28 @@ class ZODBClientThread(threading.Thread): ...@@ -69,27 +88,28 @@ class ZODBClientThread(threading.Thread):
root[name] = m root[name] = m
get_transaction().commit() get_transaction().commit()
break break
except ConflictError: except ConflictError, err:
get_transaction().abort() get_transaction().abort()
root._p_jar.sync()
for i in range(10): for i in range(10):
try: try:
return root.get(name) return root.get(name)
except ConflictError: except ConflictError:
get_transaction().abort() get_transaction().abort()
class StorageClientThread(threading.Thread): class StorageClientThread(TestThread):
__super_init = threading.Thread.__init__ __super_init = TestThread.__init__
def __init__(self, storage, test, commits=10, delay=SHORT_DELAY): def __init__(self, storage, test, commits=10, delay=SHORT_DELAY):
self.__super_init() self.__super_init(test)
self.storage = storage self.storage = storage
self.test = test self.test = test
self.commits = commits self.commits = commits
self.delay = delay self.delay = delay
self.oids = {} self.oids = {}
def run(self): def runtest(self):
for i in range(self.commits): for i in range(self.commits):
self.dostore(i) self.dostore(i)
self.check() self.check()
...@@ -133,7 +153,7 @@ class StorageClientThread(threading.Thread): ...@@ -133,7 +153,7 @@ class StorageClientThread(threading.Thread):
class ExtStorageClientThread(StorageClientThread): class ExtStorageClientThread(StorageClientThread):
def run(self): def runtest(self):
# pick some other storage ops to execute # pick some other storage ops to execute
ops = [getattr(self, meth) for meth in dir(ExtStorageClientThread) ops = [getattr(self, meth) for meth in dir(ExtStorageClientThread)
if meth.startswith('do_')] if meth.startswith('do_')]
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment