Commit ffad2575 authored by Casey Duncan's avatar Casey Duncan

Port bugfix for mergeResults KeyError bug with small sort indexes.

Port mergeResults() tests
parent 01c2cd13
...@@ -592,7 +592,8 @@ class Catalog(Persistent, Acquisition.Implicit, ExtensionClass.Base): ...@@ -592,7 +592,8 @@ class Catalog(Persistent, Acquisition.Implicit, ExtensionClass.Base):
rs = rs.keys() rs = rs.keys()
rlen = len(rs) rlen = len(rs)
if limit is None and (rlen > (len(sort_index) * (rlen / 100 + 1))): if merge and limit is None and (
rlen > (len(sort_index) * (rlen / 100 + 1))):
# The result set is much larger than the sorted index, # The result set is much larger than the sorted index,
# so iterate over the sorted index for speed. # so iterate over the sorted index for speed.
# This is rarely exercised in practice... # This is rarely exercised in practice...
...@@ -620,13 +621,10 @@ class Catalog(Persistent, Acquisition.Implicit, ExtensionClass.Base): ...@@ -620,13 +621,10 @@ class Catalog(Persistent, Acquisition.Implicit, ExtensionClass.Base):
append((k, intset, _self__getitem__)) append((k, intset, _self__getitem__))
# Note that sort keys are unique. # Note that sort keys are unique.
if merge:
result.sort() result.sort()
if reverse: if reverse:
result.reverse() result.reverse()
result = LazyCat(LazyValues(result), length) result = LazyCat(LazyValues(result), length)
else:
return result
elif limit is None or (limit * 4 > rlen): elif limit is None or (limit * 4 > rlen):
# Iterate over the result set getting sort keys from the index # Iterate over the result set getting sort keys from the index
for did in rs: for did in rs:
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
import os import os
import random import random
import unittest import unittest
from itertools import chain
import ZODB, OFS.Application import ZODB, OFS.Application
from ZODB.DemoStorage import DemoStorage from ZODB.DemoStorage import DemoStorage
...@@ -41,24 +42,10 @@ def createDatabase(): ...@@ -41,24 +42,10 @@ def createDatabase():
app = createDatabase() app = createDatabase()
def sort(iterable):
################################################################################ L = list(iterable)
# Stuff of Chris L.sort()
# XXX What's this mean? What does this comment apply to? return L
################################################################################
# XXX These imports and class don't appear to be needed?
## from AccessControl.SecurityManagement import newSecurityManager
## from AccessControl.SecurityManagement import noSecurityManager
## class DummyUser:
## def __init__( self, name ):
## self._name = name
## def getUserName( self ):
## return self._name
class CatalogBase: class CatalogBase:
def setUp(self): def setUp(self):
...@@ -445,7 +432,7 @@ class objRS(ExtensionClass.Base): ...@@ -445,7 +432,7 @@ class objRS(ExtensionClass.Base):
def __init__(self,num): def __init__(self,num):
self.number = num self.number = num
class testRS(unittest.TestCase): class TestRS(unittest.TestCase):
def setUp(self): def setUp(self):
self._vocabulary = Vocabulary.Vocabulary('Vocabulary','Vocabulary' self._vocabulary = Vocabulary.Vocabulary('Vocabulary','Vocabulary'
...@@ -472,6 +459,85 @@ class testRS(unittest.TestCase): ...@@ -472,6 +459,85 @@ class testRS(unittest.TestCase):
self.assert_(m<=size and size<=n, self.assert_(m<=size and size<=n,
"%d vs [%d,%d]" % (r.number,m,n)) "%d vs [%d,%d]" % (r.number,m,n))
class TestMerge(unittest.TestCase):
# Test merging results from multiple catalogs
def setUp(self):
vocabulary = Vocabulary.Vocabulary(
'Vocabulary','Vocabulary', globbing=1)
self.catalogs = []
for i in range(3):
cat = Catalog()
cat.addIndex('num', FieldIndex('num'))
cat.addIndex('big', FieldIndex('big'))
cat.addIndex('title', TextIndex('title'))
cat.vocabulary = vocabulary
cat.aq_parent = zdummy(16336)
for i in range(10):
obj = zdummy(i)
obj.big = i > 5
cat.catalogObject(obj, str(i))
self.catalogs.append(cat)
def testNoFilterOrSort(self):
from Products.ZCatalog.Catalog import mergeResults
results = [cat.searchResults(_merge=0) for cat in self.catalogs]
merged_rids = [r.getRID() for r in mergeResults(
results, has_sort_keys=False, reverse=False)]
expected = [r.getRID() for r in chain(*results)]
self.assertEqual(sort(merged_rids), sort(expected))
def testSortedOnly(self):
from Products.ZCatalog.Catalog import mergeResults
results = [cat.searchResults(sort_on='num', _merge=0)
for cat in self.catalogs]
merged_rids = [r.getRID() for r in mergeResults(
results, has_sort_keys=True, reverse=False)]
expected = sort(chain(*results))
expected = [rid for sortkey, rid, getitem in expected]
self.assertEqual(merged_rids, expected)
def testSortReverse(self):
from Products.ZCatalog.Catalog import mergeResults
results = [cat.searchResults(sort_on='num', _merge=0)
for cat in self.catalogs]
merged_rids = [r.getRID() for r in mergeResults(
results, has_sort_keys=True, reverse=True)]
expected = sort(chain(*results))
expected.reverse()
expected = [rid for sortkey, rid, getitem in expected]
self.assertEqual(merged_rids, expected)
def testLimitSort(self):
from Products.ZCatalog.Catalog import mergeResults
results = [cat.searchResults(sort_on='num', sort_limit=2, _merge=0)
for cat in self.catalogs]
merged_rids = [r.getRID() for r in mergeResults(
results, has_sort_keys=True, reverse=False)]
expected = sort(chain(*results))
expected = [rid for sortkey, rid, getitem in expected]
self.assertEqual(merged_rids, expected)
def testScored(self):
from Products.ZCatalog.Catalog import mergeResults
results = [cat.searchResults(title='4 or 5 or 6', _merge=0)
for cat in self.catalogs]
merged_rids = [r.getRID() for r in mergeResults(
results, has_sort_keys=True, reverse=False)]
expected = sort(chain(*results))
expected = [rid for sortkey, (nscore, score, rid), getitem in expected]
self.assertEqual(merged_rids, expected)
def testSmallIndexSort(self):
# Test that small index sort optimization is not used for merging
from Products.ZCatalog.Catalog import mergeResults
results = [cat.searchResults(sort_on='big', _merge=0)
for cat in self.catalogs]
merged_rids = [r.getRID() for r in mergeResults(
results, has_sort_keys=True, reverse=False)]
expected = sort(chain(*results))
expected = [rid for sortkey, rid, getitem in expected]
self.assertEqual(merged_rids, expected)
def test_suite(): def test_suite():
suite = unittest.TestSuite() suite = unittest.TestSuite()
...@@ -479,7 +545,8 @@ def test_suite(): ...@@ -479,7 +545,8 @@ def test_suite():
suite.addTest( unittest.makeSuite( TestAddDelIndexes ) ) suite.addTest( unittest.makeSuite( TestAddDelIndexes ) )
suite.addTest( unittest.makeSuite( TestZCatalog ) ) suite.addTest( unittest.makeSuite( TestZCatalog ) )
suite.addTest( unittest.makeSuite( TestCatalogObject ) ) suite.addTest( unittest.makeSuite( TestCatalogObject ) )
suite.addTest( unittest.makeSuite( testRS ) ) suite.addTest( unittest.makeSuite( TestRS ) )
suite.addTest( unittest.makeSuite( TestMerge ) )
return suite return suite
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment