Commit 1f8ffaa1 authored by Stefan Behnel's avatar Stefan Behnel

ticket 436: efficiently support char*.decode() through C-API calls

parent eed632a1
......@@ -136,7 +136,7 @@ class Context(object):
IntroduceBufferAuxiliaryVars(self),
_check_c_declarations,
AnalyseExpressionsTransform(self),
OptimizeBuiltinCalls(),
OptimizeBuiltinCalls(self),
IterationTransform(),
SwitchTransform(),
DropRefcountingTransform(),
......
......@@ -305,7 +305,7 @@ class IterationTransform(Visitor.VisitorTransform):
if dest_type != obj_node.type:
if dest_type.is_extension_type or dest_type.is_builtin_type:
obj_node = ExprNodes.PyTypeTestNode(
obj_node, dest_type, FakePythonEnv(), notnone=True)
obj_node, dest_type, self.current_scope, notnone=True)
result = ExprNodes.TypecastNode(
obj_node.pos,
operand = obj_node,
......@@ -320,7 +320,7 @@ class IterationTransform(Visitor.VisitorTransform):
return temp_result.result()
def generate_execution_code(self, code):
self.generate_result_code(code)
return (temp_result, CoercedTempNode(dest_type, obj_node, FakePythonEnv()))
return (temp_result, CoercedTempNode(dest_type, obj_node, self.current_scope))
if isinstance(node.body, Nodes.StatListNode):
body = node.body
......@@ -633,7 +633,7 @@ class DropRefcountingTransform(Visitor.VisitorTransform):
return (base.name, index_val)
class OptimizeBuiltinCalls(Visitor.VisitorTransform):
class OptimizeBuiltinCalls(Visitor.EnvTransform):
"""Optimize some common methods calls and instantiation patterns
for builtin types.
"""
......@@ -961,33 +961,158 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform):
_special_encodings = ['UTF8', 'UTF16', 'Latin1', 'ASCII',
'unicode_escape', 'raw_unicode_escape']
_special_encoders = [ (name, codecs.getencoder(name))
for name in _special_encodings ]
_special_codecs = [ (name, codecs.getencoder(name))
for name in _special_encodings ]
def _handle_simple_method_unicode_encode(self, node, args, is_unbound_method):
if len(args) < 1 or len(args) > 3:
self._error_wrong_arg_count('unicode.encode', node, args, '1-3')
return node
null_node = ExprNodes.NullNode(node.pos)
string_node = args[0]
if len(args) == 1:
null_node = ExprNodes.NullNode(node.pos)
return self._substitute_method_call(
node, "PyUnicode_AsEncodedString",
self.PyUnicode_AsEncodedString_func_type,
'encode', is_unbound_method, [string_node, null_node, null_node])
parameters = self._unpack_encoding_and_error_mode(node.pos, args)
if parameters is None:
return node
encoding, encoding_node, error_handling, error_handling_node = parameters
if isinstance(string_node, ExprNodes.UnicodeNode):
# constant, so try to do the encoding at compile time
try:
value = string_node.value.encode(encoding, error_handling)
except:
# well, looks like we can't
pass
else:
value = BytesLiteral(value)
value.encoding = encoding
return ExprNodes.BytesNode(
string_node.pos, value=value, type=Builtin.bytes_type)
if error_handling == 'strict':
# try to find a specific encoder function
codec_name = self._find_special_codec_name(encoding)
if codec_name is not None:
encode_function = "PyUnicode_As%sString" % codec_name
return self._substitute_method_call(
node, encode_function,
self.PyUnicode_AsXyzString_func_type,
'encode', is_unbound_method, [string_node])
return self._substitute_method_call(
node, "PyUnicode_AsEncodedString",
self.PyUnicode_AsEncodedString_func_type,
'encode', is_unbound_method,
[string_node, encoding_node, error_handling_node])
PyUnicode_DecodeXyz_func_type = PyrexTypes.CFuncType(
Builtin.unicode_type, [
PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
],
exception_value = "NULL")
PyUnicode_Decode_func_type = PyrexTypes.CFuncType(
Builtin.unicode_type, [
PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
],
exception_value = "NULL")
def _handle_simple_method_bytes_decode(self, node, args, is_unbound_method):
if len(args) < 1 or len(args) > 3:
self._error_wrong_arg_count('bytes.decode', node, args, '1-3')
return node
if is_unbound_method:
return node
if not isinstance(args[0], ExprNodes.SliceIndexNode):
# we need the string length as a slice end index
return node
index_node = args[0]
string_node = index_node.base
if not string_node.type.is_string:
# nothing to optimise here
return node
start, stop = index_node.start, index_node.stop
if not stop:
# FIXME: could use strlen() - although Python will do that anyway ...
return node
if stop.type.is_pyobject:
stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.env_stack[-1])
if start and start.constant_result != 0:
# FIXME: put start into a temp and do the math
return node
parameters = self._unpack_encoding_and_error_mode(node.pos, args)
if parameters is None:
return node
encoding, encoding_node, error_handling, error_handling_node = parameters
# try to find a specific encoder function
codec_name = self._find_special_codec_name(encoding)
if codec_name is not None:
decode_function = "PyUnicode_Decode%s" % codec_name
return ExprNodes.PythonCapiCallNode(
node.pos, decode_function,
self.PyUnicode_DecodeXyz_func_type,
args = [string_node, stop, error_handling_node],
is_temp = node.is_temp,
)
return self._substitute_method_call(
node, decode_function,
self.PyUnicode_DecodeXyz_func_type,
'decode', is_unbound_method,
[string_node, stop, error_handling_node])
return ExprNodes.PythonCapiCallNode(
node.pos, "PyUnicode_Decode",
self.PyUnicode_Decode_func_type,
args = [string_node, stop, encoding_node, error_handling_node],
is_temp = node.is_temp,
)
return self._substitute_method_call(
node, "PyUnicode_Decode",
self.PyUnicode_Decode_func_type,
'decode', is_unbound_method,
[string_node, stop, encoding_node, error_handling_node])
def _find_special_codec_name(self, encoding):
try:
requested_codec = codecs.getencoder(encoding)
except:
return None
for name, codec in self._special_codecs:
if codec == requested_codec:
if '_' in name:
name = ''.join([ s.capitalize()
for s in name.split('_')])
return name
return None
def _unpack_encoding_and_error_mode(self, pos, args):
encoding_node = args[1]
if isinstance(encoding_node, ExprNodes.CoerceToPyTypeNode):
encoding_node = encoding_node.arg
if not isinstance(encoding_node, (ExprNodes.UnicodeNode, ExprNodes.StringNode,
ExprNodes.BytesNode)):
return node
return None
encoding = encoding_node.value
encoding_node = ExprNodes.BytesNode(encoding_node.pos, value=encoding,
type=PyrexTypes.c_char_ptr_type)
null_node = ExprNodes.NullNode(pos)
if len(args) == 3:
error_handling_node = args[2]
if isinstance(error_handling_node, ExprNodes.CoerceToPyTypeNode):
......@@ -995,7 +1120,7 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform):
if not isinstance(error_handling_node,
(ExprNodes.UnicodeNode, ExprNodes.StringNode,
ExprNodes.BytesNode)):
return node
return None
error_handling = error_handling_node.value
if error_handling == 'strict':
error_handling_node = null_node
......@@ -1007,43 +1132,7 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform):
error_handling = 'strict'
error_handling_node = null_node
if isinstance(string_node, ExprNodes.UnicodeNode):
# constant, so try to do the encoding at compile time
try:
value = string_node.value.encode(encoding, error_handling)
except:
# well, looks like we can't
pass
else:
value = BytesLiteral(value)
value.encoding = encoding
return ExprNodes.BytesNode(
string_node.pos, value=value, type=Builtin.bytes_type)
if error_handling == 'strict':
# try to find a specific encoder function
try: requested_encoder = codecs.getencoder(encoding)
except: pass
else:
encode_function = None
for name, encoder in self._special_encoders:
if encoder == requested_encoder:
if '_' in name:
name = ''.join([ s.capitalize()
for s in name.split('_')])
encode_function = "PyUnicode_As%sString" % name
break
if encode_function is not None:
return self._substitute_method_call(
node, encode_function,
self.PyUnicode_AsXyzString_func_type,
'encode', is_unbound_method, [string_node])
return self._substitute_method_call(
node, "PyUnicode_AsEncodedString",
self.PyUnicode_AsEncodedString_func_type,
'encode', is_unbound_method,
[string_node, encoding_node, error_handling_node])
return (encoding, encoding_node, error_handling, error_handling_node)
def _substitute_method_call(self, node, name, func_type,
attr_name, is_unbound_method, args=()):
......
from Cython.Compiler.Visitor import VisitorTransform, CythonTransform, TreeVisitor
from Cython.Compiler.Visitor import VisitorTransform, TreeVisitor
from Cython.Compiler.Visitor import CythonTransform, EnvTransform
from Cython.Compiler.ModuleNode import ModuleNode
from Cython.Compiler.Nodes import *
from Cython.Compiler.ExprNodes import *
......@@ -938,21 +939,6 @@ class GilCheck(VisitorTransform):
return node
class EnvTransform(CythonTransform):
"""
This transformation keeps a stack of the environments.
"""
def __call__(self, root):
self.env_stack = [root.scope]
return super(EnvTransform, self).__call__(root)
def visit_FuncDefNode(self, node):
self.env_stack.append(node.local_scope)
self.visitchildren(node)
self.env_stack.pop()
return node
class TransformBuiltinMethods(EnvTransform):
def visit_SingleAssignmentNode(self, node):
......
......@@ -306,6 +306,22 @@ class ScopeTrackingTransform(CythonTransform):
def visit_CStructOrUnionDefNode(self, node):
return self.visit_scope(node, 'struct')
class EnvTransform(CythonTransform):
"""
This transformation keeps a stack of the environments.
"""
def __call__(self, root):
self.env_stack = [root.scope]
return super(EnvTransform, self).__call__(root)
def visit_FuncDefNode(self, node):
self.env_stack.append(node.local_scope)
self.visitchildren(node)
self.env_stack.pop()
return node
class RecursiveNodeReplacer(VisitorTransform):
"""
Recursively replace all occurrences of a node in a subtree by
......
cdef char* cstring = "abcABCqtp"
def slice_charptr_end():
"""
>>> print str(slice_charptr_end()).replace("b'", "'")
('a', 'abc', 'abcABCqtp')
"""
return cstring[:1], cstring[:3], cstring[:9]
def slice_charptr_decode():
"""
>>> print str(slice_charptr_decode()).replace("u'", "'")
('a', 'abc', 'abcABCqtp')
"""
return (cstring[:1].decode('UTF-8'),
cstring[:3].decode('UTF-8'),
cstring[:9].decode('UTF-8'))
def slice_charptr_decode_errormode():
"""
>>> print str(slice_charptr_decode_errormode()).replace("u'", "'")
('a', 'abc', 'abcABCqtp')
"""
return (cstring[:1].decode('UTF-8', 'strict'),
cstring[:3].decode('UTF-8', 'replace'),
cstring[:9].decode('UTF-8', 'unicode_escape'))
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment