Commit 03d0dd98 authored by Martín Ferrari's avatar Martín Ferrari

Make sure the protocol remains in a sane state during PROC commands; add tests for that

parent b17d2af5
...@@ -418,14 +418,21 @@ class Client(object): ...@@ -418,14 +418,21 @@ class Client(object):
def shutdown(self): def shutdown(self):
"Tell the client to quit." "Tell the client to quit."
self._send_cmd(("QUIT", )) self._send_cmd("QUIT")
self._read_and_check_reply() self._read_and_check_reply()
def _send_fd(self, type, fd): def _send_fd(self, name, fd):
"Pass a file descriptor" "Pass a file descriptor"
self._send_cmd("PROC", type) self._send_cmd("PROC", name)
self._read_and_check_reply(3) self._read_and_check_reply(3)
passfd.sendfd(self._fd, fd, "PROC " + type) try:
passfd.sendfd(self._fd, fd, "PROC " + name)
except:
# need to fill the buffer on the other side, nevertheless
self._fd.write("=" * (len(name) + 5))
# And also read the expected error
self._read_and_check_reply(5)
raise
self._read_and_check_reply() self._read_and_check_reply()
def spawn(self, executable, argv = None, cwd = None, env = None, def spawn(self, executable, argv = None, cwd = None, env = None,
...@@ -443,32 +450,38 @@ class Client(object): ...@@ -443,32 +450,38 @@ class Client(object):
self._send_cmd(*params) self._send_cmd(*params)
self._read_and_check_reply() self._read_and_check_reply()
if user != None: # After this, if we get an error, we have to abort the PROC
self._send_cmd("PROC", "USER", _b64(user)) try:
self._read_and_check_reply() if user != None:
self._send_cmd("PROC", "USER", _b64(user))
if cwd != None: self._read_and_check_reply()
self._send_cmd("PROC", "CWD", _b64(cwd))
self._read_and_check_reply() if cwd != None:
self._send_cmd("PROC", "CWD", _b64(cwd))
if env != None: self._read_and_check_reply()
params = []
for i in env: if env != None:
params.append(_b64(i)) params = []
self._send_cmd("PROC", "ENV", params) for i in env:
params.append(_b64(i))
self._send_cmd("PROC", "ENV", params)
self._read_and_check_reply()
if stdin != None:
self._send_fd("SIN", stdin)
if stdout != None:
self._send_fd("SOUT", stdout)
if stderr != None:
self._send_fd("SERR", stderr)
self._send_cmd("PROC", "RUN")
pid = self._read_and_check_reply().split()[0]
return pid
except:
self._send_cmd("PROC", "ABRT")
self._read_and_check_reply() self._read_and_check_reply()
raise
if stdin != None:
self._send_fd("SIN", stdin)
if stdout != None:
self._send_fd("SOUT", stdout)
if stderr != None:
self._send_fd("SERR", stderr)
self._send_cmd("PROC", "RUN")
pid = self._read_and_check_reply().split()[0]
return pid
def poll(self, pid): def poll(self, pid):
"""Equivalent to Popen.poll(), checks if the process has finished. """Equivalent to Popen.poll(), checks if the process has finished.
......
...@@ -34,6 +34,27 @@ class TestServer(unittest.TestCase): ...@@ -34,6 +34,27 @@ class TestServer(unittest.TestCase):
pid, ret = os.waitpid(pid, 0) pid, ret = os.waitpid(pid, 0)
self.assertEquals(ret, 0) self.assertEquals(ret, 0)
def test_spawn_recovery(self):
(s0, s1) = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM, 0)
pid = os.fork()
if not pid:
s1.close()
srv = netns.protocol.Server(s0)
srv.run()
os._exit(0)
cli = netns.protocol.Client(s1)
s0.close()
# make PROC SIN fail
self.assertRaises(OSError, cli.spawn, "/bin/true", stdin = -1)
# check if the protocol is in a sane state:
# PROC CWD should not be valid
cli._send_cmd("PROC", "CWD", "/")
self.assertRaises(RuntimeError, cli._read_and_check_reply)
cli.shutdown()
pid, ret = os.waitpid(pid, 0)
self.assertEquals(ret, 0)
def test_basic_stuff(self): def test_basic_stuff(self):
(s0, s1) = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM, 0) (s0, s1) = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM, 0)
srv = netns.protocol.Server(s0) srv = netns.protocol.Server(s0)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment