Commit c8b2401a authored by Robert Bradshaw's avatar Robert Bradshaw

Merge gsoc-danilo C++ code into main branch.

parents c5b12aac 09cbfe82
...@@ -1041,6 +1041,40 @@ class ImagNode(AtomicExprNode): ...@@ -1041,6 +1041,40 @@ class ImagNode(AtomicExprNode):
code.put_gotref(self.py_result()) code.put_gotref(self.py_result())
class NewExprNode(AtomicExprNode):
# C++ new statement
#
# cppclass string c++ class to create
# template_parameters None or [ExprNode] temlate parameters, if any
def analyse_types(self, env):
entry = env.lookup(self.cppclass)
if entry is None or not entry.is_cpp_class:
error(self.pos, "new operator can only be applied to a C++ class")
return
self.cpp_check(env)
if self.template_parameters is not None:
template_types = [v.analyse_as_type(env) for v in self.template_parameters]
type = entry.type.specialize_here(self.pos, template_types)
else:
type = entry.type
constructor = type.scope.lookup(u'<init>')
if constructor is None:
return_type = PyrexTypes.CFuncType(type, [])
return_type = PyrexTypes.CPtrType(return_type)
type.scope.declare_cfunction(u'<init>', return_type, self.pos)
constructor = type.scope.lookup(u'<init>')
self.class_type = type
self.entry = constructor
self.type = constructor.type
def generate_result_code(self, code):
pass
def calculate_result_code(self):
return "new " + self.class_type.declaration_code("")
class NameNode(AtomicExprNode): class NameNode(AtomicExprNode):
# Reference to a local or global variable name. # Reference to a local or global variable name.
...@@ -1239,7 +1273,8 @@ class NameNode(AtomicExprNode): ...@@ -1239,7 +1273,8 @@ class NameNode(AtomicExprNode):
if entry.is_type and entry.type.is_extension_type: if entry.is_type and entry.type.is_extension_type:
self.type_entry = entry self.type_entry = entry
if not (entry.is_const or entry.is_variable if not (entry.is_const or entry.is_variable
or entry.is_builtin or entry.is_cfunction): or entry.is_builtin or entry.is_cfunction
or entry.is_cpp_class):
if self.entry.as_variable: if self.entry.as_variable:
self.entry = self.entry.as_variable self.entry = self.entry.as_variable
else: else:
...@@ -1767,7 +1802,19 @@ class IndexNode(ExprNode): ...@@ -1767,7 +1802,19 @@ class IndexNode(ExprNode):
def analyse_as_type(self, env): def analyse_as_type(self, env):
base_type = self.base.analyse_as_type(env) base_type = self.base.analyse_as_type(env)
if base_type and not base_type.is_pyobject: if base_type and not base_type.is_pyobject:
return PyrexTypes.CArrayType(base_type, int(self.index.compile_time_value(env))) if base_type.is_cpp_class:
if isinstance(self.index, TupleExprNode):
template_values = self.index.args
else:
template_values = [self.index]
import Nodes
type_node = Nodes.TemplatedTypeNode(
pos = self.pos,
positional_args = template_values,
keyword_args = None)
return type_node.analyse(env, base_type = base_type)
else:
return PyrexTypes.CArrayType(base_type, int(self.index.compile_time_value(env)))
return None return None
def type_dependencies(self, env): def type_dependencies(self, env):
...@@ -1869,18 +1916,33 @@ class IndexNode(ExprNode): ...@@ -1869,18 +1916,33 @@ class IndexNode(ExprNode):
else: else:
if self.base.type.is_ptr or self.base.type.is_array: if self.base.type.is_ptr or self.base.type.is_array:
self.type = self.base.type.base_type self.type = self.base.type.base_type
if self.index.type.is_pyobject:
self.index = self.index.coerce_to(
PyrexTypes.c_py_ssize_t_type, env)
if not self.index.type.is_int:
error(self.pos,
"Invalid index type '%s'" %
self.index.type)
elif self.base.type.is_cpp_class:
function = env.lookup_operator("[]", [self.base, self.index])
function = self.base.type.scope.lookup("operator[]")
if function is None:
error(self.pos, "Indexing '%s' not supported for index type '%s'" % (self.base.type, self.index.type))
self.type = PyrexTypes.error_type
self.result_code = "<error>"
return
func_type = function.type
if func_type.is_ptr:
func_type = func_type.base_type
self.index = self.index.coerce_to(func_type.args[0].type, env)
self.type = func_type.return_type
if setting and not func_type.return_type.is_reference:
error(self.pos, "Can't set non-reference '%s'" % self.type)
else: else:
error(self.pos, error(self.pos,
"Attempting to index non-array type '%s'" % "Attempting to index non-array type '%s'" %
self.base.type) self.base.type)
self.type = PyrexTypes.error_type self.type = PyrexTypes.error_type
if self.index.type.is_pyobject:
self.index = self.index.coerce_to(
PyrexTypes.c_py_ssize_t_type, env)
if not self.index.type.is_int:
error(self.pos,
"Invalid index type '%s'" %
self.index.type)
gil_message = "Indexing Python object" gil_message = "Indexing Python object"
def nogil_check(self, env): def nogil_check(self, env):
...@@ -2515,36 +2577,26 @@ class SimpleCallNode(CallNode): ...@@ -2515,36 +2577,26 @@ class SimpleCallNode(CallNode):
return func_type return func_type
def analyse_c_function_call(self, env): def analyse_c_function_call(self, env):
func_type = self.function_type() if self.function.type.is_cpp_class:
# Check function type function = self.function.type.scope.lookup("operator()")
if not func_type.is_cfunction: if function is None:
if not func_type.is_error: self.type = PyrexTypes.error_type
error(self.pos, "Calling non-function type '%s'" % self.result_code = "<error>"
func_type) return
else:
function = self.function.entry
entry = PyrexTypes.best_match(self.args, function.all_alternatives(), self.pos)
if not entry:
self.type = PyrexTypes.error_type self.type = PyrexTypes.error_type
self.result_code = "<error>" self.result_code = "<error>"
return return
self.function.entry = entry
self.function.type = entry.type
func_type = self.function_type()
# Check no. of args # Check no. of args
max_nargs = len(func_type.args) max_nargs = len(func_type.args)
expected_nargs = max_nargs - func_type.optional_arg_count expected_nargs = max_nargs - func_type.optional_arg_count
actual_nargs = len(self.args) actual_nargs = len(self.args)
if actual_nargs < expected_nargs \
or (not func_type.has_varargs and actual_nargs > max_nargs):
expected_str = str(expected_nargs)
if func_type.has_varargs:
expected_str = "at least " + expected_str
elif func_type.optional_arg_count:
if actual_nargs < max_nargs:
expected_str = "at least " + expected_str
else:
expected_str = "at most " + str(max_nargs)
error(self.pos,
"Call with wrong number of arguments (expected %s, got %s)"
% (expected_str, actual_nargs))
self.args = None
self.type = PyrexTypes.error_type
self.result_code = "<error>"
return
if func_type.optional_arg_count and expected_nargs != actual_nargs: if func_type.optional_arg_count and expected_nargs != actual_nargs:
self.has_optional_args = 1 self.has_optional_args = 1
self.is_temp = 1 self.is_temp = 1
...@@ -2557,7 +2609,10 @@ class SimpleCallNode(CallNode): ...@@ -2557,7 +2609,10 @@ class SimpleCallNode(CallNode):
error(self.args[i].pos, error(self.args[i].pos,
"Python object cannot be passed as a varargs parameter") "Python object cannot be passed as a varargs parameter")
# Calc result type and code fragment # Calc result type and code fragment
self.type = func_type.return_type if isinstance(self.function, NewExprNode):
self.type = PyrexTypes.CPtrType(self.function.class_type)
else:
self.type = func_type.return_type
if self.type.is_pyobject: if self.type.is_pyobject:
self.result_ctype = py_object_type self.result_ctype = py_object_type
self.is_temp = 1 self.is_temp = 1
...@@ -2574,7 +2629,7 @@ class SimpleCallNode(CallNode): ...@@ -2574,7 +2629,7 @@ class SimpleCallNode(CallNode):
def c_call_code(self): def c_call_code(self):
func_type = self.function_type() func_type = self.function_type()
if self.args is None or not func_type.is_cfunction: if self.type is PyrexTypes.error_type or not func_type.is_cfunction:
return "<error>" return "<error>"
formal_args = func_type.args formal_args = func_type.args
arg_list_code = [] arg_list_code = []
...@@ -2874,6 +2929,9 @@ class AttributeNode(ExprNode): ...@@ -2874,6 +2929,9 @@ class AttributeNode(ExprNode):
def as_cython_attribute(self): def as_cython_attribute(self):
if isinstance(self.obj, NameNode) and self.obj.is_cython_module: if isinstance(self.obj, NameNode) and self.obj.is_cython_module:
return self.attribute return self.attribute
cy = self.obj.as_cython_attribute()
if cy:
return "%s.%s" % (cy, self.attribute)
def coerce_to(self, dst_type, env): def coerce_to(self, dst_type, env):
# If coercing to a generic pyobject and this is a cpdef function # If coercing to a generic pyobject and this is a cpdef function
...@@ -4012,6 +4070,7 @@ class UnboundMethodNode(ExprNode): ...@@ -4012,6 +4070,7 @@ class UnboundMethodNode(ExprNode):
code.error_goto_if_null(self.result(), self.pos))) code.error_goto_if_null(self.result(), self.pos)))
code.put_gotref(self.py_result()) code.put_gotref(self.py_result())
class PyCFunctionNode(AtomicExprNode): class PyCFunctionNode(AtomicExprNode):
# Helper class used in the implementation of Python # Helper class used in the implementation of Python
# class definitions. Constructs a PyCFunction object # class definitions. Constructs a PyCFunction object
...@@ -4088,6 +4147,8 @@ class UnopNode(ExprNode): ...@@ -4088,6 +4147,8 @@ class UnopNode(ExprNode):
self.coerce_operand_to_pyobject(env) self.coerce_operand_to_pyobject(env)
self.type = py_object_type self.type = py_object_type
self.is_temp = 1 self.is_temp = 1
elif self.is_cpp_operation():
self.analyse_cpp_operation(env)
else: else:
self.analyse_c_operation(env) self.analyse_c_operation(env)
...@@ -4101,6 +4162,10 @@ class UnopNode(ExprNode): ...@@ -4101,6 +4162,10 @@ class UnopNode(ExprNode):
if self.is_py_operation(): if self.is_py_operation():
self.gil_error() self.gil_error()
def is_cpp_operation(self):
type = self.operand.type
return type.is_cpp_class or type.is_reference and type.base_type.is_cpp_class
def coerce_operand_to_pyobject(self, env): def coerce_operand_to_pyobject(self, env):
self.operand = self.operand.coerce_to_pyobject(env) self.operand = self.operand.coerce_to_pyobject(env)
...@@ -4124,6 +4189,22 @@ class UnopNode(ExprNode): ...@@ -4124,6 +4189,22 @@ class UnopNode(ExprNode):
(self.operator, self.operand.type)) (self.operator, self.operand.type))
self.type = PyrexTypes.error_type self.type = PyrexTypes.error_type
def analyse_cpp_operation(self, env):
type = self.operand.type
if type.is_ptr or type.is_reference:
type = type.base_type
entry = env.lookup(type.name)
function = entry.type.scope.lookup("operator%s" % self.operator)
if not function:
error(self.pos, "'%s' operator not defined for %s"
% (self.operator, type))
self.type_error()
return
func_type = function.type
if func_type.is_ptr:
func_type = func_type.base_type
self.type = func_type.return_type
class NotNode(ExprNode): class NotNode(ExprNode):
# 'not' operator # 'not' operator
...@@ -4170,7 +4251,10 @@ class UnaryPlusNode(UnopNode): ...@@ -4170,7 +4251,10 @@ class UnaryPlusNode(UnopNode):
return "PyNumber_Positive" return "PyNumber_Positive"
def calculate_result_code(self): def calculate_result_code(self):
return self.operand.result() if self.is_cpp_operation():
return "(+%s)" % self.operand.result()
else:
return self.operand.result()
class UnaryMinusNode(UnopNode): class UnaryMinusNode(UnopNode):
...@@ -4216,6 +4300,45 @@ class TildeNode(UnopNode): ...@@ -4216,6 +4300,45 @@ class TildeNode(UnopNode):
return "(~%s)" % self.operand.result() return "(~%s)" % self.operand.result()
class CUnopNode(UnopNode):
def is_py_operation(self):
return False
class DereferenceNode(CUnopNode):
# unary * operator
operator = '*'
def analyse_c_operation(self, env):
if self.operand.type.is_ptr:
self.type = self.operand.type.base_type
else:
self.type_error()
def calculate_result_code(self):
return "(*%s)" % self.operand.result()
class DecrementIncrementNode(CUnopNode):
# unary ++/-- operator
def analyse_c_operation(self, env):
if self.operand.type.is_ptr or self.operand.type.is_numeric:
self.type = self.operand.type
else:
self.type_error()
def calculate_result_code(self):
if self.is_prefix:
return "(%s%s)" % (self.operator, self.operand.result())
else:
return "(%s%s)" % (self.operand.result(), self.operator)
def inc_dec_constructor(is_prefix, operator):
return lambda pos, **kwds: DecrementIncrementNode(pos, is_prefix=is_prefix, operator=operator, **kwds)
class AmpersandNode(ExprNode): class AmpersandNode(ExprNode):
# The C address-of operator. # The C address-of operator.
# #
...@@ -4572,6 +4695,8 @@ class BinopNode(ExprNode): ...@@ -4572,6 +4695,8 @@ class BinopNode(ExprNode):
self.coerce_operands_to_pyobjects(env) self.coerce_operands_to_pyobjects(env)
self.type = py_object_type self.type = py_object_type
self.is_temp = 1 self.is_temp = 1
elif self.is_cpp_operation():
self.analyse_cpp_operation(env)
else: else:
self.analyse_c_operation(env) self.analyse_c_operation(env)
...@@ -4581,6 +4706,33 @@ class BinopNode(ExprNode): ...@@ -4581,6 +4706,33 @@ class BinopNode(ExprNode):
def is_py_operation_types(self, type1, type2): def is_py_operation_types(self, type1, type2):
return type1.is_pyobject or type2.is_pyobject return type1.is_pyobject or type2.is_pyobject
def is_cpp_operation(self):
type1 = self.operand1.type
type2 = self.operand2.type
if type1.is_reference:
type1 = type1.base_type
if type2.is_reference:
type2 = type2.base_type
return (type1.is_cpp_class
or type2.is_cpp_class)
def analyse_cpp_operation(self, env):
type1 = self.operand1.type
type2 = self.operand2.type
entry = env.lookup_operator(self.operator, [self.operand1, self.operand2])
if not entry:
self.type_error()
return
func_type = entry.type
if func_type.is_ptr:
func_type = func_type.base_type
if len(func_type.args) == 1:
self.operand2 = self.operand2.coerce_to(func_type.args[0].type, env)
else:
self.operand1 = self.operand1.coerce_to(func_type.args[0].type, env)
self.operand2 = self.operand2.coerce_to(func_type.args[1].type, env)
self.type = func_type.return_type
def result_type(self, type1, type2): def result_type(self, type1, type2):
if self.is_py_operation_types(type1, type2): if self.is_py_operation_types(type1, type2):
return py_object_type return py_object_type
...@@ -4790,6 +4942,8 @@ class DivNode(NumBinopNode): ...@@ -4790,6 +4942,8 @@ class DivNode(NumBinopNode):
else: else:
self.ctruedivision = self.truedivision self.ctruedivision = self.truedivision
NumBinopNode.analyse_types(self, env) NumBinopNode.analyse_types(self, env)
if self.is_cpp_operation():
self.cdivision = True
if not self.type.is_pyobject: if not self.type.is_pyobject:
self.zerodivision_check = ( self.zerodivision_check = (
self.cdivision is None and not env.directives['cdivision'] self.cdivision is None and not env.directives['cdivision']
...@@ -5184,6 +5338,15 @@ class CmpNode(object): ...@@ -5184,6 +5338,15 @@ class CmpNode(object):
result = result and cascade.compile_time_value(operand2, denv) result = result and cascade.compile_time_value(operand2, denv)
return result return result
def is_cpp_comparison(self):
type1 = self.operand1.type
type2 = self.operand2.type
if type1.is_reference:
type1 = type1.base_type
if type2.is_reference:
type2 = type2.base_type
return type1.is_cpp_class or type2.is_cpp_class
def find_common_int_type(self, env, op, operand1, operand2): def find_common_int_type(self, env, op, operand1, operand2):
# type1 != type2 and at least one of the types is not a C int # type1 != type2 and at least one of the types is not a C int
type1 = operand1.type type1 = operand1.type
...@@ -5442,6 +5605,11 @@ class PrimaryCmpNode(ExprNode, CmpNode): ...@@ -5442,6 +5605,11 @@ class PrimaryCmpNode(ExprNode, CmpNode):
def analyse_types(self, env): def analyse_types(self, env):
self.operand1.analyse_types(env) self.operand1.analyse_types(env)
self.operand2.analyse_types(env) self.operand2.analyse_types(env)
if self.is_cpp_comparison():
self.analyse_cpp_comparison(env)
if self.cascade:
error(self.pos, "Cascading comparison not yet supported for cpp types.")
return
if self.cascade: if self.cascade:
self.cascade.analyse_types(env) self.cascade.analyse_types(env)
...@@ -5470,7 +5638,27 @@ class PrimaryCmpNode(ExprNode, CmpNode): ...@@ -5470,7 +5638,27 @@ class PrimaryCmpNode(ExprNode, CmpNode):
cdr = cdr.cascade cdr = cdr.cascade
if self.is_pycmp or self.cascade: if self.is_pycmp or self.cascade:
self.is_temp = 1 self.is_temp = 1
def analyse_cpp_comparison(self, env):
type1 = self.operand1.type
type2 = self.operand2.type
entry = env.lookup_operator(self.operator, [self.operand1, self.operand2])
if entry is None:
error(self.pos, "Invalid types for '%s' (%s, %s)" %
(self.operator, type1, type2))
self.type = PyrexTypes.error_type
self.result_code = "<error>"
return
func_type = entry.type
if func_type.is_ptr:
func_type = func_type.base_type
if len(func_type.args) == 1:
self.operand2 = self.operand2.coerce_to(func_type.args[0].type, env)
else:
self.operand1 = self.operand1.coerce_to(func_type.args[0].type, env)
self.operand2 = self.operand2.coerce_to(func_type.args[1].type, env)
self.type = func_type.return_type
def has_python_operands(self): def has_python_operands(self):
return (self.operand1.type.is_pyobject return (self.operand1.type.is_pyobject
or self.operand2.type.is_pyobject) or self.operand2.type.is_pyobject)
......
...@@ -66,7 +66,7 @@ class Context(object): ...@@ -66,7 +66,7 @@ class Context(object):
# include_directories [string] # include_directories [string]
# future_directives [object] # future_directives [object]
def __init__(self, include_directories, compiler_directives): def __init__(self, include_directories, compiler_directives, cpp=False):
#self.modules = {"__builtin__" : BuiltinScope()} #self.modules = {"__builtin__" : BuiltinScope()}
import Builtin, CythonScope import Builtin, CythonScope
self.modules = {"__builtin__" : Builtin.builtin_scope} self.modules = {"__builtin__" : Builtin.builtin_scope}
...@@ -74,6 +74,7 @@ class Context(object): ...@@ -74,6 +74,7 @@ class Context(object):
self.include_directories = include_directories self.include_directories = include_directories
self.future_directives = set() self.future_directives = set()
self.compiler_directives = compiler_directives self.compiler_directives = compiler_directives
self.cpp = cpp
self.pxds = {} # full name -> node tree self.pxds = {} # full name -> node tree
...@@ -451,6 +452,7 @@ class Context(object): ...@@ -451,6 +452,7 @@ class Context(object):
if not isinstance(source_desc, FileSourceDescriptor): if not isinstance(source_desc, FileSourceDescriptor):
raise RuntimeError("Only file sources for code supported") raise RuntimeError("Only file sources for code supported")
source_filename = Utils.encode_filename(source_desc.filename) source_filename = Utils.encode_filename(source_desc.filename)
scope.cpp = self.cpp
# Parse the given source file and return a parse tree. # Parse the given source file and return a parse tree.
try: try:
f = Utils.open_source_file(source_filename, "rU") f = Utils.open_source_file(source_filename, "rU")
...@@ -540,7 +542,7 @@ def create_default_resultobj(compilation_source, options): ...@@ -540,7 +542,7 @@ def create_default_resultobj(compilation_source, options):
def run_pipeline(source, options, full_module_name = None): def run_pipeline(source, options, full_module_name = None):
# Set up context # Set up context
context = Context(options.include_path, options.compiler_directives) context = Context(options.include_path, options.compiler_directives, options.cplus)
# Set up source object # Set up source object
cwd = os.getcwd() cwd = os.getcwd()
......
...@@ -616,7 +616,10 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -616,7 +616,10 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
includes = [] includes = []
for filename in env.include_files: for filename in env.include_files:
# fake decoding of filenames to their original byte sequence # fake decoding of filenames to their original byte sequence
code.putln('#include "%s"' % filename) if filename[0] == '<' and filename[-1] == '>':
code.putln('#include %s' % filename)
else:
code.putln('#include "%s"' % filename)
def generate_filename_table(self, code): def generate_filename_table(self, code):
code.putln("") code.putln("")
......
# #
# Pyrex - Parse tree nodes # Pyrex - Parse tree nodes
# #
...@@ -18,7 +19,7 @@ import PyrexTypes ...@@ -18,7 +19,7 @@ import PyrexTypes
import TypeSlots import TypeSlots
from PyrexTypes import py_object_type, error_type, CFuncType from PyrexTypes import py_object_type, error_type, CFuncType
from Symtab import ModuleScope, LocalScope, GeneratorLocalScope, \ from Symtab import ModuleScope, LocalScope, GeneratorLocalScope, \
StructOrUnionScope, PyClassScope, CClassScope StructOrUnionScope, PyClassScope, CClassScope, CppClassScope
from Cython.Utils import open_new_file, replace_suffix from Cython.Utils import open_new_file, replace_suffix
from Code import UtilityCode from Code import UtilityCode
from StringEncoding import EncodedString, escape_byte_string, split_docstring from StringEncoding import EncodedString, escape_byte_string, split_docstring
...@@ -143,6 +144,15 @@ class Node(object): ...@@ -143,6 +144,15 @@ class Node(object):
def gil_error(self, env=None): def gil_error(self, env=None):
error(self.pos, "%s not allowed without gil" % self.gil_message) error(self.pos, "%s not allowed without gil" % self.gil_message)
cpp_message = "Operation"
def cpp_check(self, env):
if not env.is_cpp():
self.cpp_error()
def cpp_error(self):
error(self.pos, "%s only allowed in c++" % self.cpp_message)
def clone_node(self): def clone_node(self):
"""Clone the node. This is defined as a shallow copy, except for member lists """Clone the node. This is defined as a shallow copy, except for member lists
...@@ -447,7 +457,19 @@ class CPtrDeclaratorNode(CDeclaratorNode): ...@@ -447,7 +457,19 @@ class CPtrDeclaratorNode(CDeclaratorNode):
"Pointer base type cannot be a Python object") "Pointer base type cannot be a Python object")
ptr_type = PyrexTypes.c_ptr_type(base_type) ptr_type = PyrexTypes.c_ptr_type(base_type)
return self.base.analyse(ptr_type, env, nonempty = nonempty) return self.base.analyse(ptr_type, env, nonempty = nonempty)
class CReferenceDeclaratorNode(CDeclaratorNode):
# base CDeclaratorNode
child_attrs = ["base"]
def analyse(self, base_type, env, nonempty = 0):
if base_type.is_pyobject:
error(self.pos,
"Reference base type cannot be a Python object")
ref_type = PyrexTypes.c_ref_type(base_type)
return self.base.analyse(ref_type, env, nonempty = nonempty)
class CArrayDeclaratorNode(CDeclaratorNode): class CArrayDeclaratorNode(CDeclaratorNode):
# base CDeclaratorNode # base CDeclaratorNode
# dimension ExprNode # dimension ExprNode
...@@ -455,6 +477,19 @@ class CArrayDeclaratorNode(CDeclaratorNode): ...@@ -455,6 +477,19 @@ class CArrayDeclaratorNode(CDeclaratorNode):
child_attrs = ["base", "dimension"] child_attrs = ["base", "dimension"]
def analyse(self, base_type, env, nonempty = 0): def analyse(self, base_type, env, nonempty = 0):
if base_type.is_cpp_class:
from ExprNodes import TupleNode
if isinstance(self.dimension, TupleNode):
args = self.dimension.args
else:
args = self.dimension,
values = [v.analyse_as_type(env) for v in args]
if None in values:
ix = values.index(None)
error(args[ix].pos, "Template parameter not a type.")
return error_type
base_type = base_type.specialize_here(self.pos, values)
return self.base.analyse(base_type, env, nonempty = nonempty)
if self.dimension: if self.dimension:
self.dimension.analyse_const_expression(env) self.dimension.analyse_const_expression(env)
if not self.dimension.type.is_int: if not self.dimension.type.is_int:
...@@ -655,6 +690,9 @@ class CBaseTypeNode(Node): ...@@ -655,6 +690,9 @@ class CBaseTypeNode(Node):
pass pass
def analyse_as_type(self, env):
return self.analyse(env)
class CAnalysedBaseTypeNode(Node): class CAnalysedBaseTypeNode(Node):
# type type # type type
...@@ -714,7 +752,12 @@ class CSimpleBaseTypeNode(CBaseTypeNode): ...@@ -714,7 +752,12 @@ class CSimpleBaseTypeNode(CBaseTypeNode):
type = py_object_type type = py_object_type
self.arg_name = self.name self.arg_name = self.name
else: else:
error(self.pos, "'%s' is not a type identifier" % self.name) if self.templates:
if not self.name in self.templates:
error(self.pos, "'%s' is not a type identifier" % self.name)
type = PyrexTypes.TemplatePlaceholderType(self.name)
else:
error(self.pos, "'%s' is not a type identifier" % self.name)
if self.complex: if self.complex:
if not type.is_numeric or type.is_complex: if not type.is_numeric or type.is_complex:
error(self.pos, "can only complexify c numeric types") error(self.pos, "can only complexify c numeric types")
...@@ -725,14 +768,14 @@ class CSimpleBaseTypeNode(CBaseTypeNode): ...@@ -725,14 +768,14 @@ class CSimpleBaseTypeNode(CBaseTypeNode):
else: else:
return PyrexTypes.error_type return PyrexTypes.error_type
class CBufferAccessTypeNode(CBaseTypeNode): class TemplatedTypeNode(CBaseTypeNode):
# After parsing: # After parsing:
# positional_args [ExprNode] List of positional arguments # positional_args [ExprNode] List of positional arguments
# keyword_args DictNode Keyword arguments # keyword_args DictNode Keyword arguments
# base_type_node CBaseTypeNode # base_type_node CBaseTypeNode
# After analysis: # After analysis:
# type PyrexType.BufferType ...containing the right options # type PyrexTypes.BufferType or PyrexTypes.CppClassType ...containing the right options
child_attrs = ["base_type_node", "positional_args", child_attrs = ["base_type_node", "positional_args",
...@@ -742,24 +785,38 @@ class CBufferAccessTypeNode(CBaseTypeNode): ...@@ -742,24 +785,38 @@ class CBufferAccessTypeNode(CBaseTypeNode):
name = None name = None
def analyse(self, env, could_be_name = False): def analyse(self, env, could_be_name = False, base_type = None):
base_type = self.base_type_node.analyse(env) if base_type is None:
base_type = self.base_type_node.analyse(env)
if base_type.is_error: return base_type if base_type.is_error: return base_type
import Buffer
if base_type.is_cpp_class:
if len(self.keyword_args.key_value_pairs) != 0:
error(self.pos, "c++ templates cannot take keyword arguments");
self.type = PyrexTypes.error_type
else:
template_types = []
for template_node in self.positional_args:
template_types.append(template_node.analyse_as_type(env))
self.type = base_type.specialize_here(self.pos, template_types)
else:
import Buffer
options = Buffer.analyse_buffer_options( options = Buffer.analyse_buffer_options(
self.pos, self.pos,
env, env,
self.positional_args, self.positional_args,
self.keyword_args, self.keyword_args,
base_type.buffer_defaults) base_type.buffer_defaults)
if sys.version_info[0] < 3: if sys.version_info[0] < 3:
# Py 2.x enforces byte strings as keyword arguments ... # Py 2.x enforces byte strings as keyword arguments ...
options = dict([ (name.encode('ASCII'), value) options = dict([ (name.encode('ASCII'), value)
for name, value in options.iteritems() ]) for name, value in options.iteritems() ])
self.type = PyrexTypes.BufferType(base_type, **options) self.type = PyrexTypes.BufferType(base_type, **options)
return self.type return self.type
class CComplexBaseTypeNode(CBaseTypeNode): class CComplexBaseTypeNode(CBaseTypeNode):
...@@ -904,6 +961,46 @@ class CStructOrUnionDefNode(StatNode): ...@@ -904,6 +961,46 @@ class CStructOrUnionDefNode(StatNode):
pass pass
class CppClassNode(CStructOrUnionDefNode):
# name string
# cname string or None
# visibility "extern"
# in_pxd boolean
# attributes [CVarDefNode] or None
# entry Entry
# base_classes [string]
# templates [string] or None
def analyse_declarations(self, env):
scope = None
if len(self.attributes) != 0:
scope = CppClassScope(self.name, env)
else:
self.attributes = None
base_class_types = []
for base_class_name in self.base_classes:
base_class_entry = env.lookup(base_class_name)
if base_class_entry is None:
error(self.pos, "'%s' not found" % base_class_name)
elif not base_class_entry.is_type or not base_class_entry.type.is_cpp_class:
error(self.pos, "'%s' is not a cpp class type" % base_class_name)
else:
base_class_types.append(base_class_entry.type)
if self.templates is None:
template_types = None
else:
template_types = [PyrexTypes.TemplatePlaceholderType(template_name) for template_name in self.templates]
self.entry = env.declare_cpp_class(
self.name, scope, self.pos,
self.cname, base_class_types, visibility = self.visibility, templates = template_types)
self.entry.is_cpp_class = 1
if self.attributes is not None:
if self.in_pxd and not env.in_cinclude:
self.entry.defined_in_pxd = 1
for attr in self.attributes:
attr.analyse_declarations(scope)
class CEnumDefNode(StatNode): class CEnumDefNode(StatNode):
# name string or None # name string or None
# cname string or None # cname string or None
...@@ -3391,8 +3488,14 @@ class DelStatNode(StatNode): ...@@ -3391,8 +3488,14 @@ class DelStatNode(StatNode):
def analyse_expressions(self, env): def analyse_expressions(self, env):
for arg in self.args: for arg in self.args:
arg.analyse_target_expression(env, None) arg.analyse_target_expression(env, None)
if not arg.type.is_pyobject: if arg.type.is_pyobject:
error(arg.pos, "Deletion of non-Python object") pass
elif arg.type.is_ptr and arg.type.base_type.is_cpp_class:
self.cpp_check(env)
elif arg.type.is_cpp_class:
error(arg.pos, "Deletion of non-heap C++ object")
else:
error(arg.pos, "Deletion of non-Python, non-C++ object")
#arg.release_target_temp(env) #arg.release_target_temp(env)
def nogil_check(self, env): def nogil_check(self, env):
...@@ -3406,6 +3509,9 @@ class DelStatNode(StatNode): ...@@ -3406,6 +3509,9 @@ class DelStatNode(StatNode):
for arg in self.args: for arg in self.args:
if arg.type.is_pyobject: if arg.type.is_pyobject:
arg.generate_deletion_code(code) arg.generate_deletion_code(code)
elif arg.type.is_ptr and arg.type.base_type.is_cpp_class:
arg.generate_result_code(code)
code.putln("delete %s;" % arg.result())
# else error reported earlier # else error reported earlier
def annotate(self, code): def annotate(self, code):
......
...@@ -128,7 +128,6 @@ class PostParseError(CompileError): pass ...@@ -128,7 +128,6 @@ class PostParseError(CompileError): pass
# error strings checked by unit tests, so define them # error strings checked by unit tests, so define them
ERR_CDEF_INCLASS = 'Cannot assign default value to fields in cdef classes, structs or unions' ERR_CDEF_INCLASS = 'Cannot assign default value to fields in cdef classes, structs or unions'
ERR_BUF_LOCALONLY = 'Buffer types only allowed as function local variables'
ERR_BUF_DEFAULTS = 'Invalid buffer defaults specification (see docs)' ERR_BUF_DEFAULTS = 'Invalid buffer defaults specification (see docs)'
ERR_INVALID_SPECIALATTR_TYPE = 'Special attributes must not have a type declared' ERR_INVALID_SPECIALATTR_TYPE = 'Special attributes must not have a type declared'
class PostParse(CythonTransform): class PostParse(CythonTransform):
...@@ -145,7 +144,7 @@ class PostParse(CythonTransform): ...@@ -145,7 +144,7 @@ class PostParse(CythonTransform):
- Interpret some node structures into Python runtime values. - Interpret some node structures into Python runtime values.
Some nodes take compile-time arguments (currently: Some nodes take compile-time arguments (currently:
CBufferAccessTypeNode[args] and __cythonbufferdefaults__ = {args}), TemplatedTypeNode[args] and __cythonbufferdefaults__ = {args}),
which should be interpreted. This happens in a general way which should be interpreted. This happens in a general way
and other steps should be taken to ensure validity. and other steps should be taken to ensure validity.
...@@ -154,7 +153,7 @@ class PostParse(CythonTransform): ...@@ -154,7 +153,7 @@ class PostParse(CythonTransform):
- For __cythonbufferdefaults__ the arguments are checked for - For __cythonbufferdefaults__ the arguments are checked for
validity. validity.
CBufferAccessTypeNode has its directives interpreted: TemplatedTypeNode has its directives interpreted:
Any first positional argument goes into the "dtype" attribute, Any first positional argument goes into the "dtype" attribute,
any "ndim" keyword argument goes into the "ndim" attribute and any "ndim" keyword argument goes into the "ndim" attribute and
so on. Also it is checked that the directive combination is valid. so on. Also it is checked that the directive combination is valid.
...@@ -243,11 +242,6 @@ class PostParse(CythonTransform): ...@@ -243,11 +242,6 @@ class PostParse(CythonTransform):
self.context.nonfatal_error(e) self.context.nonfatal_error(e)
return None return None
def visit_CBufferAccessTypeNode(self, node):
if not self.scope_type == 'function':
raise PostParseError(node.pos, ERR_BUF_LOCALONLY)
return node
class PxdPostParse(CythonTransform, SkipDeclarations): class PxdPostParse(CythonTransform, SkipDeclarations):
""" """
Basic interpretation/validity checking that should only be Basic interpretation/validity checking that should only be
...@@ -329,7 +323,22 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations): ...@@ -329,7 +323,22 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
duplication of functionality has to occur: We manually track cimports duplication of functionality has to occur: We manually track cimports
and which names the "cython" module may have been imported to. and which names the "cython" module may have been imported to.
""" """
special_methods = set(['declare', 'union', 'struct', 'typedef', 'sizeof', 'typeof', 'cast', 'address', 'pointer', 'compiled', 'NULL']) unop_method_nodes = {
'typeof': TypeofNode,
'operator.address': AmpersandNode,
'operator.dereference': DereferenceNode,
'operator.preincrement' : inc_dec_constructor(True, '++'),
'operator.predecrement' : inc_dec_constructor(True, '--'),
'operator.postincrement': inc_dec_constructor(False, '++'),
'operator.postdecrement': inc_dec_constructor(False, '--'),
# For backwards compatability.
'address': AmpersandNode,
}
special_methods = set(['declare', 'union', 'struct', 'typedef', 'sizeof', 'cast', 'pointer', 'compiled', 'NULL']
+ unop_method_nodes.keys())
def __init__(self, context, compilation_directive_defaults): def __init__(self, context, compilation_directive_defaults):
super(InterpretCompilerDirectives, self).__init__(context) super(InterpretCompilerDirectives, self).__init__(context)
...@@ -372,18 +381,33 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations): ...@@ -372,18 +381,33 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
else: else:
modname = u"cython" modname = u"cython"
self.cython_module_names.add(modname) self.cython_module_names.add(modname)
return node elif node.module_name.startswith(u"cython."):
if node.as_name:
self.directive_names[node.as_name] = node.module_name[7:]
else:
self.cython_module_names.add(u"cython")
else:
return node
def visit_FromCImportStatNode(self, node): def visit_FromCImportStatNode(self, node):
if node.module_name == u"cython": if node.module_name.startswith(u"cython."):
is_cython_module = True
submodule = node.module_name[7:] + u"."
elif node.module_name == u"cython":
is_cython_module = True
submodule = u""
else:
is_cython_module = False
if is_cython_module:
newimp = [] newimp = []
for pos, name, as_name, kind in node.imported_names: for pos, name, as_name, kind in node.imported_names:
if (name in Options.directive_types or full_name = submodule + name
name in self.special_methods or if (full_name in Options.directive_types or
PyrexTypes.parse_basic_type(name)): full_name in self.special_methods or
PyrexTypes.parse_basic_type(full_name)):
if as_name is None: if as_name is None:
as_name = name as_name = full_name
self.directive_names[as_name] = name self.directive_names[as_name] = full_name
if kind is not None: if kind is not None:
self.context.nonfatal_error(PostParseError(pos, self.context.nonfatal_error(PostParseError(pos,
"Compiler directive imports must be plain imports")) "Compiler directive imports must be plain imports"))
...@@ -395,13 +419,22 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations): ...@@ -395,13 +419,22 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
return node return node
def visit_FromImportStatNode(self, node): def visit_FromImportStatNode(self, node):
if node.module.module_name.value == u"cython": if node.module.module_name.value.startswith(u"cython."):
is_cython_module = True
submodule = node.module.module_name.value[7:] + u"."
elif node.module.module_name.value == u"cython":
is_cython_module = True
submodule = u""
else:
is_cython_module = False
if is_cython_module:
newimp = [] newimp = []
for name, name_node in node.items: for name, name_node in node.items:
if (name in Options.directive_types or full_name = submodule + name
name in self.special_methods or if (full_name in Options.directive_types or
PyrexTypes.parse_basic_type(name)): full_name in self.special_methods or
self.directive_names[name_node.name] = name PyrexTypes.parse_basic_type(full_name)):
self.directive_names[name_node.name] = full_name
else: else:
newimp.append((name, name_node)) newimp.append((name, name_node))
if not newimp: if not newimp:
...@@ -1016,7 +1049,12 @@ class TransformBuiltinMethods(EnvTransform): ...@@ -1016,7 +1049,12 @@ class TransformBuiltinMethods(EnvTransform):
# cython.foo # cython.foo
function = node.function.as_cython_attribute() function = node.function.as_cython_attribute()
if function: if function:
if function == u'cast': if function in InterpretCompilerDirectives.unop_method_nodes:
if len(node.args) != 1:
error(node.function.pos, u"%s() takes exactly one argument" % function)
else:
node = InterpretCompilerDirectives.unop_method_nodes[function](node.function.pos, operand=node.args[0])
elif function == u'cast':
if len(node.args) != 2: if len(node.args) != 2:
error(node.function.pos, u"cast() takes exactly two arguments") error(node.function.pos, u"cast() takes exactly two arguments")
else: else:
...@@ -1034,16 +1072,6 @@ class TransformBuiltinMethods(EnvTransform): ...@@ -1034,16 +1072,6 @@ class TransformBuiltinMethods(EnvTransform):
node = SizeofTypeNode(node.function.pos, arg_type=type) node = SizeofTypeNode(node.function.pos, arg_type=type)
else: else:
node = SizeofVarNode(node.function.pos, operand=node.args[0]) node = SizeofVarNode(node.function.pos, operand=node.args[0])
elif function == 'typeof':
if len(node.args) != 1:
error(node.function.pos, u"typeof() takes exactly one argument")
else:
node = TypeofNode(node.function.pos, operand=node.args[0])
elif function == 'address':
if len(node.args) != 1:
error(node.function.pos, u"address() takes exactly one argument")
else:
node = AmpersandNode(node.function.pos, operand=node.args[0])
elif function == 'cmod': elif function == 'cmod':
if len(node.args) != 2: if len(node.args) != 2:
error(node.function.pos, u"cmod() takes exactly two arguments") error(node.function.pos, u"cmod() takes exactly two arguments")
......
...@@ -28,6 +28,7 @@ cpdef p_typecast(PyrexScanner s) ...@@ -28,6 +28,7 @@ cpdef p_typecast(PyrexScanner s)
cpdef p_sizeof(PyrexScanner s) cpdef p_sizeof(PyrexScanner s)
cpdef p_yield_expression(PyrexScanner s) cpdef p_yield_expression(PyrexScanner s)
cpdef p_power(PyrexScanner s) cpdef p_power(PyrexScanner s)
cpdef p_new_expr(PyrexScanner s)
cpdef p_trailer(PyrexScanner s, node1) cpdef p_trailer(PyrexScanner s, node1)
cpdef p_call(PyrexScanner s, function) cpdef p_call(PyrexScanner s, function)
cpdef p_index(PyrexScanner s, base) cpdef p_index(PyrexScanner s, base)
...@@ -149,3 +150,4 @@ cpdef p_doc_string(PyrexScanner s) ...@@ -149,3 +150,4 @@ cpdef p_doc_string(PyrexScanner s)
cpdef p_code(PyrexScanner s, level= *) cpdef p_code(PyrexScanner s, level= *)
cpdef p_compiler_directive_comments(PyrexScanner s) cpdef p_compiler_directive_comments(PyrexScanner s)
cpdef p_module(PyrexScanner s, pxd, full_module_name) cpdef p_module(PyrexScanner s, pxd, full_module_name)
cpdef p_cpp_class_definition(PyrexScanner s, ctx)
...@@ -30,6 +30,8 @@ class Ctx(object): ...@@ -30,6 +30,8 @@ class Ctx(object):
api = 0 api = 0
overridable = 0 overridable = 0
nogil = 0 nogil = 0
namespace = None
templates = None
def __init__(self, **kwds): def __init__(self, **kwds):
self.__dict__.update(kwds) self.__dict__.update(kwds)
...@@ -298,6 +300,8 @@ def p_yield_expression(s): ...@@ -298,6 +300,8 @@ def p_yield_expression(s):
#power: atom trailer* ('**' factor)* #power: atom trailer* ('**' factor)*
def p_power(s): def p_power(s):
if s.systring == 'new' and s.peek()[0] == 'IDENT':
return p_new_expr(s)
n1 = p_atom(s) n1 = p_atom(s)
while s.sy in ('(', '[', '.'): while s.sy in ('(', '[', '.'):
n1 = p_trailer(s, n1) n1 = p_trailer(s, n1)
...@@ -308,6 +312,19 @@ def p_power(s): ...@@ -308,6 +312,19 @@ def p_power(s):
n1 = ExprNodes.binop_node(pos, '**', n1, n2) n1 = ExprNodes.binop_node(pos, '**', n1, n2)
return n1 return n1
def p_new_expr(s):
# s.systring == 'new'.
pos = s.position()
s.next()
name = p_ident(s)
if s.sy == '[':
s.next()
template_parameters = p_simple_expr_list(s)
s.expect(']')
else:
template_parameters = None
return p_call(s, ExprNodes.NewExprNode(pos, cppclass = name, template_parameters = template_parameters))
#trailer: '(' [arglist] ')' | '[' subscriptlist ']' | '.' NAME #trailer: '(' [arglist] ')' | '[' subscriptlist ']' | '.' NAME
def p_trailer(s, node1): def p_trailer(s, node1):
...@@ -1458,6 +1475,29 @@ def p_with_statement(s): ...@@ -1458,6 +1475,29 @@ def p_with_statement(s):
s.next() s.next()
body = p_suite(s) body = p_suite(s)
return Nodes.GILStatNode(pos, state = state, body = body) return Nodes.GILStatNode(pos, state = state, body = body)
elif s.systring == 'template':
templates = []
s.next()
s.expect('[')
#s.next()
templates.append(s.systring)
s.next()
while s.systring == ',':
s.next()
templates.append(s.systring)
s.next()
s.expect(']')
if s.sy == ':':
s.next()
s.expect_newline("Syntax error in template function declaration")
s.expect_indent()
body_ctx = Ctx()
body_ctx.templates = templates
func_or_var = p_c_func_or_var_declaration(s, pos, body_ctx)
s.expect_dedent()
return func_or_var
else:
error(pos, "Syntax error in template function declaration")
else: else:
manager = p_expr(s) manager = p_expr(s)
target = None target = None
...@@ -1748,13 +1788,13 @@ def p_positional_and_keyword_args(s, end_sy_set, type_positions=(), type_keyword ...@@ -1748,13 +1788,13 @@ def p_positional_and_keyword_args(s, end_sy_set, type_positions=(), type_keyword
s.next() s.next()
return positional_args, keyword_args return positional_args, keyword_args
def p_c_base_type(s, self_flag = 0, nonempty = 0): def p_c_base_type(s, self_flag = 0, nonempty = 0, templates = None):
# If self_flag is true, this is the base type for the # If self_flag is true, this is the base type for the
# self argument of a C method of an extension type. # self argument of a C method of an extension type.
if s.sy == '(': if s.sy == '(':
return p_c_complex_base_type(s) return p_c_complex_base_type(s)
else: else:
return p_c_simple_base_type(s, self_flag, nonempty = nonempty) return p_c_simple_base_type(s, self_flag, nonempty = nonempty, templates = templates)
def p_calling_convention(s): def p_calling_convention(s):
if s.sy == 'IDENT' and s.systring in calling_convention_words: if s.sy == 'IDENT' and s.systring in calling_convention_words:
...@@ -1776,7 +1816,7 @@ def p_c_complex_base_type(s): ...@@ -1776,7 +1816,7 @@ def p_c_complex_base_type(s):
return Nodes.CComplexBaseTypeNode(pos, return Nodes.CComplexBaseTypeNode(pos,
base_type = base_type, declarator = declarator) base_type = base_type, declarator = declarator)
def p_c_simple_base_type(s, self_flag, nonempty): def p_c_simple_base_type(s, self_flag, nonempty, templates = None):
#print "p_c_simple_base_type: self_flag =", self_flag, nonempty #print "p_c_simple_base_type: self_flag =", self_flag, nonempty
is_basic = 0 is_basic = 0
signed = 1 signed = 1
...@@ -1827,12 +1867,12 @@ def p_c_simple_base_type(s, self_flag, nonempty): ...@@ -1827,12 +1867,12 @@ def p_c_simple_base_type(s, self_flag, nonempty):
elif s.sy not in ('*', '**', '['): elif s.sy not in ('*', '**', '['):
s.put_back('IDENT', name) s.put_back('IDENT', name)
name = None name = None
type_node = Nodes.CSimpleBaseTypeNode(pos, type_node = Nodes.CSimpleBaseTypeNode(pos,
name = name, module_path = module_path, name = name, module_path = module_path,
is_basic_c_type = is_basic, signed = signed, is_basic_c_type = is_basic, signed = signed,
complex = complex, longness = longness, complex = complex, longness = longness,
is_self_arg = self_flag) is_self_arg = self_flag, templates = templates)
# Treat trailing [] on type as buffer access if it appears in a context # Treat trailing [] on type as buffer access if it appears in a context
...@@ -1842,11 +1882,11 @@ def p_c_simple_base_type(s, self_flag, nonempty): ...@@ -1842,11 +1882,11 @@ def p_c_simple_base_type(s, self_flag, nonempty):
# (This means that buffers cannot occur where there can be empty declarators, # (This means that buffers cannot occur where there can be empty declarators,
# which is an ok restriction to make.) # which is an ok restriction to make.)
if nonempty and s.sy == '[': if nonempty and s.sy == '[':
return p_buffer_access(s, type_node) return p_buffer_or_template(s, type_node)
else: else:
return type_node return type_node
def p_buffer_access(s, base_type_node): def p_buffer_or_template(s, base_type_node):
# s.sy == '[' # s.sy == '['
pos = s.position() pos = s.position()
s.next() s.next()
...@@ -1860,8 +1900,7 @@ def p_buffer_access(s, base_type_node): ...@@ -1860,8 +1900,7 @@ def p_buffer_access(s, base_type_node):
ExprNodes.DictItemNode(pos=key.pos, key=key, value=value) ExprNodes.DictItemNode(pos=key.pos, key=key, value=value)
for key, value in keyword_args for key, value in keyword_args
]) ])
result = Nodes.TemplatedTypeNode(pos,
result = Nodes.CBufferAccessTypeNode(pos,
positional_args = positional_args, positional_args = positional_args,
keyword_args = keyword_dict, keyword_args = keyword_dict,
base_type_node = base_type_node) base_type_node = base_type_node)
...@@ -2018,6 +2057,13 @@ def p_c_func_declarator(s, pos, ctx, base, cmethod_flag): ...@@ -2018,6 +2057,13 @@ def p_c_func_declarator(s, pos, ctx, base, cmethod_flag):
exception_value = exc_val, exception_check = exc_check, exception_value = exc_val, exception_check = exc_check,
nogil = nogil or ctx.nogil or with_gil, with_gil = with_gil) nogil = nogil or ctx.nogil or with_gil, with_gil = with_gil)
supported_overloaded_operators = set([
'+', '-', '*', '/', '%',
'++', '--', '~', '|', '&', '^', '<<', '>>',
'==', '!=', '>=', '>', '<=', '<',
'[]', '()',
])
def p_c_simple_declarator(s, ctx, empty, is_type, cmethod_flag, def p_c_simple_declarator(s, ctx, empty, is_type, cmethod_flag,
assignable, nonempty): assignable, nonempty):
pos = s.position() pos = s.position()
...@@ -2037,6 +2083,12 @@ def p_c_simple_declarator(s, ctx, empty, is_type, cmethod_flag, ...@@ -2037,6 +2083,12 @@ def p_c_simple_declarator(s, ctx, empty, is_type, cmethod_flag,
result = Nodes.CPtrDeclaratorNode(pos, result = Nodes.CPtrDeclaratorNode(pos,
base = Nodes.CPtrDeclaratorNode(pos, base = Nodes.CPtrDeclaratorNode(pos,
base = base)) base = base))
elif s.sy == '&':
s.next()
base = p_c_declarator(s, ctx, empty = empty, is_type = is_type,
cmethod_flag = cmethod_flag,
assignable = assignable, nonempty = nonempty)
result = Nodes.CReferenceDeclaratorNode(pos, base = base)
else: else:
rhs = None rhs = None
if s.sy == 'IDENT': if s.sy == 'IDENT':
...@@ -2053,6 +2105,27 @@ def p_c_simple_declarator(s, ctx, empty, is_type, cmethod_flag, ...@@ -2053,6 +2105,27 @@ def p_c_simple_declarator(s, ctx, empty, is_type, cmethod_flag,
error(s.position(), "Empty declarator") error(s.position(), "Empty declarator")
name = "" name = ""
cname = None cname = None
if cname is None and ctx.namespace is not None:
cname = ctx.namespace + "::" + name
if name == 'operator' and ctx.visibility == 'extern':
op = s.sy
s.next()
# Handle diphthong operators.
if op == '(':
s.expect(')')
op = '()'
elif op == '[':
s.expect(']')
op = '[]'
if op in ['-', '+', '|', '&'] and s.sy == op:
op = op*2
s.next()
if s.sy == '=':
op += s.sy
s.next()
if op not in supported_overloaded_operators:
s.error("Overloading operator '%s' not yet supported." % op)
name = name+op
result = Nodes.CNameDeclaratorNode(pos, result = Nodes.CNameDeclaratorNode(pos,
name = name, cname = cname, default = rhs) name = name, cname = cname, default = rhs)
result.calling_convention = calling_convention result.calling_convention = calling_convention
...@@ -2184,6 +2257,10 @@ def p_cdef_statement(s, ctx): ...@@ -2184,6 +2257,10 @@ def p_cdef_statement(s, ctx):
if ctx.overridable: if ctx.overridable:
error(pos, "Extension types cannot be declared cpdef") error(pos, "Extension types cannot be declared cpdef")
return p_c_class_definition(s, pos, ctx) return p_c_class_definition(s, pos, ctx)
elif s.sy == 'IDENT' and s.systring == 'cppclass':
if ctx.visibility != 'extern':
error(pos, "C++ classes need to be declared extern")
return p_cpp_class_definition(s, pos, ctx)
elif s.sy == 'IDENT' and s.systring in ("struct", "union", "enum", "packed"): elif s.sy == 'IDENT' and s.systring in ("struct", "union", "enum", "packed"):
if ctx.level not in ('module', 'module_pxd'): if ctx.level not in ('module', 'module_pxd'):
error(pos, "C struct/union/enum definition not allowed here") error(pos, "C struct/union/enum definition not allowed here")
...@@ -2208,13 +2285,17 @@ def p_cdef_extern_block(s, pos, ctx): ...@@ -2208,13 +2285,17 @@ def p_cdef_extern_block(s, pos, ctx):
s.next() s.next()
else: else:
_, include_file = p_string_literal(s) _, include_file = p_string_literal(s)
if s.systring == "namespace":
s.next()
ctx.namespace = p_dotted_name(s, as_allowed=False)[2].replace('.', '::')
ctx = ctx(cdef_flag = 1, visibility = 'extern') ctx = ctx(cdef_flag = 1, visibility = 'extern')
if p_nogil(s): if p_nogil(s):
ctx.nogil = 1 ctx.nogil = 1
body = p_suite(s, ctx) body = p_suite(s, ctx)
return Nodes.CDefExternNode(pos, return Nodes.CDefExternNode(pos,
include_file = include_file, include_file = include_file,
body = body) body = body,
namespace = ctx.namespace)
def p_c_enum_definition(s, pos, ctx): def p_c_enum_definition(s, pos, ctx):
# s.sy == ident 'enum' # s.sy == ident 'enum'
...@@ -2223,6 +2304,8 @@ def p_c_enum_definition(s, pos, ctx): ...@@ -2223,6 +2304,8 @@ def p_c_enum_definition(s, pos, ctx):
name = s.systring name = s.systring
s.next() s.next()
cname = p_opt_cname(s) cname = p_opt_cname(s)
if cname is None and ctx.namespace is not None:
cname = ctx.namespace + "::" + name
else: else:
name = None name = None
cname = None cname = None
...@@ -2277,6 +2360,8 @@ def p_c_struct_or_union_definition(s, pos, ctx): ...@@ -2277,6 +2360,8 @@ def p_c_struct_or_union_definition(s, pos, ctx):
s.next() s.next()
name = p_ident(s) name = p_ident(s)
cname = p_opt_cname(s) cname = p_opt_cname(s)
if cname is None and ctx.namespace is not None:
cname = ctx.namespace + "::" + name
attributes = None attributes = None
if s.sy == ':': if s.sy == ':':
s.next() s.next()
...@@ -2320,12 +2405,12 @@ def p_c_modifiers(s): ...@@ -2320,12 +2405,12 @@ def p_c_modifiers(s):
def p_c_func_or_var_declaration(s, pos, ctx): def p_c_func_or_var_declaration(s, pos, ctx):
cmethod_flag = ctx.level in ('c_class', 'c_class_pxd') cmethod_flag = ctx.level in ('c_class', 'c_class_pxd')
modifiers = p_c_modifiers(s) modifiers = p_c_modifiers(s)
base_type = p_c_base_type(s, nonempty = 1) base_type = p_c_base_type(s, nonempty = 1, templates = ctx.templates)
declarator = p_c_declarator(s, ctx, cmethod_flag = cmethod_flag, declarator = p_c_declarator(s, ctx, cmethod_flag = cmethod_flag,
assignable = 1, nonempty = 1) assignable = 1, nonempty = 1)
declarator.overridable = ctx.overridable declarator.overridable = ctx.overridable
if s.sy == ':': if s.sy == ':':
if ctx.level not in ('module', 'c_class', 'module_pxd', 'c_class_pxd'): if ctx.level not in ('module', 'c_class', 'module_pxd', 'c_class_pxd') and not ctx.templates:
s.error("C function definition not allowed here") s.error("C function definition not allowed here")
doc, suite = p_suite(s, Ctx(level = 'function'), with_doc = 1) doc, suite = p_suite(s, Ctx(level = 'function'), with_doc = 1)
result = Nodes.CFuncDefNode(pos, result = Nodes.CFuncDefNode(pos,
...@@ -2615,6 +2700,64 @@ def p_module(s, pxd, full_module_name): ...@@ -2615,6 +2700,64 @@ def p_module(s, pxd, full_module_name):
full_module_name = full_module_name, full_module_name = full_module_name,
directive_comments = directive_comments) directive_comments = directive_comments)
def p_cpp_class_definition(s, pos, ctx):
# s.sy == 'cppclass'
s.next()
module_path = []
class_name = p_ident(s)
cname = p_opt_cname(s)
if cname is None and ctx.namespace is not None:
cname = ctx.namespace + "::" + class_name
if s.sy == '.':
error(pos, "Qualified class name not allowed C++ class")
if s.sy == '[':
s.next()
templates = [p_ident(s)]
while s.sy == ',':
s.next()
templates.append(p_ident(s))
s.expect(']')
else:
templates = None
if s.sy == '(':
s.next()
base_classes = [p_dotted_name(s, False)[2]]
while s.sy == ',':
s.next()
base_classes.append(p_dotted_name(s, False)[2])
s.expect(')')
else:
base_classes = []
if s.sy == '[':
error(s.position(), "Name options not allowed for C++ class")
if s.sy == ':':
s.next()
s.expect('NEWLINE')
s.expect_indent()
attributes = []
body_ctx = Ctx(visibility = ctx.visibility)
body_ctx.templates = templates
while s.sy != 'DEDENT':
if s.sy != 'pass':
attributes.append(
p_c_func_or_var_declaration(s, s.position(), body_ctx))
else:
s.next()
s.expect_newline("Expected a newline")
s.expect_dedent()
else:
s.expect_newline("Syntax error in C++ class definition")
return Nodes.CppClassNode(pos,
name = class_name,
cname = cname,
base_classes = base_classes,
visibility = ctx.visibility,
in_pxd = ctx.level == 'module_pxd',
attributes = attributes,
templates = templates)
#---------------------------------------------- #----------------------------------------------
# #
# Debugging # Debugging
......
# #
# Pyrex - Types # Pyrex - Types
# #
...@@ -6,6 +8,7 @@ from Code import UtilityCode ...@@ -6,6 +8,7 @@ from Code import UtilityCode
import StringEncoding import StringEncoding
import Naming import Naming
import copy import copy
from Errors import error
class BaseType(object): class BaseType(object):
# #
...@@ -40,6 +43,7 @@ class PyrexType(BaseType): ...@@ -40,6 +43,7 @@ class PyrexType(BaseType):
# is_array boolean Is a C array type # is_array boolean Is a C array type
# is_ptr boolean Is a C pointer type # is_ptr boolean Is a C pointer type
# is_null_ptr boolean Is the type of NULL # is_null_ptr boolean Is the type of NULL
# is_reference boolean Is a C reference type
# is_cfunction boolean Is a C function type # is_cfunction boolean Is a C function type
# is_struct_or_union boolean Is a C struct or union type # is_struct_or_union boolean Is a C struct or union type
# is_struct boolean Is a C struct type # is_struct boolean Is a C struct type
...@@ -91,8 +95,10 @@ class PyrexType(BaseType): ...@@ -91,8 +95,10 @@ class PyrexType(BaseType):
is_array = 0 is_array = 0
is_ptr = 0 is_ptr = 0
is_null_ptr = 0 is_null_ptr = 0
is_reference = 0
is_cfunction = 0 is_cfunction = 0
is_struct_or_union = 0 is_struct_or_union = 0
is_cpp_class = 0
is_struct = 0 is_struct = 0
is_enum = 0 is_enum = 0
is_typedef = 0 is_typedef = 0
...@@ -109,6 +115,10 @@ class PyrexType(BaseType): ...@@ -109,6 +115,10 @@ class PyrexType(BaseType):
# If a typedef, returns the base type. # If a typedef, returns the base type.
return self return self
def specialize(self, values):
# TODO(danilo): Override wherever it makes sense.
return self
def literal_code(self, value): def literal_code(self, value):
# Returns a C code fragment representing a literal # Returns a C code fragment representing a literal
# value of this type. # value of this type.
...@@ -1328,9 +1338,19 @@ class CPtrType(CType): ...@@ -1328,9 +1338,19 @@ class CPtrType(CType):
return self.base_type.pointer_assignable_from_resolved_type(other_type) return self.base_type.pointer_assignable_from_resolved_type(other_type)
else: else:
return 0 return 0
if (self.base_type.is_cpp_class and other_type.is_ptr
and other_type.base_type.is_cpp_class and other_type.base_type.is_subclass(self.base_type)):
return 1
if other_type.is_array or other_type.is_ptr: if other_type.is_array or other_type.is_ptr:
return self.base_type.is_void or self.base_type.same_as(other_type.base_type) return self.base_type.is_void or self.base_type.same_as(other_type.base_type)
return 0 return 0
def specialize(self, values):
base_type = self.base_type.specialize(values)
if base_type == self.base_type:
return self
else:
return CPtrType(base_type)
class CNullPtrType(CPtrType): class CNullPtrType(CPtrType):
...@@ -1338,6 +1358,43 @@ class CNullPtrType(CPtrType): ...@@ -1338,6 +1358,43 @@ class CNullPtrType(CPtrType):
is_null_ptr = 1 is_null_ptr = 1
class CReferenceType(CType):
is_reference = 1
def __init__(self, base_type):
self.base_type = base_type
def __repr__(self):
return "<CReferenceType %s>" % repr(self.base_type)
def same_as_resolved_type(self, other_type):
return other_type.is_reference and self.base_type.same_as(other_type.base_type)
def declaration_code(self, entity_code,
for_display = 0, dll_linkage = None, pyrex = 0):
#print "CReferenceType.declaration_code: pointer to", self.base_type ###
return self.base_type.declaration_code(
"&%s" % entity_code,
for_display, dll_linkage, pyrex)
def assignable_from_resolved_type(self, other_type):
if other_type is error_type:
return 1
elif other_type.is_reference and self.base_type == other_type.base_type:
return 1
elif other_type == self.base_type:
return 1
else: #for now
return 0
def specialize(self, values):
base_type = self.base_type.specialize(values)
if base_type == self.base_type:
return self
else:
return CReferenceType(base_type)
class CFuncType(CType): class CFuncType(CType):
# return_type CType # return_type CType
# args [CFuncTypeArg] # args [CFuncTypeArg]
...@@ -1347,13 +1404,15 @@ class CFuncType(CType): ...@@ -1347,13 +1404,15 @@ class CFuncType(CType):
# calling_convention string Function calling convention # calling_convention string Function calling convention
# nogil boolean Can be called without gil # nogil boolean Can be called without gil
# with_gil boolean Acquire gil around function body # with_gil boolean Acquire gil around function body
# templates [string] or None
is_cfunction = 1 is_cfunction = 1
original_sig = None original_sig = None
def __init__(self, return_type, args, has_varargs = 0, def __init__(self, return_type, args, has_varargs = 0,
exception_value = None, exception_check = 0, calling_convention = "", exception_value = None, exception_check = 0, calling_convention = "",
nogil = 0, with_gil = 0, is_overridable = 0, optional_arg_count = 0): nogil = 0, with_gil = 0, is_overridable = 0, optional_arg_count = 0,
templates = None):
self.return_type = return_type self.return_type = return_type
self.args = args self.args = args
self.has_varargs = has_varargs self.has_varargs = has_varargs
...@@ -1364,6 +1423,7 @@ class CFuncType(CType): ...@@ -1364,6 +1423,7 @@ class CFuncType(CType):
self.nogil = nogil self.nogil = nogil
self.with_gil = with_gil self.with_gil = with_gil
self.is_overridable = is_overridable self.is_overridable = is_overridable
self.templates = templates
def __repr__(self): def __repr__(self):
arg_reprs = map(repr, self.args) arg_reprs = map(repr, self.args)
...@@ -1567,6 +1627,23 @@ class CFuncType(CType): ...@@ -1567,6 +1627,23 @@ class CFuncType(CType):
s = self.declaration_code("(*)", with_calling_convention=False) s = self.declaration_code("(*)", with_calling_convention=False)
return '(%s)' % s return '(%s)' % s
def specialize(self, values):
if self.templates is None:
new_templates = None
else:
new_templates = [v.specialize(values) for v in self.templates]
return CFuncType(self.return_type.specialize(values),
[arg.specialize(values) for arg in self.args],
has_varargs = 0,
exception_value = self.exception_value,
exception_check = self.exception_check,
calling_convention = self.calling_convention,
nogil = self.nogil,
with_gil = self.with_gil,
is_overridable = self.is_overridable,
optional_arg_count = self.optional_arg_count,
templates = new_templates)
def opt_arg_cname(self, arg_name): def opt_arg_cname(self, arg_name):
return self.op_arg_struct.base_type.scope.lookup(arg_name).cname return self.op_arg_struct.base_type.scope.lookup(arg_name).cname
...@@ -1593,6 +1670,9 @@ class CFuncTypeArg(object): ...@@ -1593,6 +1670,9 @@ class CFuncTypeArg(object):
def declaration_code(self, for_display = 0): def declaration_code(self, for_display = 0):
return self.type.declaration_code(self.cname, for_display) return self.type.declaration_code(self.cname, for_display)
def specialize(self, values):
return CFuncTypeArg(self.name, self.type.specialize(values), self.pos, self.cname)
class StructUtilityCode(object): class StructUtilityCode(object):
def __init__(self, type, forward_decl): def __init__(self, type, forward_decl):
...@@ -1732,6 +1812,122 @@ class CStructOrUnionType(CType): ...@@ -1732,6 +1812,122 @@ class CStructOrUnionType(CType):
for x in self.scope.var_entries] for x in self.scope.var_entries]
return max(child_depths) + 1 return max(child_depths) + 1
class CppClassType(CType):
# name string
# cname string
# scope CppClassScope
# templates [string] or None
is_cpp_class = 1
has_attributes = 1
exception_check = True
def __init__(self, name, scope, cname, base_classes, templates = None, template_type = None):
self.name = name
self.cname = cname
self.scope = scope
self.base_classes = base_classes
self.operators = []
self.templates = templates
self.template_type = template_type
self.specializations = {}
def specialize_here(self, pos, template_values = None):
if self.templates is None:
error(pos, "'%s' type is not a template" % self);
return PyrexTypes.error_type
if len(self.templates) != len(template_values):
error(pos, "%s templated type receives %d arguments, got %d" %
(self.name, len(self.templates), len(template_values)))
return error_type
return self.specialize(dict(zip(self.templates, template_values)))
def specialize(self, values):
if not self.templates:
return self
key = tuple(values.items())
if key in self.specializations:
return self.specializations[key]
template_values = [t.specialize(values) for t in self.templates]
specialized = self.specializations[key] = \
CppClassType(self.name, None, self.cname, self.base_classes, template_values, template_type=self)
specialized.scope = self.scope.specialize(values)
return specialized
def declaration_code(self, entity_code, for_display = 0, dll_linkage = None, pyrex = 0):
if self.templates:
template_strings = [param.declaration_code('', for_display, pyrex) for param in self.templates]
templates = "<" + ",".join(template_strings) + ">"
else:
templates = ""
if for_display or pyrex:
name = self.name
else:
name = self.cname
return "%s%s %s" % (name, templates, entity_code)
def is_subclass(self, other_type):
# TODO(danilo): Handle templates.
if self.same_as_resolved_type(other_type):
return 1
for base_class in self.base_classes:
if base_class.is_subclass(other_type):
return 1
return 0
def same_as_resolved_type(self, other_type):
if other_type.is_cpp_class:
if self == other_type:
return 1
elif self.template_type and self.template_type == other_type.template_type:
if self.templates == other_type.templates:
return 1
for t1, t2 in zip(self.templates, other_type.templates):
if not t1.same_as_resolved_type(t2):
return 0
return 1
return 0
def assignable_from_resolved_type(self, other_type):
# TODO: handle operator=(...) here?
return other_type.is_cpp_class and other_type.is_subclass(self)
def attributes_known(self):
return self.scope is not None
class TemplatePlaceholderType(CType):
def __init__(self, name):
self.name = name
def declaration_code(self, entity_code, for_display = 0, dll_linkage = None, pyrex = 0):
if entity_code:
return self.name + " " + entity_code
else:
return self.name
def specialize(self, values):
if self in values:
return values[self]
else:
return self
def same_as_resolved_type(self, other_type):
if isinstance(other_type, TemplatePlaceholderType):
return self.name == other_type.name
else:
return 0
def __hash__(self):
return hash(self.name)
def __cmp__(self, other):
if isinstance(other, TemplatePlaceholderType):
return cmp(self.name, other.name)
else:
return cmp(type(self), type(other))
class CEnumType(CType): class CEnumType(CType):
# name string # name string
# cname string or None # cname string or None
...@@ -2001,6 +2197,108 @@ modifiers_and_name_to_type = { ...@@ -2001,6 +2197,108 @@ modifiers_and_name_to_type = {
(1, 0, "bint"): c_bint_type, (1, 0, "bint"): c_bint_type,
} }
def is_promotion0(src_type, dst_type):
if src_type.is_numeric and dst_type.is_numeric:
if src_type.is_int and dst_type.is_int:
if src_type.is_enum:
return True
elif src_type.signed:
return dst_type.signed and src_type.rank <= dst_type.rank
elif dst_type.signed: # and not src_type.signed
src_type.rank < dst_type.rank
else:
return src_type.rank <= dst_type.rank
elif src_type.is_float and dst_type.is_float:
return src_type.rank <= dst_type.rank
else:
return False
else:
return False
def is_promotion(src_type, dst_type):
# It's hard to find a hard definition of promotion, but empirical
# evidence suggests that the below is all that's allowed.
if src_type.is_numeric:
if dst_type.same_as(c_int_type):
return src_type.is_enum or (src_type.is_int and (not src_type.signed) + src_type.rank < dst_type.rank)
elif dst_type.same_as(c_double_type):
return src_type.is_float and src_type.rank <= dst_type.rank
return False
def best_match(args, functions, pos=None):
"""
Finds the best function to be called
Error if no function fits the call or an ambiguity is find (two or more possible functions)
"""
# TODO: args should be a list of types, not a list of Nodes.
actual_nargs = len(args)
possibilities = []
bad_types = 0
from_type = None
target_type = None
for func in functions:
func_type = func.type
if func_type.is_ptr:
func_type = func_type.base_type
# Check function type
if not func_type.is_cfunction:
if not func_type.is_error and pos is not None:
error(pos, "Calling non-function type '%s'" % func_type)
return None
# Check no. of args
max_nargs = len(func_type.args)
min_nargs = max_nargs - func_type.optional_arg_count
if actual_nargs < min_nargs \
or (not func_type.has_varargs and actual_nargs > max_nargs):
if max_nargs == min_nargs and not func_type.has_varargs:
expectation = max_nargs
elif actual_nargs < min_nargs:
expectation = "at least %s" % min_nargs
else:
expectation = "at most %s" % max_nargs
error_str = "Call with wrong number of arguments (expected %s, got %s)" \
% (expectation, actual_nargs)
continue
if len(functions) == 1:
# Optimize the most common case of no overloading...
return func
score = [0,0,0]
for i in range(min(len(args), len(func_type.args))):
src_type = args[i].type
dst_type = func_type.args[i].type
if dst_type.assignable_from(src_type):
if src_type == dst_type or (dst_type.is_reference and \
src_type == dst_type.base_type) or \
dst_type.same_as(src_type):
pass # score 0
elif is_promotion(src_type, dst_type):
score[2] += 1
elif not src_type.is_pyobject:
score[1] += 1
else:
score[0] += 1
else:
bad_types = func
from_type = src_type
target_type = dst_type
break
else:
possibilities.append((score, func)) # so we can sort it
if len(possibilities):
possibilities.sort()
if len(possibilities) > 1 and possibilities[0][0] == possibilities[1][0]:
if pos is not None:
error(pos, "ambiguous overloaded method")
return None
return possibilities[0][1]
if pos is not None:
if bad_types:
error(pos, "Invalid conversion from '%s' to '%s'" % (from_type, target_type))
else:
error(pos, error_str)
return None
def widest_numeric_type(type1, type2): def widest_numeric_type(type1, type2):
# Given two numeric types, return the narrowest type # Given two numeric types, return the narrowest type
# encompassing both of them. # encompassing both of them.
...@@ -2092,6 +2390,26 @@ def c_ptr_type(base_type): ...@@ -2092,6 +2390,26 @@ def c_ptr_type(base_type):
else: else:
return CPtrType(base_type) return CPtrType(base_type)
def c_ref_type(base_type):
# Construct a C reference type
if base_type is error_type:
return error_type
else:
return CReferenceType(base_type)
def Node_to_type(node, env):
from ExprNodes import NameNode, AttributeNode, StringNode, error
if isinstance(node, StringNode):
node = NameNode(node.pos, name=node.value)
if isinstance(node, NameNode) and node.name in rank_to_type_name:
return simple_c_type(1, 0, node.name)
elif isinstance(node, (AttributeNode, NameNode)):
node.analyze_types(env)
if not node.entry.is_type:
pass
else:
error(node.pos, "Bad type")
def same_type(type1, type2): def same_type(type1, type2):
return type1.same_as(type2) return type1.same_as(type2)
......
...@@ -355,6 +355,14 @@ class PyrexScanner(Scanner): ...@@ -355,6 +355,14 @@ class PyrexScanner(Scanner):
t = "%s %s" % (self.sy, self.systring) t = "%s %s" % (self.sy, self.systring)
print("--- %3d %2d %s" % (line, col, t)) print("--- %3d %2d %s" % (line, col, t))
def peek(self):
saved = self.sy, self.systring
self.next()
next = self.sy, self.systring
self.unread(*next)
self.sy, self.systring = saved
return next
def put_back(self, sy, systring): def put_back(self, sy, systring):
self.unread(self.sy, self.systring) self.unread(self.sy, self.systring)
self.sy = sy self.sy = sy
......
...@@ -75,6 +75,7 @@ class Entry(object): ...@@ -75,6 +75,7 @@ class Entry(object):
# is_unbound_cmethod boolean Is an unbound C method of an extension type # is_unbound_cmethod boolean Is an unbound C method of an extension type
# is_type boolean Is a type definition # is_type boolean Is a type definition
# is_cclass boolean Is an extension class # is_cclass boolean Is an extension class
# is_cpp_class boolean Is a C++ class
# is_const boolean Is a constant # is_const boolean Is a constant
# is_property boolean Is a property of an extension type: # is_property boolean Is a property of an extension type:
# doc_cname string or None C const holding the docstring # doc_cname string or None C const holding the docstring
...@@ -133,6 +134,7 @@ class Entry(object): ...@@ -133,6 +134,7 @@ class Entry(object):
is_unbound_cmethod = 0 is_unbound_cmethod = 0
is_type = 0 is_type = 0
is_cclass = 0 is_cclass = 0
is_cpp_class = 0
is_const = 0 is_const = 0
is_property = 0 is_property = 0
doc_cname = None doc_cname = None
...@@ -172,6 +174,7 @@ class Entry(object): ...@@ -172,6 +174,7 @@ class Entry(object):
self.type = type self.type = type
self.pos = pos self.pos = pos
self.init = init self.init = init
self.overloaded_alternatives = []
self.assignments = [] self.assignments = []
def __repr__(self): def __repr__(self):
...@@ -180,6 +183,9 @@ class Entry(object): ...@@ -180,6 +183,9 @@ class Entry(object):
def redeclared(self, pos): def redeclared(self, pos):
error(pos, "'%s' does not match previous declaration" % self.name) error(pos, "'%s' does not match previous declaration" % self.name)
error(self.pos, "Previous declaration is here") error(self.pos, "Previous declaration is here")
def all_alternatives(self):
return [self] + self.overloaded_alternatives
class Scope(object): class Scope(object):
# name string Unqualified name # name string Unqualified name
...@@ -290,6 +296,8 @@ class Scope(object): ...@@ -290,6 +296,8 @@ class Scope(object):
# Create new entry, and add to dictionary if # Create new entry, and add to dictionary if
# name is not None. Reports a warning if already # name is not None. Reports a warning if already
# declared. # declared.
if type.is_buffer and not isinstance(self, LocalScope):
error(pos, ERR_BUF_LOCALONLY)
if not self.in_cinclude and cname and re.match("^_[_A-Z]+$", cname): if not self.in_cinclude and cname and re.match("^_[_A-Z]+$", cname):
# See http://www.gnu.org/software/libc/manual/html_node/Reserved-Names.html#Reserved-Names # See http://www.gnu.org/software/libc/manual/html_node/Reserved-Names.html#Reserved-Names
warning(pos, "'%s' is a reserved name in C." % cname, -1) warning(pos, "'%s' is a reserved name in C." % cname, -1)
...@@ -303,6 +311,10 @@ class Scope(object): ...@@ -303,6 +311,10 @@ class Scope(object):
entry.in_cinclude = self.in_cinclude entry.in_cinclude = self.in_cinclude
if name: if name:
entry.qualified_name = self.qualify_name(name) entry.qualified_name = self.qualify_name(name)
# if name in entries and self.is_cpp():
# entries[name].overloaded_alternatives.append(entry)
# else:
# entries[name] = entry
entries[name] = entry entries[name] = entry
entry.scope = self entry.scope = self
entry.visibility = visibility entry.visibility = visibility
...@@ -419,6 +431,10 @@ class Scope(object): ...@@ -419,6 +431,10 @@ class Scope(object):
cname = name cname = name
else: else:
cname = self.mangle(Naming.var_prefix, name) cname = self.mangle(Naming.var_prefix, name)
if type.is_cpp_class and visibility != 'extern':
constructor = type.scope.lookup(u'<init>')
if constructor is not None and PyrexTypes.best_match([], constructor.all_alternatives()) is None:
error(pos, "C++ class must have an empty constructor to be stack allocated")
entry = self.declare(name, cname, type, pos, visibility) entry = self.declare(name, cname, type, pos, visibility)
entry.is_variable = 1 entry.is_variable = 1
self.control_flow.set_state((), (name, 'initalized'), False) self.control_flow.set_state((), (name, 'initalized'), False)
...@@ -445,22 +461,27 @@ class Scope(object): ...@@ -445,22 +461,27 @@ class Scope(object):
cname = None, visibility = 'private', defining = 0, cname = None, visibility = 'private', defining = 0,
api = 0, in_pxd = 0, modifiers = ()): api = 0, in_pxd = 0, modifiers = ()):
# Add an entry for a C function. # Add an entry for a C function.
if not cname:
if api or visibility != 'private':
cname = name
else:
cname = self.mangle(Naming.func_prefix, name)
entry = self.lookup_here(name) entry = self.lookup_here(name)
if entry: if entry:
if visibility != 'private' and visibility != entry.visibility: if visibility != 'private' and visibility != entry.visibility:
warning(pos, "Function '%s' previously declared as '%s'" % (name, entry.visibility), 1) warning(pos, "Function '%s' previously declared as '%s'" % (name, entry.visibility), 1)
if not entry.type.same_as(type): if not entry.type.same_as(type):
if visibility == 'extern' and entry.visibility == 'extern': if visibility == 'extern' and entry.visibility == 'extern':
warning(pos, "Function signature does not match previous declaration", 1) if self.is_cpp():
entry.type = type temp = self.add_cfunction(name, type, pos, cname, visibility, modifiers)
temp.overloaded_alternatives = entry.all_alternatives()
entry = temp
else:
warning(pos, "Function signature does not match previous declaration", 1)
entry.type = type
else: else:
error(pos, "Function signature does not match previous declaration") error(pos, "Function signature does not match previous declaration")
else: else:
if not cname:
if api or visibility != 'private':
cname = name
else:
cname = self.mangle(Naming.func_prefix, name)
entry = self.add_cfunction(name, type, pos, cname, visibility, modifiers) entry = self.add_cfunction(name, type, pos, cname, visibility, modifiers)
entry.func_cname = cname entry.func_cname = cname
if in_pxd and visibility != 'extern': if in_pxd and visibility != 'extern':
...@@ -537,6 +558,21 @@ class Scope(object): ...@@ -537,6 +558,21 @@ class Scope(object):
entry = self.lookup(name) entry = self.lookup(name)
if entry and entry.is_type: if entry and entry.is_type:
return entry.type return entry.type
def lookup_operator(self, operator, operands):
if operands[0].type.is_cpp_class:
obj_type = operands[0].type
if obj_type.is_reference:
obj_type = obj_type.base_type
method = obj_type.scope.lookup("operator%s" % operator)
if method is not None:
res = PyrexTypes.best_match(operands[1:], method.all_alternatives())
if res is not None:
return res
function = self.lookup("operator%s" % operator)
if function is None:
return None
return PyrexTypes.best_match(operands, function.all_alternatives())
def use_utility_code(self, new_code): def use_utility_code(self, new_code):
self.global_scope().use_utility_code(new_code) self.global_scope().use_utility_code(new_code)
...@@ -556,6 +592,13 @@ class Scope(object): ...@@ -556,6 +592,13 @@ class Scope(object):
def infer_types(self): def infer_types(self):
from TypeInference import get_type_inferer from TypeInference import get_type_inferer
get_type_inferer().infer_types(self) get_type_inferer().infer_types(self)
def is_cpp(self):
outer = self.outer_scope
if outer is None:
return False
else:
return outer.is_cpp()
class PreImportScope(Scope): class PreImportScope(Scope):
...@@ -684,6 +727,7 @@ class ModuleScope(Scope): ...@@ -684,6 +727,7 @@ class ModuleScope(Scope):
# cimported_modules [ModuleScope] Modules imported with cimport # cimported_modules [ModuleScope] Modules imported with cimport
# types_imported {PyrexType : 1} Set of types for which import code generated # types_imported {PyrexType : 1} Set of types for which import code generated
# has_import_star boolean Module contains import * # has_import_star boolean Module contains import *
# cpp boolean Compiling a C++ file
is_module_scope = 1 is_module_scope = 1
has_import_star = 0 has_import_star = 0
...@@ -957,6 +1001,42 @@ class ModuleScope(Scope): ...@@ -957,6 +1001,42 @@ class ModuleScope(Scope):
if typedef_flag and not self.in_cinclude: if typedef_flag and not self.in_cinclude:
error(pos, "Forward-referenced type must use 'cdef', not 'ctypedef'") error(pos, "Forward-referenced type must use 'cdef', not 'ctypedef'")
def declare_cpp_class(self, name, scope,
pos, cname = None, base_classes = [],
visibility = 'extern', templates = None):
if visibility != 'extern':
error(pos, "C++ classes may only be extern")
if cname is None:
cname = name
entry = self.lookup(name)
if not entry:
type = PyrexTypes.CppClassType(
name, scope, cname, base_classes, templates = templates)
entry = self.declare_type(name, type, pos, cname,
visibility = visibility, defining = scope is not None)
else:
if not (entry.is_type and entry.type.is_cpp_class):
warning(pos, "'%s' redeclared " % name, 0)
elif scope and entry.type.scope:
warning(pos, "'%s' already defined (ignoring second definition)" % name, 0)
else:
if scope:
entry.type.scope = scope
self.type_entries.append(entry)
if not scope and not entry.type.scope:
entry.type.scope = CppClassScope(name, self)
if templates is not None:
for T in templates:
template_entry = entry.type.scope.declare(T.name, T.name, T, None, 'extern')
template_entry.is_type = 1
def declare_inherited_attributes(entry, base_classes):
for base_class in base_classes:
declare_inherited_attributes(entry, base_class.base_classes)
entry.type.scope.declare_inherited_cpp_attributes(base_class.scope)
declare_inherited_attributes(entry, base_classes)
return entry
def allocate_vtable_names(self, entry): def allocate_vtable_names(self, entry):
# If extension type has a vtable, allocate vtable struct and # If extension type has a vtable, allocate vtable struct and
# slot names for it. # slot names for it.
...@@ -1062,6 +1142,9 @@ class ModuleScope(Scope): ...@@ -1062,6 +1142,9 @@ class ModuleScope(Scope):
var_entry.is_readonly = 1 var_entry.is_readonly = 1
entry.as_variable = var_entry entry.as_variable = var_entry
def is_cpp(self):
return self.cpp
def infer_types(self): def infer_types(self):
from TypeInference import PyObjectTypeInferer from TypeInference import PyObjectTypeInferer
PyObjectTypeInferer().infer_types(self) PyObjectTypeInferer().infer_types(self)
...@@ -1281,6 +1364,8 @@ class CClassScope(ClassScope): ...@@ -1281,6 +1364,8 @@ class CClassScope(ClassScope):
cname = name cname = name
if visibility == 'private': if visibility == 'private':
cname = c_safe_identifier(cname) cname = c_safe_identifier(cname)
if type.is_cpp_class and visibility != 'extern':
error(pos, "C++ classes not allowed as members of an extension type, use a pointer or reference instead")
entry = self.declare(name, cname, type, pos, visibility) entry = self.declare(name, cname, type, pos, visibility)
entry.is_variable = 1 entry.is_variable = 1
self.var_entries.append(entry) self.var_entries.append(entry)
...@@ -1420,6 +1505,63 @@ class CClassScope(ClassScope): ...@@ -1420,6 +1505,63 @@ class CClassScope(ClassScope):
entry.is_inherited = 1 entry.is_inherited = 1
class CppClassScope(Scope):
# Namespace of a C++ class.
inherited_var_entries = []
def __init__(self, name, outer_scope):
Scope.__init__(self, name, outer_scope, None)
self.directives = outer_scope.directives
def declare_var(self, name, type, pos,
cname = None, visibility = 'extern', is_cdef = 0, allow_pyobject = 0):
# Add an entry for an attribute.
if not cname:
cname = name
if type.is_cfunction:
type = PyrexTypes.CPtrType(type)
entry = self.declare(name, cname, type, pos, visibility)
entry.is_variable = 1
self.var_entries.append(entry)
if type.is_pyobject and not allow_pyobject:
error(pos,
"C++ class member cannot be a Python object")
return entry
def declare_cfunction(self, name, type, pos,
cname = None, visibility = 'extern', defining = 0,
api = 0, in_pxd = 0, modifiers = ()):
if name == self.name.split('::')[-1] and cname is None:
name = '<init>'
entry = self.declare_var(name, type, pos, cname, visibility)
def declare_inherited_cpp_attributes(self, base_scope):
# Declare entries for all the C++ attributes of an
# inherited type, with cnames modified appropriately
# to work with this type.
for base_entry in \
base_scope.inherited_var_entries + base_scope.var_entries:
entry = self.declare(base_entry.name, base_entry.cname,
base_entry.type, None, 'extern')
entry.is_variable = 1
self.inherited_var_entries.append(entry)
for base_entry in base_scope.cfunc_entries:
entry = self.declare_cfunction(base_entry.name, base_entry.type,
base_entry.pos, base_entry.cname,
base_entry.visibility, base_entry.func_modifiers)
entry.is_inherited = 1
def specialize(self, values):
scope = CppClassScope(self.name, self.outer_scope)
for entry in self.entries.values():
scope.declare_var(entry.name,
entry.type.specialize(values),
entry.pos,
entry.cname,
entry.visibility)
return scope
class PropertyScope(Scope): class PropertyScope(Scope):
# Scope holding the __get__, __set__ and __del__ methods for # Scope holding the __get__, __set__ and __del__ methods for
# a property of an extension type. # a property of an extension type.
...@@ -1479,3 +1621,7 @@ static PyObject* __Pyx_Method_ClassMethod(PyObject *method) { ...@@ -1479,3 +1621,7 @@ static PyObject* __Pyx_Method_ClassMethod(PyObject *method) {
return NULL; return NULL;
} }
""") """)
#------------------------------------------------------------------------------------
ERR_BUF_LOCALONLY = 'Buffer types only allowed as function local variables'
...@@ -21,7 +21,7 @@ class TestBufferParsing(CythonTest): ...@@ -21,7 +21,7 @@ class TestBufferParsing(CythonTest):
def test_basic(self): def test_basic(self):
t = self.parse(u"cdef object[float, 4, ndim=2, foo=foo] x") t = self.parse(u"cdef object[float, 4, ndim=2, foo=foo] x")
bufnode = t.stats[0].base_type bufnode = t.stats[0].base_type
self.assert_(isinstance(bufnode, CBufferAccessTypeNode)) self.assert_(isinstance(bufnode, TemplatedTypeNode))
self.assertEqual(2, len(bufnode.positional_args)) self.assertEqual(2, len(bufnode.positional_args))
# print bufnode.dump() # print bufnode.dump()
# should put more here... # should put more here...
...@@ -65,7 +65,7 @@ class TestBufferOptions(CythonTest): ...@@ -65,7 +65,7 @@ class TestBufferOptions(CythonTest):
vardef = root.stats[0].body.stats[0] vardef = root.stats[0].body.stats[0]
assert isinstance(vardef, CVarDefNode) # use normal assert as this is to validate the test code assert isinstance(vardef, CVarDefNode) # use normal assert as this is to validate the test code
buftype = vardef.base_type buftype = vardef.base_type
self.assert_(isinstance(buftype, CBufferAccessTypeNode)) self.assert_(isinstance(buftype, TemplatedTypeNode))
self.assert_(isinstance(buftype.base_type_node, CSimpleBaseTypeNode)) self.assert_(isinstance(buftype.base_type_node, CSimpleBaseTypeNode))
self.assertEqual(u"object", buftype.base_type_node.name) self.assertEqual(u"object", buftype.base_type_node.name)
return buftype return buftype
......
...@@ -100,7 +100,8 @@ class TreeVisitor(BasicVisitor): ...@@ -100,7 +100,8 @@ class TreeVisitor(BasicVisitor):
def dump_node(self, node, indent=0): def dump_node(self, node, indent=0):
ignored = list(node.child_attrs) + [u'child_attrs', u'pos', ignored = list(node.child_attrs) + [u'child_attrs', u'pos',
u'gil_message', u'subexprs'] u'gil_message', u'cpp_message',
u'subexprs']
values = [] values = []
pos = node.pos pos = node.pos
if pos: if pos:
......
cdef extern from "<vector>" namespace std:
cdef cppclass vector[TYPE]:
#constructors
__init__()
__init__(vector&)
__init__(int)
__init__(int, TYPE&)
__init__(iterator, iterator)
#operators
TYPE& __getitem__(int)
TYPE& __setitem__(int, TYPE&)
vector __new__(vector&)
bool __eq__(vector&, vector&)
bool __ne__(vector&, vector&)
bool __lt__(vector&, vector&)
bool __gt__(vector&, vector&)
bool __le__(vector&, vector&)
bool __ge__(vector&, vector&)
#others
void assign(int, TYPE)
#void assign(iterator, iterator)
TYPE& at(int)
TYPE& back()
iterator begin()
int capacity()
void clear()
bool empty()
iterator end()
iterator erase(iterator)
iterator erase(iterator, iterator)
TYPE& front()
iterator insert(iterator, TYPE&)
void insert(iterator, int, TYPE&)
void insert(iterator, iterator)
int max_size()
void pop_back()
void push_back(TYPE&)
iterator rbegin()
iterator rend()
void reserve(int)
void resize(int)
void resize(int, TYPE&) #void resize(size_type num, const TYPE& = TYPE())
int size()
void swap(container&)
cdef extern from "<deque>" namespace std:
cdef cppclass deque[TYPE]:
#constructors
__init__()
__init__(deque&)
__init__(int)
__init__(int, TYPE&)
__init__(iterator, iterator)
#operators
TYPE& operator[]( size_type index );
const TYPE& operator[]( size_type index ) const;
deque __new__(deque&);
bool __eq__(deque&, deque&);
bool __ne__(deque&, deque&);
bool __lt__(deque&, deque&);
bool __gt__(deque&, deque&);
bool __le__(deque&, deque&);
bool __ge__(deque&, deque&);
#others
void assign(int, TYPE&)
void assign(iterator, iterator)
TYPE& at(int)
TYPE& back()
iterator begin()
void clear()
bool empty()
iterator end()
iterator erase(iterator)
iterator erase(iterator, iterator)
TYPE& front()
iterator insert(iterator, TYPE&)
void insert(iterator, int, TYPE&)
void insert(iterator, iterator, iterator)
int max_size()
void pop_back()
void pop_front()
void push_back(TYPE&)
void push_front(TYPE&)
iterator rbegin()
iterator rend()
void resize(int)
void resize(int, TYPE&)
int size()
void swap(container&)
...@@ -13,10 +13,11 @@ except: ...@@ -13,10 +13,11 @@ except:
ext_modules=[ ext_modules=[
Extension("primes", ["primes.pyx"]), Extension("primes", ["primes.pyx"]),
Extension("spam", ["spam.pyx"]), Extension("spam", ["spam.pyx"]),
Extension("square", ["square.pyx"], language="c++"),
] ]
for file in glob.glob("*.pyx"): for file in glob.glob("*.pyx"):
if file != "numeric_demo.pyx": if file != "numeric_demo.pyx" and file != "square.pyx":
ext_modules.append(Extension(file[:-4], [file], include_dirs = numpy_include_dirs)) ext_modules.append(Extension(file[:-4], [file], include_dirs = numpy_include_dirs))
setup( setup(
......
...@@ -649,7 +649,7 @@ class FileListExcluder: ...@@ -649,7 +649,7 @@ class FileListExcluder:
self.excludes[line.split()[0]] = True self.excludes[line.split()[0]] = True
def __call__(self, testname): def __call__(self, testname):
return testname.split('.')[-1] in self.excludes return testname in self.excludes or testname.split('.')[-1] in self.excludes
if __name__ == '__main__': if __name__ == '__main__':
from optparse import OptionParser from optparse import OptionParser
......
...@@ -8,3 +8,5 @@ unsignedbehaviour_T184 ...@@ -8,3 +8,5 @@ unsignedbehaviour_T184
missing_baseclass_in_predecl_T262 missing_baseclass_in_predecl_T262
cfunc_call_tuple_args_T408 cfunc_call_tuple_args_T408
cascaded_list_unpacking_T467 cascaded_list_unpacking_T467
compile.cpp_operators
cpp_nested_templates
cdef extern from "operators.h":
cdef cppclass Operators:
Operators(int)
Operators operator+(Operators)
Operators __add__(Operators, Operators)
Operators __sub__(Operators, Operators)
Operators __mul__(Operators, Operators)
Operators __div__(Operators, Operators)
bool __lt__(Operators, Operators)
bool __le__(Operators, Operators)
bool __eq__(Operators, Operators)
bool __ne__(Operators, Operators)
bool __gt__(Operators, Operators)
bool __ge__(Operators, Operators)
Operators __rshift__(Operators, int)
Operators __lshift__(Operators, int)
Operators __mod__(Operators, int)
cdef int v = 10
cdef Operators a
cdef Operators b
cdef Operators c
c = a + b
c = a - b
c = a * b
c = a / b
c = a << 2
c = a >> 1
c = b % 2
a < b
a <= b
a == b
a != b
a > b
a >= b
cdef extern from "templates.h":
cdef cppclass TemplateTest1[T]:
TemplateTest1()
T value
int t
T getValue()
cdef cppclass TemplateTest2[T, U]:
TemplateTest2()
T value1
U value2
T getValue1()
U getValue2()
cdef TemplateTest1[int] a
cdef TemplateTest1[int]* b = new TemplateTest1[int]()
cdef int c = a.getValue()
c = b.getValue()
cdef TemplateTest2[int, char] d
cdef TemplateTest2[int, char]* e = new TemplateTest2[int, char]()
c = d.getValue1()
c = e.getValue2()
cdef char f = d.getValue2()
f = e.getValue2()
del b, e
#ifndef _OPERATORS_H_
#define _OPERATORS_H_
class Operators
{
public:
int value;
Operators() { }
Operators(int value) { this->value = value; }
virtual ~Operators() { }
Operators operator+(Operators f) { return Operators(this->value + f.value); }
Operators operator-(Operators f) { return Operators(this->value - f.value); }
Operators operator*(Operators f) { return Operators(this->value * f.value); }
Operators operator/(Operators f) { return Operators(this->value / f.value); }
bool operator<(Operators f) { return this->value < f.value; }
bool operator<=(Operators f) { return this->value <= f.value; }
bool operator==(Operators f) { return this->value == f.value; }
bool operator!=(Operators f) { return this->value != f.value; }
bool operator>(Operators f) { return this->value > f.value; }
bool operator>=(Operators f) { return this->value >= f.value; }
Operators operator>>(int v) { return Operators(this->value >> v); }
Operators operator<<(int v) { return Operators(this->value << v); }
Operators operator%(int v) { return Operators(this->value % v); }
};
#endif
#ifndef _TEMPLATES_H_
#define _TEMPLATES_H_
template<class T>
class TemplateTest1
{
public:
T value;
int t;
TemplateTest1() { }
T getValue() { return value; }
};
template<class T, class U>
class TemplateTest2
{
public:
T value1;
U value2;
TemplateTest2() { }
T getValue1() { return value1; }
U getValue2() { return value2; }
};
#endif
...@@ -12,8 +12,8 @@ def f(): ...@@ -12,8 +12,8 @@ def f():
cdef object[int, 2, well] buf6 cdef object[int, 2, well] buf6
_ERRORS = u""" _ERRORS = u"""
1:11: Buffer types only allowed as function local variables 1:17: Buffer types only allowed as function local variables
3:15: Buffer types only allowed as function local variables 3:21: Buffer types only allowed as function local variables
6:27: "fakeoption" is not a buffer option 6:27: "fakeoption" is not a buffer option
""" """
#TODO: #TODO:
......
...@@ -12,7 +12,7 @@ def f(a): ...@@ -12,7 +12,7 @@ def f(a):
del s.m # error: deletion of non-Python object del s.m # error: deletion of non-Python object
_ERRORS = u""" _ERRORS = u"""
8:6: Cannot assign to or delete this 8:6: Cannot assign to or delete this
9:45: Deletion of non-Python object 9:45: Deletion of non-Python, non-C++ object
11:6: Deletion of non-Python object 11:6: Deletion of non-Python, non-C++ object
12:6: Deletion of non-Python object 12:6: Deletion of non-Python, non-C++ object
""" """
__doc__ = u"""
>>> test_new_del()
(2, 2)
>>> test_rect_area(3, 4)
12.0
>>> test_square_area(15)
(225.0, 225.0)
"""
cdef extern from "shapes.h" namespace shapes:
cdef cppclass Shape:
float area()
cdef cppclass Circle(Shape):
int radius
Circle(int)
cdef cppclass Rectangle(Shape):
int width
int height
Rectangle(int, int)
cdef cppclass Square(Rectangle):
int side
Square(int)
int constructor_count, destructor_count
def test_new_del():
cdef Rectangle *rect = new Rectangle(10, 20)
cdef Circle *circ = new Circle(15)
del rect, circ
return constructor_count, destructor_count
def test_rect_area(w, h):
cdef Rectangle *rect = new Rectangle(w, h)
try:
return rect.area()
finally:
del rect
def test_square_area(w):
cdef Square *sqr = new Square(w)
cdef Rectangle *rect = sqr
try:
return rect.area(), sqr.area()
finally:
del sqr
cdef double get_area(Rectangle s):
return s.area()
def test_value_call(int w):
"""
>>> test_value_call(5)
(25.0, 25.0)
"""
cdef Square *sqr = new Square(w)
cdef Rectangle *rect = sqr
try:
return get_area(sqr[0]), get_area(rect[0])
finally:
del sqr
from cython import dereference as deref
cdef extern from "cpp_templates_helper.h":
cdef cppclass Wrap[T]:
Wrap(T)
void set(T)
T get()
bint operator==(Wrap[T])
cdef cppclass Pair[T1,T2]:
Pair(T1,T2)
T1 first()
T2 second()
bint operator==(Pair[T1,T2])
bint operator!=(Pair[T1,T2])
def test_wrap_pair(int i, double x):
"""
>>> test_wrap_pair(1, 1.5)
(1, 1.5, True, False)
>>> test_wrap_pair(2, 2.25)
(2, 2.25, True, False)
"""
cdef Pair[int, double] *pair
cdef Wrap[Pair[int, double]] *wrap
try:
pair = new Pair[int, double](i, x)
warp = new Wrap[Pair[int, double]](deref(pair))
return wrap.get().first(), wrap.get().second(), deref(wrap) == deref(wrap)
finally:
del pair, wrap
cimport cython.operator
from cython.operator cimport dereference as deref
cdef extern from "cpp_operators_helper.h":
cdef cppclass TestOps:
char* operator+()
char* operator-()
char* operator*()
char* operator~()
char* operator++()
char* operator--()
char* operator++(int)
char* operator--(int)
char* operator+(int)
char* operator-(int)
char* operator*(int)
char* operator/(int)
char* operator%(int)
char* operator|(int)
char* operator&(int)
char* operator^(int)
char* operator<<(int)
char* operator>>(int)
char* operator==(int)
char* operator!=(int)
char* operator>=(int)
char* operator<=(int)
char* operator>(int)
char* operator<(int)
char* operator[](int)
char* operator()(int)
def test_unops():
"""
>>> test_unops()
unary +
unary -
unary ~
unary *
"""
cdef TestOps* t = new TestOps()
print +t[0]
print -t[0]
print ~t[0]
print deref(t[0])
del t
def test_incdec():
"""
>>> test_incdec()
unary ++
unary --
post ++
post --
"""
cdef TestOps* t = new TestOps()
print cython.operator.preincrement(t[0])
print cython.operator.predecrement(t[0])
print cython.operator.postincrement(t[0])
print cython.operator.postdecrement(t[0])
del t
def test_binop():
"""
>>> test_binop()
binary +
binary -
binary *
binary /
binary %
binary &
binary |
binary ^
binary <<
binary >>
"""
cdef TestOps* t = new TestOps()
print t[0] + 1
print t[0] - 1
print t[0] * 1
print t[0] / 1
print t[0] % 1
print t[0] & 1
print t[0] | 1
print t[0] ^ 1
print t[0] << 1
print t[0] >> 1
del t
def test_cmp():
"""
>>> test_cmp()
binary ==
binary !=
binary >=
binary >
binary <=
binary <
"""
cdef TestOps* t = new TestOps()
print t[0] == 1
print t[0] != 1
print t[0] >= 1
print t[0] > 1
print t[0] <= 1
print t[0] < 1
del t
def test_index_call():
"""
>>> test_index_call()
binary []
binary ()
"""
cdef TestOps* t = new TestOps()
print t[0][100]
print t[0](100)
del t
#define UN_OP(op) const char* operator op () { return "unary "#op; }
#define POST_UN_OP(op) const char* operator op (int x) { return "post "#op; }
#define BIN_OP(op) const char* operator op (int x) { return "binary "#op; }
class TestOps {
public:
UN_OP(-);
UN_OP(+);
UN_OP(*);
UN_OP(~);
UN_OP(!);
UN_OP(&);
UN_OP(++);
UN_OP(--);
POST_UN_OP(++);
POST_UN_OP(--);
BIN_OP(+);
BIN_OP(-);
BIN_OP(*);
BIN_OP(/);
BIN_OP(%);
BIN_OP(<<);
BIN_OP(>>);
BIN_OP(|);
BIN_OP(&);
BIN_OP(^);
BIN_OP(==);
BIN_OP(!=);
BIN_OP(<=);
BIN_OP(<);
BIN_OP(>=);
BIN_OP(>);
BIN_OP([]);
BIN_OP(());
};
__doc__ = u"""
>>> test_vector([1,10,100])
1
10
100
"""
cdef extern from "vector" namespace std:
cdef cppclass iterator[T]:
pass
cdef cppclass vector[T]:
#constructors
__init__()
T at(int)
void push_back(T t)
void assign(int, T)
void clear()
iterator end()
iterator begin()
int size()
def test_vector(L):
cdef vector[int] *V = new vector[int]()
for a in L:
V.push_back(a)
cdef int i
for i in range(len(L)):
print V.at(i)
del V
cdef extern from "<vector>" namespace std:
cdef cppclass vector[T]:
void push_back(T)
size_t size()
T operator[](size_t)
def simple_test(double x):
"""
>>> simple_test(55)
3
"""
cdef vector[double] *v
try:
v = new vector[double]()
v.push_back(1.0)
v.push_back(x)
from math import pi
v.push_back(pi)
return v.size()
finally:
del v
def list_test(L):
"""
>>> list_test([1,2,4,8])
(4, 4)
>>> list_test([])
(0, 0)
>>> list_test([-1] * 1000)
(1000, 1000)
"""
cdef vector[int] *v
try:
v = new vector[int]()
for a in L:
v.push_back(a)
return len(L), v.size()
finally:
del v
def index_test(L):
"""
>>> index_test([1,2,4,8])
(1.0, 8.0)
>>> index_test([1.25])
(1.25, 1.25)
"""
cdef vector[double] *v
try:
v = new vector[double]()
for a in L:
v.push_back(a)
return v[0][0], v[0][len(L)-1]
finally:
del v
from cython.operator import dereference as deref
cdef extern from "cpp_templates_helper.h":
cdef cppclass Wrap[T]:
Wrap(T)
void set(T)
T get()
bint operator==(Wrap[T])
cdef cppclass Pair[T1,T2]:
Pair(T1,T2)
T1 first()
T2 second()
bint operator==(Pair[T1,T2])
bint operator!=(Pair[T1,T2])
def test_int(int x, int y):
"""
>>> test_int(3, 4)
(3, 4, False)
>>> test_int(100, 100)
(100, 100, True)
"""
cdef Wrap[int] *a, *b
try:
a = new Wrap[int](x)
b = new Wrap[int](0)
b.set(y)
return a.get(), b.get(), a[0] == b[0]
finally:
del a, b
def test_double(double x, double y):
"""
>>> test_double(3, 3.5)
(3.0, 3.5, False)
>>> test_double(100, 100)
(100.0, 100.0, True)
"""
cdef Wrap[double] *a, *b
try:
a = new Wrap[double](x)
b = new Wrap[double](-1)
b.set(y)
return a.get(), b.get(), deref(a) == deref(b)
finally:
del a, b
def test_pair(int i, double x):
"""
>>> test_pair(1, 1.5)
(1, 1.5, True, False)
>>> test_pair(2, 2.25)
(2, 2.25, True, False)
"""
cdef Pair[int, double] *pair
try:
pair = new Pair[int, double](i, x)
return pair.first(), pair.second(), deref(pair) == deref(pair), deref(pair) != deref(pair)
finally:
del pair
template <class T>
class Wrap {
T value;
public:
Wrap(T v) { value = v; }
void set(T v) { value = v; }
T get(void) { return value; }
bool operator==(Wrap<T> other) { return value == other.value; }
};
template <class T1, class T2>
class Pair {
T1 _first;
T2 _second;
public:
Pair(T1 u, T2 v) { _first = u; _second = v; }
T1 first(void) { return _first; }
T2 second(void) { return _second; }
bool operator==(Pair<T1,T2> other) { return _first == other._first && _second == other._second; }
bool operator!=(Pair<T1,T2> other) { return _first != other._first || _second != other._second; }
};
cdef extern from *:
int new(int new)
def new(x):
"""
>>> new(3)
3
"""
cdef int new = x
return new
def x(new):
"""
>>> x(10)
110
>>> x(1)
1
"""
if new*new != new:
return new + new**2
return new
class A:
def new(self, n):
"""
>>> a = A()
>>> a.new(3)
6
>>> a.new(5)
120
"""
if n <= 1:
return 1
else:
return n * self.new(n-1)
#ifndef SHAPES_H
#define SHAPES_H
namespace shapes {
int constructor_count = 0;
int destructor_count = 0;
class Shape
{
public:
virtual float area() = 0;
Shape() { constructor_count++; }
virtual ~Shape() { destructor_count++; }
};
class Rectangle : public Shape
{
public:
Rectangle(int width, int height)
{
this->width = width;
this->height = height;
}
float area() { return width * height; }
int width;
int height;
};
class Square : public Rectangle
{
public:
Square(int side) : Rectangle(side, side) { this->side = side; }
int side;
};
class Circle : public Shape {
public:
Circle(int radius) { this->radius = radius; }
float area() { return 3.1415926535897931f * radius; }
int radius;
};
}
#endif
cimport cython
def test_deref(int x):
"""
>>> test_deref(3)
3
>>> test_deref(5)
5
"""
cdef int* x_ptr = &x
return cython.dereference(x_ptr)
def increment_decrement(int x):
"""
>>> increment_decrement(10)
11 11 12
11 11 10
10
"""
print cython.preincrement(x), cython.postincrement(x), x
print cython.predecrement(x), cython.postdecrement(x), x
return x
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment