Commit 6eb80104 authored by Kirill Smelkov's avatar Kirill Smelkov

sync.WorkGroup: Provide "with" support

So that it becomes possible to write

    with WorkGroup(ctx) as wg:
        wg.go(f1)
        wg.go(f2)

instead of

    wg = WorkGroup(ctx)
    defer(wg.wait)
    wg.go(f1)
    wg.go(f2)

or

    wg = WorkGroup(ctx)
    wg.go(f1)
    wg.go(f2)
    wg.wait()

This is sometimes handy and is referred to as "structured concurrency"
in Python world.

sync.Sema, sync.Mutex, sync.RWMutex already support "with".
sync.WaitGroup is imho too low-level, but we might consider adding
"with" support for it in the future as well.

In general pygolang way is to use defer instead of plugging all classes
with __enter__/__exit__ "with" support, but for small well-known class of
concurrency-related things its seems "with" support is worth it:

- having "with" for sync.Mutex+co allows it to be used as a drop-in
  replacement instead of threading.Lock+co, and
- having "with" for sync.WorkGroup - the most commonly-used tool to
  spawn jobs and wait for their completion - makes it on-par with
  "structured concurrency".

/reviewed-on !12
parent 85257b2a
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
from __future__ import print_function, absolute_import from __future__ import print_function, absolute_import
from cython cimport final from cython cimport final
from cpython cimport PyObject from cpython cimport PyObject, PY_MAJOR_VERSION
from golang cimport nil, newref, topyexc from golang cimport nil, newref, topyexc
from golang cimport context from golang cimport context
from golang.pyx cimport runtime from golang.pyx cimport runtime
...@@ -34,6 +34,8 @@ cdef extern from "golang/sync.h" namespace "golang::sync" nogil: ...@@ -34,6 +34,8 @@ cdef extern from "golang/sync.h" namespace "golang::sync" nogil:
from libcpp.cast cimport dynamic_cast from libcpp.cast cimport dynamic_cast
import sys as pysys
@final @final
cdef class PySema: cdef class PySema:
...@@ -196,6 +198,13 @@ cdef class PyWorkGroup: ...@@ -196,6 +198,13 @@ cdef class PyWorkGroup:
work. .wait() waits for all spawned goroutines to complete and returns/raises work. .wait() waits for all spawned goroutines to complete and returns/raises
error, if any, from the first failed subtask. error, if any, from the first failed subtask.
WorkGroup can be also used via `with` statement where .wait() is
automatically called at the end of the block, for example:
with WorkGroup(ctx) as wg:
wg.go(f1)
wg.go(f2)
WorkGroup is modelled after https://godoc.org/golang.org/x/sync/errgroup but WorkGroup is modelled after https://godoc.org/golang.org/x/sync/errgroup but
is not equal to it. is not equal to it.
""" """
...@@ -236,6 +245,40 @@ cdef class PyWorkGroup: ...@@ -236,6 +245,40 @@ cdef class PyWorkGroup:
# reraise pyerr with original traceback # reraise pyerr with original traceback
pyerr_reraise(pyerr) pyerr_reraise(pyerr)
# with support
def __enter__(PyWorkGroup pyg):
return pyg
def __exit__(PyWorkGroup pyg, exc_typ, exc_val, exc_tb):
# py2: prepare exc_val to be chained into
if PY_MAJOR_VERSION == 2 and exc_val is not None:
_pyexc_contextify(exc_val, None)
# if .wait() raises, we want raised exception to be chained into
# exc_val via .__context__, so that
#
# wg = sync.WorkGroup(ctx)
# defer(wg.wait)
# ...
#
# and
#
# with sync.WorkGroup(ctx) as wg:
# ...
#
# are equivalent.
#
# Even if Python3 implements exception chaining natively, it does not
# automatically chain exceptions in __exit__. Implement the chaining ourselves.
try:
pyg.wait()
except:
if PY_MAJOR_VERSION == 2:
if exc_val is not None and not hasattr(exc_val, '__traceback__'):
exc_val.__traceback__ = exc_tb
exc = pysys.exc_info()[1]
_pyexc_contextify(exc, exc_val)
raise
# _PyCtxFunc complements PyWorkGroup.go() : it's operator()(ctx) verifies that # _PyCtxFunc complements PyWorkGroup.go() : it's operator()(ctx) verifies that
# ctx is expected context and further calls python function without any arguments. # ctx is expected context and further calls python function without any arguments.
# PyWorkGroup.go() arranges to use python functions that are bound to PyContext # PyWorkGroup.go() arranges to use python functions that are bound to PyContext
...@@ -271,6 +314,19 @@ cdef extern from * nogil: ...@@ -271,6 +314,19 @@ cdef extern from * nogil:
# ---- misc ---- # ---- misc ----
# _pyexc_contextify makes sure pyexc has .__context__, .__cause__ and
# .__suppress_context__ attributes.
#
# .__context__ if not already present, or if it was previously None, is set to pyexccontext.
cdef _pyexc_contextify(object pyexc, pyexccontext):
if not hasattr(pyexc, '__context__') or pyexc.__context__ is None:
pyexc.__context__ = pyexccontext
if not hasattr(pyexc, '__cause__'):
pyexc.__cause__ = None
if not hasattr(pyexc, '__suppress_context__'):
pyexc.__suppress_context__ = False
cdef nogil: cdef nogil:
void semaacquire_pyexc(Sema *sema) except +topyexc: void semaacquire_pyexc(Sema *sema) except +topyexc:
......
...@@ -20,9 +20,10 @@ ...@@ -20,9 +20,10 @@
from __future__ import print_function, absolute_import from __future__ import print_function, absolute_import
from golang import go, chan, select, default from golang import go, chan, select, default, func, defer
from golang import sync, context, time from golang import sync, context, time
from pytest import raises, mark from pytest import raises, mark
from _pytest._code import Traceback
from golang.golang_test import import_pyx_tests, panics from golang.golang_test import import_pyx_tests, panics
from golang.time_test import dt from golang.time_test import dt
from six.moves import range as xrange from six.moves import range as xrange
...@@ -245,6 +246,17 @@ PyErr_Restore_traceback_ok = True ...@@ -245,6 +246,17 @@ PyErr_Restore_traceback_ok = True
if 'PyPy' in sys.version and sys.pypy_version_info < (7,3): if 'PyPy' in sys.version and sys.pypy_version_info < (7,3):
PyErr_Restore_traceback_ok = False PyErr_Restore_traceback_ok = False
# WorkGroup must catch/propagate all exception classes.
# Python2 allows to raise old-style classes not derived from BaseException.
# Python3 allows to raise only BaseException derivatives.
if six.PY2:
class MyError:
def __init__(self, *args):
self.args = args
else:
class MyError(BaseException):
pass
def test_workgroup(): def test_workgroup():
ctx, cancel = context.with_cancel(context.background()) ctx, cancel = context.with_cancel(context.background())
mu = sync.Mutex() mu = sync.Mutex()
...@@ -260,16 +272,6 @@ def test_workgroup(): ...@@ -260,16 +272,6 @@ def test_workgroup():
wg.wait() wg.wait()
assert l == [1, 2] assert l == [1, 2]
# WorkGroup must catch/propagate all exception classes.
# Python2 allows to raise old-style classes not derived from BaseException.
# Python3 allows to raise only BaseException derivatives.
if six.PY2:
class MyError:
def __init__(self, *args):
self.args = args
else:
class MyError(BaseException):
pass
# t1=fail, t2=ok, does not look at ctx # t1=fail, t2=ok, does not look at ctx
wg = sync.WorkGroup(ctx) wg = sync.WorkGroup(ctx)
...@@ -337,6 +339,92 @@ def test_workgroup(): ...@@ -337,6 +339,92 @@ def test_workgroup():
wg.wait() wg.wait()
assert l == [1, 2] assert l == [1, 2]
@func
def test_workgroup_with():
# verify with support for sync.WorkGroup
ctx, cancel = context.with_cancel(context.background())
defer(cancel)
mu = sync.Mutex()
# t1=ok, t2=ok
l = [0, 0]
with sync.WorkGroup(ctx) as wg:
for i in range(2):
def _(ctx, i):
with mu:
l[i] = i+1
wg.go(_, i)
assert l == [1, 2]
# t1=fail, t2=wait cancel, fail
with raises(MyError) as exci:
with sync.WorkGroup(ctx) as wg:
def _(ctx):
Iam_t1 = 0
raise MyError('hello (fail)')
wg.go(_)
def _(ctx):
ctx.done().recv()
raise MyError('world (after zzz)')
wg.go(_)
e = exci.value
assert e.__class__ is MyError
assert e.args == ('hello (fail)',)
assert e.__cause__ is None
assert e.__context__ is None
assert e.__suppress_context__ == False
if PyErr_Restore_traceback_ok:
assert 'Iam_t1' in exci.traceback[-1].locals
# t=ok, but code from under with raises
l = [0]
with raises(MyError) as exci:
with sync.WorkGroup(ctx) as wg:
def _(ctx):
l[0] = 1
wg.go(_)
def bad():
raise MyError('wow')
bad()
e = exci.value
assert e.__class__ is MyError
assert e.args == ('wow',)
assert e.__cause__ is None
assert e.__context__ is None
assert e.__suppress_context__ == False
assert exci.traceback[-1].name == 'bad'
assert l[0] == 1
# t=fail, code from under with also raises
with raises(MyError) as exci:
with sync.WorkGroup(ctx) as wg:
def f(ctx):
raise MyError('fail from go')
wg.go(f)
def g():
raise MyError('just raise')
g()
e = exci.value
assert e.__class__ is MyError
assert e.args == ('fail from go',)
assert e.__cause__ is None
assert e.__context__ is not None
assert e.__suppress_context__ == False
assert exci.traceback[-1].name == 'f'
e2 = e.__context__
assert e2.__class__ is MyError
assert e2.args == ('just raise',)
assert e2.__cause__ is None
assert e2.__context__ is None
assert e2.__suppress_context__ == False
assert e2.__traceback__ is not None
t2 = Traceback(e2.__traceback__)
assert t2[-1].name == 'g'
# create/wait workgroup with 1 empty worker. # create/wait workgroup with 1 empty worker.
def bench_workgroup_empty(b): def bench_workgroup_empty(b):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment