Commit b8d4bd87 authored by Xavier Thompson's avatar Xavier Thompson

Support __getitem__ and __setitem__ for cypclasses

parent ebc2f033
...@@ -3973,6 +3973,8 @@ class IndexNode(_IndexingBaseNode): ...@@ -3973,6 +3973,8 @@ class IndexNode(_IndexingBaseNode):
def analyse_as_cpp(self, env, setting): def analyse_as_cpp(self, env, setting):
base_type = self.base.type base_type = self.base.type
if base_type.is_cyp_class and setting:
return self.analyse_as_cyp_setitem(env)
function = env.lookup_operator("[]", [self.base, self.index]) function = env.lookup_operator("[]", [self.base, self.index])
if function is None: if function is None:
error(self.pos, "Indexing '%s' not supported for index type '%s'" % (base_type, self.index.type)) error(self.pos, "Indexing '%s' not supported for index type '%s'" % (base_type, self.index.type))
...@@ -3995,6 +3997,34 @@ class IndexNode(_IndexingBaseNode): ...@@ -3995,6 +3997,34 @@ class IndexNode(_IndexingBaseNode):
error(self.pos, "Can't set non-reference result '%s'" % self.type) error(self.pos, "Can't set non-reference result '%s'" % self.type)
return self return self
def analyse_as_cyp_setitem(self, env):
base_type = self.base.type
function = base_type.scope.lookup_here("__setitem__")
if function is None:
error(self.pos, "Setting item '%s' not supported for index type '%s'" % (base_type, self.index.type))
self.type = PyrexTypes.error_type
self.result_code = "<error>"
return self
if len(function.all_alternatives()) > 1:
error(self.pos, "%s.__setitem__ has several alternatives" % base_type)
self.type = PyrexTypes.error_type
self.result_code = "<error>"
return self
func_type = function.type
self.exception_check = func_type.exception_check
self.exception_value = func_type.exception_value
if self.exception_check:
if self.exception_value is None:
env.use_utility_code(UtilityCode.load_cached("CppExceptionConversion", "CppSupport.cpp"))
if len(func_type.args) != 2:
error(self.pos, "%s.__setitem__ takes wrong number of arguments" % base_type)
self.type = PyrexTypes.error_type
self.result_code = "<error>"
return self
self.index = self.index.coerce_to(func_type.args[0].type, env)
self.type = func_type.args[1].type
return self
def analyse_as_c_function(self, env): def analyse_as_c_function(self, env):
base_type = self.base.type base_type = self.base.type
if base_type.is_fused: if base_type.is_fused:
...@@ -4355,12 +4385,30 @@ class IndexNode(_IndexingBaseNode): ...@@ -4355,12 +4385,30 @@ class IndexNode(_IndexingBaseNode):
self.extra_index_params(code)), self.extra_index_params(code)),
self.pos)) self.pos))
def generate_cyp_setitem_code(self, value_code, code):
function = self.base.type.scope.lookup_here("__setitem__")
function_code = function.cname
setitem_code = "%s->%s(%s, %s);" % (
self.base.result(),
function_code,
self.index.result(),
value_code)
if self.exception_check and self.exception_check == "+":
translate_cpp_exception(code, self.pos,
setitem_code,
None,
self.exception_value, self.in_nogil_context)
else:
code.putln(setitem_code)
def generate_assignment_code(self, rhs, code, overloaded_assignment=False, def generate_assignment_code(self, rhs, code, overloaded_assignment=False,
exception_check=None, exception_value=None): exception_check=None, exception_value=None):
self.generate_subexpr_evaluation_code(code) self.generate_subexpr_evaluation_code(code)
if self.type.is_pyobject: if self.type.is_pyobject:
self.generate_setitem_code(rhs.py_result(), code) self.generate_setitem_code(rhs.py_result(), code)
elif self.base.type.is_cyp_class:
self.generate_cyp_setitem_code(rhs.result(), code)
elif self.base.type is bytearray_type: elif self.base.type is bytearray_type:
value_code = self._check_byte_value(code, rhs) value_code = self._check_byte_value(code, rhs)
self.generate_setitem_code(value_code, code) self.generate_setitem_code(value_code, code)
......
...@@ -2794,7 +2794,8 @@ class CppClassScope(Scope): ...@@ -2794,7 +2794,8 @@ class CppClassScope(Scope):
'__gt__': '>', '__gt__': '>',
'__le__': '<=', '__le__': '<=',
'__ge__': '>=', '__ge__': '>=',
'__call__':'()' '__call__':'()',
'__getitem__':'[]'
} }
def __init__(self, name, outer_scope, templates=None): def __init__(self, name, outer_scope, templates=None):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment