Commit 2787ac2a authored by Stefan Behnel's avatar Stefan Behnel

TreePath implementation for selecting nodes from the code tree

parent 26d83ebd
......@@ -21,7 +21,7 @@ builtin_function_table = [
#('eval', "", "", ""),
#('execfile', "", "", ""),
#('filter', "", "", ""),
#('getattr', "OO", "O", "PyObject_GetAttr"), # optimised later on
#('getattr', "OO", "O", "PyObject_GetAttr"), # optimised later on
('getattr3', "OOO", "O", "__Pyx_GetAttr3", "getattr"),
('hasattr', "OO", "b", "PyObject_HasAttr"),
('hash', "O", "l", "PyObject_Hash"),
......@@ -29,7 +29,7 @@ builtin_function_table = [
#('id', "", "", ""),
#('input', "", "", ""),
('intern', "s", "O", "__Pyx_InternFromString"),
('isinstance', "OO", "b", "PyObject_IsInstance"),
#('isinstance', "OO", "b", "PyObject_IsInstance"), # optimised later on
('issubclass', "OO", "b", "PyObject_IsSubclass"),
('iter', "O", "O", "PyObject_GetIter"),
('len', "O", "Z", "PyObject_Length"),
......
......@@ -712,6 +712,41 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform):
"expected 2 or 3, found %d" % len(args))
return node
PyObject_TypeCheck_func_type = PyrexTypes.CFuncType(
PyrexTypes.c_bint_type, [
PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
PyrexTypes.CFuncTypeArg("type", PyrexTypes.c_py_type_object_ptr_type, None),
])
PyObject_IsInstance_func_type = PyrexTypes.CFuncType(
PyrexTypes.c_bint_type, [
PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
PyrexTypes.CFuncTypeArg("type", PyrexTypes.py_object_type, None),
])
def _handle_simple_function_isinstance(self, node, pos_args):
"""Replace generic calls to isinstance(x, type) by a more
efficient type check.
"""
args = pos_args.args
if len(args) != 2:
error(node.pos, "isinstance(x, type) called with wrong number of args, found %d" %
len(args))
return node
type_arg = args[1]
if type_arg.type is Builtin.type_type:
function_name = "PyObject_TypeCheck"
function_type = self.PyObject_TypeCheck_func_type
args[1] = ExprNodes.CastNode(type_arg, PyrexTypes.c_py_type_object_ptr_type)
else:
function_name = "PyObject_IsInstance"
function_type = self.PyObject_IsInstance_func_type
return ExprNodes.PythonCapiCallNode(
node.pos, function_name, function_type,
args = args, is_temp = node.is_temp)
Pyx_Type_func_type = PyrexTypes.CFuncType(
Builtin.type_type, [
PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None)
......@@ -1058,8 +1093,8 @@ class FinalOptimizePhase(Visitor.CythonTransform):
just before the C code generation phase.
The optimizations currently implemented in this class are:
- Eliminate None assignment and refcounting for first assignment.
- isinstance -> typecheck for cdef types
- Eliminate None assignment and refcounting for first assignment.
- Eliminate dead coercion nodes.
"""
def visit_SingleAssignmentNode(self, node):
"""Avoid redundant initialisation of local variables before their
......@@ -1075,18 +1110,23 @@ class FinalOptimizePhase(Visitor.CythonTransform):
lhs.entry.init = 0
return node
def visit_SimpleCallNode(self, node):
"""Replace generic calls to isinstance(x, type) by a more efficient
type check.
def visit_NoneCheckNode(self, node):
"""Remove NoneCheckNode nodes wrapping nodes that cannot
possibly be None.
FIXME: the list below might be better maintained as a node
class attribute...
"""
self.visitchildren(node)
if node.function.type.is_cfunction and isinstance(node.function, ExprNodes.NameNode):
if node.function.name == 'isinstance':
type_arg = node.args[1]
if type_arg.type.is_builtin_type and type_arg.type.name == 'type':
from CythonScope import utility_scope
node.function.entry = utility_scope.lookup('PyObject_TypeCheck')
node.function.type = node.function.entry.type
PyTypeObjectPtr = PyrexTypes.CPtrType(utility_scope.lookup('PyTypeObject').type)
node.args[1] = ExprNodes.CastNode(node.args[1], PyTypeObjectPtr)
target = node.arg
if isinstance(target, ExprNodes.NoneNode):
return node
if not target.type.is_pyobject:
return target
if isinstance(target, (ExprNodes.ConstNode,
ExprNodes.NumBinopNode)):
return target
if isinstance(target, (ExprNodes.SequenceNode,
ExprNodes.ComprehensionNode,
ExprNodes.SetNode, ExprNodes.DictNode)):
return target
return node
......@@ -1691,6 +1691,9 @@ c_anon_enum_type = CAnonEnumType(-1, 1)
c_py_buffer_type = CStructOrUnionType("Py_buffer", "struct", None, 1, "Py_buffer")
c_py_buffer_ptr_type = CPtrType(c_py_buffer_type)
c_py_type_object_type = CStructOrUnionType("PyTypeObject", "struct", None, 1, "PyTypeObject")
c_py_type_object_ptr_type = CPtrType(c_py_type_object_type)
error_type = ErrorType()
unspecified_type = UnspecifiedType()
......
import unittest
from Cython.Compiler.Visitor import PrintTree
from Cython.TestUtils import TransformTest
from Cython.Compiler.TreePath import find_first, find_all
class TestTreePath(TransformTest):
def test_node_path(self):
t = self.run_pipeline([], u"""
def decorator(fun): # DefNode
return fun # ReturnStatNode, NameNode
@decorator # NameNode
def decorated(): # DefNode
pass
""")
self.assertEquals(2, len(find_all(t, "//DefNode")))
self.assertEquals(2, len(find_all(t, "//NameNode")))
self.assertEquals(1, len(find_all(t, "//ReturnStatNode")))
self.assertEquals(1, len(find_all(t, "//DefNode//ReturnStatNode")))
def test_node_path_child(self):
t = self.run_pipeline([], u"""
def decorator(fun): # DefNode
return fun # ReturnStatNode, NameNode
@decorator # NameNode
def decorated(): # DefNode
pass
""")
self.assertEquals(1, len(find_all(t, "//DefNode/ReturnStatNode/NameNode")))
self.assertEquals(1, len(find_all(t, "//ReturnStatNode/NameNode")))
def test_node_path_attribute_exists(self):
t = self.run_pipeline([], u"""
def decorator(fun):
return fun
@decorator
def decorated():
pass
""")
self.assertEquals(2, len(find_all(t, "//NameNode[@name]")))
def test_node_path_attribute_string_predicate(self):
t = self.run_pipeline([], u"""
def decorator(fun):
return fun
@decorator
def decorated():
pass
""")
self.assertEquals(1, len(find_all(t, "//NameNode[@name = 'decorator']")))
if __name__ == '__main__':
unittest.main()
"""
A simple XPath-like language for tree traversal.
This works by creating a filter chain of generator functions. Each
function selects a part of the expression, e.g. a child node, a
specific descendant or a node that holds an attribute.
"""
import re
path_tokenizer = re.compile(
"("
"'[^']*'|\"[^\"]*\"|"
"//?|"
"\(\)|"
"==?|"
"[/.*\[\]\(\)@])|"
"([^/\[\]\(\)@=\s]+)|"
"\s+"
).findall
def iterchildren(node, attr_name):
# returns an iterable of all child nodes of that name
child = getattr(node, attr_name)
if child is not None:
if type(child) is list:
return child
else:
return [child]
else:
return ()
def _get_first_or_none(it):
try:
try:
_next = it.next
except AttributeError:
return next(it)
else:
return _next()
except StopIteration:
return None
def type_name(node):
return node.__class__.__name__.split('.')[-1]
def parse_func(next, token):
name = token[1]
token = next()
if token[0] != '(':
raise ValueError("Expected '(' after function name '%s'" % name)
predicate = handle_predicate(next, token, end_marker=')')
return name, predicate
def handle_func_not(next, token):
"""
func(...)
"""
name, predicate = parse_func(next, token)
def select(result):
for node in result:
if _get_first_or_none(predicate(node)) is not None:
yield node
return select
def handle_name(next, token):
"""
/NodeName/
or
func(...)
"""
name = token[1]
if name in functions:
return functions[name](next, token)
def select(result):
for node in result:
for attr_name in node.child_attrs:
for child in iterchildren(node, attr_name):
if type_name(child) == name:
yield child
return select
def handle_star(next, token):
"""
/*/
"""
def select(result):
for node in result:
for name in node.child_attrs:
for child in iterchildren(node, name):
yield child
return select
def handle_dot(next, token):
"""
/./
"""
def select(result):
return result
return select
def handle_descendants(next, token):
"""
//...
"""
token = next()
if token[0] == "*":
def iter_recursive(node):
for name in node.child_attrs:
for child in iterchildren(node, name):
yield child
for c in iter_recursive(child):
yield c
elif not token[0]:
node_name = token[1]
def iter_recursive(node):
for name in node.child_attrs:
for child in iterchildren(node, name):
if type_name(child) == node_name:
yield child
for c in iter_recursive(child):
yield c
else:
raise ValueError("Expected node name after '//'")
def select(result):
for node in result:
for child in iter_recursive(node):
yield child
return select
def handle_attribute(next, token):
token = next()
if token[0]:
raise ValueError("Expected attribute name")
name = token[1]
token = next()
value = None
if token[0] == '=':
value = parse_path_value(next)
if value is None:
def select(result):
for node in result:
try:
attr_value = getattr(node, name)
except AttributeError:
continue
if attr_value is not None:
yield attr_value
else:
def select(result):
for node in result:
try:
attr_value = getattr(node, name)
except AttributeError:
continue
if attr_value == value:
yield value
return select
def parse_path_value(next):
token = next()
value = token[0]
if value[:1] == "'" or value[:1] == '"':
value = value[1:-1]
else:
try:
value = int(value)
except ValueError:
raise ValueError("Invalid attribute predicate: '%s'" % value)
return value
def handle_predicate(next, token, end_marker=']'):
token = next()
selector = []
while token[0] != end_marker:
selector.append( operations[token[0]](next, token) )
try:
token = next()
except StopIteration:
break
else:
if token[0] == "/":
token = next()
def select(result):
for node in result:
subresult = iter((node,))
for select in selector:
subresult = select(subresult)
predicate_result = _get_first_or_none(subresult)
if predicate_result is not None:
yield predicate_result
return select
operations = {
"@": handle_attribute,
"": handle_name,
"*": handle_star,
".": handle_dot,
"//": handle_descendants,
"[": handle_predicate,
}
functions = {
'not' : handle_func_not
}
def _build_path_iterator(path):
# parse pattern
stream = iter([ (special,text)
for (special,text) in path_tokenizer(path)
if special or text ])
try:
_next = stream.next
except AttributeError:
# Python 3
def _next():
return next(stream)
token = _next()
selector = []
while 1:
try:
selector.append(operations[token[0]](_next, token))
except StopIteration:
raise ValueError("invalid path")
try:
token = _next()
if token[0] == "/":
token = _next()
except StopIteration:
break
return selector
# main module API
def iterfind(node, path):
selector_chain = _build_path_iterator(path)
result = iter((node,))
for select in selector_chain:
result = select(result)
return result
def find_first(node, path):
return _get_first_or_none(iterfind(node, path))
def find_all(node, path):
return list(iterfind(node, path))
......@@ -8,6 +8,7 @@ unsignedbehaviour_T184
funcexc_iter_T228
bad_c_struct_T252
missing_baseclass_in_predecl_T262
compile_time_unraisable_T370
# Not yet enabled
profile_test
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment