Commit ffb006d7 authored by Stefan Behnel's avatar Stefan Behnel

refactor constant string slicing and guard it against platform specific unicode string length

parent d1cc779e
...@@ -1071,6 +1071,12 @@ class BytesNode(ConstNode): ...@@ -1071,6 +1071,12 @@ class BytesNode(ConstNode):
def calculate_constant_result(self): def calculate_constant_result(self):
self.constant_result = self.value self.constant_result = self.value
def as_sliced_node(self, start, stop, step=None):
value = StringEncoding.BytesLiteral(self.value[start:stop:step])
value.encoding = self.value.encoding
return BytesNode(
self.pos, value=value, constant_result=value)
def compile_time_value(self, denv): def compile_time_value(self, denv):
return self.value return self.value
...@@ -1155,6 +1161,22 @@ class UnicodeNode(PyConstNode): ...@@ -1155,6 +1161,22 @@ class UnicodeNode(PyConstNode):
def calculate_constant_result(self): def calculate_constant_result(self):
self.constant_result = self.value self.constant_result = self.value
def as_sliced_node(self, start, stop, step=None):
if _string_contains_surrogates(self.value[:stop]):
# this is unsafe as it may give different results in different runtimes
return None
value = StringEncoding.EncodedString(self.value[start:stop:step])
value.encoding = self.value.encoding
if self.bytes_value is not None:
bytes_value = StringEncoding.BytesLiteral(
self.bytes_value[start:stop:step])
bytes_value.encoding = self.bytes_value.encoding
else:
bytes_value = None
return UnicodeNode(
self.pos, value=value, bytes_value=bytes_value,
constant_result=value)
def coerce_to(self, dst_type, env): def coerce_to(self, dst_type, env):
if dst_type is self.type: if dst_type is self.type:
pass pass
...@@ -1181,21 +1203,7 @@ class UnicodeNode(PyConstNode): ...@@ -1181,21 +1203,7 @@ class UnicodeNode(PyConstNode):
## and (0xDC00 <= self.value[1] <= 0xDFFF)) ## and (0xDC00 <= self.value[1] <= 0xDFFF))
def contains_surrogates(self): def contains_surrogates(self):
# Check if the unicode string contains surrogate code points return _string_contains_surrogates(self.value)
# on a CPython platform with wide (UCS-4) or narrow (UTF-16)
# Unicode, i.e. characters that would be spelled as two
# separate code units on a narrow platform.
for c in map(ord, self.value):
if c > 65535: # can only happen on wide platforms
return True
# We only look for the first code unit (D800-DBFF) of a
# surrogate pair - if we find one, the other one
# (DC00-DFFF) is likely there, too. If we don't find it,
# any second code unit cannot make for a surrogate pair by
# itself.
if 0xD800 <= c <= 0xDBFF:
return True
return False
def generate_evaluation_code(self, code): def generate_evaluation_code(self, code):
self.result_code = code.get_py_string_const(self.value) self.result_code = code.get_py_string_const(self.value)
...@@ -1223,6 +1231,21 @@ class StringNode(PyConstNode): ...@@ -1223,6 +1231,21 @@ class StringNode(PyConstNode):
def calculate_constant_result(self): def calculate_constant_result(self):
self.constant_result = self.value self.constant_result = self.value
def as_sliced_node(self, start, stop, step=None):
value = type(self.value)(self.value[start:stop:step])
value.encoding = self.value.encoding
if self.unicode_value is not None:
if _string_contains_surrogates(self.unicode_value[:stop]):
# this is unsafe as it may give different results in different runtimes
return None
unicode_value = StringEncoding.EncodedString(
self.unicode_value[start:stop:step])
else:
unicode_value = None
return StringNode(
self.pos, value=value, unicode_value=unicode_value,
constant_result=value, is_identifier=self.is_identifier)
def coerce_to(self, dst_type, env): def coerce_to(self, dst_type, env):
if dst_type is not py_object_type and not str_type.subtype_of(dst_type): if dst_type is not py_object_type and not str_type.subtype_of(dst_type):
# if dst_type is Builtin.bytes_type: # if dst_type is Builtin.bytes_type:
...@@ -1257,6 +1280,26 @@ class IdentifierStringNode(StringNode): ...@@ -1257,6 +1280,26 @@ class IdentifierStringNode(StringNode):
is_identifier = True is_identifier = True
def _string_contains_surrogates(ustring):
"""
Check if the unicode string contains surrogate code points
on a CPython platform with wide (UCS-4) or narrow (UTF-16)
Unicode, i.e. characters that would be spelled as two
separate code units on a narrow platform.
"""
for c in map(ord, ustring):
if c > 65535: # can only happen on wide platforms
return True
# We only look for the first code unit (D800-DBFF) of a
# surrogate pair - if we find one, the other one
# (DC00-DFFF) is likely there, too. If we don't find it,
# any second code unit cannot make for a surrogate pair by
# itself.
if 0xD800 <= c <= 0xDBFF:
return True
return False
class ImagNode(AtomicExprNode): class ImagNode(AtomicExprNode):
# Imaginary number literal # Imaginary number literal
# #
......
...@@ -3206,18 +3206,9 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations): ...@@ -3206,18 +3206,9 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
base.args = base.args[start:stop] base.args = base.args[start:stop]
return base return base
elif base.is_string_literal: elif base.is_string_literal:
value = type(base.value)(node.constant_result) base = base.as_sliced_node(start, stop)
value.encoding = base.value.encoding if base is not None:
base.value = value return base
if isinstance(base, ExprNodes.StringNode):
if base.unicode_value is not None:
base.unicode_value = EncodedString(
base.unicode_value[start:stop])
elif isinstance(base, ExprNodes.UnicodeNode):
if base.bytes_value is not None:
base.bytes_value = BytesLiteral(
base.bytes_value[start:stop])
return base
return node return node
def visit_ForInStatNode(self, node): def visit_ForInStatNode(self, node):
......
...@@ -7,6 +7,7 @@ cimport cython ...@@ -7,6 +7,7 @@ cimport cython
bstring = b'abc\xE9def' bstring = b'abc\xE9def'
ustring = u'abc\xE9def' ustring = u'abc\xE9def'
surrogates_ustring = u'abc\U00010000def'
@cython.test_fail_if_path_exists( @cython.test_fail_if_path_exists(
...@@ -53,3 +54,29 @@ def unicode_slicing2(): ...@@ -53,3 +54,29 @@ def unicode_slicing2():
str3 = u'abc\xE9def'[2:4] str3 = u'abc\xE9def'[2:4]
return str0, str1, str2, str3 return str0, str1, str2, str3
@cython.test_assert_path_exists(
"//SliceIndexNode",
)
def unicode_slicing_unsafe_surrogates2():
"""
>>> unicode_slicing_unsafe_surrogates2() == surrogates_ustring[2:]
True
"""
ustring = u'abc\U00010000def'[2:]
return ustring
@cython.test_fail_if_path_exists(
"//SliceIndexNode",
)
def unicode_slicing_safe_surrogates2():
"""
>>> unicode_slicing_safe_surrogates2() == surrogates_ustring[:2]
True
>>> print(unicode_slicing_safe_surrogates2())
ab
"""
ustring = u'abc\U00010000def'[:2]
return ustring
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment