Commit ad24a17c authored by scoder's avatar scoder Committed by GitHub

Allow None to coerce to C types separately from other object values. (GH-4740)

This is used by some optimisations for builtins that call C-API functions directly but need to convert None arguments to NULL or special integer values in order to mimic the original Python interface.

Also add and backport the CPython macros for None checks (and True/False, while we're at it):
https://docs.python.org/3/c-api/structures.html#c.Py_Is

Closes https://github.com/cython/cython/issues/4737
See https://github.com/cython/cython/issues/4706
parent 69cb05b3
...@@ -13614,6 +13614,9 @@ class CoerceFromPyTypeNode(CoercionNode): ...@@ -13614,6 +13614,9 @@ class CoerceFromPyTypeNode(CoercionNode):
# This node is used to convert a Python object # This node is used to convert a Python object
# to a C data type. # to a C data type.
# Allow 'None' to map to a difference C value independent of the coercion, e.g. to 'NULL' or '0'.
special_none_cvalue = None
def __init__(self, result_type, arg, env): def __init__(self, result_type, arg, env):
CoercionNode.__init__(self, arg) CoercionNode.__init__(self, arg)
self.type = result_type self.type = result_type
...@@ -13643,7 +13646,10 @@ class CoerceFromPyTypeNode(CoercionNode): ...@@ -13643,7 +13646,10 @@ class CoerceFromPyTypeNode(CoercionNode):
NoneCheckNode.generate_if_needed(self.arg, code, "expected bytes, NoneType found") NoneCheckNode.generate_if_needed(self.arg, code, "expected bytes, NoneType found")
code.putln(self.type.from_py_call_code( code.putln(self.type.from_py_call_code(
self.arg.py_result(), self.result(), self.pos, code, from_py_function=from_py_function)) self.arg.py_result(), self.result(), self.pos, code,
from_py_function=from_py_function,
special_none_cvalue=self.special_none_cvalue,
))
if self.type.is_pyobject: if self.type.is_pyobject:
self.generate_gotref(code) self.generate_gotref(code)
......
...@@ -3620,6 +3620,8 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, ...@@ -3620,6 +3620,8 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin,
return node return node
if len(args) < 2: if len(args) < 2:
args.append(ExprNodes.NullNode(node.pos)) args.append(ExprNodes.NullNode(node.pos))
else:
self._inject_null_for_none(args, 1)
self._inject_int_default_argument( self._inject_int_default_argument(
node, args, 2, PyrexTypes.c_py_ssize_t_type, "-1") node, args, 2, PyrexTypes.c_py_ssize_t_type, "-1")
...@@ -4135,13 +4137,35 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, ...@@ -4135,13 +4137,35 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin,
format_args=[attr_name]) format_args=[attr_name])
return self_arg return self_arg
obj_to_obj_func_type = PyrexTypes.CFuncType(
PyrexTypes.py_object_type, [
PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None)
])
def _inject_null_for_none(self, args, index):
if len(args) <= index:
return
arg = args[index]
args[index] = ExprNodes.NullNode(arg.pos) if arg.is_none else ExprNodes.PythonCapiCallNode(
arg.pos, "__Pyx_NoneAsNull",
self.obj_to_obj_func_type,
args=[arg.coerce_to_simple(self.current_env())],
is_temp=0,
)
def _inject_int_default_argument(self, node, args, arg_index, type, default_value): def _inject_int_default_argument(self, node, args, arg_index, type, default_value):
# Python usually allows passing None for range bounds,
# so we treat that as requesting the default.
assert len(args) >= arg_index assert len(args) >= arg_index
if len(args) == arg_index: if len(args) == arg_index or args[arg_index].is_none:
args.append(ExprNodes.IntNode(node.pos, value=str(default_value), args.append(ExprNodes.IntNode(node.pos, value=str(default_value),
type=type, constant_result=default_value)) type=type, constant_result=default_value))
else: else:
args[arg_index] = args[arg_index].coerce_to(type, self.current_env()) arg = args[arg_index].coerce_to(type, self.current_env())
if isinstance(arg, ExprNodes.CoerceFromPyTypeNode):
# Add a runtime check for None and map it to the default value.
arg.special_none_cvalue = str(default_value)
args[arg_index] = arg
def _inject_bint_default_argument(self, node, args, arg_index, default_value): def _inject_bint_default_argument(self, node, args, arg_index, default_value):
assert len(args) >= arg_index assert len(args) >= arg_index
......
...@@ -341,7 +341,8 @@ class PyrexType(BaseType): ...@@ -341,7 +341,8 @@ class PyrexType(BaseType):
return 0 return 0
def _assign_from_py_code(self, source_code, result_code, error_pos, code, def _assign_from_py_code(self, source_code, result_code, error_pos, code,
from_py_function=None, error_condition=None, extra_args=None): from_py_function=None, error_condition=None, extra_args=None,
special_none_cvalue=None):
args = ', ' + ', '.join('%s' % arg for arg in extra_args) if extra_args else '' args = ', ' + ', '.join('%s' % arg for arg in extra_args) if extra_args else ''
convert_call = "%s(%s%s)" % ( convert_call = "%s(%s%s)" % (
from_py_function or self.from_py_function, from_py_function or self.from_py_function,
...@@ -350,6 +351,10 @@ class PyrexType(BaseType): ...@@ -350,6 +351,10 @@ class PyrexType(BaseType):
) )
if self.is_enum: if self.is_enum:
convert_call = typecast(self, c_long_type, convert_call) convert_call = typecast(self, c_long_type, convert_call)
if special_none_cvalue:
# NOTE: requires 'source_code' to be simple!
convert_call = "(__Pyx_Py_IsNone(%s) ? (%s) : (%s))" % (
source_code, special_none_cvalue, convert_call)
return '%s = %s; %s' % ( return '%s = %s; %s' % (
result_code, result_code,
convert_call, convert_call,
...@@ -555,11 +560,13 @@ class CTypedefType(BaseType): ...@@ -555,11 +560,13 @@ class CTypedefType(BaseType):
source_code, result_code, result_type, to_py_function) source_code, result_code, result_type, to_py_function)
def from_py_call_code(self, source_code, result_code, error_pos, code, def from_py_call_code(self, source_code, result_code, error_pos, code,
from_py_function=None, error_condition=None): from_py_function=None, error_condition=None,
special_none_cvalue=None):
return self.typedef_base_type.from_py_call_code( return self.typedef_base_type.from_py_call_code(
source_code, result_code, error_pos, code, source_code, result_code, error_pos, code,
from_py_function or self.from_py_function, from_py_function or self.from_py_function,
error_condition or self.error_condition(result_code) error_condition or self.error_condition(result_code),
special_none_cvalue=special_none_cvalue,
) )
def overflow_check_binop(self, binop, env, const_rhs=False): def overflow_check_binop(self, binop, env, const_rhs=False):
...@@ -978,13 +985,16 @@ class MemoryViewSliceType(PyrexType): ...@@ -978,13 +985,16 @@ class MemoryViewSliceType(PyrexType):
return True return True
def from_py_call_code(self, source_code, result_code, error_pos, code, def from_py_call_code(self, source_code, result_code, error_pos, code,
from_py_function=None, error_condition=None): from_py_function=None, error_condition=None,
special_none_cvalue=None):
# NOTE: auto-detection of readonly buffers is disabled: # NOTE: auto-detection of readonly buffers is disabled:
# writable = self.writable_needed or not self.dtype.is_const # writable = self.writable_needed or not self.dtype.is_const
writable = not self.dtype.is_const writable = not self.dtype.is_const
return self._assign_from_py_code( return self._assign_from_py_code(
source_code, result_code, error_pos, code, from_py_function, error_condition, source_code, result_code, error_pos, code, from_py_function, error_condition,
extra_args=['PyBUF_WRITABLE' if writable else '0']) extra_args=['PyBUF_WRITABLE' if writable else '0'],
special_none_cvalue=special_none_cvalue,
)
def create_to_py_utility_code(self, env): def create_to_py_utility_code(self, env):
self._dtype_to_py_func, self._dtype_from_py_func = self.dtype_object_conversion_funcs(env) self._dtype_to_py_func, self._dtype_from_py_func = self.dtype_object_conversion_funcs(env)
...@@ -1674,9 +1684,11 @@ class CType(PyrexType): ...@@ -1674,9 +1684,11 @@ class CType(PyrexType):
source_code or 'NULL') source_code or 'NULL')
def from_py_call_code(self, source_code, result_code, error_pos, code, def from_py_call_code(self, source_code, result_code, error_pos, code,
from_py_function=None, error_condition=None): from_py_function=None, error_condition=None,
special_none_cvalue=None):
return self._assign_from_py_code( return self._assign_from_py_code(
source_code, result_code, error_pos, code, from_py_function, error_condition) source_code, result_code, error_pos, code, from_py_function, error_condition,
special_none_cvalue=special_none_cvalue)
...@@ -2675,8 +2687,10 @@ class CArrayType(CPointerBaseType): ...@@ -2675,8 +2687,10 @@ class CArrayType(CPointerBaseType):
return True return True
def from_py_call_code(self, source_code, result_code, error_pos, code, def from_py_call_code(self, source_code, result_code, error_pos, code,
from_py_function=None, error_condition=None): from_py_function=None, error_condition=None,
special_none_cvalue=None):
assert not error_condition, '%s: %s' % (error_pos, error_condition) assert not error_condition, '%s: %s' % (error_pos, error_condition)
assert not special_none_cvalue, '%s: %s' % (error_pos, special_none_cvalue) # not currently supported
call_code = "%s(%s, %s, %s)" % ( call_code = "%s(%s, %s, %s)" % (
from_py_function or self.from_py_function, from_py_function or self.from_py_function,
source_code, result_code, self.size) source_code, result_code, self.size)
......
...@@ -628,6 +628,28 @@ class __Pyx_FakeReference { ...@@ -628,6 +628,28 @@ class __Pyx_FakeReference {
#define __Pyx_IS_TYPE(ob, type) (((const PyObject*)ob)->ob_type == (type)) #define __Pyx_IS_TYPE(ob, type) (((const PyObject*)ob)->ob_type == (type))
#endif #endif
#if PY_VERSION_HEX >= 0x030A00B1 || defined(Py_Is)
#define __Pyx_Py_Is(x, y) Py_Is(x, y)
#else
#define __Pyx_Py_Is(x, y) ((x) == (y))
#endif
#if PY_VERSION_HEX >= 0x030A00B1 || defined(Py_IsNone)
#define __Pyx_Py_IsNone(ob) Py_IsNone(ob)
#else
#define __Pyx_Py_IsNone(ob) __Pyx_Py_Is((ob), Py_None)
#endif
#if PY_VERSION_HEX >= 0x030A00B1 || defined(Py_IsTrue)
#define __Pyx_Py_IsTrue(ob) Py_IsTrue(ob)
#else
#define __Pyx_Py_IsTrue(ob) __Pyx_Py_Is((ob), Py_True)
#endif
#if PY_VERSION_HEX >= 0x030A00B1 || defined(Py_IsFalse)
#define __Pyx_Py_IsFalse(ob) Py_IsFalse(ob)
#else
#define __Pyx_Py_IsFalse(ob) __Pyx_Py_Is((ob), Py_False)
#endif
#define __Pyx_NoneAsNull(obj) (__Pyx_Py_IsNone(obj) ? NULL : (obj))
#ifndef Py_TPFLAGS_CHECKTYPES #ifndef Py_TPFLAGS_CHECKTYPES
#define Py_TPFLAGS_CHECKTYPES 0 #define Py_TPFLAGS_CHECKTYPES 0
#endif #endif
......
...@@ -59,6 +59,24 @@ def split_sep(unicode s, sep): ...@@ -59,6 +59,24 @@ def split_sep(unicode s, sep):
ab jd ab jd
sdflk as sa sdflk as sa
sadas asdas fsdf\x20 sadas asdas fsdf\x20
>>> print_all( text.split(None) )
ab
jd
sdflk
as
sa
sadas
asdas
fsdf
>>> print_all( split_sep(text, None) )
ab
jd
sdflk
as
sa
sadas
asdas
fsdf
""" """
return s.split(sep) return s.split(sep)
...@@ -76,6 +94,14 @@ def split_sep_max(unicode s, sep, max): ...@@ -76,6 +94,14 @@ def split_sep_max(unicode s, sep, max):
>>> print_all( split_sep_max(text, sep, 1) ) >>> print_all( split_sep_max(text, sep, 1) )
ab jd ab jd
sdflk as sa sadas asdas fsdf\x20 sdflk as sa sadas asdas fsdf\x20
>>> print_all( text.split(None, 2) )
ab
jd
sdflk as sa sadas asdas fsdf\x20
>>> print_all( split_sep_max(text, None, 2) )
ab
jd
sdflk as sa sadas asdas fsdf\x20
""" """
return s.split(sep, max) return s.split(sep, max)
...@@ -92,6 +118,12 @@ def split_sep_max_int(unicode s, sep): ...@@ -92,6 +118,12 @@ def split_sep_max_int(unicode s, sep):
>>> print_all( split_sep_max_int(text, sep) ) >>> print_all( split_sep_max_int(text, sep) )
ab jd ab jd
sdflk as sa sadas asdas fsdf\x20 sdflk as sa sadas asdas fsdf\x20
>>> print_all( text.split(None, 1) )
ab
jd sdflk as sa sadas asdas fsdf\x20
>>> print_all( split_sep_max_int(text, None) )
ab
jd sdflk as sa sadas asdas fsdf\x20
""" """
return s.split(sep, 1) return s.split(sep, 1)
...@@ -337,6 +369,11 @@ def startswith_start_end(unicode s, sub, start, end): ...@@ -337,6 +369,11 @@ def startswith_start_end(unicode s, sub, start, end):
False False
>>> startswith_start_end(text, 'b X', 1, 5) >>> startswith_start_end(text, 'b X', 1, 5)
'NO MATCH' 'NO MATCH'
>>> text.startswith('ab ', None, None)
True
>>> startswith_start_end(text, 'ab ', None, None)
'MATCH'
""" """
if s.startswith(sub, start, end): if s.startswith(sub, start, end):
return 'MATCH' return 'MATCH'
...@@ -407,6 +444,11 @@ def endswith_start_end(unicode s, sub, start, end): ...@@ -407,6 +444,11 @@ def endswith_start_end(unicode s, sub, start, end):
True True
>>> endswith_start_end(text, ('fsdf ', 'fsdf X'), 10, len(text)-1) >>> endswith_start_end(text, ('fsdf ', 'fsdf X'), 10, len(text)-1)
'NO MATCH' 'NO MATCH'
>>> text.endswith('fsdf ', None, None)
True
>>> endswith_start_end(text, 'fsdf ', None, None)
'MATCH'
""" """
if s.endswith(sub, start, end): if s.endswith(sub, start, end):
return 'MATCH' return 'MATCH'
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment