Commit a17fac09 authored by Mark Florisson's avatar Mark Florisson

merge

parents fe2191ac fd236aa1
...@@ -13,17 +13,27 @@ import sys ...@@ -13,17 +13,27 @@ import sys
import os import os
from distutils import sysconfig from distutils import sysconfig
INCDIR = sysconfig.get_python_inc() def get_config_var(name):
LIBDIR1 = sysconfig.get_config_var('LIBDIR') return sysconfig.get_config_var(name) or ''
LIBDIR2 = sysconfig.get_config_var('LIBPL')
PYLIB = sysconfig.get_config_var('LIBRARY')[3:-2]
CC = sysconfig.get_config_var('CC') INCDIR = sysconfig.get_python_inc()
CFLAGS = sysconfig.get_config_var('CFLAGS') + ' ' + os.environ.get('CFLAGS', '') LIBDIR1 = get_config_var('LIBDIR')
LINKCC = sysconfig.get_config_var('LINKCC') LIBDIR2 = get_config_var('LIBPL')
LINKFORSHARED = sysconfig.get_config_var('LINKFORSHARED') PYLIB = get_config_var('LIBRARY')
LIBS = sysconfig.get_config_var('LIBS') PYLIB_DYN = get_config_var('LDLIBRARY')
SYSLIBS = sysconfig.get_config_var('SYSLIBS') if PYLIB_DYN == PYLIB:
# no shared library
PYLIB_DYN = ''
else:
PYLIB_DYN = os.path.splitext(PYLIB_DYN[3:])[0] # 'lib(XYZ).so' -> XYZ
CC = get_config_var('CC')
CFLAGS = get_config_var('CFLAGS') + ' ' + os.environ.get('CFLAGS', '')
LINKCC = get_config_var('LINKCC')
LINKFORSHARED = get_config_var('LINKFORSHARED')
LIBS = get_config_var('LIBS')
SYSLIBS = get_config_var('SYSLIBS')
EXE_EXT = sysconfig.get_config_var('EXE')
def _debug(msg, *args): def _debug(msg, *args):
if DEBUG: if DEBUG:
...@@ -36,12 +46,14 @@ def dump_config(): ...@@ -36,12 +46,14 @@ def dump_config():
_debug('LIBDIR1: %s', LIBDIR1) _debug('LIBDIR1: %s', LIBDIR1)
_debug('LIBDIR2: %s', LIBDIR2) _debug('LIBDIR2: %s', LIBDIR2)
_debug('PYLIB: %s', PYLIB) _debug('PYLIB: %s', PYLIB)
_debug('PYLIB_DYN: %s', PYLIB_DYN)
_debug('CC: %s', CC) _debug('CC: %s', CC)
_debug('CFLAGS: %s', CFLAGS) _debug('CFLAGS: %s', CFLAGS)
_debug('LINKCC: %s', LINKCC) _debug('LINKCC: %s', LINKCC)
_debug('LINKFORSHARED: %s', LINKFORSHARED) _debug('LINKFORSHARED: %s', LINKFORSHARED)
_debug('LIBS: %s', LIBS) _debug('LIBS: %s', LIBS)
_debug('SYSLIBS: %s', SYSLIBS) _debug('SYSLIBS: %s', SYSLIBS)
_debug('EXE_EXT: %s', EXE_EXT)
def runcmd(cmd, shell=True): def runcmd(cmd, shell=True):
if shell: if shell:
...@@ -61,7 +73,8 @@ def runcmd(cmd, shell=True): ...@@ -61,7 +73,8 @@ def runcmd(cmd, shell=True):
sys.exit(returncode) sys.exit(returncode)
def clink(basename): def clink(basename):
runcmd([LINKCC, '-o', basename, basename+'.o', '-L'+LIBDIR1, '-L'+LIBDIR2, '-l'+PYLIB] runcmd([LINKCC, '-o', basename + EXE_EXT, basename+'.o', '-L'+LIBDIR1, '-L'+LIBDIR2]
+ [PYLIB_DYN and ('-l'+PYLIB_DYN) or os.path.join(LIBDIR1, PYLIB)]
+ LIBS.split() + SYSLIBS.split() + LINKFORSHARED.split()) + LIBS.split() + SYSLIBS.split() + LINKFORSHARED.split())
def ccompile(basename): def ccompile(basename):
...@@ -75,8 +88,8 @@ def cycompile(input_file, options=()): ...@@ -75,8 +88,8 @@ def cycompile(input_file, options=()):
if result.num_errors > 0: if result.num_errors > 0:
sys.exit(1) sys.exit(1)
def exec_file(basename, args=()): def exec_file(program_name, args=()):
runcmd([os.path.abspath(basename)] + list(args), shell=False) runcmd([os.path.abspath(program_name)] + list(args), shell=False)
def build(input_file, compiler_args=()): def build(input_file, compiler_args=()):
""" """
...@@ -88,7 +101,7 @@ def build(input_file, compiler_args=()): ...@@ -88,7 +101,7 @@ def build(input_file, compiler_args=()):
cycompile(input_file, compiler_args) cycompile(input_file, compiler_args)
ccompile(basename) ccompile(basename)
clink(basename) clink(basename)
return basename return basename + EXE_EXT
def build_and_run(args): def build_and_run(args):
""" """
...@@ -114,3 +127,6 @@ def build_and_run(args): ...@@ -114,3 +127,6 @@ def build_and_run(args):
program_name = build(input_file, cy_args) program_name = build(input_file, cy_args)
exec_file(program_name, args) exec_file(program_name, args)
if __name__ == '__main__':
build_and_run(sys.argv[1:])
...@@ -95,6 +95,54 @@ static CYTHON_INLINE int __Pyx_HasAttr(PyObject *o, PyObject *n) { ...@@ -95,6 +95,54 @@ static CYTHON_INLINE int __Pyx_HasAttr(PyObject *o, PyObject *n) {
} }
""") """)
globals_utility_code = UtilityCode(
# This is a stub implementation until we have something more complete.
# Currently, we only handle the most common case of a read-only dict
# of Python names. Supporting cdef names in the module and write
# access requires a rewrite as a dedicated class.
proto = """
static PyObject* __Pyx_Globals(); /*proto*/
""",
impl = '''
static PyObject* __Pyx_Globals() {
Py_ssize_t i;
/*PyObject *d;*/
PyObject *names = NULL;
PyObject *globals = PyObject_GetAttrString(%(MODULE)s, "__dict__");
if (!globals) {
PyErr_SetString(PyExc_TypeError,
"current module must have __dict__ attribute");
goto bad;
}
names = PyObject_Dir(%(MODULE)s);
if (!names)
goto bad;
for (i = 0; i < PyList_GET_SIZE(names); i++) {
PyObject* name = PyList_GET_ITEM(names, i);
if (!PyDict_Contains(globals, name)) {
PyObject* value = PyObject_GetAttr(%(MODULE)s, PyList_GET_ITEM(names, i));
if (!value)
goto bad;
if (PyDict_SetItem(globals, name, value) < 0) {
Py_DECREF(value);
goto bad;
}
}
}
Py_DECREF(names);
return globals;
/*
d = PyDictProxy_New(globals);
Py_DECREF(globals);
return d;
*/
bad:
Py_XDECREF(names);
Py_XDECREF(globals);
return NULL;
}
''' % {'MODULE' : Naming.module_cname})
pyexec_utility_code = UtilityCode( pyexec_utility_code = UtilityCode(
proto = """ proto = """
#if PY_VERSION_HEX < 0x02040000 #if PY_VERSION_HEX < 0x02040000
...@@ -384,6 +432,8 @@ builtin_function_table = [ ...@@ -384,6 +432,8 @@ builtin_function_table = [
utility_code = getattr3_utility_code), utility_code = getattr3_utility_code),
BuiltinFunction('getattr3', "OOO", "O", "__Pyx_GetAttr3", "getattr", BuiltinFunction('getattr3', "OOO", "O", "__Pyx_GetAttr3", "getattr",
utility_code = getattr3_utility_code), # Pyrex compatibility utility_code = getattr3_utility_code), # Pyrex compatibility
BuiltinFunction('globals', "", "O", "__Pyx_Globals",
utility_code = globals_utility_code),
BuiltinFunction('hasattr', "OO", "b", "__Pyx_HasAttr", BuiltinFunction('hasattr', "OO", "b", "__Pyx_HasAttr",
utility_code = hasattr_utility_code), utility_code = hasattr_utility_code),
BuiltinFunction('hash', "O", "h", "PyObject_Hash"), BuiltinFunction('hash', "O", "h", "PyObject_Hash"),
......
...@@ -163,15 +163,7 @@ def parse_command_line(args): ...@@ -163,15 +163,7 @@ def parse_command_line(args):
sys.stderr.write("Unknown compiler flag: %s\n" % option) sys.stderr.write("Unknown compiler flag: %s\n" % option)
sys.exit(1) sys.exit(1)
else: else:
arg = pop_arg() sources.append(pop_arg())
if arg.endswith(".pyx"):
sources.append(arg)
elif arg.endswith(".py"):
# maybe do some other stuff, but this should work for now
sources.append(arg)
else:
sys.stderr.write(
"cython: %s: Unknown filename suffix\n" % arg)
if options.use_listing_file and len(sources) > 1: if options.use_listing_file and len(sources) > 1:
sys.stderr.write( sys.stderr.write(
"cython: Only one source file allowed when using -o\n") "cython: Only one source file allowed when using -o\n")
......
...@@ -59,9 +59,10 @@ cdef class StringConst: ...@@ -59,9 +59,10 @@ cdef class StringConst:
cdef public object text cdef public object text
cdef public object escaped_value cdef public object escaped_value
cdef public dict py_strings cdef public dict py_strings
cdef public list py_versions
@cython.locals(intern=bint, is_str=bint, is_unicode=bint) @cython.locals(intern=bint, is_str=bint, is_unicode=bint)
cpdef get_py_string_const(self, encoding, identifier=*, is_str=*) cpdef get_py_string_const(self, encoding, identifier=*, is_str=*, py3str_cstring=*)
## cdef class PyStringConst: ## cdef class PyStringConst:
## cdef public object cname ## cdef public object cname
......
...@@ -25,12 +25,19 @@ except ImportError: ...@@ -25,12 +25,19 @@ except ImportError:
non_portable_builtins_map = { non_portable_builtins_map = {
# builtins that have different names in different Python versions
'bytes' : ('PY_MAJOR_VERSION < 3', 'str'), 'bytes' : ('PY_MAJOR_VERSION < 3', 'str'),
'unicode' : ('PY_MAJOR_VERSION >= 3', 'str'), 'unicode' : ('PY_MAJOR_VERSION >= 3', 'str'),
'xrange' : ('PY_MAJOR_VERSION >= 3', 'range'), 'xrange' : ('PY_MAJOR_VERSION >= 3', 'range'),
'BaseException' : ('PY_VERSION_HEX < 0x02050000', 'Exception'), 'BaseException' : ('PY_VERSION_HEX < 0x02050000', 'Exception'),
} }
uncachable_builtins = [
# builtin names that cannot be cached because they may or may not
# be available at import time
'WindowsError',
]
class UtilityCode(object): class UtilityCode(object):
# Stores utility code to add during code generation. # Stores utility code to add during code generation.
# #
...@@ -324,8 +331,16 @@ class StringConst(object): ...@@ -324,8 +331,16 @@ class StringConst(object):
self.text = text self.text = text
self.escaped_value = StringEncoding.escape_byte_string(byte_string) self.escaped_value = StringEncoding.escape_byte_string(byte_string)
self.py_strings = None self.py_strings = None
self.py_versions = []
def add_py_version(self, version):
if not version:
self.py_versions = [2,3]
elif version not in self.py_versions:
self.py_versions.append(version)
def get_py_string_const(self, encoding, identifier=None, is_str=False): def get_py_string_const(self, encoding, identifier=None,
is_str=False, py3str_cstring=None):
py_strings = self.py_strings py_strings = self.py_strings
text = self.text text = self.text
...@@ -344,47 +359,52 @@ class StringConst(object): ...@@ -344,47 +359,52 @@ class StringConst(object):
else: else:
encoding_key = ''.join(find_alphanums(encoding)) encoding_key = ''.join(find_alphanums(encoding))
key = (is_str, is_unicode, encoding_key) key = (is_str, is_unicode, encoding_key, py3str_cstring)
if py_strings is not None and key in py_strings: if py_strings is not None:
py_string = py_strings[key] try:
return py_strings[key]
except KeyError:
pass
else: else:
if py_strings is None: self.py_strings = {}
self.py_strings = {}
if identifier:
intern = True
elif identifier is None:
if isinstance(text, unicode):
intern = bool(possible_unicode_identifier(text))
else:
intern = bool(possible_bytes_identifier(text))
else:
intern = False
if intern:
prefix = Naming.interned_str_prefix
else:
prefix = Naming.py_const_prefix
pystring_cname = "%s%s_%s" % (
prefix,
(is_str and 's') or (is_unicode and 'u') or 'b',
self.cname[len(Naming.const_prefix):])
py_string = PyStringConst(
pystring_cname, encoding, is_unicode, is_str, intern)
self.py_strings[key] = py_string
if identifier:
intern = True
elif identifier is None:
if isinstance(text, unicode):
intern = bool(possible_unicode_identifier(text))
else:
intern = bool(possible_bytes_identifier(text))
else:
intern = False
if intern:
prefix = Naming.interned_str_prefix
else:
prefix = Naming.py_const_prefix
pystring_cname = "%s%s_%s" % (
prefix,
(is_str and 's') or (is_unicode and 'u') or 'b',
self.cname[len(Naming.const_prefix):])
py_string = PyStringConst(
pystring_cname, encoding, is_unicode, is_str, py3str_cstring, intern)
self.py_strings[key] = py_string
return py_string return py_string
class PyStringConst(object): class PyStringConst(object):
"""Global info about a Python string constant held by GlobalState. """Global info about a Python string constant held by GlobalState.
""" """
# cname string # cname string
# py3str_cstring string
# encoding string # encoding string
# intern boolean # intern boolean
# is_unicode boolean # is_unicode boolean
# is_str boolean # is_str boolean
def __init__(self, cname, encoding, is_unicode, is_str=False, intern=False): def __init__(self, cname, encoding, is_unicode, is_str=False,
py3str_cstring=None, intern=False):
self.cname = cname self.cname = cname
self.py3str_cstring = py3str_cstring
self.encoding = encoding self.encoding = encoding
self.is_str = is_str self.is_str = is_str
self.is_unicode = is_unicode self.is_unicode = is_unicode
...@@ -595,7 +615,7 @@ class GlobalState(object): ...@@ -595,7 +615,7 @@ class GlobalState(object):
cleanup_writer.put_xdecref_clear(const.cname, type, nanny=False) cleanup_writer.put_xdecref_clear(const.cname, type, nanny=False)
return const return const
def get_string_const(self, text): def get_string_const(self, text, py_version=None):
# return a C string constant, creating a new one if necessary # return a C string constant, creating a new one if necessary
if text.is_unicode: if text.is_unicode:
byte_string = text.utf8encode() byte_string = text.utf8encode()
...@@ -605,12 +625,21 @@ class GlobalState(object): ...@@ -605,12 +625,21 @@ class GlobalState(object):
c = self.string_const_index[byte_string] c = self.string_const_index[byte_string]
except KeyError: except KeyError:
c = self.new_string_const(text, byte_string) c = self.new_string_const(text, byte_string)
c.add_py_version(py_version)
return c return c
def get_py_string_const(self, text, identifier=None, is_str=False): def get_py_string_const(self, text, identifier=None,
is_str=False, unicode_value=None):
# return a Python string constant, creating a new one if necessary # return a Python string constant, creating a new one if necessary
c_string = self.get_string_const(text) py3str_cstring = None
py_string = c_string.get_py_string_const(text.encoding, identifier, is_str) if is_str and unicode_value is not None \
and unicode_value.utf8encode() != text.byteencode():
py3str_cstring = self.get_string_const(unicode_value, py_version=3)
c_string = self.get_string_const(text, py_version=2)
else:
c_string = self.get_string_const(text)
py_string = c_string.get_py_string_const(
text.encoding, identifier, is_str, py3str_cstring)
return py_string return py_string
def get_interned_identifier(self, text): def get_interned_identifier(self, text):
...@@ -711,8 +740,15 @@ class GlobalState(object): ...@@ -711,8 +740,15 @@ class GlobalState(object):
decls_writer = self.parts['decls'] decls_writer = self.parts['decls']
for _, cname, c in c_consts: for _, cname, c in c_consts:
conditional = False
if c.py_versions and (2 not in c.py_versions or 3 not in c.py_versions):
conditional = True
decls_writer.putln("#if PY_MAJOR_VERSION %s 3" % (
(2 in c.py_versions) and '<' or '>='))
decls_writer.putln('static char %s[] = "%s";' % ( decls_writer.putln('static char %s[] = "%s";' % (
cname, StringEncoding.split_string_literal(c.escaped_value))) cname, StringEncoding.split_string_literal(c.escaped_value)))
if conditional:
decls_writer.putln("#endif")
if c.py_strings is not None: if c.py_strings is not None:
for py_string in c.py_strings.values(): for py_string in c.py_strings.values():
py_strings.append((c.cname, len(py_string.cname), py_string)) py_strings.append((c.cname, len(py_string.cname), py_string))
...@@ -736,6 +772,17 @@ class GlobalState(object): ...@@ -736,6 +772,17 @@ class GlobalState(object):
decls_writer.putln( decls_writer.putln(
"static PyObject *%s;" % py_string.cname) "static PyObject *%s;" % py_string.cname)
if py_string.py3str_cstring:
w.putln("#if PY_MAJOR_VERSION >= 3")
w.putln(
"{&%s, %s, sizeof(%s), %s, %d, %d, %d}," % (
py_string.cname,
py_string.py3str_cstring.cname,
py_string.py3str_cstring.cname,
'0', 1, 0,
py_string.intern
))
w.putln("#else")
w.putln( w.putln(
"{&%s, %s, sizeof(%s), %s, %d, %d, %d}," % ( "{&%s, %s, sizeof(%s), %s, %d, %d, %d}," % (
py_string.cname, py_string.cname,
...@@ -746,6 +793,8 @@ class GlobalState(object): ...@@ -746,6 +793,8 @@ class GlobalState(object):
py_string.is_str, py_string.is_str,
py_string.intern py_string.intern
)) ))
if py_string.py3str_cstring:
w.putln("#endif")
w.putln("{0, 0, 0, 0, 0, 0, 0}") w.putln("{0, 0, 0, 0, 0, 0, 0}")
w.putln("};") w.putln("};")
...@@ -1003,8 +1052,10 @@ class CCodeWriter(object): ...@@ -1003,8 +1052,10 @@ class CCodeWriter(object):
def get_string_const(self, text): def get_string_const(self, text):
return self.globalstate.get_string_const(text).cname return self.globalstate.get_string_const(text).cname
def get_py_string_const(self, text, identifier=None, is_str=False): def get_py_string_const(self, text, identifier=None,
return self.globalstate.get_py_string_const(text, identifier, is_str).cname is_str=False, unicode_value=None):
return self.globalstate.get_py_string_const(
text, identifier, is_str, unicode_value).cname
def get_argument_default_const(self, type): def get_argument_default_const(self, type):
return self.globalstate.get_py_const(type).cname return self.globalstate.get_py_const(type).cname
......
...@@ -1122,16 +1122,6 @@ class StringNode(PyConstNode): ...@@ -1122,16 +1122,6 @@ class StringNode(PyConstNode):
if not dst_type.is_pyobject: if not dst_type.is_pyobject:
return BytesNode(self.pos, value=self.value).coerce_to(dst_type, env) return BytesNode(self.pos, value=self.value).coerce_to(dst_type, env)
self.check_for_coercion_error(dst_type, fail=True) self.check_for_coercion_error(dst_type, fail=True)
# this will be a unicode string in Py3, so make sure we can decode it
if self.value.encoding and isinstance(self.value, StringEncoding.BytesLiteral):
try:
self.value.decode(self.value.encoding)
except UnicodeDecodeError:
error(self.pos, ("Decoding unprefixed string literal from '%s' failed. Consider using"
"a byte string or unicode string explicitly, "
"or adjust the source code encoding.") % self.value.encoding)
return self return self
def can_coerce_to_char_literal(self): def can_coerce_to_char_literal(self):
...@@ -1139,7 +1129,8 @@ class StringNode(PyConstNode): ...@@ -1139,7 +1129,8 @@ class StringNode(PyConstNode):
def generate_evaluation_code(self, code): def generate_evaluation_code(self, code):
self.result_code = code.get_py_string_const( self.result_code = code.get_py_string_const(
self.value, identifier=self.is_identifier, is_str=True) self.value, identifier=self.is_identifier, is_str=True,
unicode_value=self.unicode_value)
def get_constant_c_result_code(self): def get_constant_c_result_code(self):
return None return None
...@@ -1926,6 +1917,40 @@ class NextNode(AtomicExprNode): ...@@ -1926,6 +1917,40 @@ class NextNode(AtomicExprNode):
code.putln("}") code.putln("}")
class WithExitCallNode(ExprNode):
# The __exit__() call of a 'with' statement. Used in both the
# except and finally clauses.
# with_stat WithStatNode the surrounding 'with' statement
# args TupleNode or ResultStatNode the exception info tuple
subexprs = ['args']
def analyse_types(self, env):
self.args.analyse_types(env)
self.type = PyrexTypes.c_bint_type
self.is_temp = True
def generate_result_code(self, code):
if isinstance(self.args, TupleNode):
# call only if it was not already called (and decref-cleared)
code.putln("if (%s) {" % self.with_stat.exit_var)
result_var = code.funcstate.allocate_temp(py_object_type, manage_ref=False)
code.putln("%s = PyObject_Call(%s, %s, NULL);" % (
result_var,
self.with_stat.exit_var,
self.args.result()))
code.put_decref_clear(self.with_stat.exit_var, type=py_object_type)
code.putln(code.error_goto_if_null(result_var, self.pos))
code.put_gotref(result_var)
code.putln("%s = __Pyx_PyObject_IsTrue(%s);" % (self.result(), result_var))
code.put_decref_clear(result_var, type=py_object_type)
code.putln(code.error_goto_if_neg(self.result(), self.pos))
code.funcstate.release_temp(result_var)
if isinstance(self.args, TupleNode):
code.putln("}")
class ExcValueNode(AtomicExprNode): class ExcValueNode(AtomicExprNode):
# Node created during analyse_types phase # Node created during analyse_types phase
# of an ExceptClauseNode to fetch the current # of an ExceptClauseNode to fetch the current
...@@ -1960,7 +1985,7 @@ class TempNode(ExprNode): ...@@ -1960,7 +1985,7 @@ class TempNode(ExprNode):
subexprs = [] subexprs = []
def __init__(self, pos, type, env): def __init__(self, pos, type, env=None):
ExprNode.__init__(self, pos) ExprNode.__init__(self, pos)
self.type = type self.type = type
if type.is_pyobject: if type.is_pyobject:
...@@ -1970,6 +1995,9 @@ class TempNode(ExprNode): ...@@ -1970,6 +1995,9 @@ class TempNode(ExprNode):
def analyse_types(self, env): def analyse_types(self, env):
return self.type return self.type
def analyse_target_declaration(self, env):
pass
def generate_result_code(self, code): def generate_result_code(self, code):
pass pass
...@@ -6665,6 +6693,14 @@ class CmpNode(object): ...@@ -6665,6 +6693,14 @@ class CmpNode(object):
env.use_utility_code(pyunicode_equals_utility_code) env.use_utility_code(pyunicode_equals_utility_code)
self.special_bool_cmp_function = "__Pyx_PyUnicode_Equals" self.special_bool_cmp_function = "__Pyx_PyUnicode_Equals"
return True return True
elif type1 is Builtin.bytes_type or type2 is Builtin.bytes_type:
env.use_utility_code(pybytes_equals_utility_code)
self.special_bool_cmp_function = "__Pyx_PyBytes_Equals"
return True
elif type1 is Builtin.str_type or type2 is Builtin.str_type:
env.use_utility_code(pystr_equals_utility_code)
self.special_bool_cmp_function = "__Pyx_PyString_Equals"
return True
return False return False
def generate_operation_code(self, code, result_code, def generate_operation_code(self, code, result_code,
...@@ -6861,8 +6897,6 @@ static CYTHON_INLINE int __Pyx_PyUnicode_Equals(PyObject* s1, PyObject* s2, int ...@@ -6861,8 +6897,6 @@ static CYTHON_INLINE int __Pyx_PyUnicode_Equals(PyObject* s1, PyObject* s2, int
return -1; return -1;
return (equals == Py_EQ) ? (result == 0) : (result != 0); return (equals == Py_EQ) ? (result == 0) : (result != 0);
} }
} else if ((s1 == Py_None) & (s2 == Py_None)) {
return (equals == Py_EQ);
} else if ((s1 == Py_None) & PyUnicode_CheckExact(s2)) { } else if ((s1 == Py_None) & PyUnicode_CheckExact(s2)) {
return (equals == Py_NE); return (equals == Py_NE);
} else if ((s2 == Py_None) & PyUnicode_CheckExact(s1)) { } else if ((s2 == Py_None) & PyUnicode_CheckExact(s1)) {
...@@ -6879,6 +6913,53 @@ static CYTHON_INLINE int __Pyx_PyUnicode_Equals(PyObject* s1, PyObject* s2, int ...@@ -6879,6 +6913,53 @@ static CYTHON_INLINE int __Pyx_PyUnicode_Equals(PyObject* s1, PyObject* s2, int
} }
""") """)
pybytes_equals_utility_code = UtilityCode(
proto="""
static CYTHON_INLINE int __Pyx_PyBytes_Equals(PyObject* s1, PyObject* s2, int equals); /*proto*/
""",
impl="""
static CYTHON_INLINE int __Pyx_PyBytes_Equals(PyObject* s1, PyObject* s2, int equals) {
if (s1 == s2) { /* as done by PyObject_RichCompareBool(); also catches the (interned) empty string */
return (equals == Py_EQ);
} else if (PyBytes_CheckExact(s1) & PyBytes_CheckExact(s2)) {
if (PyBytes_GET_SIZE(s1) != PyBytes_GET_SIZE(s2)) {
return (equals == Py_NE);
} else if (PyBytes_GET_SIZE(s1) == 1) {
if (equals == Py_EQ)
return (PyBytes_AS_STRING(s1)[0] == PyBytes_AS_STRING(s2)[0]);
else
return (PyBytes_AS_STRING(s1)[0] != PyBytes_AS_STRING(s2)[0]);
} else {
int result = memcmp(PyBytes_AS_STRING(s1), PyBytes_AS_STRING(s2), PyBytes_GET_SIZE(s1));
return (equals == Py_EQ) ? (result == 0) : (result != 0);
}
} else if ((s1 == Py_None) & PyBytes_CheckExact(s2)) {
return (equals == Py_NE);
} else if ((s2 == Py_None) & PyBytes_CheckExact(s1)) {
return (equals == Py_NE);
} else {
int result;
PyObject* py_result = PyObject_RichCompare(s1, s2, equals);
if (!py_result)
return -1;
result = __Pyx_PyObject_IsTrue(py_result);
Py_DECREF(py_result);
return result;
}
}
""",
requires=[Builtin.include_string_h_utility_code])
pystr_equals_utility_code = UtilityCode(
proto="""
#if PY_MAJOR_VERSION >= 3
#define __Pyx_PyString_Equals __Pyx_PyUnicode_Equals
#else
#define __Pyx_PyString_Equals __Pyx_PyBytes_Equals
#endif
""",
requires=[pybytes_equals_utility_code, pyunicode_equals_utility_code])
class PrimaryCmpNode(ExprNode, CmpNode): class PrimaryCmpNode(ExprNode, CmpNode):
# Non-cascaded comparison or first comparison of # Non-cascaded comparison or first comparison of
...@@ -7684,11 +7765,18 @@ impl = """ ...@@ -7684,11 +7765,18 @@ impl = """
static PyObject *__Pyx_GetName(PyObject *dict, PyObject *name) { static PyObject *__Pyx_GetName(PyObject *dict, PyObject *name) {
PyObject *result; PyObject *result;
result = PyObject_GetAttr(dict, name); result = PyObject_GetAttr(dict, name);
if (!result) if (!result) {
PyErr_SetObject(PyExc_NameError, name); if (dict != %(BUILTINS)s) {
PyErr_Clear();
result = PyObject_GetAttr(%(BUILTINS)s, name);
}
if (!result) {
PyErr_SetObject(PyExc_NameError, name);
}
}
return result; return result;
} }
""") """ % {'BUILTINS' : Naming.builtins_cname})
#------------------------------------------------------------------------------------ #------------------------------------------------------------------------------------
...@@ -8459,8 +8547,8 @@ static int __Pyx_cdivision_warning(void) { ...@@ -8459,8 +8547,8 @@ static int __Pyx_cdivision_warning(void) {
# from intobject.c # from intobject.c
division_overflow_test_code = UtilityCode( division_overflow_test_code = UtilityCode(
proto=""" proto="""
#define UNARY_NEG_WOULD_OVERFLOW(x) \ #define UNARY_NEG_WOULD_OVERFLOW(x) \
(((x) < 0) & ((unsigned long)(x) == 0-(unsigned long)(x))) (((x) < 0) & ((unsigned long)(x) == 0-(unsigned long)(x)))
""") """)
...@@ -8483,29 +8571,29 @@ static int %(binding_cfunc)s_init(void); /* proto */ ...@@ -8483,29 +8571,29 @@ static int %(binding_cfunc)s_init(void); /* proto */
impl=""" impl="""
static PyObject *%(binding_cfunc)s_NewEx(PyMethodDef *ml, PyObject *self, PyObject *module) { static PyObject *%(binding_cfunc)s_NewEx(PyMethodDef *ml, PyObject *self, PyObject *module) {
%(binding_cfunc)s_object *op = PyObject_GC_New(%(binding_cfunc)s_object, %(binding_cfunc)s); %(binding_cfunc)s_object *op = PyObject_GC_New(%(binding_cfunc)s_object, %(binding_cfunc)s);
if (op == NULL) if (op == NULL)
return NULL; return NULL;
op->func.m_ml = ml; op->func.m_ml = ml;
Py_XINCREF(self); Py_XINCREF(self);
op->func.m_self = self; op->func.m_self = self;
Py_XINCREF(module); Py_XINCREF(module);
op->func.m_module = module; op->func.m_module = module;
PyObject_GC_Track(op); PyObject_GC_Track(op);
return (PyObject *)op; return (PyObject *)op;
} }
static void %(binding_cfunc)s_dealloc(%(binding_cfunc)s_object *m) { static void %(binding_cfunc)s_dealloc(%(binding_cfunc)s_object *m) {
PyObject_GC_UnTrack(m); PyObject_GC_UnTrack(m);
Py_XDECREF(m->func.m_self); Py_XDECREF(m->func.m_self);
Py_XDECREF(m->func.m_module); Py_XDECREF(m->func.m_module);
PyObject_GC_Del(m); PyObject_GC_Del(m);
} }
static PyObject *%(binding_cfunc)s_descr_get(PyObject *func, PyObject *obj, PyObject *type) { static PyObject *%(binding_cfunc)s_descr_get(PyObject *func, PyObject *obj, PyObject *type) {
if (obj == Py_None) if (obj == Py_None)
obj = NULL; obj = NULL;
return PyMethod_New(func, obj, type); return PyMethod_New(func, obj, type);
} }
static int %(binding_cfunc)s_init(void) { static int %(binding_cfunc)s_init(void) {
...@@ -8532,6 +8620,17 @@ static PyObject *__Pyx_Generator_Throw(PyObject *gen, PyObject *args, CYTHON_UNU ...@@ -8532,6 +8620,17 @@ static PyObject *__Pyx_Generator_Throw(PyObject *gen, PyObject *args, CYTHON_UNU
typedef PyObject *(*__pyx_generator_body_t)(PyObject *, PyObject *); typedef PyObject *(*__pyx_generator_body_t)(PyObject *, PyObject *);
""", """,
impl=""" impl="""
static CYTHON_INLINE void __Pyx_Generator_ExceptionClear(struct __pyx_Generator_object *self)
{
Py_XDECREF(self->exc_type);
Py_XDECREF(self->exc_value);
Py_XDECREF(self->exc_traceback);
self->exc_type = NULL;
self->exc_value = NULL;
self->exc_traceback = NULL;
}
static CYTHON_INLINE PyObject *__Pyx_Generator_SendEx(struct __pyx_Generator_object *self, PyObject *value) static CYTHON_INLINE PyObject *__Pyx_Generator_SendEx(struct __pyx_Generator_object *self, PyObject *value)
{ {
PyObject *retval; PyObject *retval;
...@@ -8556,10 +8655,21 @@ static CYTHON_INLINE PyObject *__Pyx_Generator_SendEx(struct __pyx_Generator_obj ...@@ -8556,10 +8655,21 @@ static CYTHON_INLINE PyObject *__Pyx_Generator_SendEx(struct __pyx_Generator_obj
return NULL; return NULL;
} }
if (value)
__Pyx_ExceptionSwap(&self->exc_type, &self->exc_value, &self->exc_traceback);
else
__Pyx_Generator_ExceptionClear(self);
self->is_running = 1; self->is_running = 1;
retval = self->body((PyObject *) self, value); retval = self->body((PyObject *) self, value);
self->is_running = 0; self->is_running = 0;
if (retval)
__Pyx_ExceptionSwap(&self->exc_type, &self->exc_value, &self->exc_traceback);
else
__Pyx_Generator_ExceptionClear(self);
return retval; return retval;
} }
...@@ -8612,10 +8722,10 @@ static PyObject *__Pyx_Generator_Throw(PyObject *self, PyObject *args, CYTHON_UN ...@@ -8612,10 +8722,10 @@ static PyObject *__Pyx_Generator_Throw(PyObject *self, PyObject *args, CYTHON_UN
if (!PyArg_UnpackTuple(args, (char *)"throw", 1, 3, &typ, &val, &tb)) if (!PyArg_UnpackTuple(args, (char *)"throw", 1, 3, &typ, &val, &tb))
return NULL; return NULL;
__Pyx_Raise(typ, val, tb); __Pyx_Raise(typ, val, tb, NULL);
return __Pyx_Generator_SendEx(generator, NULL); return __Pyx_Generator_SendEx(generator, NULL);
} }
""", """,
proto_block='utility_code_proto_before_types', proto_block='utility_code_proto_before_types',
requires=[Nodes.raise_utility_code], requires=[Nodes.raise_utility_code, Nodes.swap_exception_utility_code],
) )
...@@ -616,10 +616,10 @@ def run_pipeline(source, options, full_module_name = None): ...@@ -616,10 +616,10 @@ def run_pipeline(source, options, full_module_name = None):
if os.path.exists(html_filename): if os.path.exists(html_filename):
line = codecs.open(html_filename, "r", encoding="UTF-8").readline() line = codecs.open(html_filename, "r", encoding="UTF-8").readline()
if line.startswith(u'<!-- Generated by Cython'): if line.startswith(u'<!-- Generated by Cython'):
options.annotate = True options.annotate = True
# Get pipeline # Get pipeline
if source_ext.lower() == '.py': if source_ext.lower() == '.py' or not source_ext:
pipeline = context.create_py_pipeline(options, result) pipeline = context.create_py_pipeline(options, result)
else: else:
pipeline = context.create_pyx_pipeline(options, result) pipeline = context.create_pyx_pipeline(options, result)
......
...@@ -73,7 +73,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -73,7 +73,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
self.generate_c_code(env, options, result) self.generate_c_code(env, options, result)
self.generate_h_code(env, options, result) self.generate_h_code(env, options, result)
self.generate_api_code(env, result) self.generate_api_code(env, result)
def has_imported_c_functions(self): def has_imported_c_functions(self):
for module in self.referenced_modules: for module in self.referenced_modules:
for entry in module.cfunc_entries: for entry in module.cfunc_entries:
...@@ -172,7 +172,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -172,7 +172,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
def api_name(self, env): def api_name(self, env):
return env.qualified_name.replace(".", "__") return env.qualified_name.replace(".", "__")
def generate_api_code(self, env, result): def generate_api_code(self, env, result):
def api_entries(entries, pxd=0): def api_entries(entries, pxd=0):
return [entry for entry in entries return [entry for entry in entries
...@@ -255,7 +255,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -255,7 +255,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
def generate_cclass_header_code(self, type, h_code): def generate_cclass_header_code(self, type, h_code):
h_code.putln("%s %s %s;" % ( h_code.putln("%s %s %s;" % (
Naming.extern_c_macro, Naming.extern_c_macro,
PyrexTypes.public_decl("PyTypeObject", "DL_IMPORT"), PyrexTypes.public_decl("PyTypeObject", "DL_IMPORT"),
type.typeobj_cname)) type.typeobj_cname))
...@@ -1785,7 +1785,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1785,7 +1785,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln("#endif") code.putln("#endif")
code.putln("{") code.putln("{")
tempdecl_code = code.insertion_point() tempdecl_code = code.insertion_point()
code.put_declare_refcount_context() code.put_declare_refcount_context()
code.putln("#if CYTHON_REFNANNY") code.putln("#if CYTHON_REFNANNY")
code.putln("__Pyx_RefNanny = __Pyx_RefNannyImportAPI(\"refnanny\");") code.putln("__Pyx_RefNanny = __Pyx_RefNannyImportAPI(\"refnanny\");")
...@@ -2900,104 +2900,104 @@ static int __Pyx_main(int argc, wchar_t **argv) { ...@@ -2900,104 +2900,104 @@ static int __Pyx_main(int argc, wchar_t **argv) {
static wchar_t* static wchar_t*
__Pyx_char2wchar(char* arg) __Pyx_char2wchar(char* arg)
{ {
wchar_t *res; wchar_t *res;
#ifdef HAVE_BROKEN_MBSTOWCS #ifdef HAVE_BROKEN_MBSTOWCS
/* Some platforms have a broken implementation of /* Some platforms have a broken implementation of
* mbstowcs which does not count the characters that * mbstowcs which does not count the characters that
* would result from conversion. Use an upper bound. * would result from conversion. Use an upper bound.
*/ */
size_t argsize = strlen(arg); size_t argsize = strlen(arg);
#else #else
size_t argsize = mbstowcs(NULL, arg, 0); size_t argsize = mbstowcs(NULL, arg, 0);
#endif #endif
size_t count; size_t count;
unsigned char *in; unsigned char *in;
wchar_t *out; wchar_t *out;
#ifdef HAVE_MBRTOWC #ifdef HAVE_MBRTOWC
mbstate_t mbs; mbstate_t mbs;
#endif #endif
if (argsize != (size_t)-1) { if (argsize != (size_t)-1) {
res = (wchar_t *)malloc((argsize+1)*sizeof(wchar_t)); res = (wchar_t *)malloc((argsize+1)*sizeof(wchar_t));
if (!res) if (!res)
goto oom; goto oom;
count = mbstowcs(res, arg, argsize+1); count = mbstowcs(res, arg, argsize+1);
if (count != (size_t)-1) { if (count != (size_t)-1) {
wchar_t *tmp; wchar_t *tmp;
/* Only use the result if it contains no /* Only use the result if it contains no
surrogate characters. */ surrogate characters. */
for (tmp = res; *tmp != 0 && for (tmp = res; *tmp != 0 &&
(*tmp < 0xd800 || *tmp > 0xdfff); tmp++) (*tmp < 0xd800 || *tmp > 0xdfff); tmp++)
; ;
if (*tmp == 0) if (*tmp == 0)
return res; return res;
} }
free(res); free(res);
} }
/* Conversion failed. Fall back to escaping with surrogateescape. */ /* Conversion failed. Fall back to escaping with surrogateescape. */
#ifdef HAVE_MBRTOWC #ifdef HAVE_MBRTOWC
/* Try conversion with mbrtwoc (C99), and escape non-decodable bytes. */ /* Try conversion with mbrtwoc (C99), and escape non-decodable bytes. */
/* Overallocate; as multi-byte characters are in the argument, the /* Overallocate; as multi-byte characters are in the argument, the
actual output could use less memory. */ actual output could use less memory. */
argsize = strlen(arg) + 1; argsize = strlen(arg) + 1;
res = malloc(argsize*sizeof(wchar_t)); res = malloc(argsize*sizeof(wchar_t));
if (!res) goto oom; if (!res) goto oom;
in = (unsigned char*)arg; in = (unsigned char*)arg;
out = res; out = res;
memset(&mbs, 0, sizeof mbs); memset(&mbs, 0, sizeof mbs);
while (argsize) { while (argsize) {
size_t converted = mbrtowc(out, (char*)in, argsize, &mbs); size_t converted = mbrtowc(out, (char*)in, argsize, &mbs);
if (converted == 0) if (converted == 0)
/* Reached end of string; null char stored. */ /* Reached end of string; null char stored. */
break; break;
if (converted == (size_t)-2) { if (converted == (size_t)-2) {
/* Incomplete character. This should never happen, /* Incomplete character. This should never happen,
since we provide everything that we have - since we provide everything that we have -
unless there is a bug in the C library, or I unless there is a bug in the C library, or I
misunderstood how mbrtowc works. */ misunderstood how mbrtowc works. */
fprintf(stderr, "unexpected mbrtowc result -2\\n"); fprintf(stderr, "unexpected mbrtowc result -2\\n");
return NULL; return NULL;
} }
if (converted == (size_t)-1) { if (converted == (size_t)-1) {
/* Conversion error. Escape as UTF-8b, and start over /* Conversion error. Escape as UTF-8b, and start over
in the initial shift state. */ in the initial shift state. */
*out++ = 0xdc00 + *in++; *out++ = 0xdc00 + *in++;
argsize--; argsize--;
memset(&mbs, 0, sizeof mbs); memset(&mbs, 0, sizeof mbs);
continue; continue;
} }
if (*out >= 0xd800 && *out <= 0xdfff) { if (*out >= 0xd800 && *out <= 0xdfff) {
/* Surrogate character. Escape the original /* Surrogate character. Escape the original
byte sequence with surrogateescape. */ byte sequence with surrogateescape. */
argsize -= converted; argsize -= converted;
while (converted--) while (converted--)
*out++ = 0xdc00 + *in++; *out++ = 0xdc00 + *in++;
continue; continue;
} }
/* successfully converted some bytes */ /* successfully converted some bytes */
in += converted; in += converted;
argsize -= converted; argsize -= converted;
out++; out++;
} }
#else #else
/* Cannot use C locale for escaping; manually escape as if charset /* Cannot use C locale for escaping; manually escape as if charset
is ASCII (i.e. escape all bytes > 128. This will still roundtrip is ASCII (i.e. escape all bytes > 128. This will still roundtrip
correctly in the locale's charset, which must be an ASCII superset. */ correctly in the locale's charset, which must be an ASCII superset. */
res = malloc((strlen(arg)+1)*sizeof(wchar_t)); res = malloc((strlen(arg)+1)*sizeof(wchar_t));
if (!res) goto oom; if (!res) goto oom;
in = (unsigned char*)arg; in = (unsigned char*)arg;
out = res; out = res;
while(*in) while(*in)
if(*in < 128) if(*in < 128)
*out++ = *in++; *out++ = *in++;
else else
*out++ = 0xdc00 + *in++; *out++ = 0xdc00 + *in++;
*out = 0; *out = 0;
#endif #endif
return res; return res;
oom: oom:
fprintf(stderr, "out of memory\\n"); fprintf(stderr, "out of memory\\n");
return NULL; return NULL;
} }
int int
......
...@@ -550,7 +550,7 @@ class CFuncDeclaratorNode(CDeclaratorNode): ...@@ -550,7 +550,7 @@ class CFuncDeclaratorNode(CDeclaratorNode):
other_type = type_node.analyse_as_type(env) other_type = type_node.analyse_as_type(env)
if other_type is None: if other_type is None:
error(type_node.pos, "Not a type") error(type_node.pos, "Not a type")
elif (type is not PyrexTypes.py_object_type elif (type is not PyrexTypes.py_object_type
and not type.same_as(other_type)): and not type.same_as(other_type)):
error(self.base.pos, "Signature does not agree with previous declaration") error(self.base.pos, "Signature does not agree with previous declaration")
error(type_node.pos, "Previous declaration here") error(type_node.pos, "Previous declaration here")
...@@ -1139,7 +1139,7 @@ class CEnumDefNode(StatNode): ...@@ -1139,7 +1139,7 @@ class CEnumDefNode(StatNode):
# api boolean # api boolean
# in_pxd boolean # in_pxd boolean
# entry Entry # entry Entry
child_attrs = ["items"] child_attrs = ["items"]
def analyse_declarations(self, env): def analyse_declarations(self, env):
...@@ -1186,7 +1186,7 @@ class CEnumDefItemNode(StatNode): ...@@ -1186,7 +1186,7 @@ class CEnumDefItemNode(StatNode):
if not self.value.type.is_int: if not self.value.type.is_int:
self.value = self.value.coerce_to(PyrexTypes.c_int_type, env) self.value = self.value.coerce_to(PyrexTypes.c_int_type, env)
self.value.analyse_const_expression(env) self.value.analyse_const_expression(env)
entry = env.declare_const(self.name, enum_entry.type, entry = env.declare_const(self.name, enum_entry.type,
self.value, self.pos, cname = self.cname, self.value, self.pos, cname = self.cname,
visibility = enum_entry.visibility, api = enum_entry.api) visibility = enum_entry.visibility, api = enum_entry.api)
enum_entry.enum_values.append(entry) enum_entry.enum_values.append(entry)
...@@ -1281,7 +1281,7 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -1281,7 +1281,7 @@ class FuncDefNode(StatNode, BlockNode):
other_type = type_node.analyse_as_type(env) other_type = type_node.analyse_as_type(env)
if other_type is None: if other_type is None:
error(type_node.pos, "Not a type") error(type_node.pos, "Not a type")
elif (type is not PyrexTypes.py_object_type elif (type is not PyrexTypes.py_object_type
and not type.same_as(other_type)): and not type.same_as(other_type)):
error(arg.base_type.pos, "Signature does not agree with previous declaration") error(arg.base_type.pos, "Signature does not agree with previous declaration")
error(type_node.pos, "Previous declaration here") error(type_node.pos, "Previous declaration here")
...@@ -1912,7 +1912,7 @@ class CFuncDefNode(FuncDefNode): ...@@ -1912,7 +1912,7 @@ class CFuncDefNode(FuncDefNode):
self.modifiers[self.modifiers.index('inline')] = 'cython_inline' self.modifiers[self.modifiers.index('inline')] = 'cython_inline'
if self.modifiers: if self.modifiers:
modifiers = "%s " % ' '.join(self.modifiers).upper() modifiers = "%s " % ' '.join(self.modifiers).upper()
header = self.return_type.declaration_code(entity, dll_linkage=dll_linkage) header = self.return_type.declaration_code(entity, dll_linkage=dll_linkage)
#print (storage_class, modifiers, header) #print (storage_class, modifiers, header)
code.putln("%s%s%s {" % (storage_class, modifiers, header)) code.putln("%s%s%s {" % (storage_class, modifiers, header))
...@@ -2406,14 +2406,7 @@ class DefNode(FuncDefNode): ...@@ -2406,14 +2406,7 @@ class DefNode(FuncDefNode):
entry.doc = None entry.doc = None
def declare_lambda_function(self, env): def declare_lambda_function(self, env):
name = self.name entry = env.declare_lambda_function(self.lambda_name, self.pos)
prefix = env.scope_prefix
func_cname = \
Naming.lambda_func_prefix + u'funcdef' + prefix + self.lambda_name
entry = env.declare_lambda_function(func_cname, self.pos)
entry.pymethdef_cname = \
Naming.lambda_func_prefix + u'methdef' + prefix + self.lambda_name
entry.qualified_name = env.qualify_name(self.lambda_name)
entry.doc = None entry.doc = None
self.entry = entry self.entry = entry
...@@ -3627,7 +3620,7 @@ class CClassDefNode(ClassDefNode): ...@@ -3627,7 +3620,7 @@ class CClassDefNode(ClassDefNode):
visibility = self.visibility, visibility = self.visibility,
typedef_flag = self.typedef_flag, typedef_flag = self.typedef_flag,
api = self.api, api = self.api,
buffer_defaults = buffer_defaults, buffer_defaults = buffer_defaults,
shadow = self.shadow) shadow = self.shadow)
if self.shadow: if self.shadow:
home_scope.lookup(self.class_name).as_variable = self.entry home_scope.lookup(self.class_name).as_variable = self.entry
...@@ -4349,8 +4342,9 @@ class RaiseStatNode(StatNode): ...@@ -4349,8 +4342,9 @@ class RaiseStatNode(StatNode):
# exc_type ExprNode or None # exc_type ExprNode or None
# exc_value ExprNode or None # exc_value ExprNode or None
# exc_tb ExprNode or None # exc_tb ExprNode or None
# cause ExprNode or None
child_attrs = ["exc_type", "exc_value", "exc_tb"] child_attrs = ["exc_type", "exc_value", "exc_tb", "cause"]
def analyse_expressions(self, env): def analyse_expressions(self, env):
if self.exc_type: if self.exc_type:
...@@ -4362,13 +4356,16 @@ class RaiseStatNode(StatNode): ...@@ -4362,13 +4356,16 @@ class RaiseStatNode(StatNode):
if self.exc_tb: if self.exc_tb:
self.exc_tb.analyse_types(env) self.exc_tb.analyse_types(env)
self.exc_tb = self.exc_tb.coerce_to_pyobject(env) self.exc_tb = self.exc_tb.coerce_to_pyobject(env)
if self.cause:
self.cause.analyse_types(env)
self.cause = self.cause.coerce_to_pyobject(env)
# special cases for builtin exceptions # special cases for builtin exceptions
self.builtin_exc_name = None self.builtin_exc_name = None
if self.exc_type and not self.exc_value and not self.exc_tb: if self.exc_type and not self.exc_value and not self.exc_tb:
exc = self.exc_type exc = self.exc_type
import ExprNodes import ExprNodes
if (isinstance(exc, ExprNodes.SimpleCallNode) and if (isinstance(exc, ExprNodes.SimpleCallNode) and
not (exc.args or (exc.arg_tuple is not None and not (exc.args or (exc.arg_tuple is not None and
exc.arg_tuple.args))): exc.arg_tuple.args))):
exc = exc.function # extract the exception type exc = exc.function # extract the exception type
if exc.is_name and exc.entry.is_builtin: if exc.is_name and exc.entry.is_builtin:
...@@ -4399,13 +4396,19 @@ class RaiseStatNode(StatNode): ...@@ -4399,13 +4396,19 @@ class RaiseStatNode(StatNode):
tb_code = self.exc_tb.py_result() tb_code = self.exc_tb.py_result()
else: else:
tb_code = "0" tb_code = "0"
if self.cause:
self.cause.generate_evaluation_code(code)
cause_code = self.cause.py_result()
else:
cause_code = "0"
code.globalstate.use_utility_code(raise_utility_code) code.globalstate.use_utility_code(raise_utility_code)
code.putln( code.putln(
"__Pyx_Raise(%s, %s, %s);" % ( "__Pyx_Raise(%s, %s, %s, %s);" % (
type_code, type_code,
value_code, value_code,
tb_code)) tb_code,
for obj in (self.exc_type, self.exc_value, self.exc_tb): cause_code))
for obj in (self.exc_type, self.exc_value, self.exc_tb, self.cause):
if obj: if obj:
obj.generate_disposal_code(code) obj.generate_disposal_code(code)
obj.free_temps(code) obj.free_temps(code)
...@@ -4419,6 +4422,8 @@ class RaiseStatNode(StatNode): ...@@ -4419,6 +4422,8 @@ class RaiseStatNode(StatNode):
self.exc_value.generate_function_definitions(env, code) self.exc_value.generate_function_definitions(env, code)
if self.exc_tb is not None: if self.exc_tb is not None:
self.exc_tb.generate_function_definitions(env, code) self.exc_tb.generate_function_definitions(env, code)
if self.cause is not None:
self.cause.generate_function_definitions(env, code)
def annotate(self, code): def annotate(self, code):
if self.exc_type: if self.exc_type:
...@@ -4427,6 +4432,8 @@ class RaiseStatNode(StatNode): ...@@ -4427,6 +4432,8 @@ class RaiseStatNode(StatNode):
self.exc_value.annotate(code) self.exc_value.annotate(code)
if self.exc_tb: if self.exc_tb:
self.exc_tb.annotate(code) self.exc_tb.annotate(code)
if self.cause:
self.cause.annotate(code)
class ReraiseStatNode(StatNode): class ReraiseStatNode(StatNode):
...@@ -5030,14 +5037,134 @@ class WithStatNode(StatNode): ...@@ -5030,14 +5037,134 @@ class WithStatNode(StatNode):
""" """
Represents a Python with statement. Represents a Python with statement.
This is only used at parse tree level; and is not present in Implemented by the WithTransform as follows:
analysis or generation phases.
MGR = EXPR
EXIT = MGR.__exit__
VALUE = MGR.__enter__()
EXC = True
try:
try:
TARGET = VALUE # optional
BODY
except:
EXC = False
if not EXIT(*EXCINFO):
raise
finally:
if EXC:
EXIT(None, None, None)
MGR = EXIT = VALUE = None
""" """
# manager The with statement manager object # manager The with statement manager object
# target Node (lhs expression) # target ExprNode the target lhs of the __enter__() call
# body StatNode # body StatNode
child_attrs = ["manager", "target", "body"] child_attrs = ["manager", "target", "body"]
has_target = False
def analyse_declarations(self, env):
self.manager.analyse_declarations(env)
self.body.analyse_declarations(env)
def analyse_expressions(self, env):
self.manager.analyse_types(env)
self.body.analyse_expressions(env)
def generate_function_definitions(self, env, code):
self.manager.generate_function_definitions(env, code)
self.body.generate_function_definitions(env, code)
def generate_execution_code(self, code):
code.putln("/*with:*/ {")
self.manager.generate_evaluation_code(code)
self.exit_var = code.funcstate.allocate_temp(py_object_type, manage_ref=False)
code.putln("%s = PyObject_GetAttr(%s, %s); %s" % (
self.exit_var,
self.manager.py_result(),
code.get_py_string_const(EncodedString('__exit__'), identifier=True),
code.error_goto_if_null(self.exit_var, self.pos),
))
code.put_gotref(self.exit_var)
# need to free exit_var in the face of exceptions during setup
old_error_label = code.new_error_label()
intermediate_error_label = code.error_label
enter_func = code.funcstate.allocate_temp(py_object_type, manage_ref=True)
code.putln("%s = PyObject_GetAttr(%s, %s); %s" % (
enter_func,
self.manager.py_result(),
code.get_py_string_const(EncodedString('__enter__'), identifier=True),
code.error_goto_if_null(enter_func, self.pos),
))
code.put_gotref(enter_func)
self.manager.generate_disposal_code(code)
self.manager.free_temps(code)
self.target_temp.allocate(code)
code.putln('%s = PyObject_Call(%s, ((PyObject *)%s), NULL); %s' % (
self.target_temp.result(),
enter_func,
Naming.empty_tuple,
code.error_goto_if_null(self.target_temp.result(), self.pos),
))
code.put_gotref(self.target_temp.result())
code.put_decref_clear(enter_func, py_object_type)
code.funcstate.release_temp(enter_func)
if not self.has_target:
code.put_decref_clear(self.target_temp.result(), type=py_object_type)
self.target_temp.release(code)
# otherwise, WithTargetAssignmentStatNode will do it for us
code.error_label = old_error_label
self.body.generate_execution_code(code)
step_over_label = code.new_label()
code.put_goto(step_over_label)
code.put_label(intermediate_error_label)
code.put_decref_clear(self.exit_var, py_object_type)
code.put_goto(old_error_label)
code.put_label(step_over_label)
code.funcstate.release_temp(self.exit_var)
code.putln('}')
class WithTargetAssignmentStatNode(AssignmentNode):
# The target assignment of the 'with' statement value (return
# value of the __enter__() call).
#
# This is a special cased assignment that steals the RHS reference
# and frees its temp.
#
# lhs ExprNode the assignment target
# rhs TempNode the return value of the __enter__() call
child_attrs = ["lhs", "rhs"]
def analyse_declarations(self, env):
self.lhs.analyse_target_declaration(env)
def analyse_types(self, env):
self.rhs.analyse_types(env)
self.lhs.analyse_target_types(env)
self.lhs.gil_assignment_check(env)
self.orig_rhs = self.rhs
self.rhs = self.rhs.coerce_to(self.lhs.type, env)
def generate_execution_code(self, code):
self.rhs.generate_evaluation_code(code)
self.lhs.generate_assignment_code(self.rhs, code)
self.orig_rhs.release(code)
def generate_function_definitions(self, env, code):
self.rhs.generate_function_definitions(env, code)
def annotate(self, code):
self.lhs.annotate(code)
self.rhs.annotate(code)
class TryExceptStatNode(StatNode): class TryExceptStatNode(StatNode):
# try .. except statement # try .. except statement
# #
...@@ -5203,7 +5330,7 @@ class ExceptClauseNode(Node): ...@@ -5203,7 +5330,7 @@ class ExceptClauseNode(Node):
# pattern [ExprNode] # pattern [ExprNode]
# target ExprNode or None # target ExprNode or None
# body StatNode # body StatNode
# excinfo_target NameNode or None optional target for exception info # excinfo_target ResultRefNode or None optional target for exception info
# match_flag string result of exception match # match_flag string result of exception match
# exc_value ExcValueNode used internally # exc_value ExcValueNode used internally
# function_name string qualified name of enclosing function # function_name string qualified name of enclosing function
...@@ -5221,8 +5348,6 @@ class ExceptClauseNode(Node): ...@@ -5221,8 +5348,6 @@ class ExceptClauseNode(Node):
def analyse_declarations(self, env): def analyse_declarations(self, env):
if self.target: if self.target:
self.target.analyse_target_declaration(env) self.target.analyse_target_declaration(env)
if self.excinfo_target is not None:
self.excinfo_target.analyse_target_declaration(env)
self.body.analyse_declarations(env) self.body.analyse_declarations(env)
def analyse_expressions(self, env): def analyse_expressions(self, env):
...@@ -5243,7 +5368,6 @@ class ExceptClauseNode(Node): ...@@ -5243,7 +5368,6 @@ class ExceptClauseNode(Node):
self.excinfo_tuple = ExprNodes.TupleNode(pos=self.pos, args=[ self.excinfo_tuple = ExprNodes.TupleNode(pos=self.pos, args=[
ExprNodes.ExcValueNode(pos=self.pos, env=env) for x in range(3)]) ExprNodes.ExcValueNode(pos=self.pos, env=env) for x in range(3)])
self.excinfo_tuple.analyse_expressions(env) self.excinfo_tuple.analyse_expressions(env)
self.excinfo_target.analyse_target_expression(env, self.excinfo_tuple)
self.body.analyse_expressions(env) self.body.analyse_expressions(env)
...@@ -5298,7 +5422,7 @@ class ExceptClauseNode(Node): ...@@ -5298,7 +5422,7 @@ class ExceptClauseNode(Node):
for tempvar, node in zip(exc_vars, self.excinfo_tuple.args): for tempvar, node in zip(exc_vars, self.excinfo_tuple.args):
node.set_var(tempvar) node.set_var(tempvar)
self.excinfo_tuple.generate_evaluation_code(code) self.excinfo_tuple.generate_evaluation_code(code)
self.excinfo_target.generate_assignment_code(self.excinfo_tuple, code) self.excinfo_target.result_code = self.excinfo_tuple.result()
old_break_label, old_continue_label = code.break_label, code.continue_label old_break_label, old_continue_label = code.break_label, code.continue_label
code.break_label = code.new_label('except_break') code.break_label = code.new_label('except_break')
...@@ -5308,24 +5432,32 @@ class ExceptClauseNode(Node): ...@@ -5308,24 +5432,32 @@ class ExceptClauseNode(Node):
code.funcstate.exc_vars = exc_vars code.funcstate.exc_vars = exc_vars
self.body.generate_execution_code(code) self.body.generate_execution_code(code)
code.funcstate.exc_vars = old_exc_vars code.funcstate.exc_vars = old_exc_vars
if self.excinfo_target is not None:
self.excinfo_tuple.generate_disposal_code(code)
for var in exc_vars: for var in exc_vars:
code.putln("__Pyx_DECREF(%s); %s = 0;" % (var, var)) code.put_decref_clear(var, py_object_type)
code.put_goto(end_label) code.put_goto(end_label)
if code.label_used(code.break_label): if code.label_used(code.break_label):
code.put_label(code.break_label) code.put_label(code.break_label)
if self.excinfo_target is not None:
self.excinfo_tuple.generate_disposal_code(code)
for var in exc_vars: for var in exc_vars:
code.putln("__Pyx_DECREF(%s); %s = 0;" % (var, var)) code.put_decref_clear(var, py_object_type)
code.put_goto(old_break_label) code.put_goto(old_break_label)
code.break_label = old_break_label code.break_label = old_break_label
if code.label_used(code.continue_label): if code.label_used(code.continue_label):
code.put_label(code.continue_label) code.put_label(code.continue_label)
if self.excinfo_target is not None:
self.excinfo_tuple.generate_disposal_code(code)
for var in exc_vars: for var in exc_vars:
code.putln("__Pyx_DECREF(%s); %s = 0;" % (var, var)) code.put_decref_clear(var, py_object_type)
code.put_goto(old_continue_label) code.put_goto(old_continue_label)
code.continue_label = old_continue_label code.continue_label = old_continue_label
if self.excinfo_target is not None:
self.excinfo_tuple.free_temps(code)
for temp in exc_vars: for temp in exc_vars:
code.funcstate.release_temp(temp) code.funcstate.release_temp(temp)
...@@ -5365,6 +5497,9 @@ class TryFinallyStatNode(StatNode): ...@@ -5365,6 +5497,9 @@ class TryFinallyStatNode(StatNode):
preserve_exception = 1 preserve_exception = 1
# handle exception case, in addition to return/break/continue
handle_error_case = True
disallow_continue_in_try_finally = 0 disallow_continue_in_try_finally = 0
# There doesn't seem to be any point in disallowing # There doesn't seem to be any point in disallowing
# continue in the try block, since we have no problem # continue in the try block, since we have no problem
...@@ -5398,6 +5533,8 @@ class TryFinallyStatNode(StatNode): ...@@ -5398,6 +5533,8 @@ class TryFinallyStatNode(StatNode):
old_labels = code.all_new_labels() old_labels = code.all_new_labels()
new_labels = code.get_all_labels() new_labels = code.get_all_labels()
new_error_label = code.error_label new_error_label = code.error_label
if not self.handle_error_case:
code.error_label = old_error_label
catch_label = code.new_label() catch_label = code.new_label()
code.putln( code.putln(
"/*try:*/ {") "/*try:*/ {")
...@@ -6068,11 +6205,12 @@ static CYTHON_INLINE void __Pyx_ErrFetch(PyObject **type, PyObject **value, PyOb ...@@ -6068,11 +6205,12 @@ static CYTHON_INLINE void __Pyx_ErrFetch(PyObject **type, PyObject **value, PyOb
raise_utility_code = UtilityCode( raise_utility_code = UtilityCode(
proto = """ proto = """
static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb); /*proto*/ static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb, PyObject *cause); /*proto*/
""", """,
impl = """ impl = """
#if PY_MAJOR_VERSION < 3 #if PY_MAJOR_VERSION < 3
static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb) { static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb, PyObject *cause) {
/* cause is unused */
Py_XINCREF(type); Py_XINCREF(type);
Py_XINCREF(value); Py_XINCREF(value);
Py_XINCREF(tb); Py_XINCREF(tb);
...@@ -6139,7 +6277,7 @@ raise_error: ...@@ -6139,7 +6277,7 @@ raise_error:
#else /* Python 3+ */ #else /* Python 3+ */
static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb) { static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb, PyObject *cause) {
if (tb == Py_None) { if (tb == Py_None) {
tb = 0; tb = 0;
} else if (tb && !PyTraceBack_Check(tb)) { } else if (tb && !PyTraceBack_Check(tb)) {
...@@ -6164,6 +6302,29 @@ static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb) { ...@@ -6164,6 +6302,29 @@ static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb) {
goto bad; goto bad;
} }
if (cause) {
PyObject *fixed_cause;
if (PyExceptionClass_Check(cause)) {
fixed_cause = PyObject_CallObject(cause, NULL);
if (fixed_cause == NULL)
goto bad;
}
else if (PyExceptionInstance_Check(cause)) {
fixed_cause = cause;
Py_INCREF(fixed_cause);
}
else {
PyErr_SetString(PyExc_TypeError,
"exception causes must derive from "
"BaseException");
goto bad;
}
if (!value) {
value = PyObject_CallObject(type, NULL);
}
PyException_SetCause(value, fixed_cause);
}
PyErr_SetObject(type, value); PyErr_SetObject(type, value);
if (tb) { if (tb) {
...@@ -6300,6 +6461,31 @@ static void __Pyx_ExceptionReset(PyObject *type, PyObject *value, PyObject *tb) ...@@ -6300,6 +6461,31 @@ static void __Pyx_ExceptionReset(PyObject *type, PyObject *value, PyObject *tb)
#------------------------------------------------------------------------------------ #------------------------------------------------------------------------------------
swap_exception_utility_code = UtilityCode(
proto = """
static CYTHON_INLINE void __Pyx_ExceptionSwap(PyObject **type, PyObject **value, PyObject **tb); /*proto*/
""",
impl = """
static CYTHON_INLINE void __Pyx_ExceptionSwap(PyObject **type, PyObject **value, PyObject **tb) {
PyObject *tmp_type, *tmp_value, *tmp_tb;
PyThreadState *tstate = PyThreadState_GET();
tmp_type = tstate->exc_type;
tmp_value = tstate->exc_value;
tmp_tb = tstate->exc_traceback;
tstate->exc_type = *type;
tstate->exc_value = *value;
tstate->exc_traceback = *tb;
*type = tmp_type;
*value = tmp_value;
*tb = tmp_tb;
}
""")
#------------------------------------------------------------------------------------
arg_type_test_utility_code = UtilityCode( arg_type_test_utility_code = UtilityCode(
proto = """ proto = """
static int __Pyx_ArgTypeTest(PyObject *obj, PyTypeObject *type, int none_allowed, static int __Pyx_ArgTypeTest(PyObject *obj, PyTypeObject *type, int none_allowed,
......
...@@ -930,7 +930,15 @@ class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations): ...@@ -930,7 +930,15 @@ class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations):
conds = [] conds = []
temps = [] temps = []
for arg in args: for arg in args:
if not arg.is_simple(): try:
# Trial optimisation to avoid redundant temp
# assignments. However, since is_simple() is meant to
# be called after type analysis, we ignore any errors
# and just play safe in that case.
is_simple_arg = arg.is_simple()
except Exception:
is_simple_arg = False
if not is_simple_arg:
# must evaluate all non-simple RHS before doing the comparisons # must evaluate all non-simple RHS before doing the comparisons
arg = UtilNodes.LetRefNode(arg) arg = UtilNodes.LetRefNode(arg)
temps.append(arg) temps.append(arg)
...@@ -3134,6 +3142,15 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations): ...@@ -3134,6 +3142,15 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
return ExprNodes.BoolNode(node.pos, value=bool_result, return ExprNodes.BoolNode(node.pos, value=bool_result,
constant_result=bool_result) constant_result=bool_result)
def visit_CondExprNode(self, node):
self._calculate_const(node)
if node.test.constant_result is ExprNodes.not_a_constant:
return node
if node.test.constant_result:
return node.true_val
else:
return node.false_val
def visit_IfStatNode(self, node): def visit_IfStatNode(self, node):
self.visitchildren(node) self.visitchildren(node)
# eliminate dead code based on constant condition results # eliminate dead code based on constant condition results
......
...@@ -11,11 +11,12 @@ import Naming ...@@ -11,11 +11,12 @@ import Naming
import ExprNodes import ExprNodes
import Nodes import Nodes
import Options import Options
import Builtin
from Cython.Compiler.Visitor import VisitorTransform, TreeVisitor from Cython.Compiler.Visitor import VisitorTransform, TreeVisitor
from Cython.Compiler.Visitor import CythonTransform, EnvTransform, ScopeTrackingTransform from Cython.Compiler.Visitor import CythonTransform, EnvTransform, ScopeTrackingTransform
from Cython.Compiler.ModuleNode import ModuleNode from Cython.Compiler.ModuleNode import ModuleNode
from Cython.Compiler.UtilNodes import LetNode, LetRefNode from Cython.Compiler.UtilNodes import LetNode, LetRefNode, ResultRefNode
from Cython.Compiler.TreeFragment import TreeFragment, TemplateTransform from Cython.Compiler.TreeFragment import TreeFragment, TemplateTransform
from Cython.Compiler.StringEncoding import EncodedString from Cython.Compiler.StringEncoding import EncodedString
from Cython.Compiler.Errors import error, warning, CompileError, InternalError from Cython.Compiler.Errors import error, warning, CompileError, InternalError
...@@ -928,81 +929,55 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations): ...@@ -928,81 +929,55 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
class WithTransform(CythonTransform, SkipDeclarations): class WithTransform(CythonTransform, SkipDeclarations):
# EXCINFO is manually set to a variable that contains
# the exc_info() tuple that can be generated by the enclosing except
# statement.
template_without_target = TreeFragment(u"""
MGR = EXPR
EXIT = MGR.__exit__
MGR.__enter__()
EXC = True
try:
try:
EXCINFO = None
BODY
except:
EXC = False
if not EXIT(*EXCINFO):
raise
finally:
if EXC:
EXIT(None, None, None)
""", temps=[u'MGR', u'EXC', u"EXIT"],
pipeline=[NormalizeTree(None)])
template_with_target = TreeFragment(u"""
MGR = EXPR
EXIT = MGR.__exit__
VALUE = MGR.__enter__()
EXC = True
try:
try:
EXCINFO = None
TARGET = VALUE
BODY
except:
EXC = False
if not EXIT(*EXCINFO):
raise
finally:
if EXC:
EXIT(None, None, None)
MGR = EXIT = VALUE = EXC = None
""", temps=[u'MGR', u'EXC', u"EXIT", u"VALUE"],
pipeline=[NormalizeTree(None)])
def visit_WithStatNode(self, node): def visit_WithStatNode(self, node):
# TODO: Cleanup badly needed self.visitchildren(node, 'body')
TemplateTransform.temp_name_counter += 1 pos = node.pos
handle = "__tmpvar_%d" % TemplateTransform.temp_name_counter body, target, manager = node.body, node.target, node.manager
node.target_temp = ExprNodes.TempNode(pos, type=PyrexTypes.py_object_type)
self.visitchildren(node, ['body']) if target is not None:
excinfo_temp = ExprNodes.NameNode(node.pos, name=handle)#TempHandle(Builtin.tuple_type) node.has_target = True
if node.target is not None: body = Nodes.StatListNode(
result = self.template_with_target.substitute({ pos, stats = [
u'EXPR' : node.manager, Nodes.WithTargetAssignmentStatNode(
u'BODY' : node.body, pos, lhs = target, rhs = node.target_temp),
u'TARGET' : node.target, body
u'EXCINFO' : excinfo_temp ])
}, pos=node.pos) node.target = None
else:
result = self.template_without_target.substitute({ excinfo_target = ResultRefNode(
u'EXPR' : node.manager, pos=pos, type=Builtin.tuple_type, may_hold_none=False)
u'BODY' : node.body, except_clause = Nodes.ExceptClauseNode(
u'EXCINFO' : excinfo_temp pos, body = Nodes.IfStatNode(
}, pos=node.pos) pos, if_clauses = [
Nodes.IfClauseNode(
# Set except excinfo target to EXCINFO pos, condition = ExprNodes.NotNode(
try_except = result.stats[-1].body.stats[-1] pos, operand = ExprNodes.WithExitCallNode(
try_except.except_clauses[0].excinfo_target = ExprNodes.NameNode(node.pos, name=handle) pos, with_stat = node,
# excinfo_temp.ref(node.pos)) args = excinfo_target)),
body = Nodes.ReraiseStatNode(pos),
# result.stats[-1].body.stats[-1] = TempsBlockNode( ),
# node.pos, temps=[excinfo_temp], body=try_except) ],
else_clause = None),
return result pattern = None,
target = None,
excinfo_target = excinfo_target,
)
node.body = Nodes.TryFinallyStatNode(
pos, body = Nodes.TryExceptStatNode(
pos, body = body,
except_clauses = [except_clause],
else_clause = None,
),
finally_clause = Nodes.ExprStatNode(
pos, expr = ExprNodes.WithExitCallNode(
pos, with_stat = node,
args = ExprNodes.TupleNode(
pos, args = [ExprNodes.NoneNode(pos) for _ in range(3)]
))),
handle_error_case = False,
)
return node
def visit_ExprNode(self, node): def visit_ExprNode(self, node):
# With statements are never inside expressions. # With statements are never inside expressions.
...@@ -1256,7 +1231,7 @@ if VALUE is not None: ...@@ -1256,7 +1231,7 @@ if VALUE is not None:
arg = copy.deepcopy(arg_template) arg = copy.deepcopy(arg_template)
arg.declarator.name = entry.name arg.declarator.name = entry.name
init_method.args.append(arg) init_method.args.append(arg)
# setters/getters # setters/getters
for entry, attr in zip(var_entries, attributes): for entry, attr in zip(var_entries, attributes):
# TODO: branch on visibility # TODO: branch on visibility
...@@ -1269,7 +1244,7 @@ if VALUE is not None: ...@@ -1269,7 +1244,7 @@ if VALUE is not None:
}, pos = entry.pos).stats[0] }, pos = entry.pos).stats[0]
property.name = entry.name property.name = entry.name
wrapper_class.body.stats.append(property) wrapper_class.body.stats.append(property)
wrapper_class.analyse_declarations(self.env_stack[-1]) wrapper_class.analyse_declarations(self.env_stack[-1])
return self.visit_CClassDefNode(wrapper_class) return self.visit_CClassDefNode(wrapper_class)
...@@ -1602,6 +1577,12 @@ class CreateClosureClasses(CythonTransform): ...@@ -1602,6 +1577,12 @@ class CreateClosureClasses(CythonTransform):
is_cdef=True) is_cdef=True)
klass.declare_var(pos=pos, name='resume_label', cname='resume_label', type=PyrexTypes.c_int_type, klass.declare_var(pos=pos, name='resume_label', cname='resume_label', type=PyrexTypes.c_int_type,
is_cdef=True) is_cdef=True)
klass.declare_var(pos=pos, name='exc_type', cname='exc_type',
type=PyrexTypes.py_object_type, is_cdef=True)
klass.declare_var(pos=pos, name='exc_value', cname='exc_value',
type=PyrexTypes.py_object_type, is_cdef=True)
klass.declare_var(pos=pos, name='exc_traceback', cname='exc_traceback',
type=PyrexTypes.py_object_type, is_cdef=True)
import TypeSlots import TypeSlots
e = klass.declare_pyfunction('send', pos) e = klass.declare_pyfunction('send', pos)
......
...@@ -1172,6 +1172,7 @@ def p_raise_statement(s): ...@@ -1172,6 +1172,7 @@ def p_raise_statement(s):
exc_type = None exc_type = None
exc_value = None exc_value = None
exc_tb = None exc_tb = None
cause = None
if s.sy not in statement_terminators: if s.sy not in statement_terminators:
exc_type = p_test(s) exc_type = p_test(s)
if s.sy == ',': if s.sy == ',':
...@@ -1180,11 +1181,15 @@ def p_raise_statement(s): ...@@ -1180,11 +1181,15 @@ def p_raise_statement(s):
if s.sy == ',': if s.sy == ',':
s.next() s.next()
exc_tb = p_test(s) exc_tb = p_test(s)
elif s.sy == 'from':
s.next()
cause = p_test(s)
if exc_type or exc_value or exc_tb: if exc_type or exc_value or exc_tb:
return Nodes.RaiseStatNode(pos, return Nodes.RaiseStatNode(pos,
exc_type = exc_type, exc_type = exc_type,
exc_value = exc_value, exc_value = exc_value,
exc_tb = exc_tb) exc_tb = exc_tb,
cause = cause)
else: else:
return Nodes.ReraiseStatNode(pos) return Nodes.ReraiseStatNode(pos)
...@@ -1660,15 +1665,27 @@ def p_simple_statement_list(s, ctx, first_statement = 0): ...@@ -1660,15 +1665,27 @@ def p_simple_statement_list(s, ctx, first_statement = 0):
# Parse a series of simple statements on one line # Parse a series of simple statements on one line
# separated by semicolons. # separated by semicolons.
stat = p_simple_statement(s, first_statement = first_statement) stat = p_simple_statement(s, first_statement = first_statement)
if s.sy == ';': pos = stat.pos
stats = [stat] stats = []
while s.sy == ';': if not isinstance(stat, Nodes.PassStatNode):
#print "p_simple_statement_list: maybe more to follow" ### stats.append(stat)
s.next() while s.sy == ';':
if s.sy in ('NEWLINE', 'EOF'): #print "p_simple_statement_list: maybe more to follow" ###
break s.next()
stats.append(p_simple_statement(s)) if s.sy in ('NEWLINE', 'EOF'):
stat = Nodes.StatListNode(stats[0].pos, stats = stats) break
stat = p_simple_statement(s, first_statement = first_statement)
if isinstance(stat, Nodes.PassStatNode):
continue
stats.append(stat)
first_statement = False
if not stats:
stat = Nodes.PassStatNode(pos)
elif len(stats) == 1:
stat = stats[0]
else:
stat = Nodes.StatListNode(pos, stats = stats)
s.expect_newline("Syntax error in simple statement list") s.expect_newline("Syntax error in simple statement list")
return stat return stat
...@@ -1805,9 +1822,14 @@ def p_statement_list(s, ctx, first_statement = 0): ...@@ -1805,9 +1822,14 @@ def p_statement_list(s, ctx, first_statement = 0):
pos = s.position() pos = s.position()
stats = [] stats = []
while s.sy not in ('DEDENT', 'EOF'): while s.sy not in ('DEDENT', 'EOF'):
stats.append(p_statement(s, ctx, first_statement = first_statement)) stat = p_statement(s, ctx, first_statement = first_statement)
first_statement = 0 if isinstance(stat, Nodes.PassStatNode):
if len(stats) == 1: continue
stats.append(stat)
first_statement = False
if not stats:
return Nodes.PassStatNode(pos)
elif len(stats) == 1:
return stats[0] return stats[0]
else: else:
return Nodes.StatListNode(pos, stats = stats) return Nodes.StatListNode(pos, stats = stats)
...@@ -2523,7 +2545,7 @@ def p_c_struct_or_union_definition(s, pos, ctx): ...@@ -2523,7 +2545,7 @@ def p_c_struct_or_union_definition(s, pos, ctx):
s.expect_dedent() s.expect_dedent()
else: else:
s.expect_newline("Syntax error in struct or union definition") s.expect_newline("Syntax error in struct or union definition")
return Nodes.CStructOrUnionDefNode(pos, return Nodes.CStructOrUnionDefNode(pos,
name = name, cname = cname, kind = kind, attributes = attributes, name = name, cname = cname, kind = kind, attributes = attributes,
typedef_flag = ctx.typedef_flag, visibility = ctx.visibility, typedef_flag = ctx.typedef_flag, visibility = ctx.visibility,
api = ctx.api, in_pxd = ctx.level == 'module_pxd', packed = packed) api = ctx.api, in_pxd = ctx.level == 'module_pxd', packed = packed)
......
...@@ -22,10 +22,10 @@ class BaseType(object): ...@@ -22,10 +22,10 @@ class BaseType(object):
def cast_code(self, expr_code): def cast_code(self, expr_code):
return "((%s)%s)" % (self.declaration_code(""), expr_code) return "((%s)%s)" % (self.declaration_code(""), expr_code)
def specialization_name(self): def specialization_name(self):
return self.declaration_code("").replace(" ", "__") return self.declaration_code("").replace(" ", "__")
def base_declaration_code(self, base_code, entity_code): def base_declaration_code(self, base_code, entity_code):
if entity_code: if entity_code:
return "%s %s" % (base_code, entity_code) return "%s %s" % (base_code, entity_code)
...@@ -98,7 +98,7 @@ class PyrexType(BaseType): ...@@ -98,7 +98,7 @@ class PyrexType(BaseType):
# default_value string Initial value # default_value string Initial value
# entry Entry The Entry for this type # entry Entry The Entry for this type
# #
# declaration_code(entity_code, # declaration_code(entity_code,
# for_display = 0, dll_linkage = None, pyrex = 0) # for_display = 0, dll_linkage = None, pyrex = 0)
# Returns a code fragment for the declaration of an entity # Returns a code fragment for the declaration of an entity
# of this type, given a code fragment for the entity. # of this type, given a code fragment for the entity.
...@@ -122,7 +122,7 @@ class PyrexType(BaseType): ...@@ -122,7 +122,7 @@ class PyrexType(BaseType):
# Coerces array type into pointer type for use as # Coerces array type into pointer type for use as
# a formal argument type. # a formal argument type.
# #
is_pyobject = 0 is_pyobject = 0
is_unspecified = 0 is_unspecified = 0
is_extension_type = 0 is_extension_type = 0
...@@ -150,44 +150,44 @@ class PyrexType(BaseType): ...@@ -150,44 +150,44 @@ class PyrexType(BaseType):
is_buffer = 0 is_buffer = 0
has_attributes = 0 has_attributes = 0
default_value = "" default_value = ""
def resolve(self): def resolve(self):
# If a typedef, returns the base type. # If a typedef, returns the base type.
return self return self
def specialize(self, values): def specialize(self, values):
# TODO(danilo): Override wherever it makes sense. # TODO(danilo): Override wherever it makes sense.
return self return self
def literal_code(self, value): def literal_code(self, value):
# Returns a C code fragment representing a literal # Returns a C code fragment representing a literal
# value of this type. # value of this type.
return str(value) return str(value)
def __str__(self): def __str__(self):
return self.declaration_code("", for_display = 1).strip() return self.declaration_code("", for_display = 1).strip()
def same_as(self, other_type, **kwds): def same_as(self, other_type, **kwds):
return self.same_as_resolved_type(other_type.resolve(), **kwds) return self.same_as_resolved_type(other_type.resolve(), **kwds)
def same_as_resolved_type(self, other_type): def same_as_resolved_type(self, other_type):
return self == other_type or other_type is error_type return self == other_type or other_type is error_type
def subtype_of(self, other_type): def subtype_of(self, other_type):
return self.subtype_of_resolved_type(other_type.resolve()) return self.subtype_of_resolved_type(other_type.resolve())
def subtype_of_resolved_type(self, other_type): def subtype_of_resolved_type(self, other_type):
return self.same_as(other_type) return self.same_as(other_type)
def assignable_from(self, src_type): def assignable_from(self, src_type):
return self.assignable_from_resolved_type(src_type.resolve()) return self.assignable_from_resolved_type(src_type.resolve())
def assignable_from_resolved_type(self, src_type): def assignable_from_resolved_type(self, src_type):
return self.same_as(src_type) return self.same_as(src_type)
def as_argument_type(self): def as_argument_type(self):
return self return self
def is_complete(self): def is_complete(self):
# A type is incomplete if it is an unsized array, # A type is incomplete if it is an unsized array,
# a struct whose attributes are not defined, etc. # a struct whose attributes are not defined, etc.
...@@ -209,7 +209,7 @@ def public_decl(base_code, dll_linkage): ...@@ -209,7 +209,7 @@ def public_decl(base_code, dll_linkage):
return "%s(%s)" % (dll_linkage, base_code) return "%s(%s)" % (dll_linkage, base_code)
else: else:
return base_code return base_code
def create_typedef_type(name, base_type, cname, is_external=0): def create_typedef_type(name, base_type, cname, is_external=0):
if base_type.is_complex: if base_type.is_complex:
if is_external: if is_external:
...@@ -231,7 +231,7 @@ class CTypedefType(BaseType): ...@@ -231,7 +231,7 @@ class CTypedefType(BaseType):
# typedef_cname string # typedef_cname string
# typedef_base_type PyrexType # typedef_base_type PyrexType
# typedef_is_external bool # typedef_is_external bool
is_typedef = 1 is_typedef = 1
typedef_is_external = 0 typedef_is_external = 0
...@@ -239,31 +239,31 @@ class CTypedefType(BaseType): ...@@ -239,31 +239,31 @@ class CTypedefType(BaseType):
from_py_utility_code = None from_py_utility_code = None
subtypes = ['typedef_base_type'] subtypes = ['typedef_base_type']
def __init__(self, name, base_type, cname, is_external=0): def __init__(self, name, base_type, cname, is_external=0):
assert not base_type.is_complex assert not base_type.is_complex
self.typedef_name = name self.typedef_name = name
self.typedef_cname = cname self.typedef_cname = cname
self.typedef_base_type = base_type self.typedef_base_type = base_type
self.typedef_is_external = is_external self.typedef_is_external = is_external
def resolve(self): def resolve(self):
return self.typedef_base_type.resolve() return self.typedef_base_type.resolve()
def declaration_code(self, entity_code, def declaration_code(self, entity_code,
for_display = 0, dll_linkage = None, pyrex = 0): for_display = 0, dll_linkage = None, pyrex = 0):
if pyrex or for_display: if pyrex or for_display:
base_code = self.typedef_name base_code = self.typedef_name
else: else:
base_code = public_decl(self.typedef_cname, dll_linkage) base_code = public_decl(self.typedef_cname, dll_linkage)
return self.base_declaration_code(base_code, entity_code) return self.base_declaration_code(base_code, entity_code)
def as_argument_type(self): def as_argument_type(self):
return self return self
def cast_code(self, expr_code): def cast_code(self, expr_code):
# If self is really an array (rather than pointer), we can't cast. # If self is really an array (rather than pointer), we can't cast.
# For example, the gmp mpz_t. # For example, the gmp mpz_t.
if self.typedef_base_type.is_array: if self.typedef_base_type.is_array:
base_type = self.typedef_base_type.base_type base_type = self.typedef_base_type.base_type
return CPtrType(base_type).cast_code(expr_code) return CPtrType(base_type).cast_code(expr_code)
...@@ -272,7 +272,7 @@ class CTypedefType(BaseType): ...@@ -272,7 +272,7 @@ class CTypedefType(BaseType):
def __repr__(self): def __repr__(self):
return "<CTypedefType %s>" % self.typedef_cname return "<CTypedefType %s>" % self.typedef_cname
def __str__(self): def __str__(self):
return self.typedef_name return self.typedef_name
...@@ -346,7 +346,7 @@ class BufferType(BaseType): ...@@ -346,7 +346,7 @@ class BufferType(BaseType):
# Delegates most attribute # Delegates most attribute
# lookups to the base type. ANYTHING NOT DEFINED # lookups to the base type. ANYTHING NOT DEFINED
# HERE IS DELEGATED! # HERE IS DELEGATED!
# dtype PyrexType # dtype PyrexType
# ndim int # ndim int
# mode str # mode str
...@@ -368,7 +368,7 @@ class BufferType(BaseType): ...@@ -368,7 +368,7 @@ class BufferType(BaseType):
self.mode = mode self.mode = mode
self.negative_indices = negative_indices self.negative_indices = negative_indices
self.cast = cast self.cast = cast
def as_argument_type(self): def as_argument_type(self):
return self return self
...@@ -394,7 +394,7 @@ class PyObjectType(PyrexType): ...@@ -394,7 +394,7 @@ class PyObjectType(PyrexType):
def __str__(self): def __str__(self):
return "Python object" return "Python object"
def __repr__(self): def __repr__(self):
return "<PyObjectType>" return "<PyObjectType>"
...@@ -408,8 +408,8 @@ class PyObjectType(PyrexType): ...@@ -408,8 +408,8 @@ class PyObjectType(PyrexType):
def assignable_from(self, src_type): def assignable_from(self, src_type):
# except for pointers, conversion will be attempted # except for pointers, conversion will be attempted
return not src_type.is_ptr or src_type.is_string return not src_type.is_ptr or src_type.is_string
def declaration_code(self, entity_code, def declaration_code(self, entity_code,
for_display = 0, dll_linkage = None, pyrex = 0): for_display = 0, dll_linkage = None, pyrex = 0):
if pyrex or for_display: if pyrex or for_display:
base_code = "object" base_code = "object"
...@@ -444,15 +444,15 @@ class BuiltinObjectType(PyObjectType): ...@@ -444,15 +444,15 @@ class BuiltinObjectType(PyObjectType):
self.cname = cname self.cname = cname
self.typeptr_cname = "(&%s)" % cname self.typeptr_cname = "(&%s)" % cname
self.objstruct_cname = objstruct_cname self.objstruct_cname = objstruct_cname
def set_scope(self, scope): def set_scope(self, scope):
self.scope = scope self.scope = scope
if scope: if scope:
scope.parent_type = self scope.parent_type = self
def __str__(self): def __str__(self):
return "%s object" % self.name return "%s object" % self.name
def __repr__(self): def __repr__(self):
return "<%s>"% self.cname return "<%s>"% self.cname
...@@ -473,13 +473,13 @@ class BuiltinObjectType(PyObjectType): ...@@ -473,13 +473,13 @@ class BuiltinObjectType(PyObjectType):
src_type.name == self.name) src_type.name == self.name)
else: else:
return True return True
def typeobj_is_available(self): def typeobj_is_available(self):
return True return True
def attributes_known(self): def attributes_known(self):
return True return True
def subtype_of(self, type): def subtype_of(self, type):
return type.is_pyobject and self.assignable_from(type) return type.is_pyobject and self.assignable_from(type)
...@@ -506,7 +506,7 @@ class BuiltinObjectType(PyObjectType): ...@@ -506,7 +506,7 @@ class BuiltinObjectType(PyObjectType):
error = '(PyErr_Format(PyExc_TypeError, "Expected %s, got %%.200s", Py_TYPE(%s)->tp_name), 0)' % (self.name, arg) error = '(PyErr_Format(PyExc_TypeError, "Expected %s, got %%.200s", Py_TYPE(%s)->tp_name), 0)' % (self.name, arg)
return check + '||' + error return check + '||' + error
def declaration_code(self, entity_code, def declaration_code(self, entity_code,
for_display = 0, dll_linkage = None, pyrex = 0): for_display = 0, dll_linkage = None, pyrex = 0):
if pyrex or for_display: if pyrex or for_display:
base_code = self.name base_code = self.name
...@@ -539,12 +539,12 @@ class PyExtensionType(PyObjectType): ...@@ -539,12 +539,12 @@ class PyExtensionType(PyObjectType):
# vtabstruct_cname string Name of C method table struct # vtabstruct_cname string Name of C method table struct
# vtabptr_cname string Name of pointer to C method table # vtabptr_cname string Name of pointer to C method table
# vtable_cname string Name of C method table definition # vtable_cname string Name of C method table definition
is_extension_type = 1 is_extension_type = 1
has_attributes = 1 has_attributes = 1
objtypedef_cname = None objtypedef_cname = None
def __init__(self, name, typedef_flag, base_type, is_external=0): def __init__(self, name, typedef_flag, base_type, is_external=0):
self.name = name self.name = name
self.scope = None self.scope = None
...@@ -561,28 +561,28 @@ class PyExtensionType(PyObjectType): ...@@ -561,28 +561,28 @@ class PyExtensionType(PyObjectType):
self.vtabptr_cname = None self.vtabptr_cname = None
self.vtable_cname = None self.vtable_cname = None
self.is_external = is_external self.is_external = is_external
def set_scope(self, scope): def set_scope(self, scope):
self.scope = scope self.scope = scope
if scope: if scope:
scope.parent_type = self scope.parent_type = self
def subtype_of_resolved_type(self, other_type): def subtype_of_resolved_type(self, other_type):
if other_type.is_extension_type: if other_type.is_extension_type:
return self is other_type or ( return self is other_type or (
self.base_type and self.base_type.subtype_of(other_type)) self.base_type and self.base_type.subtype_of(other_type))
else: else:
return other_type is py_object_type return other_type is py_object_type
def typeobj_is_available(self): def typeobj_is_available(self):
# Do we have a pointer to the type object? # Do we have a pointer to the type object?
return self.typeptr_cname return self.typeptr_cname
def typeobj_is_imported(self): def typeobj_is_imported(self):
# If we don't know the C name of the type object but we do # If we don't know the C name of the type object but we do
# know which module it's defined in, it will be imported. # know which module it's defined in, it will be imported.
return self.typeobj_cname is None and self.module_name is not None return self.typeobj_cname is None and self.module_name is not None
def assignable_from(self, src_type): def assignable_from(self, src_type):
if self == src_type: if self == src_type:
return True return True
...@@ -591,7 +591,7 @@ class PyExtensionType(PyObjectType): ...@@ -591,7 +591,7 @@ class PyExtensionType(PyObjectType):
return self.assignable_from(src_type.base_type) return self.assignable_from(src_type.base_type)
return False return False
def declaration_code(self, entity_code, def declaration_code(self, entity_code,
for_display = 0, dll_linkage = None, pyrex = 0, deref = 0): for_display = 0, dll_linkage = None, pyrex = 0, deref = 0):
if pyrex or for_display: if pyrex or for_display:
base_code = self.name base_code = self.name
...@@ -619,14 +619,14 @@ class PyExtensionType(PyObjectType): ...@@ -619,14 +619,14 @@ class PyExtensionType(PyObjectType):
def attributes_known(self): def attributes_known(self):
return self.scope is not None return self.scope is not None
def __str__(self): def __str__(self):
return self.name return self.name
def __repr__(self): def __repr__(self):
return "<PyExtensionType %s%s>" % (self.scope.class_name, return "<PyExtensionType %s%s>" % (self.scope.class_name,
("", " typedef")[self.typedef_flag]) ("", " typedef")[self.typedef_flag])
class CType(PyrexType): class CType(PyrexType):
# #
...@@ -635,7 +635,7 @@ class CType(PyrexType): ...@@ -635,7 +635,7 @@ class CType(PyrexType):
# to_py_function string C function for converting to Python object # to_py_function string C function for converting to Python object
# from_py_function string C function for constructing from Python object # from_py_function string C function for constructing from Python object
# #
to_py_function = None to_py_function = None
from_py_function = None from_py_function = None
exception_value = None exception_value = None
...@@ -643,7 +643,7 @@ class CType(PyrexType): ...@@ -643,7 +643,7 @@ class CType(PyrexType):
def create_to_py_utility_code(self, env): def create_to_py_utility_code(self, env):
return self.to_py_function is not None return self.to_py_function is not None
def create_from_py_utility_code(self, env): def create_from_py_utility_code(self, env):
return self.from_py_function is not None return self.from_py_function is not None
...@@ -709,18 +709,18 @@ class CVoidType(CType): ...@@ -709,18 +709,18 @@ class CVoidType(CType):
# #
is_void = 1 is_void = 1
def __repr__(self): def __repr__(self):
return "<CVoidType>" return "<CVoidType>"
def declaration_code(self, entity_code, def declaration_code(self, entity_code,
for_display = 0, dll_linkage = None, pyrex = 0): for_display = 0, dll_linkage = None, pyrex = 0):
if pyrex or for_display: if pyrex or for_display:
base_code = "void" base_code = "void"
else: else:
base_code = public_decl("void", dll_linkage) base_code = public_decl("void", dll_linkage)
return self.base_declaration_code(base_code, entity_code) return self.base_declaration_code(base_code, entity_code)
def is_complete(self): def is_complete(self):
return 0 return 0
...@@ -732,27 +732,27 @@ class CNumericType(CType): ...@@ -732,27 +732,27 @@ class CNumericType(CType):
# rank integer Relative size # rank integer Relative size
# signed integer 0 = unsigned, 1 = unspecified, 2 = explicitly signed # signed integer 0 = unsigned, 1 = unspecified, 2 = explicitly signed
# #
is_numeric = 1 is_numeric = 1
default_value = "0" default_value = "0"
has_attributes = True has_attributes = True
scope = None scope = None
sign_words = ("unsigned ", "", "signed ") sign_words = ("unsigned ", "", "signed ")
def __init__(self, rank, signed = 1): def __init__(self, rank, signed = 1):
self.rank = rank self.rank = rank
self.signed = signed self.signed = signed
def sign_and_name(self): def sign_and_name(self):
s = self.sign_words[self.signed] s = self.sign_words[self.signed]
n = rank_to_type_name[self.rank] n = rank_to_type_name[self.rank]
return s + n return s + n
def __repr__(self): def __repr__(self):
return "<CNumericType %s>" % self.sign_and_name() return "<CNumericType %s>" % self.sign_and_name()
def declaration_code(self, entity_code, def declaration_code(self, entity_code,
for_display = 0, dll_linkage = None, pyrex = 0): for_display = 0, dll_linkage = None, pyrex = 0):
type_name = self.sign_and_name() type_name = self.sign_and_name()
if pyrex or for_display: if pyrex or for_display:
...@@ -760,7 +760,7 @@ class CNumericType(CType): ...@@ -760,7 +760,7 @@ class CNumericType(CType):
else: else:
base_code = public_decl(type_name, dll_linkage) base_code = public_decl(type_name, dll_linkage)
return self.base_declaration_code(base_code, entity_code) return self.base_declaration_code(base_code, entity_code)
def attributes_known(self): def attributes_known(self):
if self.scope is None: if self.scope is None:
import Symtab import Symtab
...@@ -931,7 +931,7 @@ static CYTHON_INLINE PyObject *__Pyx_PyInt_to_py_%(TypeName)s(%(type)s val) { ...@@ -931,7 +931,7 @@ static CYTHON_INLINE PyObject *__Pyx_PyInt_to_py_%(TypeName)s(%(type)s val) {
} else { } else {
int one = 1; int little = (int)*(unsigned char *)&one; int one = 1; int little = (int)*(unsigned char *)&one;
unsigned char *bytes = (unsigned char *)&val; unsigned char *bytes = (unsigned char *)&val;
return _PyLong_FromByteArray(bytes, sizeof(%(type)s), return _PyLong_FromByteArray(bytes, sizeof(%(type)s),
little, !is_unsigned); little, !is_unsigned);
} }
} }
...@@ -1186,22 +1186,22 @@ class CFloatType(CNumericType): ...@@ -1186,22 +1186,22 @@ class CFloatType(CNumericType):
from_py_function = "__pyx_PyFloat_AsDouble" from_py_function = "__pyx_PyFloat_AsDouble"
exception_value = -1 exception_value = -1
def __init__(self, rank, math_h_modifier = ''): def __init__(self, rank, math_h_modifier = ''):
CNumericType.__init__(self, rank, 1) CNumericType.__init__(self, rank, 1)
self.math_h_modifier = math_h_modifier self.math_h_modifier = math_h_modifier
def assignable_from_resolved_type(self, src_type): def assignable_from_resolved_type(self, src_type):
return (src_type.is_numeric and not src_type.is_complex) or src_type is error_type return (src_type.is_numeric and not src_type.is_complex) or src_type is error_type
class CComplexType(CNumericType): class CComplexType(CNumericType):
is_complex = 1 is_complex = 1
to_py_function = "__pyx_PyComplex_FromComplex" to_py_function = "__pyx_PyComplex_FromComplex"
has_attributes = 1 has_attributes = 1
scope = None scope = None
def __init__(self, real_type): def __init__(self, real_type):
while real_type.is_typedef and not real_type.typedef_is_external: while real_type.is_typedef and not real_type.typedef_is_external:
real_type = real_type.typedef_base_type real_type = real_type.typedef_base_type
...@@ -1213,7 +1213,7 @@ class CComplexType(CNumericType): ...@@ -1213,7 +1213,7 @@ class CComplexType(CNumericType):
self.funcsuffix = real_type.math_h_modifier self.funcsuffix = real_type.math_h_modifier
else: else:
self.funcsuffix = "_%s" % real_type.specialization_name() self.funcsuffix = "_%s" % real_type.specialization_name()
self.real_type = real_type self.real_type = real_type
CNumericType.__init__(self, real_type.rank + 0.5, real_type.signed) CNumericType.__init__(self, real_type.rank + 0.5, real_type.signed)
self.binops = {} self.binops = {}
...@@ -1225,7 +1225,7 @@ class CComplexType(CNumericType): ...@@ -1225,7 +1225,7 @@ class CComplexType(CNumericType):
return self.real_type == other.real_type return self.real_type == other.real_type
else: else:
return False return False
def __ne__(self, other): def __ne__(self, other):
if isinstance(self, CComplexType) and isinstance(other, CComplexType): if isinstance(self, CComplexType) and isinstance(other, CComplexType):
return self.real_type != other.real_type return self.real_type != other.real_type
...@@ -1243,7 +1243,7 @@ class CComplexType(CNumericType): ...@@ -1243,7 +1243,7 @@ class CComplexType(CNumericType):
def __hash__(self): def __hash__(self):
return ~hash(self.real_type) return ~hash(self.real_type)
def declaration_code(self, entity_code, def declaration_code(self, entity_code,
for_display = 0, dll_linkage = None, pyrex = 0): for_display = 0, dll_linkage = None, pyrex = 0):
if pyrex or for_display: if pyrex or for_display:
real_code = self.real_type.declaration_code("", for_display, dll_linkage, pyrex) real_code = self.real_type.declaration_code("", for_display, dll_linkage, pyrex)
...@@ -1257,7 +1257,7 @@ class CComplexType(CNumericType): ...@@ -1257,7 +1257,7 @@ class CComplexType(CNumericType):
real_type_name = real_type_name.replace('long__double','long_double') real_type_name = real_type_name.replace('long__double','long_double')
real_type_name = real_type_name.replace('PY_LONG_LONG','long_long') real_type_name = real_type_name.replace('PY_LONG_LONG','long_long')
return Naming.type_prefix + real_type_name + "_complex" return Naming.type_prefix + real_type_name + "_complex"
def assignable_from(self, src_type): def assignable_from(self, src_type):
# Temporary hack/feature disabling, see #441 # Temporary hack/feature disabling, see #441
if (not src_type.is_complex and src_type.is_numeric and src_type.is_typedef if (not src_type.is_complex and src_type.is_numeric and src_type.is_typedef
...@@ -1265,12 +1265,12 @@ class CComplexType(CNumericType): ...@@ -1265,12 +1265,12 @@ class CComplexType(CNumericType):
return False return False
else: else:
return super(CComplexType, self).assignable_from(src_type) return super(CComplexType, self).assignable_from(src_type)
def assignable_from_resolved_type(self, src_type): def assignable_from_resolved_type(self, src_type):
return (src_type.is_complex and self.real_type.assignable_from_resolved_type(src_type.real_type) return (src_type.is_complex and self.real_type.assignable_from_resolved_type(src_type.real_type)
or src_type.is_numeric and self.real_type.assignable_from_resolved_type(src_type) or src_type.is_numeric and self.real_type.assignable_from_resolved_type(src_type)
or src_type is error_type) or src_type is error_type)
def attributes_known(self): def attributes_known(self):
if self.scope is None: if self.scope is None:
import Symtab import Symtab
...@@ -1301,7 +1301,7 @@ class CComplexType(CNumericType): ...@@ -1301,7 +1301,7 @@ class CComplexType(CNumericType):
complex_arithmetic_utility_code): complex_arithmetic_utility_code):
env.use_utility_code( env.use_utility_code(
utility_code.specialize( utility_code.specialize(
self, self,
real_type = self.real_type.declaration_code(''), real_type = self.real_type.declaration_code(''),
m = self.funcsuffix, m = self.funcsuffix,
is_float = self.real_type.is_float)) is_float = self.real_type.is_float))
...@@ -1319,13 +1319,13 @@ class CComplexType(CNumericType): ...@@ -1319,13 +1319,13 @@ class CComplexType(CNumericType):
complex_from_py_utility_code): complex_from_py_utility_code):
env.use_utility_code( env.use_utility_code(
utility_code.specialize( utility_code.specialize(
self, self,
real_type = self.real_type.declaration_code(''), real_type = self.real_type.declaration_code(''),
m = self.funcsuffix, m = self.funcsuffix,
is_float = self.real_type.is_float)) is_float = self.real_type.is_float))
self.from_py_function = "__Pyx_PyComplex_As_" + self.specialization_name() self.from_py_function = "__Pyx_PyComplex_As_" + self.specialization_name()
return True return True
def lookup_op(self, nargs, op): def lookup_op(self, nargs, op):
try: try:
return self.binops[nargs, op] return self.binops[nargs, op]
...@@ -1340,10 +1340,10 @@ class CComplexType(CNumericType): ...@@ -1340,10 +1340,10 @@ class CComplexType(CNumericType):
def unary_op(self, op): def unary_op(self, op):
return self.lookup_op(1, op) return self.lookup_op(1, op)
def binary_op(self, op): def binary_op(self, op):
return self.lookup_op(2, op) return self.lookup_op(2, op)
complex_ops = { complex_ops = {
(1, '-'): 'neg', (1, '-'): 'neg',
(1, 'zero'): 'is_zero', (1, 'zero'): 'is_zero',
...@@ -1614,7 +1614,7 @@ impl=""" ...@@ -1614,7 +1614,7 @@ impl="""
class CArrayType(CType): class CArrayType(CType):
# base_type CType Element type # base_type CType Element type
# size integer or None Number of elements # size integer or None Number of elements
is_array = 1 is_array = 1
subtypes = ['base_type'] subtypes = ['base_type']
...@@ -1624,23 +1624,23 @@ class CArrayType(CType): ...@@ -1624,23 +1624,23 @@ class CArrayType(CType):
self.size = size self.size = size
if base_type in (c_char_type, c_uchar_type, c_schar_type): if base_type in (c_char_type, c_uchar_type, c_schar_type):
self.is_string = 1 self.is_string = 1
def __repr__(self): def __repr__(self):
return "<CArrayType %s %s>" % (self.size, repr(self.base_type)) return "<CArrayType %s %s>" % (self.size, repr(self.base_type))
def same_as_resolved_type(self, other_type): def same_as_resolved_type(self, other_type):
return ((other_type.is_array and return ((other_type.is_array and
self.base_type.same_as(other_type.base_type)) self.base_type.same_as(other_type.base_type))
or other_type is error_type) or other_type is error_type)
def assignable_from_resolved_type(self, src_type): def assignable_from_resolved_type(self, src_type):
# Can't assign to a variable of an array type # Can't assign to a variable of an array type
return 0 return 0
def element_ptr_type(self): def element_ptr_type(self):
return c_ptr_type(self.base_type) return c_ptr_type(self.base_type)
def declaration_code(self, entity_code, def declaration_code(self, entity_code,
for_display = 0, dll_linkage = None, pyrex = 0): for_display = 0, dll_linkage = None, pyrex = 0):
if self.size is not None: if self.size is not None:
dimension_code = self.size dimension_code = self.size
...@@ -1651,40 +1651,40 @@ class CArrayType(CType): ...@@ -1651,40 +1651,40 @@ class CArrayType(CType):
return self.base_type.declaration_code( return self.base_type.declaration_code(
"%s[%s]" % (entity_code, dimension_code), "%s[%s]" % (entity_code, dimension_code),
for_display, dll_linkage, pyrex) for_display, dll_linkage, pyrex)
def as_argument_type(self): def as_argument_type(self):
return c_ptr_type(self.base_type) return c_ptr_type(self.base_type)
def is_complete(self): def is_complete(self):
return self.size is not None return self.size is not None
class CPtrType(CType): class CPtrType(CType):
# base_type CType Referenced type # base_type CType Referenced type
is_ptr = 1 is_ptr = 1
default_value = "0" default_value = "0"
subtypes = ['base_type'] subtypes = ['base_type']
def __init__(self, base_type): def __init__(self, base_type):
self.base_type = base_type self.base_type = base_type
def __repr__(self): def __repr__(self):
return "<CPtrType %s>" % repr(self.base_type) return "<CPtrType %s>" % repr(self.base_type)
def same_as_resolved_type(self, other_type): def same_as_resolved_type(self, other_type):
return ((other_type.is_ptr and return ((other_type.is_ptr and
self.base_type.same_as(other_type.base_type)) self.base_type.same_as(other_type.base_type))
or other_type is error_type) or other_type is error_type)
def declaration_code(self, entity_code, def declaration_code(self, entity_code,
for_display = 0, dll_linkage = None, pyrex = 0): for_display = 0, dll_linkage = None, pyrex = 0):
#print "CPtrType.declaration_code: pointer to", self.base_type ### #print "CPtrType.declaration_code: pointer to", self.base_type ###
return self.base_type.declaration_code( return self.base_type.declaration_code(
"*%s" % entity_code, "*%s" % entity_code,
for_display, dll_linkage, pyrex) for_display, dll_linkage, pyrex)
def assignable_from_resolved_type(self, other_type): def assignable_from_resolved_type(self, other_type):
if other_type is error_type: if other_type is error_type:
return 1 return 1
...@@ -1697,13 +1697,13 @@ class CPtrType(CType): ...@@ -1697,13 +1697,13 @@ class CPtrType(CType):
return self.base_type.pointer_assignable_from_resolved_type(other_type) return self.base_type.pointer_assignable_from_resolved_type(other_type)
else: else:
return 0 return 0
if (self.base_type.is_cpp_class and other_type.is_ptr if (self.base_type.is_cpp_class and other_type.is_ptr
and other_type.base_type.is_cpp_class and other_type.base_type.is_subclass(self.base_type)): and other_type.base_type.is_cpp_class and other_type.base_type.is_subclass(self.base_type)):
return 1 return 1
if other_type.is_array or other_type.is_ptr: if other_type.is_array or other_type.is_ptr:
return self.base_type.is_void or self.base_type.same_as(other_type.base_type) return self.base_type.is_void or self.base_type.same_as(other_type.base_type)
return 0 return 0
def specialize(self, values): def specialize(self, values):
base_type = self.base_type.specialize(values) base_type = self.base_type.specialize(values)
if base_type == self.base_type: if base_type == self.base_type:
...@@ -1715,7 +1715,7 @@ class CPtrType(CType): ...@@ -1715,7 +1715,7 @@ class CPtrType(CType):
class CNullPtrType(CPtrType): class CNullPtrType(CPtrType):
is_null_ptr = 1 is_null_ptr = 1
class CReferenceType(BaseType): class CReferenceType(BaseType):
...@@ -1726,20 +1726,20 @@ class CReferenceType(BaseType): ...@@ -1726,20 +1726,20 @@ class CReferenceType(BaseType):
def __repr__(self): def __repr__(self):
return "<CReferenceType %s>" % repr(self.ref_base_type) return "<CReferenceType %s>" % repr(self.ref_base_type)
def __str__(self): def __str__(self):
return "%s &" % self.ref_base_type return "%s &" % self.ref_base_type
def as_argument_type(self): def as_argument_type(self):
return self return self
def declaration_code(self, entity_code, def declaration_code(self, entity_code,
for_display = 0, dll_linkage = None, pyrex = 0): for_display = 0, dll_linkage = None, pyrex = 0):
#print "CReferenceType.declaration_code: pointer to", self.base_type ### #print "CReferenceType.declaration_code: pointer to", self.base_type ###
return self.ref_base_type.declaration_code( return self.ref_base_type.declaration_code(
"&%s" % entity_code, "&%s" % entity_code,
for_display, dll_linkage, pyrex) for_display, dll_linkage, pyrex)
def specialize(self, values): def specialize(self, values):
base_type = self.ref_base_type.specialize(values) base_type = self.ref_base_type.specialize(values)
if base_type == self.ref_base_type: if base_type == self.ref_base_type:
...@@ -1761,7 +1761,7 @@ class CFuncType(CType): ...@@ -1761,7 +1761,7 @@ class CFuncType(CType):
# nogil boolean Can be called without gil # nogil boolean Can be called without gil
# with_gil boolean Acquire gil around function body # with_gil boolean Acquire gil around function body
# templates [string] or None # templates [string] or None
is_cfunction = 1 is_cfunction = 1
original_sig = None original_sig = None
...@@ -1798,14 +1798,14 @@ class CFuncType(CType): ...@@ -1798,14 +1798,14 @@ class CFuncType(CType):
self.calling_convention_prefix(), self.calling_convention_prefix(),
",".join(arg_reprs), ",".join(arg_reprs),
except_clause) except_clause)
def calling_convention_prefix(self): def calling_convention_prefix(self):
cc = self.calling_convention cc = self.calling_convention
if cc: if cc:
return cc + " " return cc + " "
else: else:
return "" return ""
def same_c_signature_as(self, other_type, as_cmethod = 0): def same_c_signature_as(self, other_type, as_cmethod = 0):
return self.same_c_signature_as_resolved_type( return self.same_c_signature_as_resolved_type(
other_type.resolve(), as_cmethod) other_type.resolve(), as_cmethod)
...@@ -1841,7 +1841,7 @@ class CFuncType(CType): ...@@ -1841,7 +1841,7 @@ class CFuncType(CType):
def compatible_signature_with(self, other_type, as_cmethod = 0): def compatible_signature_with(self, other_type, as_cmethod = 0):
return self.compatible_signature_with_resolved_type(other_type.resolve(), as_cmethod) return self.compatible_signature_with_resolved_type(other_type.resolve(), as_cmethod)
def compatible_signature_with_resolved_type(self, other_type, as_cmethod): def compatible_signature_with_resolved_type(self, other_type, as_cmethod):
#print "CFuncType.same_c_signature_as_resolved_type:", \ #print "CFuncType.same_c_signature_as_resolved_type:", \
# self, other_type, "as_cmethod =", as_cmethod ### # self, other_type, "as_cmethod =", as_cmethod ###
...@@ -1875,11 +1875,11 @@ class CFuncType(CType): ...@@ -1875,11 +1875,11 @@ class CFuncType(CType):
if as_cmethod: if as_cmethod:
self.args[0] = other_type.args[0] self.args[0] = other_type.args[0]
return 1 return 1
def narrower_c_signature_than(self, other_type, as_cmethod = 0): def narrower_c_signature_than(self, other_type, as_cmethod = 0):
return self.narrower_c_signature_than_resolved_type(other_type.resolve(), as_cmethod) return self.narrower_c_signature_than_resolved_type(other_type.resolve(), as_cmethod)
def narrower_c_signature_than_resolved_type(self, other_type, as_cmethod): def narrower_c_signature_than_resolved_type(self, other_type, as_cmethod):
if other_type is error_type: if other_type is error_type:
return 1 return 1
...@@ -1915,7 +1915,7 @@ class CFuncType(CType): ...@@ -1915,7 +1915,7 @@ class CFuncType(CType):
sc1 = self.calling_convention == '__stdcall' sc1 = self.calling_convention == '__stdcall'
sc2 = other.calling_convention == '__stdcall' sc2 = other.calling_convention == '__stdcall'
return sc1 == sc2 return sc1 == sc2
def same_exception_signature_as(self, other_type): def same_exception_signature_as(self, other_type):
return self.same_exception_signature_as_resolved_type( return self.same_exception_signature_as_resolved_type(
other_type.resolve()) other_type.resolve())
...@@ -1923,18 +1923,18 @@ class CFuncType(CType): ...@@ -1923,18 +1923,18 @@ class CFuncType(CType):
def same_exception_signature_as_resolved_type(self, other_type): def same_exception_signature_as_resolved_type(self, other_type):
return self.exception_value == other_type.exception_value \ return self.exception_value == other_type.exception_value \
and self.exception_check == other_type.exception_check and self.exception_check == other_type.exception_check
def same_as_resolved_type(self, other_type, as_cmethod = 0): def same_as_resolved_type(self, other_type, as_cmethod = 0):
return self.same_c_signature_as_resolved_type(other_type, as_cmethod) \ return self.same_c_signature_as_resolved_type(other_type, as_cmethod) \
and self.same_exception_signature_as_resolved_type(other_type) \ and self.same_exception_signature_as_resolved_type(other_type) \
and self.nogil == other_type.nogil and self.nogil == other_type.nogil
def pointer_assignable_from_resolved_type(self, other_type): def pointer_assignable_from_resolved_type(self, other_type):
return self.same_c_signature_as_resolved_type(other_type) \ return self.same_c_signature_as_resolved_type(other_type) \
and self.same_exception_signature_as_resolved_type(other_type) \ and self.same_exception_signature_as_resolved_type(other_type) \
and not (self.nogil and not other_type.nogil) and not (self.nogil and not other_type.nogil)
def declaration_code(self, entity_code, def declaration_code(self, entity_code,
for_display = 0, dll_linkage = None, pyrex = 0, for_display = 0, dll_linkage = None, pyrex = 0,
with_calling_convention = 1): with_calling_convention = 1):
arg_decl_list = [] arg_decl_list = []
...@@ -1972,7 +1972,7 @@ class CFuncType(CType): ...@@ -1972,7 +1972,7 @@ class CFuncType(CType):
return self.return_type.declaration_code( return self.return_type.declaration_code(
"%s%s(%s)%s" % (cc, entity_code, arg_decl_code, trailer), "%s%s(%s)%s" % (cc, entity_code, arg_decl_code, trailer),
for_display, dll_linkage, pyrex) for_display, dll_linkage, pyrex)
def function_header_code(self, func_name, arg_code): def function_header_code(self, func_name, arg_code):
return "%s%s(%s)" % (self.calling_convention_prefix(), return "%s%s(%s)" % (self.calling_convention_prefix(),
func_name, arg_code) func_name, arg_code)
...@@ -1984,7 +1984,7 @@ class CFuncType(CType): ...@@ -1984,7 +1984,7 @@ class CFuncType(CType):
def signature_cast_string(self): def signature_cast_string(self):
s = self.declaration_code("(*)", with_calling_convention=False) s = self.declaration_code("(*)", with_calling_convention=False)
return '(%s)' % s return '(%s)' % s
def specialize(self, values): def specialize(self, values):
if self.templates is None: if self.templates is None:
new_templates = None new_templates = None
...@@ -2001,7 +2001,7 @@ class CFuncType(CType): ...@@ -2001,7 +2001,7 @@ class CFuncType(CType):
is_overridable = self.is_overridable, is_overridable = self.is_overridable,
optional_arg_count = self.optional_arg_count, optional_arg_count = self.optional_arg_count,
templates = new_templates) templates = new_templates)
def opt_arg_cname(self, arg_name): def opt_arg_cname(self, arg_name):
return self.op_arg_struct.base_type.scope.lookup(arg_name).cname return self.op_arg_struct.base_type.scope.lookup(arg_name).cname
...@@ -2127,13 +2127,13 @@ class CFuncTypeArg(BaseType): ...@@ -2127,13 +2127,13 @@ class CFuncTypeArg(BaseType):
self.type = type self.type = type
self.pos = pos self.pos = pos
self.needs_type_test = False # TODO: should these defaults be set in analyse_types()? self.needs_type_test = False # TODO: should these defaults be set in analyse_types()?
def __repr__(self): def __repr__(self):
return "%s:%s" % (self.name, repr(self.type)) return "%s:%s" % (self.name, repr(self.type))
def declaration_code(self, for_display = 0): def declaration_code(self, for_display = 0):
return self.type.declaration_code(self.cname, for_display) return self.type.declaration_code(self.cname, for_display)
def specialize(self, values): def specialize(self, values):
return CFuncTypeArg(self.name, self.type.specialize(values), self.pos, self.cname) return CFuncTypeArg(self.name, self.type.specialize(values), self.pos, self.cname)
...@@ -2147,11 +2147,11 @@ class StructUtilityCode(object): ...@@ -2147,11 +2147,11 @@ class StructUtilityCode(object):
return isinstance(other, StructUtilityCode) and self.header == other.header return isinstance(other, StructUtilityCode) and self.header == other.header
def __hash__(self): def __hash__(self):
return hash(self.header) return hash(self.header)
def put_code(self, output): def put_code(self, output):
code = output['utility_code_def'] code = output['utility_code_def']
proto = output['utility_code_proto'] proto = output['utility_code_proto']
code.putln("%s {" % self.header) code.putln("%s {" % self.header)
code.putln("PyObject* res;") code.putln("PyObject* res;")
code.putln("PyObject* member;") code.putln("PyObject* member;")
...@@ -2174,7 +2174,7 @@ class StructUtilityCode(object): ...@@ -2174,7 +2174,7 @@ class StructUtilityCode(object):
if self.forward_decl: if self.forward_decl:
proto.putln(self.type.declaration_code('') + ';') proto.putln(self.type.declaration_code('') + ';')
proto.putln(self.header + ";") proto.putln(self.header + ";")
class CStructOrUnionType(CType): class CStructOrUnionType(CType):
# name string # name string
...@@ -2183,12 +2183,12 @@ class CStructOrUnionType(CType): ...@@ -2183,12 +2183,12 @@ class CStructOrUnionType(CType):
# scope StructOrUnionScope, or None if incomplete # scope StructOrUnionScope, or None if incomplete
# typedef_flag boolean # typedef_flag boolean
# packed boolean # packed boolean
# entry Entry # entry Entry
is_struct_or_union = 1 is_struct_or_union = 1
has_attributes = 1 has_attributes = 1
def __init__(self, name, kind, scope, typedef_flag, cname, packed=False): def __init__(self, name, kind, scope, typedef_flag, cname, packed=False):
self.name = name self.name = name
self.cname = cname self.cname = cname
...@@ -2201,7 +2201,7 @@ class CStructOrUnionType(CType): ...@@ -2201,7 +2201,7 @@ class CStructOrUnionType(CType):
self.exception_check = True self.exception_check = True
self._convert_code = None self._convert_code = None
self.packed = packed self.packed = packed
def create_to_py_utility_code(self, env): def create_to_py_utility_code(self, env):
if env.outer_scope is None: if env.outer_scope is None:
return False return False
...@@ -2216,15 +2216,15 @@ class CStructOrUnionType(CType): ...@@ -2216,15 +2216,15 @@ class CStructOrUnionType(CType):
return False return False
forward_decl = (self.entry.visibility != 'extern') forward_decl = (self.entry.visibility != 'extern')
self._convert_code = StructUtilityCode(self, forward_decl) self._convert_code = StructUtilityCode(self, forward_decl)
env.use_utility_code(self._convert_code) env.use_utility_code(self._convert_code)
return True return True
def __repr__(self): def __repr__(self):
return "<CStructOrUnionType %s %s%s>" % (self.name, self.cname, return "<CStructOrUnionType %s %s%s>" % (self.name, self.cname,
("", " typedef")[self.typedef_flag]) ("", " typedef")[self.typedef_flag])
def declaration_code(self, entity_code, def declaration_code(self, entity_code,
for_display = 0, dll_linkage = None, pyrex = 0): for_display = 0, dll_linkage = None, pyrex = 0):
if pyrex or for_display: if pyrex or for_display:
base_code = self.name base_code = self.name
...@@ -2256,7 +2256,7 @@ class CStructOrUnionType(CType): ...@@ -2256,7 +2256,7 @@ class CStructOrUnionType(CType):
def is_complete(self): def is_complete(self):
return self.scope is not None return self.scope is not None
def attributes_known(self): def attributes_known(self):
return self.is_complete() return self.is_complete()
...@@ -2279,12 +2279,12 @@ class CppClassType(CType): ...@@ -2279,12 +2279,12 @@ class CppClassType(CType):
# cname string # cname string
# scope CppClassScope # scope CppClassScope
# templates [string] or None # templates [string] or None
is_cpp_class = 1 is_cpp_class = 1
has_attributes = 1 has_attributes = 1
exception_check = True exception_check = True
namespace = None namespace = None
def __init__(self, name, scope, cname, base_classes, templates = None, template_type = None): def __init__(self, name, scope, cname, base_classes, templates = None, template_type = None):
self.name = name self.name = name
self.cname = cname self.cname = cname
...@@ -2300,11 +2300,11 @@ class CppClassType(CType): ...@@ -2300,11 +2300,11 @@ class CppClassType(CType):
error(pos, "'%s' type is not a template" % self); error(pos, "'%s' type is not a template" % self);
return PyrexTypes.error_type return PyrexTypes.error_type
if len(self.templates) != len(template_values): if len(self.templates) != len(template_values):
error(pos, "%s templated type receives %d arguments, got %d" % error(pos, "%s templated type receives %d arguments, got %d" %
(self.name, len(self.templates), len(template_values))) (self.name, len(self.templates), len(template_values)))
return error_type return error_type
return self.specialize(dict(zip(self.templates, template_values))) return self.specialize(dict(zip(self.templates, template_values)))
def specialize(self, values): def specialize(self, values):
if not self.templates and not self.namespace: if not self.templates and not self.namespace:
return self return self
...@@ -2351,7 +2351,7 @@ class CppClassType(CType): ...@@ -2351,7 +2351,7 @@ class CppClassType(CType):
if base_class.is_subclass(other_type): if base_class.is_subclass(other_type):
return 1 return 1
return 0 return 0
def same_as_resolved_type(self, other_type): def same_as_resolved_type(self, other_type):
if other_type.is_cpp_class: if other_type.is_cpp_class:
if self == other_type: if self == other_type:
...@@ -2370,23 +2370,23 @@ class CppClassType(CType): ...@@ -2370,23 +2370,23 @@ class CppClassType(CType):
if other_type is error_type: if other_type is error_type:
return True return True
return other_type.is_cpp_class and other_type.is_subclass(self) return other_type.is_cpp_class and other_type.is_subclass(self)
def attributes_known(self): def attributes_known(self):
return self.scope is not None return self.scope is not None
class TemplatePlaceholderType(CType): class TemplatePlaceholderType(CType):
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name
def declaration_code(self, entity_code, def declaration_code(self, entity_code,
for_display = 0, dll_linkage = None, pyrex = 0): for_display = 0, dll_linkage = None, pyrex = 0):
if entity_code: if entity_code:
return self.name + " " + entity_code return self.name + " " + entity_code
else: else:
return self.name return self.name
def specialize(self, values): def specialize(self, values):
if self in values: if self in values:
return values[self] return values[self]
...@@ -2398,10 +2398,10 @@ class TemplatePlaceholderType(CType): ...@@ -2398,10 +2398,10 @@ class TemplatePlaceholderType(CType):
return self.name == other_type.name return self.name == other_type.name
else: else:
return 0 return 0
def __hash__(self): def __hash__(self):
return hash(self.name) return hash(self.name)
def __cmp__(self, other): def __cmp__(self, other):
if isinstance(other, TemplatePlaceholderType): if isinstance(other, TemplatePlaceholderType):
return cmp(self.name, other.name) return cmp(self.name, other.name)
...@@ -2424,15 +2424,15 @@ class CEnumType(CType): ...@@ -2424,15 +2424,15 @@ class CEnumType(CType):
self.cname = cname self.cname = cname
self.values = [] self.values = []
self.typedef_flag = typedef_flag self.typedef_flag = typedef_flag
def __str__(self): def __str__(self):
return self.name return self.name
def __repr__(self): def __repr__(self):
return "<CEnumType %s %s%s>" % (self.name, self.cname, return "<CEnumType %s %s%s>" % (self.name, self.cname,
("", " typedef")[self.typedef_flag]) ("", " typedef")[self.typedef_flag])
def declaration_code(self, entity_code, def declaration_code(self, entity_code,
for_display = 0, dll_linkage = None, pyrex = 0): for_display = 0, dll_linkage = None, pyrex = 0):
if pyrex or for_display: if pyrex or for_display:
base_code = self.name base_code = self.name
...@@ -2450,7 +2450,7 @@ class CStringType(object): ...@@ -2450,7 +2450,7 @@ class CStringType(object):
is_string = 1 is_string = 1
is_unicode = 0 is_unicode = 0
to_py_function = "PyBytes_FromString" to_py_function = "PyBytes_FromString"
from_py_function = "PyBytes_AsString" from_py_function = "PyBytes_AsString"
exception_value = "NULL" exception_value = "NULL"
...@@ -2462,32 +2462,32 @@ class CStringType(object): ...@@ -2462,32 +2462,32 @@ class CStringType(object):
class CUTF8CharArrayType(CStringType, CArrayType): class CUTF8CharArrayType(CStringType, CArrayType):
# C 'char []' type. # C 'char []' type.
is_unicode = 1 is_unicode = 1
to_py_function = "PyUnicode_DecodeUTF8" to_py_function = "PyUnicode_DecodeUTF8"
exception_value = "NULL" exception_value = "NULL"
def __init__(self, size): def __init__(self, size):
CArrayType.__init__(self, c_char_type, size) CArrayType.__init__(self, c_char_type, size)
class CCharArrayType(CStringType, CArrayType): class CCharArrayType(CStringType, CArrayType):
# C 'char []' type. # C 'char []' type.
def __init__(self, size): def __init__(self, size):
CArrayType.__init__(self, c_char_type, size) CArrayType.__init__(self, c_char_type, size)
class CCharPtrType(CStringType, CPtrType): class CCharPtrType(CStringType, CPtrType):
# C 'char *' type. # C 'char *' type.
def __init__(self): def __init__(self):
CPtrType.__init__(self, c_char_type) CPtrType.__init__(self, c_char_type)
class CUCharPtrType(CStringType, CPtrType): class CUCharPtrType(CStringType, CPtrType):
# C 'unsigned char *' type. # C 'unsigned char *' type.
to_py_function = "__Pyx_PyBytes_FromUString" to_py_function = "__Pyx_PyBytes_FromUString"
from_py_function = "__Pyx_PyBytes_AsUString" from_py_function = "__Pyx_PyBytes_AsUString"
...@@ -2497,39 +2497,39 @@ class CUCharPtrType(CStringType, CPtrType): ...@@ -2497,39 +2497,39 @@ class CUCharPtrType(CStringType, CPtrType):
class UnspecifiedType(PyrexType): class UnspecifiedType(PyrexType):
# Used as a placeholder until the type can be determined. # Used as a placeholder until the type can be determined.
is_unspecified = 1 is_unspecified = 1
def declaration_code(self, entity_code, def declaration_code(self, entity_code,
for_display = 0, dll_linkage = None, pyrex = 0): for_display = 0, dll_linkage = None, pyrex = 0):
return "<unspecified>" return "<unspecified>"
def same_as_resolved_type(self, other_type): def same_as_resolved_type(self, other_type):
return False return False
class ErrorType(PyrexType): class ErrorType(PyrexType):
# Used to prevent propagation of error messages. # Used to prevent propagation of error messages.
is_error = 1 is_error = 1
exception_value = "0" exception_value = "0"
exception_check = 0 exception_check = 0
to_py_function = "dummy" to_py_function = "dummy"
from_py_function = "dummy" from_py_function = "dummy"
def create_to_py_utility_code(self, env): def create_to_py_utility_code(self, env):
return True return True
def create_from_py_utility_code(self, env): def create_from_py_utility_code(self, env):
return True return True
def declaration_code(self, entity_code, def declaration_code(self, entity_code,
for_display = 0, dll_linkage = None, pyrex = 0): for_display = 0, dll_linkage = None, pyrex = 0):
return "<error>" return "<error>"
def same_as_resolved_type(self, other_type): def same_as_resolved_type(self, other_type):
return 1 return 1
def error_condition(self, result_code): def error_condition(self, result_code):
return "dummy" return "dummy"
...@@ -2659,7 +2659,7 @@ modifiers_and_name_to_type = { ...@@ -2659,7 +2659,7 @@ modifiers_and_name_to_type = {
def is_promotion(src_type, dst_type): def is_promotion(src_type, dst_type):
# It's hard to find a hard definition of promotion, but empirical # It's hard to find a hard definition of promotion, but empirical
# evidence suggests that the below is all that's allowed. # evidence suggests that the below is all that's allowed.
if src_type.is_numeric: if src_type.is_numeric:
if dst_type.same_as(c_int_type): if dst_type.same_as(c_int_type):
unsigned = (not src_type.signed) unsigned = (not src_type.signed)
...@@ -2682,7 +2682,7 @@ def best_match(args, functions, pos=None, env=None): ...@@ -2682,7 +2682,7 @@ def best_match(args, functions, pos=None, env=None):
functions based on how much work must be done to convert the functions based on how much work must be done to convert the
arguments, with the following priorities: arguments, with the following priorities:
* identical types or pointers to identical types * identical types or pointers to identical types
* promotions * promotions
* non-Python types * non-Python types
That is, we prefer functions where no arguments need converted, That is, we prefer functions where no arguments need converted,
and failing that, functions where only promotions are required, and and failing that, functions where only promotions are required, and
...@@ -2692,7 +2692,7 @@ def best_match(args, functions, pos=None, env=None): ...@@ -2692,7 +2692,7 @@ def best_match(args, functions, pos=None, env=None):
the same weight, we return None (as there is no best match). If pos the same weight, we return None (as there is no best match). If pos
is not None, we also generate an error. is not None, we also generate an error.
""" """
# TODO: args should be a list of types, not a list of Nodes. # TODO: args should be a list of types, not a list of Nodes.
actual_nargs = len(args) actual_nargs = len(args)
candidates = [] candidates = []
...@@ -2724,7 +2724,7 @@ def best_match(args, functions, pos=None, env=None): ...@@ -2724,7 +2724,7 @@ def best_match(args, functions, pos=None, env=None):
errors.append((func, error_mesg)) errors.append((func, error_mesg))
continue continue
candidates.append((func, func_type)) candidates.append((func, func_type))
# Optimize the most common case of no overloading... # Optimize the most common case of no overloading...
if len(candidates) == 1: if len(candidates) == 1:
return candidates[0][0] return candidates[0][0]
...@@ -2735,11 +2735,12 @@ def best_match(args, functions, pos=None, env=None): ...@@ -2735,11 +2735,12 @@ def best_match(args, functions, pos=None, env=None):
else: else:
error(pos, "no suitable method found") error(pos, "no suitable method found")
return None return None
possibilities = [] possibilities = []
bad_types = [] bad_types = []
needed_coercions = {} needed_coercions = {}
for func, func_type in candidates:
for index, (func, func_type) in enumerate(candidates):
score = [0,0,0,0] score = [0,0,0,0]
for i in range(min(len(args), len(func_type.args))): for i in range(min(len(args), len(func_type.args))):
src_type = args[i].type src_type = args[i].type
...@@ -2782,7 +2783,7 @@ def best_match(args, functions, pos=None, env=None): ...@@ -2782,7 +2783,7 @@ def best_match(args, functions, pos=None, env=None):
bad_types.append((func, error_mesg)) bad_types.append((func, error_mesg))
break break
else: else:
possibilities.append((score, func)) # so we can sort it possibilities.append((score, index, func)) # so we can sort it
if possibilities: if possibilities:
possibilities.sort() possibilities.sort()
...@@ -2791,7 +2792,7 @@ def best_match(args, functions, pos=None, env=None): ...@@ -2791,7 +2792,7 @@ def best_match(args, functions, pos=None, env=None):
error(pos, "ambiguous overloaded method") error(pos, "ambiguous overloaded method")
return None return None
function = possibilities[0][1] function = possibilities[0][-1]
if function in needed_coercions and env: if function in needed_coercions and env:
arg_i, coerce_to_type = needed_coercions[function] arg_i, coerce_to_type = needed_coercions[function]
...@@ -2819,7 +2820,7 @@ def widest_numeric_type(type1, type2): ...@@ -2819,7 +2820,7 @@ def widest_numeric_type(type1, type2):
return ntype return ntype
widest_type = CComplexType( widest_type = CComplexType(
widest_numeric_type( widest_numeric_type(
real_type(type1), real_type(type1),
real_type(type2))) real_type(type2)))
elif type1.is_enum and type2.is_enum: elif type1.is_enum and type2.is_enum:
widest_type = c_int_type widest_type = c_int_type
...@@ -2911,7 +2912,7 @@ def simple_c_type(signed, longness, name): ...@@ -2911,7 +2912,7 @@ def simple_c_type(signed, longness, name):
# Find type descriptor for simple type given name and modifiers. # Find type descriptor for simple type given name and modifiers.
# Returns None if arguments don't make sense. # Returns None if arguments don't make sense.
return modifiers_and_name_to_type.get((signed, longness, name)) return modifiers_and_name_to_type.get((signed, longness, name))
def parse_basic_type(name): def parse_basic_type(name):
base = None base = None
if name.startswith('p_'): if name.startswith('p_'):
...@@ -2945,7 +2946,7 @@ def parse_basic_type(name): ...@@ -2945,7 +2946,7 @@ def parse_basic_type(name):
if name.startswith('u'): if name.startswith('u'):
name = name[1:] name = name[1:]
signed = 0 signed = 0
elif (name.startswith('s') and elif (name.startswith('s') and
not name.startswith('short')): not name.startswith('short')):
name = name[1:] name = name[1:]
signed = 2 signed = 2
...@@ -2989,7 +2990,7 @@ def c_ref_type(base_type): ...@@ -2989,7 +2990,7 @@ def c_ref_type(base_type):
def same_type(type1, type2): def same_type(type1, type2):
return type1.same_as(type2) return type1.same_as(type2)
def assignable_from(type1, type2): def assignable_from(type1, type2):
return type1.assignable_from(type2) return type1.assignable_from(type2)
......
...@@ -409,13 +409,13 @@ class Scope(object): ...@@ -409,13 +409,13 @@ class Scope(object):
except ValueError, e: except ValueError, e:
error(pos, e.args[0]) error(pos, e.args[0])
type = PyrexTypes.error_type type = PyrexTypes.error_type
entry = self.declare_type(name, type, pos, cname, entry = self.declare_type(name, type, pos, cname,
visibility = visibility, api = api) visibility = visibility, api = api)
type.qualified_name = entry.qualified_name type.qualified_name = entry.qualified_name
return entry return entry
def declare_struct_or_union(self, name, kind, scope, def declare_struct_or_union(self, name, kind, scope,
typedef_flag, pos, cname = None, typedef_flag, pos, cname = None,
visibility = 'private', api = 0, visibility = 'private', api = 0,
packed = False): packed = False):
# Add an entry for a struct or union definition. # Add an entry for a struct or union definition.
...@@ -496,7 +496,7 @@ class Scope(object): ...@@ -496,7 +496,7 @@ class Scope(object):
if entry.visibility != visibility: if entry.visibility != visibility:
error(pos, "'%s' previously declared as '%s'" % ( error(pos, "'%s' previously declared as '%s'" % (
entry.name, entry.visibility)) entry.name, entry.visibility))
def declare_enum(self, name, pos, cname, typedef_flag, def declare_enum(self, name, pos, cname, typedef_flag,
visibility = 'private', api = 0): visibility = 'private', api = 0):
if name: if name:
...@@ -512,7 +512,7 @@ class Scope(object): ...@@ -512,7 +512,7 @@ class Scope(object):
visibility = visibility, api = api) visibility = visibility, api = api)
entry.enum_values = [] entry.enum_values = []
self.sue_entries.append(entry) self.sue_entries.append(entry)
return entry return entry
def declare_var(self, name, type, pos, def declare_var(self, name, type, pos,
cname = None, visibility = 'private', api = 0, is_cdef = 0): cname = None, visibility = 'private', api = 0, is_cdef = 0):
...@@ -564,11 +564,16 @@ class Scope(object): ...@@ -564,11 +564,16 @@ class Scope(object):
entry.is_anonymous = True entry.is_anonymous = True
return entry return entry
def declare_lambda_function(self, func_cname, pos): def declare_lambda_function(self, lambda_name, pos):
# Add an entry for an anonymous Python function. # Add an entry for an anonymous Python function.
entry = self.declare_var(None, py_object_type, pos, func_cname = self.mangle(Naming.lambda_func_prefix + u'funcdef_', lambda_name)
cname=func_cname, visibility='private') pymethdef_cname = self.mangle(Naming.lambda_func_prefix + u'methdef_', lambda_name)
entry.name = EncodedString(func_cname) qualified_name = self.qualify_name(lambda_name)
entry = self.declare(None, func_cname, py_object_type, pos, 'private')
entry.name = lambda_name
entry.qualified_name = qualified_name
entry.pymethdef_cname = pymethdef_cname
entry.func_cname = func_cname entry.func_cname = func_cname
entry.signature = pyfunction_signature entry.signature = pyfunction_signature
entry.is_anonymous = True entry.is_anonymous = True
...@@ -927,21 +932,18 @@ class ModuleScope(Scope): ...@@ -927,21 +932,18 @@ class ModuleScope(Scope):
return self.outer_scope.lookup(name, language_level = self.context.language_level) return self.outer_scope.lookup(name, language_level = self.context.language_level)
def declare_builtin(self, name, pos): def declare_builtin(self, name, pos):
if not hasattr(builtins, name) and name not in Code.non_portable_builtins_map: if not hasattr(builtins, name) \
# 'xrange' and 'BaseException' are special cased in Code.py and name not in Code.non_portable_builtins_map \
and name not in Code.uncachable_builtins:
if self.has_import_star: if self.has_import_star:
entry = self.declare_var(name, py_object_type, pos) entry = self.declare_var(name, py_object_type, pos)
return entry return entry
## elif self.outer_scope is not None:
## entry = self.outer_scope.declare_builtin(name, pos)
## print entry
## return entry
else: else:
# unknown - assume it's builtin and look it up at runtime
if Options.error_on_unknown_names: if Options.error_on_unknown_names:
error(pos, "undeclared name not builtin: %s" % name) error(pos, "undeclared name not builtin: %s" % name)
else: else:
warning(pos, "undeclared name not builtin: %s" % name, 2) warning(pos, "undeclared name not builtin: %s" % name, 2)
# unknown - assume it's builtin and look it up at runtime
entry = self.declare(name, None, py_object_type, pos, 'private') entry = self.declare(name, None, py_object_type, pos, 'private')
entry.is_builtin = 1 entry.is_builtin = 1
return entry return entry
...@@ -950,7 +952,7 @@ class ModuleScope(Scope): ...@@ -950,7 +952,7 @@ class ModuleScope(Scope):
if entry.name == name: if entry.name == name:
return entry return entry
entry = self.declare(None, None, py_object_type, pos, 'private') entry = self.declare(None, None, py_object_type, pos, 'private')
if Options.cache_builtins: if Options.cache_builtins and name not in Code.uncachable_builtins:
entry.is_builtin = 1 entry.is_builtin = 1
entry.is_const = 1 # cached entry.is_const = 1 # cached
entry.name = name entry.name = name
...@@ -959,6 +961,7 @@ class ModuleScope(Scope): ...@@ -959,6 +961,7 @@ class ModuleScope(Scope):
self.undeclared_cached_builtins.append(entry) self.undeclared_cached_builtins.append(entry)
else: else:
entry.is_builtin = 1 entry.is_builtin = 1
entry.name = name
return entry return entry
def find_module(self, module_name, pos): def find_module(self, module_name, pos):
...@@ -1067,7 +1070,7 @@ class ModuleScope(Scope): ...@@ -1067,7 +1070,7 @@ class ModuleScope(Scope):
buffer_defaults = None, shadow = 0): buffer_defaults = None, shadow = 0):
# If this is a non-extern typedef class, expose the typedef, but use # If this is a non-extern typedef class, expose the typedef, but use
# the non-typedef struct internally to avoid needing forward # the non-typedef struct internally to avoid needing forward
# declarations for anonymous structs. # declarations for anonymous structs.
if typedef_flag and visibility != 'extern': if typedef_flag and visibility != 'extern':
if not (visibility == 'public' or api): if not (visibility == 'public' or api):
warning(pos, "ctypedef only valid for 'extern' , 'public', and 'api'", 2) warning(pos, "ctypedef only valid for 'extern' , 'public', and 'api'", 2)
...@@ -1414,6 +1417,12 @@ class GeneratorExpressionScope(Scope): ...@@ -1414,6 +1417,12 @@ class GeneratorExpressionScope(Scope):
self.entries[name] = entry self.entries[name] = entry
return entry return entry
def declare_lambda_function(self, func_cname, pos):
return self.outer_scope.declare_lambda_function(func_cname, pos)
def add_lambda_def(self, def_node):
return self.outer_scope.add_lambda_def(def_node)
class ClosureScope(LocalScope): class ClosureScope(LocalScope):
...@@ -1464,7 +1473,7 @@ class StructOrUnionScope(Scope): ...@@ -1464,7 +1473,7 @@ class StructOrUnionScope(Scope):
def declare_cfunction(self, name, type, pos, def declare_cfunction(self, name, type, pos,
cname = None, visibility = 'private', defining = 0, cname = None, visibility = 'private', defining = 0,
api = 0, in_pxd = 0, modifiers = ()): # currently no utility code ... api = 0, in_pxd = 0, modifiers = ()): # currently no utility code ...
return self.declare_var(name, type, pos, return self.declare_var(name, type, pos,
cname=cname, visibility=visibility) cname=cname, visibility=visibility)
class ClassScope(Scope): class ClassScope(Scope):
...@@ -1629,7 +1638,7 @@ class CClassScope(ClassScope): ...@@ -1629,7 +1638,7 @@ class CClassScope(ClassScope):
if name == "__new__": if name == "__new__":
error(pos, "__new__ method of extension type will change semantics " error(pos, "__new__ method of extension type will change semantics "
"in a future version of Pyrex and Cython. Use __cinit__ instead.") "in a future version of Pyrex and Cython. Use __cinit__ instead.")
entry = self.declare_var(name, py_object_type, pos, entry = self.declare_var(name, py_object_type, pos,
visibility='extern') visibility='extern')
special_sig = get_special_method_signature(name) special_sig = get_special_method_signature(name)
if special_sig: if special_sig:
...@@ -1755,7 +1764,7 @@ class CppClassScope(Scope): ...@@ -1755,7 +1764,7 @@ class CppClassScope(Scope):
self.inherited_var_entries = [] self.inherited_var_entries = []
def declare_var(self, name, type, pos, def declare_var(self, name, type, pos,
cname = None, visibility = 'extern', api = 0, cname = None, visibility = 'extern', api = 0,
is_cdef = 0, allow_pyobject = 0): is_cdef = 0, allow_pyobject = 0):
# Add an entry for an attribute. # Add an entry for an attribute.
if not cname: if not cname:
...@@ -1797,7 +1806,7 @@ class CppClassScope(Scope): ...@@ -1797,7 +1806,7 @@ class CppClassScope(Scope):
error(pos, "no matching function for call to %s::%s()" % error(pos, "no matching function for call to %s::%s()" %
(self.default_constructor, self.default_constructor)) (self.default_constructor, self.default_constructor))
def declare_cfunction(self, name, type, pos, cname = None, def declare_cfunction(self, name, type, pos, cname = None,
visibility = 'extern', api = 0, defining = 0, visibility = 'extern', api = 0, defining = 0,
in_pxd = 0, modifiers = (), utility_code = None): in_pxd = 0, modifiers = (), utility_code = None):
if name == self.name.split('::')[-1] and cname is None: if name == self.name.split('::')[-1] and cname is None:
...@@ -1805,7 +1814,7 @@ class CppClassScope(Scope): ...@@ -1805,7 +1814,7 @@ class CppClassScope(Scope):
name = '<init>' name = '<init>'
type.return_type = self.lookup(self.name).type type.return_type = self.lookup(self.name).type
prev_entry = self.lookup_here(name) prev_entry = self.lookup_here(name)
entry = self.declare_var(name, type, pos, entry = self.declare_var(name, type, pos,
cname=cname, visibility=visibility) cname=cname, visibility=visibility)
if prev_entry: if prev_entry:
entry.overloaded_alternatives = prev_entry.all_alternatives() entry.overloaded_alternatives = prev_entry.all_alternatives()
......
...@@ -54,11 +54,10 @@ class TestTreeFragments(CythonTest): ...@@ -54,11 +54,10 @@ class TestTreeFragments(CythonTest):
x = TMP x = TMP
""") """)
T = F.substitute(temps=[u"TMP"]) T = F.substitute(temps=[u"TMP"])
s = T.stats s = T.body.stats
self.assert_(s[0].expr.name == "__tmpvar_1") self.assert_(isinstance(s[0].expr, TempRefNode))
# self.assert_(isinstance(s[0].expr, TempRefNode)) self.assert_(isinstance(s[1].rhs, TempRefNode))
# self.assert_(isinstance(s[1].rhs, TempRefNode)) self.assert_(s[0].expr.handle is s[1].rhs.handle)
# self.assert_(s[0].expr.handle is s[1].rhs.handle)
if __name__ == "__main__": if __name__ == "__main__":
import unittest import unittest
......
...@@ -121,16 +121,15 @@ class TemplateTransform(VisitorTransform): ...@@ -121,16 +121,15 @@ class TemplateTransform(VisitorTransform):
temphandles = [] temphandles = []
for temp in temps: for temp in temps:
TemplateTransform.temp_name_counter += 1 TemplateTransform.temp_name_counter += 1
handle = "__tmpvar_%d" % TemplateTransform.temp_name_counter handle = UtilNodes.TempHandle(PyrexTypes.py_object_type)
# handle = UtilNodes.TempHandle(PyrexTypes.py_object_type)
tempmap[temp] = handle tempmap[temp] = handle
# temphandles.append(handle) temphandles.append(handle)
self.tempmap = tempmap self.tempmap = tempmap
result = super(TemplateTransform, self).__call__(node) result = super(TemplateTransform, self).__call__(node)
# if temps: if temps:
# result = UtilNodes.TempsBlockNode(self.get_pos(node), result = UtilNodes.TempsBlockNode(self.get_pos(node),
# temps=temphandles, temps=temphandles,
# body=result) body=result)
return result return result
def get_pos(self, node): def get_pos(self, node):
...@@ -161,9 +160,8 @@ class TemplateTransform(VisitorTransform): ...@@ -161,9 +160,8 @@ class TemplateTransform(VisitorTransform):
def visit_NameNode(self, node): def visit_NameNode(self, node):
temphandle = self.tempmap.get(node.name) temphandle = self.tempmap.get(node.name)
if temphandle: if temphandle:
return NameNode(pos=node.pos, name=temphandle)
# Replace name with temporary # Replace name with temporary
#return temphandle.ref(self.get_pos(node)) return temphandle.ref(self.get_pos(node))
else: else:
return self.try_substitution(node, node.name) return self.try_substitution(node, node.name)
......
...@@ -134,6 +134,10 @@ class ResultRefNode(AtomicExprNode): ...@@ -134,6 +134,10 @@ class ResultRefNode(AtomicExprNode):
self.type = type self.type = type
assert self.pos is not None assert self.pos is not None
def clone_node(self):
# nothing to do here
return self
def analyse_types(self, env): def analyse_types(self, env):
if self.expression is not None: if self.expression is not None:
self.type = self.expression.type self.type = self.expression.type
......
cdef extern from "<string>" namespace "std":
size_t npos = -1
cdef cppclass string:
string()
string(char *)
string(char *, size_t)
string(string&)
# as a string formed by a repetition of character c, n times.
string(size_t, char)
char* c_str()
size_t size()
size_t max_size()
size_t length()
void resize(size_t)
void resize(size_t, char c)
size_t capacity()
void reserve(size_t)
void clear()
bint empty()
char at(size_t)
char operator[](size_t)
int compare(string&)
string& append(string&)
string& append(string&, size_t, size_t)
string& append(char *)
string& append(char *, size_t)
string& append(size_t, char)
void push_back(char c)
string& assign (string&)
string& assign (string&, size_t, size_t)
string& assign (char *, size_t)
string& assign (char *)
string& assign (size_t n, char c)
string& insert(size_t, string&)
string& insert(size_t, string&, size_t, size_t)
string& insert(size_t, char* s, size_t)
string& insert(size_t, char* s)
string& insert(size_t, size_t, char c)
size_t copy(char *, size_t, size_t)
size_t find(string&)
size_t find(string&, size_t)
size_t find(char*, size_t pos, size_t)
size_t find(char*, size_t pos)
size_t find(char, size_t pos)
size_t rfind(string&, size_t)
size_t rfind(char* s, size_t, size_t)
size_t rfind(char*, size_t pos)
size_t rfind(char c, size_t)
size_t rfind(char c)
size_t find_first_of(string&, size_t)
size_t find_first_of(char* s, size_t, size_t)
size_t find_first_of(char*, size_t pos)
size_t find_first_of(char c, size_t)
size_t find_first_of(char c)
size_t find_first_not_of(string&, size_t)
size_t find_first_not_of(char* s, size_t, size_t)
size_t find_first_not_of(char*, size_t pos)
size_t find_first_not_of(char c, size_t)
size_t find_first_not_of(char c)
size_t find_last_of(string&, size_t)
size_t find_last_of(char* s, size_t, size_t)
size_t find_last_of(char*, size_t pos)
size_t find_last_of(char c, size_t)
size_t find_last_of(char c)
size_t find_last_not_of(string&, size_t)
size_t find_last_not_of(char* s, size_t, size_t)
size_t find_last_not_of(char*, size_t pos)
string substr(size_t, size_t)
string substr()
string substr(size_t)
size_t find_last_not_of(char c, size_t)
size_t find_last_not_of(char c)
#string& operator= (string&)
#string& operator= (char*)
#string& operator= (char)
bint operator==(string&)
bint operator==(char*)
bint operator!= (string& rhs )
bint operator!= (char* )
bint operator< (string&)
bint operator< (char*)
bint operator> (string&)
bint operator> (char*)
bint operator<= (string&)
bint operator<= (char*)
bint operator>= (string&)
bint operator>= (char*)
...@@ -39,7 +39,7 @@ class UnrecognizedInput(PlexError): ...@@ -39,7 +39,7 @@ class UnrecognizedInput(PlexError):
def __init__(self, scanner, state_name): def __init__(self, scanner, state_name):
self.scanner = scanner self.scanner = scanner
self.position = scanner.position() self.position = scanner.get_position()
self.state_name = state_name self.state_name = state_name
def __str__(self): def __str__(self):
......
...@@ -299,6 +299,11 @@ class Scanner(object): ...@@ -299,6 +299,11 @@ class Scanner(object):
""" """
return (self.name, self.start_line, self.start_col) return (self.name, self.start_line, self.start_col)
def get_position(self):
"""Python accessible wrapper around position(), only for error reporting.
"""
return self.position()
def begin(self, state_name): def begin(self, state_name):
"""Set the current state of the scanner to the named state.""" """Set the current state of the scanner to the named state."""
self.initial_state = ( self.initial_state = (
......
...@@ -261,7 +261,8 @@ class PyImporter(PyxImporter): ...@@ -261,7 +261,8 @@ class PyImporter(PyxImporter):
self.super = super(PyImporter, self) self.super = super(PyImporter, self)
self.super.__init__(extension='.py', pyxbuild_dir=pyxbuild_dir) self.super.__init__(extension='.py', pyxbuild_dir=pyxbuild_dir)
self.uncompilable_modules = {} self.uncompilable_modules = {}
self.blocked_modules = ['Cython', 'distutils.extension'] self.blocked_modules = ['Cython', 'distutils.extension',
'distutils.sysconfig']
def find_module(self, fullname, package_path=None): def find_module(self, fullname, package_path=None):
if fullname in sys.modules: if fullname in sys.modules:
......
...@@ -69,20 +69,18 @@ if sys.platform == 'win32': ...@@ -69,20 +69,18 @@ if sys.platform == 'win32':
distutils_distro.parse_config_files(cfgfiles) distutils_distro.parse_config_files(cfgfiles)
EXT_DEP_MODULES = { EXT_DEP_MODULES = {
'numpy' : 'tag:numpy', 'tag:numpy' : 'numpy',
'pstats' : 'tag:pstats', 'tag:pstats': 'pstats',
'posix' : 'tag:posix', 'tag:posix' : 'posix',
} }
def get_numpy_include_dirs(): def update_numpy_extension(ext):
import numpy import numpy
return [numpy.get_include()] ext.include_dirs.append(numpy.get_include())
# TODO: use tags EXT_EXTRAS = {
EXT_DEP_INCLUDES = [ 'tag:numpy' : update_numpy_extension,
# test name matcher , callable returning list }
(re.compile('numpy_.*').match, get_numpy_include_dirs),
]
# TODO: use tags # TODO: use tags
VER_DEP_MODULES = { VER_DEP_MODULES = {
...@@ -101,11 +99,15 @@ VER_DEP_MODULES = { ...@@ -101,11 +99,15 @@ VER_DEP_MODULES = {
'run.generators_py', # generators, with statement 'run.generators_py', # generators, with statement
'run.pure_py', # decorators, with statement 'run.pure_py', # decorators, with statement
]), ]),
(2,7) : (operator.lt, lambda x: x in ['run.withstat_py', # multi context with statement
]),
# The next line should start (3,); but this is a dictionary, so # The next line should start (3,); but this is a dictionary, so
# we can only have one (3,) key. Since 2.7 is supposed to be the # we can only have one (3,) key. Since 2.7 is supposed to be the
# last 2.x release, things would have to change drastically for this # last 2.x release, things would have to change drastically for this
# to be unsafe... # to be unsafe...
(2,999): (operator.lt, lambda x: x in ['run.special_methods_T561_py3']), (2,999): (operator.lt, lambda x: x in ['run.special_methods_T561_py3',
'run.test_raisefrom',
]),
(3,): (operator.ge, lambda x: x in ['run.non_future_division', (3,): (operator.ge, lambda x: x in ['run.non_future_division',
'compile.extsetslice', 'compile.extsetslice',
'compile.extdelslice', 'compile.extdelslice',
...@@ -155,6 +157,8 @@ def parse_tags(filepath): ...@@ -155,6 +157,8 @@ def parse_tags(filepath):
parse_tags = memoize(parse_tags) parse_tags = memoize(parse_tags)
list_unchanging_dir = memoize(lambda x: os.listdir(x))
class build_ext(_build_ext): class build_ext(_build_ext):
def build_extension(self, ext): def build_extension(self, ext):
...@@ -241,7 +245,7 @@ class TestBuilder(object): ...@@ -241,7 +245,7 @@ class TestBuilder(object):
os.makedirs(workdir) os.makedirs(workdir)
suite = unittest.TestSuite() suite = unittest.TestSuite()
filenames = os.listdir(path) filenames = list_unchanging_dir(path)
filenames.sort() filenames.sort()
for filename in filenames: for filename in filenames:
filepath = os.path.join(path, filename) filepath = os.path.join(path, filename)
...@@ -293,7 +297,7 @@ class TestBuilder(object): ...@@ -293,7 +297,7 @@ class TestBuilder(object):
return suite return suite
def build_tests(self, test_class, path, workdir, module, expect_errors, tags): def build_tests(self, test_class, path, workdir, module, expect_errors, tags):
if 'werror' in tags['tags']: if 'werror' in tags['tag']:
warning_errors = True warning_errors = True
else: else:
warning_errors = False warning_errors = False
...@@ -353,6 +357,8 @@ class CythonCompileTestCase(unittest.TestCase): ...@@ -353,6 +357,8 @@ class CythonCompileTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
from Cython.Compiler import Options from Cython.Compiler import Options
self._saved_options = [ (name, getattr(Options, name))
for name in ('warning_errors', 'error_on_unknown_names') ]
Options.warning_errors = self.warning_errors Options.warning_errors = self.warning_errors
if self.workdir not in sys.path: if self.workdir not in sys.path:
...@@ -360,7 +366,8 @@ class CythonCompileTestCase(unittest.TestCase): ...@@ -360,7 +366,8 @@ class CythonCompileTestCase(unittest.TestCase):
def tearDown(self): def tearDown(self):
from Cython.Compiler import Options from Cython.Compiler import Options
Options.warning_errors = False for name, value in self._saved_options:
setattr(Options, name, value)
try: try:
sys.path.remove(self.workdir) sys.path.remove(self.workdir)
...@@ -407,19 +414,21 @@ class CythonCompileTestCase(unittest.TestCase): ...@@ -407,19 +414,21 @@ class CythonCompileTestCase(unittest.TestCase):
def build_target_filename(self, module_name): def build_target_filename(self, module_name):
target = '%s.%s' % (module_name, self.language) target = '%s.%s' % (module_name, self.language)
return target return target
def copy_related_files(self, test_directory, target_directory, module_name): def related_files(self, test_directory, module_name):
is_related = re.compile('%s_.*[.].*' % module_name).match is_related = re.compile('%s_.*[.].*' % module_name).match
for filename in os.listdir(test_directory): return [filename for filename in list_unchanging_dir(test_directory)
if is_related(filename): if is_related(filename)]
shutil.copy(os.path.join(test_directory, filename),
target_directory)
def find_source_files(self, workdir, module_name): def copy_files(self, test_directory, target_directory, file_list):
is_related = re.compile('%s_.*[.]%s' % (module_name, self.language)).match for filename in file_list:
return [self.build_target_filename(module_name)] + [ shutil.copy(os.path.join(test_directory, filename),
filename for filename in os.listdir(workdir) target_directory)
if is_related(filename) and os.path.isfile(os.path.join(workdir, filename)) ]
def source_files(self, workdir, module_name, file_list):
return ([self.build_target_filename(module_name)] +
[filename for filename in file_list
if not os.path.isfile(os.path.join(workdir, filename))])
def split_source_and_output(self, test_directory, module, workdir): def split_source_and_output(self, test_directory, module, workdir):
source_file = self.find_module_source_file(os.path.join(test_directory, module) + '.pyx') source_file = self.find_module_source_file(os.path.join(test_directory, module) + '.pyx')
...@@ -480,6 +489,12 @@ class CythonCompileTestCase(unittest.TestCase): ...@@ -480,6 +489,12 @@ class CythonCompileTestCase(unittest.TestCase):
def run_distutils(self, test_directory, module, workdir, incdir, def run_distutils(self, test_directory, module, workdir, incdir,
extra_extension_args=None): extra_extension_args=None):
original_source = self.find_module_source_file(
os.path.join(test_directory, module + '.pyx'))
try:
tags = parse_tags(original_source)
except IOError:
tags = {}
cwd = os.getcwd() cwd = os.getcwd()
os.chdir(workdir) os.chdir(workdir)
try: try:
...@@ -490,24 +505,27 @@ class CythonCompileTestCase(unittest.TestCase): ...@@ -490,24 +505,27 @@ class CythonCompileTestCase(unittest.TestCase):
build_extension.finalize_options() build_extension.finalize_options()
if COMPILER: if COMPILER:
build_extension.compiler = COMPILER build_extension.compiler = COMPILER
ext_include_dirs = []
for match, get_additional_include_dirs in EXT_DEP_INCLUDES:
if match(module):
ext_include_dirs += get_additional_include_dirs()
ext_compile_flags = CFLAGS[:] ext_compile_flags = CFLAGS[:]
if build_extension.compiler == 'mingw32': if build_extension.compiler == 'mingw32':
ext_compile_flags.append('-Wno-format') ext_compile_flags.append('-Wno-format')
if extra_extension_args is None: if extra_extension_args is None:
extra_extension_args = {} extra_extension_args = {}
self.copy_related_files(test_directory, workdir, module) related_files = self.related_files(test_directory, module)
self.copy_files(test_directory, workdir, related_files)
extension = Extension( extension = Extension(
module, module,
sources = self.find_source_files(workdir, module), sources = self.source_files(workdir, module, related_files),
include_dirs = ext_include_dirs,
extra_compile_args = ext_compile_flags, extra_compile_args = ext_compile_flags,
**extra_extension_args **extra_extension_args
) )
for matcher, fixer in EXT_EXTRAS.items():
if isinstance(matcher, str):
del EXT_EXTRAS[matcher]
matcher = string_selector(matcher)
EXT_EXTRAS[matcher] = fixer
if matcher(module, tags):
extension = fixer(extension) or extension
if self.language == 'cpp': if self.language == 'cpp':
extension.language = 'c++' extension.language = 'c++'
build_extension.extensions = [extension] build_extension.extensions = [extension]
...@@ -583,64 +601,70 @@ class CythonRunTestCase(CythonCompileTestCase): ...@@ -583,64 +601,70 @@ class CythonRunTestCase(CythonCompileTestCase):
self.run_doctests(self.module, result) self.run_doctests(self.module, result)
def run_doctests(self, module_name, result): def run_doctests(self, module_name, result):
if sys.version_info[0] >= 3 or not hasattr(os, 'fork') or not self.fork: def run_test(result):
doctest.DocTestSuite(module_name).run(result) tests = doctest.DocTestSuite(module_name)
gc.collect() tests.run(result)
return run_forked_test(result, run_test, self.shortDescription(), self.fork)
# fork to make sure we do not keep the tested module loaded
result_handle, result_file = tempfile.mkstemp() def run_forked_test(result, run_func, test_name, fork=True):
os.close(result_handle) if not fork or sys.version_info[0] >= 3 or not hasattr(os, 'fork'):
child_id = os.fork() run_func(result)
if not child_id: gc.collect()
result_code = 0 return
try:
try:
tests = None
try:
partial_result = PartialTestResult(result)
tests = doctest.DocTestSuite(module_name)
tests.run(partial_result)
gc.collect()
except Exception:
if tests is None:
# importing failed, try to fake a test class
tests = _FakeClass(
failureException=sys.exc_info()[1],
_shortDescription=self.shortDescription(),
module_name=None)
partial_result.addError(tests, sys.exc_info())
result_code = 1
output = open(result_file, 'wb')
pickle.dump(partial_result.data(), output)
except:
traceback.print_exc()
finally:
try: output.close()
except: pass
os._exit(result_code)
# fork to make sure we do not keep the tested module loaded
result_handle, result_file = tempfile.mkstemp()
os.close(result_handle)
child_id = os.fork()
if not child_id:
result_code = 0
try: try:
cid, result_code = os.waitpid(child_id, 0) try:
# os.waitpid returns the child's result code in the tests = None
# upper byte of result_code, and the signal it was
# killed by in the lower byte
if result_code & 255:
raise Exception("Tests in module '%s' were unexpectedly killed by signal %d"%
(module_name, result_code & 255))
result_code = result_code >> 8
if result_code in (0,1):
input = open(result_file, 'rb')
try: try:
PartialTestResult.join_results(result, pickle.load(input)) partial_result = PartialTestResult(result)
finally: run_func(partial_result)
input.close() gc.collect()
if result_code: except Exception:
raise Exception("Tests in module '%s' exited with status %d" % if tests is None:
(module_name, result_code)) # importing failed, try to fake a test class
tests = _FakeClass(
failureException=sys.exc_info()[1],
_shortDescription=test_name,
module_name=None)
partial_result.addError(tests, sys.exc_info())
result_code = 1
output = open(result_file, 'wb')
pickle.dump(partial_result.data(), output)
except:
traceback.print_exc()
finally: finally:
try: os.unlink(result_file) try: output.close()
except: pass except: pass
os._exit(result_code)
try:
cid, result_code = os.waitpid(child_id, 0)
# os.waitpid returns the child's result code in the
# upper byte of result_code, and the signal it was
# killed by in the lower byte
if result_code & 255:
raise Exception("Tests in module '%s' were unexpectedly killed by signal %d"%
(module_name, result_code & 255))
result_code = result_code >> 8
if result_code in (0,1):
input = open(result_file, 'rb')
try:
PartialTestResult.join_results(result, pickle.load(input))
finally:
input.close()
if result_code:
raise Exception("Tests in module '%s' exited with status %d" %
(module_name, result_code))
finally:
try: os.unlink(result_file)
except: pass
class PureDoctestTestCase(unittest.TestCase): class PureDoctestTestCase(unittest.TestCase):
def __init__(self, module_name, module_path): def __init__(self, module_name, module_path):
...@@ -738,6 +762,11 @@ class CythonUnitTestCase(CythonRunTestCase): ...@@ -738,6 +762,11 @@ class CythonUnitTestCase(CythonRunTestCase):
class CythonPyregrTestCase(CythonRunTestCase): class CythonPyregrTestCase(CythonRunTestCase):
def setUp(self):
CythonRunTestCase.setUp(self)
from Cython.Compiler import Options
Options.error_on_unknown_names = False
def _run_unittest(self, result, *classes): def _run_unittest(self, result, *classes):
"""Run tests from unittest.TestCase-derived classes.""" """Run tests from unittest.TestCase-derived classes."""
valid_types = (unittest.TestSuite, unittest.TestCase) valid_types = (unittest.TestSuite, unittest.TestCase)
...@@ -763,20 +792,23 @@ class CythonPyregrTestCase(CythonRunTestCase): ...@@ -763,20 +792,23 @@ class CythonPyregrTestCase(CythonRunTestCase):
except ImportError: # Py3k except ImportError: # Py3k
from test import support from test import support
def run_unittest(*classes): def run_test(result):
return self._run_unittest(result, *classes) def run_unittest(*classes):
def run_doctest(module, verbosity=None): return self._run_unittest(result, *classes)
return self._run_doctest(result, module) def run_doctest(module, verbosity=None):
return self._run_doctest(result, module)
support.run_unittest = run_unittest support.run_unittest = run_unittest
support.run_doctest = run_doctest support.run_doctest = run_doctest
try: try:
module = __import__(self.module) module = __import__(self.module)
if hasattr(module, 'test_main'): if hasattr(module, 'test_main'):
module.test_main() module.test_main()
except (unittest.SkipTest, support.ResourceDenied): except (unittest.SkipTest, support.ResourceDenied):
result.addSkip(self, 'ok') result.addSkip(self, 'ok')
run_forked_test(result, run_test, self.shortDescription(), self.fork)
include_debugger = sys.version_info[:2] > (2, 5) include_debugger = sys.version_info[:2] > (2, 5)
...@@ -979,9 +1011,9 @@ class EmbedTest(unittest.TestCase): ...@@ -979,9 +1011,9 @@ class EmbedTest(unittest.TestCase):
class MissingDependencyExcluder: class MissingDependencyExcluder:
def __init__(self, deps): def __init__(self, deps):
# deps: { module name : matcher func } # deps: { matcher func : module name }
self.exclude_matchers = [] self.exclude_matchers = []
for mod, matcher in deps.items(): for matcher, mod in deps.items():
try: try:
__import__(mod) __import__(mod)
except ImportError: except ImportError:
...@@ -1168,6 +1200,9 @@ def main(): ...@@ -1168,6 +1200,9 @@ def main():
parser.add_option("--coverage-xml", dest="coverage_xml", parser.add_option("--coverage-xml", dest="coverage_xml",
action="store_true", default=False, action="store_true", default=False,
help="collect source coverage data for the Compiler in XML format") help="collect source coverage data for the Compiler in XML format")
parser.add_option("--coverage-html", dest="coverage_html",
action="store_true", default=False,
help="collect source coverage data for the Compiler in HTML format")
parser.add_option("-A", "--annotate", dest="annotate_source", parser.add_option("-A", "--annotate", dest="annotate_source",
action="store_true", default=True, action="store_true", default=True,
help="generate annotated HTML versions of the test source files") help="generate annotated HTML versions of the test source files")
...@@ -1223,9 +1258,9 @@ def main(): ...@@ -1223,9 +1258,9 @@ def main():
WITH_CYTHON = options.with_cython WITH_CYTHON = options.with_cython
if options.coverage or options.coverage_xml: if options.coverage or options.coverage_xml or options.coverage_html:
if not WITH_CYTHON: if not WITH_CYTHON:
options.coverage = options.coverage_xml = False options.coverage = options.coverage_xml = options.coverage_html = False
else: else:
from coverage import coverage as _coverage from coverage import coverage as _coverage
coverage = _coverage(branch=True) coverage = _coverage(branch=True)
...@@ -1339,15 +1374,15 @@ def main(): ...@@ -1339,15 +1374,15 @@ def main():
test_suite.addTest(filetests.build_suite()) test_suite.addTest(filetests.build_suite())
if options.system_pyregr and languages: if options.system_pyregr and languages:
filetests = TestBuilder(ROOTDIR, WORKDIR, selectors, exclude_selectors, sys_pyregr_dir = os.path.join(sys.prefix, 'lib', 'python'+sys.version[:3], 'test')
options.annotate_source, options.cleanup_workdir, if os.path.isdir(sys_pyregr_dir):
options.cleanup_sharedlibs, True, filetests = TestBuilder(ROOTDIR, WORKDIR, selectors, exclude_selectors,
options.cython_only, languages, test_bugs, options.annotate_source, options.cleanup_workdir,
options.fork, options.language_level) options.cleanup_sharedlibs, True,
test_suite.addTest( options.cython_only, languages, test_bugs,
filetests.handle_directory( options.fork, sys.version_info[0])
os.path.join(sys.prefix, 'lib', 'python'+sys.version[:3], 'test'), sys.stderr.write("Including CPython regression tests in %s\n" % sys_pyregr_dir)
'pyregr')) test_suite.addTest(filetests.handle_directory(sys_pyregr_dir, 'pyregr'))
if options.xml_output_dir: if options.xml_output_dir:
from Cython.Tests.xmlrunner import XMLTestRunner from Cython.Tests.xmlrunner import XMLTestRunner
...@@ -1358,7 +1393,7 @@ def main(): ...@@ -1358,7 +1393,7 @@ def main():
result = test_runner.run(test_suite) result = test_runner.run(test_suite)
if options.coverage or options.coverage_xml: if options.coverage or options.coverage_xml or options.coverage_html:
coverage.stop() coverage.stop()
ignored_modules = ('Options', 'Version', 'DebugFlags', 'CmdLine') ignored_modules = ('Options', 'Version', 'DebugFlags', 'CmdLine')
modules = [ module for name, module in sys.modules.items() modules = [ module for name, module in sys.modules.items()
...@@ -1369,6 +1404,8 @@ def main(): ...@@ -1369,6 +1404,8 @@ def main():
coverage.report(modules, show_missing=0) coverage.report(modules, show_missing=0)
if options.coverage_xml: if options.coverage_xml:
coverage.xml_report(modules, outfile="coverage-report.xml") coverage.xml_report(modules, outfile="coverage-report.xml")
if options.coverage_html:
coverage.html_report(modules, directory="coverage-report-html")
if missing_dep_excluder.tests_missing_deps: if missing_dep_excluder.tests_missing_deps:
sys.stderr.write("Following tests excluded because of missing dependencies on your system:\n") sys.stderr.write("Following tests excluded because of missing dependencies on your system:\n")
......
...@@ -249,7 +249,11 @@ except ValueError: ...@@ -249,7 +249,11 @@ except ValueError:
try: try:
sys.argv.remove("--no-cython-compile") sys.argv.remove("--no-cython-compile")
compile_cython_itself = False
except ValueError: except ValueError:
compile_cython_itself = True
if compile_cython_itself:
compile_cython_modules(cython_profile, cython_compile_more, cython_with_refnanny) compile_cython_modules(cython_profile, cython_compile_more, cython_with_refnanny)
setup_args.update(setuptools_extra_args) setup_args.update(setuptools_extra_args)
......
...@@ -10,7 +10,6 @@ cfunc_call_tuple_args_T408 ...@@ -10,7 +10,6 @@ cfunc_call_tuple_args_T408
compile.cpp_operators compile.cpp_operators
cpp_templated_ctypedef cpp_templated_ctypedef
cpp_structs cpp_structs
with_statement_module_level_T536
function_as_method_T494 function_as_method_T494
closure_inside_cdef_T554 closure_inside_cdef_T554
pure_mode_cmethod_inheritance_T583 pure_mode_cmethod_inheritance_T583
...@@ -22,12 +21,20 @@ class_scope_T671 ...@@ -22,12 +21,20 @@ class_scope_T671
# CPython regression tests that don't current work: # CPython regression tests that don't current work:
pyregr.test_threadsignals pyregr.test_threadsignals
pyregr.test_module
pyregr.test_capi pyregr.test_capi
pyregr.test_socket pyregr.test_socket
pyregr.test_threading pyregr.test_threading
pyregr.test_sys pyregr.test_sys
pyregr.test_pep3131
# CPython regression tests that don't make sense # CPython regression tests that don't make sense
pyregr.test_gdb pyregr.test_gdb
pyregr.test_support pyregr.test_support
# Inlined generators
all
any
builtin_sorted
dictcomp
inlined_generator_expressions
setcomp
# mode: compile
from __future__ import nested_scopes
from __future__ import with_statement
pass
from __future__ import nested_scopes ; from __future__ import nested_scopes
# mode: error # mode: error
# tags: werror # tag: werror
cdef foo(): cdef foo():
pass pass
......
...@@ -18,40 +18,40 @@ def short_binop(short val): ...@@ -18,40 +18,40 @@ def short_binop(short val):
""" """
Arithmetic in C is always done with at least int precision. Arithmetic in C is always done with at least int precision.
>>> short_binop(3) >>> print(short_binop(3))
'int called' int called
""" """
assert typeof(val + val) == "int", typeof(val + val) assert typeof(val + val) == "int", typeof(val + val)
assert typeof(val - val) == "int", typeof(val - val) assert typeof(val - val) == "int", typeof(val - val)
assert typeof(val & val) == "int", typeof(val & val) assert typeof(val & val) == "int", typeof(val & val)
cdef int_return x = f(val + val) cdef int_return x = f(val + val)
return x.msg return x.msg.decode('ASCII')
def short_unnop(short val): def short_unnop(short val):
""" """
Arithmetic in C is always done with at least int precision. Arithmetic in C is always done with at least int precision.
>>> short_unnop(3) >>> print(short_unnop(3))
'int called' int called
""" """
cdef int_return x = f(-val) cdef int_return x = f(-val)
return x.msg return x.msg.decode('ASCII')
def longlong_binop(long long val): def longlong_binop(long long val):
""" """
>>> longlong_binop(3) >>> print(longlong_binop(3))
'long long called' long long called
""" """
cdef longlong_return x = f(val * val) cdef longlong_return x = f(val * val)
return x.msg return x.msg.decode('ASCII')
def longlong_unnop(long long val): def longlong_unnop(long long val):
""" """
>>> longlong_unnop(3) >>> print(longlong_unnop(3))
'long long called' long long called
""" """
cdef longlong_return x = f(~val) cdef longlong_return x = f(~val)
return x.msg return x.msg.decode('ASCII')
def test_bint(bint a): def test_bint(bint a):
......
/* A set of mutually incompatable return types. */ /* A set of mutually incompatable return types. */
struct short_return { const char *msg; }; struct short_return { char *msg; };
struct int_return { const char *msg; }; struct int_return { char *msg; };
struct longlong_return { const char *msg; }; struct longlong_return { char *msg; };
/* A set of overloaded methods. */ /* A set of overloaded methods. */
short_return f(short arg) { short_return f(short arg) {
short_return val; short_return val;
val.msg = "short called"; val.msg = (char*)"short called";
return val; return val;
} }
int_return f(int arg) { int_return f(int arg) {
int_return val; int_return val;
val.msg = "int called"; val.msg = (char*)"int called";
return val; return val;
} }
longlong_return f(long long arg) { longlong_return f(long long arg) {
longlong_return val; longlong_return val;
val.msg = "long long called"; val.msg = (char*)"long long called";
return val; return val;
} }
......
# tag: cpp
from libcpp.string cimport string
b_asdf = b'asdf'
b_asdg = b'asdg'
b_s = b's'
def test_indexing(char *py_str):
"""
>>> test_indexing(b_asdf)
('s', 's')
"""
cdef string s
s = string(py_str)
return chr(s[1]), chr(s.at(1))
def test_size(char *py_str):
"""
>>> test_size(b_asdf)
(4, 4)
"""
cdef string s
s = string(py_str)
return s.size(), s.length()
def test_compare(char *a, char *b):
"""
>>> test_compare(b_asdf, b_asdf)
0
>>> test_compare(b_asdf, b_asdg) < 0
True
"""
cdef string s = string(a)
cdef string t = string(b)
return s.compare(t)
def test_empty():
"""
>>> test_empty()
(True, False)
"""
cdef string a = string(<char *>b"")
cdef string b = string(<char *>b"aa")
return a.empty(), b.empty()
def test_push_back(char *a):
"""
>>> test_push_back(b_asdf) == b_asdf + b_s
True
"""
cdef string s = string(a)
s.push_back(<char>ord('s'))
return s.c_str()
def test_insert(char *a, char *b, int i):
"""
>>> test_insert('AAAA'.encode('ASCII'), 'BBBB'.encode('ASCII'), 2) == 'AABBBBAA'.encode('ASCII')
True
"""
cdef string s = string(a)
cdef string t = string(b)
cdef string u = s.insert(i, t)
return u.c_str()
def test_copy(char *a):
"""
>>> test_copy(b_asdf) == b_asdf[1:]
True
"""
cdef string t = string(a)
cdef char buffer[6]
cdef size_t length = t.copy(buffer, 4, 1)
buffer[length] = c'\0'
return buffer
def test_find(char *a, char *b):
"""
>>> test_find(b_asdf, 'df'.encode('ASCII'))
2
"""
cdef string s = string(a)
cdef string t = string(b)
cdef size_t i = s.find(t)
return i
def test_clear():
"""
>>> test_clear() == ''.encode('ASCII')
True
"""
cdef string s = string(<char *>"asdf")
s.clear()
return s.c_str()
def test_assign(char *a):
"""
>>> test_assign(b_asdf) == 'ggg'.encode('ASCII')
True
"""
cdef string s = string(a)
s.assign(<char *>"ggg")
return s.c_str()
def test_substr(char *a):
"""
>>> test_substr('ABCDEFGH'.encode('ASCII')) == ('BCDEFGH'.encode('ASCII'), 'BCDE'.encode('ASCII'), 'ABCDEFGH'.encode('ASCII'))
True
"""
cdef string s = string(a)
cdef string x, y, z
x = s.substr(1)
y = s.substr(1, 4)
z = s.substr()
return x.c_str(), y.c_str(), z.c_str()
def test_append(char *a, char *b):
"""
>>> test_append(b_asdf, '1234'.encode('ASCII')) == b_asdf + '1234'.encode('ASCII')
True
"""
cdef string s = string(a)
cdef string t = string(b)
cdef string j = s.append(t)
return j.c_str()
def test_char_compare(py_str):
"""
>>> test_char_compare(b_asdf)
True
"""
cdef char *a = py_str
cdef string b = string(a)
return b.compare(b) == 0
def test_cstr(char *a):
"""
>>> test_cstr(b_asdf) == b_asdf
True
"""
cdef string b = string(a)
return b.c_str()
def test_equals_operator(char *a, char *b):
"""
>>> test_equals_operator(b_asdf, b_asdf)
(True, False)
"""
cdef string s = string(a)
cdef string t = string(b)
return t == s, t != <char *>"asdf"
def test_less_than(char *a, char *b):
"""
>>> test_less_than(b_asdf[:-1], b_asdf)
(True, True, True)
>>> test_less_than(b_asdf[:-1], b_asdf[:-1])
(False, False, True)
"""
cdef string s = string(a)
cdef string t = string(b)
return (s < t, s < b, s <= b)
def test_greater_than(char *a, char *b):
"""
>>> test_greater_than(b_asdf[:-1], b_asdf)
(False, False, False)
>>> test_greater_than(b_asdf[:-1], b_asdf[:-1])
(False, False, True)
"""
cdef string s = string(a)
cdef string t = string(b)
return (s > t, s > b, s >= b)
# cython: language_level=3 # cython: language_level=3
# mode: run
# tag: generators, python3
cimport cython cimport cython
...@@ -89,6 +91,16 @@ def list_comp(): ...@@ -89,6 +91,16 @@ def list_comp():
assert x == 'abc' # don't leak in Py3 code assert x == 'abc' # don't leak in Py3 code
return result return result
def list_comp_with_lambda():
"""
>>> list_comp_with_lambda()
[0, 4, 8]
"""
x = 'abc'
result = [x*2 for x in range(5) if (lambda x:x % 2)(x) == 0]
assert x == 'abc' # don't leak in Py3 code
return result
module_level_lc = [ module_level_loopvar*2 for module_level_loopvar in range(4) ] module_level_lc = [ module_level_loopvar*2 for module_level_loopvar in range(4) ]
def list_comp_module_level(): def list_comp_module_level():
""" """
......
...@@ -13,7 +13,7 @@ def test_in(s): ...@@ -13,7 +13,7 @@ def test_in(s):
>>> test_in('') >>> test_in('')
5 5
""" """
if s in (u'ABC', u'BCD'): if s in (u'ABC', u'BCD', u'ABC'[:3], u'ABC'[::-1], u'ABC'[-1]):
return 1 return 1
elif s.upper() in (u'ABC', u'BCD'): elif s.upper() in (u'ABC', u'BCD'):
return 2 return 2
......
__doc__ = u"""
>>> def bar():
... try:
... foo()
... except ValueError:
... if IS_PY3:
... print(isinstance(sys.exc_info()[1].__cause__, TypeError))
... else:
... print(True)
>>> bar()
True
>>> print(sys.exc_info())
(None, None, None)
>>> def bar2():
... try:
... foo2()
... except ValueError:
... if IS_PY3:
... cause = sys.exc_info()[1].__cause__
... print(isinstance(cause, TypeError))
... print(cause.args==('value',))
... pass
... else:
... print(True)
... print(True)
>>> bar2()
True
True
"""
import sys
IS_PY3 = sys.version_info[0] >= 3
if not IS_PY3:
sys.exc_clear()
def foo():
try:
raise TypeError
except TypeError:
raise ValueError from TypeError
def foo2():
try:
raise TypeError
except TypeError:
raise ValueError() from TypeError('value')
# mode: run
# tag: generators, lambda
def genexpr():
"""
>>> genexpr()
[0, 2, 4, 6, 8]
"""
x = 'abc'
result = list( x*2 for x in range(5) )
assert x == 'abc' # don't leak
return result
def genexpr_if():
"""
>>> genexpr_if()
[0, 4, 8]
"""
x = 'abc'
result = list( x*2 for x in range(5) if x % 2 == 0 )
assert x == 'abc' # don't leak
return result
def genexpr_with_lambda():
"""
>>> genexpr_with_lambda()
[0, 4, 8]
"""
x = 'abc'
result = list( x*2 for x in range(5) if (lambda x:x % 2)(x) == 0 )
assert x == 'abc' # don't leak
return result
def genexpr_of_lambdas(int N):
"""
>>> [ (f(), g()) for f,g in genexpr_of_lambdas(5) ]
[(0, 0), (1, 2), (2, 4), (3, 6), (4, 8)]
"""
return ( ((lambda : x), (lambda : x*2)) for x in range(N) )
...@@ -167,6 +167,40 @@ def check_yield_in_except(): ...@@ -167,6 +167,40 @@ def check_yield_in_except():
except ValueError: except ValueError:
yield yield
def yield_in_except_throw_exc_type():
"""
>>> import sys
>>> g = yield_in_except_throw_exc_type()
>>> next(g)
>>> g.throw(TypeError)
Traceback (most recent call last):
TypeError
>>> next(g)
Traceback (most recent call last):
StopIteration
"""
try:
raise ValueError
except ValueError:
yield
def yield_in_except_throw_instance():
"""
>>> import sys
>>> g = yield_in_except_throw_instance()
>>> next(g)
>>> g.throw(TypeError())
Traceback (most recent call last):
TypeError
>>> next(g)
Traceback (most recent call last):
StopIteration
"""
try:
raise ValueError
except ValueError:
yield
def test_swap_assignment(): def test_swap_assignment():
""" """
>>> gen = test_swap_assignment() >>> gen = test_swap_assignment()
......
# mode: run
# tag: condexpr
# ticket: 267 # ticket: 267
""" cimport cython
>>> constants(4)
1
>>> constants(5)
10
>>> temps(4)
1
>>> temps(5)
10
>>> nested(1)
1
>>> nested(2)
2
>>> nested(3)
3
"""
def ident(x): return x def ident(x): return x
def constants(x): def constants(x):
"""
>>> constants(4)
1
>>> constants(5)
10
"""
a = 1 if x < 5 else 10 a = 1 if x < 5 else 10
return a return a
def temps(x): def temps(x):
"""
>>> temps(4)
1
>>> temps(5)
10
"""
return ident(1) if ident(x) < ident(5) else ident(10) return ident(1) if ident(x) < ident(5) else ident(10)
def nested(x): def nested(x):
"""
>>> nested(1)
1
>>> nested(2)
2
>>> nested(3)
3
"""
return 1 if x == 1 else (2 if x == 2 else 3) return 1 if x == 1 else (2 if x == 2 else 3)
@cython.test_fail_if_path_exists('//CondExprNode')
def const_true(a,b):
"""
>>> const_true(1,2)
1
"""
return a if 1 == 1 else b
@cython.test_fail_if_path_exists('//CondExprNode')
def const_false(a,b):
"""
>>> const_false(1,2)
2
"""
return a if 1 != 1 else b
...@@ -20,10 +20,9 @@ def test_relative(): ...@@ -20,10 +20,9 @@ def test_relative():
def test_absolute(): def test_absolute():
""" """
>>> test_absolute() >>> test_absolute() # doctest: +ELLIPSIS
Traceback (most recent call last): Traceback (most recent call last):
... ImportError: No module named ...debug...
ImportError: No module named debug
""" """
import debug import debug
return return
......
...@@ -132,6 +132,18 @@ __doc__ = ur""" ...@@ -132,6 +132,18 @@ __doc__ = ur"""
>>> len(bytes_uescape) >>> len(bytes_uescape)
28 28
>>> (sys.version_info[0] >= 3 and sys.maxunicode == 1114111 and len(str_uescape) == 3 or
... sys.version_info[0] >= 3 and sys.maxunicode == 65535 and len(str_uescape) == 4 or
... sys.version_info[0] < 3 and len(str_uescape) == 17 or
... len(str_uescape))
True
>>> (sys.version_info[0] >= 3 and str_uescape[0] == 'c' or
... sys.version_info[0] < 3 and str_uescape[0] == '\\' or
... str_uescape[0])
True
>>> print(str_uescape[-1])
B
>>> newlines == "Aaa\n" >>> newlines == "Aaa\n"
True True
...@@ -173,6 +185,7 @@ bresc = br'\12\'\"\\' ...@@ -173,6 +185,7 @@ bresc = br'\12\'\"\\'
uresc = ur'\12\'\"\\' uresc = ur'\12\'\"\\'
bytes_uescape = b'\u1234\U12345678\u\u1\u12\uX' bytes_uescape = b'\u1234\U12345678\u\u1\u12\uX'
str_uescape = '\u0063\U00012345\x42'
newlines = "Aaa\n" newlines = "Aaa\n"
......
import unittest
# adapted from pyregr
class TestCause(unittest.TestCase):
def test_invalid_cause(self):
try:
raise IndexError from 5
except TypeError as e:
self.assertTrue("exception cause" in str(e))
else:
self.fail("No exception raised")
def test_class_cause(self):
try:
raise IndexError from KeyError
except IndexError as e:
self.assertTrue(isinstance(e.__cause__, KeyError))
else:
self.fail("No exception raised")
def test_instance_cause(self):
cause = KeyError()
try:
raise IndexError from cause
except IndexError as e:
self.assertTrue(e.__cause__ is cause)
else:
self.fail("No exception raised")
def test_erroneous_cause(self):
class MyException(Exception):
def __init__(self):
raise RuntimeError()
try:
raise IndexError from MyException
except RuntimeError:
pass
else:
self.fail("No exception raised")
__doc__ = u"""
>>> print(foo())
a
"""
# Indirectly makes sure the cleanup happens correctly on breaking. # Indirectly makes sure the cleanup happens correctly on breaking.
def foo():
for x in "abc": def try_except_break():
"""
>>> print(try_except_break())
a
"""
for x in list("abc"):
try: try:
x() x()
except: except:
break break
for x in "abc": return x
def try_break_except():
"""
>>> print(try_break_except())
a
"""
for x in list("abc"):
try:
break
except:
pass
return x
def try_no_break_except_return():
"""
>>> print(try_no_break_except_return())
a
"""
for x in list("abc"):
try: try:
x() x()
break
except: except:
return x return x
return x
...@@ -3,17 +3,32 @@ ...@@ -3,17 +3,32 @@
__doc__ = """ __doc__ = """
>>> inner_result >>> inner_result
['ENTER'] ['ENTER']
>>> result >>> result # doctest: +ELLIPSIS
['ENTER', ...EXIT (<...ValueError...>,...ValueError..., <traceback object at ...)...]
>>> inner_result_no_exc
['ENTER']
>>> result_no_exc
['ENTER', 'EXIT (None, None, None)'] ['ENTER', 'EXIT (None, None, None)']
""" """
result = []
class ContextManager(object): class ContextManager(object):
def __init__(self, result):
self.result = result
def __enter__(self): def __enter__(self):
result.append("ENTER") self.result.append("ENTER")
def __exit__(self, *values): def __exit__(self, *values):
result.append("EXIT %r" % (values,)) self.result.append("EXIT %r" % (values,))
return True
result_no_exc = []
with ContextManager(result_no_exc) as c:
inner_result_no_exc = result_no_exc[:]
result = []
with ContextManager() as c: with ContextManager(result) as c:
inner_result = result[:] inner_result = result[:]
raise ValueError('TEST')
...@@ -56,17 +56,6 @@ def with_pass(): ...@@ -56,17 +56,6 @@ def with_pass():
with ContextManager(u"value") as x: with ContextManager(u"value") as x:
pass pass
def with_return():
"""
>>> with_return()
enter
exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
"""
with ContextManager(u"value") as x:
# FIXME: DISABLED - currently crashes!!
# return x
pass
def with_exception(exit_ret): def with_exception(exit_ret):
""" """
>>> with_exception(None) >>> with_exception(None)
......
import sys
def typename(t):
name = type(t).__name__
if sys.version_info < (2,5):
if name == 'classobj' and issubclass(t, MyException):
name = 'type'
elif name == 'instance' and isinstance(t, MyException):
name = 'MyException'
return "<type '%s'>" % name
class MyException(Exception):
pass
class ContextManager(object):
def __init__(self, value, exit_ret = None):
self.value = value
self.exit_ret = exit_ret
def __exit__(self, a, b, tb):
print("exit %s %s %s" % (typename(a), typename(b), typename(tb)))
return self.exit_ret
def __enter__(self):
print("enter")
return self.value
def no_as():
"""
>>> no_as()
enter
hello
exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
"""
with ContextManager("value"):
print("hello")
def basic():
"""
>>> basic()
enter
value
exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
"""
with ContextManager("value") as x:
print(x)
def with_pass():
"""
>>> with_pass()
enter
exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
"""
with ContextManager("value") as x:
pass
def with_return():
"""
>>> print(with_return())
enter
exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
value
"""
with ContextManager("value") as x:
return x
def with_break():
"""
>>> print(with_break())
enter
exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
a
"""
for c in list("abc"):
with ContextManager("value") as x:
break
print("FAILED")
return c
def with_continue():
"""
>>> print(with_continue())
enter
exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
enter
exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
enter
exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
c
"""
for c in list("abc"):
with ContextManager("value") as x:
continue
print("FAILED")
return c
def with_exception(exit_ret):
"""
>>> with_exception(None)
enter
value
exit <type 'type'> <type 'MyException'> <type 'traceback'>
outer except
>>> with_exception(True)
enter
value
exit <type 'type'> <type 'MyException'> <type 'traceback'>
"""
try:
with ContextManager("value", exit_ret=exit_ret) as value:
print(value)
raise MyException()
except:
print("outer except")
def functions_in_with():
"""
>>> f = functions_in_with()
enter
exit <type 'type'> <type 'MyException'> <type 'traceback'>
outer except
>>> f(1)[0]
1
>>> print(f(1)[1])
value
"""
try:
with ContextManager("value") as value:
def f(x): return x, value
make = lambda x:x()
raise make(MyException)
except:
print("outer except")
return f
def multitarget():
"""
>>> multitarget()
enter
1 2 3 4 5
exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
"""
with ContextManager((1, 2, (3, (4, 5)))) as (a, b, (c, (d, e))):
print('%s %s %s %s %s' % (a, b, c, d, e))
def tupletarget():
"""
>>> tupletarget()
enter
(1, 2, (3, (4, 5)))
exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
"""
with ContextManager((1, 2, (3, (4, 5)))) as t:
print(t)
def multimanager():
"""
>>> multimanager()
enter
enter
enter
enter
enter
enter
2
value
1 2 3 4 5
nested
exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
"""
with ContextManager(1), ContextManager(2) as x, ContextManager('value') as y,\
ContextManager(3), ContextManager((1, 2, (3, (4, 5)))) as (a, b, (c, (d, e))):
with ContextManager('nested') as nested:
print(x)
print(y)
print('%s %s %s %s %s' % (a, b, c, d, e))
print(nested)
# Tests borrowed from pyregr test_with.py,
# modified to follow the constraints of Cython.
import unittest
class Dummy(object):
def __init__(self, value=None, gobble=False):
if value is None:
value = self
self.value = value
self.gobble = gobble
self.enter_called = False
self.exit_called = False
def __enter__(self):
self.enter_called = True
return self.value
def __exit__(self, *exc_info):
self.exit_called = True
self.exc_info = exc_info
if self.gobble:
return True
class InitRaises(object):
def __init__(self): raise RuntimeError()
class EnterRaises(object):
def __enter__(self): raise RuntimeError()
def __exit__(self, *exc_info): pass
class ExitRaises(object):
def __enter__(self): pass
def __exit__(self, *exc_info): raise RuntimeError()
class NestedWith(unittest.TestCase):
"""
>>> NestedWith().runTest()
"""
def runTest(self):
self.testNoExceptions()
self.testExceptionInExprList()
self.testExceptionInEnter()
self.testExceptionInExit()
self.testEnterReturnsTuple()
def testNoExceptions(self):
with Dummy() as a, Dummy() as b:
self.assertTrue(a.enter_called)
self.assertTrue(b.enter_called)
self.assertTrue(a.exit_called)
self.assertTrue(b.exit_called)
def testExceptionInExprList(self):
try:
with Dummy() as a, InitRaises():
pass
except:
pass
self.assertTrue(a.enter_called)
self.assertTrue(a.exit_called)
def testExceptionInEnter(self):
try:
with Dummy() as a, EnterRaises():
self.fail('body of bad with executed')
except RuntimeError:
pass
else:
self.fail('RuntimeError not reraised')
self.assertTrue(a.enter_called)
self.assertTrue(a.exit_called)
def testExceptionInExit(self):
body_executed = False
with Dummy(gobble=True) as a, ExitRaises():
body_executed = True
self.assertTrue(a.enter_called)
self.assertTrue(a.exit_called)
self.assertTrue(body_executed)
self.assertNotEqual(a.exc_info[0], None)
def testEnterReturnsTuple(self):
with Dummy(value=(1,2)) as (a1, a2), \
Dummy(value=(10, 20)) as (b1, b2):
self.assertEquals(1, a1)
self.assertEquals(2, a2)
self.assertEquals(10, b1)
self.assertEquals(20, b2)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment