Commit e35445d1 authored by Robert Bradshaw's avatar Robert Bradshaw

Compile decorator.

parent 6d95c710
...@@ -161,7 +161,7 @@ def strip_string_literals(code, prefix='__Pyx_L'): ...@@ -161,7 +161,7 @@ def strip_string_literals(code, prefix='__Pyx_L'):
if q == -1: q = max(single_q, double_q) if q == -1: q = max(single_q, double_q)
# Process comment. # Process comment.
if hash_mark < q or hash_mark > -1 == q: if -1 < hash_mark and (hash_mark < q or q == -1):
end = code.find('\n', hash_mark) end = code.find('\n', hash_mark)
if end == -1: if end == -1:
end = None end = None
...@@ -173,6 +173,7 @@ def strip_string_literals(code, prefix='__Pyx_L'): ...@@ -173,6 +173,7 @@ def strip_string_literals(code, prefix='__Pyx_L'):
if end is None: if end is None:
break break
q = end q = end
start = q
# We're done. # We're done.
elif q == -1: elif q == -1:
...@@ -194,8 +195,8 @@ def strip_string_literals(code, prefix='__Pyx_L'): ...@@ -194,8 +195,8 @@ def strip_string_literals(code, prefix='__Pyx_L'):
literals[label] = code[start+len(in_quote):q] literals[label] = code[start+len(in_quote):q]
new_code.append("%s%s%s" % (in_quote, label, in_quote)) new_code.append("%s%s%s" % (in_quote, label, in_quote))
q += len(in_quote) q += len(in_quote)
start = q
in_quote = False in_quote = False
start = q
else: else:
q += 1 q += 1
......
...@@ -108,7 +108,12 @@ def cython_inline(code, ...@@ -108,7 +108,12 @@ def cython_inline(code,
arg_sigs = tuple([(get_type(kwds[arg], ctx), arg) for arg in arg_names]) arg_sigs = tuple([(get_type(kwds[arg], ctx), arg) for arg in arg_names])
key = code, arg_sigs, sys.version_info, sys.executable, Cython.__version__ key = code, arg_sigs, sys.version_info, sys.executable, Cython.__version__
module_name = "_cython_inline_" + hashlib.md5(str(key)).hexdigest() module_name = "_cython_inline_" + hashlib.md5(str(key)).hexdigest()
# # TODO: Does this cover all the platforms?
# if (not os.path.exists(os.path.join(lib_dir, module_name + ".so")) and
# not os.path.exists(os.path.join(lib_dir, module_name + ".dll"))):
try: try:
if not os.path.exists(lib_dir):
os.makedirs(lib_dir)
if lib_dir not in sys.path: if lib_dir not in sys.path:
sys.path.append(lib_dir) sys.path.append(lib_dir)
__import__(module_name) __import__(module_name)
...@@ -134,7 +139,7 @@ def __invoke(%(params)s): ...@@ -134,7 +139,7 @@ def __invoke(%(params)s):
""" % {'cimports': '\n'.join(cimports), 'module_body': module_body, 'params': params, 'func_body': func_body } """ % {'cimports': '\n'.join(cimports), 'module_body': module_body, 'params': params, 'func_body': func_body }
for key, value in literals.items(): for key, value in literals.items():
module_code = module_code.replace(key, value) module_code = module_code.replace(key, value)
pyx_file = os.path.join(tempfile.mkdtemp(), module_name + '.pyx') pyx_file = os.path.join(lib_dir, module_name + '.pyx')
open(pyx_file, 'w').write(module_code) open(pyx_file, 'w').write(module_code)
extension = Extension( extension = Extension(
name = module_name, name = module_name,
...@@ -175,7 +180,6 @@ module_statement = re.compile(r'^((cdef +(extern|class))|cimport|(from .+ cimpor ...@@ -175,7 +180,6 @@ module_statement = re.compile(r'^((cdef +(extern|class))|cimport|(from .+ cimpor
def extract_func_code(code): def extract_func_code(code):
module = [] module = []
function = [] function = []
# TODO: string literals, backslash
current = function current = function
code = code.replace('\t', ' ') code = code.replace('\t', ' ')
lines = code.split('\n') lines = code.split('\n')
...@@ -187,3 +191,54 @@ def extract_func_code(code): ...@@ -187,3 +191,54 @@ def extract_func_code(code):
current = function current = function
current.append(line) current.append(line)
return '\n'.join(module), ' ' + '\n '.join(function) return '\n'.join(module), ' ' + '\n '.join(function)
try:
from inspect import getcallargs
except ImportError:
def getcallargs(func, *arg_values, **kwd_values):
all = {}
args, varargs, kwds, defaults = inspect.getargspec(func)
if varargs is not None:
all[varargs] = arg_values[len(args):]
for name, value in zip(args, arg_values):
all[name] = value
for name, value in kwd_values.items():
if name in args:
if name in all:
raise TypeError, "Duplicate argument %s" % name
all[name] = kwd_values.pop(name)
if kwds is not None:
all[kwds] = kwd_values
elif kwd_values:
raise TypeError, "Unexpected keyword arguments: %s" % kwd_values.keys()
if defaults is None:
defaults = ()
first_default = len(args) - len(defaults)
for ix, name in enumerate(args):
if name not in all:
if ix >= first_default:
all[name] = defaults[ix - first_default]
else:
raise TypeError, "Missing argument: %s" % name
return all
def get_body(source):
ix = source.index(':')
if source[:5] == 'lambda':
return "return %s" % source[ix+1:]
else:
return source[ix+1:]
# Lots to be done here... It would be especially cool if compiled functions
# could invoke each other quickly.
class RuntimeCompiledFunction(object):
def __init__(self, f):
self._f = f
self._body = get_body(inspect.getsource(f))
def __call__(self, *args, **kwds):
all = getcallargs(self._f, *args, **kwds)
return cython_inline(self._body, locals=self._f.func_globals, globals=self._f.func_globals, **all)
...@@ -18,6 +18,10 @@ def inline(f, *args, **kwds): ...@@ -18,6 +18,10 @@ def inline(f, *args, **kwds):
assert len(args) == len(kwds) == 0 assert len(args) == len(kwds) == 0
return f return f
def compile(f):
from Cython.Build.Inline import RuntimeCompiledFunction
return RuntimeCompiledFunction(f)
# Special functions # Special functions
def cdiv(a, b): def cdiv(a, b):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment