Commit 9de33b2a authored by Robert Bradshaw's avatar Robert Bradshaw

More module caching.

parent 858709b6
...@@ -12,6 +12,7 @@ except ImportError: ...@@ -12,6 +12,7 @@ except ImportError:
from distutils.core import Distribution, Extension from distutils.core import Distribution, Extension
from distutils.command.build_ext import build_ext from distutils.command.build_ext import build_ext
import Cython
from Cython.Compiler.Main import Context, CompilationOptions, default_options from Cython.Compiler.Main import Context, CompilationOptions, default_options
from Cython.Compiler.ParseTreeTransforms import CythonTransform, SkipDeclarations, AnalyseDeclarationsTransform from Cython.Compiler.ParseTreeTransforms import CythonTransform, SkipDeclarations, AnalyseDeclarationsTransform
...@@ -106,8 +107,14 @@ def cython_inline(code, ...@@ -106,8 +107,14 @@ def cython_inline(code,
arg_names.sort() arg_names.sort()
arg_sigs = tuple([(get_type(kwds[arg], ctx), arg) for arg in arg_names]) arg_sigs = tuple([(get_type(kwds[arg], ctx), arg) for arg in arg_names])
key = code, arg_sigs key = code, arg_sigs
module = _code_cache.get(key) module_name = _code_cache.get(key)
if not module: if module_name is None:
module_name = "_cython_inline_" + hashlib.md5(code + str(arg_sigs) + Cython.__version__).hexdigest()
try:
if lib_dir not in sys.path:
sys.path.append(lib_dir)
__import__(module_name)
except ImportError:
cimports = [] cimports = []
qualified = re.compile(r'([.\w]+)[.]') qualified = re.compile(r'([.\w]+)[.]')
for type, _ in arg_sigs: for type, _ in arg_sigs:
...@@ -124,24 +131,21 @@ def __invoke(%(params)s): ...@@ -124,24 +131,21 @@ def __invoke(%(params)s):
""" % {'cimports': '\n'.join(cimports), 'module_body': module_body, 'params': params, 'func_body': func_body } """ % {'cimports': '\n'.join(cimports), 'module_body': module_body, 'params': params, 'func_body': func_body }
for key, value in literals.items(): for key, value in literals.items():
module_code = module_code.replace(key, value) module_code = module_code.replace(key, value)
module = "_" + hashlib.md5(code + str(arg_sigs)).hexdigest() pyx_file = os.path.join(tempfile.mkdtemp(), module_name + '.pyx')
pyx_file = os.path.join(tempfile.mkdtemp(), module + '.pyx')
open(pyx_file, 'w').write(module_code) open(pyx_file, 'w').write(module_code)
extension = Extension( extension = Extension(
name = module, name = module_name,
sources = [pyx_file], sources = [pyx_file],
pyrex_include_dirs = include_dirs) pyrex_include_dirs = include_dirs)
build_extension = build_ext(Distribution()) build_extension = build_ext(Distribution())
build_extension.finalize_options() build_extension.finalize_options()
build_extension.extensions = cythonize([extension]) build_extension.extensions = cythonize([extension])
build_extension.build_temp = os.path.dirname(pyx_file) build_extension.build_temp = os.path.dirname(pyx_file)
if lib_dir not in sys.path:
sys.path.append(lib_dir)
build_extension.build_lib = lib_dir build_extension.build_lib = lib_dir
build_extension.run() build_extension.run()
_code_cache[key] = module _code_cache[key] = module_name
arg_list = [kwds[arg] for arg in arg_names] arg_list = [kwds[arg] for arg in arg_names]
return __import__(module).__invoke(*arg_list) return __import__(module_name).__invoke(*arg_list)
non_space = re.compile('[^ ]') non_space = re.compile('[^ ]')
def strip_common_indent(code): def strip_common_indent(code):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment