Commit 943dbd7d authored by Xavier Thompson's avatar Xavier Thompson

Inject wrapper classes into the AST only after the declaration analysis phase

parent f82fb075
......@@ -93,56 +93,55 @@ underlying_name = EncodedString("nogil_cyobject")
# - Insert additional cclass wrapper nodes by returning lists of nodes
# => must run after NormalizeTree (otherwise single statements might not be held in a list)
#
class CypclassWrapperInjection(VisitorTransform):
class CypclassWrapperInjection(CythonTransform):
"""
Synthesize and insert a wrapper c class at the module level for each cypclass that supports it.
- Even nested cypclasses have their wrapper at the module level.
- Must run after NormalizeTree.
- The root node passed when calling this visitor should not be lower than a ModuleNode.
"""
def __call__(self, root):
self.cypclass_wrappers_stack = []
self.nesting_stack = []
self.module_scope = root.scope
from .ParseTreeTransforms import AnalyseDeclarationsTransform
self.analyser = AnalyseDeclarationsTransform(self.context)
return super(CypclassWrapperInjection, self).__call__(root)
def visit_Node(self, node):
def visit_ModuleNode(self, node):
self.cypclass_wrappers = []
self.nesting_stack = []
self.module_scope = node.scope
self.visitchildren(node)
self.inject_cypclass_wrappers(node)
return node
def inject_cypclass_wrappers(self, module_node):
fake_module_node = module_node.clone_node()
fake_module_node.body = Nodes.StatListNode(
module_node.body.pos,
stats = self.cypclass_wrappers
)
self.analyser(fake_module_node)
module_node.body.stats.extend(fake_module_node.body.stats)
# TODO: can cypclasses be nested in something other than this ?
# can cypclasses even be nested in non-cypclass cpp classes, or structs ?
def visit_CStructOrUnionDefNode(self, node):
self.nesting_stack.append(node)
self.visitchildren(node)
self.nesting_stack.pop()
top_level = not self.nesting_stack
if top_level:
return_nodes = [node]
for wrapper in self.cypclass_wrappers_stack:
return_nodes.append(wrapper)
self.cypclass_wrappers_stack.clear()
return return_nodes
return node
def visit_CppClassNode(self, node):
if node.cypclass:
wrapper = self.synthesize_wrapper_cclass(node)
if wrapper is not None:
self.cypclass_wrappers_stack.append(wrapper)
# visit children and return all wrappers when at the top level
# forward-declare the wrapper
wrapper.declare(self.module_scope)
self.cypclass_wrappers.append(wrapper)
# visit children and keep track of nesting
return self.visit_CStructOrUnionDefNode(node)
def find_module_scope(self, scope):
module_scope = scope
while module_scope and not module_scope.is_module_scope:
module_scope = module_scope.outer_scope
return module_scope
def iter_wrapper_methods(self, wrapper_cclass):
for node in wrapper_cclass.body.stats:
if isinstance(node, Nodes.DefNode):
yield node
def synthesize_wrapper_cclass(self, node):
if node.templates:
# Python wrapper for templated cypclasses not supported yet
......@@ -152,39 +151,36 @@ class CypclassWrapperInjection(VisitorTransform):
# whether the is declared with ':' and a suite, or just a forward declaration
node_has_suite = node.attributes is not None
# TODO: forward declare wrapper classes too ?
if not node_has_suite:
return None
# TODO: take nesting into account for the name
# TODO: check that there is no collision with another name
cclass_name = EncodedString("%s_cyp_wrapper" % node.name)
from .ExprNodes import TupleNode
bases_args = []
if node.base_classes:
first_base = node.base_classes[0]
if isinstance(first_base, Nodes.CSimpleBaseTypeNode) and first_base.templates is None:
first_base_name = first_base.name
builtin_entry = self.module_scope.lookup(first_base_name)
if builtin_entry is not None:
return
node_type = node.entry.type
node_type.find_wrapped_base_type()
first_wrapped_base = node_type.first_wrapped_base
if first_wrapped_base:
first_base_wrapper_name = first_wrapped_base.wrapper_type.name
wrapped_first_base = Nodes.CSimpleBaseTypeNode(
first_base.pos,
name = "%s_cyp_wrapper" % first_base_name,
node.pos,
name = first_base_wrapper_name,
module_path = [],
is_basic_c_type = first_base.is_basic_c_type,
signed = first_base.signed,
complex = first_base.complex,
longness = first_base.longness,
is_self_arg = first_base.is_self_arg,
is_basic_c_type = 0,
signed = 1,
complex = 0,
longness = 0,
is_self_arg = 0,
templates = None
)
bases_args.append(wrapped_first_base)
cclass_bases = TupleNode(node.pos, args=bases_args)
# the underlying cyobject must come first thing after PyObject_HEAD in the memory layout
# long term, only the base class will declare the underlying attribute
stats = []
if not bases_args:
underlying_cyobject = self.synthesize_underlying_cyobject_attribute(node)
......@@ -211,6 +207,9 @@ class CypclassWrapperInjection(VisitorTransform):
wrapped_cypclass = node,
)
# indicate that the cypclass will have a wrapper
node.entry.type.support_wrapper = True
return wrapper
def synthesize_underlying_cyobject_attribute(self, node):
......
......@@ -5532,6 +5532,14 @@ class CypclassWrapperDefNode(CClassDefNode):
is_cyp_wrapper = 1
def declare(self, env):
# > declare the same way as a standard c class
super(CypclassWrapperDefNode, self).declare(env)
# > mark the wrapper type as such
self.entry.type.is_cyp_wrapper = 1
# > associate the wrapper type to the wrapped type
self.wrapped_cypclass.entry.type.wrapper_type = self.entry.type
def analyse_declarations(self, env):
# > analyse declarations before inserting methods
super(CypclassWrapperDefNode, self).analyse_declarations(env)
......
......@@ -182,7 +182,6 @@ def create_pipeline(context, mode, exclude_classes=()):
# compilation stage.
stages = [
NormalizeTree(context),
CypclassWrapperInjection(),
PostParse(context),
_specific_post_parse,
TrackNumpyAttributes(),
......@@ -199,6 +198,7 @@ def create_pipeline(context, mode, exclude_classes=()):
ForwardDeclareTypes(context),
InjectGilHandling(),
AnalyseDeclarationsTransform(context),
CypclassWrapperInjection(context),
AutoTestDictTransform(context),
EmbedSignature(context),
EarlyReplaceBuiltinCalls(context), ## Necessary?
......
......@@ -4107,7 +4107,9 @@ def compute_mro_generic(cls):
class CypClassType(CppClassType):
# lock_mode string (tri-state: "nolock"/"checklock"/"autolock")
# _mro [CppClassType] or None the Method Resolution Order of this cypclass according to Python
# support_wrapper boolean whether this cypclass will be wrapped
# wrapper_type PyExtensionType or None the type of the cclass wrapper
# first_wrapped_base CypClassType or None the first cypclass base that has a wrapper if there is one
# wrapped_base_type CypClassType or None the type of the oldest wrapped cypclass base
is_cyp_class = 1
......@@ -4115,24 +4117,29 @@ class CypClassType(CppClassType):
def __init__(self, name, scope, cname, base_classes, templates=None, template_type=None, nogil=0, lock_mode=None, activable=False):
CppClassType.__init__(self, name, scope, cname, base_classes, templates, template_type, nogil)
if base_classes:
self.find_wrapped_base_type(base_classes)
self.lock_mode = lock_mode if lock_mode else "autolock"
self.activable = activable
self._mro = None
self.support_wrapper = False
self.wrapper_type = None
self.wrapped_base_type = None
def find_wrapped_base_type(self, base_classes):
first_wrapped_cypclass_base = None
for base_type in base_classes:
if base_type.is_cyp_class and base_type.wrapper_type:
first_wrapped_cypclass_base = base_type
break
if first_wrapped_cypclass_base:
self.wrapped_base_type = first_wrapped_cypclass_base.wrapped_base_type
else:
# find the first base that has a wrapper, if there is one
# find the oldest superclass such that all intervening classes have a wrapper
def find_wrapped_base_type(self):
# default: the oldest superclass is self and there are no bases
self.wrapped_base_type = self
self.first_wrapped_base = None
# if there are no bases, no need to look further
if not self.base_classes:
return
# otherwise, find the first wrapped base (if there is one) and take the same oldest superclass
for base_type in self.base_classes:
if base_type.is_cyp_class and base_type.support_wrapper:
# this base type is the first wrapped base
self.first_wrapped_base = base_type
self.wrapped_base_type = base_type.wrapped_base_type
break
# Return the MRO for this cypclass
# Compute all the mro needed when a previous computation is not available
......
......@@ -762,8 +762,6 @@ class Scope(object):
entry.already_declared_here()
else:
entry.type.base_classes = base_classes
if cypclass:
entry.type.find_wrapped_base_type(base_classes)
if templates or entry.type.templates:
if templates != entry.type.templates:
error(pos, "Template parameters do not match previous declaration")
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment