Commit e969ef32 authored by Stefan Behnel's avatar Stefan Behnel

Simplify optimisation code for cascaded comparisons and improve its test coverage.

parent 714123e3
...@@ -4527,22 +4527,20 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations): ...@@ -4527,22 +4527,20 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
cascades = [[node.operand1]] cascades = [[node.operand1]]
final_false_result = [] final_false_result = []
def split_cascades(cmp_node): cmp_node = node
while cmp_node is not None:
if cmp_node.has_constant_result(): if cmp_node.has_constant_result():
if not cmp_node.constant_result: if not cmp_node.constant_result:
# False => short-circuit # False => short-circuit
final_false_result.append(self._bool_node(cmp_node, False)) final_false_result.append(self._bool_node(cmp_node, False))
return break
else: else:
# True => discard and start new cascade # True => discard and start new cascade
cascades.append([cmp_node.operand2]) cascades.append([cmp_node.operand2])
else: else:
# not constant => append to current cascade # not constant => append to current cascade
cascades[-1].append(cmp_node) cascades[-1].append(cmp_node)
if cmp_node.cascade: cmp_node = cmp_node.cascade
split_cascades(cmp_node.cascade)
split_cascades(node)
cmp_nodes = [] cmp_nodes = []
for cascade in cascades: for cascade in cascades:
......
# mode: compile
cdef void foo():
cdef int bool, int1=0, int2=0, int3=0, int4=0
cdef object obj1, obj2, obj3, obj4
obj1 = 1
obj2 = 2
obj3 = 3
obj4 = 4
bool = int1 < int2 < int3
bool = obj1 < obj2 < obj3
bool = int1 < int2 < obj3
bool = obj1 < 2 < 3
bool = obj1 < 2 < 3 < 4
bool = int1 < (int2 == int3) < int4
foo()
# mode: run
# tag: cascade, compare
def ints_and_objects():
"""
>>> ints_and_objects()
(0, 1, 0, 1, 1, 0)
"""
cdef int int1=0, int2=0, int3=0, int4=0
cdef int r1, r2, r3, r4, r5, r6
cdef object obj1, obj2, obj3, obj4
obj1 = 1
obj2 = 2
obj3 = 3
obj4 = 4
r1 = int1 < int2 < int3
r2 = obj1 < obj2 < obj3
r3 = int1 < int2 < obj3
r4 = obj1 < 2 < 3
r5 = obj1 < 2 < 3 < 4
r6 = int1 < (int2 == int3) < int4
return r1, r2, r3, r4, r5, r6
def const_cascade(x):
"""
>>> const_cascade(2)
(True, False, True, False, False, True, False)
"""
return (
0 <= 1,
1 <= 0,
1 <= 1 <= 2,
1 <= 0 < 1,
1 <= 1 <= 0,
1 <= 1 <= x <= 2 <= 3 > x <= 2 <= 2,
1 <= 1 <= x <= 1 <= 1 <= x <= 2,
)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment