[Zodb-checkins] CVS: ZODB4/src/zodb/zeo - schema.xml:1.2 runzeo.py:1.2 component.xml:1.2 threadedasync.py:1.3 stubs.py:1.8 server.py:1.13 interfaces.py:1.4 client.py:1.14 cache.py:1.5

Jeremy Hylton jeremy at zope.com
Thu Jun 19 18:41:39 EDT 2003


Update of /cvs-repository/ZODB4/src/zodb/zeo
In directory cvs.zope.org:/tmp/cvs-serv15960/src/zodb/zeo

Modified Files:
	threadedasync.py stubs.py server.py interfaces.py client.py 
	cache.py 
Added Files:
	schema.xml runzeo.py component.xml 
Log Message:
Merge ZODB3-2-merge branch to the head.

This completes the porting of bug fixes and random improvements from
ZODB 3.2 to ZODB 4.


=== ZODB4/src/zodb/zeo/schema.xml 1.1 => 1.2 ===
--- /dev/null	Thu Jun 19 17:41:39 2003
+++ ZODB4/src/zodb/zeo/schema.xml	Thu Jun 19 17:41:08 2003
@@ -0,0 +1,29 @@
+<schema>
+
+  <description>
+    This schema describes the configuration of the ZEO storage server
+    process.
+  </description>
+
+  <!-- Use the storage types defined by ZODB. -->
+  <import package="zodb"/>
+
+  <!-- Use the ZEO server information structure. -->
+  <import package="zodb/zeo"/>
+
+  <section type="zeo" name="*" required="yes" attribute="zeo" />
+
+  <multisection name="+" type="ZODB.storage"
+                attribute="storages"
+                required="yes">
+    <description>
+      One or more storages that are provided by the ZEO server.  The
+      section names are used as the storage names, and must be unique
+      within each ZEO storage server.  Traditionally, these names
+      represent small integers starting at '1'.
+    </description>
+  </multisection>
+
+  <section name="*" type="eventlog" attribute="eventlog" required="no" />
+
+</schema>


=== ZODB4/src/zodb/zeo/runzeo.py 1.1 => 1.2 ===
--- /dev/null	Thu Jun 19 17:41:39 2003
+++ ZODB4/src/zodb/zeo/runzeo.py	Thu Jun 19 17:41:08 2003
@@ -0,0 +1,276 @@
+#!python
+##############################################################################
+#
+# Copyright (c) 2001, 2002, 2003 Zope Corporation and Contributors.
+# All Rights Reserved.
+#
+# This software is subject to the provisions of the Zope Public License,
+# Version 2.0 (ZPL).  A copy of the ZPL should accompany this distribution.
+# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
+# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
+# FOR A PARTICULAR PURPOSE
+#
+##############################################################################
+"""Start the ZEO storage server.
+
+Usage: %s [-C URL] [-a ADDRESS] [-f FILENAME] [-h]
+
+Options:
+-C/--configuration URL -- configuration file or URL
+-a/--address ADDRESS -- server address of the form PORT, HOST:PORT, or PATH
+                        (a PATH must contain at least one "/")
+-f/--filename FILENAME -- filename for FileStorage
+-t/--timeout TIMEOUT -- transaction timeout in secondes (default no timeout)
+-h/--help -- print this usage message and exit
+-m/--monitor ADDRESS -- address of monitor server ([HOST:]PORT or PATH)
+
+Unless -C is specified, -a and -f are required.
+"""
+
+# The code here is designed to be reused by other, similar servers.
+# For the forseeable future, it must work under Python 2.1 as well as
+# 2.2 and above.
+
+import os
+import sys
+import getopt
+import signal
+import socket
+import logging
+
+import ZConfig
+from zdaemon.zdoptions import ZDOptions
+from zodb import zeo
+
+def parse_address(arg):
+    # XXX Not part of the official ZConfig API
+    obj = ZConfig.datatypes.SocketAddress(arg)
+    return obj.family, obj.address
+
+class ZEOOptionsMixin:
+
+    storages = None
+
+    def handle_address(self, arg):
+        self.family, self.address = parse_address(arg)
+
+    def handle_monitor_address(self, arg):
+        self.monitor_family, self.monitor_address = parse_address(arg)
+
+    def handle_filename(self, arg):
+        from zodb.config import FileStorage # That's a FileStorage *opener*!
+        class FSConfig:
+            def __init__(self, name, path):
+                self._name = name
+                self.path = path
+                self.create = 0
+                self.read_only = 0
+                self.stop = None
+                self.quota = None
+            def getSectionName(self):
+                return self._name
+        if not self.storages:
+            self.storages = []
+        name = str(1 + len(self.storages))
+        conf = FileStorage(FSConfig(name, arg))
+        self.storages.append(conf)
+
+    def add_zeo_options(self):
+        self.add(None, None, "a:", "address=", self.handle_address)
+        self.add(None, None, "f:", "filename=", self.handle_filename)
+        self.add("family", "zeo.address.family")
+        self.add("address", "zeo.address.address",
+                 required="no server address specified; use -a or -C")
+        self.add("read_only", "zeo.read_only", default=0)
+        self.add("invalidation_queue_size", "zeo.invalidation_queue_size",
+                 default=100)
+        self.add("transaction_timeout", "zeo.transaction_timeout",
+                 "t:", "timeout=", float)
+        self.add("monitor_address", "zeo.monitor_address", "m:", "monitor=",
+                 self.handle_monitor_address)
+        self.add('auth_protocol', 'zeo.authentication_protocol',
+                 None, 'auth-protocol=', default=None)
+        self.add('auth_database', 'zeo.authentication_database',
+                 None, 'auth-database=')
+        self.add('auth_realm', 'zeo.authentication_realm',
+                 None, 'auth-realm=')
+
+class ZEOOptions(ZDOptions, ZEOOptionsMixin):
+
+    logsectionname = "eventlog"
+
+    def __init__(self):
+        self.schemadir = os.path.dirname(zeo.__file__)
+        ZDOptions.__init__(self)
+        self.add_zeo_options()
+        self.add("storages", "storages",
+                 required="no storages specified; use -f or -C")
+
+
+class ZEOServer:
+
+    def __init__(self, options):
+        self.options = options
+
+    def main(self):
+        self.setup_default_logging()
+        self.check_socket()
+        self.clear_socket()
+        try:
+            self.open_storages()
+            self.setup_signals()
+            self.create_server()
+            self.loop_forever()
+        finally:
+            self.close_storages()
+            self.clear_socket()
+
+    def setup_default_logging(self):
+        if self.options.config_logger is not None:
+            return
+        if os.getenv("EVENT_LOG_FILE") is not None:
+            return
+        if os.getenv("STUPID_LOG_FILE") is not None:
+            return
+        # No log file is configured; default to stderr.  The logging
+        # level can still be controlled by {STUPID,EVENT}_LOG_SEVERITY.
+        os.environ["EVENT_LOG_FILE"] = ""
+
+    def check_socket(self):
+        if self.can_connect(self.options.family, self.options.address):
+            self.options.usage("address %s already in use" %
+                               repr(self.options.address))
+
+    def can_connect(self, family, address):
+        s = socket.socket(family, socket.SOCK_STREAM)
+        try:
+            s.connect(address)
+        except socket.error:
+            return 0
+        else:
+            s.close()
+            return 1
+
+    def clear_socket(self):
+        if isinstance(self.options.address, type("")):
+            try:
+                os.unlink(self.options.address)
+            except os.error:
+                pass
+
+    def open_storages(self):
+        self.storages = {}
+        for opener in self.options.storages:
+            _logger.info("opening storage %r using %s"
+                 % (opener.name, opener.__class__.__name__))
+            self.storages[opener.name] = opener.open()
+
+    def setup_signals(self):
+        """Set up signal handlers.
+
+        The signal handler for SIGFOO is a method handle_sigfoo().
+        If no handler method is defined for a signal, the signal
+        action is not changed from its initial value.  The handler
+        method is called without additional arguments.
+        """
+        if os.name != "posix":
+            return
+        if hasattr(signal, 'SIGXFSZ'):
+            signal.signal(signal.SIGXFSZ, signal.SIG_IGN) # Special case
+        init_signames()
+        for sig, name in signames.items():
+            method = getattr(self, "handle_" + name.lower(), None)
+            if method is not None:
+                def wrapper(sig_dummy, frame_dummy, method=method):
+                    method()
+                signal.signal(sig, wrapper)
+
+    def create_server(self):
+        from zodb.zeo.server import StorageServer
+        self.server = StorageServer(
+            self.options.address,
+            self.storages,
+            read_only=self.options.read_only,
+            invalidation_queue_size=self.options.invalidation_queue_size,
+            transaction_timeout=self.options.transaction_timeout,
+            monitor_address=self.options.monitor_address,
+            auth_protocol=self.options.auth_protocol,
+            auth_database=self.options.auth_database,
+            auth_realm=self.options.auth_realm)
+
+    def loop_forever(self):
+        from zodb.zeo.threadedasync import LoopCallback
+        LoopCallback.loop()
+
+    def handle_sigterm(self):
+        _logger.info("terminated by SIGTERM")
+        sys.exit(0)
+
+    def handle_sigint(self):
+        _logger.info("terminated by SIGINT")
+        sys.exit(0)
+
+    def handle_sighup(self):
+        _logger.info("restarted by SIGHUP")
+        sys.exit(1)
+
+    def handle_sigusr2(self):
+        # How should this work with new logging?
+        
+        # This requires a modern zLOG (from Zope 2.6 or later); older
+        # zLOG packages don't have the initialize() method
+        _logger.info("reinitializing zLOG")
+        # XXX Shouldn't this be below with _log()?
+        import zLOG
+        zLOG.initialize()
+        _logger.info("reinitialized zLOG")
+
+    def close_storages(self):
+        for name, storage in self.storages.items():
+            _logger.info("closing storage %r" % name)
+            try:
+                storage.close()
+            except: # Keep going
+                _logging.exception("failed to close storage %r" % name)
+
+
+# Signal names
+
+signames = None
+
+def signame(sig):
+    """Return a symbolic name for a signal.
+
+    Return "signal NNN" if there is no corresponding SIG name in the
+    signal module.
+    """
+
+    if signames is None:
+        init_signames()
+    return signames.get(sig) or "signal %d" % sig
+
+def init_signames():
+    global signames
+    signames = {}
+    for name, sig in signal.__dict__.items():
+        k_startswith = getattr(name, "startswith", None)
+        if k_startswith is None:
+            continue
+        if k_startswith("SIG") and not k_startswith("SIG_"):
+            signames[sig] = name
+
+
+# Main program
+
+def main(args=None):
+    global _logger
+    _logger = logging.getLogger("runzeo")
+
+    options = ZEOOptions()
+    options.realize(args)
+    s = ZEOServer(options)
+    s.main()
+
+if __name__ == "__main__":
+    main()


=== ZODB4/src/zodb/zeo/component.xml 1.1 => 1.2 ===
--- /dev/null	Thu Jun 19 17:41:39 2003
+++ ZODB4/src/zodb/zeo/component.xml	Thu Jun 19 17:41:08 2003
@@ -0,0 +1,101 @@
+<component>
+
+  <!-- stub out the type until we figure out how to zconfig logging -->
+  <sectiontype name="eventlog" />
+
+  <sectiontype name="zeo">
+
+    <description>
+      The content of a ZEO section describe operational parameters
+      of a ZEO server except for the storage(s) to be served.
+    </description>
+
+    <key name="address" datatype="socket-address"
+         required="yes">
+      <description>
+        The address at which the server should listen.  This can be in
+        the form 'host:port' to signify a TCP/IP connection or a
+        pathname string to signify a Unix domain socket connection (at
+        least one '/' is required).  A hostname may be a DNS name or a
+        dotted IP address.  If the hostname is omitted, the platform's
+        default behavior is used when binding the listening socket (''
+        is passed to socket.bind() as the hostname portion of the
+        address).
+      </description>
+    </key>
+
+    <key name="read-only" datatype="boolean"
+         required="no"
+         default="false">
+      <description>
+        Flag indicating whether the server should operate in read-only
+        mode.  Defaults to false.  Note that even if the server is
+        operating in writable mode, individual storages may still be
+        read-only.  But if the server is in read-only mode, no write
+        operations are allowed, even if the storages are writable.  Note
+        that pack() is considered a read-only operation.
+      </description>
+    </key>
+
+    <key name="invalidation-queue-size" datatype="integer"
+         required="no"
+         default="100">
+      <description>
+        The storage server keeps a queue of the objects modified by the
+        last N transactions, where N == invalidation_queue_size.  This
+        queue is used to speed client cache verification when a client
+        disconnects for a short period of time.
+      </description>
+    </key>
+
+    <key name="monitor-address" datatype="socket-address"
+         required="no">
+      <description>
+        The address at which the monitor server should listen.  If
+        specified, a monitor server is started.  The monitor server
+        provides server statistics in a simple text format.  This can
+        be in the form 'host:port' to signify a TCP/IP connection or a
+        pathname string to signify a Unix domain socket connection (at
+        least one '/' is required).  A hostname may be a DNS name or a
+        dotted IP address.  If the hostname is omitted, the platform's
+        default behavior is used when binding the listening socket (''
+        is passed to socket.bind() as the hostname portion of the
+        address).
+      </description>
+    </key>
+
+    <key name="transaction-timeout" datatype="integer"
+         required="no">
+      <description>
+        The maximum amount of time to wait for a transaction to commit
+        after acquiring the storage lock, specified in seconds.  If the
+        transaction takes too long, the client connection will be closed
+        and the transaction aborted.
+      </description>
+    </key>
+
+    <key name="authentication-protocol" required="no">
+      <description>
+        The name of the protocol used for authentication.  The
+        only protocol provided with ZEO is "digest," but extensions
+        may provide other protocols.
+      </description>
+    </key>
+
+    <key name="authentication-database" required="no">
+      <description>
+        The path of the database containing authentication credentials.
+      </description>
+    </key>
+
+    <key name="authentication-realm" required="no">
+      <description>
+        The authentication realm of the server.  Some authentication
+        schemes use a realm to identify the logic set of usernames
+        that are accepted by this server.
+      </description>
+    </key>
+
+  </sectiontype>
+
+</component>


=== ZODB4/src/zodb/zeo/threadedasync.py 1.2 => 1.3 ===
--- ZODB4/src/zodb/zeo/threadedasync.py:1.2	Wed Dec 25 09:12:22 2002
+++ ZODB4/src/zodb/zeo/threadedasync.py	Thu Jun 19 17:41:08 2003
@@ -52,7 +52,7 @@
     _loop_lock.acquire()
     try:
         if _looping is not None:
-            apply(callback, (_looping,) + args, kw or {})
+            callback(_looping, *args, **(kw or {}))
         else:
             _loop_callbacks.append((callback, args, kw))
     finally:
@@ -65,7 +65,7 @@
         _looping = map
         while _loop_callbacks:
             cb, args, kw = _loop_callbacks.pop()
-            apply(cb, (map,) + args, kw or {})
+            cb(map, *args, **(kw or {}))
     finally:
         _loop_lock.release()
 


=== ZODB4/src/zodb/zeo/stubs.py 1.7 => 1.8 ===
--- ZODB4/src/zodb/zeo/stubs.py:1.7	Mon May 19 11:02:51 2003
+++ ZODB4/src/zodb/zeo/stubs.py	Thu Jun 19 17:41:08 2003
@@ -52,7 +52,7 @@
         self.rpc.callAsync('endVerify')
 
     def invalidateTransaction(self, tid, invlist):
-        self.rpc.callAsync('invalidateTransaction', tid, invlist)
+        self.rpc.callAsyncNoPoll('invalidateTransaction', tid, invlist)
 
     def serialnos(self, arg):
         self.rpc.callAsync('serialnos', arg)
@@ -102,6 +102,12 @@
 
     def get_info(self):
         return self.rpc.call('get_info')
+
+    def getAuthProtocol(self):
+        return self.rpc.call('getAuthProtocol')
+    
+    def lastTransaction(self):
+        return self.rpc.call('lastTransaction')
 
     def getInvalidations(self, tid):
         return self.rpc.call('getInvalidations', tid)


=== ZODB4/src/zodb/zeo/server.py 1.12 => 1.13 ===
--- ZODB4/src/zodb/zeo/server.py:1.12	Sat Jun  7 02:54:23 2003
+++ ZODB4/src/zodb/zeo/server.py	Thu Jun 19 17:41:08 2003
@@ -58,7 +58,11 @@
 
     ClientStorageStubClass = ClientStorageStub
 
-    def __init__(self, server, read_only=0):
+    # A list of extension methods.  A subclass with extra methods
+    # should override.
+    extensions = []
+
+    def __init__(self, server, read_only=0, auth_realm=None):
         self.server = server
         # timeout and stats will be initialized in register()
         self.timeout = None
@@ -73,7 +77,22 @@
         self.verifying = 0
         self.logger = logging.getLogger("ZSS.%d.ZEO" % os.getpid())
         self.log_label = ""
+        self.authenticated = 0
+        self.auth_realm = auth_realm
+        # The authentication protocol may define extra methods.
+        self._extensions = {}
+        for func in self.extensions:
+            self._extensions[func.func_name] = None
+
+    def finish_auth(self, authenticated):
+        if not self.auth_realm:
+            return 1
+        self.authenticated = authenticated
+        return authenticated
 
+    def set_database(self, database):
+        self.database = database
+        
     def notifyConnected(self, conn):
         self.connection = conn # For restart_other() below
         self.client = self.ClientStorageStubClass(conn)
@@ -110,6 +129,7 @@
         """Delegate several methods to the storage"""
         self.versionEmpty = self.storage.versionEmpty
         self.versions = self.storage.versions
+        self.getSerial = self.storage.getSerial
         self.load = self.storage.load
         self.modifiedInVersion = self.storage.modifiedInVersion
         self.getVersion = self.storage.getVersion
@@ -125,9 +145,11 @@
             # can be removed
             pass
         else:
-            for name in fn().keys():
-                if not hasattr(self,name):
-                    setattr(self, name, getattr(self.storage, name))
+            d = fn()
+            self._extensions.update(d)
+            for name in d.keys():
+                assert not hasattr(self, name)
+                setattr(self, name, getattr(self.storage, name))
         self.lastTransaction = self.storage.lastTransaction
 
     def _check_tid(self, tid, exc=None):
@@ -149,6 +171,15 @@
                 return 0
         return 1
 
+    def getAuthProtocol(self):
+        """Return string specifying name of authentication module to use.
+
+        The module name should be auth_%s where %s is auth_protocol."""
+        protocol = self.server.auth_protocol
+        if not protocol or protocol == 'none':
+            return None
+        return protocol
+    
     def register(self, storage_id, read_only):
         """Select the storage that this client will use
 
@@ -173,19 +204,14 @@
                                                                    self)
 
     def get_info(self):
-        return {'name': self.storage.getName(),
-                'extensionMethods': self.getExtensionMethods(),
+        return {"name": self.storage.getName(),
+                "extensionMethods": self.getExtensionMethods(),
                  "implements": [iface.__name__
                                 for iface in providedBy(self.storage)],
                 }
 
     def getExtensionMethods(self):
-        try:
-            e = self.storage.getExtensionMethods
-        except AttributeError:
-            return {}
-        else:
-            return e()
+        return self._extensions
 
     def zeoLoad(self, oid):
         self.stats.loads += 1
@@ -564,7 +590,10 @@
     def __init__(self, addr, storages, read_only=0,
                  invalidation_queue_size=100,
                  transaction_timeout=None,
-                 monitor_address=None):
+                 monitor_address=None,
+                 auth_protocol=None,
+                 auth_filename=None,
+                 auth_realm=None):
         """StorageServer constructor.
 
         This is typically invoked from the start.py script.
@@ -606,6 +635,21 @@
             should listen.  If specified, a monitor server is started.
             The monitor server provides server statistics in a simple
             text format.
+
+        auth_protocol -- The name of the authentication protocol to use.
+            Examples are "digest" and "srp".
+            
+        auth_filename -- The name of the password database filename.
+            It should be in a format compatible with the authentication
+            protocol used; for instance, "sha" and "srp" require different
+            formats.
+            
+            Note that to implement an authentication protocol, a server
+            and client authentication mechanism must be implemented in a
+            auth_* module, which should be stored inside the "auth"
+            subdirectory. This module may also define a DatabaseClass
+            variable that should indicate what database should be used
+            by the authenticator.
         """
 
         self.addr = addr
@@ -621,6 +665,12 @@
         for s in storages.values():
             s._waiting = []
         self.read_only = read_only
+        self.auth_protocol = auth_protocol
+        self.auth_filename = auth_filename
+        self.auth_realm = auth_realm
+        self.database = None
+        if auth_protocol:
+            self._setup_auth(auth_protocol)
         # A list of at most invalidation_queue_size invalidations
         self.invq = []
         self.invq_bound = invalidation_queue_size
@@ -643,6 +693,40 @@
         else:
             self.monitor = None
 
+    def _setup_auth(self, protocol):
+        # Can't be done in global scope, because of cyclic references
+        from zodb.zeo.auth import get_module
+
+        name = self.__class__.__name__
+
+        module = get_module(protocol)
+        if not module:
+            self.logger.info("%s: no such an auth protocol: %s",
+                             name, protocol)
+            return
+        
+        storage_class, client, db_class = module
+        
+        if not storage_class or not issubclass(storage_class, ZEOStorage):
+            self.logger.info("%s: %s isn't a valid protocol, "
+                             "must have a StorageClass", name, protocol)
+            self.auth_protocol = None
+            return
+        self.ZEOStorageClass = storage_class
+
+        self.logger.info("%s: using auth protocol: %s", name, protocol)
+        
+        # We create a Database instance here for use with the authenticator
+        # modules. Having one instance allows it to be shared between multiple
+        # storages, avoiding the need to bloat each with a new authenticator
+        # Database that would contain the same info, and also avoiding any
+        # possibly synchronization issues between them.
+        self.database = db_class(self.auth_filename)
+        if self.database.realm != self.auth_realm:
+            raise ValueError("password database realm %r "
+                             "does not match storage realm %r"
+                             % (self.database.realm, self.auth_realm))
+
     def new_connection(self, sock, addr):
         """Internal: factory to create a new connection.
 
@@ -650,8 +734,13 @@
         whenever accept() returns a socket for a new incoming
         connection.
         """
-        z = self.ZEOStorageClass(self, self.read_only)
-        c = self.ManagedServerConnectionClass(sock, addr, z, self)
+        if self.auth_protocol and self.database:
+            zstorage = self.ZEOStorageClass(self, self.read_only,
+                                            auth_realm=self.auth_realm)
+            zstorage.set_database(self.database)
+        else:
+            zstorage = self.ZEOStorageClass(self, self.read_only)
+        c = self.ManagedServerConnectionClass(sock, addr, zstorage, self)
         self.logger.warn("new connection %s: %s", addr, `c`)
         return c
 


=== ZODB4/src/zodb/zeo/interfaces.py 1.3 => 1.4 ===
--- ZODB4/src/zodb/zeo/interfaces.py:1.3	Tue Feb 25 13:55:05 2003
+++ ZODB4/src/zodb/zeo/interfaces.py	Thu Jun 19 17:41:08 2003
@@ -27,3 +27,5 @@
 class ClientDisconnected(ClientStorageError):
     """The database storage is disconnected from the storage."""
 
+class AuthError(StorageError):
+    """The client provided invalid authentication credentials."""


=== ZODB4/src/zodb/zeo/client.py 1.13 => 1.14 ===
--- ZODB4/src/zodb/zeo/client.py:1.13	Fri Jun  6 11:24:21 2003
+++ ZODB4/src/zodb/zeo/client.py	Thu Jun 19 17:41:08 2003
@@ -33,6 +33,7 @@
 from zope.interface import directlyProvides, implements
 
 from zodb.zeo import cache
+from zodb.zeo.auth import get_module
 from zodb.zeo.stubs import StorageServerStub
 from zodb.zeo.tbuf import TransactionBuffer
 from zodb.zeo.zrpc.client import ConnectionManager
@@ -57,7 +58,7 @@
     the argument.
     """
     t = time.time()
-    t = apply(TimeStamp, (time.gmtime(t)[:5] + (t % 60,)))
+    t = TimeStamp(*time.gmtime(t)[:5] + (t % 60,))
     if prev_ts is not None:
         t = t.laterThan(prev_ts)
     return t
@@ -106,7 +107,8 @@
     def __init__(self, addr, storage='1', cache_size=20 * MB,
                  name='', client=None, var=None,
                  min_disconnect_poll=5, max_disconnect_poll=300,
-                 wait=True, read_only=False, read_only_fallback=False):
+                 wait=True, read_only=False, read_only_fallback=False,
+                 username='', password='', realm=None):
 
         """ClientStorage constructor.
 
@@ -161,6 +163,17 @@
             writable storages are available.  Defaults to false.  At
             most one of read_only and read_only_fallback should be
             true.
+
+        username -- string with username to be used when authenticating.
+            These only need to be provided if you are connecting to an
+            authenticated server storage.
+
+        password -- string with plaintext password to be used
+            when authenticated.
+
+        Note that the authentication protocol is defined by the server
+        and is detected by the ClientStorage upon connecting (see
+        testConnection() and doAuth() for details).
         """
 
         self.logger = logging.getLogger("ZCS.%d" % os.getpid())
@@ -202,6 +215,9 @@
         self._conn_is_read_only = 0
         self._storage = storage
         self._read_only_fallback = read_only_fallback
+        self._username = username
+        self._password = password
+        self._realm = realm
         # _server_addr is used by sortKey()
         self._server_addr = None
         self._tfile = None
@@ -236,6 +252,21 @@
         self._oid_lock = threading.Lock()
         self._oids = [] # Object ids retrieved from newObjectIds()
 
+        # load() and tpc_finish() must be serialized to guarantee
+        # that cache modifications from each occur atomically.
+        # It also prevents multiple load calls occuring simultaneously,
+        # which simplifies the cache logic.
+        self._load_lock = threading.Lock()
+        # _load_oid and _load_status are protected by _lock
+        self._load_oid = None
+        self._load_status = None
+
+        # Can't read data in one thread while writing data
+        # (tpc_finish) in another thread.  In general, the lock
+        # must prevent access to the cache while _update_cache
+        # is executing.
+        self._lock = threading.Lock()
+
         t = self._ts = get_timestamp()
         self._serial = `t`
         self._oid = '\0\0\0\0\0\0\0\0'
@@ -330,6 +361,29 @@
         if cn is not None:
             cn.pending()
 
+    def doAuth(self, protocol, stub):
+        if not (self._username and self._password):
+            raise AuthError, "empty username or password"
+
+        module = get_module(protocol)
+        if not module:
+            log2(PROBLEM, "%s: no such an auth protocol: %s" %
+                 (self.__class__.__name__, protocol))
+            return
+
+        storage_class, client, db_class = module
+
+        if not client:
+            log2(PROBLEM,
+                 "%s: %s isn't a valid protocol, must have a Client class" %
+                 (self.__class__.__name__, protocol))
+            raise AuthError, "invalid protocol"
+
+        c = client(stub)
+
+        # Initiate authentication, returns boolean specifying whether OK
+        return c.start(self._username, self._realm, self._password)
+
     def testConnection(self, conn):
         """Internal: test the given connection.
 
@@ -355,6 +409,16 @@
         # XXX Check the protocol version here?
         self._conn_is_read_only = 0
         stub = self.StorageServerStubClass(conn)
+        
+        auth = stub.getAuthProtocol()
+        self.logger.info("Client authentication successful")
+        if auth:
+            if self.doAuth(auth, stub):
+                self.logger.info("Client authentication successful")
+            else:
+                self.logger.error("Authentication failed")
+                raise AuthError, "Authentication failed"
+
         try:
             stub.register(str(self._storage), self._is_read_only)
             return 1
@@ -406,6 +470,12 @@
         if not conn.is_async():
             self.logger.warn("Waiting for cache verification to finish")
             self._wait_sync()
+        self._handle_extensions()
+
+    def _handle_extensions(self):
+        for name in self.getExtensionMethods().keys():
+            if not hasattr(self, name):
+                setattr(self, name, self._server.extensionMethod(name))
 
     def update_interfaces(self):
         # Update what interfaces the instance provides based on the server.
@@ -600,12 +670,6 @@
         """
         return self._server.history(oid, version, length)
 
-    def __getattr__(self, name):
-        if self.getExtensionMethods().has_key(name):
-            return self._server.extensionMethod(name)
-        else:
-            raise AttributeError(name)
-
     def loadSerial(self, oid, serial):
         """Storage API: load a historical revision of an object."""
         return self._server.loadSerial(oid, serial)
@@ -621,14 +685,39 @@
         specified by the given object id and version, if they exist;
         otherwise a KeyError is raised.
         """
-        p = self._cache.load(oid, version)
-        if p:
-            return p
+        self._lock.acquire()    # for atomic processing of invalidations
+        try:
+            pair = self._cache.load(oid, version)
+            if pair:
+                return pair
+        finally:
+            self._lock.release()
+
         if self._server is None:
             raise ClientDisconnected()
-        p, s, v, pv, sv = self._server.zeoLoad(oid)
-        self._cache.checkSize(0)
-        self._cache.store(oid, p, s, v, pv, sv)
+
+        self._load_lock.acquire()
+        try:
+            self._lock.acquire()
+            try:
+                self._load_oid = oid
+                self._load_status = 1
+            finally:
+                self._lock.release()
+
+            p, s, v, pv, sv = self._server.zeoLoad(oid)
+
+            self._lock.acquire()    # for atomic processing of invalidations
+            try:
+                if self._load_status:
+                    self._cache.checkSize(0)
+                    self._cache.store(oid, p, s, v, pv, sv)
+                self._load_oid = None
+            finally:
+                self._lock.release()
+        finally:
+            self._load_lock.release()
+
         if v and version and v == version:
             return pv, sv
         else:
@@ -641,9 +730,13 @@
 
         If no version modified the object, return an empty string.
         """
-        v = self._cache.modifiedInVersion(oid)
-        if v is not None:
-            return v
+        self._lock.acquire()
+        try:
+            v = self._cache.modifiedInVersion(oid)
+            if v is not None:
+                return v
+        finally:
+            self._lock.release()
         return self._server.modifiedInVersion(oid)
 
     def newObjectId(self):
@@ -740,6 +833,7 @@
 
         self._serial = id
         self._seriald.clear()
+        self._tbuf.clear()
         del self._serials[:]
 
     def end_transaction(self):
@@ -779,18 +873,23 @@
         """Storage API: finish a transaction."""
         if transaction is not self._transaction:
             return
+        self._load_lock.acquire()
         try:
-            if f is not None:
-                f()
+            self._lock.acquire()  # for atomic processing of invalidations
+            try:
+                self._update_cache()
+                if f is not None:
+                    f()
+            finally:
+                self._lock.release()
 
             tid = self._server.tpcFinish(self._serial)
+            self._cache.setLastTid(tid)
 
             r = self._check_serials()
             assert r is None or len(r) == 0, "unhandled serialnos: %s" % r
-
-            self._update_cache()
-            self._cache.setLastTid(tid)
         finally:
+            self._load_lock.release()
             self.end_transaction()
 
     def _update_cache(self):
@@ -799,6 +898,13 @@
         This iterates over the objects in the transaction buffer and
         update or invalidate the cache.
         """
+        # Must be called with _lock already acquired.
+
+        # XXX not sure why _update_cache() would be called on
+        # a closed storage.
+        if self._cache is None:
+            return
+
         self._cache.checkSize(self._tbuf.get_size())
         try:
             self._tbuf.begin_iterate()
@@ -892,15 +998,21 @@
         # oid, version pairs.  The DB's invalidate() method expects a
         # dictionary of oids.
 
-        # versions maps version names to dictionary of invalidations
-        versions = {}
-        for oid, version in invs:
-            d = versions.setdefault(version, {})
-            self._cache.invalidate(oid, version=version)
-            d[oid] = 1
-        if self._db is not None:
-            for v, d in versions.items():
-                self._db.invalidate(d, version=v)
+        self._lock.acquire()
+        try:
+            # versions maps version names to dictionary of invalidations
+            versions = {}
+            for oid, version in invs:
+                if oid == self._load_oid:
+                    self._load_status = 0
+                self._cache.invalidate(oid, version=version)
+                versions.setdefault(version, {})[oid] = 1
+
+            if self._db is not None:
+                for v, d in versions.items():
+                    self._db.invalidate(d, version=v)
+        finally:
+            self._lock.release()
 
     def endVerify(self):
         """Server callback to signal end of cache validation."""
@@ -928,7 +1040,7 @@
             self.logger.debug(
                 "Transactional invalidation during cache verification")
             for t in args:
-                self.self._pickler.dump(t)
+                self._pickler.dump(t)
             return
         self._process_invalidations(args)
 


=== ZODB4/src/zodb/zeo/cache.py 1.4 => 1.5 ===
--- ZODB4/src/zodb/zeo/cache.py:1.4	Thu Mar 13 16:32:30 2003
+++ ZODB4/src/zodb/zeo/cache.py	Thu Jun 19 17:41:08 2003
@@ -13,10 +13,11 @@
 ##############################################################################
 
 # XXX TO DO
-# use two indices rather than the sign bit of the index??????
-# add a shared routine to read + verify a record???
-# redesign header to include vdlen???
-# rewrite the cache using a different algorithm???
+# Add a shared routine to read + verify a record.  Have that routine
+#   return a record object rather than a string.
+# Use two indices rather than the sign bit of the index??????
+# Redesign header to include vdlen???
+# Rewrite the cache using a different algorithm???
 
 """Implement a client cache
 
@@ -44,7 +45,9 @@
 
   offset in record: name -- description
 
-  0: oid -- 8-byte object id
+  0: oidlen -- 2-byte unsigned object id length
+
+  2: reserved (6 bytes)
 
   8: status -- 1-byte status 'v': valid, 'n': non-version valid, 'i': invalid
                ('n' means only the non-version data in the record is valid)
@@ -57,23 +60,25 @@
 
   19: serial -- 8-byte non-version serial (timestamp)
 
-  27: data -- non-version data
+  27: oid -- object id
+
+  27+oidlen: data -- non-version data
 
-  27+dlen: version -- Version string (if vlen > 0)
+  27+oidlen+dlen: version -- Version string (if vlen > 0)
 
-  27+dlen+vlen: vdlen -- 4-byte length of version data (if vlen > 0)
+  27+oidlen+dlen+vlen: vdlen -- 4-byte length of version data (if vlen > 0)
 
-  31+dlen+vlen: vdata -- version data (if vlen > 0)
+  31+oidlen+dlen+vlen: vdata -- version data (if vlen > 0)
 
-  31+dlen+vlen+vdlen: vserial -- 8-byte version serial (timestamp)
+  31+oidlen+dlen+vlen+vdlen: vserial -- 8-byte version serial (timestamp)
                                  (if vlen > 0)
 
-  27+dlen (if vlen == 0) **or**
-  39+dlen+vlen+vdlen: tlen -- 4-byte (unsigned) record length (for
-                              redundancy and backward traversal)
+  27+oidlen+dlen (if vlen == 0) **or**
+  39+oidlen+dlen+vlen+vdlen: tlen -- 4-byte (unsigned) record length (for
+                                     redundancy and backward traversal)
 
-  31+dlen (if vlen == 0) **or**
-  43+dlen+vlen+vdlen: -- total record length (equal to tlen)
+  31+oidlen+dlen (if vlen == 0) **or**
+  43+oidlen+dlen+vlen+vdlen: -- total record length (equal to tlen)
 
 There is a cache size limit.
 
@@ -105,7 +110,6 @@
 file 0 and file 1.
 """
 
-import logging
 import os
 import time
 import logging
@@ -114,9 +118,9 @@
 from thread import allocate_lock
 
 from zodb.utils import u64
-from zodb.interfaces import ZERO
+from zodb.interfaces import ZERO, _fmt_oid
 
-magic = 'ZEC1'
+magic = 'ZEC2'
 headersize = 12
 
 MB = 1024**2
@@ -158,15 +162,13 @@
                 if os.path.exists(p[i]):
                     fi = open(p[i],'r+b')
                     if fi.read(4) == magic: # Minimal sanity
-                        fi.seek(0, 2)
-                        if fi.tell() > headersize:
-                            # Read serial at offset 19 of first record
-                            fi.seek(headersize + 19)
-                            s[i] = fi.read(8)
+                        # Read the ltid for this file.  If it never
+                        # saw a transaction commit, it will get tossed,
+                        # even if it has valid data.
+                        s[i] = fi.read(8)
                     # If we found a non-zero serial, then use the file
                     if s[i] != ZERO:
                         f[i] = fi
-                    fi = None
 
             # Whoever has the larger serial is the current
             if s[1] > s[0]:
@@ -186,11 +188,16 @@
             self._p = p = [None, None]
             f[0].write(magic + '\0' * (headersize - len(magic)))
             current = 0
+        self._current = current
 
-        self.log("%s: storage=%r, size=%r; file[%r]=%r",
-                 self.__class__.__name__, storage, size, current, p[current])
+        if self._ltid:
+            ts = "; last txn=%x" % u64(self._ltid)
+        else:
+            ts = ""
+        self.log("%s: storage=%r, size=%r; file[%r]=%r%s" %
+                 (self.__class__.__name__, storage, size, current, p[current],
+                  ts))
 
-        self._current = current
         self._setup_trace()
 
     def open(self):
@@ -224,6 +231,18 @@
                 except OSError:
                     pass
 
+    def _read_header(self, f, pos):
+        # Read record header from f at pos, returning header and oid.
+        f.seek(pos)
+        h = f.read(27)
+        if len(h) != 27:
+            self.log("_read_header: short record at %s in %s", pos, f.name)
+            return None, None
+        oidlen = unpack(">H", h[:2])[0]
+        oid = f.read(oidlen)
+        return h, oid
+        
+
     def getLastTid(self):
         """Get the last transaction id stored by setLastTid().
 
@@ -243,7 +262,7 @@
         f = self._f[self._current]
         f.seek(4)
         tid = f.read(8)
-        if len(tid) < 8 or tid == '\0\0\0\0\0\0\0\0':
+        if len(tid) < 8 or tid == ZERO:
             return None
         else:
             return tid
@@ -255,7 +274,7 @@
         cache file; otherwise it's an instance variable.
         """
         if self._client is None:
-            if tid == '\0\0\0\0\0\0\0\0':
+            if tid == ZERO:
                 tid = None
             self._ltid = tid
         else:
@@ -267,7 +286,7 @@
 
     def _setLastTid(self, tid):
         if tid is None:
-            tid = '\0\0\0\0\0\0\0\0'
+            tid = ZERO
         else:
             tid = str(tid)
             assert len(tid) == 8
@@ -292,18 +311,14 @@
                 return None
             f = self._f[p < 0]
             ap = abs(p)
-            f.seek(ap)
-            h = f.read(27)
-            if len(h) != 27:
-                self.log("invalidate: short record for oid %16x "
-                         "at position %d in cache file %d",
-                         U64(oid), ap, p < 0)
+            h, rec_oid = self._read_header(f, ap)
+            if h is None:
                 del self._index[oid]
                 return None
-            if h[:8] != oid:
-                self.log("invalidate: oid mismatch: expected %16x read %16x "
+            if rec_oid != oid:
+                self.log("invalidate: oid mismatch: expected %s read %s "
                          "at position %d in cache file %d",
-                         U64(oid), U64(h[:8]), ap, p < 0)
+                         _fmt_oid(oid), _fmt_oid(rec_oid), ap, p < 0)
                 del self._index[oid]
                 return None
             f.seek(ap+8) # Switch from reading to writing
@@ -329,16 +344,18 @@
             ap = abs(p)
             seek = f.seek
             read = f.read
-            seek(ap)
-            h = read(27)
-            if len(h)==27 and h[8] in 'nv' and h[:8]==oid:
+            h, rec_oid = self._read_header(f, ap)
+            if h is None:
+                del self._index[oid]
+                return None
+            if len(h) == 27 and h[8] in 'nv' and rec_oid == oid:
                 tlen, vlen, dlen = unpack(">iHi", h[9:19])
             else:
                 tlen = -1
             if tlen <= 0 or vlen < 0 or dlen < 0 or vlen+dlen > tlen:
-                self.log("load: bad record for oid %16x "
+                self.log("load: bad record for oid %s "
                          "at position %d in cache file %d",
-                         U64(oid), ap, p < 0)
+                         _fmt_oid(oid), ap, p < 0)
                 del self._index[oid]
                 return None
 
@@ -357,7 +374,7 @@
                     data = read(dlen)
                     self._trace(0x2A, oid, version, h[19:], dlen)
                     if (p < 0) != self._current:
-                        self._copytocurrent(ap, tlen, dlen, vlen, h, data)
+                        self._copytocurrent(ap, tlen, dlen, vlen, h, oid, data)
                     return data, h[19:]
                 else:
                     self._trace(0x26, oid, version)
@@ -369,12 +386,12 @@
             v = vheader[:-4]
             if version != v:
                 if dlen:
-                    seek(ap+27)
+                    seek(ap+27+len(oid))
                     data = read(dlen)
                     self._trace(0x2C, oid, version, h[19:], dlen)
                     if (p < 0) != self._current:
                         self._copytocurrent(ap, tlen, dlen, vlen, h,
-                                            data, vheader)
+                                            oid, data, vheader)
                     return data, h[19:]
                 else:
                     self._trace(0x28, oid, version)
@@ -386,12 +403,12 @@
             self._trace(0x2E, oid, version, vserial, vdlen)
             if (p < 0) != self._current:
                 self._copytocurrent(ap, tlen, dlen, vlen, h,
-                                    None, vheader, vdata, vserial)
+                                    oid, None, vheader, vdata, vserial)
             return vdata, vserial
         finally:
             self._release()
 
-    def _copytocurrent(self, pos, tlen, dlen, vlen, header,
+    def _copytocurrent(self, pos, tlen, dlen, vlen, header, oid,
                        data=None, vheader=None, vdata=None, vserial=None):
         """Copy a cache hit from the non-current file to the current file.
 
@@ -402,29 +419,31 @@
         if self._pos + tlen > self._limit:
             return # Don't let this cause a cache flip
         assert len(header) == 27
+        oidlen = len(oid)
         if header[8] == 'n':
             # Rewrite the header to drop the version data.
             # This shortens the record.
-            tlen = 31 + dlen
+            tlen = 31 + oidlen + dlen
             vlen = 0
-            # (oid:8, status:1, tlen:4, vlen:2, dlen:4, serial:8)
+            # (oidlen:2, reserved:6, status:1, tlen:4,
+            #  vlen:2, dlen:4, serial:8)
             header = header[:9] + pack(">IHI", tlen, vlen, dlen) + header[-8:]
         else:
             assert header[8] == 'v'
         f = self._f[not self._current]
         if data is None:
-            f.seek(pos+27)
+            f.seek(pos + 27 + len(oid))
             data = f.read(dlen)
             if len(data) != dlen:
                 return
-        l = [header, data]
+        l = [header, oid, data]
         if vlen:
             assert vheader is not None
             l.append(vheader)
             assert (vdata is None) == (vserial is None)
             if vdata is None:
                 vdlen = unpack(">I", vheader[-4:])[0]
-                f.seek(pos+27+dlen+vlen+4)
+                f.seek(pos + 27 + len(oid) + dlen + vlen + 4)
                 vdata = f.read(vdlen)
                 if len(vdata) != vdlen:
                     return
@@ -440,13 +459,12 @@
         g.seek(self._pos)
         g.writelines(l)
         assert g.tell() == self._pos + tlen
-        oid = header[:8]
         if self._current:
             self._index[oid] = - self._pos
         else:
             self._index[oid] = self._pos
         self._pos += tlen
-        self._trace(0x6A, header[:8], vlen and vheader[:-4] or '',
+        self._trace(0x6A, oid, vlen and vheader[:-4] or '',
                     vlen and vserial or header[-8:], dlen)
 
     def update(self, oid, serial, version, data, refs):
@@ -462,9 +480,11 @@
                 ap = abs(p)
                 seek = f.seek
                 read = f.read
-                seek(ap)
-                h = read(27)
-                if len(h)==27 and h[8] in 'nv' and h[:8]==oid:
+                h, rec_oid = self._read_header(f, ap)
+                if h is None:
+                    del self._index[oid]
+                    return None
+                if len(h)==27 and h[8] in 'nv' and rec_oid == oid:
                     tlen, vlen, dlen = unpack(">iHi", h[9:19])
                 else:
                     return self._store(oid, '', '', version, data, serial)
@@ -500,16 +520,19 @@
             ap = abs(p)
             seek = f.seek
             read = f.read
-            seek(ap)
-            h = read(27)
-            if len(h)==27 and h[8] in 'nv' and h[:8]==oid:
+            h, rec_oid = self._read_header(f, ap)
+            if h is None:
+                del self._index[oid]
+                return None
+                
+            if len(h) == 27 and h[8] in 'nv' and rec_oid == oid:
                 tlen, vlen, dlen = unpack(">iHi", h[9:19])
             else:
                 tlen = -1
             if tlen <= 0 or vlen < 0 or dlen < 0 or vlen+dlen > tlen:
-                self.log("modifiedInVersion: bad record for oid %16x "
+                self.log("modifiedInVersion: bad record for oid %s "
                          "at position %d in cache file %d",
-                         U64(oid), ap, p < 0)
+                         _fmt_oid(oid), ap, p < 0)
                 del self._index[oid]
                 return None
 
@@ -581,7 +604,7 @@
         if not s:
             p = ''
             s = ZERO
-        tlen = 31 + len(p)
+        tlen = 31 + len(oid) + len(p)
         if version:
             tlen = tlen + len(version) + 12 + len(pv)
             vlen = len(version)
@@ -590,7 +613,11 @@
 
         stlen = pack(">I", tlen)
         # accumulate various data to write into a list
-        l = [oid, 'v', stlen, pack(">HI", vlen, len(p)), s]
+        assert len(oid) < 2**16
+        assert vlen < 2**16
+        assert tlen < 2L**32
+        l = [pack(">H6x", len(oid)), 'v', stlen,
+             pack(">HI", vlen, len(p)), s, oid]
         if p:
             l.append(p)
         if version:
@@ -643,11 +670,11 @@
         if version:
             code |= 0x80
         self._tracefile.write(
-            struct_pack(">ii8s8s",
+            struct_pack(">iiH8s",
                         time_time(),
                         (dlen+255) & 0x7fffff00 | code | self._current,
-                        oid,
-                        serial))
+                        len(oid),
+                        serial) + oid)
 
     def read_index(self, serial, fileindex):
         index = self._index
@@ -658,9 +685,8 @@
         count = 0
 
         while 1:
-            f.seek(pos)
-            h = read(27)
-            if len(h) != 27:
+            h, oid = self._read_header(f, pos)
+            if h is None:
                 # An empty read is expected, anything else is suspect
                 if h:
                     self.rilog("truncated header", pos, fileindex)
@@ -674,8 +700,6 @@
                 self.rilog("invalid header data", pos, fileindex)
                 break
 
-            oid = h[:8]
-
             if h[8] == 'v' and vlen:
                 seek(dlen+vlen, 1)
                 vdlen = read(4)
@@ -683,7 +707,7 @@
                     self.rilog("truncated record", pos, fileindex)
                     break
                 vdlen = unpack(">i", vdlen)[0]
-                if vlen+dlen+43+vdlen != tlen:
+                if vlen + dlen + 43 + len(oid) + vdlen != tlen:
                     self.rilog("inconsistent lengths", pos, fileindex)
                     break
                 seek(vdlen, 1)
@@ -693,7 +717,7 @@
                     break
             else:
                 if h[8] in 'vn' and vlen == 0:
-                    if dlen+31 != tlen:
+                    if dlen + len(oid) + 31 != tlen:
                         self.rilog("inconsistent nv lengths", pos, fileindex)
                     seek(dlen, 1)
                     if read(4) != h[9:13]:




More information about the Zodb-checkins mailing list