Commit 16aba7eb authored by Matus Valo's avatar Matus Valo Committed by GitHub

Execute AlignFunctionDefinitions before MarkClosureTransform. (GH-4127)

This commit fixes a crash of Cython when generator expressions are used in cdef functions in pure python mode.
Closes https://github.com/cython/cython/issues/3477
parent b1f7d592
...@@ -2496,8 +2496,6 @@ class AlignFunctionDefinitions(CythonTransform): ...@@ -2496,8 +2496,6 @@ class AlignFunctionDefinitions(CythonTransform):
def visit_ModuleNode(self, node): def visit_ModuleNode(self, node):
self.scope = node.scope self.scope = node.scope
self.directives = node.directives
self.imported_names = set() # hack, see visit_FromImportStatNode()
self.visitchildren(node) self.visitchildren(node)
return node return node
...@@ -2535,15 +2533,45 @@ class AlignFunctionDefinitions(CythonTransform): ...@@ -2535,15 +2533,45 @@ class AlignFunctionDefinitions(CythonTransform):
error(pxd_def.pos, "previous declaration here") error(pxd_def.pos, "previous declaration here")
return None return None
node = node.as_cfunction(pxd_def) node = node.as_cfunction(pxd_def)
elif (self.scope.is_module_scope and self.directives['auto_cpdef']
and node.name not in self.imported_names
and node.is_cdef_func_compatible()):
# FIXME: cpdef-ing should be done in analyse_declarations()
node = node.as_cfunction(scope=self.scope)
# Enable this when nested cdef functions are allowed. # Enable this when nested cdef functions are allowed.
# self.visitchildren(node) # self.visitchildren(node)
return node return node
def visit_ExprNode(self, node):
# ignore lambdas and everything else that appears in expressions
return node
class AutoCpdefFunctionDefinitions(CythonTransform):
def visit_ModuleNode(self, node):
self.directives = node.directives
self.imported_names = set() # hack, see visit_FromImportStatNode()
self.scope = node.scope
self.visitchildren(node)
return node
def visit_DefNode(self, node):
if (self.scope.is_module_scope and self.directives['auto_cpdef']
and node.name not in self.imported_names
and node.is_cdef_func_compatible()):
# FIXME: cpdef-ing should be done in analyse_declarations()
node = node.as_cfunction(scope=self.scope)
return node
def visit_CClassDefNode(self, node, pxd_def=None):
if pxd_def is None:
pxd_def = self.scope.lookup(node.class_name)
if pxd_def:
if not pxd_def.defined_in_pxd:
return node
outer_scope = self.scope
self.scope = pxd_def.type.scope
self.visitchildren(node)
if pxd_def:
self.scope = outer_scope
return node
def visit_FromImportStatNode(self, node): def visit_FromImportStatNode(self, node):
# hack to prevent conditional import fallback functions from # hack to prevent conditional import fallback functions from
# being cdpef-ed (global Python variables currently conflict # being cdpef-ed (global Python variables currently conflict
......
...@@ -149,7 +149,7 @@ def create_pipeline(context, mode, exclude_classes=()): ...@@ -149,7 +149,7 @@ def create_pipeline(context, mode, exclude_classes=()):
from .ParseTreeTransforms import ExpandInplaceOperators, ParallelRangeTransform from .ParseTreeTransforms import ExpandInplaceOperators, ParallelRangeTransform
from .ParseTreeTransforms import CalculateQualifiedNamesTransform from .ParseTreeTransforms import CalculateQualifiedNamesTransform
from .TypeInference import MarkParallelAssignments, MarkOverflowingArithmetic from .TypeInference import MarkParallelAssignments, MarkOverflowingArithmetic
from .ParseTreeTransforms import AdjustDefByDirectives, AlignFunctionDefinitions from .ParseTreeTransforms import AdjustDefByDirectives, AlignFunctionDefinitions, AutoCpdefFunctionDefinitions
from .ParseTreeTransforms import RemoveUnreachableCode, GilCheck from .ParseTreeTransforms import RemoveUnreachableCode, GilCheck
from .FlowControl import ControlFlowAnalysis from .FlowControl import ControlFlowAnalysis
from .AnalysedTreeTransforms import AutoTestDictTransform from .AnalysedTreeTransforms import AutoTestDictTransform
...@@ -186,10 +186,11 @@ def create_pipeline(context, mode, exclude_classes=()): ...@@ -186,10 +186,11 @@ def create_pipeline(context, mode, exclude_classes=()):
TrackNumpyAttributes(), TrackNumpyAttributes(),
InterpretCompilerDirectives(context, context.compiler_directives), InterpretCompilerDirectives(context, context.compiler_directives),
ParallelRangeTransform(context), ParallelRangeTransform(context),
AdjustDefByDirectives(context),
WithTransform(context), WithTransform(context),
MarkClosureVisitor(context), AdjustDefByDirectives(context),
_align_function_definitions, _align_function_definitions,
MarkClosureVisitor(context),
AutoCpdefFunctionDefinitions(context),
RemoveUnreachableCode(context), RemoveUnreachableCode(context),
ConstantFolding(), ConstantFolding(),
FlattenInListTransform(), FlattenInListTransform(),
......
...@@ -55,6 +55,17 @@ def func(a, b, c): ...@@ -55,6 +55,17 @@ def func(a, b, c):
""" """
return a + b + c return a + b + c
def sum_generator_expression(a):
# GH-3477 - closure variables incorrectly captured in functions transformed to cdef
return sum(i for i in range(a))
def run_sum_generator_expression(a):
"""
>>> run_sum_generator_expression(5)
10
"""
return sum_generator_expression(a)
def test(module): def test(module):
import os.path import os.path
...@@ -95,3 +106,6 @@ cdef class TypedMethod: ...@@ -95,3 +106,6 @@ cdef class TypedMethod:
cpdef int func(x, int y, z) except? -1 # argument names should not matter, types should cpdef int func(x, int y, z) except? -1 # argument names should not matter, types should
cdef int sum_generator_expression(int a)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment