Commit 51ac02b7 authored by Stefan Behnel's avatar Stefan Behnel Committed by GitHub

Merge pull request #2640 from mattip/ctypedef-class-getter2

ENH: allow @property decorator on external ctypedef classes
parents 3a5a2e3b 499fd678
...@@ -7154,6 +7154,8 @@ class AttributeNode(ExprNode): ...@@ -7154,6 +7154,8 @@ class AttributeNode(ExprNode):
obj_code = obj.result_as(obj.type) obj_code = obj.result_as(obj.type)
#print "...obj_code =", obj_code ### #print "...obj_code =", obj_code ###
if self.entry and self.entry.is_cmethod: if self.entry and self.entry.is_cmethod:
if self.entry.is_cgetter:
return "%s(%s)" % (self.entry.func_cname, obj_code)
if obj.type.is_extension_type and not self.entry.is_builtin_cmethod: if obj.type.is_extension_type and not self.entry.is_builtin_cmethod:
if self.entry.final_func_cname: if self.entry.final_func_cname:
return self.entry.final_func_cname return self.entry.final_func_cname
...@@ -11242,6 +11244,10 @@ class NumBinopNode(BinopNode): ...@@ -11242,6 +11244,10 @@ class NumBinopNode(BinopNode):
self.operand2 = self.operand2.coerce_to(self.type, env) self.operand2 = self.operand2.coerce_to(self.type, env)
def compute_c_result_type(self, type1, type2): def compute_c_result_type(self, type1, type2):
if type1.is_cfunction and type1.entry.is_cgetter:
type1 = type1.return_type
if type2.is_cfunction and type2.entry.is_cgetter:
type2 = type2.return_type
if self.c_types_okay(type1, type2): if self.c_types_okay(type1, type2):
widest_type = PyrexTypes.widest_numeric_type(type1, type2) widest_type = PyrexTypes.widest_numeric_type(type1, type2)
if widest_type is PyrexTypes.c_bint_type: if widest_type is PyrexTypes.c_bint_type:
......
...@@ -2321,7 +2321,8 @@ class CFuncDefNode(FuncDefNode): ...@@ -2321,7 +2321,8 @@ class CFuncDefNode(FuncDefNode):
# is_static_method whether this is a static method # is_static_method whether this is a static method
# is_c_class_method whether this is a cclass method # is_c_class_method whether this is a cclass method
child_attrs = ["base_type", "declarator", "body", "py_func_stat"] child_attrs = ["base_type", "declarator", "body", "py_func_stat", "decorators"]
outer_attrs = ["decorators"]
inline_in_pxd = False inline_in_pxd = False
decorators = None decorators = None
...@@ -2341,6 +2342,21 @@ class CFuncDefNode(FuncDefNode): ...@@ -2341,6 +2342,21 @@ class CFuncDefNode(FuncDefNode):
return self.py_func.code_object if self.py_func else None return self.py_func.code_object if self.py_func else None
def analyse_declarations(self, env): def analyse_declarations(self, env):
is_property = 0
if self.decorators:
for decorator in self.decorators:
func = decorator.decorator
if func.is_name:
if func.name == 'property':
is_property = 1
elif func.name == 'staticmethod':
pass
else:
error(self.pos, "Cannot handle %s decorators yet" % func.name)
else:
error(self.pos,
"Cannot handle %s decorators yet" % type(func).__name__)
self.is_c_class_method = env.is_c_class_scope self.is_c_class_method = env.is_c_class_scope
if self.directive_locals is None: if self.directive_locals is None:
self.directive_locals = {} self.directive_locals = {}
...@@ -2355,20 +2371,20 @@ class CFuncDefNode(FuncDefNode): ...@@ -2355,20 +2371,20 @@ class CFuncDefNode(FuncDefNode):
self.is_static_method = 'staticmethod' in env.directives and not env.lookup_here('staticmethod') self.is_static_method = 'staticmethod' in env.directives and not env.lookup_here('staticmethod')
# The 2 here is because we need both function and argument names. # The 2 here is because we need both function and argument names.
if isinstance(self.declarator, CFuncDeclaratorNode): if isinstance(self.declarator, CFuncDeclaratorNode):
name_declarator, type = self.declarator.analyse( name_declarator, typ = self.declarator.analyse(
base_type, env, nonempty=2 * (self.body is not None), base_type, env, nonempty=2 * (self.body is not None),
directive_locals=self.directive_locals, visibility=self.visibility) directive_locals=self.directive_locals, visibility=self.visibility)
else: else:
name_declarator, type = self.declarator.analyse( name_declarator, typ = self.declarator.analyse(
base_type, env, nonempty=2 * (self.body is not None), visibility=self.visibility) base_type, env, nonempty=2 * (self.body is not None), visibility=self.visibility)
if not type.is_cfunction: if not typ.is_cfunction:
error(self.pos, "Suite attached to non-function declaration") error(self.pos, "Suite attached to non-function declaration")
# Remember the actual type according to the function header # Remember the actual type according to the function header
# written here, because the type in the symbol table entry # written here, because the type in the symbol table entry
# may be different if we're overriding a C method inherited # may be different if we're overriding a C method inherited
# from the base type of an extension type. # from the base type of an extension type.
self.type = type self.type = typ
type.is_overridable = self.overridable typ.is_overridable = self.overridable
declarator = self.declarator declarator = self.declarator
while not hasattr(declarator, 'args'): while not hasattr(declarator, 'args'):
declarator = declarator.base declarator = declarator.base
...@@ -2381,11 +2397,11 @@ class CFuncDefNode(FuncDefNode): ...@@ -2381,11 +2397,11 @@ class CFuncDefNode(FuncDefNode):
error(self.cfunc_declarator.pos, error(self.cfunc_declarator.pos,
"Function with optional arguments may not be declared public or api") "Function with optional arguments may not be declared public or api")
if type.exception_check == '+' and self.visibility != 'extern': if typ.exception_check == '+' and self.visibility != 'extern':
warning(self.cfunc_declarator.pos, warning(self.cfunc_declarator.pos,
"Only extern functions can throw C++ exceptions.") "Only extern functions can throw C++ exceptions.")
for formal_arg, type_arg in zip(self.args, type.args): for formal_arg, type_arg in zip(self.args, typ.args):
self.align_argument_type(env, type_arg) self.align_argument_type(env, type_arg)
formal_arg.type = type_arg.type formal_arg.type = type_arg.type
formal_arg.name = type_arg.name formal_arg.name = type_arg.name
...@@ -2406,20 +2422,25 @@ class CFuncDefNode(FuncDefNode): ...@@ -2406,20 +2422,25 @@ class CFuncDefNode(FuncDefNode):
elif 'inline' in self.modifiers: elif 'inline' in self.modifiers:
warning(formal_arg.pos, "Buffer unpacking not optimized away.", 1) warning(formal_arg.pos, "Buffer unpacking not optimized away.", 1)
self._validate_type_visibility(type.return_type, self.pos, env) self._validate_type_visibility(typ.return_type, self.pos, env)
name = name_declarator.name name = name_declarator.name
cname = name_declarator.cname cname = name_declarator.cname
type.is_const_method = self.is_const_method typ.is_const_method = self.is_const_method
type.is_static_method = self.is_static_method typ.is_static_method = self.is_static_method
self.entry = env.declare_cfunction( self.entry = env.declare_cfunction(
name, type, self.pos, name, typ, self.pos,
cname=cname, visibility=self.visibility, api=self.api, cname=cname, visibility=self.visibility, api=self.api,
defining=self.body is not None, modifiers=self.modifiers, defining=self.body is not None, modifiers=self.modifiers,
overridable=self.overridable) overridable=self.overridable)
if is_property:
self.entry.is_property = 1
env.property_entries.append(self.entry)
env.cfunc_entries.remove(self.entry)
self.entry.inline_func_in_pxd = self.inline_in_pxd self.entry.inline_func_in_pxd = self.inline_in_pxd
self.return_type = type.return_type self.return_type = typ.return_type
if self.return_type.is_array and self.visibility != 'extern': if self.return_type.is_array and self.visibility != 'extern':
error(self.pos, "Function cannot return an array") error(self.pos, "Function cannot return an array")
if self.return_type.is_cpp_class: if self.return_type.is_cpp_class:
......
...@@ -1031,7 +1031,7 @@ class InterpretCompilerDirectives(CythonTransform): ...@@ -1031,7 +1031,7 @@ class InterpretCompilerDirectives(CythonTransform):
else: else:
realdecs.append(dec) realdecs.append(dec)
if realdecs and (scope_name == 'cclass' or if realdecs and (scope_name == 'cclass' or
isinstance(node, (Nodes.CFuncDefNode, Nodes.CClassDefNode, Nodes.CVarDefNode))): isinstance(node, (Nodes.CClassDefNode, Nodes.CVarDefNode))):
raise PostParseError(realdecs[0].pos, "Cdef functions/classes cannot take arbitrary decorators.") raise PostParseError(realdecs[0].pos, "Cdef functions/classes cannot take arbitrary decorators.")
node.decorators = realdecs[::-1] + both[::-1] node.decorators = realdecs[::-1] + both[::-1]
# merge or override repeated directives # merge or override repeated directives
...@@ -2239,6 +2239,29 @@ class AnalyseExpressionsTransform(CythonTransform): ...@@ -2239,6 +2239,29 @@ class AnalyseExpressionsTransform(CythonTransform):
node = node.base node = node.base
return node return node
class ReplacePropertyNode(CythonTransform):
def visit_CFuncDefNode(self, node):
if not node.decorators:
return node
decorator = self.find_first_decorator(node, 'property')
if decorator:
# transform class functions into c-getters
if len(node.decorators) > 1:
# raises
self._reject_decorated_property(node, decorator_node)
node.entry.is_cgetter = True
# Add a func_cname to be output instead of the attribute
node.entry.func_cname = node.body.stats[0].value.function.name
node.decorators.remove(decorator)
return node
def find_first_decorator(self, node, name):
for decorator_node in node.decorators[::-1]:
decorator = decorator_node.decorator
if decorator.is_name and decorator.name == name:
return decorator_node
return None
class FindInvalidUseOfFusedTypes(CythonTransform): class FindInvalidUseOfFusedTypes(CythonTransform):
......
...@@ -146,7 +146,7 @@ def create_pipeline(context, mode, exclude_classes=()): ...@@ -146,7 +146,7 @@ def create_pipeline(context, mode, exclude_classes=()):
from .ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform from .ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
from .ParseTreeTransforms import TrackNumpyAttributes, InterpretCompilerDirectives, TransformBuiltinMethods from .ParseTreeTransforms import TrackNumpyAttributes, InterpretCompilerDirectives, TransformBuiltinMethods
from .ParseTreeTransforms import ExpandInplaceOperators, ParallelRangeTransform from .ParseTreeTransforms import ExpandInplaceOperators, ParallelRangeTransform
from .ParseTreeTransforms import CalculateQualifiedNamesTransform from .ParseTreeTransforms import CalculateQualifiedNamesTransform, ReplacePropertyNode
from .TypeInference import MarkParallelAssignments, MarkOverflowingArithmetic from .TypeInference import MarkParallelAssignments, MarkOverflowingArithmetic
from .ParseTreeTransforms import AdjustDefByDirectives, AlignFunctionDefinitions from .ParseTreeTransforms import AdjustDefByDirectives, AlignFunctionDefinitions
from .ParseTreeTransforms import RemoveUnreachableCode, GilCheck from .ParseTreeTransforms import RemoveUnreachableCode, GilCheck
...@@ -198,6 +198,7 @@ def create_pipeline(context, mode, exclude_classes=()): ...@@ -198,6 +198,7 @@ def create_pipeline(context, mode, exclude_classes=()):
AnalyseDeclarationsTransform(context), AnalyseDeclarationsTransform(context),
AutoTestDictTransform(context), AutoTestDictTransform(context),
EmbedSignature(context), EmbedSignature(context),
ReplacePropertyNode(context),
EarlyReplaceBuiltinCalls(context), ## Necessary? EarlyReplaceBuiltinCalls(context), ## Necessary?
TransformBuiltinMethods(context), TransformBuiltinMethods(context),
MarkParallelAssignments(context), MarkParallelAssignments(context),
......
...@@ -134,6 +134,7 @@ class Entry(object): ...@@ -134,6 +134,7 @@ class Entry(object):
# cf_used boolean Entry is used # cf_used boolean Entry is used
# is_fused_specialized boolean Whether this entry of a cdef or def function # is_fused_specialized boolean Whether this entry of a cdef or def function
# is a specialization # is a specialization
# is_cgetter boolean Is a c-level getter function
# TODO: utility_code and utility_code_definition serves the same purpose... # TODO: utility_code and utility_code_definition serves the same purpose...
...@@ -203,6 +204,7 @@ class Entry(object): ...@@ -203,6 +204,7 @@ class Entry(object):
error_on_uninitialized = False error_on_uninitialized = False
cf_used = True cf_used = True
outer_entry = None outer_entry = None
is_cgetter = False
def __init__(self, name, cname, type, pos = None, init = None): def __init__(self, name, cname, type, pos = None, init = None):
self.name = name self.name = name
...@@ -829,7 +831,8 @@ class Scope(object): ...@@ -829,7 +831,8 @@ class Scope(object):
type.entry = entry type.entry = entry
return entry return entry
def add_cfunction(self, name, type, pos, cname, visibility, modifiers, inherited=False): def add_cfunction(self, name, type, pos, cname, visibility, modifiers,
inherited=False):
# Add a C function entry without giving it a func_cname. # Add a C function entry without giving it a func_cname.
entry = self.declare(name, cname, type, pos, visibility) entry = self.declare(name, cname, type, pos, visibility)
entry.is_cfunction = 1 entry.is_cfunction = 1
...@@ -1435,7 +1438,8 @@ class ModuleScope(Scope): ...@@ -1435,7 +1438,8 @@ class ModuleScope(Scope):
def declare_cfunction(self, name, type, pos, def declare_cfunction(self, name, type, pos,
cname=None, visibility='private', api=0, in_pxd=0, cname=None, visibility='private', api=0, in_pxd=0,
defining=0, modifiers=(), utility_code=None, overridable=False): defining=0, modifiers=(), utility_code=None,
overridable=False):
if not defining and 'inline' in modifiers: if not defining and 'inline' in modifiers:
# TODO(github/1736): Make this an error. # TODO(github/1736): Make this an error.
warning(pos, "Declarations should not be declared inline.", 1) warning(pos, "Declarations should not be declared inline.", 1)
...@@ -1933,7 +1937,8 @@ class StructOrUnionScope(Scope): ...@@ -1933,7 +1937,8 @@ class StructOrUnionScope(Scope):
def declare_cfunction(self, name, type, pos, def declare_cfunction(self, name, type, pos,
cname=None, visibility='private', api=0, in_pxd=0, cname=None, visibility='private', api=0, in_pxd=0,
defining=0, modifiers=(), overridable=False): # currently no utility code ... defining=0, modifiers=(),
overridable=False): # currently no utility code ...
if overridable: if overridable:
error(pos, "C struct/union member cannot be declared 'cpdef'") error(pos, "C struct/union member cannot be declared 'cpdef'")
return self.declare_var(name, type, pos, return self.declare_var(name, type, pos,
...@@ -2214,7 +2219,8 @@ class CClassScope(ClassScope): ...@@ -2214,7 +2219,8 @@ class CClassScope(ClassScope):
def declare_cfunction(self, name, type, pos, def declare_cfunction(self, name, type, pos,
cname=None, visibility='private', api=0, in_pxd=0, cname=None, visibility='private', api=0, in_pxd=0,
defining=0, modifiers=(), utility_code=None, overridable=False): defining=0, modifiers=(), utility_code=None,
overridable=False):
if get_special_method_signature(name) and not self.parent_type.is_builtin_type: if get_special_method_signature(name) and not self.parent_type.is_builtin_type:
error(pos, "Special methods must be declared with 'def', not 'cdef'") error(pos, "Special methods must be declared with 'def', not 'cdef'")
args = type.args args = type.args
...@@ -2258,7 +2264,8 @@ class CClassScope(ClassScope): ...@@ -2258,7 +2264,8 @@ class CClassScope(ClassScope):
error(pos, error(pos,
"C method '%s' not previously declared in definition part of" "C method '%s' not previously declared in definition part of"
" extension type '%s'" % (name, self.class_name)) " extension type '%s'" % (name, self.class_name))
entry = self.add_cfunction(name, type, pos, cname, visibility, modifiers) entry = self.add_cfunction(name, type, pos, cname, visibility,
modifiers)
if defining: if defining:
entry.func_cname = self.mangle(Naming.func_prefix, name) entry.func_cname = self.mangle(Naming.func_prefix, name)
entry.utility_code = utility_code entry.utility_code = utility_code
...@@ -2274,11 +2281,13 @@ class CClassScope(ClassScope): ...@@ -2274,11 +2281,13 @@ class CClassScope(ClassScope):
return entry return entry
def add_cfunction(self, name, type, pos, cname, visibility, modifiers, inherited=False): def add_cfunction(self, name, type, pos, cname, visibility, modifiers,
inherited=False):
# Add a cfunction entry without giving it a func_cname. # Add a cfunction entry without giving it a func_cname.
prev_entry = self.lookup_here(name) prev_entry = self.lookup_here(name)
entry = ClassScope.add_cfunction(self, name, type, pos, cname, entry = ClassScope.add_cfunction(self, name, type, pos, cname,
visibility, modifiers, inherited=inherited) visibility, modifiers,
inherited=inherited)
entry.is_cmethod = 1 entry.is_cmethod = 1
entry.prev_entry = prev_entry entry.prev_entry = prev_entry
return entry return entry
......
...@@ -4,14 +4,22 @@ PYTHON -c "import runner" ...@@ -4,14 +4,22 @@ PYTHON -c "import runner"
######## setup.py ######## ######## setup.py ########
from Cython.Build.Dependencies import cythonize from Cython.Build.Dependencies import cythonize
from Cython.Compiler.Errors import CompileError
from distutils.core import setup from distutils.core import setup
# force the build order # force the build order
setup(ext_modules= cythonize("foo_extension.pyx")) setup(ext_modules= cythonize("foo_extension.pyx", language_level=3))
setup(ext_modules = cythonize("getter*.pyx")) setup(ext_modules = cythonize("getter[0-9].pyx", language_level=3))
######## foo_nominal.h ######## for name in ("getter_fail0.pyx", "getter_fail1.pyx"):
try:
cythonize(name, language_level=3)
assert False
except CompileError as e:
print("\nGot expected exception, continuing\n")
######## foo.h ########
#include <Python.h> #include <Python.h>
...@@ -26,6 +34,30 @@ typedef struct { ...@@ -26,6 +34,30 @@ typedef struct {
int f2; int f2;
} FooStructNominal; } FooStructNominal;
typedef struct {
PyObject_HEAD
} FooStructOpaque;
#define PyFoo_GET0M(a) ((FooStructNominal*)a)->f0
#define PyFoo_GET1M(a) ((FooStructNominal*)a)->f1
#define PyFoo_GET2M(a) ((FooStructNominal*)a)->f2
int PyFoo_Get0F(FooStructOpaque *f)
{
return PyFoo_GET0M(f);
}
int PyFoo_Get1F(FooStructOpaque *f)
{
return PyFoo_GET1M(f);
}
int PyFoo_Get2F(FooStructOpaque *f)
{
return PyFoo_GET2M(f);
}
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif
...@@ -33,21 +65,24 @@ typedef struct { ...@@ -33,21 +65,24 @@ typedef struct {
######## foo_extension.pyx ######## ######## foo_extension.pyx ########
cdef class Foo: cdef class Foo:
cdef public int field0, field1, field2; cdef public int _field0, _field1, _field2;
def __init__(self, f0, f1, f2): @property
self.field0 = f0 def field0(self):
self.field1 = f1 return self._field0
self.field2 = f2
cdef get_field0(Foo f): @property
return f.field0 def field1(self):
return self._field1
cdef get_field1(Foo f): @property
return f.field1 def field2(self):
return self._field2
cdef get_field2(Foo f): def __init__(self, f0, f1, f2):
return f.field2 self._field0 = f0
self._field1 = f1
self._field2 = f2
# A pure-python class that disallows direct access to fields # A pure-python class that disallows direct access to fields
class OpaqueFoo(Foo): class OpaqueFoo(Foo):
...@@ -64,12 +99,11 @@ class OpaqueFoo(Foo): ...@@ -64,12 +99,11 @@ class OpaqueFoo(Foo):
def field2(self): def field2(self):
raise AttributeError('no direct access to field2') raise AttributeError('no direct access to field2')
######## getter0.pyx ######## ######## getter0.pyx ########
# Access base Foo fields from C via aliased field names # Access base Foo fields from C via aliased field names
cdef extern from "foo_nominal.h": cdef extern from "foo.h":
ctypedef class foo_extension.Foo [object FooStructNominal]: ctypedef class foo_extension.Foo [object FooStructNominal]:
cdef: cdef:
...@@ -78,13 +112,70 @@ cdef extern from "foo_nominal.h": ...@@ -78,13 +112,70 @@ cdef extern from "foo_nominal.h":
int field2 "f2" int field2 "f2"
def sum(Foo f): def sum(Foo f):
# the f.__getattr__('field0') is replaced in c by f->f0 # Note - not a cdef function but compiling the f.__getattr__('field0')
# notices the alias and replaces the __getattr__ in c by f->f0 anyway
return f.field0 + f.field1 + f.field2 return f.field0 + f.field1 + f.field2
######## getter1.pyx ########
# Access base Foo fields from C via getter functions
cdef extern from "foo.h":
ctypedef class foo_extension.Foo [object FooStructOpaque, check_size ignore]:
@property
cdef int fieldM0(self):
return PyFoo_GET0M(self)
@property
cdef int fieldF1(self):
return PyFoo_Get1F(self)
@property
cdef int fieldM2(self):
return PyFoo_GET2M(self)
int PyFoo_GET0M(Foo); # this is actually a macro !
int PyFoo_Get1F(Foo);
int PyFoo_GET2M(Foo); # this is actually a macro !
def sum(Foo f):
# Note - not a cdef function but compiling the f.__getattr__('field0')
# notices the getter and replaces the __getattr__ in c by PyFoo_GET anyway
return f.fieldM0 + f.fieldF1 + f.fieldM2
######## getter_fail0.pyx ########
# Make sure not all decorators are accepted
cdef extern from "foo.h":
ctypedef class foo_extension.Foo [object FooStructOpaque]:
@classmethod
cdef void field0():
print('in staticmethod of Foo')
######## getter_fail1.pyx ########
# Make sure not all decorators are accepted
cimport cython
cdef extern from "foo.h":
ctypedef class foo_extension.Foo [object FooStructOpaque]:
@prop.getter
cdef void field0(self):
pass
######## runner.py ######## ######## runner.py ########
import foo_extension, getter0 import warnings
import foo_extension, getter0, getter1
def sum(f):
# pure python field access, but code is identical to cython cdef sum
return f.field0 + f.field1 + f.field2
# Baseline test: if this fails something else is wrong
foo = foo_extension.Foo(23, 123, 1023) foo = foo_extension.Foo(23, 123, 1023)
assert foo.field0 == 23 assert foo.field0 == 23
...@@ -92,18 +183,28 @@ assert foo.field1 == 123 ...@@ -92,18 +183,28 @@ assert foo.field1 == 123
assert foo.field2 == 1023 assert foo.field2 == 1023
ret = getter0.sum(foo) ret = getter0.sum(foo)
assert ret == foo.field0 + foo.field1 + foo.field2 assert ret == sum(foo)
# Aliasing test. Check 'cdef int field0 "f0" works as advertised:
# - C can access the fields through the aliases
# - Python cannot access the fields at all
opaque_foo = foo_extension.OpaqueFoo(23, 123, 1023) opaque_foo = foo_extension.OpaqueFoo(23, 123, 1023)
# C can access the fields through the aliases
opaque_ret = getter0.sum(opaque_foo) opaque_ret = getter0.sum(opaque_foo)
assert opaque_ret == ret assert opaque_ret == ret
try: try:
# Python cannot access the fields
f0 = opaque_ret.field0 f0 = opaque_ret.field0
assert False assert False
except AttributeError as e: except AttributeError as e:
pass pass
# Getter test. Check C-level getter works as advertised:
# - C accesses the fields through getter calls (maybe macros)
# - Python accesses the fields through attribute lookup
opaque_foo = foo_extension.OpaqueFoo(23, 123, 1023)
opaque_ret = getter1.sum(opaque_foo)
assert opaque_ret == ret
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment