Commit 992219fc authored by Stefan Behnel's avatar Stefan Behnel

speed up tree visitor somewhat by moving code out of the critical methods

parent d6bc4573
cimport cython
cdef class BasicVisitor: cdef class BasicVisitor:
cdef dict dispatch_table cdef dict dispatch_table
cpdef visit(self, obj) cpdef visit(self, obj)
cpdef find_handler(self, obj)
cdef class TreeVisitor(BasicVisitor): cdef class TreeVisitor(BasicVisitor):
cdef public list access_path cdef public list access_path
cpdef visitchild(self, child, parent, attrname, idx) cpdef visitchild(self, child, parent, attrname, idx)
@cython.locals(idx=int)
cpdef dict _visitchildren(self, parent, attrs)
# cpdef visitchildren(self, parent, attrs=*) # cpdef visitchildren(self, parent, attrs=*)
cdef class VisitorTransform(TreeVisitor): cdef class VisitorTransform(TreeVisitor):
......
# cython: infer_types=True
# #
# Tree visitor and transform framework # Tree visitor and transform framework
# #
...@@ -19,10 +21,15 @@ class BasicVisitor(object): ...@@ -19,10 +21,15 @@ class BasicVisitor(object):
self.dispatch_table = {} self.dispatch_table = {}
def visit(self, obj): def visit(self, obj):
cls = type(obj)
try: try:
handler_method = self.dispatch_table[cls] handler_method = self.dispatch_table[type(obj)]
except KeyError: except KeyError:
handler_method = self.find_handler(obj)
self.dispatch_table[type(obj)] = handler_method
return handler_method(obj)
def find_handler(self, obj):
cls = type(obj)
#print "Cache miss for class %s in visitor %s" % ( #print "Cache miss for class %s in visitor %s" % (
# cls.__name__, type(self).__name__) # cls.__name__, type(self).__name__)
# Must resolve, try entire hierarchy # Must resolve, try entire hierarchy
...@@ -34,7 +41,7 @@ class BasicVisitor(object): ...@@ -34,7 +41,7 @@ class BasicVisitor(object):
handler_method = getattr(self, pattern % mro_cls.__name__) handler_method = getattr(self, pattern % mro_cls.__name__)
break break
if handler_method is None: if handler_method is None:
print type(self), type(obj) print type(self), cls
if hasattr(self, 'access_path') and self.access_path: if hasattr(self, 'access_path') and self.access_path:
print self.access_path print self.access_path
if self.access_path: if self.access_path:
...@@ -42,8 +49,7 @@ class BasicVisitor(object): ...@@ -42,8 +49,7 @@ class BasicVisitor(object):
print self.access_path[-1][0].__dict__ print self.access_path[-1][0].__dict__
raise RuntimeError("Visitor does not accept object: %s" % obj) raise RuntimeError("Visitor does not accept object: %s" % obj)
#print "Caching " + cls.__name__ #print "Caching " + cls.__name__
self.dispatch_table[cls] = handler_method return handler_method
return handler_method(obj)
class TreeVisitor(BasicVisitor): class TreeVisitor(BasicVisitor):
""" """
...@@ -144,16 +150,8 @@ class TreeVisitor(BasicVisitor): ...@@ -144,16 +150,8 @@ class TreeVisitor(BasicVisitor):
stacktrace = stacktrace.tb_next stacktrace = stacktrace.tb_next
return (last_traceback, nodes) return (last_traceback, nodes)
def visitchild(self, child, parent, attrname, idx): def _raise_compiler_error(self, child, e):
self.access_path.append((parent, attrname, idx))
try:
result = self.visit(child)
except Errors.CompileError:
raise
except Exception, e:
import sys import sys
if DebugFlags.debug_no_exception_intercept:
raise
trace = [''] trace = ['']
for parent, attribute, index in self.access_path: for parent, attribute, index in self.access_path:
node = getattr(parent, attribute) node = getattr(parent, attribute)
...@@ -174,10 +172,24 @@ class TreeVisitor(BasicVisitor): ...@@ -174,10 +172,24 @@ class TreeVisitor(BasicVisitor):
raise Errors.CompilerCrash( raise Errors.CompilerCrash(
last_node.pos, self.__class__.__name__, last_node.pos, self.__class__.__name__,
u'\n'.join(trace), e, stacktrace) u'\n'.join(trace), e, stacktrace)
def visitchild(self, child, parent, attrname, idx):
self.access_path.append((parent, attrname, idx))
try:
result = self.visit(child)
except Errors.CompileError:
raise
except Exception, e:
if DebugFlags.debug_no_exception_intercept:
raise
self._raise_compiler_error(child, e)
self.access_path.pop() self.access_path.pop()
return result return result
def visitchildren(self, parent, attrs=None): def visitchildren(self, parent, attrs=None):
return self._visitchildren(parent, attrs)
def _visitchildren(self, parent, attrs):
""" """
Visits the children of the given parent. If parent is None, returns Visits the children of the given parent. If parent is None, returns
immediately (returning None). immediately (returning None).
...@@ -223,8 +235,7 @@ class VisitorTransform(TreeVisitor): ...@@ -223,8 +235,7 @@ class VisitorTransform(TreeVisitor):
are within a StatListNode or similar before doing this.) are within a StatListNode or similar before doing this.)
""" """
def visitchildren(self, parent, attrs=None): def visitchildren(self, parent, attrs=None):
result = cython.declare(dict) result = self._visitchildren(parent, attrs)
result = TreeVisitor.visitchildren(self, parent, attrs)
for attr, newnode in result.iteritems(): for attr, newnode in result.iteritems():
if not type(newnode) is list: if not type(newnode) is list:
setattr(parent, attr, newnode) setattr(parent, attr, newnode)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment