Commit 5cc0a44b authored by Dag Sverre Seljebotn's avatar Dag Sverre Seljebotn

merge

parents da14f38b fa55b747
...@@ -20,7 +20,7 @@ class EmbedSignature(CythonTransform): ...@@ -20,7 +20,7 @@ class EmbedSignature(CythonTransform):
try: try:
denv = self.denv # XXX denv = self.denv # XXX
ctval = default_val.compile_time_value(self.denv) ctval = default_val.compile_time_value(self.denv)
repr_val = '%r' % ctval repr_val = repr(ctval)
if isinstance(default_val, ExprNodes.UnicodeNode): if isinstance(default_val, ExprNodes.UnicodeNode):
if repr_val[:1] != 'u': if repr_val[:1] != 'u':
return u'u%s' % repr_val return u'u%s' % repr_val
...@@ -28,8 +28,8 @@ class EmbedSignature(CythonTransform): ...@@ -28,8 +28,8 @@ class EmbedSignature(CythonTransform):
if repr_val[:1] != 'b': if repr_val[:1] != 'b':
return u'b%s' % repr_val return u'b%s' % repr_val
elif isinstance(default_val, ExprNodes.StringNode): elif isinstance(default_val, ExprNodes.StringNode):
if repr_val[:1] in ('u', 'b'): if repr_val[:1] in 'ub':
repr_val[1:] return repr_val[1:]
return repr_val return repr_val
except Exception: except Exception:
try: try:
......
...@@ -137,7 +137,7 @@ def analyse_buffer_options(globalpos, env, posargs, dictargs, defaults=None, nee ...@@ -137,7 +137,7 @@ def analyse_buffer_options(globalpos, env, posargs, dictargs, defaults=None, nee
if defaults is None: if defaults is None:
defaults = buffer_defaults defaults = buffer_defaults
posargs, dictargs = Interpreter.interpret_compiletime_options(posargs, dictargs, type_env=env) posargs, dictargs = Interpreter.interpret_compiletime_options(posargs, dictargs, type_env=env, type_args = (0,'dtype'))
if len(posargs) > buffer_positional_options_count: if len(posargs) > buffer_positional_options_count:
raise CompileError(posargs[-1][1], ERR_BUF_TOO_MANY) raise CompileError(posargs[-1][1], ERR_BUF_TOO_MANY)
......
...@@ -637,6 +637,8 @@ class GlobalState(object): ...@@ -637,6 +637,8 @@ class GlobalState(object):
def put_cached_builtin_init(self, pos, name, cname): def put_cached_builtin_init(self, pos, name, cname):
w = self.parts['cached_builtins'] w = self.parts['cached_builtins']
interned_cname = self.get_interned_identifier(name).cname interned_cname = self.get_interned_identifier(name).cname
from ExprNodes import get_name_interned_utility_code
self.use_utility_code(get_name_interned_utility_code)
w.putln('%s = __Pyx_GetName(%s, %s); if (!%s) %s' % ( w.putln('%s = __Pyx_GetName(%s, %s); if (!%s) %s' % (
cname, cname,
Naming.builtins_cname, Naming.builtins_cname,
...@@ -667,7 +669,7 @@ class GlobalState(object): ...@@ -667,7 +669,7 @@ class GlobalState(object):
decls_writer = self.parts['decls'] decls_writer = self.parts['decls']
for _, cname, c in c_consts: for _, cname, c in c_consts:
decls_writer.putln('static char %s[] = "%s";' % ( decls_writer.putln('static char %s[] = "%s";' % (
cname, c.escaped_value)) cname, StringEncoding.split_string_literal(c.escaped_value)))
if c.py_strings is not None: if c.py_strings is not None:
for py_string in c.py_strings.itervalues(): for py_string in c.py_strings.itervalues():
py_strings.append((c.cname, len(py_string.cname), py_string)) py_strings.append((c.cname, len(py_string.cname), py_string))
......
This diff is collapsed.
...@@ -17,7 +17,7 @@ class EmptyScope(object): ...@@ -17,7 +17,7 @@ class EmptyScope(object):
empty_scope = EmptyScope() empty_scope = EmptyScope()
def interpret_compiletime_options(optlist, optdict, type_env=None): def interpret_compiletime_options(optlist, optdict, type_env=None, type_args=()):
""" """
Tries to interpret a list of compile time option nodes. Tries to interpret a list of compile time option nodes.
The result will be a tuple (optlist, optdict) but where The result will be a tuple (optlist, optdict) but where
...@@ -34,21 +34,21 @@ def interpret_compiletime_options(optlist, optdict, type_env=None): ...@@ -34,21 +34,21 @@ def interpret_compiletime_options(optlist, optdict, type_env=None):
A CompileError will be raised if there are problems. A CompileError will be raised if there are problems.
""" """
def interpret(node): def interpret(node, ix):
if isinstance(node, CBaseTypeNode): if ix in type_args:
if type_env: if type_env:
return (node.analyse(type_env), node.pos) return (node.analyse_as_type(type_env), node.pos)
else: else:
raise CompileError(node.pos, "Type not allowed here.") raise CompileError(node.pos, "Type not allowed here.")
else: else:
return (node.compile_time_value(empty_scope), node.pos) return (node.compile_time_value(empty_scope), node.pos)
if optlist: if optlist:
optlist = [interpret(x) for x in optlist] optlist = [interpret(x, ix) for ix, x in enumerate(optlist)]
if optdict: if optdict:
assert isinstance(optdict, DictNode) assert isinstance(optdict, DictNode)
new_optdict = {} new_optdict = {}
for item in optdict.key_value_pairs: for item in optdict.key_value_pairs:
new_optdict[item.key.value] = interpret(item.value) new_optdict[item.key.value] = interpret(item.value, item.key.value)
optdict = new_optdict optdict = new_optdict
return (optlist, new_optdict) return (optlist, new_optdict)
...@@ -66,7 +66,7 @@ class Context(object): ...@@ -66,7 +66,7 @@ class Context(object):
# include_directories [string] # include_directories [string]
# future_directives [object] # future_directives [object]
def __init__(self, include_directories, compiler_directives): def __init__(self, include_directories, compiler_directives, cpp=False):
#self.modules = {"__builtin__" : BuiltinScope()} #self.modules = {"__builtin__" : BuiltinScope()}
import Builtin, CythonScope import Builtin, CythonScope
self.modules = {"__builtin__" : Builtin.builtin_scope} self.modules = {"__builtin__" : Builtin.builtin_scope}
...@@ -74,6 +74,7 @@ class Context(object): ...@@ -74,6 +74,7 @@ class Context(object):
self.include_directories = include_directories self.include_directories = include_directories
self.future_directives = set() self.future_directives = set()
self.compiler_directives = compiler_directives self.compiler_directives = compiler_directives
self.cpp = cpp
self.pxds = {} # full name -> node tree self.pxds = {} # full name -> node tree
...@@ -87,7 +88,7 @@ class Context(object): ...@@ -87,7 +88,7 @@ class Context(object):
from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
from ParseTreeTransforms import InterpretCompilerDirectives, TransformBuiltinMethods from ParseTreeTransforms import InterpretCompilerDirectives, TransformBuiltinMethods
from TypeInference import MarkAssignments from TypeInference import MarkAssignments, MarkOverflowingArithmatic
from ParseTreeTransforms import AlignFunctionDefinitions, GilCheck from ParseTreeTransforms import AlignFunctionDefinitions, GilCheck
from AnalysedTreeTransforms import AutoTestDictTransform from AnalysedTreeTransforms import AutoTestDictTransform
from AutoDocTransforms import EmbedSignature from AutoDocTransforms import EmbedSignature
...@@ -134,6 +135,7 @@ class Context(object): ...@@ -134,6 +135,7 @@ class Context(object):
EmbedSignature(self), EmbedSignature(self),
EarlyReplaceBuiltinCalls(self), EarlyReplaceBuiltinCalls(self),
MarkAssignments(self), MarkAssignments(self),
MarkOverflowingArithmatic(self),
TransformBuiltinMethods(self), TransformBuiltinMethods(self),
IntroduceBufferAuxiliaryVars(self), IntroduceBufferAuxiliaryVars(self),
_check_c_declarations, _check_c_declarations,
...@@ -217,8 +219,11 @@ class Context(object): ...@@ -217,8 +219,11 @@ class Context(object):
for phase in pipeline: for phase in pipeline:
if phase is not None: if phase is not None:
if DebugFlags.debug_verbose_pipeline: if DebugFlags.debug_verbose_pipeline:
t = time()
print "Entering pipeline phase %r" % phase print "Entering pipeline phase %r" % phase
data = phase(data) data = phase(data)
if DebugFlags.debug_verbose_pipeline:
print " %.3f seconds" % (time() - t)
except CompileError, err: except CompileError, err:
# err is set # err is set
Errors.report_error(err) Errors.report_error(err)
...@@ -451,6 +456,7 @@ class Context(object): ...@@ -451,6 +456,7 @@ class Context(object):
if not isinstance(source_desc, FileSourceDescriptor): if not isinstance(source_desc, FileSourceDescriptor):
raise RuntimeError("Only file sources for code supported") raise RuntimeError("Only file sources for code supported")
source_filename = Utils.encode_filename(source_desc.filename) source_filename = Utils.encode_filename(source_desc.filename)
scope.cpp = self.cpp
# Parse the given source file and return a parse tree. # Parse the given source file and return a parse tree.
try: try:
f = Utils.open_source_file(source_filename, "rU") f = Utils.open_source_file(source_filename, "rU")
...@@ -540,7 +546,7 @@ def create_default_resultobj(compilation_source, options): ...@@ -540,7 +546,7 @@ def create_default_resultobj(compilation_source, options):
def run_pipeline(source, options, full_module_name = None): def run_pipeline(source, options, full_module_name = None):
# Set up context # Set up context
context = Context(options.include_path, options.compiler_directives) context = Context(options.include_path, options.compiler_directives, options.cplus)
# Set up source object # Set up source object
cwd = os.getcwd() cwd = os.getcwd()
......
...@@ -616,6 +616,9 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -616,6 +616,9 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
includes = [] includes = []
for filename in env.include_files: for filename in env.include_files:
# fake decoding of filenames to their original byte sequence # fake decoding of filenames to their original byte sequence
if filename[0] == '<' and filename[-1] == '>':
code.putln('#include %s' % filename)
else:
code.putln('#include "%s"' % filename) code.putln('#include "%s"' % filename)
def generate_filename_table(self, code): def generate_filename_table(self, code):
......
This diff is collapsed.
...@@ -89,10 +89,12 @@ class IterationTransform(Visitor.VisitorTransform): ...@@ -89,10 +89,12 @@ class IterationTransform(Visitor.VisitorTransform):
return self._transform_dict_iteration( return self._transform_dict_iteration(
node, dict_obj=iterator, keys=True, values=False) node, dict_obj=iterator, keys=True, values=False)
# C array slice iteration? # C array (slice) iteration?
if isinstance(iterator, ExprNodes.SliceIndexNode) and \ if isinstance(iterator, ExprNodes.SliceIndexNode) and \
(iterator.base.type.is_array or iterator.base.type.is_ptr): (iterator.base.type.is_array or iterator.base.type.is_ptr):
return self._transform_carray_iteration(node, iterator) return self._transform_carray_iteration(node, iterator)
elif iterator.type.is_array:
return self._transform_carray_iteration(node, iterator)
elif not isinstance(iterator, ExprNodes.SimpleCallNode): elif not isinstance(iterator, ExprNodes.SimpleCallNode):
return node return node
...@@ -131,13 +133,26 @@ class IterationTransform(Visitor.VisitorTransform): ...@@ -131,13 +133,26 @@ class IterationTransform(Visitor.VisitorTransform):
return node return node
def _transform_carray_iteration(self, node, slice_node): def _transform_carray_iteration(self, node, slice_node):
if isinstance(slice_node, ExprNodes.SliceIndexNode):
slice_base = slice_node.base
start = slice_node.start start = slice_node.start
stop = slice_node.stop stop = slice_node.stop
step = None step = None
if not stop: if not stop:
return node return node
elif slice_node.type.is_array and slice_node.type.size is not None:
slice_base = slice_node
start = None
stop = ExprNodes.IntNode(
slice_node.pos, value=str(slice_node.type.size))
step = None
else:
return node
carray_ptr = slice_node.base.coerce_to_simple(self.current_scope) ptr_type = slice_base.type
if ptr_type.is_array:
ptr_type = ptr_type.element_ptr_type()
carray_ptr = slice_base.coerce_to_simple(self.current_scope)
if start and start.constant_result != 0: if start and start.constant_result != 0:
start_ptr_node = ExprNodes.AddNode( start_ptr_node = ExprNodes.AddNode(
...@@ -145,7 +160,7 @@ class IterationTransform(Visitor.VisitorTransform): ...@@ -145,7 +160,7 @@ class IterationTransform(Visitor.VisitorTransform):
operand1=carray_ptr, operand1=carray_ptr,
operator='+', operator='+',
operand2=start, operand2=start,
type=carray_ptr.type) type=ptr_type)
else: else:
start_ptr_node = carray_ptr start_ptr_node = carray_ptr
...@@ -154,13 +169,13 @@ class IterationTransform(Visitor.VisitorTransform): ...@@ -154,13 +169,13 @@ class IterationTransform(Visitor.VisitorTransform):
operand1=carray_ptr, operand1=carray_ptr,
operator='+', operator='+',
operand2=stop, operand2=stop,
type=carray_ptr.type type=ptr_type
).coerce_to_simple(self.current_scope) ).coerce_to_simple(self.current_scope)
counter = UtilNodes.TempHandle(carray_ptr.type) counter = UtilNodes.TempHandle(ptr_type)
counter_temp = counter.ref(node.target.pos) counter_temp = counter.ref(node.target.pos)
if slice_node.base.type.is_string and node.target.type.is_pyobject: if slice_base.type.is_string and node.target.type.is_pyobject:
# special case: char* -> bytes # special case: char* -> bytes
target_value = ExprNodes.SliceIndexNode( target_value = ExprNodes.SliceIndexNode(
node.target.pos, node.target.pos,
...@@ -181,7 +196,7 @@ class IterationTransform(Visitor.VisitorTransform): ...@@ -181,7 +196,7 @@ class IterationTransform(Visitor.VisitorTransform):
type=PyrexTypes.c_int_type), type=PyrexTypes.c_int_type),
base=counter_temp, base=counter_temp,
is_buffer_access=False, is_buffer_access=False,
type=carray_ptr.type.base_type) type=ptr_type.base_type)
if target_value.type != node.target.type: if target_value.type != node.target.type:
target_value = target_value.coerce_to(node.target.type, target_value = target_value.coerce_to(node.target.type,
...@@ -1606,20 +1621,20 @@ impl = "" ...@@ -1606,20 +1621,20 @@ impl = ""
pop_utility_code = UtilityCode( pop_utility_code = UtilityCode(
proto = """ proto = """
static CYTHON_INLINE PyObject* __Pyx_PyObject_Pop(PyObject* L) { static CYTHON_INLINE PyObject* __Pyx_PyObject_Pop(PyObject* L) {
#if PY_VERSION_HEX >= 0x02040000
if (likely(PyList_CheckExact(L)) if (likely(PyList_CheckExact(L))
/* Check that both the size is positive and no reallocation shrinking needs to be done. */ /* Check that both the size is positive and no reallocation shrinking needs to be done. */
&& likely(PyList_GET_SIZE(L) > (((PyListObject*)L)->allocated >> 1))) { && likely(PyList_GET_SIZE(L) > (((PyListObject*)L)->allocated >> 1))) {
Py_SIZE(L) -= 1; Py_SIZE(L) -= 1;
return PyList_GET_ITEM(L, PyList_GET_SIZE(L)); return PyList_GET_ITEM(L, PyList_GET_SIZE(L));
} }
else { #endif
PyObject *r, *m; PyObject *r, *m;
m = __Pyx_GetAttrString(L, "pop"); m = __Pyx_GetAttrString(L, "pop");
if (!m) return NULL; if (!m) return NULL;
r = PyObject_CallObject(m, NULL); r = PyObject_CallObject(m, NULL);
Py_DECREF(m); Py_DECREF(m);
return r; return r;
}
} }
""", """,
impl = "" impl = ""
...@@ -1632,6 +1647,7 @@ static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix); ...@@ -1632,6 +1647,7 @@ static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix);
impl = """ impl = """
static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix) { static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix) {
PyObject *r, *m, *t, *py_ix; PyObject *r, *m, *t, *py_ix;
#if PY_VERSION_HEX >= 0x02040000
if (likely(PyList_CheckExact(L))) { if (likely(PyList_CheckExact(L))) {
Py_ssize_t size = PyList_GET_SIZE(L); Py_ssize_t size = PyList_GET_SIZE(L);
if (likely(size > (((PyListObject*)L)->allocated >> 1))) { if (likely(size > (((PyListObject*)L)->allocated >> 1))) {
...@@ -1650,6 +1666,7 @@ static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix) { ...@@ -1650,6 +1666,7 @@ static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix) {
} }
} }
} }
#endif
py_ix = t = NULL; py_ix = t = NULL;
m = __Pyx_GetAttrString(L, "pop"); m = __Pyx_GetAttrString(L, "pop");
if (!m) goto bad; if (!m) goto bad;
......
...@@ -128,7 +128,6 @@ class PostParseError(CompileError): pass ...@@ -128,7 +128,6 @@ class PostParseError(CompileError): pass
# error strings checked by unit tests, so define them # error strings checked by unit tests, so define them
ERR_CDEF_INCLASS = 'Cannot assign default value to fields in cdef classes, structs or unions' ERR_CDEF_INCLASS = 'Cannot assign default value to fields in cdef classes, structs or unions'
ERR_BUF_LOCALONLY = 'Buffer types only allowed as function local variables'
ERR_BUF_DEFAULTS = 'Invalid buffer defaults specification (see docs)' ERR_BUF_DEFAULTS = 'Invalid buffer defaults specification (see docs)'
ERR_INVALID_SPECIALATTR_TYPE = 'Special attributes must not have a type declared' ERR_INVALID_SPECIALATTR_TYPE = 'Special attributes must not have a type declared'
class PostParse(CythonTransform): class PostParse(CythonTransform):
...@@ -145,7 +144,7 @@ class PostParse(CythonTransform): ...@@ -145,7 +144,7 @@ class PostParse(CythonTransform):
- Interpret some node structures into Python runtime values. - Interpret some node structures into Python runtime values.
Some nodes take compile-time arguments (currently: Some nodes take compile-time arguments (currently:
CBufferAccessTypeNode[args] and __cythonbufferdefaults__ = {args}), TemplatedTypeNode[args] and __cythonbufferdefaults__ = {args}),
which should be interpreted. This happens in a general way which should be interpreted. This happens in a general way
and other steps should be taken to ensure validity. and other steps should be taken to ensure validity.
...@@ -154,7 +153,7 @@ class PostParse(CythonTransform): ...@@ -154,7 +153,7 @@ class PostParse(CythonTransform):
- For __cythonbufferdefaults__ the arguments are checked for - For __cythonbufferdefaults__ the arguments are checked for
validity. validity.
CBufferAccessTypeNode has its directives interpreted: TemplatedTypeNode has its directives interpreted:
Any first positional argument goes into the "dtype" attribute, Any first positional argument goes into the "dtype" attribute,
any "ndim" keyword argument goes into the "ndim" attribute and any "ndim" keyword argument goes into the "ndim" attribute and
so on. Also it is checked that the directive combination is valid. so on. Also it is checked that the directive combination is valid.
...@@ -243,11 +242,6 @@ class PostParse(CythonTransform): ...@@ -243,11 +242,6 @@ class PostParse(CythonTransform):
self.context.nonfatal_error(e) self.context.nonfatal_error(e)
return None return None
def visit_CBufferAccessTypeNode(self, node):
if not self.scope_type == 'function':
raise PostParseError(node.pos, ERR_BUF_LOCALONLY)
return node
class PxdPostParse(CythonTransform, SkipDeclarations): class PxdPostParse(CythonTransform, SkipDeclarations):
""" """
Basic interpretation/validity checking that should only be Basic interpretation/validity checking that should only be
...@@ -329,7 +323,23 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations): ...@@ -329,7 +323,23 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
duplication of functionality has to occur: We manually track cimports duplication of functionality has to occur: We manually track cimports
and which names the "cython" module may have been imported to. and which names the "cython" module may have been imported to.
""" """
special_methods = set(['declare', 'union', 'struct', 'typedef', 'sizeof', 'typeof', 'cast', 'address', 'pointer', 'compiled', 'NULL']) unop_method_nodes = {
'typeof': TypeofNode,
'operator.address': AmpersandNode,
'operator.dereference': DereferenceNode,
'operator.preincrement' : inc_dec_constructor(True, '++'),
'operator.predecrement' : inc_dec_constructor(True, '--'),
'operator.postincrement': inc_dec_constructor(False, '++'),
'operator.postdecrement': inc_dec_constructor(False, '--'),
# For backwards compatability.
'address': AmpersandNode,
}
special_methods = set(['declare', 'union', 'struct', 'typedef', 'sizeof',
'cast', 'pointer', 'compiled', 'NULL']
+ unop_method_nodes.keys())
def __init__(self, context, compilation_directive_defaults): def __init__(self, context, compilation_directive_defaults):
super(InterpretCompilerDirectives, self).__init__(context) super(InterpretCompilerDirectives, self).__init__(context)
...@@ -364,26 +374,37 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations): ...@@ -364,26 +374,37 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
node.cython_module_names = self.cython_module_names node.cython_module_names = self.cython_module_names
return node return node
# Track cimports of the cython module. # The following four functions track imports and cimports that
# begin with "cython"
def is_cython_directive(self, name):
return (name in Options.directive_types or
name in self.special_methods or
PyrexTypes.parse_basic_type(name))
def visit_CImportStatNode(self, node): def visit_CImportStatNode(self, node):
if node.module_name == u"cython": if node.module_name == u"cython":
self.cython_module_names.add(node.as_name or u"cython")
elif node.module_name.startswith(u"cython."):
if node.as_name: if node.as_name:
modname = node.as_name self.directive_names[node.as_name] = node.module_name[7:]
else: else:
modname = u"cython" self.cython_module_names.add(u"cython")
self.cython_module_names.add(modname) # if this cimport was a compiler directive, we don't
# want to leave the cimport node sitting in the tree
return None
return node return node
def visit_FromCImportStatNode(self, node): def visit_FromCImportStatNode(self, node):
if node.module_name == u"cython": if (node.module_name == u"cython") or \
node.module_name.startswith(u"cython."):
submodule = (node.module_name + u".")[7:]
newimp = [] newimp = []
for pos, name, as_name, kind in node.imported_names: for pos, name, as_name, kind in node.imported_names:
if (name in Options.directive_types or full_name = submodule + name
name in self.special_methods or if self.is_cython_directive(full_name):
PyrexTypes.parse_basic_type(name)):
if as_name is None: if as_name is None:
as_name = name as_name = full_name
self.directive_names[as_name] = name self.directive_names[as_name] = full_name
if kind is not None: if kind is not None:
self.context.nonfatal_error(PostParseError(pos, self.context.nonfatal_error(PostParseError(pos,
"Compiler directive imports must be plain imports")) "Compiler directive imports must be plain imports"))
...@@ -395,13 +416,14 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations): ...@@ -395,13 +416,14 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
return node return node
def visit_FromImportStatNode(self, node): def visit_FromImportStatNode(self, node):
if node.module.module_name.value == u"cython": if (node.module.module_name.value == u"cython") or \
node.module.module_name.value.startswith(u"cython."):
submodule = (node.module.module_name.value + u".")[7:]
newimp = [] newimp = []
for name, name_node in node.items: for name, name_node in node.items:
if (name in Options.directive_types or full_name = submodule + name
name in self.special_methods or if self.is_cython_directive(full_name):
PyrexTypes.parse_basic_type(name)): self.directive_names[name_node.name] = full_name
self.directive_names[name_node.name] = name
else: else:
newimp.append((name, name_node)) newimp.append((name, name_node))
if not newimp: if not newimp:
...@@ -1016,7 +1038,12 @@ class TransformBuiltinMethods(EnvTransform): ...@@ -1016,7 +1038,12 @@ class TransformBuiltinMethods(EnvTransform):
# cython.foo # cython.foo
function = node.function.as_cython_attribute() function = node.function.as_cython_attribute()
if function: if function:
if function == u'cast': if function in InterpretCompilerDirectives.unop_method_nodes:
if len(node.args) != 1:
error(node.function.pos, u"%s() takes exactly one argument" % function)
else:
node = InterpretCompilerDirectives.unop_method_nodes[function](node.function.pos, operand=node.args[0])
elif function == u'cast':
if len(node.args) != 2: if len(node.args) != 2:
error(node.function.pos, u"cast() takes exactly two arguments") error(node.function.pos, u"cast() takes exactly two arguments")
else: else:
...@@ -1034,16 +1061,6 @@ class TransformBuiltinMethods(EnvTransform): ...@@ -1034,16 +1061,6 @@ class TransformBuiltinMethods(EnvTransform):
node = SizeofTypeNode(node.function.pos, arg_type=type) node = SizeofTypeNode(node.function.pos, arg_type=type)
else: else:
node = SizeofVarNode(node.function.pos, operand=node.args[0]) node = SizeofVarNode(node.function.pos, operand=node.args[0])
elif function == 'typeof':
if len(node.args) != 1:
error(node.function.pos, u"typeof() takes exactly one argument")
else:
node = TypeofNode(node.function.pos, operand=node.args[0])
elif function == 'address':
if len(node.args) != 1:
error(node.function.pos, u"address() takes exactly one argument")
else:
node = AmpersandNode(node.function.pos, operand=node.args[0])
elif function == 'cmod': elif function == 'cmod':
if len(node.args) != 2: if len(node.args) != 2:
error(node.function.pos, u"cmod() takes exactly two arguments") error(node.function.pos, u"cmod() takes exactly two arguments")
......
...@@ -28,6 +28,7 @@ cpdef p_typecast(PyrexScanner s) ...@@ -28,6 +28,7 @@ cpdef p_typecast(PyrexScanner s)
cpdef p_sizeof(PyrexScanner s) cpdef p_sizeof(PyrexScanner s)
cpdef p_yield_expression(PyrexScanner s) cpdef p_yield_expression(PyrexScanner s)
cpdef p_power(PyrexScanner s) cpdef p_power(PyrexScanner s)
cpdef p_new_expr(PyrexScanner s)
cpdef p_trailer(PyrexScanner s, node1) cpdef p_trailer(PyrexScanner s, node1)
cpdef p_call(PyrexScanner s, function) cpdef p_call(PyrexScanner s, function)
cpdef p_index(PyrexScanner s, base) cpdef p_index(PyrexScanner s, base)
...@@ -99,13 +100,13 @@ cpdef p_IF_statement(PyrexScanner s, ctx) ...@@ -99,13 +100,13 @@ cpdef p_IF_statement(PyrexScanner s, ctx)
cpdef p_statement(PyrexScanner s, ctx, bint first_statement = *) cpdef p_statement(PyrexScanner s, ctx, bint first_statement = *)
cpdef p_statement_list(PyrexScanner s, ctx, bint first_statement = *) cpdef p_statement_list(PyrexScanner s, ctx, bint first_statement = *)
cpdef p_suite(PyrexScanner s, ctx = *, bint with_doc = *, bint with_pseudo_doc = *) cpdef p_suite(PyrexScanner s, ctx = *, bint with_doc = *, bint with_pseudo_doc = *)
cpdef p_positional_and_keyword_args(PyrexScanner s, end_sy_set, type_positions= *, type_keywords= * ) cpdef p_positional_and_keyword_args(PyrexScanner s, end_sy_set, templates = *)
cpdef p_c_base_type(PyrexScanner s, bint self_flag = *, bint nonempty = *) cpdef p_c_base_type(PyrexScanner s, bint self_flag = *, bint nonempty = *, templates = *)
cpdef p_calling_convention(PyrexScanner s) cpdef p_calling_convention(PyrexScanner s)
cpdef p_c_complex_base_type(PyrexScanner s) cpdef p_c_complex_base_type(PyrexScanner s)
cpdef p_c_simple_base_type(PyrexScanner s, self_flag, nonempty) cpdef p_c_simple_base_type(PyrexScanner s, bint self_flag, bint nonempty, templates = *)
cpdef p_buffer_access(PyrexScanner s, base_type_node) cpdef p_buffer_or_template(PyrexScanner s, base_type_node, templates)
cpdef bint looking_at_name(PyrexScanner s) except -2 cpdef bint looking_at_name(PyrexScanner s) except -2
cpdef bint looking_at_expr(PyrexScanner s) except -2 cpdef bint looking_at_expr(PyrexScanner s) except -2
cpdef bint looking_at_base_type(PyrexScanner s) except -2 cpdef bint looking_at_base_type(PyrexScanner s) except -2
...@@ -149,3 +150,4 @@ cpdef p_doc_string(PyrexScanner s) ...@@ -149,3 +150,4 @@ cpdef p_doc_string(PyrexScanner s)
cpdef p_code(PyrexScanner s, level= *) cpdef p_code(PyrexScanner s, level= *)
cpdef p_compiler_directive_comments(PyrexScanner s) cpdef p_compiler_directive_comments(PyrexScanner s)
cpdef p_module(PyrexScanner s, pxd, full_module_name) cpdef p_module(PyrexScanner s, pxd, full_module_name)
cpdef p_cpp_class_definition(PyrexScanner s, pos, ctx)
This diff is collapsed.
This diff is collapsed.
...@@ -97,7 +97,10 @@ def initial_compile_time_env(): ...@@ -97,7 +97,10 @@ def initial_compile_time_env():
'UNAME_VERSION', 'UNAME_MACHINE') 'UNAME_VERSION', 'UNAME_MACHINE')
for name, value in zip(names, platform.uname()): for name, value in zip(names, platform.uname()):
benv.declare(name, value) benv.declare(name, value)
try:
import __builtin__ as builtins import __builtin__ as builtins
except ImportError:
import builtins
names = ('False', 'True', names = ('False', 'True',
'abs', 'bool', 'chr', 'cmp', 'complex', 'dict', 'divmod', 'enumerate', 'abs', 'bool', 'chr', 'cmp', 'complex', 'dict', 'divmod', 'enumerate',
'float', 'hash', 'hex', 'int', 'len', 'list', 'long', 'map', 'max', 'min', 'float', 'hash', 'hex', 'int', 'len', 'list', 'long', 'map', 'max', 'min',
...@@ -355,6 +358,14 @@ class PyrexScanner(Scanner): ...@@ -355,6 +358,14 @@ class PyrexScanner(Scanner):
t = "%s %s" % (self.sy, self.systring) t = "%s %s" % (self.sy, self.systring)
print("--- %3d %2d %s" % (line, col, t)) print("--- %3d %2d %s" % (line, col, t))
def peek(self):
saved = self.sy, self.systring
self.next()
next = self.sy, self.systring
self.unread(*next)
self.sy, self.systring = saved
return next
def put_back(self, sy, systring): def put_back(self, sy, systring):
self.unread(self.sy, self.systring) self.unread(self.sy, self.systring)
self.sy = sy self.sy = sy
......
...@@ -185,7 +185,9 @@ def escape_byte_string(s): ...@@ -185,7 +185,9 @@ def escape_byte_string(s):
append(c) append(c)
return join_bytes(l).decode('ISO-8859-1') return join_bytes(l).decode('ISO-8859-1')
def split_docstring(s): def split_string_literal(s):
# MSVC can't handle long string literals.
if len(s) < 2047: if len(s) < 2047:
return s return s
return '\\n\"\"'.join(s.split(r'\n')) else:
return '""'.join([s[i:i+2000] for i in range(0, len(s), 2000)]).replace(r'\""', '""\\')
This diff is collapsed.
...@@ -21,7 +21,7 @@ class TestBufferParsing(CythonTest): ...@@ -21,7 +21,7 @@ class TestBufferParsing(CythonTest):
def test_basic(self): def test_basic(self):
t = self.parse(u"cdef object[float, 4, ndim=2, foo=foo] x") t = self.parse(u"cdef object[float, 4, ndim=2, foo=foo] x")
bufnode = t.stats[0].base_type bufnode = t.stats[0].base_type
self.assert_(isinstance(bufnode, CBufferAccessTypeNode)) self.assert_(isinstance(bufnode, TemplatedTypeNode))
self.assertEqual(2, len(bufnode.positional_args)) self.assertEqual(2, len(bufnode.positional_args))
# print bufnode.dump() # print bufnode.dump()
# should put more here... # should put more here...
...@@ -32,14 +32,6 @@ class TestBufferParsing(CythonTest): ...@@ -32,14 +32,6 @@ class TestBufferParsing(CythonTest):
def test_type_keyword(self): def test_type_keyword(self):
self.parse(u"cdef object[foo=foo, dtype=short unsigned int] x") self.parse(u"cdef object[foo=foo, dtype=short unsigned int] x")
def test_notype_as_expr1(self):
self.not_parseable("Expected: expression",
u"cdef object[foo2=short unsigned int] x")
def test_notype_as_expr2(self):
self.not_parseable("Expected: expression",
u"cdef object[int, short unsigned int] x")
def test_pos_after_key(self): def test_pos_after_key(self):
self.not_parseable("Non-keyword arg following keyword arg", self.not_parseable("Non-keyword arg following keyword arg",
u"cdef object[foo=1, 2] x") u"cdef object[foo=1, 2] x")
...@@ -65,7 +57,7 @@ class TestBufferOptions(CythonTest): ...@@ -65,7 +57,7 @@ class TestBufferOptions(CythonTest):
vardef = root.stats[0].body.stats[0] vardef = root.stats[0].body.stats[0]
assert isinstance(vardef, CVarDefNode) # use normal assert as this is to validate the test code assert isinstance(vardef, CVarDefNode) # use normal assert as this is to validate the test code
buftype = vardef.base_type buftype = vardef.base_type
self.assert_(isinstance(buftype, CBufferAccessTypeNode)) self.assert_(isinstance(buftype, TemplatedTypeNode))
self.assert_(isinstance(buftype.base_type_node, CSimpleBaseTypeNode)) self.assert_(isinstance(buftype.base_type_node, CSimpleBaseTypeNode))
self.assertEqual(u"object", buftype.base_type_node.name) self.assertEqual(u"object", buftype.base_type_node.name)
return buftype return buftype
......
...@@ -112,6 +112,62 @@ class MarkAssignments(CythonTransform): ...@@ -112,6 +112,62 @@ class MarkAssignments(CythonTransform):
self.visitchildren(node) self.visitchildren(node)
return node return node
class MarkOverflowingArithmatic(CythonTransform):
# It may be possible to integrate this with the above for
# performance improvements (though likely not worth it).
might_overflow = False
def __call__(self, root):
self.env_stack = []
self.env = root.scope
return super(MarkOverflowingArithmatic, self).__call__(root)
def visit_safe_node(self, node):
self.might_overflow, saved = False, self.might_overflow
self.visitchildren(node)
self.might_overflow = saved
return node
def visit_neutral_node(self, node):
self.visitchildren(node)
return node
def visit_dangerous_node(self, node):
self.might_overflow, saved = True, self.might_overflow
self.visitchildren(node)
self.might_overflow = saved
return node
def visit_FuncDefNode(self, node):
self.env_stack.append(self.env)
self.env = node.local_scope
self.visit_safe_node(node)
self.env = self.env_stack.pop()
return node
def visit_NameNode(self, node):
if self.might_overflow:
entry = node.entry or self.env.lookup(node.name)
if entry:
entry.might_overflow = True
return node
def visit_BinopNode(self, node):
if node.operator in '&|^':
return self.visit_neutral_node(node)
else:
return self.visit_dangerous_node(node)
visit_UnopNode = visit_neutral_node
visit_UnaryMinusNode = visit_dangerous_node
visit_InPlaceAssignmentNode = visit_dangerous_node
visit_Node = visit_safe_node
class PyObjectTypeInferer: class PyObjectTypeInferer:
""" """
...@@ -175,7 +231,7 @@ class SimpleAssignmentTypeInferer: ...@@ -175,7 +231,7 @@ class SimpleAssignmentTypeInferer:
entry = ready_to_infer.pop() entry = ready_to_infer.pop()
types = [expr.infer_type(scope) for expr in entry.assignments] types = [expr.infer_type(scope) for expr in entry.assignments]
if types: if types:
entry.type = spanning_type(types) entry.type = spanning_type(types, entry.might_overflow)
else: else:
# FIXME: raise a warning? # FIXME: raise a warning?
# print "No assignments", entry.pos, entry # print "No assignments", entry.pos, entry
...@@ -188,9 +244,9 @@ class SimpleAssignmentTypeInferer: ...@@ -188,9 +244,9 @@ class SimpleAssignmentTypeInferer:
if len(deps) == 1 and deps == set([entry]): if len(deps) == 1 and deps == set([entry]):
types = [expr.infer_type(scope) for expr in entry.assignments if expr.type_dependencies(scope) == ()] types = [expr.infer_type(scope) for expr in entry.assignments if expr.type_dependencies(scope) == ()]
if types: if types:
entry.type = spanning_type(types) entry.type = spanning_type(types, entry.might_overflow)
types = [expr.infer_type(scope) for expr in entry.assignments] types = [expr.infer_type(scope) for expr in entry.assignments]
entry.type = spanning_type(types) # might be wider... entry.type = spanning_type(types, entry.might_overflow) # might be wider...
resolve_dependancy(entry) resolve_dependancy(entry)
del dependancies_by_entry[entry] del dependancies_by_entry[entry]
if ready_to_infer: if ready_to_infer:
...@@ -218,11 +274,11 @@ def find_spanning_type(type1, type2): ...@@ -218,11 +274,11 @@ def find_spanning_type(type1, type2):
return PyrexTypes.c_double_type return PyrexTypes.c_double_type
return result_type return result_type
def aggressive_spanning_type(types): def aggressive_spanning_type(types, might_overflow):
result_type = reduce(find_spanning_type, types) result_type = reduce(find_spanning_type, types)
return result_type return result_type
def safe_spanning_type(types): def safe_spanning_type(types, might_overflow):
result_type = reduce(find_spanning_type, types) result_type = reduce(find_spanning_type, types)
if result_type.is_pyobject: if result_type.is_pyobject:
# any specific Python type is always safe to infer # any specific Python type is always safe to infer
...@@ -235,6 +291,22 @@ def safe_spanning_type(types): ...@@ -235,6 +291,22 @@ def safe_spanning_type(types):
# find_spanning_type() only returns 'bint' for clean boolean # find_spanning_type() only returns 'bint' for clean boolean
# operations without other int types, so this is safe, too # operations without other int types, so this is safe, too
return result_type return result_type
elif result_type.is_ptr and not (result_type.is_int and result_type.rank == 0):
# Any pointer except (signed|unsigned|) char* can't implicitly
# become a PyObject.
return result_type
elif result_type.is_cpp_class:
# These can't implicitly become Python objects either.
return result_type
elif result_type.is_struct:
# Though we have struct -> object for some structs, this is uncommonly
# used, won't arise in pure Python, and there shouldn't be side
# effects, so I'm declaring this safe.
return result_type
# TODO: double complex should be OK as well, but we need
# to make sure everything is supported.
elif result_type.is_int and not might_overflow:
return result_type
return py_object_type return py_object_type
......
...@@ -100,7 +100,8 @@ class TreeVisitor(BasicVisitor): ...@@ -100,7 +100,8 @@ class TreeVisitor(BasicVisitor):
def dump_node(self, node, indent=0): def dump_node(self, node, indent=0):
ignored = list(node.child_attrs) + [u'child_attrs', u'pos', ignored = list(node.child_attrs) + [u'child_attrs', u'pos',
u'gil_message', u'subexprs'] u'gil_message', u'cpp_message',
u'subexprs']
values = [] values = []
pos = node.pos pos = node.pos
if pos: if pos:
......
...@@ -7,5 +7,6 @@ ...@@ -7,5 +7,6 @@
# and keep the old one under the module name _build_ext, # and keep the old one under the module name _build_ext,
# so that *our* build_ext can make use of it. # so that *our* build_ext can make use of it.
from build_ext import build_ext from Cython.Distutils.build_ext import build_ext
# from extension import Extension # from extension import Extension
...@@ -15,16 +15,6 @@ from distutils.sysconfig import customize_compiler, get_python_version ...@@ -15,16 +15,6 @@ from distutils.sysconfig import customize_compiler, get_python_version
from distutils.dep_util import newer, newer_group from distutils.dep_util import newer, newer_group
from distutils import log from distutils import log
from distutils.dir_util import mkpath from distutils.dir_util import mkpath
try:
from Cython.Compiler.Main \
import CompilationOptions, \
default_options as pyrex_default_options, \
compile as cython_compile
from Cython.Compiler.Errors import PyrexError
except ImportError, e:
print "failed to import Cython: %s" % e
PyrexError = None
from distutils.command import build_ext as _build_ext from distutils.command import build_ext as _build_ext
extension_name_re = _build_ext.extension_name_re extension_name_re = _build_ext.extension_name_re
...@@ -83,18 +73,22 @@ class build_ext(_build_ext.build_ext): ...@@ -83,18 +73,22 @@ class build_ext(_build_ext.build_ext):
self.build_extension(ext) self.build_extension(ext)
def cython_sources(self, sources, extension): def cython_sources(self, sources, extension):
""" """
Walk the list of source files in 'sources', looking for Cython Walk the list of source files in 'sources', looking for Cython
source files (.pyx and .py). Run Cython on all that are source files (.pyx and .py). Run Cython on all that are
found, and return a modified 'sources' list with Cython source found, and return a modified 'sources' list with Cython source
files replaced by the generated C (or C++) files. files replaced by the generated C (or C++) files.
""" """
try:
if PyrexError == None: from Cython.Compiler.Main \
raise DistutilsPlatformError, \ import CompilationOptions, \
("Cython does not appear to be installed " default_options as pyrex_default_options, \
"on platform '%s'") % os.name compile as cython_compile
from Cython.Compiler.Errors import PyrexError
except ImportError:
e = sys.exc_info()[1]
print("failed to import Cython: %s" % e)
raise DistutilsPlatformError("Cython does not appear to be installed")
new_sources = [] new_sources = []
pyrex_sources = [] pyrex_sources = []
......
cdef extern from "<vector>" namespace std:
cdef cppclass vector[TYPE]:
#constructors
__init__()
__init__(vector&)
__init__(int)
__init__(int, TYPE&)
__init__(iterator, iterator)
#operators
TYPE& __getitem__(int)
TYPE& __setitem__(int, TYPE&)
vector __new__(vector&)
bool __eq__(vector&, vector&)
bool __ne__(vector&, vector&)
bool __lt__(vector&, vector&)
bool __gt__(vector&, vector&)
bool __le__(vector&, vector&)
bool __ge__(vector&, vector&)
#others
void assign(int, TYPE)
#void assign(iterator, iterator)
TYPE& at(int)
TYPE& back()
iterator begin()
int capacity()
void clear()
bool empty()
iterator end()
iterator erase(iterator)
iterator erase(iterator, iterator)
TYPE& front()
iterator insert(iterator, TYPE&)
void insert(iterator, int, TYPE&)
void insert(iterator, iterator)
int max_size()
void pop_back()
void push_back(TYPE&)
iterator rbegin()
iterator rend()
void reserve(int)
void resize(int)
void resize(int, TYPE&) #void resize(size_type num, const TYPE& = TYPE())
int size()
void swap(container&)
cdef extern from "<deque>" namespace std:
cdef cppclass deque[TYPE]:
#constructors
__init__()
__init__(deque&)
__init__(int)
__init__(int, TYPE&)
__init__(iterator, iterator)
#operators
TYPE& operator[]( size_type index );
const TYPE& operator[]( size_type index ) const;
deque __new__(deque&);
bool __eq__(deque&, deque&);
bool __ne__(deque&, deque&);
bool __lt__(deque&, deque&);
bool __gt__(deque&, deque&);
bool __le__(deque&, deque&);
bool __ge__(deque&, deque&);
#others
void assign(int, TYPE&)
void assign(iterator, iterator)
TYPE& at(int)
TYPE& back()
iterator begin()
void clear()
bool empty()
iterator end()
iterator erase(iterator)
iterator erase(iterator, iterator)
TYPE& front()
iterator insert(iterator, TYPE&)
void insert(iterator, int, TYPE&)
void insert(iterator, iterator, iterator)
int max_size()
void pop_back()
void pop_front()
void push_back(TYPE&)
void push_front(TYPE&)
iterator rbegin()
iterator rend()
void resize(int)
void resize(int, TYPE&)
int size()
void swap(container&)
...@@ -13,10 +13,11 @@ except: ...@@ -13,10 +13,11 @@ except:
ext_modules=[ ext_modules=[
Extension("primes", ["primes.pyx"]), Extension("primes", ["primes.pyx"]),
Extension("spam", ["spam.pyx"]), Extension("spam", ["spam.pyx"]),
Extension("square", ["square.pyx"], language="c++"),
] ]
for file in glob.glob("*.pyx"): for file in glob.glob("*.pyx"):
if file != "numeric_demo.pyx": if file != "numeric_demo.pyx" and file != "square.pyx":
ext_modules.append(Extension(file[:-4], [file], include_dirs = numpy_include_dirs)) ext_modules.append(Extension(file[:-4], [file], include_dirs = numpy_include_dirs))
setup( setup(
......
;;;; `Cython' mode. (add-to-list 'auto-mode-alist '("\\.pyx\\'" . cython-mode)) (define-derived-mode cython-mode python-mode "Cython" (font-lock-add-keywords nil `((,(concat "\\<\\(NULL" "\\|c\\(def\\|har\\|typedef\\)" "\\|e\\(num\\|xtern\\)" "\\|float" "\\|in\\(clude\\|t\\)" "\\|object\\|public\\|struct\\|type\\|union\\|void" "\\)\\>") 1 font-lock-keyword-face t)))) ;; Cython mode
\ No newline at end of file
(require 'python-mode)
(add-to-list 'auto-mode-alist '("\\.pyx\\'" . cython-mode))
(add-to-list 'auto-mode-alist '("\\.pxd\\'" . cython-mode))
(add-to-list 'auto-mode-alist '("\\.pxi\\'" . cython-mode))
(defun cython-compile ()
"Compile the file via Cython."
(interactive)
(let ((cy-buffer (current-buffer)))
(with-current-buffer
(compile compile-command)
(set (make-local-variable 'cython-buffer) cy-buffer)
(add-to-list (make-local-variable 'compilation-finish-functions)
'cython-compilation-finish)))
)
(defun cython-compilation-finish (buffer how)
"Called when Cython compilation finishes."
;; XXX could annotate source here
)
(defvar cython-mode-map
(let ((map (make-sparse-keymap)))
;; Will inherit from `python-mode-map' thanks to define-derived-mode.
(define-key map "\C-c\C-c" 'cython-compile)
map)
"Keymap used in `cython-mode'.")
(defvar cython-font-lock-keywords
`(;; new keywords in Cython language
(,(regexp-opt '("by" "cdef" "cimport" "cpdef" "ctypedef" "enum" "except?"
"extern" "gil" "include" "nogil" "property" "public"
"readonly" "struct" "union" "DEF" "IF" "ELIF" "ELSE") 'words)
1 font-lock-keyword-face)
;; C and Python types (highlight as builtins)
(,(regexp-opt '("NULL" "bint" "char" "dict" "double" "float" "int" "list"
"long" "object" "Py_ssize_t" "short" "size_t" "void") 'words)
1 font-lock-builtin-face)
;; cdef is used for more than functions, so simply highlighting the next
;; word is problematic. struct, enum and property work though.
("\\<\\(?:struct\\|enum\\)[ \t]+\\([a-zA-Z_]+[a-zA-Z0-9_]*\\)"
1 py-class-name-face)
("\\<property[ \t]+\\([a-zA-Z_]+[a-zA-Z0-9_]*\\)"
1 font-lock-function-name-face))
"Additional font lock keywords for Cython mode.")
(define-derived-mode cython-mode python-mode "Cython"
"Major mode for Cython development, derived from Python mode.
\\{cython-mode-map}"
(setcar font-lock-defaults
(append python-font-lock-keywords cython-font-lock-keywords))
(set (make-local-variable 'compile-command)
(concat "cython -a " buffer-file-name))
(add-to-list (make-local-variable 'compilation-finish-functions)
'cython-compilation-finish)
)
(provide 'cython-mode)
...@@ -28,8 +28,8 @@ from distutils.core import Extension ...@@ -28,8 +28,8 @@ from distutils.core import Extension
from distutils.command.build_ext import build_ext as _build_ext from distutils.command.build_ext import build_ext as _build_ext
distutils_distro = Distribution() distutils_distro = Distribution()
TEST_DIRS = ['compile', 'errors', 'run', 'pyregr'] TEST_DIRS = ['compile', 'errors', 'run', 'wrappers', 'pyregr']
TEST_RUN_DIRS = ['run', 'pyregr'] TEST_RUN_DIRS = ['run', 'wrappers', 'pyregr']
# Lists external modules, and a matcher matching tests # Lists external modules, and a matcher matching tests
# which should be excluded if the module is not present. # which should be excluded if the module is not present.
...@@ -48,8 +48,8 @@ EXT_DEP_INCLUDES = [ ...@@ -48,8 +48,8 @@ EXT_DEP_INCLUDES = [
] ]
VER_DEP_MODULES = { VER_DEP_MODULES = {
# such as: (2,4) : (operator.le, lambda x: x in ['run.extern_builtins_T258'
# (2,4) : (operator.le, lambda x: x in ['run.set']), ]),
(3,): (operator.ge, lambda x: x in ['run.non_future_division', (3,): (operator.ge, lambda x: x in ['run.non_future_division',
'compile.extsetslice', 'compile.extsetslice',
'compile.extdelslice']), 'compile.extdelslice']),
...@@ -200,10 +200,10 @@ class TestBuilder(object): ...@@ -200,10 +200,10 @@ class TestBuilder(object):
fork=self.fork) fork=self.fork)
class CythonCompileTestCase(unittest.TestCase): class CythonCompileTestCase(unittest.TestCase):
def __init__(self, directory, workdir, module, language='c', def __init__(self, test_directory, workdir, module, language='c',
expect_errors=False, annotate=False, cleanup_workdir=True, expect_errors=False, annotate=False, cleanup_workdir=True,
cleanup_sharedlibs=True, cython_only=False, fork=True): cleanup_sharedlibs=True, cython_only=False, fork=True):
self.directory = directory self.test_directory = test_directory
self.workdir = workdir self.workdir = workdir
self.module = module self.module = module
self.language = language self.language = language
...@@ -257,8 +257,8 @@ class CythonCompileTestCase(unittest.TestCase): ...@@ -257,8 +257,8 @@ class CythonCompileTestCase(unittest.TestCase):
self.runCompileTest() self.runCompileTest()
def runCompileTest(self): def runCompileTest(self):
self.compile(self.directory, self.module, self.workdir, self.compile(self.test_directory, self.module, self.workdir,
self.directory, self.expect_errors, self.annotate) self.test_directory, self.expect_errors, self.annotate)
def find_module_source_file(self, source_file): def find_module_source_file(self, source_file):
if not os.path.exists(source_file): if not os.path.exists(source_file):
...@@ -269,8 +269,15 @@ class CythonCompileTestCase(unittest.TestCase): ...@@ -269,8 +269,15 @@ class CythonCompileTestCase(unittest.TestCase):
target = '%s.%s' % (module_name, self.language) target = '%s.%s' % (module_name, self.language)
return target return target
def split_source_and_output(self, directory, module, workdir): def find_source_files(self, test_directory, module_name):
source_file = os.path.join(directory, module) + '.pyx' is_related = re.compile('%s_.*[.]%s' % (module_name, self.language)).match
return [self.build_target_filename(module_name)] + [
os.path.join(test_directory, filename)
for filename in os.listdir(test_directory)
if is_related(filename) and os.path.isfile(os.path.join(test_directory, filename)) ]
def split_source_and_output(self, test_directory, module, workdir):
source_file = os.path.join(test_directory, module) + '.pyx'
source_and_output = codecs.open( source_and_output = codecs.open(
self.find_module_source_file(source_file), 'rU', 'ISO-8859-1') self.find_module_source_file(source_file), 'rU', 'ISO-8859-1')
out = codecs.open(os.path.join(workdir, module + '.pyx'), out = codecs.open(os.path.join(workdir, module + '.pyx'),
...@@ -289,12 +296,12 @@ class CythonCompileTestCase(unittest.TestCase): ...@@ -289,12 +296,12 @@ class CythonCompileTestCase(unittest.TestCase):
else: else:
return geterrors() return geterrors()
def run_cython(self, directory, module, targetdir, incdir, annotate): def run_cython(self, test_directory, module, targetdir, incdir, annotate):
include_dirs = INCLUDE_DIRS[:] include_dirs = INCLUDE_DIRS[:]
if incdir: if incdir:
include_dirs.append(incdir) include_dirs.append(incdir)
source = self.find_module_source_file( source = self.find_module_source_file(
os.path.join(directory, module + '.pyx')) os.path.join(test_directory, module + '.pyx'))
target = os.path.join(targetdir, self.build_target_filename(module)) target = os.path.join(targetdir, self.build_target_filename(module))
options = CompilationOptions( options = CompilationOptions(
pyrex_default_options, pyrex_default_options,
...@@ -309,7 +316,7 @@ class CythonCompileTestCase(unittest.TestCase): ...@@ -309,7 +316,7 @@ class CythonCompileTestCase(unittest.TestCase):
cython_compile(source, options=options, cython_compile(source, options=options,
full_module_name=module) full_module_name=module)
def run_distutils(self, module, workdir, incdir): def run_distutils(self, test_directory, module, workdir, incdir):
cwd = os.getcwd() cwd = os.getcwd()
os.chdir(workdir) os.chdir(workdir)
try: try:
...@@ -324,7 +331,7 @@ class CythonCompileTestCase(unittest.TestCase): ...@@ -324,7 +331,7 @@ class CythonCompileTestCase(unittest.TestCase):
ext_include_dirs += get_additional_include_dirs() ext_include_dirs += get_additional_include_dirs()
extension = Extension( extension = Extension(
module, module,
sources = [self.build_target_filename(module)], sources = self.find_source_files(test_directory, module),
include_dirs = ext_include_dirs, include_dirs = ext_include_dirs,
extra_compile_args = CFLAGS, extra_compile_args = CFLAGS,
) )
...@@ -337,19 +344,19 @@ class CythonCompileTestCase(unittest.TestCase): ...@@ -337,19 +344,19 @@ class CythonCompileTestCase(unittest.TestCase):
finally: finally:
os.chdir(cwd) os.chdir(cwd)
def compile(self, directory, module, workdir, incdir, def compile(self, test_directory, module, workdir, incdir,
expect_errors, annotate): expect_errors, annotate):
expected_errors = errors = () expected_errors = errors = ()
if expect_errors: if expect_errors:
expected_errors = self.split_source_and_output( expected_errors = self.split_source_and_output(
directory, module, workdir) test_directory, module, workdir)
directory = workdir test_directory = workdir
if WITH_CYTHON: if WITH_CYTHON:
old_stderr = sys.stderr old_stderr = sys.stderr
try: try:
sys.stderr = ErrorWriter() sys.stderr = ErrorWriter()
self.run_cython(directory, module, workdir, incdir, annotate) self.run_cython(test_directory, module, workdir, incdir, annotate)
errors = sys.stderr.geterrors() errors = sys.stderr.geterrors()
finally: finally:
sys.stderr = old_stderr sys.stderr = old_stderr
...@@ -373,7 +380,7 @@ class CythonCompileTestCase(unittest.TestCase): ...@@ -373,7 +380,7 @@ class CythonCompileTestCase(unittest.TestCase):
raise raise
else: else:
if not self.cython_only: if not self.cython_only:
self.run_distutils(module, workdir, incdir) self.run_distutils(test_directory, module, workdir, incdir)
class CythonRunTestCase(CythonCompileTestCase): class CythonRunTestCase(CythonCompileTestCase):
def shortDescription(self): def shortDescription(self):
...@@ -649,7 +656,7 @@ class FileListExcluder: ...@@ -649,7 +656,7 @@ class FileListExcluder:
self.excludes[line.split()[0]] = True self.excludes[line.split()[0]] = True
def __call__(self, testname): def __call__(self, testname):
return testname.split('.')[-1] in self.excludes return testname in self.excludes or testname.split('.')[-1] in self.excludes
if __name__ == '__main__': if __name__ == '__main__':
from optparse import OptionParser from optparse import OptionParser
......
...@@ -25,6 +25,11 @@ if sys.platform == "darwin": ...@@ -25,6 +25,11 @@ if sys.platform == "darwin":
setup_args = {} setup_args = {}
def add_command_class(name, cls):
cmdclasses = setup_args.get('cmdclass', {})
cmdclasses[name] = cls
setup_args['cmdclass'] = cmdclasses
if sys.version_info[0] >= 3: if sys.version_info[0] >= 3:
import lib2to3.refactor import lib2to3.refactor
from distutils.command.build_py \ from distutils.command.build_py \
...@@ -34,7 +39,7 @@ if sys.version_info[0] >= 3: ...@@ -34,7 +39,7 @@ if sys.version_info[0] >= 3:
if fix.split('fix_')[-1] not in ('next',) if fix.split('fix_')[-1] not in ('next',)
] ]
build_py.fixer_names = fixers build_py.fixer_names = fixers
setup_args['cmdclass'] = {"build_py" : build_py} add_command_class("build_py", build_py)
if sys.version_info < (2,4): if sys.version_info < (2,4):
...@@ -72,14 +77,45 @@ else: ...@@ -72,14 +77,45 @@ else:
else: else:
scripts = ["cython.py"] scripts = ["cython.py"]
def compile_cython_modules():
source_root = os.path.abspath(os.path.dirname(__file__))
compiled_modules = ["Cython.Plex.Scanners",
"Cython.Compiler.Scanning",
"Cython.Compiler.Parsing",
"Cython.Compiler.Visitor",
"Cython.Runtime.refnanny"]
extensions = []
try:
if sys.version_info[0] >= 3: if sys.version_info[0] >= 3:
raise ValueError from Cython.Distutils import build_ext as build_ext_orig
sys.argv.remove("--no-cython-compile") for module in compiled_modules:
except ValueError: source_file = os.path.join(source_root, *module.split('.'))
try: if os.path.exists(source_file + ".py"):
pyx_source_file = source_file + ".py"
else:
pyx_source_file = source_file + ".pyx"
extensions.append(
Extension(module, sources = [pyx_source_file])
)
class build_ext(build_ext_orig):
def build_extensions(self):
# add path where 2to3 installed the transformed sources
# and make sure Python (re-)imports them from there
already_imported = [ module for module in sys.modules
if module == 'Cython' or module.startswith('Cython.') ]
for module in already_imported:
del sys.modules[module]
sys.path.insert(0, os.path.join(source_root, self.build_lib))
build_ext_orig.build_extensions(self)
setup_args['ext_modules'] = extensions
add_command_class("build_ext", build_ext)
else: # Python 2.x
from distutils.command.build_ext import build_ext as build_ext_orig from distutils.command.build_ext import build_ext as build_ext_orig
try:
class build_ext(build_ext_orig): class build_ext(build_ext_orig):
def build_extension(self, ext, *args, **kargs): def build_extension(self, ext, *args, **kargs):
try: try:
...@@ -89,12 +125,6 @@ except ValueError: ...@@ -89,12 +125,6 @@ except ValueError:
from Cython.Compiler.Main import compile from Cython.Compiler.Main import compile
from Cython import Utils from Cython import Utils
source_root = os.path.dirname(__file__) source_root = os.path.dirname(__file__)
compiled_modules = ["Cython.Plex.Scanners",
"Cython.Compiler.Scanning",
"Cython.Compiler.Parsing",
"Cython.Compiler.Visitor",
"Cython.Runtime.refnanny"]
extensions = []
for module in compiled_modules: for module in compiled_modules:
source_file = os.path.join(source_root, *module.split('.')) source_file = os.path.join(source_root, *module.split('.'))
if os.path.exists(source_file + ".py"): if os.path.exists(source_file + ".py"):
...@@ -116,11 +146,16 @@ except ValueError: ...@@ -116,11 +146,16 @@ except ValueError:
print("Compilation failed") print("Compilation failed")
if extensions: if extensions:
setup_args['ext_modules'] = extensions setup_args['ext_modules'] = extensions
setup_args['cmdclass'] = {"build_ext" : build_ext} add_command_class("build_ext", build_ext)
except Exception: except Exception:
print("ERROR: %s" % sys.exc_info()[1]) print("ERROR: %s" % sys.exc_info()[1])
print("Extension module compilation failed, using plain Python implementation") print("Extension module compilation failed, using plain Python implementation")
try:
sys.argv.remove("--no-cython-compile")
except ValueError:
compile_cython_modules()
setup_args.update(setuptools_extra_args) setup_args.update(setuptools_extra_args)
from Cython.Compiler.Version import version from Cython.Compiler.Version import version
......
...@@ -8,3 +8,6 @@ unsignedbehaviour_T184 ...@@ -8,3 +8,6 @@ unsignedbehaviour_T184
missing_baseclass_in_predecl_T262 missing_baseclass_in_predecl_T262
cfunc_call_tuple_args_T408 cfunc_call_tuple_args_T408
cascaded_list_unpacking_T467 cascaded_list_unpacking_T467
compile.cpp_operators
cppwrap
cpp_overload_wrapper
cdef extern from "operators.h":
cdef cppclass Operators:
Operators(int)
Operators operator+(Operators)
Operators __add__(Operators, Operators)
Operators __sub__(Operators, Operators)
Operators __mul__(Operators, Operators)
Operators __div__(Operators, Operators)
bool __lt__(Operators, Operators)
bool __le__(Operators, Operators)
bool __eq__(Operators, Operators)
bool __ne__(Operators, Operators)
bool __gt__(Operators, Operators)
bool __ge__(Operators, Operators)
Operators __rshift__(Operators, int)
Operators __lshift__(Operators, int)
Operators __mod__(Operators, int)
cdef int v = 10
cdef Operators a
cdef Operators b
cdef Operators c
c = a + b
c = a - b
c = a * b
c = a / b
c = a << 2
c = a >> 1
c = b % 2
a < b
a <= b
a == b
a != b
a > b
a >= b
cdef extern from "templates.h":
cdef cppclass TemplateTest1[T]:
TemplateTest1()
T value
int t
T getValue()
cdef cppclass TemplateTest2[T, U]:
TemplateTest2()
T value1
U value2
T getValue1()
U getValue2()
cdef TemplateTest1[int] a
cdef TemplateTest1[int]* b = new TemplateTest1[int]()
cdef int c = a.getValue()
c = b.getValue()
cdef TemplateTest2[int, char] d
cdef TemplateTest2[int, char]* e = new TemplateTest2[int, char]()
c = d.getValue1()
c = e.getValue2()
cdef char f = d.getValue2()
f = e.getValue2()
del b, e
...@@ -11,6 +11,10 @@ cdef extern int (*iapfn())[5] ...@@ -11,6 +11,10 @@ cdef extern int (*iapfn())[5]
cdef extern char *(*cpapfn())[5] cdef extern char *(*cpapfn())[5]
cdef extern int fnargfn(int ()) cdef extern int fnargfn(int ())
cdef extern int ia[]
cdef extern int iaa[][3]
cdef extern int a(int[][3], int[][3][5])
cdef void f(): cdef void f():
cdef void *p=NULL cdef void *p=NULL
global ifnp, cpa global ifnp, cpa
......
This diff is collapsed.
#ifndef _OPERATORS_H_
#define _OPERATORS_H_
class Operators
{
public:
int value;
Operators() { }
Operators(int value) { this->value = value; }
virtual ~Operators() { }
Operators operator+(Operators f) { return Operators(this->value + f.value); }
Operators operator-(Operators f) { return Operators(this->value - f.value); }
Operators operator*(Operators f) { return Operators(this->value * f.value); }
Operators operator/(Operators f) { return Operators(this->value / f.value); }
bool operator<(Operators f) { return this->value < f.value; }
bool operator<=(Operators f) { return this->value <= f.value; }
bool operator==(Operators f) { return this->value == f.value; }
bool operator!=(Operators f) { return this->value != f.value; }
bool operator>(Operators f) { return this->value > f.value; }
bool operator>=(Operators f) { return this->value >= f.value; }
Operators operator>>(int v) { return Operators(this->value >> v); }
Operators operator<<(int v) { return Operators(this->value << v); }
Operators operator%(int v) { return Operators(this->value % v); }
};
#endif
#ifndef _TEMPLATES_H_
#define _TEMPLATES_H_
template<class T>
class TemplateTest1
{
public:
T value;
int t;
TemplateTest1() { }
T getValue() { return value; }
};
template<class T, class U>
class TemplateTest2
{
public:
T value1;
U value2;
TemplateTest2() { }
T getValue1() { return value1; }
U getValue2() { return value2; }
};
#endif
...@@ -12,9 +12,9 @@ def f(): ...@@ -12,9 +12,9 @@ def f():
cdef object[int, 2, well] buf6 cdef object[int, 2, well] buf6
_ERRORS = u""" _ERRORS = u"""
1:11: Buffer types only allowed as function local variables 1:17: Buffer types only allowed as function local variables
3:15: Buffer types only allowed as function local variables 3:21: Buffer types only allowed as function local variables
6:27: "fakeoption" is not a buffer option 6:31: "fakeoption" is not a buffer option
""" """
#TODO: #TODO:
#7:22: "ndim" must be non-negative #7:22: "ndim" must be non-negative
......
...@@ -12,7 +12,7 @@ def f(a): ...@@ -12,7 +12,7 @@ def f(a):
del s.m # error: deletion of non-Python object del s.m # error: deletion of non-Python object
_ERRORS = u""" _ERRORS = u"""
8:6: Cannot assign to or delete this 8:6: Cannot assign to or delete this
9:45: Deletion of non-Python object 9:45: Deletion of non-Python, non-C++ object
11:6: Deletion of non-Python object 11:6: Deletion of non-Python, non-C++ object
12:6: Deletion of non-Python object 12:6: Deletion of non-Python, non-C++ object
""" """
...@@ -111,21 +111,84 @@ def slice_charptr_for_loop_c_enumerate(): ...@@ -111,21 +111,84 @@ def slice_charptr_for_loop_c_enumerate():
############################################################ ############################################################
# tests for int* slicing # tests for int* slicing
## cdef int cints[6] cdef int cints[6]
## for i in range(6): for i in range(6):
## cints[i] = i cints[i] = i
## @cython.test_assert_path_exists("//ForFromStatNode", @cython.test_assert_path_exists("//ForFromStatNode",
## "//ForFromStatNode//IndexNode") "//ForFromStatNode//IndexNode")
## @cython.test_fail_if_path_exists("//ForInStatNode") @cython.test_fail_if_path_exists("//ForInStatNode")
## def slice_intptr_for_loop_c(): def slice_intarray_for_loop_c():
## """ """
## >>> slice_intptr_for_loop_c() >>> slice_intarray_for_loop_c()
## [0, 1, 2] [0, 1, 2]
## [1, 2, 3, 4] [1, 2, 3, 4]
## [4, 5] [4, 5]
## """ """
## cdef int i cdef int i
## print [ i for i in cints[:3] ] print [ i for i in cints[:3] ]
## print [ i for i in cints[1:5] ] print [ i for i in cints[1:5] ]
## print [ i for i in cints[4:6] ] print [ i for i in cints[4:6] ]
@cython.test_assert_path_exists("//ForFromStatNode",
"//ForFromStatNode//IndexNode")
@cython.test_fail_if_path_exists("//ForInStatNode")
def iter_intarray_for_loop_c():
"""
>>> iter_intarray_for_loop_c()
[0, 1, 2, 3, 4, 5]
"""
cdef int i
print [ i for i in cints ]
@cython.test_assert_path_exists("//ForFromStatNode",
"//ForFromStatNode//IndexNode")
@cython.test_fail_if_path_exists("//ForInStatNode")
def slice_intptr_for_loop_c():
"""
>>> slice_intptr_for_loop_c()
[0, 1, 2]
[1, 2, 3, 4]
[4, 5]
"""
cdef int* nums = cints
cdef int i
print [ i for i in nums[:3] ]
print [ i for i in nums[1:5] ]
print [ i for i in nums[4:6] ]
############################################################
# tests for slicing other arrays
cdef double cdoubles[6]
for i in range(6):
cdoubles[i] = i + 0.5
cdef double* cdoubles_ptr = cdoubles
@cython.test_assert_path_exists("//ForFromStatNode",
"//ForFromStatNode//IndexNode")
@cython.test_fail_if_path_exists("//ForInStatNode")
def slice_doublptr_for_loop_c():
"""
>>> slice_doublptr_for_loop_c()
[0.5, 1.5, 2.5]
[1.5, 2.5, 3.5, 4.5]
[4.5, 5.5]
"""
cdef double d
print [ d for d in cdoubles_ptr[:3] ]
print [ d for d in cdoubles_ptr[1:5] ]
print [ d for d in cdoubles_ptr[4:6] ]
@cython.test_assert_path_exists("//ForFromStatNode",
"//ForFromStatNode//IndexNode")
@cython.test_fail_if_path_exists("//ForInStatNode")
def iter_doublearray_for_loop_c():
"""
>>> iter_doublearray_for_loop_c()
[0.5, 1.5, 2.5, 3.5, 4.5, 5.5]
"""
cdef double d
print [ d for d in cdoubles ]
cimport cython
cdef extern from "Python.h":
cdef cython.unicode PyUnicode_DecodeUTF8(char* s, Py_ssize_t size, char* errors)
def test_capi():
"""
>>> print(test_capi())
abc
"""
return PyUnicode_DecodeUTF8("abc", 3, NULL)
...@@ -21,7 +21,7 @@ def test_arithmetic(double complex z, double complex w): ...@@ -21,7 +21,7 @@ def test_arithmetic(double complex z, double complex w):
>>> test_arithmetic(5-10j, 3+4j) >>> test_arithmetic(5-10j, 3+4j)
((5-10j), (-5+10j), (8-6j), (2-14j), (55-10j), (-1-2j)) ((5-10j), (-5+10j), (8-6j), (2-14j), (55-10j), (-1-2j))
""" """
return +z, -z, z+w, z-w, z*w, z/w return +z, -z+0, z+w, z-w, z*w, z/w
@cython.cdivision(False) @cython.cdivision(False)
def test_div_by_zero(double complex z): def test_div_by_zero(double complex z):
......
__doc__ = u"""
>>> test_new_del()
(2, 2)
>>> test_rect_area(3, 4)
12.0
>>> test_square_area(15)
(225.0, 225.0)
"""
cdef extern from "shapes.h" namespace shapes:
cdef cppclass Shape:
float area()
cdef cppclass Circle(Shape):
int radius
Circle(int)
cdef cppclass Rectangle(Shape):
int width
int height
Rectangle(int, int)
cdef cppclass Square(Rectangle):
int side
Square(int)
int constructor_count, destructor_count
def test_new_del():
cdef Rectangle *rect = new Rectangle(10, 20)
cdef Circle *circ = new Circle(15)
del rect, circ
return constructor_count, destructor_count
def test_rect_area(w, h):
cdef Rectangle *rect = new Rectangle(w, h)
try:
return rect.area()
finally:
del rect
def test_square_area(w):
cdef Square *sqr = new Square(w)
cdef Rectangle *rect = sqr
try:
return rect.area(), sqr.area()
finally:
del sqr
cdef double get_area(Rectangle s):
return s.area()
def test_value_call(int w):
"""
>>> test_value_call(5)
(25.0, 25.0)
"""
cdef Square *sqr = new Square(w)
cdef Rectangle *rect = sqr
try:
return get_area(sqr[0]), get_area(rect[0])
finally:
del sqr
from cython.operator cimport dereference as deref
cdef extern from "cpp_templates_helper.h":
cdef cppclass Wrap[T]:
Wrap(T)
void set(T)
T get()
bint operator==(Wrap[T])
cdef cppclass Pair[T1,T2]:
Pair(T1,T2)
T1 first()
T2 second()
bint operator==(Pair[T1,T2])
bint operator!=(Pair[T1,T2])
def test_wrap_pair(int i, double x):
"""
>>> test_wrap_pair(1, 1.5)
(1, 1.5, True)
>>> test_wrap_pair(2, 2.25)
(2, 2.25, True)
"""
try:
wrap = new Wrap[Pair[int, double]](Pair[int, double](i, x))
return wrap.get().first(), wrap.get().second(), deref(wrap) == deref(wrap)
finally:
del wrap
cimport cython.operator
from cython.operator cimport dereference as deref
cdef extern from "cpp_operators_helper.h":
cdef cppclass TestOps:
char* operator+()
char* operator-()
char* operator*()
char* operator~()
char* operator++()
char* operator--()
char* operator++(int)
char* operator--(int)
char* operator+(int)
char* operator-(int)
char* operator*(int)
char* operator/(int)
char* operator%(int)
char* operator|(int)
char* operator&(int)
char* operator^(int)
char* operator<<(int)
char* operator>>(int)
char* operator==(int)
char* operator!=(int)
char* operator>=(int)
char* operator<=(int)
char* operator>(int)
char* operator<(int)
char* operator[](int)
char* operator()(int)
def test_unops():
"""
>>> test_unops()
unary +
unary -
unary ~
unary *
"""
cdef TestOps* t = new TestOps()
print +t[0]
print -t[0]
print ~t[0]
print deref(t[0])
del t
def test_incdec():
"""
>>> test_incdec()
unary ++
unary --
post ++
post --
"""
cdef TestOps* t = new TestOps()
print cython.operator.preincrement(t[0])
print cython.operator.predecrement(t[0])
print cython.operator.postincrement(t[0])
print cython.operator.postdecrement(t[0])
del t
def test_binop():
"""
>>> test_binop()
binary +
binary -
binary *
binary /
binary %
binary &
binary |
binary ^
binary <<
binary >>
"""
cdef TestOps* t = new TestOps()
print t[0] + 1
print t[0] - 1
print t[0] * 1
print t[0] / 1
print t[0] % 1
print t[0] & 1
print t[0] | 1
print t[0] ^ 1
print t[0] << 1
print t[0] >> 1
del t
def test_cmp():
"""
>>> test_cmp()
binary ==
binary !=
binary >=
binary >
binary <=
binary <
"""
cdef TestOps* t = new TestOps()
print t[0] == 1
print t[0] != 1
print t[0] >= 1
print t[0] > 1
print t[0] <= 1
print t[0] < 1
del t
def test_index_call():
"""
>>> test_index_call()
binary []
binary ()
"""
cdef TestOps* t = new TestOps()
print t[0][100]
print t[0](100)
del t
#define UN_OP(op) const char* operator op () { return "unary "#op; }
#define POST_UN_OP(op) const char* operator op (int x) { return "post "#op; }
#define BIN_OP(op) const char* operator op (int x) { return "binary "#op; }
class TestOps {
public:
UN_OP(-);
UN_OP(+);
UN_OP(*);
UN_OP(~);
UN_OP(!);
UN_OP(&);
UN_OP(++);
UN_OP(--);
POST_UN_OP(++);
POST_UN_OP(--);
BIN_OP(+);
BIN_OP(-);
BIN_OP(*);
BIN_OP(/);
BIN_OP(%);
BIN_OP(<<);
BIN_OP(>>);
BIN_OP(|);
BIN_OP(&);
BIN_OP(^);
BIN_OP(==);
BIN_OP(!=);
BIN_OP(<=);
BIN_OP(<);
BIN_OP(>=);
BIN_OP(>);
BIN_OP([]);
BIN_OP(());
};
__doc__ = u"""
>>> test_vector([1,10,100])
1
10
100
"""
cdef extern from "vector" namespace std:
cdef cppclass iterator[T]:
pass
cdef cppclass vector[T]:
#constructors
__init__()
T at(int)
void push_back(T t)
void assign(int, T)
void clear()
iterator end()
iterator begin()
int size()
def test_vector(L):
cdef vector[int] *V = new vector[int]()
for a in L:
V.push_back(a)
cdef int i
for i in range(len(L)):
print V.at(i)
del V
cdef extern from "<vector>" namespace std:
cdef cppclass vector[T]:
void push_back(T)
size_t size()
T operator[](size_t)
def simple_test(double x):
"""
>>> simple_test(55)
3
"""
cdef vector[double] *v
try:
v = new vector[double]()
v.push_back(1.0)
v.push_back(x)
from math import pi
v.push_back(pi)
return v.size()
finally:
del v
def list_test(L):
"""
>>> list_test([1,2,4,8])
(4, 4)
>>> list_test([])
(0, 0)
>>> list_test([-1] * 1000)
(1000, 1000)
"""
cdef vector[int] *v
try:
v = new vector[int]()
for a in L:
v.push_back(a)
return len(L), v.size()
finally:
del v
def index_test(L):
"""
>>> index_test([1,2,4,8])
(1.0, 8.0)
>>> index_test([1.25])
(1.25, 1.25)
"""
cdef vector[double] *v
try:
v = new vector[double]()
for a in L:
v.push_back(a)
return v[0][0], v[0][len(L)-1]
finally:
del v
from cython.operator import dereference as deref
cdef extern from "cpp_templates_helper.h":
cdef cppclass Wrap[T]:
Wrap(T)
void set(T)
T get()
bint operator==(Wrap[T])
cdef cppclass Pair[T1,T2]:
Pair(T1,T2)
T1 first()
T2 second()
bint operator==(Pair[T1,T2])
bint operator!=(Pair[T1,T2])
def test_int(int x, int y):
"""
>>> test_int(3, 4)
(3, 4, False)
>>> test_int(100, 100)
(100, 100, True)
"""
try:
a = new Wrap[int](x)
b = new Wrap[int](0)
b.set(y)
return a.get(), b.get(), a[0] == b[0]
finally:
del a, b
def test_double(double x, double y):
"""
>>> test_double(3, 3.5)
(3.0, 3.5, False)
>>> test_double(100, 100)
(100.0, 100.0, True)
"""
try:
a = new Wrap[double](x)
b = new Wrap[double](-1)
b.set(y)
return a.get(), b.get(), deref(a) == deref(b)
finally:
del a, b
def test_pair(int i, double x):
"""
>>> test_pair(1, 1.5)
(1, 1.5, True, False)
>>> test_pair(2, 2.25)
(2, 2.25, True, False)
"""
try:
pair = new Pair[int, double](i, x)
return pair.first(), pair.second(), deref(pair) == deref(pair), deref(pair) != deref(pair)
finally:
del pair
template <class T>
class Wrap {
T value;
public:
Wrap(T v) : value(v) { }
void set(T v) { value = v; }
T get(void) { return value; }
bool operator==(Wrap<T> other) { return value == other.value; }
};
template <class T1, class T2>
class Pair {
T1 _first;
T2 _second;
public:
Pair(T1 u, T2 v) { _first = u; _second = v; }
T1 first(void) { return _first; }
T2 second(void) { return _second; }
bool operator==(Pair<T1,T2> other) { return _first == other._first && _second == other._second; }
bool operator!=(Pair<T1,T2> other) { return _first != other._first || _second != other._second; }
};
cdef extern from "Python.h": cdef extern from "Python.h":
ctypedef class __builtin__.str [object PyStringObject]:
cdef long ob_shash
ctypedef class __builtin__.list [object PyListObject]: ctypedef class __builtin__.list [object PyListObject]:
cdef Py_ssize_t ob_size
cdef Py_ssize_t allocated cdef Py_ssize_t allocated
ctypedef class __builtin__.dict [object PyDictObject]: ctypedef class __builtin__.dict [object PyDictObject]:
pass pass
cdef str s = "abc" cdef Py_ssize_t Py_SIZE(object o)
cdef list L = [1,2,4] cdef list L = [1,2,4]
cdef dict d = {'A': 'a'} cdef dict d = {'A': 'a'}
...@@ -23,18 +20,7 @@ def test_list(list L): ...@@ -23,18 +20,7 @@ def test_list(list L):
>>> test_list(list_subclass([1,2,3])) >>> test_list(list_subclass([1,2,3]))
True True
""" """
return L.ob_size <= L.allocated return Py_SIZE(L) <= L.allocated
def test_str(str s):
"""
>>> test_str("abc")
True
>>> class str_subclass(str): pass
>>> test_str(str_subclass("xyz"))
True
"""
cdef char* ss = s
return hash(s) == s.ob_shash
def test_tuple(tuple t): def test_tuple(tuple t):
""" """
......
cdef extern from *:
int new(int new)
def new(x):
"""
>>> new(3)
3
"""
cdef int new = x
return new
def x(new):
"""
>>> x(10)
110
>>> x(1)
1
"""
if new*new != new:
return new + new**2
return new
class A:
def new(self, n):
"""
>>> a = A()
>>> a.new(3)
6
>>> a.new(5)
120
"""
if n <= 1:
return 1
else:
return n * self.new(n-1)
#ifndef SHAPES_H
#define SHAPES_H
namespace shapes {
int constructor_count = 0;
int destructor_count = 0;
class Shape
{
public:
virtual float area() = 0;
Shape() { constructor_count++; }
virtual ~Shape() { destructor_count++; }
};
class Rectangle : public Shape
{
public:
Rectangle(int width, int height)
{
this->width = width;
this->height = height;
}
float area() { return width * height; }
int width;
int height;
};
class Square : public Rectangle
{
public:
Square(int side) : Rectangle(side, side) { this->side = side; }
int side;
};
class Circle : public Shape {
public:
Circle(int radius) { this->radius = radius; }
float area() { return 3.1415926535897931f * radius; }
int radius;
};
}
#endif
...@@ -93,13 +93,49 @@ def arithmetic(): ...@@ -93,13 +93,49 @@ def arithmetic():
>>> arithmetic() >>> arithmetic()
""" """
a = 1 + 2 a = 1 + 2
assert typeof(a) == "long" assert typeof(a) == "long", typeof(a)
b = 1 + 1.5 b = 1 + 1.5
assert typeof(b) == "double" assert typeof(b) == "double", typeof(b)
c = 1 + <object>2 c = 1 + <object>2
assert typeof(c) == "Python object" assert typeof(c) == "Python object", typeof(c)
d = "abc %s" % "x" d = 1 * 1.5 ** 2
assert typeof(d) == "Python object" assert typeof(d) == "double", typeof(d)
def builtin_type_operations():
"""
>>> builtin_type_operations()
"""
b1 = b'a' * 10
b1 = 10 * b'a'
b1 = 10 * b'a' * 10
assert typeof(b1) == "bytes object", typeof(b1)
b2 = b'a' + b'b'
assert typeof(b2) == "bytes object", typeof(b2)
u1 = u'a' * 10
u1 = 10 * u'a'
assert typeof(u1) == "unicode object", typeof(u1)
u2 = u'a' + u'b'
assert typeof(u2) == "unicode object", typeof(u2)
u3 = u'a%s' % u'b'
u3 = u'a%s' % 10
assert typeof(u3) == "unicode object", typeof(u3)
s1 = "abc %s" % "x"
s1 = "abc %s" % 10
assert typeof(s1) == "str object", typeof(s1)
s2 = "abc %s" + "x"
assert typeof(s2) == "str object", typeof(s2)
s3 = "abc %s" * 10
s3 = "abc %s" * 10 * 10
s3 = 10 * "abc %s" * 10
assert typeof(s3) == "str object", typeof(s3)
L1 = [] + []
assert typeof(L1) == "list object", typeof(L1)
L2 = [] * 2
assert typeof(L2) == "list object", typeof(L2)
T1 = () + ()
assert typeof(T1) == "tuple object", typeof(T1)
T2 = () * 2
assert typeof(T2) == "tuple object", typeof(T2)
def cascade(): def cascade():
""" """
...@@ -215,10 +251,29 @@ def safe_only(): ...@@ -215,10 +251,29 @@ def safe_only():
""" """
a = 1.0 a = 1.0
assert typeof(a) == "double", typeof(c) assert typeof(a) == "double", typeof(c)
b = 1 b = 1;
assert typeof(b) == "Python object", typeof(b) assert typeof(b) == "long", typeof(b)
c = MyType() c = MyType()
assert typeof(c) == "MyType", typeof(c) assert typeof(c) == "MyType", typeof(c)
for i in range(10): pass
assert typeof(i) == "long", typeof(i)
d = 1
res = ~d
assert typeof(d) == "long", typeof(d)
# potentially overflowing arithmatic
e = 1
e += 1
assert typeof(e) == "Python object", typeof(e)
f = 1
res = f * 10
assert typeof(f) == "Python object", typeof(f)
g = 1
res = 10*(~g)
assert typeof(g) == "Python object", typeof(g)
for j in range(10):
res = -j
assert typeof(j) == "Python object", typeof(j)
@infer_types(None) @infer_types(None)
def args_tuple_keywords(*args, **kwargs): def args_tuple_keywords(*args, **kwargs):
...@@ -249,3 +304,36 @@ def args_tuple_keywords_reassign_pyobjects(*args, **kwargs): ...@@ -249,3 +304,36 @@ def args_tuple_keywords_reassign_pyobjects(*args, **kwargs):
args = [] args = []
kwargs = "test" kwargs = "test"
# / A -> AA -> AAA
# Base0 -> Base -
# \ B -> BB
# C -> CC
cdef class Base0: pass
cdef class Base(Base0): pass
cdef class A(Base): pass
cdef class AA(A): pass
cdef class AAA(AA): pass
cdef class B(Base): pass
cdef class BB(B): pass
cdef class C: pass
cdef class CC(C): pass
@infer_types(None)
def common_extension_type_base():
"""
>>> common_extension_type_base()
"""
x = A()
x = AA()
assert typeof(x) == "A", typeof(x)
y = A()
y = B()
assert typeof(y) == "Base", typeof(y)
z = AAA()
z = BB()
assert typeof(z) == "Base", typeof(z)
w = A()
w = CC()
assert typeof(w) == "Python object", typeof(w)
cimport cython.operator
def test_deref(int x):
"""
>>> test_deref(3)
3
>>> test_deref(5)
5
"""
cdef int* x_ptr = &x
return cython.operator.dereference(x_ptr)
def increment_decrement(int x):
"""
>>> increment_decrement(10)
11 11 12
11 11 10
10
"""
print cython.operator.preincrement(x), cython.operator.postincrement(x), x
print cython.operator.predecrement(x), cython.operator.postdecrement(x), x
return x
This diff is collapsed.
This diff is collapsed.
void voidfunc(void);
double doublefunc (double a, double b, double c);
class DoubleKeeper
{
double number;
public:
DoubleKeeper ();
DoubleKeeper (double number);
virtual ~DoubleKeeper ();
void set_number (double num);
void set_number (void);
double get_number () const;
virtual double transmogrify (double value) const;
};
double transmogrify_from_cpp (DoubleKeeper const *obj, double value);
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment