Commit 196d9e35 authored by Stefan Behnel's avatar Stefan Behnel

Clean up and test type identifier escaping.

- hash() hashing lead to unpredictable random prefixes for long names across multiple runs
- use a single regex run instead of repeated calls to replace()
parent c9e107b4
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
from __future__ import absolute_import from __future__ import absolute_import
import hashlib
import copy import copy
import re import re
...@@ -11,6 +12,7 @@ try: ...@@ -11,6 +12,7 @@ try:
reduce reduce
except NameError: except NameError:
from functools import reduce from functools import reduce
from functools import partial
from Cython.Utils import cached_function from Cython.Utils import cached_function
from .Code import UtilityCode, LazyUtilityCode, TempitaUtilityCode from .Code import UtilityCode, LazyUtilityCode, TempitaUtilityCode
...@@ -45,7 +47,9 @@ class BaseType(object): ...@@ -45,7 +47,9 @@ class BaseType(object):
def cast_code(self, expr_code): def cast_code(self, expr_code):
return "((%s)%s)" % (self.empty_declaration_code(), expr_code) return "((%s)%s)" % (self.empty_declaration_code(), expr_code)
def empty_declaration_code(self): def empty_declaration_code(self, pyrex=False):
if pyrex:
return self.declaration_code('', pyrex=True)
if self._empty_declaration is None: if self._empty_declaration is None:
self._empty_declaration = self.declaration_code('') self._empty_declaration = self.declaration_code('')
return self._empty_declaration return self._empty_declaration
...@@ -3248,8 +3252,7 @@ class CFuncType(CType): ...@@ -3248,8 +3252,7 @@ class CFuncType(CType):
if not self.can_coerce_to_pyobject(env): if not self.can_coerce_to_pyobject(env):
return False return False
from .UtilityCode import CythonUtilityCode from .UtilityCode import CythonUtilityCode
safe_typename = type_identifier_from_declaration(self.declaration_code("", pyrex=1)) to_py_function = "__Pyx_CFunc_%s_to_py" % type_identifier(self, pyrex=True)
to_py_function = "__Pyx_CFunc_%s_to_py" % safe_typename
for arg in self.args: for arg in self.args:
if not arg.type.is_pyobject and not arg.type.create_from_py_utility_code(env): if not arg.type.is_pyobject and not arg.type.create_from_py_utility_code(env):
...@@ -4986,8 +4989,32 @@ def typecast(to_type, from_type, expr_code): ...@@ -4986,8 +4989,32 @@ def typecast(to_type, from_type, expr_code):
def type_list_identifier(types): def type_list_identifier(types):
return cap_length('__and_'.join(type_identifier(type) for type in types)) return cap_length('__and_'.join(type_identifier(type) for type in types))
def type_identifier(type): _special_type_characters = {
decl = type.empty_declaration_code() '__': '__dunder',
'const ': '__const_',
' ': '__space_',
'*': '__ptr',
'&': '__ref',
'&&': '__fwref',
'[': '__lArr',
']': '__rArr',
'<': '__lAng',
'>': '__rAng',
'(': '__lParen',
')': '__rParen',
',': '__comma_',
'...': '__EL',
'::': '__in_',
':': '__D',
}
_escape_special_type_characters = partial(re.compile(
# join substrings in reverse order to put longer matches first, e.g. "::" before ":"
" ?(%s) ?" % "|".join(re.escape(s) for s in sorted(_special_type_characters, reverse=True))
).sub, lambda match: _special_type_characters[match.group(1)])
def type_identifier(type, pyrex=False):
decl = type.empty_declaration_code(pyrex=pyrex)
return type_identifier_from_declaration(decl) return type_identifier_from_declaration(decl)
_type_identifier_cache = {} _type_identifier_cache = {}
...@@ -4996,21 +5023,8 @@ def type_identifier_from_declaration(decl): ...@@ -4996,21 +5023,8 @@ def type_identifier_from_declaration(decl):
if safe is None: if safe is None:
safe = decl safe = decl
safe = re.sub(' +', ' ', safe) safe = re.sub(' +', ' ', safe)
safe = re.sub(' ([^a-zA-Z0-9_])', r'\1', safe) safe = re.sub(' ?([^a-zA-Z0-9_]) ?', r'\1', safe)
safe = re.sub('([^a-zA-Z0-9_]) ', r'\1', safe) safe = _escape_special_type_characters(safe)
safe = (safe.replace('__', '__dunder')
.replace('const ', '__const_')
.replace(' ', '__space_')
.replace('*', '__ptr')
.replace('&', '__ref')
.replace('[', '__lArr')
.replace(']', '__rArr')
.replace('<', '__lAng')
.replace('>', '__rAng')
.replace('(', '__lParen')
.replace(')', '__rParen')
.replace(',', '__comma_')
.replace('::', '__in_'))
safe = cap_length(re.sub('[^a-zA-Z0-9_]', lambda x: '__%X' % ord(x.group(0)), safe)) safe = cap_length(re.sub('[^a-zA-Z0-9_]', lambda x: '__%X' % ord(x.group(0)), safe))
_type_identifier_cache[decl] = safe _type_identifier_cache[decl] = safe
return safe return safe
...@@ -5018,5 +5032,5 @@ def type_identifier_from_declaration(decl): ...@@ -5018,5 +5032,5 @@ def type_identifier_from_declaration(decl):
def cap_length(s, max_prefix=63, max_len=1024): def cap_length(s, max_prefix=63, max_len=1024):
if len(s) <= max_prefix: if len(s) <= max_prefix:
return s return s
else: hash_prefix = hashlib.sha1(s.encode('ascii')).hexdigest()[:6]
return '%x__%s__etc' % (abs(hash(s)) % (1<<20), s[:max_len-17]) return '%s__%s__etc' % (hash_prefix, s[:max_len-17])
...@@ -17,3 +17,59 @@ class TestMethodDispatcherTransform(unittest.TestCase): ...@@ -17,3 +17,59 @@ class TestMethodDispatcherTransform(unittest.TestCase):
cenum = PT.CEnumType("E", "cenum", typedef_flag=False) cenum = PT.CEnumType("E", "cenum", typedef_flag=False)
assert_widest(PT.c_int_type, cenum, PT.c_int_type) assert_widest(PT.c_int_type, cenum, PT.c_int_type)
class TestTypeIdentifiers(unittest.TestCase):
TEST_DATA = [
("char*", "char__ptr"),
("char *", "char__ptr"),
("char **", "char__ptr__ptr"),
("_typedef", "_typedef"),
("__typedef", "__dundertypedef"),
("___typedef", "__dunder_typedef"),
("____typedef", "__dunder__dundertypedef"),
("_____typedef", "__dunder__dunder_typedef"),
("const __typedef", "__const___dundertypedef"),
("int[42]", "int__lArr42__rArr"),
("int[:]", "int__lArr__D__rArr"),
("int[:,:]", "int__lArr__D__comma___D__rArr"),
("int[:,:,:]", "int__lArr__D__comma___D__comma___D__rArr"),
("int[:,:,...]", "int__lArr__D__comma___D__comma___EL__rArr"),
("std::vector", "std__in_vector"),
("std::vector&&", "std__in_vector__fwref"),
("const std::vector", "__const_std__in_vector"),
("const std::vector&", "__const_std__in_vector__ref"),
("const_std", "const_std"),
]
def test_escape_special_type_characters(self):
test_func = PT._escape_special_type_characters # keep test usage visible for IDEs
function_name = "_escape_special_type_characters"
self._test_escape(function_name)
def test_type_identifier_for_declaration(self):
test_func = PT.type_identifier_from_declaration # keep test usage visible for IDEs
function_name = test_func.__name__
self._test_escape(function_name)
# differences due to whitespace removal
test_data = [
("const &std::vector", "const__refstd__in_vector"),
("const &std::vector<int>", "const__refstd__in_vector__lAngint__rAng"),
("const &&std::vector", "const__fwrefstd__in_vector"),
("const &&&std::vector", "const__fwref__refstd__in_vector"),
("const &&std::vector", "const__fwrefstd__in_vector"),
("void (*func)(int x, float y)",
"07d63e__void__lParen__ptrfunc__rParen__lParenint__space_x__comma_float__space_y__rParen__etc"),
("float ** (*func)(int x, int[:] y)",
"79b33d__float__ptr__ptr__lParen__ptrfunc__rParen__lParenint__space_x__comma_int__lArr__D__rArry__rParen__etc"),
]
self._test_escape(function_name, test_data)
def _test_escape(self, func_name, test_data=TEST_DATA):
escape = getattr(PT, func_name)
for declaration, expected in test_data:
escaped_value = escape(declaration)
self.assertEqual(escaped_value, expected, "%s('%s') == '%s' != '%s'" % (
func_name, declaration, escaped_value, expected))
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment