Commit 70abf0f1 authored by Robert Bradshaw's avatar Robert Bradshaw

merge

parents af626302 747a0110
...@@ -25,22 +25,6 @@ from Cython import Utils ...@@ -25,22 +25,6 @@ from Cython import Utils
module_name_pattern = re.compile(r"[A-Za-z_][A-Za-z0-9_]*(\.[A-Za-z_][A-Za-z0-9_]*)*$") module_name_pattern = re.compile(r"[A-Za-z_][A-Za-z0-9_]*(\.[A-Za-z_][A-Za-z0-9_]*)*$")
# Note: PHASES and TransformSet should be removed soon; but that's for
# another day and another commit.
PHASES = [
'before_analyse_function', # run in FuncDefNode.generate_function_definitions
'after_analyse_function' # run in FuncDefNode.generate_function_definitions
]
class TransformSet(dict):
def __init__(self):
for name in PHASES:
self[name] = []
def run(self, name, node, **options):
assert name in self, "Transform phase %s not defined" % name
for transform in self[name]:
transform(node, phase=name, **options)
verbose = 0 verbose = 0
class Context: class Context:
...@@ -295,7 +279,10 @@ class Context: ...@@ -295,7 +279,10 @@ class Context:
else: else:
Errors.open_listing_file(None) Errors.open_listing_file(None)
def teardown_errors(self, errors_occurred, options, result, source_desc): def teardown_errors(self, errors_occurred, options, result):
source_desc = result.compilation_source.source_desc
if not isinstance(source_desc, FileSourceDescriptor):
raise RuntimeError("Only file sources for code supported")
Errors.close_listing_file() Errors.close_listing_file()
result.num_errors = Errors.num_errors result.num_errors = Errors.num_errors
if result.num_errors > 0: if result.num_errors > 0:
...@@ -315,17 +302,16 @@ class Context: ...@@ -315,17 +302,16 @@ class Context:
extra_objects = options.objects, extra_objects = options.objects,
verbose_flag = options.show_version, verbose_flag = options.show_version,
cplus = options.cplus) cplus = options.cplus)
class CompilationSource(object): def run_pipeline(self, pipeline, source):
""" errors_occurred = False
Contains the data necesarry to start up a compilation pipeline for data = source
a single compilation source (= file, usually). try:
""" for phase in pipeline:
def __init__(self, source_desc, full_module_name, cwd): data = phase(data)
self.source_desc = source_desc except CompileError:
self.full_module_name = full_module_name errors_occurred = True
self.cwd = cwd return (errors_occurred, data)
def create_parse(context): def create_parse(context):
def parse(compsrc): def parse(compsrc):
...@@ -339,17 +325,18 @@ def create_parse(context): ...@@ -339,17 +325,18 @@ def create_parse(context):
return tree return tree
return parse return parse
def create_generate_code(context, options): def create_generate_code(context, options, result):
def generate_code(module_node): def generate_code(module_node):
scope = module_node.scope scope = module_node.scope
result = create_default_resultobj(module_node.compilation_source, options)
module_node.process_implementation(options, result) module_node.process_implementation(options, result)
result.compilation_source = module_node.compilation_source
return result return result
return generate_code return generate_code
def create_default_pipeline(context, options): def create_default_pipeline(context, options, result):
from ParseTreeTransforms import WithTransform, PostParse from ParseTreeTransforms import WithTransform, PostParse
from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform, MarkClosureVisitor from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor
from ModuleNode import check_c_classes from ModuleNode import check_c_classes
return [ return [
...@@ -360,12 +347,14 @@ def create_default_pipeline(context, options): ...@@ -360,12 +347,14 @@ def create_default_pipeline(context, options):
AnalyseDeclarationsTransform(), AnalyseDeclarationsTransform(),
check_c_classes, check_c_classes,
AnalyseExpressionsTransform(), AnalyseExpressionsTransform(),
create_generate_code(context, options) CreateClosureClasses(),
create_generate_code(context, options, result)
] ]
def create_default_resultobj(compilation_source, options): def create_default_resultobj(compilation_source, options):
result = CompilationResult() result = CompilationResult()
result.main_source_file = compilation_source.source_desc.filename result.main_source_file = compilation_source.source_desc.filename
result.compilation_source = compilation_source
source_desc = compilation_source.source_desc source_desc = compilation_source.source_desc
if options.output_file: if options.output_file:
result.c_file = os.path.join(compilation_source.cwd, options.output_file) result.c_file = os.path.join(compilation_source.cwd, options.output_file)
...@@ -384,13 +373,9 @@ def create_default_resultobj(compilation_source, options): ...@@ -384,13 +373,9 @@ def create_default_resultobj(compilation_source, options):
pass pass
return result return result
def run_pipeline(source, options = None, full_module_name = None): def run_pipeline(source, options, full_module_name = None):
if not options:
options = default_options
# Set up context # Set up context
context = Context(options.include_path) context = Context(options.include_path)
context.setup_errors(options)
# Set up source object # Set up source object
cwd = os.getcwd() cwd = os.getcwd()
...@@ -398,18 +383,15 @@ def run_pipeline(source, options = None, full_module_name = None): ...@@ -398,18 +383,15 @@ def run_pipeline(source, options = None, full_module_name = None):
full_module_name = full_module_name or context.extract_module_name(source, options) full_module_name = full_module_name or context.extract_module_name(source, options)
source = CompilationSource(source_desc, full_module_name, cwd) source = CompilationSource(source_desc, full_module_name, cwd)
# Set up result object
result = create_default_resultobj(source, options)
# Get pipeline # Get pipeline
pipeline = create_default_pipeline(context, options) pipeline = create_default_pipeline(context, options, result)
data = source context.setup_errors(options)
errors_occurred = False errors_occurred, enddata = context.run_pipeline(pipeline, source)
try: context.teardown_errors(errors_occurred, options, result)
for phase in pipeline:
data = phase(data)
except CompileError:
errors_occurred = True
result = data
context.teardown_errors(errors_occurred, options, result, source_desc)
return result return result
#------------------------------------------------------------------------ #------------------------------------------------------------------------
...@@ -418,6 +400,16 @@ def run_pipeline(source, options = None, full_module_name = None): ...@@ -418,6 +400,16 @@ def run_pipeline(source, options = None, full_module_name = None):
# #
#------------------------------------------------------------------------ #------------------------------------------------------------------------
class CompilationSource(object):
"""
Contains the data necesarry to start up a compilation pipeline for
a single compilation unit.
"""
def __init__(self, source_desc, full_module_name, cwd):
self.source_desc = source_desc
self.full_module_name = full_module_name
self.cwd = cwd
class CompilationOptions: class CompilationOptions:
""" """
Options to the Cython compiler: Options to the Cython compiler:
...@@ -433,7 +425,6 @@ class CompilationOptions: ...@@ -433,7 +425,6 @@ class CompilationOptions:
defaults to true when recursive is true. defaults to true when recursive is true.
verbose boolean Always print source names being compiled verbose boolean Always print source names being compiled
quiet boolean Don't print source names in recursive mode quiet boolean Don't print source names in recursive mode
transforms Transform.TransformSet Transforms to use on the parse tree
Following options are experimental and only used on MacOSX: Following options are experimental and only used on MacOSX:
...@@ -471,6 +462,7 @@ class CompilationResult: ...@@ -471,6 +462,7 @@ class CompilationResult:
object_file string or None Result of compiling the C file object_file string or None Result of compiling the C file
extension_file string or None Result of linking the object file extension_file string or None Result of linking the object file
num_errors integer Number of compilation errors num_errors integer Number of compilation errors
compilation_source CompilationSource
""" """
def __init__(self): def __init__(self):
...@@ -620,9 +612,9 @@ default_options = dict( ...@@ -620,9 +612,9 @@ default_options = dict(
output_file = None, output_file = None,
annotate = False, annotate = False,
generate_pxi = 0, generate_pxi = 0,
transforms = TransformSet(),
working_path = "", working_path = "",
recursive = 0, recursive = 0,
transforms = None, # deprecated
timestamps = None, timestamps = None,
verbose = 0, verbose = 0,
quiet = 0) quiet = 0)
......
from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle
from Cython.Compiler.ModuleNode import ModuleNode
from Cython.Compiler.Nodes import * from Cython.Compiler.Nodes import *
from Cython.Compiler.ExprNodes import * from Cython.Compiler.ExprNodes import *
from Cython.Compiler.TreeFragment import TreeFragment from Cython.Compiler.TreeFragment import TreeFragment
class PostParse(VisitorTransform): class PostParse(VisitorTransform):
""" """
This transform fixes up a few things after parsing This transform fixes up a few things after parsing
...@@ -170,7 +170,6 @@ class AnalyseDeclarationsTransform(VisitorTransform): ...@@ -170,7 +170,6 @@ class AnalyseDeclarationsTransform(VisitorTransform):
class AnalyseExpressionsTransform(VisitorTransform): class AnalyseExpressionsTransform(VisitorTransform):
def visit_ModuleNode(self, node): def visit_ModuleNode(self, node):
node.body.analyse_expressions(node.scope) node.body.analyse_expressions(node.scope)
self.visitchildren(node) self.visitchildren(node)
...@@ -208,3 +207,35 @@ class MarkClosureVisitor(VisitorTransform): ...@@ -208,3 +207,35 @@ class MarkClosureVisitor(VisitorTransform):
self.visitchildren(node) self.visitchildren(node)
return node return node
class CreateClosureClasses(VisitorTransform):
# Output closure classes in module scope for all functions
# that need it.
def visit_ModuleNode(self, node):
self.module_scope = node.scope
self.visitchildren(node)
return node
def create_class_from_scope(self, node, target_module_scope):
as_name = temp_name_handle("closure")
func_scope = node.local_scope
entry = target_module_scope.declare_c_class(name = as_name,
pos = node.pos, defining = True, implementing = True)
class_scope = entry.type.scope
for entry in func_scope.entries.values():
class_scope.declare_var(pos=node.pos,
name=entry.name,
cname=entry.cname,
type=entry.type,
is_cdef=True)
def visit_FuncDefNode(self, node):
self.create_class_from_scope(node, self.module_scope)
return node
def visit_Node(self, node):
self.visitchildren(node)
return node
...@@ -1386,8 +1386,8 @@ def p_statement(s, ctx, first_statement = 0): ...@@ -1386,8 +1386,8 @@ def p_statement(s, ctx, first_statement = 0):
if ctx.api: if ctx.api:
error(s.pos, "'api' not allowed with this statement") error(s.pos, "'api' not allowed with this statement")
elif s.sy == 'def': elif s.sy == 'def':
# if ctx.level not in ('module', 'class', 'c_class', 'property'): if ctx.level not in ('module', 'class', 'c_class', 'function', 'property'):
# s.error('def statement not allowed here') s.error('def statement not allowed here')
s.level = ctx.level s.level = ctx.level
return p_def_statement(s) return p_def_statement(s)
elif s.sy == 'class': elif s.sy == 'class':
......
...@@ -27,6 +27,15 @@ class NodeTypeWriter(TreeVisitor): ...@@ -27,6 +27,15 @@ class NodeTypeWriter(TreeVisitor):
self.visitchildren(node) self.visitchildren(node)
self._indents -= 1 self._indents -= 1
def treetypes(root):
"""Returns a string representing the tree by class names.
There's a leading and trailing whitespace so that it can be
compared by simple string comparison while still making test
cases look ok."""
w = NodeTypeWriter()
w.visit(root)
return u"\n".join([u""] + w.result + [u""])
class CythonTest(unittest.TestCase): class CythonTest(unittest.TestCase):
def assertLines(self, expected, result): def assertLines(self, expected, result):
...@@ -58,13 +67,7 @@ class CythonTest(unittest.TestCase): ...@@ -58,13 +67,7 @@ class CythonTest(unittest.TestCase):
return TreeFragment(code, name, pxds) return TreeFragment(code, name, pxds)
def treetypes(self, root): def treetypes(self, root):
"""Returns a string representing the tree by class names. return treetypes(root)
There's a leading and trailing whitespace so that it can be
compared by simple string comparison while still making test
cases look ok."""
w = NodeTypeWriter()
w.visit(root)
return u"\n".join([u""] + w.result + [u""])
class TransformTest(CythonTest): class TransformTest(CythonTest):
""" """
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment