From 2aad64bb954d28a28171b92deba8eb3f7c72a22a Mon Sep 17 00:00:00 2001
From: Kirill Smelkov <kirr@nexedi.com>
Date: Thu, 2 May 2019 11:25:11 +0300
Subject: [PATCH] golang: Add support for nil channel

Send/recv on the nil channel block forever; close panics.
If a nil channel is used in select - corresponding case is never selected.

Setting channel to nil is a usual idiom in Go to disable processing some
cases in select. Nil channel is also used as "done" for e.g.
context.Background() - for contexts that can be never canceled.
---
 README.rst            | 21 ++++++++----
 golang/__init__.py    | 78 +++++++++++++++++++++++++++++++------------
 golang/golang_test.py | 74 +++++++++++++++++++++++++++++++++++++++-
 3 files changed, 143 insertions(+), 30 deletions(-)

diff --git a/README.rst b/README.rst
index ac24500..42148f4 100644
--- a/README.rst
+++ b/README.rst
@@ -51,8 +51,9 @@ Goroutines and channels
 `go` spawns a coroutine, or thread if gevent was not activated. It is possible to
 exchange data in between either threads or coroutines via channels. `chan`
 creates a new channel with Go semantic - either synchronous or buffered. Use
-`chan.recv`, `chan.send` and `chan.close` for communication. `select` can be
-used to multiplex on several channels. For example::
+`chan.recv`, `chan.send` and `chan.close` for communication. `nilchan`
+stands for the nil channel. `select` can be used to multiplex on several
+channels. For example::
 
     ch1 = chan()    # synchronous channel
     ch2 = chan(3)   # channel with buffer of size 3
@@ -65,22 +66,28 @@ used to multiplex on several channels. For example::
     ch1.recv()      # will give 'a'
     ch2.recv_()     # will give ('b', True)
 
+    ch2 = nilchan   # rebind ch2 to nil channel
     _, _rx = select(
         ch1.recv,           # 0
-        ch2.recv_,          # 1
-        (ch2.send, obj2),   # 2
-        default,            # 3
+        ch1.recv_,          # 1
+        (ch1.send, obj),    # 2
+        ch2.recv,           # 3
+        default,            # 4
     )
     if _ == 0:
         # _rx is what was received from ch1
         ...
     if _ == 1:
-        # _rx is (rx, ok) of what was received from ch2
+        # _rx is (rx, ok) of what was received from ch1
         ...
     if _ == 2:
-        # we know obj2 was sent to ch2
+        # we know obj was sent to ch1
         ...
     if _ == 3:
+        # this case will be never selected because
+        # send/recv on nil channel block forever.
+        ...
+    if _ == 4:
         # default case
         ...
 
diff --git a/golang/__init__.py b/golang/__init__.py
index ea43a1e..e972f24 100644
--- a/golang/__init__.py
+++ b/golang/__init__.py
@@ -30,7 +30,7 @@
 
 __version__ = "0.0.0.dev8"
 
-__all__ = ['go', 'chan', 'select', 'default', 'defer', 'panic', 'recover', 'func', 'gimport']
+__all__ = ['go', 'chan', 'select', 'default', 'nilchan', 'defer', 'panic', 'recover', 'func', 'gimport']
 
 from golang._gopath import gimport  # make gimport available from golang
 import inspect, threading, collections, random, sys
@@ -332,6 +332,9 @@ class chan(object):
     #
     # .send(obj)
     def send(self, obj):
+        if self is nilchan:
+            _blockforever()
+
         self._mu.acquire()
         if 1:
             ok = self._trysend(obj)
@@ -356,6 +359,9 @@ class chan(object):
     #
     # .recv_() -> (rx, ok)
     def recv_(self):
+        if self is nilchan:
+            _blockforever()
+
         self._mu.acquire()
         if 1:
             rx_, ok = self._tryrecv()
@@ -449,6 +455,9 @@ class chan(object):
 
     # close closes sending side of the channel.
     def close(self):
+        if self is nilchan:
+            panic("close of nil channel")
+
         recvv = []
         sendv = []
 
@@ -481,6 +490,18 @@ class chan(object):
     def __len__(self):
         return len(self._dataq)
 
+    def __repr__(self):
+        if self is nilchan:
+            return "nilchan"
+        else:
+            return super(chan, self).__repr__()
+
+
+# nilchan is the nil channel.
+#
+# On nil channel: send/recv block forever; close panics.
+nilchan = chan(None)    # TODO -> <chan*>(NULL) after move to Cython
+
 
 # default represents default case for select.
 default  = object()
@@ -523,9 +544,6 @@ if six.PY2:
 #       # default case
 #       ...
 def select(*casev):
-    # XXX select on nil chan?
-    # XXX select{} -> block forever
-
     # select promise: if multiple cases are ready - one will be selected randomly
     ncasev = list(enumerate(casev))
     random.shuffle(ncasev)
@@ -550,14 +568,15 @@ def select(*casev):
                 panic("select: send expected: %r" % (send,))
 
             ch = send.__self__
-            ch._mu.acquire()
-            if 1:
-                ok = ch._trysend(tx)
-                if ok:
-                    return n, None
-            ch._mu.release()
+            if ch is not nilchan:   # nil chan is never ready
+                ch._mu.acquire()
+                if 1:
+                    ok = ch._trysend(tx)
+                    if ok:
+                        return n, None
+                ch._mu.release()
 
-            sendv.append((n, ch, tx))
+                sendv.append((n, ch, tx))
 
         # recv
         else:
@@ -572,22 +591,27 @@ def select(*casev):
                 panic("select: recv expected: %r" % (recv,))
 
             ch = recv.__self__
-            ch._mu.acquire()
-            if 1:
-                rx_, ok = ch._tryrecv()
-                if ok:
-                    if not commaok:
-                        rx, ok = rx_
-                        rx_ = rx
-                    return n, rx_
-            ch._mu.release()
-
-            recvv.append((n, ch, commaok))
+            if ch is not nilchan:   # nil chan is never ready
+                ch._mu.acquire()
+                if 1:
+                    rx_, ok = ch._tryrecv()
+                    if ok:
+                        if not commaok:
+                            rx, ok = rx_
+                            rx_ = rx
+                        return n, rx_
+                ch._mu.release()
+
+                recvv.append((n, ch, commaok))
 
     # execute default if we have it
     if ndefault is not None:
         return ndefault, None
 
+    # select{} or with nil-channels only -> block forever
+    if len(recvv) + len(sendv) == 0:
+        _blockforever()
+
     # second pass: subscribe and wait on all rx/tx cases
     g = _WaitGroup()
 
@@ -660,3 +684,13 @@ def select(*casev):
     finally:
         # unsubscribe not-succeeded waiters
         g.dequeAll()
+
+
+# _blockforever blocks current goroutine forever.
+def _blockforever():
+    # take a lock twice. It will forever block on the second lock attempt.
+    # Under gevent, similarly to Go, this raises "LoopExit: This operation
+    # would block forever", if there are no other greenlets scheduled to be run.
+    dead = threading.Lock()
+    dead.acquire()
+    dead.acquire()
diff --git a/golang/golang_test.py b/golang/golang_test.py
index 05d9b75..c44f81a 100644
--- a/golang/golang_test.py
+++ b/golang/golang_test.py
@@ -18,11 +18,12 @@
 # See COPYING file for full licensing terms.
 # See https://www.nexedi.com/licensing for rationale and options.
 
-from golang import go, chan, select, default, _PanicError, func, panic, defer, recover
+from golang import go, chan, select, default, nilchan, _PanicError, func, panic, defer, recover
 from pytest import raises
 from os.path import dirname
 import os, sys, time, threading, inspect, subprocess
 
+import golang
 from golang import _chan_recv, _chan_send
 from golang._pycompat import im_class
 
@@ -313,6 +314,50 @@ def test_select():
     assert len(ch2._sendq) == len(ch2._recvq) == 0
 
 
+    # blocking send + nil channel
+    z = nilchan
+    for i in range(N):
+        ch = chan()
+        done = chan()
+        def _():
+            waitBlocked(ch.send)
+            assert len(z._sendq) == len(z._recvq) == 0
+            assert ch.recv() == 'c'
+            done.close()
+        go(_)
+
+        _, _rx = select(
+                z.recv,
+                (z.send, 0),
+                (ch.send, 'c'),
+        )
+
+        assert (_, _rx) == (2, None)
+        done.recv()
+        assert len(ch._sendq) == len(ch._recvq) == 0
+
+    # blocking recv + nil channel
+    for i in range(N):
+        ch = chan()
+        done = chan()
+        def _():
+            waitBlocked(ch.recv)
+            assert len(z._sendq) == len(z._recvq) == 0
+            ch.send('d')
+            done.close()
+        go(_)
+
+        _, _rx = select(
+                z.recv,
+                (z.send, 0),
+                ch.recv,
+        )
+
+        assert (_, _rx) == (2, 'd')
+        done.recv()
+        assert len(ch._sendq) == len(ch._recvq) == 0
+
+
     # buffered ping-pong
     ch = chan(1)
     for i in range(N):
@@ -406,6 +451,33 @@ def test_select():
     assert len(ch2._sendq) == len(ch2._recvq) == 0
 
 
+# BlocksForever is used in "blocks forever" tests where golang._blockforever
+# is patched to raise instead of block.
+class BlocksForever(Exception):
+    pass
+
+def test_blockforever():
+    B = golang._blockforever
+    def _(): raise BlocksForever()
+    golang._blockforever = _
+    try:
+        _test_blockforever()
+    finally:
+        golang._blockforever = B
+
+def _test_blockforever():
+    z = nilchan
+    with raises(BlocksForever): z.send(0)
+    with raises(BlocksForever): z.recv()
+    with raises(_PanicError):   z.close()   # to fully cover nilchan ops
+
+    # select{} & nil-channel only
+    with raises(BlocksForever): select()
+    with raises(BlocksForever): select((z.send, 0))
+    with raises(BlocksForever): select(z.recv)
+    with raises(BlocksForever): select((z.send, 1), z.recv)
+
+
 def test_method():
     # test how @func(cls) works
     # this also implicitly tests just @func, since @func(cls) uses that.
-- 
2.30.9