Commit 03d0dd98 authored by Martín Ferrari's avatar Martín Ferrari

Make sure the protocol remains in a sane state during PROC commands; add tests for that

parent b17d2af5
......@@ -418,14 +418,21 @@ class Client(object):
def shutdown(self):
"Tell the client to quit."
self._send_cmd(("QUIT", ))
self._send_cmd("QUIT")
self._read_and_check_reply()
def _send_fd(self, type, fd):
def _send_fd(self, name, fd):
"Pass a file descriptor"
self._send_cmd("PROC", type)
self._send_cmd("PROC", name)
self._read_and_check_reply(3)
passfd.sendfd(self._fd, fd, "PROC " + type)
try:
passfd.sendfd(self._fd, fd, "PROC " + name)
except:
# need to fill the buffer on the other side, nevertheless
self._fd.write("=" * (len(name) + 5))
# And also read the expected error
self._read_and_check_reply(5)
raise
self._read_and_check_reply()
def spawn(self, executable, argv = None, cwd = None, env = None,
......@@ -443,6 +450,8 @@ class Client(object):
self._send_cmd(*params)
self._read_and_check_reply()
# After this, if we get an error, we have to abort the PROC
try:
if user != None:
self._send_cmd("PROC", "USER", _b64(user))
self._read_and_check_reply()
......@@ -469,6 +478,10 @@ class Client(object):
pid = self._read_and_check_reply().split()[0]
return pid
except:
self._send_cmd("PROC", "ABRT")
self._read_and_check_reply()
raise
def poll(self, pid):
"""Equivalent to Popen.poll(), checks if the process has finished.
......
......@@ -34,6 +34,27 @@ class TestServer(unittest.TestCase):
pid, ret = os.waitpid(pid, 0)
self.assertEquals(ret, 0)
def test_spawn_recovery(self):
(s0, s1) = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM, 0)
pid = os.fork()
if not pid:
s1.close()
srv = netns.protocol.Server(s0)
srv.run()
os._exit(0)
cli = netns.protocol.Client(s1)
s0.close()
# make PROC SIN fail
self.assertRaises(OSError, cli.spawn, "/bin/true", stdin = -1)
# check if the protocol is in a sane state:
# PROC CWD should not be valid
cli._send_cmd("PROC", "CWD", "/")
self.assertRaises(RuntimeError, cli._read_and_check_reply)
cli.shutdown()
pid, ret = os.waitpid(pid, 0)
self.assertEquals(ret, 0)
def test_basic_stuff(self):
(s0, s1) = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM, 0)
srv = netns.protocol.Server(s0)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment