[Zope3-checkins] CVS: Zope3/src/zodb/storage/tests - mt.py:1.5

Jeremy Hylton jeremy@zope.com
Thu, 13 Mar 2003 15:48:51 -0500


Update of /cvs-repository/Zope3/src/zodb/storage/tests
In directory cvs.zope.org:/tmp/cvs-serv20012

Modified Files:
	mt.py 
Log Message:
Add mechanism so that failures in a thread cause the test to fail.


=== Zope3/src/zodb/storage/tests/mt.py 1.4 => 1.5 ===
--- Zope3/src/zodb/storage/tests/mt.py:1.4	Wed Feb  5 18:28:27 2003
+++ Zope3/src/zodb/storage/tests/mt.py	Thu Mar 13 15:48:50 2003
@@ -12,9 +12,10 @@
 #
 ##############################################################################
 
-import time
 import random
+import sys
 import threading
+import time
 
 from persistence.dict import PersistentDict
 from transaction import get_transaction
@@ -33,20 +34,49 @@
     l.sort()
     return l
 
+class TestThread(threading.Thread):
+    """Base class for defining threads that run from unittest.
+
+    If the thread exits with an uncaught exception, catch it and
+    re-raise it when the thread is joined.  The re-raise will cause
+    the test to fail.
+
+    The subclass should define a runtest() method instead of a run()
+    method.
+    """
+
+    def __init__(self, test):
+        threading.Thread.__init__(self)
+        self.test = test
+        self._fail = None
+        self._exc_info = None
+
+    def run(self):
+        try:
+            self.runtest()
+        except:
+            self._exc_info = sys.exc_info()
+
+    def fail(self, msg=""):
+        self._test.fail(msg)
+
+    def join(self, timeout=None):
+        threading.Thread.join(self, timeout)
+        if self._exc_info:
+            raise self._exc_info[0], self._exc_info[1], self._exc_info[2]
 
-class ZODBClientThread(threading.Thread):
+class ZODBClientThread(TestThread):
 
-    __super_init = threading.Thread.__init__
+    __super_init = TestThread.__init__
 
     def __init__(self, db, test, commits=10, delay=SHORT_DELAY):
-        self.__super_init()
+        self.__super_init(test)
         self.setDaemon(1)
         self.db = db
-        self.test = test
         self.commits = commits
         self.delay = delay
 
-    def run(self):
+    def runtest(self):
         conn = self.db.open()
         root = conn.root()
         d = self.get_thread_dict(root)
@@ -81,19 +111,18 @@
                 get_transaction().abort()
 
 
-class StorageClientThread(threading.Thread):
+class StorageClientThread(TestThread):
 
-    __super_init = threading.Thread.__init__
+    __super_init = TestThread.__init__
 
     def __init__(self, storage, test, commits=10, delay=SHORT_DELAY):
-        self.__super_init()
+        self.__super_init(test)
         self.storage = storage
-        self.test = test
         self.commits = commits
         self.delay = delay
         self.oids = {}
 
-    def run(self):
+    def runtest(self):
         for i in range(self.commits):
             self.dostore(i)
         self.check()
@@ -138,7 +167,7 @@
 
 class ExtStorageClientThread(StorageClientThread):
 
-    def run(self):
+    def runtest(self):
         # pick some other storage ops to execute
         ops = [getattr(self, meth) for meth in dir(ExtStorageClientThread)
                if meth.startswith('do_')]