Commit f2e83334 authored by Stefan Behnel's avatar Stefan Behnel

Add Cython.Utils to the list of compiled modules and include a faster...

Add Cython.Utils to the list of compiled modules and include a faster @contextmanager for try-finally cases.
parent 47c1d85d
......@@ -2,7 +2,7 @@ import unittest
from Cython.Utils import (
_CACHE_NAME_PATTERN, _build_cache_name, _find_cache_attributes,
build_hex_version, cached_method, clear_method_caches)
build_hex_version, cached_method, clear_method_caches, try_finally_contextmanager)
METHOD_NAME = "cached_next"
CACHE_NAME = _build_cache_name(METHOD_NAME)
......@@ -94,3 +94,35 @@ class TestCythonUtils(unittest.TestCase):
clear_method_caches(obj)
self.set_of_names_equal(obj, {names})
def test_try_finally_contextmanager(self):
states = []
@try_finally_contextmanager
def gen(*args, **kwargs):
states.append("enter")
yield (args, kwargs)
states.append("exit")
with gen(1, 2, 3, x=4) as call_args:
assert states == ["enter"]
self.assertEqual(call_args, ((1, 2, 3), {'x': 4}))
assert states == ["enter", "exit"]
class MyException(RuntimeError):
pass
del states[:]
with self.assertRaises(MyException):
with gen(1, 2, y=4) as call_args:
assert states == ["enter"]
self.assertEqual(call_args, ((1, 2), {'y': 4}))
raise MyException("FAIL INSIDE")
assert states == ["enter", "exit"]
del states[:]
with self.assertRaises(StopIteration):
with gen(1, 2, y=4) as call_args:
assert states == ["enter"]
self.assertEqual(call_args, ((1, 2), {'y': 4}))
raise StopIteration("STOP")
assert states == ["enter", "exit"]
cdef class _TryFinallyGeneratorContextManager:
cdef object _gen
"""
Cython -- Things that don't belong
anywhere else in particular
Cython -- Things that don't belong anywhere else in particular
"""
from __future__ import absolute_import
import cython
cython.declare(
basestring=object,
os=object, sys=object, re=object, io=object, codecs=object, glob=object, shutil=object, tempfile=object,
cython_version=object,
_function_caches=list, _parse_file_version=object, _match_file_encoding=object,
)
try:
from __builtin__ import basestring
except ImportError:
......@@ -23,7 +31,7 @@ import codecs
import glob
import shutil
import tempfile
from contextlib import contextmanager
from functools import wraps
from . import __version__ as cython_version
......@@ -34,6 +42,31 @@ _CACHE_NAME_PATTERN = re.compile(r"^__(.+)_cache$")
modification_time = os.path.getmtime
class _TryFinallyGeneratorContextManager(object):
"""
Fast, bare minimum @contextmanager, only for try-finally, not for exception handling.
"""
def __init__(self, gen):
self._gen = gen
def __enter__(self):
return next(self._gen)
def __exit__(self, exc_type, exc_val, exc_tb):
try:
next(self._gen)
except (StopIteration, GeneratorExit):
pass
def try_finally_contextmanager(gen_func):
@wraps(gen_func)
def make_gen(*args, **kwargs):
return _TryFinallyGeneratorContextManager(gen_func(*args, **kwargs))
return make_gen
_function_caches = []
......@@ -47,6 +80,7 @@ def cached_function(f):
_function_caches.append(cache)
uncomputed = object()
@wraps(f)
def wrapper(*args):
res = cache.get(args, uncomputed)
if res is uncomputed:
......@@ -443,7 +477,7 @@ def get_cython_cache_dir():
return os.path.expanduser(os.path.join('~', '.cython'))
@contextmanager
@try_finally_contextmanager
def captured_fd(stream=2, encoding=None):
orig_stream = os.dup(stream) # keep copy of original stream
try:
......@@ -455,13 +489,12 @@ def captured_fd(stream=2, encoding=None):
return _output[0]
os.dup2(temp_file.fileno(), stream) # replace stream by copy of pipe
try:
def get_output():
result = read_output()
return result.decode(encoding) if encoding else result
yield get_output
finally:
# note: @contextlib.contextmanager requires try-finally here
os.dup2(orig_stream, stream) # restore original stream
read_output() # keep the output in case it's used after closing the context manager
finally:
......@@ -514,23 +547,6 @@ def print_bytes(s, header_text=None, end=b'\n', file=sys.stdout, flush=True):
out.flush()
class LazyStr:
def __init__(self, callback):
self.callback = callback
def __str__(self):
return self.callback()
def __repr__(self):
return self.callback()
def __add__(self, right):
return self.callback() + right
def __radd__(self, left):
return left + self.callback()
class OrderedSet(object):
def __init__(self, elements=()):
self._list = []
......
......@@ -94,6 +94,7 @@ def compile_cython_modules(profile=False, coverage=False, compile_more=False, cy
"Cython.Compiler.FusedNode",
"Cython.Tempita._tempita",
"Cython.StringIOTree",
"Cython.Utils",
]
if compile_more:
compiled_modules.extend([
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment