[Zope-Checkins] CVS: Zope/lib/python/ZConfig - SchemaParser.py:1.1.2.16

Fred L. Drake, Jr. fred@zope.com
Thu, 5 Dec 2002 15:13:58 -0500


Update of /cvs-repository/Zope/lib/python/ZConfig
In directory cvs.zope.org:/tmp/cvs-serv19436

Modified Files:
      Tag: chrism-install-branch
	SchemaParser.py 
Log Message:
Convert to use SAX 2 instead of xml.parsers.expat, so we're not bound
to a specific parser.


=== Zope/lib/python/ZConfig/SchemaParser.py 1.1.2.15 => 1.1.2.16 ===
--- Zope/lib/python/ZConfig/SchemaParser.py:1.1.2.15	Sat Nov 30 00:54:16 2002
+++ Zope/lib/python/ZConfig/SchemaParser.py	Thu Dec  5 15:13:57 2002
@@ -18,6 +18,7 @@
 """
 
 import types
+import xml.sax
 
 from Common import *
 from Config import Configuration
@@ -32,48 +33,14 @@
 nodefault = []
 
 
-class SchemaParser:
+class SchemaParser(xml.sax.ContentHandler):
 
-    handler_names = [
-        "StartElementHandler",
-        "EndElementHandler",
-        "ProcessingInstructionHandler",
-        "CharacterDataHandler",
-        "UnparsedEntityDeclHandler",
-        "NotationDeclHandler",
-        "StartNamespaceDeclHandler",
-        "EndNamespaceDeclHandler",
-        "CommentHandler",
-        "StartCdataSectionHandler",
-        "EndCdataSectionHandler",
-        "DefaultHandler",
-        "DefaultHandlerExpand",
-        "NotStandaloneHandler",
-        "ExternalEntityRefHandler",
-        "XmlDeclHandler",
-        "StartDoctypeDeclHandler",
-        "EndDoctypeDeclHandler",
-        "ElementDeclHandler",
-        "AttlistDeclHandler"
-        ]
-
-    def __init__(self, encoding='ascii'):
-        self.parser = p = self.createParser(encoding)
-        for name in self.handler_names:
-            method = getattr(self, name, None)
-            if method is not None:
-                setattr(p, name, method)
-        self.cdata_tags = ('description', 'metadefault', 'example')
-
-    def createParser(self, encoding=None):
-        global XMLParseError
-        from xml.parsers import expat
-        XMLParseError = expat.ExpatError
-        parser = expat.ParserCreate(encoding, ' ')
-        parser.returns_unicode = False
-        if hasattr(parser, 'buffer_text'):
-            parser.buffer_text = True
-        return parser
+    _cdata_tags = 'description', 'metadefault', 'example'
+    _current_cdata_attr = None
+    _locator = None
+
+    def __init__(self):
+        self._schema = None
 
     def __call__(self, f, context, print_classes=False):
         self._stack = []
@@ -87,48 +54,62 @@
             self.parseStream(f)
 
     def parseFile(self, filename):
-        self.parseStream(open(filename))
+        xml.sax.parse(open(filename, 'rU'), self)
 
     def parseString(self, s):
-        self.parser.Parse(s, 1)
+        xml.sax.parseString(s, self)
 
     def parseStream(self, stream):
-        self.parser.ParseFile(stream)
+        xml.sax.parse(stream, self)
+
+    # SAX 2 ContentHandler methods
 
-    def parseFragment(self, s, end=0):
-        self.parser.Parse(s, end)
+    def setDocumentLocator(self, locator):
+        self._locator = locator
 
-    def StartElementHandler(self, name, attrs):
+    def startElement(self, name, attrs):
+        attrs = dict(attrs)
         if name == SCHEMA_TYPE:
-            return self.handleSchema(attrs)
+            self.handleSchema(attrs)
         elif name == KEY_TYPE:
-            return self.handleKey(attrs)
+            if not self._schema:
+                self.doSchemaError(`name` + " element outside of schema")
+            self.handleKey(attrs)
         elif name == SECTION_TYPE:
-            return self.handleSection(attrs)
-        elif name in self.cdata_tags:
-            return self.handleCdata(name)
-        msg = "Unknown tag %s" % name
-        self.doSchemaError(msg)
+            if not self._schema:
+                self.doSchemaError(`name` + " element outside of schema")
+            self.handleSection(attrs)
+        elif name in self._cdata_tags:
+            if not self._schema:
+                self.doSchemaError(`name` + " element outside of schema")
+            self.handleCdata(name)
+        else:
+            self.doSchemaError("Unknown tag %s" % name)
 
-    def CharacterDataHandler(self, data):
+    def characters(self, data):
         data = data.strip()
         if not data:
             return
         attr = self.getCurrentCdata()
         if attr is None:
-            msg = 'cdata only valid within %s tags' % (self.cdata_tags,)
-            self.doSchemaError(msg)
+            self.doSchemaError('#pcdata only valid within %s tags'
+                               % `self._cdata_tags`[1:-1])
         setattr(self._stack[-1], attr, data)
 
-    def EndElementHandler(self, name):
-        if name in self.cdata_tags:
-            self.current_cdata_attr = None
+    def endElement(self, name):
+        if name in self._cdata_tags:
+            self._current_cdata_attr = None
             return
-        if name == SCHEMA_TYPE:
-            self.checkClasses()
-            self.context.setSchema(self._stack[0])
         self._stack.pop()
 
+    def endDocument(self):
+        if not self._schema:
+            self.doSchemaError("no schema found")
+        self.checkClasses()
+        self.context.setSchema(self._schema)
+
+    # schema handling logic
+
     def checkClasses(self):
         items = self._seen_classes.items()
         items.sort()
@@ -142,7 +123,7 @@
             self.doSchemaError('Class loading failed for: %s' % ', '.join(l))
 
     def getCurrentCdata(self):
-        return getattr(self, 'current_cdata_attr', None)
+        return self._current_cdata_attr
 
     def handleCdata(self, name):
         if not self._stack:
@@ -152,14 +133,14 @@
         if current:
             msg = 'Cannot place %s attribute in %s' % (name, current)
             self.doSchemaError(msg)
-        self.current_cdata_attr = name
+        self._current_cdata_attr = name
 
     def getClass(self, name):
         if name.startswith('.'):
             name = self._prefix + name
         if self._seen_classes.get(name, missing) is missing:
             try:
-                klass = importer(name)
+                klass = importer(str(name))
             except ImportError:
                 klass = ConfigMissing
             self._seen_classes[name] = klass
@@ -208,11 +189,12 @@
             config = Klass(type, name, self.context, None)
         except Exception, e:
             self.doSchemaError(str(e))
+        self._schema = config
         self._stack.append(config)
 
     def handleKey(self, attrs):
         if self.getCurrentCdata():
-            msg = 'cannot nest key in %s' % self.cdata_tags
+            msg = 'cannot nest key in ' + `self._cdata_tags`
             self.doSchemaError(msg)
         if not self._stack:
             msg = '"key" tag valid only within "schema" tag and subordinates'
@@ -251,7 +233,7 @@
 
     def handleSection(self, attrs):
         if self.getCurrentCdata():
-            msg = 'cannot nest section in %s' % self.cdata_tags
+            msg = 'cannot nest section in ' + 'self._cdata_tags'
             self.doSchemaError(msg)
         if not self._stack:
             msg = ('"section" tag valid only within "schema" tag '
@@ -286,8 +268,12 @@
         self._stack.append(config)
 
     def doSchemaError(self, msg):
-        raise SchemaError(msg, self.parser.ErrorLineNumber,
-                          self.parser.ErrorColumnNumber)
+        if self._locator is None:
+            colno = lineno = None
+        else:
+            colno = self._locator.getColumnNumber()
+            lineno = self._locator.getLineNumber()
+        raise SchemaError(msg, lineno, colno)
 
 class SchemaError(Exception):
     def __init__(self, msg, line, col):
@@ -296,7 +282,11 @@
         self.col = col
 
     def __str__(self):
-        return "%s at line %s, column %s" % (self.msg, self.line, self.col)
+        if self.line is not None:
+            pos = " at line %d, column %s" % (self.line, self.col)
+        else:
+            pos = " (unknown position)"
+        return self.msg + pos
 
 class SchemaConfiguration(Configuration):
     resolved = missing
@@ -709,7 +699,8 @@
 def importer(name):
     components = name.split('.')
     start = components[0]
-    package = __import__(start, globals(), globals())
+    g = globals()
+    package = __import__(start, g, g)
     modulenames = [start]
     for component in components[1:]:
         modulenames.append(component)
@@ -717,7 +708,7 @@
             package = getattr(package, component)
         except AttributeError:
             name = '.'.join(modulenames)
-            package = __import__(name, globals(), globals(), component)
+            package = __import__(name, g, g, component)
     return package
 
 if __name__ == '__main__':