Commit 943dbd7d authored by Xavier Thompson's avatar Xavier Thompson

Inject wrapper classes into the AST only after the declaration analysis phase

parent f82fb075
...@@ -93,56 +93,55 @@ underlying_name = EncodedString("nogil_cyobject") ...@@ -93,56 +93,55 @@ underlying_name = EncodedString("nogil_cyobject")
# - Insert additional cclass wrapper nodes by returning lists of nodes # - Insert additional cclass wrapper nodes by returning lists of nodes
# => must run after NormalizeTree (otherwise single statements might not be held in a list) # => must run after NormalizeTree (otherwise single statements might not be held in a list)
# #
class CypclassWrapperInjection(VisitorTransform): class CypclassWrapperInjection(CythonTransform):
""" """
Synthesize and insert a wrapper c class at the module level for each cypclass that supports it. Synthesize and insert a wrapper c class at the module level for each cypclass that supports it.
- Even nested cypclasses have their wrapper at the module level. - Even nested cypclasses have their wrapper at the module level.
- Must run after NormalizeTree. - Must run after NormalizeTree.
- The root node passed when calling this visitor should not be lower than a ModuleNode.
""" """
def __call__(self, root): def __call__(self, root):
self.cypclass_wrappers_stack = [] from .ParseTreeTransforms import AnalyseDeclarationsTransform
self.nesting_stack = [] self.analyser = AnalyseDeclarationsTransform(self.context)
self.module_scope = root.scope
return super(CypclassWrapperInjection, self).__call__(root) return super(CypclassWrapperInjection, self).__call__(root)
def visit_Node(self, node): def visit_ModuleNode(self, node):
self.cypclass_wrappers = []
self.nesting_stack = []
self.module_scope = node.scope
self.visitchildren(node) self.visitchildren(node)
self.inject_cypclass_wrappers(node)
return node return node
def inject_cypclass_wrappers(self, module_node):
fake_module_node = module_node.clone_node()
fake_module_node.body = Nodes.StatListNode(
module_node.body.pos,
stats = self.cypclass_wrappers
)
self.analyser(fake_module_node)
module_node.body.stats.extend(fake_module_node.body.stats)
# TODO: can cypclasses be nested in something other than this ? # TODO: can cypclasses be nested in something other than this ?
# can cypclasses even be nested in non-cypclass cpp classes, or structs ? # can cypclasses even be nested in non-cypclass cpp classes, or structs ?
def visit_CStructOrUnionDefNode(self, node): def visit_CStructOrUnionDefNode(self, node):
self.nesting_stack.append(node) self.nesting_stack.append(node)
self.visitchildren(node) self.visitchildren(node)
self.nesting_stack.pop() self.nesting_stack.pop()
top_level = not self.nesting_stack
if top_level:
return_nodes = [node]
for wrapper in self.cypclass_wrappers_stack:
return_nodes.append(wrapper)
self.cypclass_wrappers_stack.clear()
return return_nodes
return node return node
def visit_CppClassNode(self, node): def visit_CppClassNode(self, node):
if node.cypclass: if node.cypclass:
wrapper = self.synthesize_wrapper_cclass(node) wrapper = self.synthesize_wrapper_cclass(node)
if wrapper is not None: if wrapper is not None:
self.cypclass_wrappers_stack.append(wrapper) # forward-declare the wrapper
# visit children and return all wrappers when at the top level wrapper.declare(self.module_scope)
self.cypclass_wrappers.append(wrapper)
# visit children and keep track of nesting
return self.visit_CStructOrUnionDefNode(node) return self.visit_CStructOrUnionDefNode(node)
def find_module_scope(self, scope):
module_scope = scope
while module_scope and not module_scope.is_module_scope:
module_scope = module_scope.outer_scope
return module_scope
def iter_wrapper_methods(self, wrapper_cclass):
for node in wrapper_cclass.body.stats:
if isinstance(node, Nodes.DefNode):
yield node
def synthesize_wrapper_cclass(self, node): def synthesize_wrapper_cclass(self, node):
if node.templates: if node.templates:
# Python wrapper for templated cypclasses not supported yet # Python wrapper for templated cypclasses not supported yet
...@@ -152,39 +151,36 @@ class CypclassWrapperInjection(VisitorTransform): ...@@ -152,39 +151,36 @@ class CypclassWrapperInjection(VisitorTransform):
# whether the is declared with ':' and a suite, or just a forward declaration # whether the is declared with ':' and a suite, or just a forward declaration
node_has_suite = node.attributes is not None node_has_suite = node.attributes is not None
# TODO: forward declare wrapper classes too ?
if not node_has_suite: if not node_has_suite:
return None return None
# TODO: take nesting into account for the name # TODO: take nesting into account for the name
# TODO: check that there is no collision with another name
cclass_name = EncodedString("%s_cyp_wrapper" % node.name) cclass_name = EncodedString("%s_cyp_wrapper" % node.name)
from .ExprNodes import TupleNode from .ExprNodes import TupleNode
bases_args = [] bases_args = []
if node.base_classes: node_type = node.entry.type
first_base = node.base_classes[0] node_type.find_wrapped_base_type()
if isinstance(first_base, Nodes.CSimpleBaseTypeNode) and first_base.templates is None: first_wrapped_base = node_type.first_wrapped_base
first_base_name = first_base.name if first_wrapped_base:
builtin_entry = self.module_scope.lookup(first_base_name) first_base_wrapper_name = first_wrapped_base.wrapper_type.name
if builtin_entry is not None:
return
wrapped_first_base = Nodes.CSimpleBaseTypeNode( wrapped_first_base = Nodes.CSimpleBaseTypeNode(
first_base.pos, node.pos,
name = "%s_cyp_wrapper" % first_base_name, name = first_base_wrapper_name,
module_path = [], module_path = [],
is_basic_c_type = first_base.is_basic_c_type, is_basic_c_type = 0,
signed = first_base.signed, signed = 1,
complex = first_base.complex, complex = 0,
longness = first_base.longness, longness = 0,
is_self_arg = first_base.is_self_arg, is_self_arg = 0,
templates = None templates = None
) )
bases_args.append(wrapped_first_base) bases_args.append(wrapped_first_base)
cclass_bases = TupleNode(node.pos, args=bases_args) cclass_bases = TupleNode(node.pos, args=bases_args)
# the underlying cyobject must come first thing after PyObject_HEAD in the memory layout
# long term, only the base class will declare the underlying attribute
stats = [] stats = []
if not bases_args: if not bases_args:
underlying_cyobject = self.synthesize_underlying_cyobject_attribute(node) underlying_cyobject = self.synthesize_underlying_cyobject_attribute(node)
...@@ -211,6 +207,9 @@ class CypclassWrapperInjection(VisitorTransform): ...@@ -211,6 +207,9 @@ class CypclassWrapperInjection(VisitorTransform):
wrapped_cypclass = node, wrapped_cypclass = node,
) )
# indicate that the cypclass will have a wrapper
node.entry.type.support_wrapper = True
return wrapper return wrapper
def synthesize_underlying_cyobject_attribute(self, node): def synthesize_underlying_cyobject_attribute(self, node):
......
...@@ -5532,6 +5532,14 @@ class CypclassWrapperDefNode(CClassDefNode): ...@@ -5532,6 +5532,14 @@ class CypclassWrapperDefNode(CClassDefNode):
is_cyp_wrapper = 1 is_cyp_wrapper = 1
def declare(self, env):
# > declare the same way as a standard c class
super(CypclassWrapperDefNode, self).declare(env)
# > mark the wrapper type as such
self.entry.type.is_cyp_wrapper = 1
# > associate the wrapper type to the wrapped type
self.wrapped_cypclass.entry.type.wrapper_type = self.entry.type
def analyse_declarations(self, env): def analyse_declarations(self, env):
# > analyse declarations before inserting methods # > analyse declarations before inserting methods
super(CypclassWrapperDefNode, self).analyse_declarations(env) super(CypclassWrapperDefNode, self).analyse_declarations(env)
......
...@@ -182,7 +182,6 @@ def create_pipeline(context, mode, exclude_classes=()): ...@@ -182,7 +182,6 @@ def create_pipeline(context, mode, exclude_classes=()):
# compilation stage. # compilation stage.
stages = [ stages = [
NormalizeTree(context), NormalizeTree(context),
CypclassWrapperInjection(),
PostParse(context), PostParse(context),
_specific_post_parse, _specific_post_parse,
TrackNumpyAttributes(), TrackNumpyAttributes(),
...@@ -199,6 +198,7 @@ def create_pipeline(context, mode, exclude_classes=()): ...@@ -199,6 +198,7 @@ def create_pipeline(context, mode, exclude_classes=()):
ForwardDeclareTypes(context), ForwardDeclareTypes(context),
InjectGilHandling(), InjectGilHandling(),
AnalyseDeclarationsTransform(context), AnalyseDeclarationsTransform(context),
CypclassWrapperInjection(context),
AutoTestDictTransform(context), AutoTestDictTransform(context),
EmbedSignature(context), EmbedSignature(context),
EarlyReplaceBuiltinCalls(context), ## Necessary? EarlyReplaceBuiltinCalls(context), ## Necessary?
......
...@@ -4107,7 +4107,9 @@ def compute_mro_generic(cls): ...@@ -4107,7 +4107,9 @@ def compute_mro_generic(cls):
class CypClassType(CppClassType): class CypClassType(CppClassType):
# lock_mode string (tri-state: "nolock"/"checklock"/"autolock") # lock_mode string (tri-state: "nolock"/"checklock"/"autolock")
# _mro [CppClassType] or None the Method Resolution Order of this cypclass according to Python # _mro [CppClassType] or None the Method Resolution Order of this cypclass according to Python
# support_wrapper boolean whether this cypclass will be wrapped
# wrapper_type PyExtensionType or None the type of the cclass wrapper # wrapper_type PyExtensionType or None the type of the cclass wrapper
# first_wrapped_base CypClassType or None the first cypclass base that has a wrapper if there is one
# wrapped_base_type CypClassType or None the type of the oldest wrapped cypclass base # wrapped_base_type CypClassType or None the type of the oldest wrapped cypclass base
is_cyp_class = 1 is_cyp_class = 1
...@@ -4115,24 +4117,29 @@ class CypClassType(CppClassType): ...@@ -4115,24 +4117,29 @@ class CypClassType(CppClassType):
def __init__(self, name, scope, cname, base_classes, templates=None, template_type=None, nogil=0, lock_mode=None, activable=False): def __init__(self, name, scope, cname, base_classes, templates=None, template_type=None, nogil=0, lock_mode=None, activable=False):
CppClassType.__init__(self, name, scope, cname, base_classes, templates, template_type, nogil) CppClassType.__init__(self, name, scope, cname, base_classes, templates, template_type, nogil)
if base_classes:
self.find_wrapped_base_type(base_classes)
self.lock_mode = lock_mode if lock_mode else "autolock" self.lock_mode = lock_mode if lock_mode else "autolock"
self.activable = activable self.activable = activable
self._mro = None self._mro = None
self.support_wrapper = False
self.wrapper_type = None self.wrapper_type = None
self.wrapped_base_type = None self.wrapped_base_type = None
def find_wrapped_base_type(self, base_classes): # find the first base that has a wrapper, if there is one
first_wrapped_cypclass_base = None # find the oldest superclass such that all intervening classes have a wrapper
for base_type in base_classes: def find_wrapped_base_type(self):
if base_type.is_cyp_class and base_type.wrapper_type: # default: the oldest superclass is self and there are no bases
first_wrapped_cypclass_base = base_type
break
if first_wrapped_cypclass_base:
self.wrapped_base_type = first_wrapped_cypclass_base.wrapped_base_type
else:
self.wrapped_base_type = self self.wrapped_base_type = self
self.first_wrapped_base = None
# if there are no bases, no need to look further
if not self.base_classes:
return
# otherwise, find the first wrapped base (if there is one) and take the same oldest superclass
for base_type in self.base_classes:
if base_type.is_cyp_class and base_type.support_wrapper:
# this base type is the first wrapped base
self.first_wrapped_base = base_type
self.wrapped_base_type = base_type.wrapped_base_type
break
# Return the MRO for this cypclass # Return the MRO for this cypclass
# Compute all the mro needed when a previous computation is not available # Compute all the mro needed when a previous computation is not available
......
...@@ -762,8 +762,6 @@ class Scope(object): ...@@ -762,8 +762,6 @@ class Scope(object):
entry.already_declared_here() entry.already_declared_here()
else: else:
entry.type.base_classes = base_classes entry.type.base_classes = base_classes
if cypclass:
entry.type.find_wrapped_base_type(base_classes)
if templates or entry.type.templates: if templates or entry.type.templates:
if templates != entry.type.templates: if templates != entry.type.templates:
error(pos, "Template parameters do not match previous declaration") error(pos, "Template parameters do not match previous declaration")
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment