[Zope-Checkins] CVS: ZODB3/ZEO - ClientStorage.py:1.93.2.1

Jeremy Hylton jeremy@zope.com
Tue, 13 May 2003 16:42:39 -0400


Update of /cvs-repository/ZODB3/ZEO
In directory cvs.zope.org:/tmp/cvs-serv10167/ZEO

Modified Files:
      Tag: ZODB3-auth-branch
	ClientStorage.py 
Log Message:
Add ClientStorage support for authentication.

XXX Somehow this was lost when the rest of the branch was checked in.


=== ZODB3/ZEO/ClientStorage.py 1.93 => 1.93.2.1 ===
--- ZODB3/ZEO/ClientStorage.py:1.93	Tue Apr 22 14:00:16 2003
+++ ZODB3/ZEO/ClientStorage.py	Tue May 13 16:42:38 2003
@@ -1,6 +1,6 @@
 ##############################################################################
 #
-# Copyright (c) 2001, 2002 Zope Corporation and Contributors.
+# Copyright (c) 2001, 2002, 2003 Zope Corporation and Contributors.
 # All Rights Reserved.
 #
 # This software is subject to the provisions of the Zope Public License,
@@ -29,7 +29,8 @@
 from ZEO import ClientCache, ServerStub
 from ZEO.TransactionBuffer import TransactionBuffer
 from ZEO.Exceptions \
-     import ClientStorageError, UnrecognizedResult, ClientDisconnected
+     import ClientStorageError, UnrecognizedResult, ClientDisconnected, \
+            AuthError
 from ZEO.zrpc.client import ConnectionManager
 
 from ZODB import POSException
@@ -99,7 +100,8 @@
                  min_disconnect_poll=5, max_disconnect_poll=300,
                  wait_for_server_on_startup=None, # deprecated alias for wait
                  wait=None, # defaults to 1
-                 read_only=0, read_only_fallback=0):
+                 read_only=0, read_only_fallback=0,
+                 username='', password=''):
 
         """ClientStorage constructor.
 
@@ -159,6 +161,17 @@
             writable storages are available.  Defaults to false.  At
             most one of read_only and read_only_fallback should be
             true.
+
+        username -- string with username to be used when authenticating.
+            These only need to be provided if you are connecting to an
+            authenticated server storage.
+ 
+        password -- string with plaintext password to be used
+            when authenticated.
+
+        Note that the authentication protocol is defined by the server
+        and is detected by the ClientStorage upon connecting (see
+        testConnection() and doAuth() for details).
         """
 
         log2(INFO, "%s (pid=%d) created %s/%s for storage: %r" %
@@ -217,6 +230,8 @@
         self._conn_is_read_only = 0
         self._storage = storage
         self._read_only_fallback = read_only_fallback
+        self._username = username
+        self._password = password
         # _server_addr is used by sortKey()
         self._server_addr = None
         self._tfile = None
@@ -293,18 +308,21 @@
                     break
                 log2(INFO, "Wait for cache verification to finish")
         else:
-            # If there is no mainloop running, this code needs
-            # to call poll() to cause asyncore to handle events.
-            while 1:
-                if self._ready.isSet():
-                    break
-                log2(INFO, "Wait for cache verification to finish")
-                if self._connection is None:
-                    # If the connection was closed while we were
-                    # waiting for it to become ready, start over.
-                    return self._wait()
-                else:
-                    self._connection.pending(30)
+            self._wait_sync()
+
+    def _wait_sync(self):
+        # If there is no mainloop running, this code needs
+        # to call poll() to cause asyncore to handle events.
+        while 1:
+            if self._ready.isSet():
+                break
+            log2(INFO, "Wait for cache verification to finish")
+            if self._connection is None:
+                # If the connection was closed while we were
+                # waiting for it to become ready, start over.
+                return self._wait()
+            else:
+                self._connection.pending(30)
 
     def close(self):
         """Storage API: finalize the storage, releasing external resources."""
@@ -344,6 +362,38 @@
         if cn is not None:
             cn.pending()
 
+    def doAuth(self, protocol, stub):
+        if self._username == '' and self._password == '':
+            raise AuthError, "empty username or password"
+
+        # import the auth module
+        # XXX: Should we validate the client module that is being specified
+        # by the server? A malicious server could cause any auth_*.py file
+        # to be loaded according to Python import semantics.
+
+        # XXX There should probably be a registry of valid authentication
+        # mechanisms for the client, and we should only import those
+        # modules.
+        
+        fullname = 'ZEO.auth.auth_' + protocol
+        try:
+            module = __import__(fullname, globals(), locals(), protocol)
+        except ImportError:
+            log("%s: no such an auth protocol: %s" %
+                (self.__class__.__name__, protocol))
+
+        # instantiate the client authenticator
+        Client = getattr(module, 'Client', None)
+        if not Client:
+            log("%s: %s is not a valid auth protocol, must have a " + \
+                "Client class" % (self.__class__.__name__, protocol))
+            raise AuthError, "invalid protocol"
+        
+        c = Client(stub)
+        
+        # Initiate authentication, returns boolean specifying whether OK
+        return c.start(self._username, self._password)
+        
     def testConnection(self, conn):
         """Internal: test the given connection.
 
@@ -369,6 +419,12 @@
         # XXX Check the protocol version here?
         self._conn_is_read_only = 0
         stub = self.StorageServerStubClass(conn)
+
+        # XXX: Verify return value?
+        auth = stub.getAuthProtocol()
+        if auth and not self.doAuth(auth, stub):
+            raise AuthError, "Authentication failed"
+        
         try:
             stub.register(str(self._storage), self._is_read_only)
             return 1
@@ -414,6 +470,9 @@
         self._oids = []
         self._info.update(stub.get_info())
         self.verify_cache(stub)
+        if not conn.is_async():
+            log2(INFO, "Waiting for cache verification to finish")
+            self._wait_sync()
 
     def set_server_addr(self, addr):
         # Normalize server address and convert to string