Commit 659ba3a5 authored by Xavier Thompson's avatar Xavier Thompson

Add concatenation operators for cypclass builtin list and associated unit tests

parent 1000a609
......@@ -97,6 +97,58 @@ cdef cypclass cyplist[V]:
with gil:
raise RuntimeError("Modifying a list with active iterators")
cyplist[V] __add__(self, const cyplist[V] other) const:
result = cyplist[V]()
result._elements.reserve(self._elements.size() + other._elements.size())
for value in self._elements:
Cy_INCREF(value)
result._elements.push_back(value)
for value in other._elements:
Cy_INCREF(value)
result._elements.push_back(value)
return result
cyplist[V] __iadd__(self, const cyplist[V] other):
if self._active_iterators == 0:
self._elements.reserve(self._elements.size() + other._elements.size())
for value in other._elements:
Cy_INCREF(value)
self._elements.push_back(value)
return self
else:
with gil:
raise RuntimeError("Modifying a list with active iterators")
cyplist[V] __mul__(self, size_type n) const:
result = cyplist[V]()
result._elements.reserve(self._elements.size() * n)
for i in range(n):
for value in self._elements:
Cy_INCREF(value)
result._elements.push_back(value)
return result
cyplist[V] __imul__(self, size_type n):
if self._active_iterators == 0:
if n > 1:
elements = self._elements
self._elements.reserve(elements.size() * n)
for i in range(1, n):
for value in elements:
Cy_INCREF(value)
self._elements.push_back(value)
return self
elif n == 1:
return self
else:
for value in self._elements:
Cy_DECREF(value)
self._elements.clear()
return self
else:
with gil:
raise RuntimeError("Modifying a list with active iterators")
list_iterator_t[cyplist[V], vector[value_type].iterator, value_type] begin(self) const:
return list_iterator_t[cyplist[V], vector[value_type].iterator, value_type](self._elements.begin(), self)
......
......@@ -103,6 +103,56 @@ def test_contains():
return 0
return 1
def test_add():
"""
>>> test_add()
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
"""
l1 = cyplist[Value]()
for i in range(5):
l1.append(Value(i))
l2 = cyplist[Value]()
for i in range(5, 10):
l2.append(Value(i))
l = l1 + l2
return [v.value for v in l]
def test_iadd():
"""
>>> test_iadd()
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
"""
l1 = cyplist[Value]()
for i in range(5):
l1.append(Value(i))
l2 = cyplist[Value]()
for i in range(5, 10):
l2.append(Value(i))
l1 += l2
return [v.value for v in l1]
def test_mul():
"""
>>> test_mul()
[0, 1, 0, 1, 0, 1]
"""
l1 = cyplist[Value]()
for i in range(2):
l1.append(Value(i))
l = l1 * 3
return [v.value for v in l]
def test_imul():
"""
>>> test_imul()
[0, 1, 0, 1, 0, 1]
"""
l = cyplist[Value]()
for i in range(2):
l.append(Value(i))
l *= 3
return [v.value for v in l]
def test_getitem_out_of_range():
"""
>>> test_getitem_out_of_range()
......@@ -337,3 +387,41 @@ def test_iterator_refcount():
return 0
return 1
def test_concatenation_refcount():
"""
>>> test_concatenation_refcount()
1
"""
value = Value(1)
l1 = cyplist[Value]()
if Cy_GETREF(value) != 2:
return 0
l1.append(value)
if Cy_GETREF(value) != 3:
return 0
l2 = cyplist[Value]()
l2.append(value)
if Cy_GETREF(value) != 4:
return 0
l3 = l1 + l2
if Cy_GETREF(value) != 6:
return 0
l3 += l1
if Cy_GETREF(value) != 7:
return 0
l4 = l3 * 3
if Cy_GETREF(value) != 16:
return 0
l4 *= 2
if Cy_GETREF(value) != 25:
return 0
return 1
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment