Commit c9d6f7ca authored by Vincent Pelletier's avatar Vincent Pelletier

ZSQLCatalog: Assorted assertion improvements.

Accelerate debugging by providing relevant values right away.
parent 54055607
...@@ -109,14 +109,14 @@ class ColumnMap(object): ...@@ -109,14 +109,14 @@ class ColumnMap(object):
) )
def registerColumn(self, raw_column, group=DEFAULT_GROUP_ID, simple_query=None): def registerColumn(self, raw_column, group=DEFAULT_GROUP_ID, simple_query=None):
assert ' as ' not in raw_column.lower() assert ' as ' not in raw_column.lower(), raw_column
# Sanitize input: extract column from raw column (might contain COUNT, ...). # Sanitize input: extract column from raw column (might contain COUNT, ...).
# XXX This is not enough to parse something like: # XXX This is not enough to parse something like:
# GROUP_CONCAT(DISTINCT foo ORDER BY bar) # GROUP_CONCAT(DISTINCT foo ORDER BY bar)
if '(' in raw_column: if '(' in raw_column:
function, column = raw_column.split('(') function, column = raw_column.split('(')
column = column.strip() column = column.strip()
assert column[-1] == ')' assert column[-1] == ')', column
column = column[:-1].strip() column = column[:-1].strip()
else: else:
function = None function = None
...@@ -156,8 +156,9 @@ class ColumnMap(object): ...@@ -156,8 +156,9 @@ class ColumnMap(object):
order = self.related_key_order_dict.get(real_related_column, 0) + 1 order = self.related_key_order_dict.get(real_related_column, 0) + 1
related_column = '%s_%s' % (related_column, order) related_column = '%s_%s' % (related_column, order)
group = 'related_%s' % (related_column, ) group = 'related_%s' % (related_column, )
assert group not in self.registry assert group not in self.registry, (group, self.registry)
assert group not in self.related_group_dict assert group not in self.related_group_dict, (group,
self.related_group_dict)
self.related_key_order_dict[real_related_column] = order self.related_key_order_dict[real_related_column] = order
self.related_key_dict[real_related_column] = (group, column) self.related_key_dict[real_related_column] = (group, column)
self.registerColumn(column, group=group) self.registerColumn(column, group=group)
...@@ -185,9 +186,10 @@ class ColumnMap(object): ...@@ -185,9 +186,10 @@ class ColumnMap(object):
self.resolveTable(self.catalog_table_name, self.catalog_table_name) self.resolveTable(self.catalog_table_name, self.catalog_table_name)
def registerRelatedKeyColumn(self, related_column, position, group): def registerRelatedKeyColumn(self, related_column, position, group):
assert group in self.related_group_dict assert group in self.related_group_dict, (group, self.related_group_dict)
group = self.getRelatedKeyGroup(position, group) group = self.getRelatedKeyGroup(position, group)
assert group not in self.related_group_dict assert group not in self.related_group_dict, (group,
self.related_group_dict)
self.related_group_dict[group] = related_column self.related_group_dict[group] = related_column
return group return group
...@@ -494,11 +496,13 @@ class ColumnMap(object): ...@@ -494,11 +496,13 @@ class ColumnMap(object):
return None return None
def resolveColumn(self, column, table_name, group=DEFAULT_GROUP_ID): def resolveColumn(self, column, table_name, group=DEFAULT_GROUP_ID):
assert group in self.registry assert group in self.registry, (group, self.registry)
assert column in self.registry[group] assert column in self.registry[group], (column, group,
self.registry[group])
column_map_key = (group, column) column_map_key = (group, column)
column_map = self.column_map column_map = self.column_map
assert (group, table_name) in self.table_alias_dict assert (group, table_name) in self.table_alias_dict, (group, table_name,
self.table_alias_dict)
previous_value = column_map.get(column_map_key) previous_value = column_map.get(column_map_key)
if previous_value is None: if previous_value is None:
column_map[column_map_key] = table_name column_map[column_map_key] = table_name
...@@ -510,10 +514,13 @@ class ColumnMap(object): ...@@ -510,10 +514,13 @@ class ColumnMap(object):
def resolveTable(self, table_name, alias, group=DEFAULT_GROUP_ID): def resolveTable(self, table_name, alias, group=DEFAULT_GROUP_ID):
table_alias_key = (group, table_name) table_alias_key = (group, table_name)
assert table_alias_key in self.table_alias_dict assert table_alias_key in self.table_alias_dict, (table_alias_key,
assert self.table_alias_dict[table_alias_key] in (None, alias) self.table_alias_dict)
assert self.table_alias_dict[table_alias_key] in (None, alias), (
table_alias_key, self.table_alias_dict[table_alias_key], alias)
self.table_alias_dict[table_alias_key] = alias self.table_alias_dict[table_alias_key] = alias
assert self.table_map.get(alias) in (None, table_name) assert self.table_map.get(alias) in (None, table_name), (alias,
self.table_map.get(alias), table_name)
self.table_map[alias] = table_name self.table_map[alias] = table_name
def getTableAlias(self, table_name, group=DEFAULT_GROUP_ID): def getTableAlias(self, table_name, group=DEFAULT_GROUP_ID):
......
...@@ -133,7 +133,8 @@ class RelatedKey(SearchKey): ...@@ -133,7 +133,8 @@ class RelatedKey(SearchKey):
# value of the "group" variable) to be the same as the table used # value of the "group" variable) to be the same as the table used
# in join_condition. # in join_condition.
if table_alias_list is not None: if table_alias_list is not None:
assert len(self.table_list) == len(table_alias_list) assert len(self.table_list) == len(table_alias_list), (self.table_list,
table_alias_list)
# XXX-Leo: remove the rest of this 'if' branch after making sure # XXX-Leo: remove the rest of this 'if' branch after making sure
# that ColumnMap.addRelatedKeyJoin() can handle collapsing # that ColumnMap.addRelatedKeyJoin() can handle collapsing
# chains of inner-joins that are subsets of one another based on # chains of inner-joins that are subsets of one another based on
...@@ -150,13 +151,13 @@ class RelatedKey(SearchKey): ...@@ -150,13 +151,13 @@ class RelatedKey(SearchKey):
if table_alias_list is not None: if table_alias_list is not None:
# Pre-resolve all tables with given aliases # Pre-resolve all tables with given aliases
given_name, given_alias = table_alias_list[table_position] given_name, given_alias = table_alias_list[table_position]
assert table_name == given_name assert table_name == given_name, (table_name, given_name)
column_map.resolveTable(table_name, given_alias, group=local_group) column_map.resolveTable(table_name, given_alias, group=local_group)
table_name = self.table_list[-1] table_name = self.table_list[-1]
column_map.registerTable(table_name, group=group) column_map.registerTable(table_name, group=group)
if table_alias_list is not None: if table_alias_list is not None:
given_name, given_alias = table_alias_list[-1] given_name, given_alias = table_alias_list[-1]
assert table_name == given_name assert table_name == given_name, (table_name, given_name)
column_map.resolveTable(table_name, given_alias, group=group) column_map.resolveTable(table_name, given_alias, group=group)
# Resolve (and register) related key column in related key group with its last table. # Resolve (and register) related key column in related key group with its last table.
column_map.registerColumn(self.real_column, group=group) column_map.registerColumn(self.real_column, group=group)
...@@ -171,7 +172,7 @@ class RelatedKey(SearchKey): ...@@ -171,7 +172,7 @@ class RelatedKey(SearchKey):
right = column_map.makeTableAliasDefinition(table, alias) right = column_map.makeTableAliasDefinition(table, alias)
if not join_query_list: if not join_query_list:
# nothing to do, just return the table alias # nothing to do, just return the table alias
assert len(table_alias_list) == 1 assert len(table_alias_list) == 1, table_alias_list
return right return right
else: else:
# create an InnerJoin of the last element of the alias list with # create an InnerJoin of the last element of the alias list with
...@@ -214,7 +215,8 @@ class RelatedKey(SearchKey): ...@@ -214,7 +215,8 @@ class RelatedKey(SearchKey):
table_alias_dict = {'table_%s' % index: alias[0] table_alias_dict = {'table_%s' % index: alias[0]
for index, alias in enumerate(table_alias_list)} for index, alias in enumerate(table_alias_list)}
assert len(table_alias_list) == len(table_alias_dict) assert len(table_alias_list) == len(table_alias_dict), (table_alias_list,
table_alias_dict)
query_table=column_map.getCatalogTableAlias() query_table=column_map.getCatalogTableAlias()
rendered_related_key = related_key( rendered_related_key = related_key(
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment