Commit e166d581 authored by Tim Peters's avatar Tim Peters

Added many new tests of set operations (they weren't tested at all before).

Noted that the way difference() treats None doesn't match the docs and is
almost certainly wrong.
parent 9fc60bc3
...@@ -11,17 +11,169 @@ ...@@ -11,17 +11,169 @@
# FOR A PARTICULAR PURPOSE # FOR A PARTICULAR PURPOSE
# #
############################################################################## ##############################################################################
import sys, os, time, random import random
from unittest import TestCase, TestSuite, TextTestRunner, makeSuite from unittest import TestCase, TestSuite, TextTestRunner, makeSuite
from BTrees.IIBTree import IIBTree, IIBucket, IISet, IITreeSet, \ from BTrees.OOBTree import OOBTree, OOBucket, OOSet, OOTreeSet
union, intersection, difference, weightedUnion, weightedIntersection, \ from BTrees.IOBTree import IOBTree, IOBucket, IOSet, IOTreeSet
multiunion from BTrees.IIBTree import IIBTree, IIBucket, IISet, IITreeSet
from BTrees.OIBTree import OIBTree, OIBucket, OISet, OITreeSet
from BTrees.IIBTree import multiunion
# XXX TODO Needs more tests. # XXX TODO Needs more tests.
# This file was created when multiunion was added. The other set operations # This file was created when multiunion was added. The other set operations
# don't appear to be tested anywhere yet. # don't appear to be tested anywhere yet.
# Subclasses have to set up:
# builders - functions to build inputs, taking an optional keys arg
# intersection, union, difference - set to the type-correct versions
class SetResult(TestCase):
def setUp(self):
self.Akeys = [1, 3, 5, 6 ]
self.Bkeys = [ 2, 3, 4, 6, 7]
self.As = [makeset(self.Akeys) for makeset in self.builders]
self.Bs = [makeset(self.Bkeys) for makeset in self.builders]
self.emptys = [makeset() for makeset in self.builders]
# Slow but obviously correct Python implementations of basic ops.
def _union(self, x, y):
result = list(x.keys())
for e in y.keys():
if e not in result:
result.append(e)
result.sort()
return result
def _intersection(self, x, y):
result = []
ykeys = y.keys()
for e in x.keys():
if e in ykeys:
result.append(e)
return result
def _difference(self, x, y):
result = list(x.keys())
for e in y.keys():
if e in result:
result.remove(e)
# Difference preserves LHS values.
if hasattr(x, "values"):
result = [(k, x[k]) for k in result]
return result
def testNone(self):
for op in self.union, self.intersection, self.difference:
C = op(None, None)
self.assert_(C is None)
for op in self.union, self.intersection:
for A in self.As:
C = op(A, None)
self.assert_(C is A)
C = op(None, A)
self.assert_(C is A)
# XXX These difference results contradict the docs. The implementation
# XXX is almost certainly wrong, but can we change it?
for A in self.As:
C = self.difference(A, None)
self.assert_(C is None)
C = self.difference(None, A)
self.assert_(C is None)
def testEmptyUnion(self):
for A in self.As:
for E in self.emptys:
C = self.union(A, E)
self.assert_(not hasattr(C, "values"))
self.assertEqual(list(C), self.Akeys)
C = self.union(E, A)
self.assert_(not hasattr(C, "values"))
self.assertEqual(list(C), self.Akeys)
def testEmptyIntersection(self):
for A in self.As:
for E in self.emptys:
C = self.intersection(A, E)
self.assert_(not hasattr(C, "values"))
self.assertEqual(list(C), [])
C = self.intersection(E, A)
self.assert_(not hasattr(C, "values"))
self.assertEqual(list(C), [])
def testEmptyDifference(self):
for A in self.As:
for E in self.emptys:
C = self.difference(A, E)
# Difference preserves LHS values.
self.assertEqual(hasattr(C, "values"), hasattr(A, "values"))
if hasattr(A, "values"):
self.assertEqual(list(C.items()), list(A.items()))
else:
self.assertEqual(list(C), self.Akeys)
C = self.difference(E, A)
self.assertEqual(hasattr(C, "values"), hasattr(E, "values"))
self.assertEqual(list(C.keys()), [])
def testUnion(self):
inputs = self.As + self.Bs
for A in inputs:
for B in inputs:
C = self.union(A, B)
self.assert_(not hasattr(C, "values"))
self.assertEqual(list(C), self._union(A, B))
def testIntersection(self):
inputs = self.As + self.Bs
for A in inputs:
for B in inputs:
C = self.intersection(A, B)
self.assert_(not hasattr(C, "values"))
self.assertEqual(list(C), self._intersection(A, B))
def testDifference(self):
inputs = self.As + self.Bs
for A in inputs:
for B in inputs:
C = self.difference(A, B)
# Difference preserves LHS values.
self.assertEqual(hasattr(C, "values"), hasattr(A, "values"))
want = self._difference(A, B)
if hasattr(A, "values"):
self.assertEqual(list(C.items()), want)
else:
self.assertEqual(list(C), want)
# Given a mapping builder (IIBTree, OOBucket, etc), return a function
# that builds an object of that type given only a list of keys.
def makeBuilder(mapbuilder):
def result(keys=[], mapbuilder=mapbuilder):
return mapbuilder(zip(keys, keys))
return result
class PureII(SetResult):
from BTrees.IIBTree import union, intersection, difference
builders = IISet, IITreeSet, makeBuilder(IIBTree), makeBuilder(IIBucket)
class PureIO(SetResult):
from BTrees.IOBTree import union, intersection, difference
builders = IOSet, IOTreeSet, makeBuilder(IOBTree), makeBuilder(IOBucket)
class PureOO(SetResult):
from BTrees.OOBTree import union, intersection, difference
builders = OOSet, OOTreeSet, makeBuilder(OOBTree), makeBuilder(OOBucket)
class PureOI(SetResult):
from BTrees.OIBTree import union, intersection, difference
builders = OISet, OITreeSet, makeBuilder(OIBTree), makeBuilder(OIBucket)
class TestMultiUnion(TestCase): class TestMultiUnion(TestCase):
def testEmpty(self): def testEmpty(self):
...@@ -72,6 +224,7 @@ class TestMultiUnion(TestCase): ...@@ -72,6 +224,7 @@ class TestMultiUnion(TestCase):
def testFunkyKeyIteration(self): def testFunkyKeyIteration(self):
# The internal set iteration protocol allows "iterating over" a # The internal set iteration protocol allows "iterating over" a
# a single key as if it were a set. # a single key as if it were a set.
from BTrees.IIBTree import union
N = 100 N = 100
slow = IISet() slow = IISet()
for i in range(N): for i in range(N):
...@@ -83,10 +236,14 @@ class TestMultiUnion(TestCase): ...@@ -83,10 +236,14 @@ class TestMultiUnion(TestCase):
self.assertEqual(list(fast.keys()), range(N)) self.assertEqual(list(fast.keys()), range(N))
def test_suite(): def test_suite():
return makeSuite(TestMultiUnion, 'test') s = TestSuite()
for klass in (TestMultiUnion,
PureII, PureIO, PureOI, PureOO):
s.addTest(makeSuite(klass))
return s
def main(): def main():
TextTestRunner().run(test_suite()) TextTestRunner().run(test_suite())
if __name__ == '__main__': if __name__ == '__main__':
main() main()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment