Commit f499a021 authored by Jens Axboe's avatar Jens Axboe

io_uring: ensure async punted connect requests copy data

Just like commit f67676d1 for read/write requests, this one ensures
that the sockaddr data has been copied for IORING_OP_CONNECT if we need
to punt the request to async context.
Signed-off-by: default avatarJens Axboe <axboe@kernel.dk>
parent 03b1230c
...@@ -308,6 +308,10 @@ struct io_timeout { ...@@ -308,6 +308,10 @@ struct io_timeout {
struct io_timeout_data *data; struct io_timeout_data *data;
}; };
struct io_async_connect {
struct sockaddr_storage address;
};
struct io_async_msghdr { struct io_async_msghdr {
struct iovec fast_iov[UIO_FASTIOV]; struct iovec fast_iov[UIO_FASTIOV];
struct iovec *iov; struct iovec *iov;
...@@ -327,6 +331,7 @@ struct io_async_ctx { ...@@ -327,6 +331,7 @@ struct io_async_ctx {
union { union {
struct io_async_rw rw; struct io_async_rw rw;
struct io_async_msghdr msg; struct io_async_msghdr msg;
struct io_async_connect connect;
}; };
}; };
...@@ -2195,11 +2200,26 @@ static int io_accept(struct io_kiocb *req, const struct io_uring_sqe *sqe, ...@@ -2195,11 +2200,26 @@ static int io_accept(struct io_kiocb *req, const struct io_uring_sqe *sqe,
#endif #endif
} }
static int io_connect_prep(struct io_kiocb *req, struct io_async_ctx *io)
{
#if defined(CONFIG_NET)
const struct io_uring_sqe *sqe = req->sqe;
struct sockaddr __user *addr;
int addr_len;
addr = (struct sockaddr __user *) (unsigned long) READ_ONCE(sqe->addr);
addr_len = READ_ONCE(sqe->addr2);
return move_addr_to_kernel(addr, addr_len, &io->connect.address);
#else
return 0;
#endif
}
static int io_connect(struct io_kiocb *req, const struct io_uring_sqe *sqe, static int io_connect(struct io_kiocb *req, const struct io_uring_sqe *sqe,
struct io_kiocb **nxt, bool force_nonblock) struct io_kiocb **nxt, bool force_nonblock)
{ {
#if defined(CONFIG_NET) #if defined(CONFIG_NET)
struct sockaddr __user *addr; struct io_async_ctx __io, *io;
unsigned file_flags; unsigned file_flags;
int addr_len, ret; int addr_len, ret;
...@@ -2208,15 +2228,35 @@ static int io_connect(struct io_kiocb *req, const struct io_uring_sqe *sqe, ...@@ -2208,15 +2228,35 @@ static int io_connect(struct io_kiocb *req, const struct io_uring_sqe *sqe,
if (sqe->ioprio || sqe->len || sqe->buf_index || sqe->rw_flags) if (sqe->ioprio || sqe->len || sqe->buf_index || sqe->rw_flags)
return -EINVAL; return -EINVAL;
addr = (struct sockaddr __user *) (unsigned long) READ_ONCE(sqe->addr);
addr_len = READ_ONCE(sqe->addr2); addr_len = READ_ONCE(sqe->addr2);
file_flags = force_nonblock ? O_NONBLOCK : 0; file_flags = force_nonblock ? O_NONBLOCK : 0;
ret = __sys_connect_file(req->file, addr, addr_len, file_flags); if (req->io) {
if (ret == -EAGAIN && force_nonblock) io = req->io;
} else {
ret = io_connect_prep(req, &__io);
if (ret)
goto out;
io = &__io;
}
ret = __sys_connect_file(req->file, &io->connect.address, addr_len,
file_flags);
if (ret == -EAGAIN && force_nonblock) {
io = kmalloc(sizeof(*io), GFP_KERNEL);
if (!io) {
ret = -ENOMEM;
goto out;
}
memcpy(&io->connect, &__io.connect, sizeof(io->connect));
req->io = io;
memcpy(&io->sqe, req->sqe, sizeof(*req->sqe));
req->sqe = &io->sqe;
return -EAGAIN; return -EAGAIN;
}
if (ret == -ERESTARTSYS) if (ret == -ERESTARTSYS)
ret = -EINTR; ret = -EINTR;
out:
if (ret < 0 && (req->flags & REQ_F_LINK)) if (ret < 0 && (req->flags & REQ_F_LINK))
req->flags |= REQ_F_FAIL_LINK; req->flags |= REQ_F_FAIL_LINK;
io_cqring_add_event(req, ret); io_cqring_add_event(req, ret);
...@@ -2832,6 +2872,9 @@ static int io_req_defer_prep(struct io_kiocb *req, struct io_async_ctx *io) ...@@ -2832,6 +2872,9 @@ static int io_req_defer_prep(struct io_kiocb *req, struct io_async_ctx *io)
case IORING_OP_RECVMSG: case IORING_OP_RECVMSG:
ret = io_recvmsg_prep(req, io); ret = io_recvmsg_prep(req, io);
break; break;
case IORING_OP_CONNECT:
ret = io_connect_prep(req, io);
break;
default: default:
req->io = io; req->io = io;
return 0; return 0;
......
...@@ -406,9 +406,8 @@ extern int __sys_accept4(int fd, struct sockaddr __user *upeer_sockaddr, ...@@ -406,9 +406,8 @@ extern int __sys_accept4(int fd, struct sockaddr __user *upeer_sockaddr,
int __user *upeer_addrlen, int flags); int __user *upeer_addrlen, int flags);
extern int __sys_socket(int family, int type, int protocol); extern int __sys_socket(int family, int type, int protocol);
extern int __sys_bind(int fd, struct sockaddr __user *umyaddr, int addrlen); extern int __sys_bind(int fd, struct sockaddr __user *umyaddr, int addrlen);
extern int __sys_connect_file(struct file *file, extern int __sys_connect_file(struct file *file, struct sockaddr_storage *addr,
struct sockaddr __user *uservaddr, int addrlen, int addrlen, int file_flags);
int file_flags);
extern int __sys_connect(int fd, struct sockaddr __user *uservaddr, extern int __sys_connect(int fd, struct sockaddr __user *uservaddr,
int addrlen); int addrlen);
extern int __sys_listen(int fd, int backlog); extern int __sys_listen(int fd, int backlog);
......
...@@ -1826,26 +1826,22 @@ SYSCALL_DEFINE3(accept, int, fd, struct sockaddr __user *, upeer_sockaddr, ...@@ -1826,26 +1826,22 @@ SYSCALL_DEFINE3(accept, int, fd, struct sockaddr __user *, upeer_sockaddr,
* include the -EINPROGRESS status for such sockets. * include the -EINPROGRESS status for such sockets.
*/ */
int __sys_connect_file(struct file *file, struct sockaddr __user *uservaddr, int __sys_connect_file(struct file *file, struct sockaddr_storage *address,
int addrlen, int file_flags) int addrlen, int file_flags)
{ {
struct socket *sock; struct socket *sock;
struct sockaddr_storage address;
int err; int err;
sock = sock_from_file(file, &err); sock = sock_from_file(file, &err);
if (!sock) if (!sock)
goto out; goto out;
err = move_addr_to_kernel(uservaddr, addrlen, &address);
if (err < 0)
goto out;
err = err =
security_socket_connect(sock, (struct sockaddr *)&address, addrlen); security_socket_connect(sock, (struct sockaddr *)address, addrlen);
if (err) if (err)
goto out; goto out;
err = sock->ops->connect(sock, (struct sockaddr *)&address, addrlen, err = sock->ops->connect(sock, (struct sockaddr *)address, addrlen,
sock->file->f_flags | file_flags); sock->file->f_flags | file_flags);
out: out:
return err; return err;
...@@ -1858,7 +1854,11 @@ int __sys_connect(int fd, struct sockaddr __user *uservaddr, int addrlen) ...@@ -1858,7 +1854,11 @@ int __sys_connect(int fd, struct sockaddr __user *uservaddr, int addrlen)
f = fdget(fd); f = fdget(fd);
if (f.file) { if (f.file) {
ret = __sys_connect_file(f.file, uservaddr, addrlen, 0); struct sockaddr_storage address;
ret = move_addr_to_kernel(uservaddr, addrlen, &address);
if (!ret)
ret = __sys_connect_file(f.file, &address, addrlen, 0);
if (f.flags) if (f.flags)
fput(f.file); fput(f.file);
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment