Commit cb27f386 authored by Martín Ferrari's avatar Martín Ferrari

Some cleaning, fixed some tests, a little more coverage.

parent 2cb6d712
#!/usr/bin/env python #!/usr/bin/env python
# vim:ts=4:sw=4:et:ai:sts=4 # vim:ts=4:sw=4:et:ai:sts=4
try: try: # pragma: no cover
from yaml import CLoader as Loader from yaml import CLoader as Loader
from yaml import CDumper as Dumper from yaml import CDumper as Dumper
except ImportError: except ImportError:
from yaml import Loader, Dumper from yaml import Loader, Dumper
import base64, os, passfd, re, signal, sys, traceback, unshare, yaml import base64, os, passfd, re, signal, sys, traceback, unshare, yaml
import netns.subprocess_, netns.iproute, netns.interface import netns.subprocess_, netns.iproute, netns.interface
...@@ -91,12 +92,12 @@ class Server(object): ...@@ -91,12 +92,12 @@ class Server(object):
for i in range(len(clean) - 1): for i in range(len(clean) - 1):
s = str(code) + "-" + clean[i] + "\n" s = str(code) + "-" + clean[i] + "\n"
self._wfd.write(s) self._wfd.write(s)
if self.debug: if self.debug: # pragma: no cover
sys.stderr.write("<ans> %s" % s) sys.stderr.write("<ans> %s" % s)
s = str(code) + " " + clean[-1] + "\n" s = str(code) + " " + clean[-1] + "\n"
self._wfd.write(s) self._wfd.write(s)
if self.debug: if self.debug: # pragma: no cover
sys.stderr.write("<ans> %s" % s) sys.stderr.write("<ans> %s" % s)
return return
...@@ -106,28 +107,10 @@ class Server(object): ...@@ -106,28 +107,10 @@ class Server(object):
if not line: if not line:
self.closed = True self.closed = True
return None return None
if self.debug: if self.debug: # pragma: no cover
sys.stderr.write("<C> %s" % line) sys.stderr.write("<C> %s" % line)
return line.rstrip() return line.rstrip()
def readchunk(self, size):
"Read a chunk of data limited by size or by an empty line."
read = 0
res = ""
while True:
line = self._rfd.readline()
if not line:
self.closed = True
return None
if size == None and line == "\n":
break
read += len(line)
res += line
if size != None and read >= size:
break
return res
def readcmd(self): def readcmd(self):
"""Main entry point: read and parse a line from the client, handle """Main entry point: read and parse a line from the client, handle
argument validation and return a tuple (function, command_name, argument validation and return a tuple (function, command_name,
...@@ -165,7 +148,7 @@ class Server(object): ...@@ -165,7 +148,7 @@ class Server(object):
cmdname = cmd1 cmdname = cmd1
funcname = "do_%s" % cmd1 funcname = "do_%s" % cmd1
if not hasattr(self, funcname): if not hasattr(self, funcname): # pragma: no cover
self.reply(500, "Not implemented.") self.reply(500, "Not implemented.")
return None return None
...@@ -181,7 +164,6 @@ class Server(object): ...@@ -181,7 +164,6 @@ class Server(object):
for i in range(len(args)): for i in range(len(args)):
if argstemplate[j] == '*': if argstemplate[j] == '*':
j = j - 1 j = j - 1
if argstemplate[j] == 'i': if argstemplate[j] == 'i':
try: try:
args[i] = int(args[i]) args[i] = int(args[i])
...@@ -189,24 +171,20 @@ class Server(object): ...@@ -189,24 +171,20 @@ class Server(object):
self.reply(500, "Invalid parameter %s: must be an integer." self.reply(500, "Invalid parameter %s: must be an integer."
% args[i]) % args[i])
return None return None
elif argstemplate[j] == 's':
pass
elif argstemplate[j] == 'b': elif argstemplate[j] == 'b':
try: try:
if args[i][0] == '=': if args[i][0] == '=':
args[i] = base64.b64decode(args[i][1:]) args[i] = base64.b64decode(args[i][1:])
# if len(args[i]) == 0:
# self.reply(500, "Invalid parameter: empty.")
# return None
except TypeError: except TypeError:
self.reply(500, "Invalid parameter: not base-64 encoded.") self.reply(500, "Invalid parameter: not base-64 encoded.")
return None return None
else: elif argstemplate[j] != 's': # pragma: no cover
raise RuntimeError("Invalid argument template: %s" % _argstmpl) raise RuntimeError("Invalid argument template: %s" % _argstmpl)
# Nothing done for "s" parameters
j += 1 j += 1
func = getattr(self, funcname) func = getattr(self, funcname)
if self.debug: if self.debug: # pragma: no cover
sys.stderr.write("<cmd> %s, args: %s\n" % (cmdname, args)) sys.stderr.write("<cmd> %s, args: %s\n" % (cmdname, args))
return (func, cmdname, args) return (func, cmdname, args)
...@@ -276,7 +254,6 @@ class Server(object): ...@@ -276,7 +254,6 @@ class Server(object):
self.reply(354, self.reply(354,
"Pass the file descriptor now, with `%s\\n' as payload." % "Pass the file descriptor now, with `%s\\n' as payload." %
cmdname) cmdname)
try: try:
fd, payload = passfd.recvfd(self._rfd, len(cmdname) + 1) fd, payload = passfd.recvfd(self._rfd, len(cmdname) + 1)
except (IOError, BaseException), e: # FIXME except (IOError, BaseException), e: # FIXME
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# vim:ts=4:sw=4:et:ai:sts=4 # vim:ts=4:sw=4:et:ai:sts=4
import netns.protocol import netns.protocol
import os, socket, sys, unittest import os, socket, sys, threading, unittest
class TestServer(unittest.TestCase): class TestServer(unittest.TestCase):
def test_server_startup(self): def test_server_startup(self):
...@@ -10,40 +10,36 @@ class TestServer(unittest.TestCase): ...@@ -10,40 +10,36 @@ class TestServer(unittest.TestCase):
# the file descriptor; and check the banner. # the file descriptor; and check the banner.
(s0, s1) = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM, 0) (s0, s1) = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM, 0)
(s2, s3) = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM, 0) (s2, s3) = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM, 0)
pid = os.fork()
if not pid: def run_server():
s1.close()
srv = netns.protocol.Server(s0, s0) srv = netns.protocol.Server(s0, s0)
srv.run() srv.run()
s3.close()
srv = netns.protocol.Server(s2.fileno(), s2.fileno()) srv = netns.protocol.Server(s2.fileno(), s2.fileno())
srv.run() srv.run()
t = threading.Thread(target = run_server)
t.start()
os._exit(0)
s0.close()
s = os.fdopen(s1.fileno(), "r+", 1) s = os.fdopen(s1.fileno(), "r+", 1)
self.assertEquals(s.readline()[0:4], "220 ") self.assertEquals(s.readline()[0:4], "220 ")
s.close() s.close()
s0.close()
s2.close()
s = os.fdopen(s3.fileno(), "r+", 1) s = os.fdopen(s3.fileno(), "r+", 1)
self.assertEquals(s.readline()[0:4], "220 ") self.assertEquals(s.readline()[0:4], "220 ")
s.close() s.close()
pid, ret = os.waitpid(pid, 0) s2.close()
self.assertEquals(ret, 0) t.join()
def test_spawn_recovery(self): def test_spawn_recovery(self):
(s0, s1) = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM, 0) (s0, s1) = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM, 0)
pid = os.fork()
if not pid: def run_server():
s1.close() netns.protocol.Server(s0, s0, debug = 0).run()
srv = netns.protocol.Server(s0, s0) t = threading.Thread(target = run_server)
srv.run() t.start()
os._exit(0)
cli = netns.protocol.Client(s1, s1) cli = netns.protocol.Client(s1, s1)
s0.close()
# make PROC SIN fail # make PROC SIN fail
self.assertRaises(OSError, cli.spawn, "/bin/true", stdin = -1) self.assertRaises(OSError, cli.spawn, "/bin/true", stdin = -1)
...@@ -51,13 +47,17 @@ class TestServer(unittest.TestCase): ...@@ -51,13 +47,17 @@ class TestServer(unittest.TestCase):
# PROC CWD should not be valid # PROC CWD should not be valid
cli._send_cmd("PROC", "CWD", "/") cli._send_cmd("PROC", "CWD", "/")
self.assertRaises(RuntimeError, cli._read_and_check_reply) self.assertRaises(RuntimeError, cli._read_and_check_reply)
# Force a KeyError, and check that the exception is received correctly
cli._send_cmd("IF", "LIST", "-1")
self.assertRaises(KeyError, cli._read_and_check_reply)
cli.shutdown() cli.shutdown()
pid, ret = os.waitpid(pid, 0)
self.assertEquals(ret, 0) t.join()
def test_basic_stuff(self): def test_basic_stuff(self):
(s0, s1) = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM, 0) (s0, s1) = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM, 0)
srv = netns.protocol.Server(s0, s0) srv = netns.protocol.Server(s0, s0, debug = 0)
s1 = s1.makefile("r+", 1) s1 = s1.makefile("r+", 1)
def check_error(self, cmd, code = 500): def check_error(self, cmd, code = 500):
...@@ -93,6 +93,9 @@ class TestServer(unittest.TestCase): ...@@ -93,6 +93,9 @@ class TestServer(unittest.TestCase):
check_error(self, "proc poll") # missing arg check_error(self, "proc poll") # missing arg
check_error(self, "proc poll 1 2") # too many args check_error(self, "proc poll 1 2") # too many args
check_error(self, "proc poll a") # invalid type check_error(self, "proc poll a") # invalid type
check_error(self, "proc") # incomplete command
check_error(self, "proc foo") # unknown subcommand
check_error(self, "foo bar") # unknown
check_ok(self, "proc crte /bin/sh", srv.do_PROC_CRTE, check_ok(self, "proc crte /bin/sh", srv.do_PROC_CRTE,
['/bin/sh']) ['/bin/sh'])
......
...@@ -155,6 +155,14 @@ class TestSubprocess(unittest.TestCase): ...@@ -155,6 +155,14 @@ class TestSubprocess(unittest.TestCase):
os.close(r) os.close(r)
p.wait() p.wait()
# cwd
r, w = os.pipe()
p = Subprocess(node, '/bin/pwd', stdout = w, cwd = "/")
os.close(w)
self.assertEquals(_readall(r), "/\n")
os.close(r)
p.wait()
p = Subprocess(node, ['sleep', '100']) p = Subprocess(node, ['sleep', '100'])
self.assertTrue(p.pid > 0) self.assertTrue(p.pid > 0)
self.assertEquals(p.poll(), None) # not finished self.assertEquals(p.poll(), None) # not finished
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment