[Zodb-checkins] CVS: ZODB3/ZODB - TmpStore.py:1.8

Jeremy Hylton jeremy@zope.com
Tue, 3 Dec 2002 16:26:38 -0500


Update of /cvs-repository/ZODB3/ZODB
In directory cvs.zope.org:/tmp/cvs-serv31443

Modified Files:
	TmpStore.py 
Log Message:
Sync TmpStore between ZODB3 and ZODB4


=== ZODB3/ZODB/TmpStore.py 1.7 => 1.8 ===
--- ZODB3/ZODB/TmpStore.py:1.7	Wed Aug 14 18:07:09 2002
+++ ZODB3/ZODB/TmpStore.py	Tue Dec  3 16:26:38 2002
@@ -12,97 +12,113 @@
 #
 ##############################################################################
 
-import POSException
-from utils import p64, u64
+from ZODB import POSException
+from ZODB.utils import p64, u64, z64
+
+import tempfile
 
 class TmpStore:
-    _transaction=_isCommitting=None
+    """A storage to support subtransactions."""
+
+    _bver = ''
 
-    def __init__(self, base_version, file=None):
-        if file is None:
-            import tempfile
-            file=tempfile.TemporaryFile()
-
-        self._file=file
-        self._index={}
-        self._pos=self._tpos=0
-        self._bver=base_version
-        self._tindex=[]
-        self._db=None
-        self._creating=[]
+    def __init__(self, base_version):
+        self._transaction = None
+        if base_version:
+            self._bver = base_version
+        self._file = tempfile.TemporaryFile()
+        # _pos: current file position
+        # _tpos: file position at last commit point
+        self._pos = self._tpos = 0
+        # _index: map oid to pos of last committed version
+        self._index = {}
+        # _tindex: map oid to pos for new updates
+        self._tindex = {}
+        self._db = None
+        self._creating = []
 
-    def __del__(self): self.close()
+    def __del__(self):
+        # XXX Is this necessary?
+        self._file.close()
 
     def close(self):
         self._file.close()
-        del self._file
-        del self._index
-        del self._db
 
-    def getName(self): return self._db.getName()
-    def getSize(self): return self._pos
+    def getName(self):
+        return self._db.getName()
+    
+    def getSize(self):
+        return self._pos
 
     def load(self, oid, version):
-        #if version is not self: raise KeyError, oid
-        pos=self._index.get(oid, None)
+        pos = self._index.get(oid)
         if pos is None:
             return self._storage.load(oid, self._bver)
-        file=self._file
-        file.seek(pos)
-        h=file.read(24)
+        self._file.seek(pos)
+        h = self._file.read(24)
         if h[:8] != oid:
-            raise POSException.StorageSystemError, 'Bad temporary storage'
-        return file.read(u64(h[16:])), h[8:16]
+            raise POSException.StorageSystemError('Bad temporary storage')
+        size = u64(h[16:])
+        serial = h[8:16]
+        return self._file.read(size), serial
+
+    # XXX clarify difference between self._storage & self._db._storage
 
     def modifiedInVersion(self, oid):
-        if self._index.has_key(oid): return 1
+        if self._index.has_key(oid):
+            return self._bver
         return self._db._storage.modifiedInVersion(oid)
 
-    def new_oid(self): return self._db._storage.new_oid()
+    def new_oid(self):
+        return self._db._storage.new_oid()
 
     def registerDB(self, db, limit):
-        self._db=db
-        self._storage=db._storage
+        self._db = db
+        self._storage = db._storage
 
     def store(self, oid, serial, data, version, transaction):
         if transaction is not self._transaction:
             raise POSException.StorageTransactionError(self, transaction)
-        file=self._file
-        pos=self._pos
-        file.seek(pos)
-        l=len(data)
+        self._file.seek(self._pos)
+        l = len(data)
         if serial is None:
-            serial = '\0\0\0\0\0\0\0\0'
-        file.write(oid+serial+p64(l))
-        file.write(data)
-        self._tindex.append((oid,pos))
-        self._pos=pos+l+24
+            serial = z64
+        self._file.write(oid + serial + p64(l))
+        self._file.write(data)
+        self._tindex[oid] = self._pos
+        self._pos += l + 24
         return serial
 
     def tpc_abort(self, transaction):
-        if transaction is not self._transaction: return
-        del self._tindex[:]
-        self._transaction=None
-        self._pos=self._tpos
+        if transaction is not self._transaction:
+            return
+        self._tindex.clear()
+        self._transaction = None
+        self._pos = self._tpos
 
     def tpc_begin(self, transaction):
-        if self._transaction is transaction: return
-        self._transaction=transaction
-        del self._tindex[:]   # Just to be sure!
-        self._pos=self._tpos
+        if self._transaction is transaction:
+            return
+        self._transaction = transaction
+        self._tindex.clear() # Just to be sure!
+        self._pos = self._tpos
 
-    def tpc_vote(self, transaction): pass
+    def tpc_vote(self, transaction):
+        pass
 
     def tpc_finish(self, transaction, f=None):
-        if transaction is not self._transaction: return
-        if f is not None: f()
-        index=self._index
-        tindex=self._tindex
-        for oid, pos in tindex: index[oid]=pos
-        del tindex[:]
-        self._tpos=self._pos
+        if transaction is not self._transaction:
+            return
+        if f is not None:
+            f()
+        self._index.update(self._tindex)
+        self._tindex.clear()
+        self._tpos = self._pos
 
-    def undoLog(self, first, last, filter=None): return ()
+    def undoLog(self, first, last, filter=None):
+        return ()
 
     def versionEmpty(self, version):
-        if version is self: return len(self._index)
+        # XXX what is this supposed to do?
+        if version == self._bver:
+            return len(self._index)