[Zodb-checkins] CVS: Zope3/lib/python/ZODB - ExportImport.py:1.18 Serialize.py:1.8

Jeremy Hylton jeremy@zope.com
Mon, 2 Dec 2002 16:22:35 -0500


Update of /cvs-repository/Zope3/lib/python/ZODB
In directory cvs.zope.org:/tmp/cvs-serv31574

Modified Files:
	ExportImport.py Serialize.py 
Log Message:
Move copy logic from ExportImport to ObjectCopier in Serialize module.





=== Zope3/lib/python/ZODB/ExportImport.py 1.17 => 1.18 ===
--- Zope3/lib/python/ZODB/ExportImport.py:1.17	Tue Nov 26 12:41:21 2002
+++ Zope3/lib/python/ZODB/ExportImport.py	Mon Dec  2 16:22:33 2002
@@ -13,9 +13,9 @@
 ##############################################################################
 """Support for database export and import."""
 
-from ZODB import POSException
-from ZODB.utils import p64, u64
-from ZODB.Serialize import findrefs
+from ZODB.POSException import ExportError
+from ZODB.utils import p64, u64, Set
+from ZODB.Serialize import findrefs, ObjectCopier
 from Transaction import get_transaction
 
 from cStringIO import StringIO
@@ -23,6 +23,8 @@
 from tempfile import TemporaryFile
 from types import StringType, TupleType
 
+export_end_marker = '\377' * 16
+
 class ExportImport:
     # a mixin for use with ZODB.Connection.Connection
 
@@ -35,12 +37,12 @@
             file = open(file, 'w+b')
         file.write('ZEXP')
         oids = [oid]
-        done_oids = {}
+        done_oids = Set()
         while oids:
             oid = oids.pop(0)
             if oid in done_oids:
                 continue
-            done_oids[oid] = 1
+            done_oids.add(oid)
             try:
                 p, serial = self._storage.load(oid, self._version)
             except:
@@ -68,7 +70,7 @@
             if customImporters is not None and customImporters.has_key(magic):
                 file.seek(0)
                 return customImporters[magic](self, file, clue)
-            raise POSException.ExportError, 'Invalid export header'
+            raise ExportError("Invalid export header")
 
         t = get_transaction()
         if clue is not None:
@@ -93,80 +95,37 @@
             self._importDuringCommit(txn, file, L)
         del self.__hooks
 
-    def _importDuringCommit(self, transaction, file, return_oid_list):
+    def _importDuringCommit(self, txn, file, return_oid_list):
         """Invoked by the transaction manager mid commit.
         
         Appends one item, the OID of the first object created,
         to return_oid_list.
         """
-        oids = {}
-
-        def persistent_load(ooid):
-            "Remap a persistent id to a new ID and create a ghost for it."
-
-            if isinstance(ooid, TupleType):
-                ooid, klass = ooid
-            else:
-                klass = None
-
-            oid = oids.get(ooid)
-            if oid is None:
-                if klass is None:
-                    oid = self._storage.new_oid()
-                    self._created.add(oid)
-                else:
-                    oid = self._storage.new_oid(), klass
-                    self._created.add(oid[0])
-                oids[ooid] = oid
-
-            g = Placeholder()
-            g.oid = oid
-            return g
-
-        version = self._version
+        copier = ObjectCopier(self, self._storage, self._created)
 
         while 1:
             h = file.read(16)
             if h == export_end_marker:
                 break
             if len(h) != 16:
-                raise POSException.ExportError, 'Truncated export file'
+                raise ExportError("Truncated export file")
             l = u64(h[8:16])
             p = file.read(l)
             if len(p) != l:
-                raise POSException.ExportError, 'Truncated export file'
+                raise ExportError("Truncated export file")
 
-            # XXX what does the tuple in oids mean?
-            ooid = h[:8]
-            if oids:
-                oid = oids[ooid]
-                if isinstance(oid, TupleType):
-                    oid = oid[0]
+            # XXX I think it would be better if copier.copy()
+            # returned an oid and a new pickle so that this logic
+            # wasn't smeared across to modules.
+            oid = h[:8]
+            new_ref = copier.oids.get(oid)
+            if new_ref is None:
+                new_oid = self._storage.new_oid()
+                copier.oids[oid] = new_oid, None
+                return_oid_list.append(new_oid)
+                self._created.add(new_oid)
             else:
-                oids[ooid] = oid = self._storage.new_oid()
-                return_oid_list.append(oid)
-                self._created.add(oid)
-
-            pfile = StringIO(p)
-            unpickler = Unpickler(pfile)
-            unpickler.persistent_load = persistent_load
-
-            newp = StringIO()
-            pickler = Pickler(newp, 1)
-            pickler.persistent_id = persistent_id
-
-            pickler.dump(unpickler.load())
-            pickler.dump(unpickler.load())
-            p = newp.getvalue()
-
-            self._storage.store(oid, None, p, version, transaction)
-
-export_end_marker = '\377' * 16
-
-class Placeholder(object):
-    pass
-
-def persistent_id(object):
-    if isinstance(object, Placeholder):
-        return object.oid
+                new_oid = new_ref[0]
 
+            new = copier.copy(p)
+            self._storage.store(new_oid, None, new, self._version, txn)


=== Zope3/lib/python/ZODB/Serialize.py 1.7 => 1.8 ===
--- Zope3/lib/python/ZODB/Serialize.py:1.7	Mon Dec  2 14:17:00 2002
+++ Zope3/lib/python/ZODB/Serialize.py	Mon Dec  2 16:22:33 2002
@@ -48,6 +48,8 @@
 are used.
 """
 
+__metaclass__ = type
+
 from cStringIO import StringIO
 import cPickle
 from types import StringType, TupleType
@@ -135,10 +137,13 @@
         return NewObjectIterator(self._stack)
 
     def getState(self, obj):
+        return self._dump(getClassMetadata(obj), obj.__getstate__())
+
+    def _dump(self, classmeta, state):
         self._file.reset()
         self._p.clear_memo()
-        self._p.dump(getClassMetadata(obj))
-        self._p.dump(obj.__getstate__())
+        self._p.dump(classmeta)
+        self._p.dump(state)
         return self._file.getvalue()
 
 class NewObjectIterator:
@@ -233,6 +238,54 @@
         if object is not None:
             return object
         return self._conn[oid]
+
+class CopyReference:
+    def __init__(self, ref):
+        self.ref = ref
+
+class CopyObjectReader(BaseObjectReader):
+
+    def __init__(self, storage, created, oids):
+        self._storage = storage
+        self._created = created
+        self._cache = oids
+
+    def _persistent_load(self, oid):
+        if isinstance(oid, TupleType):
+            oid, classmeta = oid
+        else:
+            classmeta = None
+        new_ref = self._cache.get(oid)
+        if new_ref is None:
+            new_oid = self._storage.new_oid()
+            self._created.add(new_oid)
+            self._cache[oid] = new_ref = new_oid, classmeta
+        return CopyReference(new_ref)
+
+    def readPickle(self, pickle):
+        unpickler = self._get_unpickler(pickle)
+        classmeta = unpickler.load()
+        state = unpickler.load()
+        return classmeta, state
+
+class CopyObjectWriter(ObjectWriter):
+
+    def _persistent_id(self, obj):
+        if isinstance(obj, CopyReference):
+            return obj.ref
+        else:
+            return super(CopyObjectWriter, self)._persistent_id(obj)
+
+class ObjectCopier:
+
+    def __init__(self, jar, storage, created):
+        self.oids = {}
+        self._reader = CopyObjectReader(storage, created, self.oids)
+        self._writer = CopyObjectWriter(jar)
+
+    def copy(self, pickle):
+        classmeta, state = self._reader.readPickle(pickle)
+        return self._writer._dump(classmeta, state)
 
 def findrefs(p):
     f = StringIO(p)