[Zope3-checkins] CVS: ZODB4/src/zodb/zeo/zrpc - smac.py:1.5 connection.py:1.6

Jeremy Hylton jeremy@zope.com
Thu, 19 Jun 2003 17:41:37 -0400


Update of /cvs-repository/ZODB4/src/zodb/zeo/zrpc
In directory cvs.zope.org:/tmp/cvs-serv15960/src/zodb/zeo/zrpc

Modified Files:
	smac.py connection.py 
Log Message:
Merge ZODB3-2-merge branch to the head.

This completes the porting of bug fixes and random improvements from
ZODB 3.2 to ZODB 4.


=== ZODB4/src/zodb/zeo/zrpc/smac.py 1.4 => 1.5 ===
--- ZODB4/src/zodb/zeo/zrpc/smac.py:1.4	Thu Mar 13 17:11:36 2003
+++ ZODB4/src/zodb/zeo/zrpc/smac.py	Thu Jun 19 17:41:07 2003
@@ -11,11 +11,27 @@
 # FOR A PARTICULAR PURPOSE
 #
 ##############################################################################
-"""Sized Message Async Connections."""
+"""Sized Message Async Connections.
 
-import asyncore, struct
+This class extends the basic asyncore layer with a record-marking
+layer.  The message_output() method accepts an arbitrary sized string
+as its argument.  It sends over the wire the length of the string
+encoded using struct.pack('>i') and the string itself.  The receiver
+passes the original string to message_input().
+
+This layer also supports an optional message authentication code
+(MAC).  If a session key is present, it uses HMAC-SHA-1 to generate a
+20-byte MAC.  If a MAC is present, the high-order bit of the length
+is set to 1 and the MAC immediately follows the length.
+"""
+
+import asyncore
+import errno
+import hmac
+import sha
+import socket
+import struct
 import threading
-import socket, errno
 
 from zodb.zeo.zrpc.interfaces import DisconnectedError
 from zodb.zeo.zrpc import log
@@ -42,6 +58,8 @@
 # that we could pass to send() without blocking.
 SEND_SIZE = 60000
 
+MAC_BIT = 0x80000000
+
 class SizedMessageAsyncConnection(asyncore.dispatcher):
     __super_init = asyncore.dispatcher.__init__
     __super_close = asyncore.dispatcher.close
@@ -72,8 +90,12 @@
         self.__output_lock = threading.Lock() # Protects __output
         self.__output = []
         self.__closed = 0
+        self.__hmac = None
         self.__super_init(sock, map)
 
+    def setSessionKey(self, sesskey):
+        self.__hmac = hmac.HMAC(sesskey, digestmod=sha)
+
     def get_addr(self):
         return self.addr
 
@@ -121,12 +143,17 @@
                 inp = "".join(inp)
 
             offset = 0
+            expect_mac = 0
             while (offset + msg_size) <= input_len:
                 msg = inp[offset:offset + msg_size]
                 offset = offset + msg_size
                 if not state:
                     # waiting for message
                     msg_size = struct.unpack(">i", msg)[0]
+                    expect_mac = msg_size & MAC_BIT
+                    if expect_mac:
+                        msg_size ^= MAC_BIT
+                        msg_size += 20
                     state = 1
                 else:
                     msg_size = 4
@@ -141,6 +168,17 @@
                     # incoming call to be handled.  During all this
                     # time, the __input_lock is held.  That's a good
                     # thing, because it serializes incoming calls.
+                    if expect_mac:
+                        mac = msg[:20]
+                        msg = msg[20:]
+                        if self.__hmac:
+                            self.__hmac.update(msg)
+                            _mac = self.__hmac.digest()
+                            if mac != _mac:
+                                raise ValueError("MAC failed: %r != %r"
+                                                 % (_mac, mac))
+                        else:
+                            log.warn("Received MAC but no session key set")
                     self.message_input(msg)
 
             self.__state = state
@@ -208,8 +246,13 @@
             raise DisconnectedError("Action is temporarily unavailable")
         self.__output_lock.acquire()
         try:
-            # do two separate appends to avoid copying the message string
-            self.__output.append(struct.pack(">i", len(message)))
+            # do separate appends to avoid copying the message string
+            if self.__hmac:
+                self.__output.append(struct.pack(">i", len(message) | MAC_BIT))
+                self.__hmac.update(message)
+                self.__output.append(self.__hmac.digest())
+            else:
+                self.__output.append(struct.pack(">i", len(message)))
             if len(message) <= SEND_SIZE:
                 self.__output.append(message)
             else:


=== ZODB4/src/zodb/zeo/zrpc/connection.py 1.5 => 1.6 ===
--- ZODB4/src/zodb/zeo/zrpc/connection.py:1.5	Fri Mar 14 10:51:05 2003
+++ ZODB4/src/zodb/zeo/zrpc/connection.py	Thu Jun 19 17:41:07 2003
@@ -113,6 +113,7 @@
 
     __super_init = smac.SizedMessageAsyncConnection.__init__
     __super_close = smac.SizedMessageAsyncConnection.close
+    __super_setSessionKey = smac.SizedMessageAsyncConnection.setSessionKey
 
     oldest_protocol_version = "Z400"
     protocol_version = "Z400"
@@ -146,6 +147,11 @@
         # waiting for a response
         self.replies_cond = threading.Condition()
         self.replies = {}
+        # waiting_for_reply is used internally to indicate whether
+        # a call is in progress.  setting a session key is deferred
+        # until after the call returns.
+        self.waiting_for_reply = False
+        self.delay_sesskey = None
         self.register_object(obj)
         self.handshake()
 
@@ -235,7 +241,11 @@
 
         meth = getattr(self.obj, name)
         try:
-            ret = meth(*args)
+            self.waiting_for_reply = True
+            try:
+                ret = meth(*args)
+            finally:
+                self.waiting_for_reply = False
         except (SystemExit, KeyboardInterrupt):
             raise
         except Exception, msg:
@@ -258,6 +268,10 @@
             else:
                 self.send_reply(msgid, ret)
 
+        if self.delay_sesskey:
+            self.__super_setSessionKey(self.delay_sesskey)
+            self.delay_sesskey = None
+
     def handle_error(self):
         if sys.exc_info()[0] == SystemExit:
             raise sys.exc_info()
@@ -305,6 +319,12 @@
         self.message_output(msg)
         self.poll()
 
+    def setSessionKey(self, key):
+        if self.waiting_for_reply:
+            self.delay_sesskey = key
+        else:
+            self.__super_setSessionKey(key)
+
     # The next two public methods (call and callAsync) are used by
     # clients to invoke methods on remote objects
 
@@ -328,7 +348,7 @@
             raise DisconnectedError()
         msgid = self.send_call(method, args, 0)
         r_flags, r_args = self.wait(msgid)
-        if (isinstance(r_args, tuple)
+        if (isinstance(r_args, tuple) and len(r_args) > 1
             and isinstance(r_args[0], types.ClassType)
             and issubclass(r_args[0], Exception)):
             inst = r_args[1]
@@ -365,6 +385,14 @@
         self.send_call(method, args, ASYNC)
         self.poll()
 
+    def callAsyncNoPoll(self, method, *args):
+        # Like CallAsync but doesn't poll.  This exists so that we can
+        # send invalidations atomically to all clients without
+        # allowing any client to sneak in a load request.
+        if self.closed:
+            raise DisconnectedError()
+        self.send_call(method, args, ASYNC)
+
     # handle IO, possibly in async mode
 
     def _prepare_async(self):
@@ -438,6 +466,13 @@
                         self.replies_cond.acquire()
         finally:
             self.replies_cond.release()
+
+    def flush(self):
+        """Invoke poll() until the output buffer is empty."""
+        if __debug__:
+            log.debug("flush")
+        while self.writable():
+            self.poll()
 
     def poll(self):
         """Invoke asyncore mainloop to get pending message out."""