Commit 626301a5 authored by Stefan Behnel's avatar Stefan Behnel

merge

parents 4978fea0 2ac6a622
...@@ -6,6 +6,7 @@ syntax: glob ...@@ -6,6 +6,7 @@ syntax: glob
Cython/Compiler/Lexicon.pickle Cython/Compiler/Lexicon.pickle
BUILD/ BUILD/
build/ build/
dist/
.coverage .coverage
*~ *~
*.orig *.orig
......
print "Warning: Using prototype cython.inline code..."
import tempfile
import sys, os, re, inspect
try:
import hashlib
except ImportError:
import md5 as hashlib
from distutils.dist import Distribution
from Cython.Distutils.extension import Extension
from Cython.Distutils import build_ext
from Cython.Compiler.Main import Context, CompilationOptions, default_options
from Cython.Compiler.ParseTreeTransforms import CythonTransform, SkipDeclarations, AnalyseDeclarationsTransform
from Cython.Compiler.TreeFragment import parse_from_strings
_code_cache = {}
class AllSymbols(CythonTransform, SkipDeclarations):
def __init__(self):
CythonTransform.__init__(self, None)
self.names = set()
def visit_NameNode(self, node):
self.names.add(node.name)
def unbound_symbols(code, context=None):
if context is None:
context = Context([], default_options)
from Cython.Compiler.ParseTreeTransforms import AnalyseDeclarationsTransform
if isinstance(code, str):
code = code.decode('ascii')
tree = parse_from_strings('(tree fragment)', code)
for phase in context.create_pipeline(pxd=False):
if phase is None:
continue
tree = phase(tree)
if isinstance(phase, AnalyseDeclarationsTransform):
break
symbol_collector = AllSymbols()
symbol_collector(tree)
unbound = []
import __builtin__
for name in symbol_collector.names:
if not tree.scope.lookup(name) and not hasattr(__builtin__, name):
unbound.append(name)
return unbound
def get_type(arg, context=None):
py_type = type(arg)
if py_type in [list, tuple, dict, str]:
return py_type.__name__
elif py_type is float:
return 'double'
elif py_type is bool:
return 'bint'
elif py_type is int:
return 'long'
elif 'numpy' in sys.modules and isinstance(arg, sys.modules['numpy'].ndarray):
return 'numpy.ndarray[numpy.%s_t, ndim=%s]' % (arg.dtype.name, arg.ndim)
else:
for base_type in py_type.mro():
if base_type.__module__ == '__builtin__':
return 'object'
module = context.find_module(base_type.__module__, need_pxd=False)
if module:
entry = module.lookup(base_type.__name__)
if entry.is_type:
return '%s.%s' % (base_type.__module__, base_type.__name__)
return 'object'
# TODO: use locals/globals for unbound variables
def cython_inline(code,
types='aggressive',
lib_dir=os.path.expanduser('~/.cython/inline'),
include_dirs=['.'],
locals=None,
globals=None,
**kwds):
code = strip_common_indent(code)
ctx = Context(include_dirs, default_options)
if locals is None:
locals = inspect.currentframe().f_back.f_back.f_locals
if globals is None:
globals = inspect.currentframe().f_back.f_back.f_globals
try:
for symbol in unbound_symbols(code):
if symbol in kwds:
continue
elif symbol in locals:
kwds[symbol] = locals[symbol]
elif symbol in globals:
kwds[symbol] = globals[symbol]
else:
print "Couldn't find ", symbol
except AssertionError:
# Parsing from strings not fully supported (e.g. cimports).
print "Could not parse code as a string (to extract unbound symbols)."
arg_names = kwds.keys()
arg_names.sort()
arg_sigs = tuple([(get_type(kwds[arg], ctx), arg) for arg in arg_names])
key = code, arg_sigs
module = _code_cache.get(key)
if not module:
cimports = []
qualified = re.compile(r'([.\w]+)[.]')
for type, _ in arg_sigs:
m = qualified.match(type)
if m:
cimports.append('\ncimport %s' % m.groups()[0])
module_body, func_body = extract_func_code(code)
params = ', '.join('%s %s' % a for a in arg_sigs)
module_code = """
%(cimports)s
%(module_body)s
def __invoke(%(params)s):
%(func_body)s
""" % {'cimports': '\n'.join(cimports), 'module_body': module_body, 'params': params, 'func_body': func_body }
# print module_code
_, pyx_file = tempfile.mkstemp('.pyx')
open(pyx_file, 'w').write(module_code)
module = "_" + hashlib.md5(code + str(arg_sigs)).hexdigest()
extension = Extension(
name = module,
sources = [pyx_file],
pyrex_include_dirs = include_dirs)
build_extension = build_ext(Distribution())
build_extension.finalize_options()
build_extension.extensions = [extension]
build_extension.build_temp = os.path.dirname(pyx_file)
if lib_dir not in sys.path:
sys.path.append(lib_dir)
build_extension.build_lib = lib_dir
build_extension.run()
_code_cache[key] = module
arg_list = [kwds[arg] for arg in arg_names]
return __import__(module).__invoke(*arg_list)
non_space = re.compile('[^ ]')
def strip_common_indent(code):
min_indent = None
lines = code.split('\n')
for line in lines:
match = non_space.search(line)
if not match:
continue # blank
indent = match.start()
if line[indent] == '#':
continue # comment
elif min_indent is None or min_indent > indent:
min_indent = indent
for ix, line in enumerate(lines):
match = non_space.search(line)
if not match or line[indent] == '#':
continue
else:
lines[ix] = line[min_indent:]
return '\n'.join(lines)
module_statement = re.compile(r'^((cdef +(extern|class))|cimport|(from .+ cimport)|(from .+ import +[*]))')
def extract_func_code(code):
module = []
function = []
# TODO: string literals, backslash
current = function
code = code.replace('\t', ' ')
lines = code.split('\n')
for line in lines:
if not line.startswith(' '):
if module_statement.match(line):
current = module
else:
current = function
current.append(line)
return '\n'.join(module), ' ' + '\n '.join(function)
...@@ -60,6 +60,7 @@ def parse_from_strings(name, code, pxds={}, level=None, initial_pos=None): ...@@ -60,6 +60,7 @@ def parse_from_strings(name, code, pxds={}, level=None, initial_pos=None):
scope = scope, context = context, initial_pos = initial_pos) scope = scope, context = context, initial_pos = initial_pos)
if level is None: if level is None:
tree = Parsing.p_module(scanner, 0, module_name) tree = Parsing.p_module(scanner, 0, module_name)
tree.scope = scope
else: else:
tree = Parsing.p_code(scanner, level=level) tree = Parsing.p_code(scanner, level=level)
return tree return tree
...@@ -201,6 +202,8 @@ class TreeFragment(object): ...@@ -201,6 +202,8 @@ class TreeFragment(object):
if not isinstance(t, StatListNode): if not isinstance(t, StatListNode):
t = StatListNode(pos=mod.pos, stats=[t]) t = StatListNode(pos=mod.pos, stats=[t])
for transform in pipeline: for transform in pipeline:
if transform is None:
continue
t = transform(t) t = transform(t)
self.root = t self.root = t
elif isinstance(code, Node): elif isinstance(code, Node):
......
# cython.* namespace for pure mode.
compiled = False compiled = False
def empty_decorator(x): def empty_decorator(x):
return x return x
# Function decorators
def locals(**arg_types): def locals(**arg_types):
return empty_decorator return empty_decorator
def inline(f, *args, **kwds):
if isinstance(f, basestring):
from Cython.Build.Inline import cython_inline
return cython_inline(f, *args, **kwds)
else:
assert len(args) == len(kwds) == 0
return f
# Special functions # Special functions
def cdiv(a, b): def cdiv(a, b):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment