Commit a0ede829 authored by Stefan Behnel's avatar Stefan Behnel

extend switch transform to not-in tests, some refactoring

parent 628d4d97
...@@ -507,7 +507,9 @@ class SwitchTransform(Visitor.VisitorTransform): ...@@ -507,7 +507,9 @@ class SwitchTransform(Visitor.VisitorTransform):
The requirement is that every clause be an (or of) var == value, where the var The requirement is that every clause be an (or of) var == value, where the var
is common among all clauses and both var and value are ints. is common among all clauses and both var and value are ints.
""" """
def extract_conditions(self, cond): NO_MATCH = (None, None, None)
def extract_conditions(self, cond, allow_not_in):
while True: while True:
if isinstance(cond, ExprNodes.CoerceToTempNode): if isinstance(cond, ExprNodes.CoerceToTempNode):
cond = cond.arg cond = cond.arg
...@@ -519,51 +521,80 @@ class SwitchTransform(Visitor.VisitorTransform): ...@@ -519,51 +521,80 @@ class SwitchTransform(Visitor.VisitorTransform):
else: else:
break break
if (isinstance(cond, ExprNodes.PrimaryCmpNode) if isinstance(cond, ExprNodes.PrimaryCmpNode):
and cond.cascade is None if cond.cascade is None and not cond.is_python_comparison():
and cond.operator == '==' if cond.operator == '==':
and not cond.is_python_comparison()): not_in = False
elif allow_not_in and cond.operator == '!=':
not_in = True
else:
return self.NO_MATCH
# this looks somewhat silly, but it does the right
# checks for NameNode and AttributeNode
if is_common_value(cond.operand1, cond.operand1): if is_common_value(cond.operand1, cond.operand1):
if cond.operand2.is_literal: if cond.operand2.is_literal:
return cond.operand1, [cond.operand2] return not_in, cond.operand1, [cond.operand2]
elif getattr(cond.operand2, 'entry', None) and cond.operand2.entry.is_const: elif getattr(cond.operand2, 'entry', None) \
return cond.operand1, [cond.operand2] and cond.operand2.entry.is_const:
return not_in, cond.operand1, [cond.operand2]
if is_common_value(cond.operand2, cond.operand2): if is_common_value(cond.operand2, cond.operand2):
if cond.operand1.is_literal: if cond.operand1.is_literal:
return cond.operand2, [cond.operand1] return not_in, cond.operand2, [cond.operand1]
elif getattr(cond.operand1, 'entry', None) and cond.operand1.entry.is_const: elif getattr(cond.operand1, 'entry', None) \
return cond.operand2, [cond.operand1] and cond.operand1.entry.is_const:
elif (isinstance(cond, ExprNodes.BoolBinopNode) return not_in, cond.operand2, [cond.operand1]
and cond.operator == 'or'): elif isinstance(cond, ExprNodes.BoolBinopNode):
t1, c1 = self.extract_conditions(cond.operand1) if cond.operator == 'or' or (allow_not_in and cond.operator == 'and'):
t2, c2 = self.extract_conditions(cond.operand2) allow_not_in = (cond.operator == 'and')
if is_common_value(t1, t2): not_in_1, t1, c1 = self.extract_conditions(cond.operand1, allow_not_in)
return t1, c1+c2 not_in_2, t2, c2 = self.extract_conditions(cond.operand2, allow_not_in)
return None, None if t1 is not None and not_in_1 == not_in_2 and is_common_value(t1, t2):
if (not not_in_1) or allow_not_in:
def extract_common_conditions(self, common_var, condition): return not_in_1, t1, c1+c2
var, conditions = self.extract_conditions(condition) return self.NO_MATCH
def extract_common_conditions(self, common_var, condition, allow_not_in):
not_in, var, conditions = self.extract_conditions(condition, allow_not_in)
if var is None: if var is None:
return None, None return self.NO_MATCH
elif common_var is not None and not is_common_value(var, common_var): elif common_var is not None and not is_common_value(var, common_var):
return None, None return self.NO_MATCH
elif not var.type.is_int or sum([not cond.type.is_int for cond in conditions]): elif not var.type.is_int or sum([not cond.type.is_int for cond in conditions]):
return None, None return self.NO_MATCH
return var, conditions return not_in, var, conditions
def has_duplicate_values(self, condition_values):
# duplicated values don't work in a switch statement
seen = set()
for value in condition_values:
if value.constant_result is not ExprNodes.not_a_constant:
if value.constant_result in seen:
return True
seen.add(value.constant_result)
else:
# this isn't completely safe as we don't know the
# final C value, but this is about the best we can do
seen.add(getattr(getattr(value, 'entry', None), 'cname'))
return False
def visit_IfStatNode(self, node): def visit_IfStatNode(self, node):
common_var = None common_var = None
cases = [] cases = []
for if_clause in node.if_clauses: for if_clause in node.if_clauses:
common_var, conditions = self.extract_common_conditions( _, common_var, conditions = self.extract_common_conditions(
common_var, if_clause.condition) common_var, if_clause.condition, False)
if common_var is None: if common_var is None:
self.visitchildren(node)
return node return node
cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos, cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos,
conditions = conditions, conditions = conditions,
body = if_clause.body)) body = if_clause.body))
if sum([ len(case.conditions) for case in cases ]) < 2: if sum([ len(case.conditions) for case in cases ]) < 2:
self.visitchildren(node)
return node
if self.has_duplicate_values(sum([case.conditions for case in cases], [])):
self.visitchildren(node)
return node return node
common_var = unwrap_node(common_var) common_var = unwrap_node(common_var)
...@@ -571,59 +602,51 @@ class SwitchTransform(Visitor.VisitorTransform): ...@@ -571,59 +602,51 @@ class SwitchTransform(Visitor.VisitorTransform):
test = common_var, test = common_var,
cases = cases, cases = cases,
else_clause = node.else_clause) else_clause = node.else_clause)
self.visitchildren(switch_node)
return switch_node return switch_node
def visit_CondExprNode(self, node): def visit_CondExprNode(self, node):
common_var, conditions = self.extract_common_conditions(None, node.test) not_in, common_var, conditions = self.extract_common_conditions(
if common_var is None: None, node.test, True)
return node if common_var is None \
if len(conditions) < 2: or len(conditions) < 2 \
or self.has_duplicate_values(conditions):
self.visitchildren(node)
return node return node
return self.build_simple_switch_statement(
result_ref = UtilNodes.ResultRefNode(node) node, common_var, conditions, not_in,
true_body = Nodes.SingleAssignmentNode( node.true_val, node.false_val)
node.pos,
lhs = result_ref,
rhs = node.true_val,
first = True)
false_body = Nodes.SingleAssignmentNode(
node.pos,
lhs = result_ref,
rhs = node.false_val,
first = True)
cases = [Nodes.SwitchCaseNode(pos = node.pos,
conditions = conditions,
body = true_body)]
common_var = unwrap_node(common_var)
switch_node = Nodes.SwitchStatNode(pos = node.pos,
test = common_var,
cases = cases,
else_clause = false_body)
self.visitchildren(switch_node)
return UtilNodes.TempResultFromStatNode(result_ref, switch_node)
def visit_BoolBinopNode(self, node): def visit_BoolBinopNode(self, node):
common_var, conditions = self.extract_common_conditions(None, node) not_in, common_var, conditions = self.extract_common_conditions(
if common_var is None: None, node, True)
return node if common_var is None \
if len(conditions) < 2: or len(conditions) < 2 \
or self.has_duplicate_values(conditions):
self.visitchildren(node)
return node return node
return self.build_simple_switch_statement(
node, common_var, conditions, not_in,
ExprNodes.BoolNode(node.pos, value=True),
ExprNodes.BoolNode(node.pos, value=False))
def build_simple_switch_statement(self, node, common_var, conditions,
not_in, true_val, false_val):
result_ref = UtilNodes.ResultRefNode(node) result_ref = UtilNodes.ResultRefNode(node)
true_body = Nodes.SingleAssignmentNode( true_body = Nodes.SingleAssignmentNode(
node.pos, node.pos,
lhs = result_ref, lhs = result_ref,
rhs = ExprNodes.BoolNode(node.pos, value=True), rhs = true_val,
first = True) first = True)
false_body = Nodes.SingleAssignmentNode( false_body = Nodes.SingleAssignmentNode(
node.pos, node.pos,
lhs = result_ref, lhs = result_ref,
rhs = ExprNodes.BoolNode(node.pos, value=False), rhs = false_val,
first = True) first = True)
if not_in:
true_body, false_body = false_body, true_body
cases = [Nodes.SwitchCaseNode(pos = node.pos, cases = [Nodes.SwitchCaseNode(pos = node.pos,
conditions = conditions, conditions = conditions,
body = true_body)] body = true_body)]
...@@ -633,7 +656,6 @@ class SwitchTransform(Visitor.VisitorTransform): ...@@ -633,7 +656,6 @@ class SwitchTransform(Visitor.VisitorTransform):
test = common_var, test = common_var,
cases = cases, cases = cases,
else_clause = false_body) else_clause = false_body)
self.visitchildren(switch_node)
return UtilNodes.TempResultFromStatNode(result_ref, switch_node) return UtilNodes.TempResultFromStatNode(result_ref, switch_node)
visit_Node = Visitor.VisitorTransform.recurse_to_children visit_Node = Visitor.VisitorTransform.recurse_to_children
......
cimport cython
def f(a,b): def f(a,b):
""" """
>>> f(1,[1,2,3]) >>> f(1,[1,2,3])
...@@ -44,6 +47,7 @@ def j(b): ...@@ -44,6 +47,7 @@ def j(b):
result = 2 not in b result = 2 not in b
return result return result
@cython.test_fail_if_path_exists("//SwitchStatNode")
def k(a): def k(a):
""" """
>>> k(1) >>> k(1)
...@@ -54,16 +58,86 @@ def k(a): ...@@ -54,16 +58,86 @@ def k(a):
cdef int result = a not in [1,2,3,4] cdef int result = a not in [1,2,3,4]
return result return result
def m(int a): @cython.test_assert_path_exists("//SwitchStatNode")
@cython.test_fail_if_path_exists("//PrimaryCmpNode")
def m_list(int a):
""" """
>>> m(2) >>> m_list(2)
0 0
>>> m(5) >>> m_list(5)
1 1
""" """
cdef int result = a not in [1,2,3,4] cdef int result = a not in [1,2,3,4]
return result return result
@cython.test_assert_path_exists("//SwitchStatNode")
@cython.test_fail_if_path_exists("//PrimaryCmpNode")
def m_tuple(int a):
"""
>>> m_tuple(2)
0
>>> m_tuple(5)
1
"""
cdef int result = a not in (1,2,3,4)
return result
@cython.test_assert_path_exists("//SwitchStatNode", "//BoolBinopNode")
@cython.test_fail_if_path_exists("//PrimaryCmpNode")
def m_tuple_in_or_notin(int a):
"""
>>> m_tuple_in_or_notin(2)
0
>>> m_tuple_in_or_notin(3)
1
>>> m_tuple_in_or_notin(5)
1
"""
cdef int result = a not in (1,2,3,4) or a in (3,4)
return result
@cython.test_assert_path_exists("//SwitchStatNode", "//BoolBinopNode")
@cython.test_fail_if_path_exists("//PrimaryCmpNode")
def m_tuple_notin_or_notin(int a):
"""
>>> m_tuple_notin_or_notin(2)
1
>>> m_tuple_notin_or_notin(6)
1
>>> m_tuple_notin_or_notin(4)
0
"""
cdef int result = a not in (1,2,3,4) or a not in (4,5)
return result
@cython.test_assert_path_exists("//SwitchStatNode")
@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode")
def m_tuple_notin_and_notin(int a):
"""
>>> m_tuple_notin_and_notin(2)
0
>>> m_tuple_notin_and_notin(6)
0
>>> m_tuple_notin_and_notin(5)
1
"""
cdef int result = a not in (1,2,3,4) and a not in (6,7)
return result
@cython.test_assert_path_exists("//SwitchStatNode", "//BoolBinopNode")
@cython.test_fail_if_path_exists("//PrimaryCmpNode")
def m_tuple_notin_and_notin_overlap(int a):
"""
>>> m_tuple_notin_and_notin_overlap(2)
0
>>> m_tuple_notin_and_notin_overlap(4)
0
>>> m_tuple_notin_and_notin_overlap(5)
1
"""
cdef int result = a not in (1,2,3,4) and a not in (3,4)
return result
def n(a): def n(a):
""" """
>>> n('d *') >>> n('d *')
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment