Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
C
cython
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
Analytics
Analytics
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Commits
Issue Boards
Open sidebar
Xavier Thompson
cython
Commits
82c13a65
Commit
82c13a65
authored
May 03, 2011
by
Mark Florisson
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Allow indexing of fused cdef functions
parent
f99ecc90
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
128 additions
and
12 deletions
+128
-12
Cython/Compiler/ExprNodes.py
Cython/Compiler/ExprNodes.py
+83
-5
Cython/Compiler/ParseTreeTransforms.py
Cython/Compiler/ParseTreeTransforms.py
+31
-5
Cython/Compiler/PyrexTypes.py
Cython/Compiler/PyrexTypes.py
+8
-2
tests/run/public_fused_types.srctree
tests/run/public_fused_types.srctree
+6
-0
No files found.
Cython/Compiler/ExprNodes.py
View file @
82c13a65
...
@@ -2248,11 +2248,15 @@ class IndexNode(ExprNode):
...
@@ -2248,11 +2248,15 @@ class IndexNode(ExprNode):
self
.
base
.
entry
.
buffer_aux
.
writable_needed
=
True
self
.
base
.
entry
.
buffer_aux
.
writable_needed
=
True
else
:
else
:
base_type
=
self
.
base
.
type
base_type
=
self
.
base
.
type
if
isinstance
(
self
.
index
,
TupleNode
):
self
.
index
.
analyse_types
(
env
,
skip_children
=
skip_child_analysis
)
fused_index_operation
=
base_type
.
is_cfunction
and
base_type
.
is_fused
elif
not
skip_child_analysis
:
if
not
fused_index_operation
:
self
.
index
.
analyse_types
(
env
)
if
isinstance
(
self
.
index
,
TupleNode
):
self
.
original_index_type
=
self
.
index
.
type
self
.
index
.
analyse_types
(
env
,
skip_children
=
skip_child_analysis
)
elif
not
skip_child_analysis
:
self
.
index
.
analyse_types
(
env
)
self
.
original_index_type
=
self
.
index
.
type
if
base_type
.
is_unicode_char
:
if
base_type
.
is_unicode_char
:
# we infer Py_UNICODE/Py_UCS4 for unicode strings in some
# we infer Py_UNICODE/Py_UCS4 for unicode strings in some
# cases, but indexing must still work for them
# cases, but indexing must still work for them
...
@@ -2309,12 +2313,84 @@ class IndexNode(ExprNode):
...
@@ -2309,12 +2313,84 @@ class IndexNode(ExprNode):
self
.
type
=
func_type
.
return_type
self
.
type
=
func_type
.
return_type
if
setting
and
not
func_type
.
return_type
.
is_reference
:
if
setting
and
not
func_type
.
return_type
.
is_reference
:
error
(
self
.
pos
,
"Can't set non-reference result '%s'"
%
self
.
type
)
error
(
self
.
pos
,
"Can't set non-reference result '%s'"
%
self
.
type
)
elif
fused_index_operation
:
self
.
parse_indexed_fused_cdef
(
env
)
else
:
else
:
error
(
self
.
pos
,
error
(
self
.
pos
,
"Attempting to index non-array type '%s'"
%
"Attempting to index non-array type '%s'"
%
base_type
)
base_type
)
self
.
type
=
PyrexTypes
.
error_type
self
.
type
=
PyrexTypes
.
error_type
def
parse_indexed_fused_cdef
(
self
,
env
):
"""
Interpret fused_cdef_func[specific_type1, ...]
Note that if this method is called, we are an indexed cdef function
with fused argument types, and this IndexNode will be replaced by the
NameNode with specific entry just after analysis of expressions by
AnalyseExpressionsTransform.
"""
base_type
=
self
.
base
.
type
def
err
(
msg
,
pos
=
None
):
error
(
pos
or
self
.
pos
,
msg
)
self
.
type
=
PyrexTypes
.
error_type
specific_types
=
[]
positions
=
[]
if
self
.
index
.
is_name
:
positions
.
append
(
self
.
index
.
pos
)
specific_types
.
append
(
self
.
index
.
analyse_as_type
(
env
))
elif
isinstance
(
self
.
index
,
TupleNode
):
for
arg
in
self
.
index
.
args
:
positions
.
append
(
arg
.
pos
)
specific_types
.
append
(
arg
.
analyse_as_type
(
env
))
else
:
return
err
(
"Can only index fused functions with types"
)
fused_types
=
base_type
.
get_fused_types
()
if
len
(
specific_types
)
>
len
(
fused_types
):
return
err
(
"Too many types specified"
)
# See if our index types form valid specializations
for
pos
,
specific_type
,
fused_type
in
zip
(
positions
,
specific_types
,
fused_types
):
if
not
Utils
.
any
([
specific_type
.
same_as
(
t
)
for
t
in
fused_type
.
types
]):
return
err
(
"Type not in fused type"
,
pos
=
pos
)
if
specific_type
is
None
or
specific_type
.
is_error
:
return
fused_to_specific
=
dict
(
zip
(
fused_types
,
specific_types
))
# If we are only partially fused, specialize accordingly
for
fused_type
in
fused_types
:
if
fused_type
not
in
fused_to_specific
:
fused_to_specific
[
fused_type
]
=
fused_type
type
=
base_type
.
specialize
(
fused_to_specific
)
if
type
is
not
base_type
:
import
copy
e
=
copy
.
copy
(
base_type
.
entry
)
e
.
type
=
type
type
.
entry
=
e
if
not
type
.
is_fused
:
# Fully specific, find the signature with the specialized entry
for
signature
in
self
.
base
.
type
.
get_all_specific_function_types
():
if
type
.
same_as
(
signature
):
self
.
type
=
signature
break
else
:
assert
False
else
:
# Only partially specific
self
.
type
=
type
gil_message
=
"Indexing Python object"
gil_message
=
"Indexing Python object"
def
nogil_check
(
self
,
env
):
def
nogil_check
(
self
,
env
):
...
@@ -3041,6 +3117,8 @@ class SimpleCallNode(CallNode):
...
@@ -3041,6 +3117,8 @@ class SimpleCallNode(CallNode):
return
return
elif
hasattr
(
self
.
function
,
'entry'
):
elif
hasattr
(
self
.
function
,
'entry'
):
overloaded_entry
=
self
.
function
.
entry
overloaded_entry
=
self
.
function
.
entry
elif
isinstance
(
self
.
function
,
IndexNode
)
and
self
.
function
.
type
.
is_fused
:
overloaded_entry
=
self
.
function
.
type
.
entry
else
:
else
:
overloaded_entry
=
None
overloaded_entry
=
None
...
...
Cython/Compiler/ParseTreeTransforms.py
View file @
82c13a65
...
@@ -1320,6 +1320,8 @@ if VALUE is not None:
...
@@ -1320,6 +1320,8 @@ if VALUE is not None:
class
AnalyseExpressionsTransform
(
CythonTransform
):
class
AnalyseExpressionsTransform
(
CythonTransform
):
nested_index_node
=
False
def
visit_ModuleNode
(
self
,
node
):
def
visit_ModuleNode
(
self
,
node
):
node
.
scope
.
infer_types
()
node
.
scope
.
infer_types
()
node
.
body
.
analyse_expressions
(
node
.
scope
)
node
.
body
.
analyse_expressions
(
node
.
scope
)
...
@@ -1339,6 +1341,34 @@ class AnalyseExpressionsTransform(CythonTransform):
...
@@ -1339,6 +1341,34 @@ class AnalyseExpressionsTransform(CythonTransform):
self
.
visitchildren
(
node
)
self
.
visitchildren
(
node
)
return
node
return
node
def
visit_IndexNode
(
self
,
node
):
"""
Replace index nodes used to specialize cdef functions with fused
argument types with a NameNode referring to the function with
specialized entry and type.
"""
was_nested
=
self
.
nested_index_node
self
.
nested_index_node
=
True
self
.
visit_Node
(
node
)
self
.
nested_index_node
=
was_nested
type
=
node
.
type
if
type
.
is_cfunction
and
type
.
is_fused
and
not
self
.
nested_index_node
:
error
(
node
.
pos
,
"Not enough types were specified to indicate a "
"specialized function"
)
elif
type
.
is_cfunction
and
node
.
base
.
type
.
is_fused
:
while
not
node
.
is_name
:
node
=
node
.
base
node
.
type
=
type
node
.
entry
=
type
.
entry
print
node
.
entry
.
cname
return
node
return
node
class
ExpandInplaceOperators
(
EnvTransform
):
class
ExpandInplaceOperators
(
EnvTransform
):
def
visit_InPlaceAssignmentNode
(
self
,
node
):
def
visit_InPlaceAssignmentNode
(
self
,
node
):
...
@@ -1924,11 +1954,7 @@ class ReplaceFusedTypeChecks(VisitorTransform):
...
@@ -1924,11 +1954,7 @@ class ReplaceFusedTypeChecks(VisitorTransform):
error
(
node
.
operand2
.
pos
,
error
(
node
.
operand2
.
pos
,
"Can only use 'in' or 'not in' on a fused type"
)
"Can only use 'in' or 'not in' on a fused type"
)
else
:
else
:
if
not
isinstance
(
type2
,
PyrexTypes
.
FusedType
):
types
=
PyrexTypes
.
get_specific_types
(
type2
)
# Composed fused type, get all specific versions
types
=
PyrexTypes
.
get_specific_types
(
type2
)
else
:
types
=
type2
.
types
for
specific_type
in
types
:
for
specific_type
in
types
:
if
type1
.
same_as
(
specific_type
):
if
type1
.
same_as
(
specific_type
):
...
...
Cython/Compiler/PyrexTypes.py
View file @
82c13a65
...
@@ -2079,13 +2079,16 @@ def map_with_specific_entries(entry, func, *args, **kwargs):
...
@@ -2079,13 +2079,16 @@ def map_with_specific_entries(entry, func, *args, **kwargs):
# a normal cdef or not a c function
# a normal cdef or not a c function
func
(
entry
,
*
args
,
**
kwargs
)
func
(
entry
,
*
args
,
**
kwargs
)
def
get_all_specific_permutations
(
fused_types
,
id
=
"
0
"
,
f2s
=
()):
def
get_all_specific_permutations
(
fused_types
,
id
=
""
,
f2s
=
()):
fused_type
=
fused_types
[
0
]
fused_type
=
fused_types
[
0
]
result
=
[]
result
=
[]
for
newid
,
specific_type
in
enumerate
(
fused_type
.
types
):
for
newid
,
specific_type
in
enumerate
(
fused_type
.
types
):
f2s
=
dict
(
f2s
,
**
{
fused_type
:
specific_type
})
f2s
=
dict
(
f2s
,
**
{
fused_type
:
specific_type
})
cname
=
'%s_%s'
%
(
id
,
newid
)
if
id
:
cname
=
'%s_%s'
%
(
id
,
newid
)
else
:
cname
=
newid
if
len
(
fused_types
)
>
1
:
if
len
(
fused_types
)
>
1
:
result
.
extend
(
get_all_specific_permutations
(
result
.
extend
(
get_all_specific_permutations
(
...
@@ -2098,6 +2101,9 @@ def get_all_specific_permutations(fused_types, id="0", f2s=()):
...
@@ -2098,6 +2101,9 @@ def get_all_specific_permutations(fused_types, id="0", f2s=()):
def
get_specific_types
(
type
):
def
get_specific_types
(
type
):
assert
type
.
is_fused
assert
type
.
is_fused
if
isinstance
(
type
,
FusedType
):
return
type
.
types
result
=
[]
result
=
[]
for
cname
,
f2s
in
get_all_specific_permutations
(
type
.
get_fused_types
()):
for
cname
,
f2s
in
get_all_specific_permutations
(
type
.
get_fused_types
()):
result
.
append
(
type
.
specialize
(
f2s
))
result
.
append
(
type
.
specialize
(
f2s
))
...
...
tests/run/public_fused_types.srctree
View file @
82c13a65
...
@@ -79,3 +79,9 @@ assert f(mystruct, 5).a == 10
...
@@ -79,3 +79,9 @@ assert f(mystruct, 5).a == 10
f = <mystruct_t (*)(mystruct_t, int)> add_simple
f = <mystruct_t (*)(mystruct_t, int)> add_simple
assert f(mystruct, 5).a == 10
assert f(mystruct, 5).a == 10
f = add_simple[mystruct_t, int]
assert f(mystruct, 5).a == 10
f = add_simple[mystruct_t][int]
assert f(mystruct, 5).a == 10
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment