import tempfile
import sys, os, re, inspect
from cython import set

try:
    import hashlib
except ImportError:
    import md5 as hashlib

from distutils.core import Distribution, Extension
from distutils.command.build_ext import build_ext

import Cython
from Cython.Compiler.Main import Context, CompilationOptions, default_options

from Cython.Compiler.ParseTreeTransforms import CythonTransform, SkipDeclarations, AnalyseDeclarationsTransform
from Cython.Compiler.TreeFragment import parse_from_strings
from Cython.Build.Dependencies import strip_string_literals, cythonize

# A utility function to convert user-supplied ASCII strings to unicode.
if sys.version_info[0] < 3:
    def to_unicode(s):
        if not isinstance(s, unicode):
            return s.decode('ascii')
        else:
            return s
else:
    to_unicode = lambda x: x

_code_cache = {}


class AllSymbols(CythonTransform, SkipDeclarations):
    def __init__(self):
        CythonTransform.__init__(self, None)
        self.names = set()
    def visit_NameNode(self, node):
        self.names.add(node.name)

def unbound_symbols(code, context=None):
    code = to_unicode(code)
    if context is None:
        context = Context([], default_options)
    from Cython.Compiler.ParseTreeTransforms import AnalyseDeclarationsTransform
    tree = parse_from_strings('(tree fragment)', code)
    for phase in context.create_pipeline(pxd=False):
        if phase is None:
            continue
        tree = phase(tree)
        if isinstance(phase, AnalyseDeclarationsTransform):
            break
    symbol_collector = AllSymbols()
    symbol_collector(tree)
    unbound = []
    import __builtin__
    for name in symbol_collector.names:
        if not tree.scope.lookup(name) and not hasattr(__builtin__, name):
            unbound.append(name)
    return unbound

def unsafe_type(arg, context=None):
    py_type = type(arg)
    if py_type is int:
        return 'long'
    else:
        return safe_type(arg, context)

def safe_type(arg, context=None):
    py_type = type(arg)
    if py_type in [list, tuple, dict, str]:
        return py_type.__name__
    elif py_type is complex:
        return 'double complex'
    elif py_type is float:
        return 'double'
    elif py_type is bool:
        return 'bint'
    elif 'numpy' in sys.modules and isinstance(arg, sys.modules['numpy'].ndarray):
        return 'numpy.ndarray[numpy.%s_t, ndim=%s]' % (arg.dtype.name, arg.ndim)
    else:
        for base_type in py_type.mro():
            if base_type.__module__ == '__builtin__':
                return 'object'
            module = context.find_module(base_type.__module__, need_pxd=False)
            if module:
                entry = module.lookup(base_type.__name__)
                if entry.is_type:
                    return '%s.%s' % (base_type.__module__, base_type.__name__)
        return 'object'

def cython_inline(code, 
                  get_type=unsafe_type,
                  lib_dir=os.path.expanduser('~/.cython/inline'),
                  cython_include_dirs=['.'],
                  force=False,
                  quiet=False,
                  locals=None,
                  globals=None,
                  **kwds):
    if get_type is None:
        get_type = lambda x: 'object'
    code = to_unicode(code)
    code, literals = strip_string_literals(code)
    code = strip_common_indent(code)
    ctx = Context(cython_include_dirs, default_options)
    if locals is None:
        locals = inspect.currentframe().f_back.f_back.f_locals
    if globals is None:
        globals = inspect.currentframe().f_back.f_back.f_globals
    try:
        for symbol in unbound_symbols(code):
            if symbol in kwds:
                continue
            elif symbol in locals:
                kwds[symbol] = locals[symbol]
            elif symbol in globals:
                kwds[symbol] = globals[symbol]
            else:
                print("Couldn't find ", symbol)
    except AssertionError:
        if not quiet:
            # Parsing from strings not fully supported (e.g. cimports).
            print("Could not parse code as a string (to extract unbound symbols).")
    arg_names = kwds.keys()
    arg_names.sort()
    arg_sigs = tuple([(get_type(kwds[arg], ctx), arg) for arg in arg_names])
    key = code, arg_sigs, sys.version_info, sys.executable, Cython.__version__
    module_name = "_cython_inline_" + hashlib.md5(str(key)).hexdigest()
    try:
        if not os.path.exists(lib_dir):
            os.makedirs(lib_dir)
        if lib_dir not in sys.path:
            sys.path.append(lib_dir)
        if force:
            raise ImportError
        else:
            __import__(module_name)
    except ImportError:
        cflags = []
        c_include_dirs = []
        cimports = []
        qualified = re.compile(r'([.\w]+)[.]')
        for type, _ in arg_sigs:
            m = qualified.match(type)
            if m:
                cimports.append('\ncimport %s' % m.groups()[0])
                # one special case
                if m.groups()[0] == 'numpy':
                    import numpy
                    c_include_dirs.append(numpy.get_include())
                    cflags.append('-Wno-unused')
        module_body, func_body = extract_func_code(code)
        params = ', '.join(['%s %s' % a for a in arg_sigs])
        module_code = """
%(module_body)s
%(cimports)s
def __invoke(%(params)s):
%(func_body)s
        """ % {'cimports': '\n'.join(cimports), 'module_body': module_body, 'params': params, 'func_body': func_body }
        for key, value in literals.items():
            module_code = module_code.replace(key, value)
        pyx_file = os.path.join(lib_dir, module_name + '.pyx')
        open(pyx_file, 'w').write(module_code)
        extension = Extension(
            name = module_name,
            sources = [pyx_file],
            include_dirs = c_include_dirs,
            extra_compile_args = cflags)
        build_extension = build_ext(Distribution())
        build_extension.finalize_options()
        build_extension.extensions = cythonize([extension], ctx=ctx, quiet=quiet)
        build_extension.build_temp = os.path.dirname(pyx_file)
        build_extension.build_lib  = lib_dir
        build_extension.run()
        _code_cache[key] = module_name
    arg_list = [kwds[arg] for arg in arg_names]
    return __import__(module_name).__invoke(*arg_list)

non_space = re.compile('[^ ]')
def strip_common_indent(code):
    min_indent = None
    lines = code.split('\n')
    for line in lines:
        match = non_space.search(line)
        if not match:
            continue # blank
        indent = match.start()
        if line[indent] == '#':
            continue # comment
        elif min_indent is None or min_indent > indent:
            min_indent = indent
    for ix, line in enumerate(lines):
        match = non_space.search(line)
        if not match or line[indent] == '#':
            continue
        else:
            lines[ix] = line[min_indent:]
    return '\n'.join(lines)

module_statement = re.compile(r'^((cdef +(extern|class))|cimport|(from .+ cimport)|(from .+ import +[*]))')
def extract_func_code(code):
    module = []
    function = []
    current = function
    code = code.replace('\t', ' ')
    lines = code.split('\n')
    for line in lines:
        if not line.startswith(' '):
            if module_statement.match(line):
                current = module
            else:
                current = function
        current.append(line)
    return '\n'.join(module), '    ' + '\n    '.join(function)



try:
    from inspect import getcallargs
except ImportError:
    def getcallargs(func, *arg_values, **kwd_values):
        all = {}
        args, varargs, kwds, defaults = inspect.getargspec(func)
        if varargs is not None:
            all[varargs] = arg_values[len(args):]
        for name, value in zip(args, arg_values):
            all[name] = value
        for name, value in kwd_values.items():
            if name in args:
                if name in all:
                    raise TypeError, "Duplicate argument %s" % name
                all[name] = kwd_values.pop(name)
        if kwds is not None:
            all[kwds] = kwd_values
        elif kwd_values:
            raise TypeError, "Unexpected keyword arguments: %s" % kwd_values.keys()
        if defaults is None:
            defaults = ()
        first_default = len(args) - len(defaults)
        for ix, name in enumerate(args):
            if name not in all:
                if ix >= first_default:
                    all[name] = defaults[ix - first_default]
                else:
                    raise TypeError, "Missing argument: %s" % name
        return all

def get_body(source):
    ix = source.index(':')
    if source[:5] == 'lambda':
        return "return %s" % source[ix+1:]
    else:
        return source[ix+1:]

# Lots to be done here... It would be especially cool if compiled functions 
# could invoke each other quickly.
class RuntimeCompiledFunction(object):

    def __init__(self, f):
        self._f = f
        self._body = get_body(inspect.getsource(f))
    
    def __call__(self, *args, **kwds):
        all = getcallargs(self._f, *args, **kwds)
        return cython_inline(self._body, locals=self._f.func_globals, globals=self._f.func_globals, **all)