Commit 82c13a65 authored by Mark Florisson's avatar Mark Florisson

Allow indexing of fused cdef functions

parent f99ecc90
...@@ -2248,11 +2248,15 @@ class IndexNode(ExprNode): ...@@ -2248,11 +2248,15 @@ class IndexNode(ExprNode):
self.base.entry.buffer_aux.writable_needed = True self.base.entry.buffer_aux.writable_needed = True
else: else:
base_type = self.base.type base_type = self.base.type
if isinstance(self.index, TupleNode):
self.index.analyse_types(env, skip_children=skip_child_analysis) fused_index_operation = base_type.is_cfunction and base_type.is_fused
elif not skip_child_analysis: if not fused_index_operation:
self.index.analyse_types(env) if isinstance(self.index, TupleNode):
self.original_index_type = self.index.type self.index.analyse_types(env, skip_children=skip_child_analysis)
elif not skip_child_analysis:
self.index.analyse_types(env)
self.original_index_type = self.index.type
if base_type.is_unicode_char: if base_type.is_unicode_char:
# we infer Py_UNICODE/Py_UCS4 for unicode strings in some # we infer Py_UNICODE/Py_UCS4 for unicode strings in some
# cases, but indexing must still work for them # cases, but indexing must still work for them
...@@ -2309,12 +2313,84 @@ class IndexNode(ExprNode): ...@@ -2309,12 +2313,84 @@ class IndexNode(ExprNode):
self.type = func_type.return_type self.type = func_type.return_type
if setting and not func_type.return_type.is_reference: if setting and not func_type.return_type.is_reference:
error(self.pos, "Can't set non-reference result '%s'" % self.type) error(self.pos, "Can't set non-reference result '%s'" % self.type)
elif fused_index_operation:
self.parse_indexed_fused_cdef(env)
else: else:
error(self.pos, error(self.pos,
"Attempting to index non-array type '%s'" % "Attempting to index non-array type '%s'" %
base_type) base_type)
self.type = PyrexTypes.error_type self.type = PyrexTypes.error_type
def parse_indexed_fused_cdef(self, env):
"""
Interpret fused_cdef_func[specific_type1, ...]
Note that if this method is called, we are an indexed cdef function
with fused argument types, and this IndexNode will be replaced by the
NameNode with specific entry just after analysis of expressions by
AnalyseExpressionsTransform.
"""
base_type = self.base.type
def err(msg, pos=None):
error(pos or self.pos, msg)
self.type = PyrexTypes.error_type
specific_types = []
positions = []
if self.index.is_name:
positions.append(self.index.pos)
specific_types.append(self.index.analyse_as_type(env))
elif isinstance(self.index, TupleNode):
for arg in self.index.args:
positions.append(arg.pos)
specific_types.append(arg.analyse_as_type(env))
else:
return err("Can only index fused functions with types")
fused_types = base_type.get_fused_types()
if len(specific_types) > len(fused_types):
return err("Too many types specified")
# See if our index types form valid specializations
for pos, specific_type, fused_type in zip(positions,
specific_types,
fused_types):
if not Utils.any([specific_type.same_as(t)
for t in fused_type.types]):
return err("Type not in fused type", pos=pos)
if specific_type is None or specific_type.is_error:
return
fused_to_specific = dict(zip(fused_types, specific_types))
# If we are only partially fused, specialize accordingly
for fused_type in fused_types:
if fused_type not in fused_to_specific:
fused_to_specific[fused_type] = fused_type
type = base_type.specialize(fused_to_specific)
if type is not base_type:
import copy
e = copy.copy(base_type.entry)
e.type = type
type.entry = e
if not type.is_fused:
# Fully specific, find the signature with the specialized entry
for signature in self.base.type.get_all_specific_function_types():
if type.same_as(signature):
self.type = signature
break
else:
assert False
else:
# Only partially specific
self.type = type
gil_message = "Indexing Python object" gil_message = "Indexing Python object"
def nogil_check(self, env): def nogil_check(self, env):
...@@ -3041,6 +3117,8 @@ class SimpleCallNode(CallNode): ...@@ -3041,6 +3117,8 @@ class SimpleCallNode(CallNode):
return return
elif hasattr(self.function, 'entry'): elif hasattr(self.function, 'entry'):
overloaded_entry = self.function.entry overloaded_entry = self.function.entry
elif isinstance(self.function, IndexNode) and self.function.type.is_fused:
overloaded_entry = self.function.type.entry
else: else:
overloaded_entry = None overloaded_entry = None
......
...@@ -1320,6 +1320,8 @@ if VALUE is not None: ...@@ -1320,6 +1320,8 @@ if VALUE is not None:
class AnalyseExpressionsTransform(CythonTransform): class AnalyseExpressionsTransform(CythonTransform):
nested_index_node = False
def visit_ModuleNode(self, node): def visit_ModuleNode(self, node):
node.scope.infer_types() node.scope.infer_types()
node.body.analyse_expressions(node.scope) node.body.analyse_expressions(node.scope)
...@@ -1339,6 +1341,34 @@ class AnalyseExpressionsTransform(CythonTransform): ...@@ -1339,6 +1341,34 @@ class AnalyseExpressionsTransform(CythonTransform):
self.visitchildren(node) self.visitchildren(node)
return node return node
def visit_IndexNode(self, node):
"""
Replace index nodes used to specialize cdef functions with fused
argument types with a NameNode referring to the function with
specialized entry and type.
"""
was_nested = self.nested_index_node
self.nested_index_node = True
self.visit_Node(node)
self.nested_index_node = was_nested
type = node.type
if type.is_cfunction and type.is_fused and not self.nested_index_node:
error(node.pos, "Not enough types were specified to indicate a "
"specialized function")
elif type.is_cfunction and node.base.type.is_fused:
while not node.is_name:
node = node.base
node.type = type
node.entry = type.entry
print node.entry.cname
return node
return node
class ExpandInplaceOperators(EnvTransform): class ExpandInplaceOperators(EnvTransform):
def visit_InPlaceAssignmentNode(self, node): def visit_InPlaceAssignmentNode(self, node):
...@@ -1924,11 +1954,7 @@ class ReplaceFusedTypeChecks(VisitorTransform): ...@@ -1924,11 +1954,7 @@ class ReplaceFusedTypeChecks(VisitorTransform):
error(node.operand2.pos, error(node.operand2.pos,
"Can only use 'in' or 'not in' on a fused type") "Can only use 'in' or 'not in' on a fused type")
else: else:
if not isinstance(type2, PyrexTypes.FusedType): types = PyrexTypes.get_specific_types(type2)
# Composed fused type, get all specific versions
types = PyrexTypes.get_specific_types(type2)
else:
types = type2.types
for specific_type in types: for specific_type in types:
if type1.same_as(specific_type): if type1.same_as(specific_type):
......
...@@ -2079,13 +2079,16 @@ def map_with_specific_entries(entry, func, *args, **kwargs): ...@@ -2079,13 +2079,16 @@ def map_with_specific_entries(entry, func, *args, **kwargs):
# a normal cdef or not a c function # a normal cdef or not a c function
func(entry, *args, **kwargs) func(entry, *args, **kwargs)
def get_all_specific_permutations(fused_types, id="0", f2s=()): def get_all_specific_permutations(fused_types, id="", f2s=()):
fused_type = fused_types[0] fused_type = fused_types[0]
result = [] result = []
for newid, specific_type in enumerate(fused_type.types): for newid, specific_type in enumerate(fused_type.types):
f2s = dict(f2s, **{ fused_type: specific_type }) f2s = dict(f2s, **{ fused_type: specific_type })
cname = '%s_%s' % (id, newid) if id:
cname = '%s_%s' % (id, newid)
else:
cname = newid
if len(fused_types) > 1: if len(fused_types) > 1:
result.extend(get_all_specific_permutations( result.extend(get_all_specific_permutations(
...@@ -2098,6 +2101,9 @@ def get_all_specific_permutations(fused_types, id="0", f2s=()): ...@@ -2098,6 +2101,9 @@ def get_all_specific_permutations(fused_types, id="0", f2s=()):
def get_specific_types(type): def get_specific_types(type):
assert type.is_fused assert type.is_fused
if isinstance(type, FusedType):
return type.types
result = [] result = []
for cname, f2s in get_all_specific_permutations(type.get_fused_types()): for cname, f2s in get_all_specific_permutations(type.get_fused_types()):
result.append(type.specialize(f2s)) result.append(type.specialize(f2s))
......
...@@ -79,3 +79,9 @@ assert f(mystruct, 5).a == 10 ...@@ -79,3 +79,9 @@ assert f(mystruct, 5).a == 10
f = <mystruct_t (*)(mystruct_t, int)> add_simple f = <mystruct_t (*)(mystruct_t, int)> add_simple
assert f(mystruct, 5).a == 10 assert f(mystruct, 5).a == 10
f = add_simple[mystruct_t, int]
assert f(mystruct, 5).a == 10
f = add_simple[mystruct_t][int]
assert f(mystruct, 5).a == 10
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment