Inline.py 9.31 KB
Newer Older
Robert Bradshaw's avatar
Robert Bradshaw committed
1
import tempfile
2
import sys, os, re, inspect
Robert Bradshaw's avatar
Robert Bradshaw committed
3
from cython import set
Robert Bradshaw's avatar
Robert Bradshaw committed
4 5 6 7 8 9

try:
    import hashlib
except ImportError:
    import md5 as hashlib

10 11
from distutils.core import Distribution, Extension
from distutils.command.build_ext import build_ext
12

Robert Bradshaw's avatar
Robert Bradshaw committed
13
import Cython
14 15
from Cython.Compiler.Main import Context, CompilationOptions, default_options

16 17
from Cython.Compiler.ParseTreeTransforms import CythonTransform, SkipDeclarations, AnalyseDeclarationsTransform
from Cython.Compiler.TreeFragment import parse_from_strings
18
from Cython.Build.Dependencies import strip_string_literals, cythonize
19

20 21 22 23 24 25 26 27 28 29
# A utility function to convert user-supplied ASCII strings to unicode.
if sys.version_info[0] < 3:
    def to_unicode(s):
        if not isinstance(s, unicode):
            return s.decode('ascii')
        else:
            return s
else:
    to_unicode = lambda x: x

30 31 32 33 34 35 36 37 38
_code_cache = {}


class AllSymbols(CythonTransform, SkipDeclarations):
    def __init__(self):
        CythonTransform.__init__(self, None)
        self.names = set()
    def visit_NameNode(self, node):
        self.names.add(node.name)
Robert Bradshaw's avatar
Robert Bradshaw committed
39

40
def unbound_symbols(code, context=None):
41
    code = to_unicode(code)
42 43 44 45 46 47 48 49 50 51 52 53 54
    if context is None:
        context = Context([], default_options)
    from Cython.Compiler.ParseTreeTransforms import AnalyseDeclarationsTransform
    tree = parse_from_strings('(tree fragment)', code)
    for phase in context.create_pipeline(pxd=False):
        if phase is None:
            continue
        tree = phase(tree)
        if isinstance(phase, AnalyseDeclarationsTransform):
            break
    symbol_collector = AllSymbols()
    symbol_collector(tree)
    unbound = []
Lisandro Dalcin's avatar
Lisandro Dalcin committed
55 56 57 58
    try:
        import builtins
    except ImportError:
        import __builtin__ as builtins
59
    for name in symbol_collector.names:
Lisandro Dalcin's avatar
Lisandro Dalcin committed
60
        if not tree.scope.lookup(name) and not hasattr(builtins, name):
61 62
            unbound.append(name)
    return unbound
63

64 65 66 67 68 69 70 71
def unsafe_type(arg, context=None):
    py_type = type(arg)
    if py_type is int:
        return 'long'
    else:
        return safe_type(arg, context)

def safe_type(arg, context=None):
Robert Bradshaw's avatar
Robert Bradshaw committed
72 73 74
    py_type = type(arg)
    if py_type in [list, tuple, dict, str]:
        return py_type.__name__
Robert Bradshaw's avatar
Robert Bradshaw committed
75 76
    elif py_type is complex:
        return 'double complex'
Robert Bradshaw's avatar
Robert Bradshaw committed
77 78
    elif py_type is float:
        return 'double'
79 80 81 82
    elif py_type is bool:
        return 'bint'
    elif 'numpy' in sys.modules and isinstance(arg, sys.modules['numpy'].ndarray):
        return 'numpy.ndarray[numpy.%s_t, ndim=%s]' % (arg.dtype.name, arg.ndim)
Robert Bradshaw's avatar
Robert Bradshaw committed
83
    else:
84
        for base_type in py_type.mro():
Lisandro Dalcin's avatar
Lisandro Dalcin committed
85
            if base_type.__module__ in ('__builtin__', 'builtins'):
86 87 88 89 90 91
                return 'object'
            module = context.find_module(base_type.__module__, need_pxd=False)
            if module:
                entry = module.lookup(base_type.__name__)
                if entry.is_type:
                    return '%s.%s' % (base_type.__module__, base_type.__name__)
Robert Bradshaw's avatar
Robert Bradshaw committed
92 93
        return 'object'

94
def cython_inline(code,
95
                  get_type=unsafe_type,
96
                  lib_dir=os.path.expanduser('~/.cython/inline'),
Robert Bradshaw's avatar
Robert Bradshaw committed
97
                  cython_include_dirs=['.'],
98
                  force=False,
99
                  quiet=False,
100 101 102
                  locals=None,
                  globals=None,
                  **kwds):
103 104
    if get_type is None:
        get_type = lambda x: 'object'
105
    code = to_unicode(code)
106
    orig_code = code
107
    code, literals = strip_string_literals(code)
Robert Bradshaw's avatar
Robert Bradshaw committed
108
    code = strip_common_indent(code)
Robert Bradshaw's avatar
Robert Bradshaw committed
109
    ctx = Context(cython_include_dirs, default_options)
110 111 112 113 114 115 116 117 118 119 120 121 122
    if locals is None:
        locals = inspect.currentframe().f_back.f_back.f_locals
    if globals is None:
        globals = inspect.currentframe().f_back.f_back.f_globals
    try:
        for symbol in unbound_symbols(code):
            if symbol in kwds:
                continue
            elif symbol in locals:
                kwds[symbol] = locals[symbol]
            elif symbol in globals:
                kwds[symbol] = globals[symbol]
            else:
123
                print("Couldn't find ", symbol)
124
    except AssertionError:
125 126 127
        if not quiet:
            # Parsing from strings not fully supported (e.g. cimports).
            print("Could not parse code as a string (to extract unbound symbols).")
Robert Bradshaw's avatar
Robert Bradshaw committed
128 129
    arg_names = kwds.keys()
    arg_names.sort()
Robert Bradshaw's avatar
Robert Bradshaw committed
130
    arg_sigs = tuple([(get_type(kwds[arg], ctx), arg) for arg in arg_names])
131
    key = orig_code, arg_sigs, sys.version_info, sys.executable, Cython.__version__
Lisandro Dalcin's avatar
Lisandro Dalcin committed
132
    module_name = "_cython_inline_" + hashlib.md5(str(key).encode('utf-8')).hexdigest()
Robert Bradshaw's avatar
Robert Bradshaw committed
133
    try:
Robert Bradshaw's avatar
Robert Bradshaw committed
134 135
        if not os.path.exists(lib_dir):
            os.makedirs(lib_dir)
Robert Bradshaw's avatar
Robert Bradshaw committed
136 137
        if lib_dir not in sys.path:
            sys.path.append(lib_dir)
138 139 140 141
        if force:
            raise ImportError
        else:
            __import__(module_name)
Robert Bradshaw's avatar
Robert Bradshaw committed
142
    except ImportError:
143
        cflags = []
Robert Bradshaw's avatar
Robert Bradshaw committed
144
        c_include_dirs = []
145
        cimports = []
146 147 148 149
        qualified = re.compile(r'([.\w]+)[.]')
        for type, _ in arg_sigs:
            m = qualified.match(type)
            if m:
150
                cimports.append('\ncimport %s' % m.groups()[0])
Robert Bradshaw's avatar
Robert Bradshaw committed
151 152 153 154
                # one special case
                if m.groups()[0] == 'numpy':
                    import numpy
                    c_include_dirs.append(numpy.get_include())
155
                    # cflags.append('-Wno-unused')
156
        module_body, func_body = extract_func_code(code)
Stefan Behnel's avatar
Stefan Behnel committed
157
        params = ', '.join(['%s %s' % a for a in arg_sigs])
Robert Bradshaw's avatar
Robert Bradshaw committed
158 159
        module_code = """
%(module_body)s
160
%(cimports)s
Robert Bradshaw's avatar
Robert Bradshaw committed
161 162
def __invoke(%(params)s):
%(func_body)s
163
        """ % {'cimports': '\n'.join(cimports), 'module_body': module_body, 'params': params, 'func_body': func_body }
164 165
        for key, value in literals.items():
            module_code = module_code.replace(key, value)
Robert Bradshaw's avatar
Robert Bradshaw committed
166
        pyx_file = os.path.join(lib_dir, module_name + '.pyx')
167 168 169 170 171
        fh = open(pyx_file, 'w')
        try: 
            fh.write(module_code)
        finally:
            fh.close()
Robert Bradshaw's avatar
Robert Bradshaw committed
172
        extension = Extension(
Robert Bradshaw's avatar
Robert Bradshaw committed
173
            name = module_name,
174
            sources = [pyx_file],
175 176
            include_dirs = c_include_dirs,
            extra_compile_args = cflags)
Robert Bradshaw's avatar
Robert Bradshaw committed
177 178
        build_extension = build_ext(Distribution())
        build_extension.finalize_options()
179
        build_extension.extensions = cythonize([extension], ctx=ctx, quiet=quiet)
Robert Bradshaw's avatar
Robert Bradshaw committed
180 181 182
        build_extension.build_temp = os.path.dirname(pyx_file)
        build_extension.build_lib  = lib_dir
        build_extension.run()
Robert Bradshaw's avatar
Robert Bradshaw committed
183
        _code_cache[key] = module_name
Robert Bradshaw's avatar
Robert Bradshaw committed
184
    arg_list = [kwds[arg] for arg in arg_names]
Robert Bradshaw's avatar
Robert Bradshaw committed
185
    return __import__(module_name).__invoke(*arg_list)
Robert Bradshaw's avatar
Robert Bradshaw committed
186

187
non_space = re.compile('[^ ]')
Robert Bradshaw's avatar
Robert Bradshaw committed
188
def strip_common_indent(code):
189
    min_indent = None
Robert Bradshaw's avatar
Robert Bradshaw committed
190
    lines = code.split('\n')
191
    for line in lines:
Robert Bradshaw's avatar
Robert Bradshaw committed
192 193
        match = non_space.search(line)
        if not match:
194
            continue # blank
Robert Bradshaw's avatar
Robert Bradshaw committed
195 196
        indent = match.start()
        if line[indent] == '#':
197 198 199
            continue # comment
        elif min_indent is None or min_indent > indent:
            min_indent = indent
Robert Bradshaw's avatar
Robert Bradshaw committed
200 201 202
    for ix, line in enumerate(lines):
        match = non_space.search(line)
        if not match or line[indent] == '#':
203 204
            continue
        else:
Robert Bradshaw's avatar
Robert Bradshaw committed
205 206
            lines[ix] = line[min_indent:]
    return '\n'.join(lines)
207

Robert Bradshaw's avatar
Robert Bradshaw committed
208 209 210 211 212
module_statement = re.compile(r'^((cdef +(extern|class))|cimport|(from .+ cimport)|(from .+ import +[*]))')
def extract_func_code(code):
    module = []
    function = []
    current = function
213
    code = code.replace('\t', ' ')
Robert Bradshaw's avatar
Robert Bradshaw committed
214
    lines = code.split('\n')
215
    for line in lines:
Robert Bradshaw's avatar
Robert Bradshaw committed
216 217 218 219 220 221 222
        if not line.startswith(' '):
            if module_statement.match(line):
                current = module
            else:
                current = function
        current.append(line)
    return '\n'.join(module), '    ' + '\n    '.join(function)
Robert Bradshaw's avatar
Robert Bradshaw committed
223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262



try:
    from inspect import getcallargs
except ImportError:
    def getcallargs(func, *arg_values, **kwd_values):
        all = {}
        args, varargs, kwds, defaults = inspect.getargspec(func)
        if varargs is not None:
            all[varargs] = arg_values[len(args):]
        for name, value in zip(args, arg_values):
            all[name] = value
        for name, value in kwd_values.items():
            if name in args:
                if name in all:
                    raise TypeError, "Duplicate argument %s" % name
                all[name] = kwd_values.pop(name)
        if kwds is not None:
            all[kwds] = kwd_values
        elif kwd_values:
            raise TypeError, "Unexpected keyword arguments: %s" % kwd_values.keys()
        if defaults is None:
            defaults = ()
        first_default = len(args) - len(defaults)
        for ix, name in enumerate(args):
            if name not in all:
                if ix >= first_default:
                    all[name] = defaults[ix - first_default]
                else:
                    raise TypeError, "Missing argument: %s" % name
        return all

def get_body(source):
    ix = source.index(':')
    if source[:5] == 'lambda':
        return "return %s" % source[ix+1:]
    else:
        return source[ix+1:]

263
# Lots to be done here... It would be especially cool if compiled functions
Robert Bradshaw's avatar
Robert Bradshaw committed
264 265 266 267 268 269
# could invoke each other quickly.
class RuntimeCompiledFunction(object):

    def __init__(self, f):
        self._f = f
        self._body = get_body(inspect.getsource(f))
270

Robert Bradshaw's avatar
Robert Bradshaw committed
271 272 273
    def __call__(self, *args, **kwds):
        all = getcallargs(self._f, *args, **kwds)
        return cython_inline(self._body, locals=self._f.func_globals, globals=self._f.func_globals, **all)