Commit 848fbbf3 authored by gsamain's avatar gsamain Committed by Xavier Thompson

Cypclass constructor wrapper

parent 76091374
...@@ -5579,15 +5579,40 @@ class CallNode(ExprNode): ...@@ -5579,15 +5579,40 @@ class CallNode(ExprNode):
elif type and type.is_cpp_class: elif type and type.is_cpp_class:
self.args = [ arg.analyse_types(env) for arg in self.args ] self.args = [ arg.analyse_types(env) for arg in self.args ]
constructor = type.scope.lookup("<init>") constructor = type.scope.lookup("<init>")
if not constructor: constructor_type = None
constructor_cname = None
if type.is_cyp_class:
constructor = wrapper = type.scope.lookup_here("<constructor>")
if not wrapper:
error(self.function.pos, "no constructor wrapper found for Cypclass type '%s'" % self.function.name)
namespace_list = wrapper.func_cname.split('::')
templates = ''
if type.templates:
templates = '<' + ','.join([param.declaration_code('')
for param in type.templates
if not PyrexTypes.is_optional_template_param(param) and not param.is_fused]) + '>'
if len(namespace_list) > 2:
# We do this because cypclass wrappers are outside of the class namespace
# in the C++ code, but they are declared within the class scope
constructor_cname = '::'.join(namespace_list[:-2] + [namespace_list[-1]]) + templates
else:
constructor_cname = namespace_list[-1] + templates
constructor_type = wrapper.type
elif not constructor:
error(self.function.pos, "no constructor found for C++ type '%s'" % self.function.name) error(self.function.pos, "no constructor found for C++ type '%s'" % self.function.name)
self.type = error_type self.type = error_type
return self return self
self.function = RawCNameExprNode(self.function.pos, constructor.type) else:
constructor_type = constructor.type
constructor_cname = type.empty_declaration_code()
self.function = RawCNameExprNode(self.function.pos, constructor_type)
self.function.entry = constructor self.function.entry = constructor
self.function.set_cname(type.empty_declaration_code()) self.function.set_cname(constructor_cname)
self.analyse_c_function_call(env) self.analyse_c_function_call(env)
self.type = type if type.is_cyp_class:
self.type = constructor_type.return_type
else:
self.type = type
return True return True
def is_lvalue(self): def is_lvalue(self):
......
...@@ -651,6 +651,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -651,6 +651,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
type_entries.append(entry) type_entries.append(entry)
type_entries = [t for t in type_entries if t not in vtabslot_entries] type_entries = [t for t in type_entries if t not in vtabslot_entries]
self.generate_type_header_code(type_entries, code) self.generate_type_header_code(type_entries, code)
self.generate_cyp_class_wrapper_definitions(type_entries, code)
for entry in vtabslot_list: for entry in vtabslot_list:
self.generate_objstruct_definition(entry.type, code) self.generate_objstruct_definition(entry.type, code)
self.generate_typeobj_predeclaration(entry, code) self.generate_typeobj_predeclaration(entry, code)
...@@ -889,6 +890,22 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -889,6 +890,22 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
elif type.is_extension_type: elif type.is_extension_type:
self.generate_objstruct_definition(type, code) self.generate_objstruct_definition(type, code)
def generate_cyp_class_wrapper_definitions(self, type_entries, code):
for entry in type_entries:
if entry.type.is_cyp_class:
# Generate wrapper constructor
scope = entry.type.scope
wrapper = scope.lookup_here("<constructor>")
constructor = scope.lookup_here("<init>")
new = scope.lookup_here("__new__")
alloc = scope.lookup_here("<alloc>")
if not wrapper:
error(self.pos, "No constructor wrapper found for cypclass %s, did you write an __init__ method ?" % type.name)
return
for wrapper_entry in wrapper.all_alternatives():
if wrapper_entry.used or entry.type.templates:
self.generate_cyp_class_wrapper_definition(entry.type, wrapper_entry, constructor, new, alloc, code)
def generate_gcc33_hack(self, env, code): def generate_gcc33_hack(self, env, code):
# Workaround for spurious warning generation in gcc 3.3 # Workaround for spurious warning generation in gcc 3.3
code.putln("") code.putln("")
...@@ -973,6 +990,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -973,6 +990,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.mark_pos(entry.pos) code.mark_pos(entry.pos)
type = entry.type type = entry.type
scope = type.scope scope = type.scope
default_constructor = False
if scope: if scope:
if type.templates: if type.templates:
code.putln("template <class %s>" % ", class ".join( code.putln("template <class %s>" % ", class ".join(
...@@ -996,8 +1014,6 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -996,8 +1014,6 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
has_virtual_methods = False has_virtual_methods = False
constructor = None constructor = None
destructor = None destructor = None
if type.is_cyp_class:
has_virtual_methods = True
for attr in scope.var_entries: for attr in scope.var_entries:
cname = attr.cname cname = attr.cname
if attr.type.is_cfunction and attr.type.is_static_method: if attr.type.is_cfunction and attr.type.is_static_method:
...@@ -1030,7 +1046,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1030,7 +1046,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
else: else:
code.putln("%s(%s);" % (type.cname, ", ".join(arg_decls))) code.putln("%s(%s);" % (type.cname, ", ".join(arg_decls)))
if constructor or py_attrs: if not type.is_cyp_class and (constructor or py_attrs):
if constructor: if constructor:
for constructor_alternative in constructor.all_alternatives(): for constructor_alternative in constructor.all_alternatives():
arg_decls = [] arg_decls = []
...@@ -1095,12 +1111,158 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1095,12 +1111,158 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
else: else:
code.putln("%s(const %s& __Pyx_other);" % (type.cname, type.cname)) code.putln("%s(const %s& __Pyx_other);" % (type.cname, type.cname))
code.putln("%s& operator=(const %s& __Pyx_other);" % (type.cname, type.cname)) code.putln("%s& operator=(const %s& __Pyx_other);" % (type.cname, type.cname))
if type.is_cyp_class:
code.putln("// Auto generating default constructor to have Python-like behaviour")
code.putln("%s(){}" % type.cname)
code.putln("// Generating __alloc__ function (used for __new__ calls)")
alloc_entry = scope.lookup_here("<alloc>")
code.putln("static %s { return new %s(); }" % (alloc_entry.type.declaration_code(alloc_entry.cname), type.declaration_code("", deref=1)))
code.putln("};") code.putln("};")
if type.is_cyp_class: if type.is_cyp_class:
code.globalstate.use_utility_code( code.globalstate.use_utility_code(
UtilityCode.load("CyObjects", "CyObjects.cpp", proto_block="utility_code_proto_before_types")) UtilityCode.load("CyObjects", "CyObjects.cpp", proto_block="utility_code_proto_before_types"))
def generate_cyp_class_wrapper_definition(self, type, wrapper_entry, constructor_entry, new_entry, alloc_entry, code):
if type.templates:
code.putln("template <typename %s>" % ", class ".join(
[T.empty_declaration_code() for T in type.templates]))
init_entry = constructor_entry
self_type = wrapper_entry.type.return_type.declaration_code('')
type_string = type.empty_declaration_code()
class_name = type.name
wrapper_cname = "%s__constructor__%s" % (Naming.func_prefix, class_name)
wrapper_type = wrapper_entry.type
arg_decls = []
arg_names = []
for arg in wrapper_type.args[:len(wrapper_type.args)-wrapper_type.optional_arg_count]:
arg_decl = arg.declaration_code()
arg_decls.append(arg_decl)
arg_names.append(arg.cname)
if wrapper_type.optional_arg_count:
arg_decls.append(wrapper_type.op_arg_struct.declaration_code(Naming.optional_args_cname))
arg_names.append(Naming.optional_args_cname)
if wrapper_type.has_varargs:
# We can't safely handle varargs because we need
# to know where the size argument is to start a va_list
error(wrapper_entry.pos,
"Cypclass cannot handle variable arguments constructors, but you can use optional arguments (arg=some_value)")
if not arg_decls:
arg_decls = ["void"]
decl_arg_string = ', '.join(arg_decls)
code.putln("static %s %s(%s)" % (self_type, wrapper_cname, decl_arg_string))
code.putln("{")
wrapper_arg_types = [arg.type for arg in wrapper_entry.type.args]
pos = wrapper_entry.pos or type.entry.pos
if new_entry:
alloc_type = alloc_entry.type
new_arg_types = [alloc_type] + wrapper_arg_types
new_entry = PyrexTypes.best_match(new_arg_types,
new_entry.all_alternatives(), pos)
if new_entry:
alloc_call_string = "(" + new_entry.type.original_alloc_type.type.declaration_code("") + ") %s" % alloc_entry.func_cname
new_arg_names = [alloc_call_string] + arg_names
new_arg_string = ', '.join(new_arg_names)
code.putln("%s self =(%s) %s(%s);" % (self_type, self_type, new_entry.func_cname, new_arg_string))
else:
code.putln("%s self = new %s();" % (self_type, type_string))
if init_entry:
init_entry = PyrexTypes.best_match(wrapper_arg_types,
init_entry.all_alternatives(), None)
if init_entry and (not new_entry or new_entry.type.return_type == type):
# Calling __init__
max_init_nargs = len(init_entry.type.args)
min_init_nargs = max_init_nargs - init_entry.type.optional_arg_count
max_wrapper_nargs = len(wrapper_entry.type.args)
min_wrapper_nargs = max_wrapper_nargs - wrapper_entry.type.optional_arg_count
if min_init_nargs == min_wrapper_nargs:
# The optional arguments begin at the same rank for both function
# => just pass the wrapper opt args structure, and everything will be fine.
if max_wrapper_nargs > min_wrapper_nargs:
# The wrapper has optional args
arg_names[-1] = "(%s) %s" % (init_entry.type.op_arg_struct.declaration_code(''), arg_names[-1])
elif max_init_nargs > min_init_nargs:
# The wrapper has no optional args but the __init__ function does
arg_names.append("(%s) NULL" % init_entry.type.op_arg_struct.declaration_code(''))
# else, neither __init__ nor __new__ have optional arguments, nothing to do
elif min_wrapper_nargs < min_init_nargs:
# It means some args from the wrapper should be at
# their default values, which we cannot know from here,
# so shout and stop, sadly.
error(init_entry.pos, "Could not call this __init__ function because the corresponding __new__ wrapper isn't aware of default values")
error(wrapper_entry.pos, "Wrapped __new__ is here (some args passed to __init__ could be at their default values)")
elif min_wrapper_nargs > min_init_nargs:
# Here, the __init__ optional arguments start before
# the __new__ ones. We have to unpack the __new__ opt args struct
# in some variables and then repack in the __init__ opt args struct.
init_opt_args_name_list = [arg.cname for arg in wrapper_entry.type.args[min_init_nargs:]]
# The first __init__ optional arguments are mandatory
# in the __new__ signature, so they will always appear
# in the __init__ optional arguments structure
init_opt_args_number = "init_opt_n"
code.putln("int %s = %s;" % (init_opt_args_number, min_wrapper_nargs - min_init_nargs))
if wrapper_entry.type.optional_arg_count:
for i, arg in enumerate(wrapper_entry.type.args[min_wrapper_nargs:]):
# It's an opt arg => it's not declared in the (c++) function scope => declare a variable for it
arg_name = arg.cname
code.putln("%s;" % arg.type.declaration_code(arg_name))
# Arguments unpacking
optional_struct_name = arg_names.pop()
code.putln("if (%s) {" % optional_struct_name)
# This is necessary to keep __init__ informed of
# how many optional arguments were explicitely given
code.putln("%s += %s->%sn;" % (init_opt_args_number, optional_struct_name, Naming.pyrex_prefix))
braces_number = 1 + max_wrapper_nargs - min_wrapper_nargs
for i, arg in enumerate(wrapper_entry.type.args[min_wrapper_nargs:]):
code.putln("if(%s->%sn > %s) {" % (optional_struct_name, Naming.pyrex_prefix, i))
code.putln("%s = %s->%s;" % (
arg.cname,
optional_struct_name,
wrapper_entry.type.op_arg_struct.base_type.scope.var_entries[i+1].cname
))
for _ in range(braces_number):
code.putln('}')
# Arguments packing
init_opt_args_struct_name = "init_opt_args"
code.putln("%s;" % init_entry.type.op_arg_struct.base_type.declaration_code(init_opt_args_struct_name))
code.putln("%s.%sn = %s;" % (init_opt_args_struct_name, Naming.pyrex_prefix, init_opt_args_number))
for i, arg_name in enumerate(init_opt_args_name_list):
# The second tuple member is a bit tricky.
# Actually, the only way we have to precisely know the attribute cname
# which corresponds to the argument in the opt args struct
# is to rely on the declaration order in the struct scope.
# FuncDefNode doesn't do this because it has it's declarator node,
# which is not our case here.
code.putln("%s.%s = %s;" % (
init_opt_args_struct_name,
init_entry.type.opt_arg_cname(init_entry.type.args[min_init_nargs+i].name),
arg_name
))
arg_names = arg_names[:min_init_nargs] + ["&"+init_opt_args_struct_name]
init_arg_string = ','.join(arg_names)
code.putln("self->%s(%s);" % (init_entry.cname, init_arg_string))
code.putln("return self;")
code.putln("}")
def generate_enum_definition(self, entry, code): def generate_enum_definition(self, entry, code):
code.mark_pos(entry.pos) code.mark_pos(entry.pos)
type = entry.type type = entry.type
......
...@@ -680,7 +680,28 @@ class Scope(object): ...@@ -680,7 +680,28 @@ class Scope(object):
if scope: if scope:
entry.type.set_scope(scope) entry.type.set_scope(scope)
declare_inherited_attributes(entry, base_classes) declare_inherited_attributes(entry, base_classes)
scope.declare_var(name="this", cname="this", type=PyrexTypes.CPtrType(entry.type), pos=entry.pos) this_type = PyrexTypes.CPtrType(entry.type) if not cypclass else entry.type
scope.declare_var(name="this", cname="this", type=this_type, pos=entry.pos)
if cypclass:
# Declare a shadow default constructor
wrapper_type = PyrexTypes.CFuncType(entry.type, [], nogil=1)
wrapper_cname = "%s__constructor__%s" % (Naming.func_prefix, name)
wrapper_name = "<constructor>"
wrapper_entry = scope.declare(wrapper_name, wrapper_cname, wrapper_type, pos, visibility)
wrapper_type.entry = wrapper_entry
wrapper_entry.is_inherited = 1
wrapper_entry.is_cfunction = 1
wrapper_entry.func_cname = "%s::%s" % (entry.type.empty_declaration_code(), wrapper_cname)
# Declare the default __alloc__ method
alloc_type = wrapper_type
alloc_cname = "%s__alloc__%s" % (Naming.func_prefix, name)
alloc_name = "<alloc>"
alloc_entry = scope.declare(alloc_name, alloc_cname, alloc_type, pos, visibility)
alloc_type.entry = alloc_entry
alloc_entry.is_cfunction = 1
alloc_entry.func_cname = "%s::%s" % (entry.type.empty_declaration_code(), alloc_cname)
if self.is_cpp_class_scope: if self.is_cpp_class_scope:
entry.type.namespace = self.outer_scope.lookup(self.name).type entry.type.namespace = self.outer_scope.lookup(self.name).type
return entry return entry
...@@ -2543,6 +2564,26 @@ class CppClassScope(Scope): ...@@ -2543,6 +2564,26 @@ class CppClassScope(Scope):
self.var_entries.append(entry) self.var_entries.append(entry)
return entry return entry
def declare_constructor_wrapper(self, args, pos, defining=0, has_varargs=0, optional_arg_count=0, op_arg_struct = None, return_type=None):
if not return_type:
return_type = self.type
class_type = self.parent_type
class_name = self.name.split('::')[-1]
wrapper_cname = "%s__constructor__%s" % (Naming.func_prefix, class_name)
wrapper_name = "<constructor>"
wrapper_type = PyrexTypes.CFuncType(return_type, args, nogil=1,
has_varargs=has_varargs,
optional_arg_count=optional_arg_count)
if op_arg_struct:
wrapper_type.op_arg_struct = op_arg_struct
wrapper_entry = self.declare(wrapper_name, wrapper_cname, wrapper_type,
pos, 'extern')
wrapper_type.entry = wrapper_entry
wrapper_entry.is_cfunction = 1
if defining:
wrapper_entry.func_cname = "%s::%s" % (class_type.empty_declaration_code(), wrapper_cname)
return wrapper_entry
def declare_cfunction(self, name, type, pos, def declare_cfunction(self, name, type, pos,
cname=None, visibility='extern', api=0, in_pxd=0, cname=None, visibility='extern', api=0, in_pxd=0,
defining=0, modifiers=(), utility_code=None, overridable=False): defining=0, modifiers=(), utility_code=None, overridable=False):
...@@ -2555,16 +2596,40 @@ class CppClassScope(Scope): ...@@ -2555,16 +2596,40 @@ class CppClassScope(Scope):
# arguments that cannot by called by value. # arguments that cannot by called by value.
type.original_args = type.args type.original_args = type.args
def maybe_ref(arg): def maybe_ref(arg):
if arg.type.is_cpp_class and not arg.type.is_reference: if arg.type.is_cpp_class and not arg.type.is_reference and not arg.type.is_cyp_class:
return PyrexTypes.CFuncTypeArg( return PyrexTypes.CFuncTypeArg(
arg.name, PyrexTypes.c_ref_type(arg.type), arg.pos) arg.name, PyrexTypes.c_ref_type(arg.type), arg.pos)
else: else:
return arg return arg
type.args = [maybe_ref(arg) for arg in type.args] type.args = [maybe_ref(arg) for arg in type.args]
if self.type.is_cyp_class and not self.lookup_here("__new__"):
self.declare_constructor_wrapper(type.args, pos, defining,
type.has_varargs, type.optional_arg_count,
getattr(type, 'op_arg_struct', None))
elif name == '__dealloc__' and cname is None: elif name == '__dealloc__' and cname is None:
cname = "%s__dealloc__%s" % (Naming.func_prefix, class_name) cname = "%s__dealloc__%s" % (Naming.func_prefix, class_name)
name = EncodedString('<del>') name = EncodedString('<del>')
type.return_type = PyrexTypes.c_void_type type.return_type = PyrexTypes.c_void_type
elif name == '__alloc__' and self.type.is_cyp_class:
cname = "%s__alloc__%s" % (Naming.func_prefix, class_name)
name = '<alloc>'
elif name == '__new__' and self.type.is_cyp_class:
if name in self.entries:
if self.entries[name].is_inherited:
del self.entries[name]
else:
error(pos, "Couldn't have more than one __new__ function")
if self.lookup_here("<constructor>"):
del self.entries["<constructor>"]
self.declare_constructor_wrapper(type.args[1:], pos, defining,
type.has_varargs, type.optional_arg_count,
getattr(type, 'op_arg_struct', None),
return_type=type.return_type)
type.original_alloc_type = type.args[0]
if name in ('<init>', '<del>') and type.nogil: if name in ('<init>', '<del>') and type.nogil:
for base in self.type.base_classes: for base in self.type.base_classes:
if base is cy_object_type: if base is cy_object_type:
...@@ -2598,14 +2663,50 @@ class CppClassScope(Scope): ...@@ -2598,14 +2663,50 @@ class CppClassScope(Scope):
# to work with this type. # to work with this type.
for base_entry in \ for base_entry in \
base_scope.inherited_var_entries + base_scope.var_entries: base_scope.inherited_var_entries + base_scope.var_entries:
base_entry_type = base_entry.type
#constructor/destructor is not inherited #constructor/destructor is not inherited
if base_entry.name in ("<init>", "<del>"): if base_entry.name == "<del>"\
or base_entry.name == "<init>" and not self.parent_type.is_cyp_class\
or base_entry.name in ("<constructor>", "<alloc>") and self.parent_type.is_cyp_class:
continue continue
elif base_entry.name == "<init>" and not self.lookup_here("__new__"):
wrapper_entry = self.declare_constructor_wrapper(base_entry_type.args, base_entry.pos,
defining=1, has_varargs = base_entry_type.has_varargs,
optional_arg_count = base_entry_type.optional_arg_count,
op_arg_struct = getattr(base_entry_type, 'op_arg_struct', None),
return_type=self.parent_type)
wrapper_entry.is_inherited = 1
#print base_entry.name, self.entries #print base_entry.name, self.entries
elif base_entry.name == "__new__" and self.parent_type.is_cyp_class:
# Rewrite first argument for __new__
alloc_type = PyrexTypes.CPtrType(PyrexTypes.CFuncType(
self.parent_type, [], nogil=1))
alloc_arg = PyrexTypes.CFuncTypeArg(base_entry.type.args[0].name, alloc_type,
base_entry.type.args[0].pos, cname=base_entry.type.args[0].cname)
base_entry_type = PyrexTypes.CFuncType(base_entry_type.return_type,
[alloc_arg] + base_entry_type.args[1:], nogil=1,
has_varargs=base_entry_type.has_varargs,
optional_arg_count=base_entry_type.optional_arg_count)
if hasattr(base_entry.type, 'op_arg_struct'):
base_entry_type.op_arg_struct = base_entry.type.op_arg_struct
base_entry_type.original_alloc_type = base_entry.type.original_alloc_type
if base_entry.name in self.entries:
del self.entries[base_entry.name]
del self.entries["<constructor>"]
elif "<init>" in self.entries:
del self.entries["<constructor>"]
wrapper_entry = self.declare_constructor_wrapper(base_entry_type.args[1:],
base_entry.pos, defining=1,
has_varargs = base_entry_type.has_varargs,
optional_arg_count = base_entry_type.optional_arg_count,
op_arg_struct = getattr(base_entry_type, 'op_arg_struct', None),
return_type=base_entry_type.return_type)
wrapper_entry.is_inherited = 1
if base_entry.name in self.entries: if base_entry.name in self.entries:
base_entry.name # FIXME: is there anything to do in this case? base_entry.name # FIXME: is there anything to do in this case?
entry = self.declare(base_entry.name, base_entry.cname, entry = self.declare(base_entry.name, base_entry.cname,
base_entry.type, None, 'extern') base_entry_type, None, 'extern')
entry.is_variable = 1 entry.is_variable = 1
entry.is_inherited = 1 entry.is_inherited = 1
entry.is_cfunction = base_entry.is_cfunction entry.is_cfunction = base_entry.is_cfunction
...@@ -2659,7 +2760,9 @@ class CppClassScope(Scope): ...@@ -2659,7 +2760,9 @@ class CppClassScope(Scope):
name = "<init>" name = "<init>"
elif name == "__dealloc__": elif name == "__dealloc__":
name = "<del>" name = "<del>"
return super(CppClassScope, self).lookup_here(name) elif name == "__alloc__":
name = "<alloc>"
return super(CppClassScope,self).lookup_here(name)
class PropertyScope(Scope): class PropertyScope(Scope):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment