[Zodb-checkins] CVS: Packages/ZEO - zrpc2.py:1.1.2.2

jeremy@digicool.com jeremy@digicool.com
Thu, 29 Mar 2001 08:34:33 -0500 (EST)


Update of /cvs-repository/Packages/ZEO
In directory korak:/tmp/cvs-serv10202

Modified Files:
      Tag: ZEO-ZRPC-Dev
	zrpc2.py 
Log Message:
Add proper mainloop maintenance to zrpc2

Add _do_io() method to Connection that handles logic for when to
trigger a ThreadedAsync loop and when to call poll directly.

Hack alert: Add ServerConnection() that has empty _do_io() method.  On
the server side, the connection is driven by a single asyncore.poll()
and doesn't need to explicitly restart loops.

Fix name errors in exception handlers; err namespace no longer exists.

Add DebugLock that adds logging around acquire and release calls.  
XXX Depends on sys._getframe()

Use thread lock instead of threading lock; re-entrant is not what is
needed.

Fiddle with logging calls.





--- Updated File zrpc2.py in package Packages/ZEO --
--- zrpc2.py	2001/03/17 00:14:51	1.1.2.1
+++ zrpc2.py	2001/03/29 13:34:32	1.1.2.2
@@ -1,4 +1,4 @@
-"""RPC protocol for ZEO
+"""RPC protocol for ZEO based on asyncore
 
 The basic protocol is as:
 a pickled tuple containing: msgid, flags, method, args
@@ -18,6 +18,7 @@
 import socket
 import sys
 import threading
+import thread
 import time
 import traceback
 import types
@@ -26,6 +27,7 @@
 
 from ZODB import POSException
 import smac
+import trigger
 import zLOG
 
 REPLY = ".reply" # message name used for replies
@@ -82,7 +84,7 @@
         try:
             msgid, flags, name, args = unpickler.load()
         except (cPickle.UnpicklingError, IndexError), msg:
-            raise err.DecodingError(msg)
+            raise DecodingError(msg)
         return msgid, flags, name, args
 
 class Delay:
@@ -99,6 +101,31 @@
 
     def reply(self, obj):
         self.send_reply(self.msgid, obj)
+
+class DebugLock:
+    def __init__(self):
+        self.lock = thread.allocate_lock()
+
+    def _debug(self):
+        method = sys._getframe().f_back
+        caller = method.f_back
+        filename = os.path.split(caller.f_code.co_filename)[1]
+        log("LOCK %s: %s called by %s, %s, line %s" % (id(self.lock),
+                                                         method.f_code.co_name,
+                                                         caller.f_code.co_name,
+                                                         filename,
+                                                         caller.f_lineno))
+
+    def acquire(self, wait=None):
+        self._debug()
+        if wait is not None:
+            return self.lock.acquire(wait)
+        else:
+            return self.lock.acquire()
+        
+    def release(self):
+        self._debug()
+        return self.lock.release()
     
 class Connection(smac.SizedMessageAsyncConnection):
     """Dispatcher for RPC on object
@@ -122,9 +149,10 @@
         self.obj = obj
         self.marshal = Marshaller(pickle)
         self.closed = 0
+        self.async = 0
         # The reply lock is used to block when a synchronous call is
         # waiting for a response
-        self.__reply_lock = threading.Lock()
+        self.__reply_lock = thread.allocate_lock()
         self.__reply_lock.acquire()
         self.__super_init(sock, addr)
         if isinstance(obj, Handler):
@@ -146,24 +174,26 @@
         """Decoding an incoming message and dispatch it"""
         try:
             msgid, flags, name, args = self.marshal.decode(message)
-        except err.DecodingError, msg:
+        except DecodingError, msg:
             return self.return_error(None, None, sys.exc_info()[0],
                                      sys.exc_info()[1])  
 
+        log("message: %s, %s, %s, %s" % (msgid, flags, name, repr(args)[:40]),
+            level=zLOG.TRACE)
         if name == REPLY:
             self.handle_reply(msgid, flags, args)
         else:
             self.handle_request(msgid, flags, name, args)
 
     def handle_reply(self, msgid, flags, args):
+        log("reply: %s, %s, %s" % (msgid, flags, str(args)[:40]))
         self.__reply = msgid, flags, args
-#        self.__lock.release()
+        self.__reply_lock.release() # will fail if lock is unlocked
 
     def handle_request(self, msgid, flags, name, args):
-        if __debug__:
-            log("%s%s" % (name, args), zLOG.TRACE)
+        log("%s%s" % (name, repr(args)[:40]), zLOG.BLATHER)
         if not self.check_method(name):
-            raise err.ZRPCError("Invalid method name: %s" % name)
+            raise ZRPCError("Invalid method name: %s" % name)
 
         meth = getattr(self.obj, name)
         try:
@@ -185,8 +215,9 @@
         
         if flags & ASYNC:
             if ret is not None:
-                raise err.ZRPCError("async method returned value")
+                raise ZRPCError("async method returned value")
         else:
+            log("%s reply %s" % (name, repr(ret)[:40]), zLOG.BLATHER)
             if isinstance(ret, Delay):
                 ret.set_sender(msgid, self.send_reply)
             else:
@@ -194,7 +225,6 @@
 
     def handle_error(self):
         t, v, tb = sys.exc_info()
-        print "%s: %s" % (str(t), v)
         traceback.print_tb(tb)
 
     def check_method(self, name):
@@ -205,6 +235,10 @@
         self.message_output(msg)
     
     def return_error(self, msgid, flags, err_type, err_value):
+        if flags is None:
+            print "Exception raised during decoding"
+            self.handle_error()
+            return
         if flags & ASYNC:
             print "Asynchronous call raised exception:"
             self.handle_error()
@@ -220,6 +254,7 @@
             err = ZRPCError("Couldn't pickle error %s" % `err_value`)
             msg = self.marshal.encode(msgid, 0, REPLY, (ZRPCError, err))
         self.message_output(msg)
+        self._do_io()
         print "Sent error message for:"
         self.handle_error()
 
@@ -244,10 +279,6 @@
         
     # The previous five methods implement an asyncore socket map
 
-    def _mainloop(self):
-        """Invoke the asyncore mainloop"""
-        
-
     # The next two methods are used by clients to invoke methods on
     # remote objects  
 
@@ -258,17 +289,17 @@
         self.message_output(self.marshal.encode(msgid, 0, method, args))
 
         self.__reply = None
-        while self.__reply is None:
-            # this is where you want to call the main loop
-            asyncore.poll(60.0)
+        self._do_io(wait=1)
         r_msgid, r_flags, r_args = self.__reply
+        self.__reply_lock.acquire()
         assert r_msgid == msgid, "%s != %s: %s" % (r_msgid, msgid, r_args)
+
         if type(r_args) == types.TupleType \
            and type(r_args[0]) == types.ClassType \
            and issubclass(r_args[0], Exception):
-            log("error")
-            print repr(r_args[1])
+            log("call %s %s raised error" % (msgid, method))
             raise r_args[1]
+        log("call %s %s returned" % (msgid, method))
         return r_args
 
     def callAsync(self, method, *args):
@@ -276,6 +307,33 @@
         self.msgid += 1
         log("async %s %s" % (msgid, method))
         self.message_output(self.marshal.encode(msgid, ASYNC, method, args))
+        self._do_io()
+
+    # handle IO, possibly in async mode
+
+    def _do_io(self, wait=0): # XXX need better name
+        # XXX invariant? lock must be held when calling with wait==1
+        # otherwise, in non-async mode, there will be no poll
+        
+        log("_do_io(wait=%d), async=%d" % (wait, self.async),
+            level=zLOG.BLATHER)
+        if self.async:
+            self.trigger.pull_trigger()
+            if wait:
+                self.__reply_lock.acquire()
+        else:
+            if wait:
+                # do loop only if lock is already acquired
+                while not self.__reply_lock.acquire(0):
+                    asyncore.poll(60.0, self)
+                self.__reply_lock.release()
+            else:
+                asyncore.poll(0.0, self)
+
+class ServerConnection(Connection):
+    def _do_io(self, wait=0):
+        """If this is a server, there is no explicit IO to do"""
+        pass
 
 class ConnectionManager:
     """Keeps a connection up over time"""
@@ -293,6 +351,10 @@
     def register_object(self, obj):
         self.obj = obj
 
+    def set_async(self):
+        self.async = 1
+        self.trigger = trigger.trigger()
+
     def connect(self, sync=0, callback=None):
         if self.connected == 1:
             return
@@ -330,6 +392,7 @@
                     log("Connected to server", level=zLOG.DEBUG)
                 self.connected = 1
         if self.connected:
+            # XXX how do we get here with s being defined?
             c = ManagedConnection(s, self.addr, self.obj, self)
             log("Connection created: %s" % c)
             log("callback = %s" % self._callback)
@@ -345,8 +408,22 @@
         return t
 
     def closed(self, conn):
+        self.connected = 0
         self.connect()
 
+class ManagedServerConnection(ServerConnection):
+    """A connection that notifies its ConnectionManager of closing"""
+    __super_init = Connection.__init__
+    __super_close = Connection.close
+
+    def __init__(self, sock, addr, obj, mgr, pickle=None):
+        self.__mgr = mgr
+        self.__super_init(sock, addr, obj, pickle)
+
+    def close(self):
+        self.__super_close()
+        self.__mgr.closed(self)
+
 class ManagedConnection(Connection):
     """A connection that notifies its ConnectionManager of closing"""
     __super_init = Connection.__init__
@@ -418,12 +495,12 @@
     try:
         m = __import__(module, _globals, _globals, _silly)
     except ImportError, msg:
-        raise err.ZRPCError("import error %s: %s" % (module, msg))
+        raise ZRPCError("import error %s: %s" % (module, msg))
 
     try:
         r = getattr(m, name)
     except AttributeError:
-        raise err.ZRPCError("module %s has no global %s" % (module, name))
+        raise ZRPCError("module %s has no global %s" % (module, name))
         
     safe = getattr(r, '__no_side_effects__', 0)
     if safe:
@@ -432,5 +509,5 @@
     if type(r) == types.ClassType and issubclass(r, Exception):
         return r
 
-    raise err.ZRPCError("Unsafe global: %s.%s" % (module, name))
+    raise ZRPCError("Unsafe global: %s.%s" % (module, name))