Commit a3dada56 authored by Vincent Pelletier's avatar Vincent Pelletier

ZSQLCatalog: Sanitise more EntireQuery parameters.

Also, make some existing sanitation happen earlier than before, to store
values in a cleaner format.
parent 990f7aee
...@@ -28,6 +28,8 @@ ...@@ -28,6 +28,8 @@
# #
############################################################################## ##############################################################################
import functools
import re
import warnings import warnings
from Products.ZSQLCatalog.SQLExpression import SQLExpression from Products.ZSQLCatalog.SQLExpression import SQLExpression
from Products.ZSQLCatalog.ColumnMap import ColumnMap from Products.ZSQLCatalog.ColumnMap import ColumnMap
...@@ -37,6 +39,41 @@ from zope.interface.verify import verifyClass ...@@ -37,6 +39,41 @@ from zope.interface.verify import verifyClass
from zope.interface import implements from zope.interface import implements
from Products.ZSQLCatalog.TableDefinition import LegacyTableDefinition from Products.ZSQLCatalog.TableDefinition import LegacyTableDefinition
# SQL identifier
# ZSQLCatalog only allows unquote-safe identifiers as table and column names,
# even though it may internally quote them.
# Also, this is a subset of what is accepted as SQL99 identifier: we restrict
# ourselves to ASCII.
UNQUOTED_SQL99_IDENTIFIER = '[0-9a-z$_]+'
COLUMN = '(' + UNQUOTED_SQL99_IDENTIFIER + r'\.)?' + UNQUOTED_SQL99_IDENTIFIER
def _check(value, match):
if value is not None and match(value) is None:
raise ValueError(repr(value))
return value
checkIdentifier = functools.partial(
_check,
match=re.compile('^' + UNQUOTED_SQL99_IDENTIFIER + '$', re.I).match,
)
checkColumn = functools.partial(
_check,
match=re.compile('^' + COLUMN + '$', re.I).match,
)
# Are selectable:
# - "<identifier>([DISTINCT ]{<column>,*})" (ex: "COUNT(DISTINCT foo.reference)")
# - "<column>" (ex: "foo.reference" or "reference")
# - "''", a dirty hack to block acquisition on brain of values which should
# not be available (see stat methods in ListBox)
checkSelectable = functools.partial(
_check,
match=re.compile(
'^(' + UNQUOTED_SQL99_IDENTIFIER + r'\((DISTINCT )?(' + COLUMN + r'|\*)\)|' + COLUMN + "|'')$",
re.I,
).match,
)
del _check
del COLUMN
del UNQUOTED_SQL99_IDENTIFIER
def defaultDict(value): def defaultDict(value):
if value is None: if value is None:
return {} return {}
...@@ -53,6 +90,7 @@ class EntireQuery(object): ...@@ -53,6 +90,7 @@ class EntireQuery(object):
implements(IEntireQuery) implements(IEntireQuery)
column_map = None column_map = None
limit = None
def __init__(self, query, def __init__(self, query,
order_by_list=(), order_by_list=(),
...@@ -66,16 +104,46 @@ class EntireQuery(object): ...@@ -66,16 +104,46 @@ class EntireQuery(object):
extra_column_list=(), extra_column_list=(),
implicit_join=False): implicit_join=False):
self.query = query self.query = query
self.order_by_list = list(order_by_list) self.order_by_list = my_order_by_list = []
self.group_by_list = list(group_by_list) for order_by in order_by_list:
self.select_dict = defaultDict(select_dict) assert isinstance(order_by, (tuple, list))
column, direction, cast = (tuple(order_by) + (None, None))[:3]
my_order_by_list.append((
checkColumn(column),
checkIdentifier(direction),
checkIdentifier(cast) if cast else None,
))
self.group_by_list = [checkColumn(x) for x in group_by_list]
self.select_dict = {
checkIdentifier(alias): checkSelectable(column)
for alias, column in defaultDict(select_dict).iteritems()
}
# No need to sanitize, it's compared against columns and not included in SQL
self.left_join_list = left_join_list self.left_join_list = left_join_list
# No need to sanitize, it's compared against columns and not included in SQL
self.inner_join_list = inner_join_list self.inner_join_list = inner_join_list
self.limit = limit if limit:
self.catalog_table_name = catalog_table_name if not isinstance(limit, (list, tuple)):
self.catalog_table_alias = catalog_table_alias limit = (limit, )
self.extra_column_list = list(extra_column_list) self.limit = [int(x) for x in limit]
self.implicit_join = implicit_join self.catalog_table_name = checkIdentifier(catalog_table_name)
self.catalog_table_alias = checkIdentifier(catalog_table_alias) # XXX: check as quoted identifier ?
self.extra_column_list = my_extra_column_list = []
for extra_column in extra_column_list:
table, column = extra_column.replace('`', '').split('.')
if table != self.catalog_table_name:
raise ValueError('Extra columns must be catalog columns. %r does not follow this rule (catalog=%r, extra_column_list=%r)' % (extra_column, self.catalog_table_name, extra_column_list))
my_extra_column_list.append(
'`%s`.`%s`' % (
# table == self.catalog_table_name, and self.catalog_table_name
# is already checked.
table,
# Note: this is really and identifier and not a column as we
# stripped table name
checkIdentifier(column),
),
)
self.implicit_join = bool(implicit_join)
def asSearchTextExpression(self, sql_catalog): def asSearchTextExpression(self, sql_catalog):
return self.query.asSearchTextExpression(sql_catalog) return self.query.asSearchTextExpression(sql_catalog)
...@@ -98,9 +166,6 @@ class EntireQuery(object): ...@@ -98,9 +166,6 @@ class EntireQuery(object):
self.catalog_table_alias, self.catalog_table_alias,
) )
for extra_column in self.extra_column_list: for extra_column in self.extra_column_list:
table, column = extra_column.replace('`', '').split('.')
if table != self.catalog_table_name:
raise ValueError, 'Extra columns must be catalog columns. %r does not follow this rule (catalog=%r, extra_column_list=%r)' % (extra_column, self.catalog_table_name, self.extra_column_list)
column_map.registerColumn(extra_column) column_map.registerColumn(extra_column)
for column in self.group_by_list: for column in self.group_by_list:
column_map.registerColumn(column) column_map.registerColumn(column)
...@@ -111,8 +176,6 @@ class EntireQuery(object): ...@@ -111,8 +176,6 @@ class EntireQuery(object):
column_map.ignoreColumn(alias) column_map.ignoreColumn(alias)
column_map.registerColumn(column) column_map.registerColumn(column)
for order_by in self.order_by_list: for order_by in self.order_by_list:
assert isinstance(order_by, (tuple, list))
assert len(order_by)
column_map.registerColumn(order_by[0]) column_map.registerColumn(order_by[0])
self.query.registerColumnMap(sql_catalog, column_map) self.query.registerColumnMap(sql_catalog, column_map)
column_map.build(sql_catalog) column_map.build(sql_catalog)
...@@ -154,8 +217,7 @@ class EntireQuery(object): ...@@ -154,8 +217,7 @@ class EntireQuery(object):
LOG('EntireQuery', WARNING, 'Order by %r ignored: it could not be mapped to a known column.' % (order_by, )) LOG('EntireQuery', WARNING, 'Order by %r ignored: it could not be mapped to a known column.' % (order_by, ))
rendered = None rendered = None
if rendered is not None: if rendered is not None:
append((rendered, ) + tuple(order_by[1:]) + ( append((rendered, ) + tuple(order_by[1:]))
None, ) * (3 - len(order_by)))
self.order_by_list = new_order_by_list self.order_by_list = new_order_by_list
# generate SQLExpression from query # generate SQLExpression from query
sql_expression_list = [self.query.asSQLExpression(sql_catalog, sql_expression_list = [self.query.asSQLExpression(sql_catalog,
......
...@@ -270,7 +270,7 @@ class SQLExpression(object): ...@@ -270,7 +270,7 @@ class SQLExpression(object):
for (column, direction, cast) in self.getOrderByList(): for (column, direction, cast) in self.getOrderByList():
expression = conflictSafeGet(order_by_dict, column, str(column)) expression = conflictSafeGet(order_by_dict, column, str(column))
expression = self._reversed_select_dict.get(expression, expression) expression = self._reversed_select_dict.get(expression, expression)
if cast not in (None, ''): if cast is not None:
expression = 'CAST(%s AS %s)' % (expression, cast) expression = 'CAST(%s AS %s)' % (expression, cast)
if direction is not None: if direction is not None:
expression = '%s %s' % (expression, direction) expression = '%s %s' % (expression, direction)
......
...@@ -694,25 +694,6 @@ class TestSQLCatalog(ERP5TypeTestCase): ...@@ -694,25 +694,6 @@ class TestSQLCatalog(ERP5TypeTestCase):
select_dict = sql_expression.getSelectDict() select_dict = sql_expression.getSelectDict()
self.assertTrue('ambiguous_mapping' in select_dict, select_dict) self.assertTrue('ambiguous_mapping' in select_dict, select_dict)
self.assertTrue('bar' in select_dict['ambiguous_mapping'], select_dict['ambiguous_mapping']) self.assertTrue('bar' in select_dict['ambiguous_mapping'], select_dict['ambiguous_mapping'])
# Dotted alias: table name must get stripped. This is required to have an
# upgrade path from old ZSQLCatalog versions where pre-mapped columns were
# used in their select_expression. This must only happen in the
# "{column: None}" form, as otherwise it's the user explicitely asking for
# such alias (which is not strictly invalid).
sql_expression = self.asSQLExpression({'select_dict': {
'foo.default': None,
'foo.keyword': 'foo.keyword',
}}, query_table='foo')
select_dict = sql_expression.getSelectDict()
self.assertTrue('default' in select_dict, select_dict)
self.assertFalse('foo.default' in select_dict, select_dict)
self.assertTrue('foo.keyword' in select_dict, select_dict)
# Variant: same operation, but this time stripping generates an ambiguity.
# That must be detected and cause a mapping exception.
self.assertRaises(ValueError, self.asSQLExpression, {'select_dict': {
'foo.ambiguous_mapping': None,
'bar.ambiguous_mapping': None,
}}, query_table='foo')
def test_hasColumn(self): def test_hasColumn(self):
self.assertTrue(self._catalog.hasColumn('uid')) self.assertTrue(self._catalog.hasColumn('uid'))
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment