Commit 70abf0f1 authored by Robert Bradshaw's avatar Robert Bradshaw

merge

parents af626302 747a0110
......@@ -25,22 +25,6 @@ from Cython import Utils
module_name_pattern = re.compile(r"[A-Za-z_][A-Za-z0-9_]*(\.[A-Za-z_][A-Za-z0-9_]*)*$")
# Note: PHASES and TransformSet should be removed soon; but that's for
# another day and another commit.
PHASES = [
'before_analyse_function', # run in FuncDefNode.generate_function_definitions
'after_analyse_function' # run in FuncDefNode.generate_function_definitions
]
class TransformSet(dict):
def __init__(self):
for name in PHASES:
self[name] = []
def run(self, name, node, **options):
assert name in self, "Transform phase %s not defined" % name
for transform in self[name]:
transform(node, phase=name, **options)
verbose = 0
class Context:
......@@ -295,7 +279,10 @@ class Context:
else:
Errors.open_listing_file(None)
def teardown_errors(self, errors_occurred, options, result, source_desc):
def teardown_errors(self, errors_occurred, options, result):
source_desc = result.compilation_source.source_desc
if not isinstance(source_desc, FileSourceDescriptor):
raise RuntimeError("Only file sources for code supported")
Errors.close_listing_file()
result.num_errors = Errors.num_errors
if result.num_errors > 0:
......@@ -316,16 +303,15 @@ class Context:
verbose_flag = options.show_version,
cplus = options.cplus)
class CompilationSource(object):
"""
Contains the data necesarry to start up a compilation pipeline for
a single compilation source (= file, usually).
"""
def __init__(self, source_desc, full_module_name, cwd):
self.source_desc = source_desc
self.full_module_name = full_module_name
self.cwd = cwd
def run_pipeline(self, pipeline, source):
errors_occurred = False
data = source
try:
for phase in pipeline:
data = phase(data)
except CompileError:
errors_occurred = True
return (errors_occurred, data)
def create_parse(context):
def parse(compsrc):
......@@ -339,17 +325,18 @@ def create_parse(context):
return tree
return parse
def create_generate_code(context, options):
def create_generate_code(context, options, result):
def generate_code(module_node):
scope = module_node.scope
result = create_default_resultobj(module_node.compilation_source, options)
module_node.process_implementation(options, result)
result.compilation_source = module_node.compilation_source
return result
return generate_code
def create_default_pipeline(context, options):
def create_default_pipeline(context, options, result):
from ParseTreeTransforms import WithTransform, PostParse
from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform, MarkClosureVisitor
from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor
from ModuleNode import check_c_classes
return [
......@@ -360,12 +347,14 @@ def create_default_pipeline(context, options):
AnalyseDeclarationsTransform(),
check_c_classes,
AnalyseExpressionsTransform(),
create_generate_code(context, options)
CreateClosureClasses(),
create_generate_code(context, options, result)
]
def create_default_resultobj(compilation_source, options):
result = CompilationResult()
result.main_source_file = compilation_source.source_desc.filename
result.compilation_source = compilation_source
source_desc = compilation_source.source_desc
if options.output_file:
result.c_file = os.path.join(compilation_source.cwd, options.output_file)
......@@ -384,13 +373,9 @@ def create_default_resultobj(compilation_source, options):
pass
return result
def run_pipeline(source, options = None, full_module_name = None):
if not options:
options = default_options
def run_pipeline(source, options, full_module_name = None):
# Set up context
context = Context(options.include_path)
context.setup_errors(options)
# Set up source object
cwd = os.getcwd()
......@@ -398,18 +383,15 @@ def run_pipeline(source, options = None, full_module_name = None):
full_module_name = full_module_name or context.extract_module_name(source, options)
source = CompilationSource(source_desc, full_module_name, cwd)
# Set up result object
result = create_default_resultobj(source, options)
# Get pipeline
pipeline = create_default_pipeline(context, options)
pipeline = create_default_pipeline(context, options, result)
data = source
errors_occurred = False
try:
for phase in pipeline:
data = phase(data)
except CompileError:
errors_occurred = True
result = data
context.teardown_errors(errors_occurred, options, result, source_desc)
context.setup_errors(options)
errors_occurred, enddata = context.run_pipeline(pipeline, source)
context.teardown_errors(errors_occurred, options, result)
return result
#------------------------------------------------------------------------
......@@ -418,6 +400,16 @@ def run_pipeline(source, options = None, full_module_name = None):
#
#------------------------------------------------------------------------
class CompilationSource(object):
"""
Contains the data necesarry to start up a compilation pipeline for
a single compilation unit.
"""
def __init__(self, source_desc, full_module_name, cwd):
self.source_desc = source_desc
self.full_module_name = full_module_name
self.cwd = cwd
class CompilationOptions:
"""
Options to the Cython compiler:
......@@ -433,7 +425,6 @@ class CompilationOptions:
defaults to true when recursive is true.
verbose boolean Always print source names being compiled
quiet boolean Don't print source names in recursive mode
transforms Transform.TransformSet Transforms to use on the parse tree
Following options are experimental and only used on MacOSX:
......@@ -471,6 +462,7 @@ class CompilationResult:
object_file string or None Result of compiling the C file
extension_file string or None Result of linking the object file
num_errors integer Number of compilation errors
compilation_source CompilationSource
"""
def __init__(self):
......@@ -620,9 +612,9 @@ default_options = dict(
output_file = None,
annotate = False,
generate_pxi = 0,
transforms = TransformSet(),
working_path = "",
recursive = 0,
transforms = None, # deprecated
timestamps = None,
verbose = 0,
quiet = 0)
......
from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle
from Cython.Compiler.ModuleNode import ModuleNode
from Cython.Compiler.Nodes import *
from Cython.Compiler.ExprNodes import *
from Cython.Compiler.TreeFragment import TreeFragment
class PostParse(VisitorTransform):
"""
This transform fixes up a few things after parsing
......@@ -170,7 +170,6 @@ class AnalyseDeclarationsTransform(VisitorTransform):
class AnalyseExpressionsTransform(VisitorTransform):
def visit_ModuleNode(self, node):
node.body.analyse_expressions(node.scope)
self.visitchildren(node)
......@@ -208,3 +207,35 @@ class MarkClosureVisitor(VisitorTransform):
self.visitchildren(node)
return node
class CreateClosureClasses(VisitorTransform):
# Output closure classes in module scope for all functions
# that need it.
def visit_ModuleNode(self, node):
self.module_scope = node.scope
self.visitchildren(node)
return node
def create_class_from_scope(self, node, target_module_scope):
as_name = temp_name_handle("closure")
func_scope = node.local_scope
entry = target_module_scope.declare_c_class(name = as_name,
pos = node.pos, defining = True, implementing = True)
class_scope = entry.type.scope
for entry in func_scope.entries.values():
class_scope.declare_var(pos=node.pos,
name=entry.name,
cname=entry.cname,
type=entry.type,
is_cdef=True)
def visit_FuncDefNode(self, node):
self.create_class_from_scope(node, self.module_scope)
return node
def visit_Node(self, node):
self.visitchildren(node)
return node
......@@ -1386,8 +1386,8 @@ def p_statement(s, ctx, first_statement = 0):
if ctx.api:
error(s.pos, "'api' not allowed with this statement")
elif s.sy == 'def':
# if ctx.level not in ('module', 'class', 'c_class', 'property'):
# s.error('def statement not allowed here')
if ctx.level not in ('module', 'class', 'c_class', 'function', 'property'):
s.error('def statement not allowed here')
s.level = ctx.level
return p_def_statement(s)
elif s.sy == 'class':
......
......@@ -27,6 +27,15 @@ class NodeTypeWriter(TreeVisitor):
self.visitchildren(node)
self._indents -= 1
def treetypes(root):
"""Returns a string representing the tree by class names.
There's a leading and trailing whitespace so that it can be
compared by simple string comparison while still making test
cases look ok."""
w = NodeTypeWriter()
w.visit(root)
return u"\n".join([u""] + w.result + [u""])
class CythonTest(unittest.TestCase):
def assertLines(self, expected, result):
......@@ -58,13 +67,7 @@ class CythonTest(unittest.TestCase):
return TreeFragment(code, name, pxds)
def treetypes(self, root):
"""Returns a string representing the tree by class names.
There's a leading and trailing whitespace so that it can be
compared by simple string comparison while still making test
cases look ok."""
w = NodeTypeWriter()
w.visit(root)
return u"\n".join([u""] + w.result + [u""])
return treetypes(root)
class TransformTest(CythonTest):
"""
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment