Commit 6b18662e authored by Al Viro's avatar Al Viro

9p connect fixes

* if we fail in p9_conn_create(), we shouldn't leak references to struct file.
  Logics in ->close() doesn't help - ->trans is already gone by the time it's
  called.
* sock_create_kern() can fail.
* use of sock_map_fd() is all fscked up; I'd fixed most of that, but the
  rest will have to wait for a bit more work in net/socket.c (we still are
  violating the basic rule of working with descriptor table: "once the reference
  is installed there, don't rely on finding it there again").
Signed-off-by: default avatarAl Viro <viro@zeniv.linux.org.uk>
parent 7cbe66b6
...@@ -42,6 +42,8 @@ ...@@ -42,6 +42,8 @@
#include <net/9p/client.h> #include <net/9p/client.h>
#include <net/9p/transport.h> #include <net/9p/transport.h>
#include <linux/syscalls.h> /* killme */
#define P9_PORT 564 #define P9_PORT 564
#define MAX_SOCK_BUF (64*1024) #define MAX_SOCK_BUF (64*1024)
#define MAXPOLLWADDR 2 #define MAXPOLLWADDR 2
...@@ -788,24 +790,41 @@ static int p9_fd_open(struct p9_client *client, int rfd, int wfd) ...@@ -788,24 +790,41 @@ static int p9_fd_open(struct p9_client *client, int rfd, int wfd)
static int p9_socket_open(struct p9_client *client, struct socket *csocket) static int p9_socket_open(struct p9_client *client, struct socket *csocket)
{ {
int fd, ret; struct p9_trans_fd *p;
int ret, fd;
p = kmalloc(sizeof(struct p9_trans_fd), GFP_KERNEL);
if (!p)
return -ENOMEM;
csocket->sk->sk_allocation = GFP_NOIO; csocket->sk->sk_allocation = GFP_NOIO;
fd = sock_map_fd(csocket, 0); fd = sock_map_fd(csocket, 0);
if (fd < 0) { if (fd < 0) {
P9_EPRINTK(KERN_ERR, "p9_socket_open: failed to map fd\n"); P9_EPRINTK(KERN_ERR, "p9_socket_open: failed to map fd\n");
sock_release(csocket);
kfree(p);
return fd; return fd;
} }
ret = p9_fd_open(client, fd, fd); get_file(csocket->file);
if (ret < 0) { get_file(csocket->file);
P9_EPRINTK(KERN_ERR, "p9_socket_open: failed to open fd\n"); p->wr = p->rd = csocket->file;
client->trans = p;
client->status = Connected;
sys_close(fd); /* still racy */
p->rd->f_flags |= O_NONBLOCK;
p->conn = p9_conn_create(client);
if (IS_ERR(p->conn)) {
ret = PTR_ERR(p->conn);
p->conn = NULL;
kfree(p);
sockfd_put(csocket);
sockfd_put(csocket); sockfd_put(csocket);
return ret; return ret;
} }
((struct p9_trans_fd *)client->trans)->rd->f_flags |= O_NONBLOCK;
return 0; return 0;
} }
...@@ -883,7 +902,6 @@ p9_fd_create_tcp(struct p9_client *client, const char *addr, char *args) ...@@ -883,7 +902,6 @@ p9_fd_create_tcp(struct p9_client *client, const char *addr, char *args)
struct socket *csocket; struct socket *csocket;
struct sockaddr_in sin_server; struct sockaddr_in sin_server;
struct p9_fd_opts opts; struct p9_fd_opts opts;
struct p9_trans_fd *p = NULL; /* this gets allocated in p9_fd_open */
err = parse_opts(args, &opts); err = parse_opts(args, &opts);
if (err < 0) if (err < 0)
...@@ -897,12 +915,11 @@ p9_fd_create_tcp(struct p9_client *client, const char *addr, char *args) ...@@ -897,12 +915,11 @@ p9_fd_create_tcp(struct p9_client *client, const char *addr, char *args)
sin_server.sin_family = AF_INET; sin_server.sin_family = AF_INET;
sin_server.sin_addr.s_addr = in_aton(addr); sin_server.sin_addr.s_addr = in_aton(addr);
sin_server.sin_port = htons(opts.port); sin_server.sin_port = htons(opts.port);
sock_create_kern(PF_INET, SOCK_STREAM, IPPROTO_TCP, &csocket); err = sock_create_kern(PF_INET, SOCK_STREAM, IPPROTO_TCP, &csocket);
if (!csocket) { if (err) {
P9_EPRINTK(KERN_ERR, "p9_trans_tcp: problem creating socket\n"); P9_EPRINTK(KERN_ERR, "p9_trans_tcp: problem creating socket\n");
err = -EIO; return err;
goto error;
} }
err = csocket->ops->connect(csocket, err = csocket->ops->connect(csocket,
...@@ -912,30 +929,11 @@ p9_fd_create_tcp(struct p9_client *client, const char *addr, char *args) ...@@ -912,30 +929,11 @@ p9_fd_create_tcp(struct p9_client *client, const char *addr, char *args)
P9_EPRINTK(KERN_ERR, P9_EPRINTK(KERN_ERR,
"p9_trans_tcp: problem connecting socket to %s\n", "p9_trans_tcp: problem connecting socket to %s\n",
addr); addr);
goto error;
}
err = p9_socket_open(client, csocket);
if (err < 0)
goto error;
p = (struct p9_trans_fd *) client->trans;
p->conn = p9_conn_create(client);
if (IS_ERR(p->conn)) {
err = PTR_ERR(p->conn);
p->conn = NULL;
goto error;
}
return 0;
error:
if (csocket)
sock_release(csocket); sock_release(csocket);
return err;
}
kfree(p); return p9_socket_open(client, csocket);
return err;
} }
static int static int
...@@ -944,49 +942,33 @@ p9_fd_create_unix(struct p9_client *client, const char *addr, char *args) ...@@ -944,49 +942,33 @@ p9_fd_create_unix(struct p9_client *client, const char *addr, char *args)
int err; int err;
struct socket *csocket; struct socket *csocket;
struct sockaddr_un sun_server; struct sockaddr_un sun_server;
struct p9_trans_fd *p = NULL; /* this gets allocated in p9_fd_open */
csocket = NULL; csocket = NULL;
if (strlen(addr) > UNIX_PATH_MAX) { if (strlen(addr) > UNIX_PATH_MAX) {
P9_EPRINTK(KERN_ERR, "p9_trans_unix: address too long: %s\n", P9_EPRINTK(KERN_ERR, "p9_trans_unix: address too long: %s\n",
addr); addr);
err = -ENAMETOOLONG; return -ENAMETOOLONG;
goto error;
} }
sun_server.sun_family = PF_UNIX; sun_server.sun_family = PF_UNIX;
strcpy(sun_server.sun_path, addr); strcpy(sun_server.sun_path, addr);
sock_create_kern(PF_UNIX, SOCK_STREAM, 0, &csocket); err = sock_create_kern(PF_UNIX, SOCK_STREAM, 0, &csocket);
if (err < 0) {
P9_EPRINTK(KERN_ERR, "p9_trans_unix: problem creating socket\n");
return err;
}
err = csocket->ops->connect(csocket, (struct sockaddr *)&sun_server, err = csocket->ops->connect(csocket, (struct sockaddr *)&sun_server,
sizeof(struct sockaddr_un) - 1, 0); sizeof(struct sockaddr_un) - 1, 0);
if (err < 0) { if (err < 0) {
P9_EPRINTK(KERN_ERR, P9_EPRINTK(KERN_ERR,
"p9_trans_unix: problem connecting socket: %s: %d\n", "p9_trans_unix: problem connecting socket: %s: %d\n",
addr, err); addr, err);
goto error;
}
err = p9_socket_open(client, csocket);
if (err < 0)
goto error;
p = (struct p9_trans_fd *) client->trans;
p->conn = p9_conn_create(client);
if (IS_ERR(p->conn)) {
err = PTR_ERR(p->conn);
p->conn = NULL;
goto error;
}
return 0;
error:
if (csocket)
sock_release(csocket); sock_release(csocket);
return err;
}
kfree(p); return p9_socket_open(client, csocket);
return err;
} }
static int static int
...@@ -994,7 +976,7 @@ p9_fd_create(struct p9_client *client, const char *addr, char *args) ...@@ -994,7 +976,7 @@ p9_fd_create(struct p9_client *client, const char *addr, char *args)
{ {
int err; int err;
struct p9_fd_opts opts; struct p9_fd_opts opts;
struct p9_trans_fd *p = NULL; /* this get allocated in p9_fd_open */ struct p9_trans_fd *p;
parse_opts(args, &opts); parse_opts(args, &opts);
...@@ -1005,21 +987,19 @@ p9_fd_create(struct p9_client *client, const char *addr, char *args) ...@@ -1005,21 +987,19 @@ p9_fd_create(struct p9_client *client, const char *addr, char *args)
err = p9_fd_open(client, opts.rfd, opts.wfd); err = p9_fd_open(client, opts.rfd, opts.wfd);
if (err < 0) if (err < 0)
goto error; return err;
p = (struct p9_trans_fd *) client->trans; p = (struct p9_trans_fd *) client->trans;
p->conn = p9_conn_create(client); p->conn = p9_conn_create(client);
if (IS_ERR(p->conn)) { if (IS_ERR(p->conn)) {
err = PTR_ERR(p->conn); err = PTR_ERR(p->conn);
p->conn = NULL; p->conn = NULL;
goto error; fput(p->rd);
fput(p->wr);
return err;
} }
return 0; return 0;
error:
kfree(p);
return err;
} }
static struct p9_trans_module p9_tcp_trans = { static struct p9_trans_module p9_tcp_trans = {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment