Commit 2d2fba93 authored by Robert Bradshaw's avatar Robert Bradshaw

numpy and extension types for runtime cython

parent d0a66db8
...@@ -7,45 +7,69 @@ except ImportError: ...@@ -7,45 +7,69 @@ except ImportError:
import md5 as hashlib import md5 as hashlib
from distutils.dist import Distribution from distutils.dist import Distribution
from distutils.core import Extension from Cython.Distutils.extension import Extension
from Cython.Distutils import build_ext from Cython.Distutils import build_ext
from Cython.Compiler.Main import Context, CompilationOptions, default_options
code_cache = {} code_cache = {}
def get_type(arg):
def get_type(arg, context=None):
py_type = type(arg) py_type = type(arg)
# TODO: numpy
# TODO: extension types # TODO: extension types
if py_type in [list, tuple, dict, str]: if py_type in [list, tuple, dict, str]:
return py_type.__name__ return py_type.__name__
elif py_type is float: elif py_type is float:
return 'double' return 'double'
elif py_type is bool:
return 'bint'
elif py_type is int: elif py_type is int:
return 'long' return 'long'
elif 'numpy' in sys.modules and isinstance(arg, sys.modules['numpy'].ndarray):
return 'numpy.ndarray[numpy.%s_t, ndim=%s]' % (arg.dtype.name, arg.ndim)
else: else:
for base_type in py_type.mro():
if base_type.__module__ == '__builtin__':
return 'object'
module = context.find_module(base_type.__module__, need_pxd=False)
if module:
entry = module.lookup(base_type.__name__)
if entry.is_type:
return '%s.%s' % (base_type.__module__, base_type.__name__)
return 'object' return 'object'
# TODO: use locals/globals for unbound variables # TODO: use locals/globals for unbound variables
def cython_inline(code, types='aggressive', lib_dir=os.path.expanduser('~/.cython/inline'), **kwds): def cython_inline(code, types='aggressive', lib_dir=os.path.expanduser('~/.cython/inline'), include_dirs=['.'], **kwds):
ctx = Context(include_dirs, default_options)
_, pyx_file = tempfile.mkstemp('.pyx') _, pyx_file = tempfile.mkstemp('.pyx')
arg_names = kwds.keys() arg_names = kwds.keys()
arg_names.sort() arg_names.sort()
arg_sigs = tuple((get_type(kwds[arg]), arg) for arg in arg_names) arg_sigs = tuple((get_type(kwds[arg], ctx), arg) for arg in arg_names)
key = code, arg_sigs key = code, arg_sigs
module = code_cache.get(key) module = code_cache.get(key)
if not module: if not module:
cimports = ''
qualified = re.compile(r'([.\w]+)[.]')
for type, _ in arg_sigs:
m = qualified.match(type)
if m:
cimports += '\ncimport %s' % m.groups()[0]
module_body, func_body = extract_func_code(code) module_body, func_body = extract_func_code(code)
params = ', '.join('%s %s' % a for a in arg_sigs) params = ', '.join('%s %s' % a for a in arg_sigs)
module_code = """ module_code = """
%(cimports)s
%(module_body)s %(module_body)s
def __invoke(%(params)s): def __invoke(%(params)s):
%(func_body)s %(func_body)s
""" % locals() """ % locals()
print module_code
open(pyx_file, 'w').write(module_code) open(pyx_file, 'w').write(module_code)
module = "_" + hashlib.md5(code + str(arg_sigs)).hexdigest() module = "_" + hashlib.md5(code + str(arg_sigs)).hexdigest()
extension = Extension( extension = Extension(
name = module, name = module,
sources=[pyx_file]) sources = [pyx_file],
pyrex_include_dirs = include_dirs)
build_extension = build_ext(Distribution()) build_extension = build_ext(Distribution())
build_extension.finalize_options() build_extension.finalize_options()
build_extension.extensions = [extension] build_extension.extensions = [extension]
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment