From 80453523e3b6c3aa0af9816c999373875d31817a Mon Sep 17 00:00:00 2001
From: Julien Muchembled <jm@nexedi.com>
Date: Wed, 4 May 2016 15:37:08 +0200
Subject: [PATCH] Base: do not check security for unrestricted category value
 getters

---
 product/CMFCategory/CategoryTool.py    |  36 ++++---
 product/ERP5Type/Accessor/Value.py     |   4 +-
 product/ERP5Type/Base.py               | 129 +++++++++++++++++--------
 product/ERP5Type/Core/Predicate.py     |  27 ++----
 product/ERP5Type/tests/testERP5Type.py |   2 +-
 5 files changed, 122 insertions(+), 76 deletions(-)

diff --git a/product/CMFCategory/CategoryTool.py b/product/CMFCategory/CategoryTool.py
index d9c15036be..4cd0d02acf 100644
--- a/product/CMFCategory/CategoryTool.py
+++ b/product/CMFCategory/CategoryTool.py
@@ -1678,11 +1678,14 @@ class CategoryTool( UniqueObject, Folder, Base ):
                                 display_id = 'title')
 
     security.declarePublic('resolveCategory')
-    def resolveCategory(self, relative_url,  default=_marker):
+    def resolveCategory(self, relative_url):
         """
           Finds an object from a relative_url
           Method is public since we use restrictedTraverse
         """
+        return self._resolveCategory(relative_url, True)
+
+    def _resolveCategory(self, relative_url, restricted=False):
         if not isinstance(relative_url, str):
           # Handle parent base category is a special way
           return relative_url
@@ -1759,15 +1762,16 @@ class CategoryTool( UniqueObject, Folder, Base ):
         stack.reverse()
         __traceback_info__ = relative_url
 
-        validate = getSecurityManager().validate
-        def restrictedGetOb(container):
-          obj = container._getOb(key, None)
-          if obj is None or validate(container, container, key, obj):
-            return obj
-          # if user can't access object try to return default passed
-          if default is _marker:
+        if restricted:
+          validate = getSecurityManager().validate
+          def getOb(container):
+            obj = container._getOb(key, None)
+            if obj is None or validate(container, container, key, obj):
+              return obj
             raise Unauthorized('unauthorized access to element %s' % key)
-          return default
+        else:
+          def getOb(container):
+            return container._getOb(key, None)
 
         # XXX Currently, resolveCategory accepts that a category might
         # not start with a Base Category, but with a Module. This is
@@ -1778,32 +1782,32 @@ class CategoryTool( UniqueObject, Folder, Base ):
         if stack:
           portal = aq_inner(self.getPortalObject())
           key = stack.pop()
-          obj = restrictedGetOb(self)
+          obj = getOb(self)
           if obj is None:
-            obj = restrictedGetOb(portal)
+            obj = getOb(portal)
             if obj is not None:
               obj = obj.__of__(self)
           else:
             while stack:
               container = obj
               key = stack.pop()
-              obj = restrictedGetOb(container)
+              obj = getOb(container)
               if obj is not None:
                 break
-              obj = restrictedGetOb(self)
+              obj = getOb(self)
               if obj is None:
-                obj = restrictedGetOb(portal)
+                obj = getOb(portal)
                 if obj is not None:
                   obj = obj.__of__(container)
                 break
 
           while obj is not None and stack:
             key = stack.pop()
-            obj = restrictedGetOb(obj)
+            obj = getOb(obj)
 
         if obj is None:
           LOG('CMFCategory', WARNING,
-              'Could not access object %s' % relative_url)
+              'Could not get object %s' % relative_url)
 
         if cache is not None:
           cache[cache_key] = obj
diff --git a/product/ERP5Type/Accessor/Value.py b/product/ERP5Type/Accessor/Value.py
index cd45a53be6..5064ee0c7e 100644
--- a/product/ERP5Type/Accessor/Value.py
+++ b/product/ERP5Type/Accessor/Value.py
@@ -122,7 +122,7 @@ class DefaultGetter(BaseGetter):
         LOG("ERP5Type Deprecated Getter Id:",0, self._id)
       if args:
         kw['default'] = args[0]
-      return instance._getDefaultAcquiredValue(self._key, **kw)
+      return instance.getDefaultAcquiredValue(self._key, **kw)
 
     psyco.bind(__call__)
 
@@ -153,7 +153,7 @@ class ListGetter(BaseGetter):
         LOG("ERP5Type Deprecated Getter Id:",0, self._id)
       if args:
         kw['default'] = args[0]
-      return instance._getAcquiredValueList(self._key, **kw)
+      return instance.getAcquiredValueList(self._key, **kw)
 
     psyco.bind(__call__)
 
diff --git a/product/ERP5Type/Base.py b/product/ERP5Type/Base.py
index 12cc04198d..c55d016c43 100644
--- a/product/ERP5Type/Base.py
+++ b/product/ERP5Type/Base.py
@@ -1002,8 +1002,8 @@ class Base( CopyContainer,
       #LOG("Get Acquired Property self",0,str(self))
       #LOG("Get Acquired Property portal_type",0,str(portal_type))
       #LOG("Get Acquired Property base_category",0,str(base_category))
-      #super_list = self._getValueList(base_category, portal_type=portal_type) # We only do a single jump
-      super_list = self._getAcquiredValueList(base_category, portal_type=portal_type,
+      #super_list = self.getValueList(base_category, portal_type=portal_type) # We only do a single jump
+      super_list = self.getAcquiredValueList(base_category, portal_type=portal_type,
                                               checked_permission=checked_permission) # Full acquisition
       super_list = [o for o in super_list if o.getPhysicalPath() != self.getPhysicalPath()] # Make sure we do not create stupid loop here
       #LOG("Get Acquired Property super_list",0,str(super_list))
@@ -1131,11 +1131,11 @@ class Base( CopyContainer,
             super_list.append(acquisition_object)
           except (KeyError, AttributeError):
             pass
-      super_list.extend(self._getAcquiredValueList(
-                                          base_category,
-                                          portal_type=portal_type,
-                                          checked_permission=checked_permission))
-                                          # Full acquisition
+      super_list += self.getAcquiredValueList(
+        base_category,
+        portal_type=portal_type,
+        checked_permission=checked_permission,
+        ) # Full acquisition
       super_list = [o for o in super_list if o.getPhysicalPath() != self.getPhysicalPath()] # Make sure we do not create stupid loop here
       if len(super_list) > 0:
         value = []
@@ -1891,60 +1891,111 @@ class Base( CopyContainer,
                               checked_permission=None)
     self.reindexObject()
 
+  # Unrestricted category value getters
+
   def _getDefaultValue(self, id, spec=(), filter=None, default=_MARKER, **kw):
     path = self._getDefaultCategoryMembership(id, base=1, spec=spec,
                                               filter=filter, **kw)
     if path:
-      return self._getCategoryTool().resolveCategory(path)
+      return self._getCategoryTool()._resolveCategory(path)
     if default is not _MARKER:
       return default
 
-  security.declareProtected(Permissions.AccessContentsInformation,
-                            'getDefaultValue')
-  getDefaultValue = _getDefaultValue
-
   def _getValueList(self, id, spec=(), filter=None, default=_MARKER, **kw):
-    ref_list = []
-    for path in self._getCategoryMembershipList(id, base=1, spec=spec,
-                                                filter=filter, **kw):
-      category = self._getCategoryTool().resolveCategory(path)
-      if category is not None:
-        ref_list.append(category)
-    return ref_list if ref_list or default is _MARKER else default
-
-  security.declareProtected(Permissions.AccessContentsInformation,
-                            'getValueList')
-  getValueList = _getValueList
+    ref_list = self._getCategoryMembershipList(id, base=1, spec=spec,
+                                               filter=filter, **kw)
+    if ref_list:
+      resolveCategory = self._getCategoryTool()._resolveCategory
+      value_list = []
+      for path in ref_list:
+        value = resolveCategory(path)
+        if value is not None:
+          value_list.append(value)
+      return value_list if value_list or default is _MARKER else default
+    return ref_list if default is _MARKER else default
 
   def _getDefaultAcquiredValue(self, id, spec=(), filter=None, portal_type=(),
                                evaluate=1, checked_permission=None,
                                default=None, **kw):
-    path = self._getDefaultAcquiredCategoryMembership(id, spec=spec, filter=filter,
-                                                  portal_type=portal_type, base=1,
-                                                  checked_permission=checked_permission,
-                                                  **kw)
+    path = self._getDefaultAcquiredCategoryMembership(
+      id, spec=spec, filter=filter, portal_type=portal_type,
+      base=1, checked_permission=checked_permission, **kw)
+    if path:
+      return self._getCategoryTool()._resolveCategory(path)
+    if default is not _MARKER:
+      return default
+
+  def _getAcquiredValueList(self, id, spec=(), filter=None, default=_MARKER,
+                            **kw):
+    ref_list = self._getAcquiredCategoryMembershipList(id, base=1, spec=spec,
+                                                       filter=filter, **kw)
+    if ref_list:
+      resolveCategory = self._getCategoryTool()._resolveCategory
+      value_list = []
+      for path in ref_list:
+        value = resolveCategory(path)
+        if value is not None:
+          value_list.append(value)
+      return value_list if value_list or default is _MARKER else default
+    return ref_list if default is _MARKER else default
+
+  # Restricted category value getters
+
+  security.declareProtected(Permissions.AccessContentsInformation,
+                            'getDefaultValue')
+  def getDefaultValue(self, id, spec=(), filter=None, default=_MARKER, **kw):
+    path = self._getDefaultCategoryMembership(id, base=1, spec=spec,
+                                              filter=filter, **kw)
     if path:
       return self._getCategoryTool().resolveCategory(path)
     if default is not _MARKER:
       return default
 
   security.declareProtected(Permissions.AccessContentsInformation,
-                            'getDefaultAcquiredValue')
-  getDefaultAcquiredValue = _getDefaultAcquiredValue
+                            'getValueList')
+  def getValueList(self, id, spec=(), filter=None, default=_MARKER, **kw):
+    ref_list = self._getCategoryMembershipList(id, base=1, spec=spec,
+                                               filter=filter, **kw)
+    if ref_list:
+      resolveCategory = self._getCategoryTool().resolveCategory
+      value_list = []
+      for path in ref_list:
+        value = resolveCategory(path)
+        if value is not None:
+          value_list.append(value)
+      return value_list if value_list or default is _MARKER else default
+    return ref_list if default is _MARKER else default
 
-  def _getAcquiredValueList(self, id, spec=(), filter=None, default=_MARKER,
-                            **kw):
-    ref_list = []
-    for path in self._getAcquiredCategoryMembershipList(id, base=1,
-                                                spec=spec,  filter=filter, **kw):
-      category = self._getCategoryTool().resolveCategory(path)
-      if category is not None:
-        ref_list.append(category)
-    return ref_list if ref_list or default is _MARKER else default
+  security.declareProtected(Permissions.AccessContentsInformation,
+                            'getDefaultAcquiredValue')
+  def getDefaultAcquiredValue(self, id, spec=(), filter=None, portal_type=(),
+                              evaluate=1, checked_permission=None,
+                              default=None, **kw):
+    path = self._getDefaultAcquiredCategoryMembership(
+      id, spec=spec, filter=filter, portal_type=portal_type,
+      base=1, checked_permission=checked_permission, **kw)
+    if path:
+      return self._getCategoryTool().resolveCategory(path)
+    if default is not _MARKER:
+      return default
 
   security.declareProtected(Permissions.AccessContentsInformation,
                             'getAcquiredValueList')
-  getAcquiredValueList = _getAcquiredValueList
+  def getAcquiredValueList(self, id, spec=(), filter=None, default=_MARKER,
+                            **kw):
+    ref_list = self._getAcquiredCategoryMembershipList(id, base=1, spec=spec,
+                                                       filter=filter, **kw)
+    if ref_list:
+      resolveCategory = self._getCategoryTool().resolveCategory
+      value_list = []
+      for path in ref_list:
+        value = resolveCategory(path)
+        if value is not None:
+          value_list.append(value)
+      return value_list if value_list or default is _MARKER else default
+    return ref_list if default is _MARKER else default
+
+  ###
 
   def _getDefaultRelatedValue(self, id, spec=(), filter=None, portal_type=(),
                               strict_membership=0, strict="deprecated",
diff --git a/product/ERP5Type/Core/Predicate.py b/product/ERP5Type/Core/Predicate.py
index fbb96bc44b..7877245622 100644
--- a/product/ERP5Type/Core/Predicate.py
+++ b/product/ERP5Type/Core/Predicate.py
@@ -33,8 +33,6 @@ from warnings import warn
 from AccessControl import ClassSecurityInfo
 from Acquisition import aq_base, aq_inner
 
-from Products.CMFCore.utils import getToolByName
-
 from Products.ERP5Type import Permissions, PropertySheet, interfaces
 from Products.ERP5Type.Accessor.Constant import PropertyGetter as ConstantGetter
 from Products.ERP5Type.Document import newTempBase
@@ -44,7 +42,6 @@ from Products.ERP5Type.Cache import readOnlyTransactionCache
 from Products.ERP5Type.TransactionalVariable import getTransactionalVariable
 from Products.ZSQLCatalog.SQLCatalog import SQLQuery
 from Products.ERP5Type.Globals import PersistentMapping
-from Products.ERP5Type.UnrestrictedMethod import UnrestrictedMethod
 from Products.ERP5Type.UnrestrictedMethod import unrestricted_apply
 from Products.CMFCore.Expression import Expression
 
@@ -199,13 +196,6 @@ class Predicate(XMLObject):
       result = expression(createExpressionContext(context))
     return result
 
-  @UnrestrictedMethod
-  def _unrestrictedResolveCategory(self, *args):
-    # Categories used on predicate can be not available to user query, which
-    # shall be applied with predicate.
-    portal_categories = getToolByName(self, 'portal_categories')
-    return portal_categories.resolveCategory(*args)
-
   security.declareProtected( Permissions.AccessContentsInformation,
                              'buildSQLQuery' )
   def buildSQLQuery(self, strict_membership=0, table='category',
@@ -248,7 +238,8 @@ class Predicate(XMLObject):
               f = (i,) if i in f else ()
             catalog_kw[p] = list(f)
 
-    portal_catalog = getToolByName(self, 'portal_catalog')
+    portal = self.getPortalObject()
+    resolveCategory = portal.portal_categories._resolveCategory
 
     from_table_dict = {}
 
@@ -260,7 +251,7 @@ class Predicate(XMLObject):
     for category in self.getMembershipCriterionCategoryList():
       base_category = category.split('/')[0] # Retrieve base category
       if membership_dict.has_key(base_category):
-        category_value = self._unrestrictedResolveCategory(category, None)
+        category_value = resolveCategory(category)
         if category_value is not None:
           table_alias = "single_%s_%s" % (table, base_category)
           from_table_dict[table_alias] = 'category'
@@ -282,7 +273,7 @@ class Predicate(XMLObject):
     for category in self.getMembershipCriterionCategoryList():
       base_category = category.split('/')[0] # Retrieve base category
       if multimembership_dict.has_key(base_category):
-        category_value = self._unrestrictedResolveCategory(category)
+        category_value = resolveCategory(category)
         if category_value is not None:
           join_count += 1
           table_alias = "multi_%s_%s" % (table, join_count)
@@ -317,7 +308,7 @@ class Predicate(XMLObject):
       catalog_kw['where_expression'] = SQLQuery(sql_text)
     # force implicit join
     catalog_kw['implicit_join'] = True
-    sql_query = portal_catalog.buildSQLQuery(**catalog_kw)
+    sql_query = portal.portal_catalog.buildSQLQuery(**catalog_kw)
     # XXX from_table_list is None most of the time after the explicit_join work
     for alias, table in sql_query['from_table_list']:
       if from_table_dict.has_key(alias):
@@ -359,15 +350,15 @@ class Predicate(XMLObject):
   def searchResults(self, **kw):
     """
     """
-    portal_catalog = getToolByName(self, 'portal_catalog')
-    return portal_catalog.searchResults(build_sql_query_method=self.buildSQLQuery,**kw)
+    return self.getPortalObject().portal_catalog.searchResults(
+      build_sql_query_method=self.buildSQLQuery, **kw)
 
   security.declareProtected(Permissions.AccessContentsInformation, 'countResults')
   def countResults(self, REQUEST=None, used=None, **kw):
     """
     """
-    portal_catalog = getToolByName(self, 'portal_catalog')
-    return portal_catalog.countResults(build_sql_query_method=self.buildSQLQuery,**kw)
+    return self.getPortalObject().portal_catalog.countResults(
+      build_sql_query_method=self.buildSQLQuery, **kw)
 
   security.declareProtected( Permissions.AccessContentsInformation, 'getCriterionList' )
   def getCriterionList(self, **kw):
diff --git a/product/ERP5Type/tests/testERP5Type.py b/product/ERP5Type/tests/testERP5Type.py
index d7ce9b5ce1..e4a3cdc054 100644
--- a/product/ERP5Type/tests/testERP5Type.py
+++ b/product/ERP5Type/tests/testERP5Type.py
@@ -2347,7 +2347,7 @@ class TestERP5Type(PropertySheetTestCase, LogInterceptor):
       self._ignore_log_errors()
       logged_errors = [ logrecord for logrecord in self.logged
                         if logrecord.name == 'CMFCategory' ]
-      self.assertEqual('Could not access object region/gamma',
+      self.assertEqual('Could not get object region/gamma',
                         logged_errors[0].getMessage())
 
     def test_list_accessors(self):
-- 
2.30.9