Commit a8e96fd3 authored by Robert Bradshaw's avatar Robert Bradshaw

A couple more template inference fixes.

parent 9361decf
...@@ -3517,6 +3517,8 @@ class CppClassType(CType): ...@@ -3517,6 +3517,8 @@ class CppClassType(CType):
return {} return {}
elif actual.is_cpp_class: elif actual.is_cpp_class:
self_template_type = self.template_type or self self_template_type = self.template_type or self
while getattr(self_template_type, 'template_type', None):
self_template_type = self_template_type.template_type
def all_bases(cls): def all_bases(cls):
yield cls yield cls
for parent in cls.base_classes: for parent in cls.base_classes:
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
cimport cython cimport cython
from libcpp.pair cimport pair from libcpp.pair cimport pair
from libcpp.vector cimport vector
cdef extern from "cpp_template_functions_helper.h": cdef extern from "cpp_template_functions_helper.h":
cdef T no_arg[T]() cdef T no_arg[T]()
...@@ -11,6 +12,9 @@ cdef extern from "cpp_template_functions_helper.h": ...@@ -11,6 +12,9 @@ cdef extern from "cpp_template_functions_helper.h":
pair[T, U] method[U](T, U) pair[T, U] method[U](T, U)
U part_method[U](pair[T, U]) U part_method[U](pair[T, U])
U part_method_ref[U](pair[T, U]&) U part_method_ref[U](pair[T, U]&)
int overloaded(double x)
T overloaded(pair[T, T])
U overloaded[U](vector[U])
cdef T nested_deduction[T](const T*) cdef T nested_deduction[T](const T*)
pair[T, U] pair_arg[T, U](pair[T, U] a) pair[T, U] pair_arg[T, U](pair[T, U] a)
cdef T* pointer_param[T](T*) cdef T* pointer_param[T](T*)
...@@ -99,3 +103,16 @@ def test_inference(int k): ...@@ -99,3 +103,16 @@ def test_inference(int k):
res = one_param(&k) res = one_param(&k)
assert cython.typeof(res) == 'int *', cython.typeof(res) assert cython.typeof(res) == 'int *', cython.typeof(res)
return res[0] return res[0]
def test_overload_GH1583():
"""
>>> test_overload_GH1583()
"""
cdef A[int] a
assert a.overloaded(1.5) == 1
cdef pair[int, int] p = (2, 3)
assert a.overloaded(p) == 2
cdef vector[double] v = [0.25, 0.125]
assert a.overloaded(v) == 0.25
# GH Issue #1584
# assert a.overloaded[double](v) == 0.25
...@@ -28,6 +28,17 @@ class A { ...@@ -28,6 +28,17 @@ class A {
U part_method_ref(const std::pair<T, U>& p) { U part_method_ref(const std::pair<T, U>& p) {
return p.second; return p.second;
} }
int overloaded(double d) {
return (int) d;
}
T overloaded(std::pair<T, T> p) {
return p.first;
}
template <typename U>
U overloaded(std::vector<U> v) {
return v[0];
}
}; };
template <typename T> template <typename T>
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment