Commit 3f026222 authored by scoder's avatar scoder

Merge pull request #408 from insertinterestingnamehere/operator_exceptions

Fix exception handling for overloaded operators.
parents a07634b3 b20ed656
...@@ -701,6 +701,8 @@ class FunctionState(object): ...@@ -701,6 +701,8 @@ class FunctionState(object):
""" """
if type.is_const and not type.is_reference: if type.is_const and not type.is_reference:
type = type.const_base_type type = type.const_base_type
elif type.is_reference and not type.is_fake_reference:
type = type.ref_base_type
if not type.is_pyobject and not type.is_memoryviewslice: if not type.is_pyobject and not type.is_memoryviewslice:
# Make manage_ref canonical, so that manage_ref will always mean # Make manage_ref canonical, so that manage_ref will always mean
# a decref is needed. # a decref is needed.
......
...@@ -182,6 +182,58 @@ def infer_sequence_item_type(env, seq_node, index_node=None, seq_type=None): ...@@ -182,6 +182,58 @@ def infer_sequence_item_type(env, seq_node, index_node=None, seq_type=None):
return item_types.pop() return item_types.pop()
return None return None
def get_exception_handler(exception_value):
if exception_value is None:
return "__Pyx_CppExn2PyErr();"
elif exception_value.type.is_pyobject:
return 'try { throw; } catch(const std::exception& exn) { PyErr_SetString(%s, exn.what()); } catch(...) { PyErr_SetNone(%s); }' % (
exception_value.entry.cname,
exception_value.entry.cname)
else:
return '%s(); if (!PyErr_Occurred()) PyErr_SetString(PyExc_RuntimeError , "Error converting c++ exception.");' % exception_value.entry.cname
def translate_cpp_exception(code, pos, inside, exception_value, nogil):
raise_py_exception = get_exception_handler(exception_value)
code.putln("try {")
code.putln("%s" % inside)
code.putln("} catch(...) {")
if nogil:
code.put_ensure_gil(declare_gilstate=True)
code.putln(raise_py_exception)
if nogil:
code.put_release_ensured_gil()
code.putln(code.error_goto(pos))
code.putln("}")
# Used to handle the case where an lvalue expression and an overloaded assignment
# both have an exception declaration.
def translate_double_cpp_exception(code, pos, lhs_type, lhs_code, rhs_code,
lhs_exc_val, assign_exc_val, nogil):
handle_lhs_exc = get_exception_handler(lhs_exc_val)
handle_assignment_exc = get_exception_handler(assign_exc_val)
code.putln("try {")
code.putln(lhs_type.declaration_code("__pyx_local_lvalue = %s;" % lhs_code))
code.putln("try {")
code.putln("__pyx_local_lvalue = %s;" % rhs_code)
# Catch any exception from the overloaded assignment.
code.putln("} catch(...) {")
if nogil:
code.put_ensure_gil(declare_gilstate=True)
code.putln(handle_assignment_exc)
if nogil:
code.put_release_ensured_gil()
code.putln(code.error_goto(pos))
code.putln("}")
# Catch any exception from evaluating lhs.
code.putln("} catch(...) {")
if nogil:
code.put_ensure_gil(declare_gilstate=True)
code.putln(handle_lhs_exc)
if nogil:
code.put_release_ensured_gil()
code.putln(code.error_goto(pos))
code.putln('}')
class ExprNode(Node): class ExprNode(Node):
# subexprs [string] Class var holding names of subexpr node attrs # subexprs [string] Class var holding names of subexpr node attrs
...@@ -700,7 +752,8 @@ class ExprNode(Node): ...@@ -700,7 +752,8 @@ class ExprNode(Node):
else: else:
self.generate_subexpr_disposal_code(code) self.generate_subexpr_disposal_code(code)
def generate_assignment_code(self, rhs, code, overloaded_assignment=False): def generate_assignment_code(self, rhs, code, overloaded_assignment=False,
exception_check=None, exception_value=None):
# Stub method for nodes which are not legal as # Stub method for nodes which are not legal as
# the LHS of an assignment. An error will have # the LHS of an assignment. An error will have
# been reported earlier. # been reported earlier.
...@@ -2037,7 +2090,8 @@ class NameNode(AtomicExprNode): ...@@ -2037,7 +2090,8 @@ class NameNode(AtomicExprNode):
if null_code and raise_unbound and (entry.type.is_pyobject or memslice_check): if null_code and raise_unbound and (entry.type.is_pyobject or memslice_check):
code.put_error_if_unbound(self.pos, entry, self.in_nogil_context) code.put_error_if_unbound(self.pos, entry, self.in_nogil_context)
def generate_assignment_code(self, rhs, code, overloaded_assignment=False): def generate_assignment_code(self, rhs, code, overloaded_assignment=False,
exception_check=None, exception_value=None):
#print "NameNode.generate_assignment_code:", self.name ### #print "NameNode.generate_assignment_code:", self.name ###
entry = self.entry entry = self.entry
if entry is None: if entry is None:
...@@ -2127,8 +2181,15 @@ class NameNode(AtomicExprNode): ...@@ -2127,8 +2181,15 @@ class NameNode(AtomicExprNode):
code.put_giveref(rhs.py_result()) code.put_giveref(rhs.py_result())
if not self.type.is_memoryviewslice: if not self.type.is_memoryviewslice:
if not assigned: if not assigned:
result = rhs.result() if overloaded_assignment else rhs.result_as(self.ctype()) if overloaded_assignment:
code.putln('%s = %s;' % (self.result(), result)) result = rhs.result()
if exception_check == '+':
translate_cpp_exception(code, self.pos, '%s = %s;' % (self.result(), result), exception_value, self.in_nogil_context)
else:
code.putln('%s = %s;' % (self.result(), result))
else:
result = rhs.result_as(self.ctype())
code.putln('%s = %s;' % (self.result(), result))
if debug_disposal_code: if debug_disposal_code:
print("NameNode.generate_assignment_code:") print("NameNode.generate_assignment_code:")
print("...generating post-assignment code for %s" % rhs) print("...generating post-assignment code for %s" % rhs)
...@@ -3261,6 +3322,13 @@ class IndexNode(_IndexingBaseNode): ...@@ -3261,6 +3322,13 @@ class IndexNode(_IndexingBaseNode):
func_type = function.type func_type = function.type
if func_type.is_ptr: if func_type.is_ptr:
func_type = func_type.base_type func_type = func_type.base_type
self.exception_check = func_type.exception_check
self.exception_value = func_type.exception_value
if self.exception_check:
if not setting:
self.is_temp = True
if self.exception_value is None:
env.use_utility_code(UtilityCode.load_cached("CppExceptionConversion", "CppSupport.cpp"))
self.index = self.index.coerce_to(func_type.args[0].type, env) self.index = self.index.coerce_to(func_type.args[0].type, env)
self.type = func_type.return_type self.type = func_type.return_type
if setting and not func_type.return_type.is_reference: if setting and not func_type.return_type.is_reference:
...@@ -3530,7 +3598,7 @@ class IndexNode(_IndexingBaseNode): ...@@ -3530,7 +3598,7 @@ class IndexNode(_IndexingBaseNode):
error_value = '-1' error_value = '-1'
code.globalstate.use_utility_code( code.globalstate.use_utility_code(
UtilityCode.load_cached("GetItemIntByteArray", "StringTools.c")) UtilityCode.load_cached("GetItemIntByteArray", "StringTools.c"))
else: elif not (self.base.type.is_cpp_class and self.exception_check):
assert False, "unexpected type %s and base type %s for indexing" % ( assert False, "unexpected type %s and base type %s for indexing" % (
self.type, self.base.type) self.type, self.base.type)
...@@ -3539,16 +3607,22 @@ class IndexNode(_IndexingBaseNode): ...@@ -3539,16 +3607,22 @@ class IndexNode(_IndexingBaseNode):
else: else:
index_code = self.index.py_result() index_code = self.index.py_result()
code.putln( if self.base.type.is_cpp_class and self.exception_check:
"%s = %s(%s, %s%s); if (unlikely(%s == %s)) %s;" % ( translate_cpp_exception(code, self.pos,
self.result(), "%s = %s[%s];" % (self.result(), self.base.result(),
function, self.index.result()),
self.base.py_result(), self.exception_value, self.in_nogil_context)
index_code, else:
self.extra_index_params(code), code.putln(
self.result(), "%s = %s(%s, %s%s); if (unlikely(%s == %s)) %s;" % (
error_value, self.result(),
code.error_goto(self.pos))) function,
self.base.py_result(),
index_code,
self.extra_index_params(code),
self.result(),
error_value,
code.error_goto(self.pos)))
if self.type.is_pyobject: if self.type.is_pyobject:
code.put_gotref(self.py_result()) code.put_gotref(self.py_result())
...@@ -3585,7 +3659,8 @@ class IndexNode(_IndexingBaseNode): ...@@ -3585,7 +3659,8 @@ class IndexNode(_IndexingBaseNode):
self.extra_index_params(code), self.extra_index_params(code),
code.error_goto(self.pos))) code.error_goto(self.pos)))
def generate_assignment_code(self, rhs, code, overloaded_assignment=False): def generate_assignment_code(self, rhs, code, overloaded_assignment=False,
exception_check=None, exception_value=None):
self.generate_subexpr_evaluation_code(code) self.generate_subexpr_evaluation_code(code)
if self.type.is_pyobject: if self.type.is_pyobject:
...@@ -3593,6 +3668,21 @@ class IndexNode(_IndexingBaseNode): ...@@ -3593,6 +3668,21 @@ class IndexNode(_IndexingBaseNode):
elif self.base.type is bytearray_type: elif self.base.type is bytearray_type:
value_code = self._check_byte_value(code, rhs) value_code = self._check_byte_value(code, rhs)
self.generate_setitem_code(value_code, code) self.generate_setitem_code(value_code, code)
elif self.base.type.is_cpp_class and self.exception_check and self.exception_check == '+':
if overloaded_assignment and exception_check and \
self.exception_value != exception_value:
# Handle the case that both the index operator and the assignment
# operator have a c++ exception handler and they are not the same.
translate_double_cpp_exception(code, self.pos, self.type,
self.result(), rhs.result(), self.exception_value,
exception_value, self.in_nogil_context)
else:
# Handle the case that only the index operator has a
# c++ exception handler, or that
# both exception handlers are the same.
translate_cpp_exception(code, self.pos,
"%s = %s;" % (self.result(), rhs.result()),
self.exception_value, self.in_nogil_context)
else: else:
code.putln( code.putln(
"%s = %s;" % (self.result(), rhs.result())) "%s = %s;" % (self.result(), rhs.result()))
...@@ -4403,7 +4493,8 @@ class SliceIndexNode(ExprNode): ...@@ -4403,7 +4493,8 @@ class SliceIndexNode(ExprNode):
code.error_goto_if_null(result, self.pos))) code.error_goto_if_null(result, self.pos)))
code.put_gotref(self.py_result()) code.put_gotref(self.py_result())
def generate_assignment_code(self, rhs, code, overloaded_assignment=False): def generate_assignment_code(self, rhs, code, overloaded_assignment=False,
exception_check=None, exception_value=None):
self.generate_subexpr_evaluation_code(code) self.generate_subexpr_evaluation_code(code)
if self.type.is_pyobject: if self.type.is_pyobject:
code.globalstate.use_utility_code(self.set_slice_utility_code) code.globalstate.use_utility_code(self.set_slice_utility_code)
...@@ -4823,6 +4914,8 @@ class SimpleCallNode(CallNode): ...@@ -4823,6 +4914,8 @@ class SimpleCallNode(CallNode):
else: else:
self.args = [ arg.analyse_types(env) for arg in self.args ] self.args = [ arg.analyse_types(env) for arg in self.args ]
self.analyse_c_function_call(env) self.analyse_c_function_call(env)
if func_type.exception_check == '+':
self.is_temp = True
return self return self
def function_type(self): def function_type(self):
...@@ -5149,24 +5242,8 @@ class SimpleCallNode(CallNode): ...@@ -5149,24 +5242,8 @@ class SimpleCallNode(CallNode):
else: else:
lhs = "" lhs = ""
if func_type.exception_check == '+': if func_type.exception_check == '+':
if func_type.exception_value is None: translate_cpp_exception(code, self.pos, '%s%s;' % (lhs, rhs),
raise_py_exception = "__Pyx_CppExn2PyErr();" func_type.exception_value, self.nogil)
elif func_type.exception_value.type.is_pyobject:
raise_py_exception = 'try { throw; } catch(const std::exception& exn) { PyErr_SetString(%s, exn.what()); } catch(...) { PyErr_SetNone(%s); }' % (
func_type.exception_value.entry.cname,
func_type.exception_value.entry.cname)
else:
raise_py_exception = '%s(); if (!PyErr_Occurred()) PyErr_SetString(PyExc_RuntimeError , "Error converting c++ exception.");' % func_type.exception_value.entry.cname
code.putln("try {")
code.putln("%s%s;" % (lhs, rhs))
code.putln("} catch(...) {")
if self.nogil:
code.put_ensure_gil(declare_gilstate=True)
code.putln(raise_py_exception)
if self.nogil:
code.put_release_ensured_gil()
code.putln(code.error_goto(self.pos))
code.putln("}")
else: else:
if exc_checks: if exc_checks:
goto_error = code.error_goto_if(" && ".join(exc_checks), self.pos) goto_error = code.error_goto_if(" && ".join(exc_checks), self.pos)
...@@ -6371,7 +6448,8 @@ class AttributeNode(ExprNode): ...@@ -6371,7 +6448,8 @@ class AttributeNode(ExprNode):
else: else:
ExprNode.generate_disposal_code(self, code) ExprNode.generate_disposal_code(self, code)
def generate_assignment_code(self, rhs, code, overloaded_assignment=False): def generate_assignment_code(self, rhs, code, overloaded_assignment=False,
exception_check=None, exception_value=None):
self.obj.generate_evaluation_code(code) self.obj.generate_evaluation_code(code)
if self.is_py_attr: if self.is_py_attr:
code.globalstate.use_utility_code( code.globalstate.use_utility_code(
...@@ -6708,7 +6786,8 @@ class SequenceNode(ExprNode): ...@@ -6708,7 +6786,8 @@ class SequenceNode(ExprNode):
if self.mult_factor: if self.mult_factor:
self.mult_factor.generate_disposal_code(code) self.mult_factor.generate_disposal_code(code)
def generate_assignment_code(self, rhs, code, overloaded_assignment=False): def generate_assignment_code(self, rhs, code, overloaded_assignment=False,
exception_check=None, exception_value=None):
if self.starred_assignment: if self.starred_assignment:
self.generate_starred_assignment_code(rhs, code) self.generate_starred_assignment_code(rhs, code)
else: else:
...@@ -9187,6 +9266,13 @@ class UnopNode(ExprNode): ...@@ -9187,6 +9266,13 @@ class UnopNode(ExprNode):
def generate_result_code(self, code): def generate_result_code(self, code):
if self.operand.type.is_pyobject: if self.operand.type.is_pyobject:
self.generate_py_operation_code(code) self.generate_py_operation_code(code)
elif self.is_temp:
if self.is_cpp_operation() and self.exception_check == '+':
translate_cpp_exception(code, self.pos,
"%s = %s %s;" % (self.result(), self.operator, self.operand.result()),
self.exception_value, self.in_nogil_context)
else:
code.putln("%s = %s %s;" % (self.result(), self.operator, self.operand.result()))
def generate_py_operation_code(self, code): def generate_py_operation_code(self, code):
function = self.py_operation_function(code) function = self.py_operation_function(code)
...@@ -9204,9 +9290,23 @@ class UnopNode(ExprNode): ...@@ -9204,9 +9290,23 @@ class UnopNode(ExprNode):
(self.operator, self.operand.type)) (self.operator, self.operand.type))
self.type = PyrexTypes.error_type self.type = PyrexTypes.error_type
def analyse_cpp_operation(self, env): def analyse_cpp_operation(self, env, overload_check=True):
entry = env.lookup_operator(self.operator, [self.operand])
if overload_check and not entry:
self.type_error()
return
if entry:
self.exception_check = entry.type.exception_check
self.exception_value = entry.type.exception_value
if self.exception_check == '+':
self.is_temp = True
if self.exception_value is None:
env.use_utility_code(UtilityCode.load_cached("CppExceptionConversion", "CppSupport.cpp"))
else:
self.exception_check = ''
self.exception_value = ''
cpp_type = self.operand.type.find_cpp_operation_type(self.operator) cpp_type = self.operand.type.find_cpp_operation_type(self.operator)
if cpp_type is None: if overload_check and cpp_type is None:
error(self.pos, "'%s' operator not defined for %s" % ( error(self.pos, "'%s' operator not defined for %s" % (
self.operator, type)) self.operator, type))
self.type_error() self.type_error()
...@@ -9239,12 +9339,7 @@ class NotNode(UnopNode): ...@@ -9239,12 +9339,7 @@ class NotNode(UnopNode):
self.operand = self.operand.analyse_types(env) self.operand = self.operand.analyse_types(env)
operand_type = self.operand.type operand_type = self.operand.type
if operand_type.is_cpp_class: if operand_type.is_cpp_class:
cpp_type = operand_type.find_cpp_operation_type(self.operator) self.analyse_cpp_operation(env)
if not cpp_type:
error(self.pos, "'!' operator not defined for %s" % operand_type)
self.type = PyrexTypes.error_type
return
self.type = cpp_type
else: else:
self.operand = self.operand.coerce_to_boolean(env) self.operand = self.operand.coerce_to_boolean(env)
return self return self
...@@ -9252,9 +9347,6 @@ class NotNode(UnopNode): ...@@ -9252,9 +9347,6 @@ class NotNode(UnopNode):
def calculate_result_code(self): def calculate_result_code(self):
return "(!%s)" % self.operand.result() return "(!%s)" % self.operand.result()
def generate_result_code(self, code):
pass
class UnaryPlusNode(UnopNode): class UnaryPlusNode(UnopNode):
# unary '+' operator # unary '+' operator
...@@ -9385,10 +9477,7 @@ class AmpersandNode(CUnopNode): ...@@ -9385,10 +9477,7 @@ class AmpersandNode(CUnopNode):
self.operand = self.operand.analyse_types(env) self.operand = self.operand.analyse_types(env)
argtype = self.operand.type argtype = self.operand.type
if argtype.is_cpp_class: if argtype.is_cpp_class:
cpp_type = argtype.find_cpp_operation_type(self.operator) self.analyse_cpp_operation(env, overload_check=False)
if cpp_type is not None:
self.type = cpp_type
return self
if not (argtype.is_cfunction or argtype.is_reference or self.operand.is_addressable()): if not (argtype.is_cfunction or argtype.is_reference or self.operand.is_addressable()):
if argtype.is_memoryviewslice: if argtype.is_memoryviewslice:
self.error("Cannot take address of memoryview slice") self.error("Cannot take address of memoryview slice")
...@@ -9398,7 +9487,8 @@ class AmpersandNode(CUnopNode): ...@@ -9398,7 +9487,8 @@ class AmpersandNode(CUnopNode):
if argtype.is_pyobject: if argtype.is_pyobject:
self.error("Cannot take address of Python variable") self.error("Cannot take address of Python variable")
return self return self
self.type = PyrexTypes.c_ptr_type(argtype) if not argtype.is_cpp_class or not self.type:
self.type = PyrexTypes.c_ptr_type(argtype)
return self return self
def check_const(self): def check_const(self):
...@@ -9413,7 +9503,10 @@ class AmpersandNode(CUnopNode): ...@@ -9413,7 +9503,10 @@ class AmpersandNode(CUnopNode):
return "(&%s)" % self.operand.result() return "(&%s)" % self.operand.result()
def generate_result_code(self, code): def generate_result_code(self, code):
pass if (self.operand.type.is_cpp_class and self.exception_check == '+'):
translate_cpp_exception(code, self.pos,
"%s = %s %s;" % (self.result(), self.operator, self.operand.result()),
self.exception_value, self.in_nogil_context)
unop_node_classes = { unop_node_classes = {
...@@ -10037,6 +10130,14 @@ class BinopNode(ExprNode): ...@@ -10037,6 +10130,14 @@ class BinopNode(ExprNode):
self.type_error() self.type_error()
return return
func_type = entry.type func_type = entry.type
self.exception_check = func_type.exception_check
self.exception_value = func_type.exception_value
if self.exception_check == '+':
# Used by NumBinopNodes to break up expressions involving multiple
# operators so that exceptions can be handled properly.
self.is_temp = 1
if self.exception_value is None:
env.use_utility_code(UtilityCode.load_cached("CppExceptionConversion", "CppSupport.cpp"))
if func_type.is_ptr: if func_type.is_ptr:
func_type = func_type.base_type func_type = func_type.base_type
if len(func_type.args) == 1: if len(func_type.args) == 1:
...@@ -10103,7 +10204,14 @@ class BinopNode(ExprNode): ...@@ -10103,7 +10204,14 @@ class BinopNode(ExprNode):
code.error_goto_if_null(self.result(), self.pos))) code.error_goto_if_null(self.result(), self.pos)))
code.put_gotref(self.py_result()) code.put_gotref(self.py_result())
elif self.is_temp: elif self.is_temp:
code.putln("%s = %s;" % (self.result(), self.calculate_result_code())) # C++ overloaded operators with exception values are currently all
# handled through temporaries.
if self.is_cpp_operation() and self.exception_check == '+':
translate_cpp_exception(code, self.pos,
"%s = %s;" % (self.result(), self.calculate_result_code()),
self.exception_value, self.in_nogil_context)
else:
code.putln("%s = %s;" % (self.result(), self.calculate_result_code()))
def type_error(self): def type_error(self):
if not (self.operand1.type.is_error if not (self.operand1.type.is_error
...@@ -10241,7 +10349,7 @@ class NumBinopNode(BinopNode): ...@@ -10241,7 +10349,7 @@ class NumBinopNode(BinopNode):
self.operand1.result(), self.operand1.result(),
self.operand2.result(), self.operand2.result(),
self.overflow_bit_node.overflow_bit) self.overflow_bit_node.overflow_bit)
elif self.infix: elif self.type.is_cpp_class or self.infix:
return "(%s %s %s)" % ( return "(%s %s %s)" % (
self.operand1.result(), self.operand1.result(),
self.operator, self.operator,
...@@ -11364,12 +11472,15 @@ class CmpNode(object): ...@@ -11364,12 +11472,15 @@ class CmpNode(object):
common_type = type1 common_type = type1
code1 = operand1.result_as(common_type) code1 = operand1.result_as(common_type)
code2 = operand2.result_as(common_type) code2 = operand2.result_as(common_type)
code.putln("%s = %s(%s %s %s);" % ( statement = "%s = %s(%s %s %s);" % (
result_code, result_code,
coerce_result, coerce_result,
code1, code1,
self.c_operator(op), self.c_operator(op),
code2)) code2)
if self.is_cpp_comparison() and self.exception_check == '+':
translate_cpp_exception(code, self.pos, statement, self.exception_value, self.in_nogil_context)
code.putln(statement)
def c_operator(self, op): def c_operator(self, op):
if op == 'is': if op == 'is':
...@@ -11507,6 +11618,12 @@ class PrimaryCmpNode(ExprNode, CmpNode): ...@@ -11507,6 +11618,12 @@ class PrimaryCmpNode(ExprNode, CmpNode):
func_type = entry.type func_type = entry.type
if func_type.is_ptr: if func_type.is_ptr:
func_type = func_type.base_type func_type = func_type.base_type
self.exception_check = func_type.exception_check
self.exception_value = func_type.exception_value
if self.exception_check == '+':
self.is_temp = True
if self.exception_value is None:
env.use_utility_code(UtilityCode.load_cached("CppExceptionConversion", "CppSupport.cpp"))
if len(func_type.args) == 1: if len(func_type.args) == 1:
self.operand2 = self.operand2.coerce_to(func_type.args[0].type, env) self.operand2 = self.operand2.coerce_to(func_type.args[0].type, env)
else: else:
...@@ -11665,6 +11782,10 @@ class CascadedCmpNode(Node, CmpNode): ...@@ -11665,6 +11782,10 @@ class CascadedCmpNode(Node, CmpNode):
def has_python_operands(self): def has_python_operands(self):
return self.operand2.type.is_pyobject return self.operand2.type.is_pyobject
def is_cpp_comparison(self):
# cascaded comparisons aren't currently implemented for c++ classes.
return False
def optimise_comparison(self, operand1, env, result_is_bool=False): def optimise_comparison(self, operand1, env, result_is_bool=False):
if self.find_special_bool_compare_function(env, operand1, result_is_bool): if self.find_special_bool_compare_function(env, operand1, result_is_bool):
self.is_pycmp = False self.is_pycmp = False
......
...@@ -4797,6 +4797,8 @@ class SingleAssignmentNode(AssignmentNode): ...@@ -4797,6 +4797,8 @@ class SingleAssignmentNode(AssignmentNode):
# rhs ExprNode Right hand side # rhs ExprNode Right hand side
# first bool Is this guaranteed the first assignment to lhs? # first bool Is this guaranteed the first assignment to lhs?
# is_overloaded_assignment bool Is this assignment done via an overloaded operator= # is_overloaded_assignment bool Is this assignment done via an overloaded operator=
# exception_check
# exception_value
child_attrs = ["lhs", "rhs"] child_attrs = ["lhs", "rhs"]
first = False first = False
...@@ -4910,6 +4912,10 @@ class SingleAssignmentNode(AssignmentNode): ...@@ -4910,6 +4912,10 @@ class SingleAssignmentNode(AssignmentNode):
if op: if op:
rhs = self.rhs rhs = self.rhs
self.is_overloaded_assignment = True self.is_overloaded_assignment = True
self.exception_check = op.type.exception_check
self.exception_value = op.type.exception_value
if self.exception_check == '+' and self.exception_value is None:
env.use_utility_code(UtilityCode.load_cached("CppExceptionConversion", "CppSupport.cpp"))
else: else:
rhs = self.rhs.coerce_to(self.lhs.type, env) rhs = self.rhs.coerce_to(self.lhs.type, env)
else: else:
...@@ -5062,8 +5068,15 @@ class SingleAssignmentNode(AssignmentNode): ...@@ -5062,8 +5068,15 @@ class SingleAssignmentNode(AssignmentNode):
self.rhs.generate_evaluation_code(code) self.rhs.generate_evaluation_code(code)
def generate_assignment_code(self, code, overloaded_assignment=False): def generate_assignment_code(self, code, overloaded_assignment=False):
self.lhs.generate_assignment_code( if self.is_overloaded_assignment:
self.rhs, code, overloaded_assignment=self.is_overloaded_assignment) self.lhs.generate_assignment_code(
self.rhs,
code,
overloaded_assignment=self.is_overloaded_assignment,
exception_check=self.exception_check,
exception_value=self.exception_value)
else:
self.lhs.generate_assignment_code(self.rhs, code)
def generate_function_definitions(self, env, code): def generate_function_definitions(self, env, code):
self.rhs.generate_function_definitions(env, code) self.rhs.generate_function_definitions(env, code)
......
# mode: run
# tag: cpp, werror
from cython.operator import (preincrement, predecrement,
postincrement, postdecrement)
from libcpp cimport bool
cdef extern from "cpp_operator_exc_handling_helper.hpp" nogil:
cppclass wrapped_int:
long long val
wrapped_int()
wrapped_int(long long val)
wrapped_int(long long v1, long long v2) except +
wrapped_int operator+(wrapped_int &other) except +ValueError
wrapped_int operator+() except +RuntimeError
wrapped_int operator-(wrapped_int &other) except +
wrapped_int operator-() except +
wrapped_int operator*(wrapped_int &other) except +OverflowError
wrapped_int operator/(wrapped_int &other) except +
wrapped_int operator%(wrapped_int &other) except +
long long operator^(wrapped_int &other) except +
long long operator&(wrapped_int &other) except +
long long operator|(wrapped_int &other) except +
wrapped_int operator~() except +
long long operator&() except +
long long operator==(wrapped_int &other) except +
long long operator!=(wrapped_int &other) except +
long long operator<(wrapped_int &other) except +
long long operator<=(wrapped_int &other) except +
long long operator>(wrapped_int &other) except +
long long operator>=(wrapped_int &other) except +
wrapped_int operator<<(long long shift) except +
wrapped_int operator>>(long long shift) except +
wrapped_int &operator++() except +
wrapped_int &operator--() except +
wrapped_int operator++(int) except +
wrapped_int operator--(int) except +
wrapped_int operator!() except +
bool operator bool() except +
wrapped_int &operator[](long long &index) except +IndexError
long long &operator()() except +AttributeError
wrapped_int &operator=(const wrapped_int &other) except +ArithmeticError
wrapped_int &operator=(const long long &vao) except +
def assert_raised(f, *args, **kwargs):
err = kwargs.get('err', None)
if err is None:
try:
f(*args)
raised = False
except:
raised = True
else:
try:
f(*args)
raised = False
except err:
raised = True
assert raised
def initialization(long long a, long long b):
cdef wrapped_int w = wrapped_int(a, b)
return w.val
def addition(long long a, long long b):
cdef wrapped_int wa = wrapped_int(a)
cdef wrapped_int wb = wrapped_int(b)
return (wa + wb).val
def subtraction(long long a, long long b):
cdef wrapped_int wa = wrapped_int(a)
cdef wrapped_int wb = wrapped_int(b)
return (wa - wb).val
def multiplication(long long a, long long b):
cdef wrapped_int wa = wrapped_int(a)
cdef wrapped_int wb = wrapped_int(b)
return (wa * wb).val
def division(long long a, long long b):
cdef wrapped_int wa = wrapped_int(a)
cdef wrapped_int wb = wrapped_int(b)
return (wa / wb).val
def mod(long long a, long long b):
cdef wrapped_int wa = wrapped_int(a)
cdef wrapped_int wb = wrapped_int(b)
return (wa % wb).val
def minus(long long a):
cdef wrapped_int wa = wrapped_int(a)
return (-wa).val
def plus(long long a):
cdef wrapped_int wa = wrapped_int(a)
return (+wa).val
def xor(long long a, long long b):
cdef wrapped_int wa = wrapped_int(a)
cdef wrapped_int wb = wrapped_int(b)
return wa ^ wb
def bitwise_and(long long a, long long b):
cdef wrapped_int wa = wrapped_int(a)
cdef wrapped_int wb = wrapped_int(b)
return wa & wb
def bitwise_or(long long a, long long b):
cdef wrapped_int wa = wrapped_int(a)
cdef wrapped_int wb = wrapped_int(b)
return wa | wb
def bitwise_not(long long a):
cdef wrapped_int wa = wrapped_int(a)
return (~a).val
def address(long long a):
cdef wrapped_int wa = wrapped_int(a)
return &wa
def iseq(long long a, long long b):
cdef wrapped_int wa = wrapped_int(a)
cdef wrapped_int wb = wrapped_int(b)
return wa == wb
def neq(long long a, long long b):
cdef wrapped_int wa = wrapped_int(a)
cdef wrapped_int wb = wrapped_int(b)
return wa != wb
def less(long long a, long long b):
cdef wrapped_int wa = wrapped_int(a)
cdef wrapped_int wb = wrapped_int(b)
return wa < wb
def leq(long long a, long long b):
cdef wrapped_int wa = wrapped_int(a)
cdef wrapped_int wb = wrapped_int(b)
return wa <= wb
def greater(long long a, long long b):
cdef wrapped_int wa = wrapped_int(a)
cdef wrapped_int wb = wrapped_int(b)
return wa > wb
def geq(long long a, long long b):
cdef wrapped_int wa = wrapped_int(a)
cdef wrapped_int wb = wrapped_int(b)
return wa < wb
def left_shift(long long a, long long b):
cdef wrapped_int wa = wrapped_int(a)
return (wa << b).val
def right_shift(long long a, long long b):
cdef wrapped_int wa = wrapped_int(a)
return (wa >> b).val
def cpp_preincrement(long long a):
cdef wrapped_int wa = wrapped_int(a)
return preincrement(wa).val
def cpp_predecrement(long long a):
cdef wrapped_int wa = wrapped_int(a)
return predecrement(wa).val
def cpp_postincrement(long long a):
cdef wrapped_int wa = wrapped_int(a)
return postincrement(wa).val
def cpp_postdecrement(long long a):
cdef wrapped_int wa = wrapped_int(a)
return postdecrement(wa).val
def negate(long long a):
cdef wrapped_int wa = wrapped_int(a)
return (not wa).val
def bool_cast(long long a):
cdef wrapped_int wa = wrapped_int(a)
if wa:
return True
else:
return False
def index(long long a, long long b):
cdef wrapped_int wa = wrapped_int(a)
return wa[b].val
def assign_index(long long a, long long b, long long c):
cdef wrapped_int wa = wrapped_int(a)
cdef wrapped_int wb = wrapped_int(b)
wb[c] = wa
return wb.val
def call(long long a):
cdef wrapped_int wa = wrapped_int(a)
return wa()
def assign_same(long long a, long long b):
cdef wrapped_int wa = wrapped_int(a)
cdef wrapped_int wb = wrapped_int(b)
wa = wb
return wa.val
def assign_different(long long a, long long b):
cdef wrapped_int wa = wrapped_int(a)
wa = b
return wa.val
def cascaded_assign(long long a, long long b, long long c):
cdef wrapped_int wa = wrapped_int(a)
a = b = c
return a.val
def separate_exceptions(long long a, long long b, long long c, long long d, long long e):
cdef:
wrapped_int wa = wrapped_int(a)
wrapped_int wc = wrapped_int(c)
wrapped_int wd = wrapped_int(d)
wrapped_int we = wrapped_int(e)
wa[b] = (+wc) * wd + we
return a.val
def call_temp_separation(long long a, long long b, long long c):
cdef:
wrapped_int wa = wrapped_int(a)
wrapped_int wc = wrapped_int(c)
wa[b] = wc()
return wa.val
def test_operator_exception_handling():
"""
>>> test_operator_exception_handling()
"""
assert_raised(initialization, 1, 4)
assert_raised(addition, 1, 4)
assert_raised(subtraction, 1, 4)
assert_raised(multiplication, 1, 4)
assert_raised(division, 1, 4)
assert_raised(mod, 1, 4)
assert_raised(minus, 4)
assert_raised(plus, 4)
assert_raised(xor, 1, 4)
assert_raised(address, 4)
assert_raised(iseq, 1, 4)
assert_raised(neq, 1, 4)
assert_raised(left_shift, 1, 4)
assert_raised(right_shift, 1, 4)
assert_raised(cpp_preincrement, 4)
assert_raised(cpp_predecrement, 4)
assert_raised(cpp_postincrement, 4)
assert_raised(cpp_postdecrement, 4)
assert_raised(negate, 4)
assert_raised(bool_cast, 4)
assert_raised(index, 1, 4)
assert_raised(assign_index, 1, 4, 4)
assert_raised(call, 4)
assert_raised(assign_same, 4, 4)
assert_raised(assign_different, 4, 4)
assert_raised(cascaded_assign, 4, 4, 1)
assert_raised(cascaded_assign, 4, 1, 4)
assert_raised(separate_exceptions, 1, 1, 1, 1, 4, err=ValueError)
assert_raised(separate_exceptions, 1, 1, 1, 4, 1, err=OverflowError)
assert_raised(separate_exceptions, 1, 1, 4, 1, 1, err=RuntimeError)
assert_raised(separate_exceptions, 1, 4, 1, 1, 1, err=IndexError)
assert_raised(separate_exceptions, 4, 1, 1, 1, 3, err=ArithmeticError)
assert_raised(call_temp_separation, 2, 1, 4, err=AttributeError)
assert_raised(call_temp_separation, 2, 4, 1, err=IndexError)
#pragma once
#include <stdexcept>
class wrapped_int {
public:
long long val;
wrapped_int() { val = 0; }
wrapped_int(long long val) { this->val = val; }
wrapped_int(long long v1, long long v2) {
if (v2 == 4) {
throw std::domain_error("4 isn't good for initialization!");
}
this->val = v1;
}
wrapped_int operator+(wrapped_int &other) {
if (other.val == 4) {
throw std::invalid_argument("tried to add 4");
}
return wrapped_int(this->val + other.val);
}
wrapped_int operator+() {
if (this->val == 4) {
throw std::domain_error("'4' not in valid domain.");
}
return *this;
}
wrapped_int operator-(wrapped_int &other) {
if (other.val == 4) {
throw std::overflow_error("Value '4' is no good.");
}
return *this;
}
wrapped_int operator-() {
if (this->val == 4) {
throw std::range_error("Can't take the negative of 4.");
}
return wrapped_int(-this->val);
}
wrapped_int operator*(wrapped_int &other) {
if (other.val == 4) {
throw std::out_of_range("Multiplying by 4 isn't going to work.");
}
return wrapped_int(this->val * other.val);
}
wrapped_int operator/(wrapped_int &other) {
if (other.val == 4) {
throw std::out_of_range("Multiplying by 4 isn't going to work.");
}
return wrapped_int(this->val / other.val);
}
wrapped_int operator%(wrapped_int &other) {
if (other.val == 4) {
throw std::out_of_range("Multiplying by 4 isn't going to work.");
}
return wrapped_int(this->val % other.val);
}
long long operator^(wrapped_int &other) {
if (other.val == 4) {
throw std::out_of_range("Multiplying by 4 isn't going to work.");
}
return this->val ^ other.val;
}
long long operator&(wrapped_int &other) {
if (other.val == 4) {
throw std::underflow_error("Can't do this with 4!");
}
return this->val & other.val;
}
long long operator|(wrapped_int &other) {
if (other.val == 4) {
throw std::underflow_error("Can't do this with 4!");
}
return this->val & other.val;
}
wrapped_int operator~() {
if (this->val == 4) {
throw std::range_error("4 is really just no good for this!");
}
return *this;
}
long long operator&() {
if (this->val == 4) {
throw std::out_of_range("4 cannot be located!");
}
return this->val;
}
long long operator==(wrapped_int &other) {
if (other.val == 4) {
throw std::invalid_argument("4 isn't logical and can't be equal to anything!");
}
return this->val == other.val;
}
long long operator!=(wrapped_int &other) {
if (other.val == 4) {
throw std::invalid_argument("4 isn't logical and can'd be not equal to anything either!");
}
return this->val != other.val;
}
long long operator<(wrapped_int &other) {
if (other.val == 4) {
throw std::invalid_argument("Can't compare with 4!");
}
return this->val < other.val;
}
long long operator<=(wrapped_int &other) {
if (other.val == 4) {
throw std::invalid_argument("Can't compare with 4!");
}
return this->val <= other.val;
}
long long operator>(wrapped_int &other) {
if (other.val == 4) {
throw std::invalid_argument("Can't compare with 4!");
}
return this->val > other.val;
}
long long operator>=(wrapped_int &other) {
if (other.val == 4) {
throw std::invalid_argument("Can't compare with 4!");
}
return this->val >= other.val;
}
wrapped_int operator<<(long long &shift) {
if (shift == 4) {
throw std::overflow_error("Shifting by 4 is just bad.");
}
return wrapped_int(this->val << shift);
}
wrapped_int operator>>(long long &shift) {
if (shift == 4) {
throw std::underflow_error("Shifting by 4 is just bad.");
}
return wrapped_int(this->val >> shift);
}
wrapped_int &operator++() {
if (this->val == 4) {
throw std::out_of_range("Can't increment 4!");
}
this->val += 1;
return *this;
}
wrapped_int &operator--() {
if (this->val == 4) {
throw std::out_of_range("Can't decrement 4!");
}
this->val -= 1;
return *this;
}
wrapped_int operator++(int) {
if (this->val == 4) {
throw std::out_of_range("Can't increment 4!");
}
wrapped_int t = *this;
this->val += 1;
return t;
}
wrapped_int operator--(int) {
if (this->val == 4) {
throw std::out_of_range("Can't decrement 4!");
}
wrapped_int t = *this;
this->val -= 1;
return t;
}
wrapped_int operator!() {
if (this->val == 4) {
throw std::out_of_range("Can't negate 4!");
}
return wrapped_int(!this->val);
}
operator bool() {
if (this->val == 4) {
throw std::invalid_argument("4 can't be cast to a boolean value!");
}
return (this->val != 0);
}
wrapped_int &operator[](long long &idx) {
if (idx == 4) {
throw std::invalid_argument("Index of 4 not allowed.");
}
return *this;
}
long long &operator()() {
if (this->val == 4) {
throw std::range_error("Can't call 4!");
}
return this->val;
}
wrapped_int &operator=(const wrapped_int &other) {
if ((other.val == 4) && (this->val == 4)) {
throw std::overflow_error("Can't assign 4 to 4!");
}
this->val = other.val;
return *this;
}
wrapped_int &operator=(const long long &v) {
if ((v == 4) && (this->val == 4)) {
throw std::overflow_error("Can't assign 4 to 4!");
}
this->val = v;
return *this;
}
};
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment