Commit fe502c4a authored by Stefano Garzarella's avatar Stefano Garzarella Committed by David S. Miller

vsock: add 'transport' member in the struct vsock_sock

As a preparation to support multiple transports, this patch adds
the 'transport' member at the 'struct vsock_sock'.
This new field is initialized during the creation in the
__vsock_create() function.

This patch also renames the global 'transport' pointer to
'transport_single', since for now we're only supporting a single
transport registered at run-time.
Reviewed-by: default avatarStefan Hajnoczi <stefanha@redhat.com>
Reviewed-by: default avatarJorgen Hansen <jhansen@vmware.com>
Signed-off-by: default avatarStefano Garzarella <sgarzare@redhat.com>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent 3603a2e9
...@@ -27,6 +27,7 @@ extern spinlock_t vsock_table_lock; ...@@ -27,6 +27,7 @@ extern spinlock_t vsock_table_lock;
struct vsock_sock { struct vsock_sock {
/* sk must be the first member. */ /* sk must be the first member. */
struct sock sk; struct sock sk;
const struct vsock_transport *transport;
struct sockaddr_vm local_addr; struct sockaddr_vm local_addr;
struct sockaddr_vm remote_addr; struct sockaddr_vm remote_addr;
/* Links for the global tables of bound and connected sockets. */ /* Links for the global tables of bound and connected sockets. */
......
...@@ -126,7 +126,7 @@ static struct proto vsock_proto = { ...@@ -126,7 +126,7 @@ static struct proto vsock_proto = {
*/ */
#define VSOCK_DEFAULT_CONNECT_TIMEOUT (2 * HZ) #define VSOCK_DEFAULT_CONNECT_TIMEOUT (2 * HZ)
static const struct vsock_transport *transport; static const struct vsock_transport *transport_single;
static DEFINE_MUTEX(vsock_register_mutex); static DEFINE_MUTEX(vsock_register_mutex);
/**** UTILS ****/ /**** UTILS ****/
...@@ -408,7 +408,9 @@ static bool vsock_is_pending(struct sock *sk) ...@@ -408,7 +408,9 @@ static bool vsock_is_pending(struct sock *sk)
static int vsock_send_shutdown(struct sock *sk, int mode) static int vsock_send_shutdown(struct sock *sk, int mode)
{ {
return transport->shutdown(vsock_sk(sk), mode); struct vsock_sock *vsk = vsock_sk(sk);
return vsk->transport->shutdown(vsk, mode);
} }
static void vsock_pending_work(struct work_struct *work) static void vsock_pending_work(struct work_struct *work)
...@@ -518,7 +520,7 @@ static int __vsock_bind_stream(struct vsock_sock *vsk, ...@@ -518,7 +520,7 @@ static int __vsock_bind_stream(struct vsock_sock *vsk,
static int __vsock_bind_dgram(struct vsock_sock *vsk, static int __vsock_bind_dgram(struct vsock_sock *vsk,
struct sockaddr_vm *addr) struct sockaddr_vm *addr)
{ {
return transport->dgram_bind(vsk, addr); return vsk->transport->dgram_bind(vsk, addr);
} }
static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr) static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr)
...@@ -536,7 +538,7 @@ static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr) ...@@ -536,7 +538,7 @@ static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr)
* like AF_INET prevents binding to a non-local IP address (in most * like AF_INET prevents binding to a non-local IP address (in most
* cases), we only allow binding to the local CID. * cases), we only allow binding to the local CID.
*/ */
cid = transport->get_local_cid(); cid = vsk->transport->get_local_cid();
if (addr->svm_cid != cid && addr->svm_cid != VMADDR_CID_ANY) if (addr->svm_cid != cid && addr->svm_cid != VMADDR_CID_ANY)
return -EADDRNOTAVAIL; return -EADDRNOTAVAIL;
...@@ -586,6 +588,7 @@ struct sock *__vsock_create(struct net *net, ...@@ -586,6 +588,7 @@ struct sock *__vsock_create(struct net *net,
sk->sk_type = type; sk->sk_type = type;
vsk = vsock_sk(sk); vsk = vsock_sk(sk);
vsk->transport = transport_single;
vsock_addr_init(&vsk->local_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY); vsock_addr_init(&vsk->local_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY);
vsock_addr_init(&vsk->remote_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY); vsock_addr_init(&vsk->remote_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY);
...@@ -616,7 +619,7 @@ struct sock *__vsock_create(struct net *net, ...@@ -616,7 +619,7 @@ struct sock *__vsock_create(struct net *net,
vsk->connect_timeout = VSOCK_DEFAULT_CONNECT_TIMEOUT; vsk->connect_timeout = VSOCK_DEFAULT_CONNECT_TIMEOUT;
} }
if (transport->init(vsk, psk) < 0) { if (vsk->transport->init(vsk, psk) < 0) {
sk_free(sk); sk_free(sk);
return NULL; return NULL;
} }
...@@ -640,7 +643,7 @@ static void __vsock_release(struct sock *sk, int level) ...@@ -640,7 +643,7 @@ static void __vsock_release(struct sock *sk, int level)
/* The release call is supposed to use lock_sock_nested() /* The release call is supposed to use lock_sock_nested()
* rather than lock_sock(), if a sock lock should be acquired. * rather than lock_sock(), if a sock lock should be acquired.
*/ */
transport->release(vsk); vsk->transport->release(vsk);
/* When "level" is SINGLE_DEPTH_NESTING, use the nested /* When "level" is SINGLE_DEPTH_NESTING, use the nested
* version to avoid the warning "possible recursive locking * version to avoid the warning "possible recursive locking
...@@ -668,7 +671,7 @@ static void vsock_sk_destruct(struct sock *sk) ...@@ -668,7 +671,7 @@ static void vsock_sk_destruct(struct sock *sk)
{ {
struct vsock_sock *vsk = vsock_sk(sk); struct vsock_sock *vsk = vsock_sk(sk);
transport->destruct(vsk); vsk->transport->destruct(vsk);
/* When clearing these addresses, there's no need to set the family and /* When clearing these addresses, there's no need to set the family and
* possibly register the address family with the kernel. * possibly register the address family with the kernel.
...@@ -692,13 +695,13 @@ static int vsock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb) ...@@ -692,13 +695,13 @@ static int vsock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb)
s64 vsock_stream_has_data(struct vsock_sock *vsk) s64 vsock_stream_has_data(struct vsock_sock *vsk)
{ {
return transport->stream_has_data(vsk); return vsk->transport->stream_has_data(vsk);
} }
EXPORT_SYMBOL_GPL(vsock_stream_has_data); EXPORT_SYMBOL_GPL(vsock_stream_has_data);
s64 vsock_stream_has_space(struct vsock_sock *vsk) s64 vsock_stream_has_space(struct vsock_sock *vsk)
{ {
return transport->stream_has_space(vsk); return vsk->transport->stream_has_space(vsk);
} }
EXPORT_SYMBOL_GPL(vsock_stream_has_space); EXPORT_SYMBOL_GPL(vsock_stream_has_space);
...@@ -867,6 +870,7 @@ static __poll_t vsock_poll(struct file *file, struct socket *sock, ...@@ -867,6 +870,7 @@ static __poll_t vsock_poll(struct file *file, struct socket *sock,
mask |= EPOLLOUT | EPOLLWRNORM | EPOLLWRBAND; mask |= EPOLLOUT | EPOLLWRNORM | EPOLLWRBAND;
} else if (sock->type == SOCK_STREAM) { } else if (sock->type == SOCK_STREAM) {
const struct vsock_transport *transport = vsk->transport;
lock_sock(sk); lock_sock(sk);
/* Listening sockets that have connections in their accept /* Listening sockets that have connections in their accept
...@@ -942,6 +946,7 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg, ...@@ -942,6 +946,7 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
struct sock *sk; struct sock *sk;
struct vsock_sock *vsk; struct vsock_sock *vsk;
struct sockaddr_vm *remote_addr; struct sockaddr_vm *remote_addr;
const struct vsock_transport *transport;
if (msg->msg_flags & MSG_OOB) if (msg->msg_flags & MSG_OOB)
return -EOPNOTSUPP; return -EOPNOTSUPP;
...@@ -950,6 +955,7 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg, ...@@ -950,6 +955,7 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
err = 0; err = 0;
sk = sock->sk; sk = sock->sk;
vsk = vsock_sk(sk); vsk = vsock_sk(sk);
transport = vsk->transport;
lock_sock(sk); lock_sock(sk);
...@@ -1034,7 +1040,7 @@ static int vsock_dgram_connect(struct socket *sock, ...@@ -1034,7 +1040,7 @@ static int vsock_dgram_connect(struct socket *sock,
if (err) if (err)
goto out; goto out;
if (!transport->dgram_allow(remote_addr->svm_cid, if (!vsk->transport->dgram_allow(remote_addr->svm_cid,
remote_addr->svm_port)) { remote_addr->svm_port)) {
err = -EINVAL; err = -EINVAL;
goto out; goto out;
...@@ -1051,7 +1057,9 @@ static int vsock_dgram_connect(struct socket *sock, ...@@ -1051,7 +1057,9 @@ static int vsock_dgram_connect(struct socket *sock,
static int vsock_dgram_recvmsg(struct socket *sock, struct msghdr *msg, static int vsock_dgram_recvmsg(struct socket *sock, struct msghdr *msg,
size_t len, int flags) size_t len, int flags)
{ {
return transport->dgram_dequeue(vsock_sk(sock->sk), msg, len, flags); struct vsock_sock *vsk = vsock_sk(sock->sk);
return vsk->transport->dgram_dequeue(vsk, msg, len, flags);
} }
static const struct proto_ops vsock_dgram_ops = { static const struct proto_ops vsock_dgram_ops = {
...@@ -1077,6 +1085,8 @@ static const struct proto_ops vsock_dgram_ops = { ...@@ -1077,6 +1085,8 @@ static const struct proto_ops vsock_dgram_ops = {
static int vsock_transport_cancel_pkt(struct vsock_sock *vsk) static int vsock_transport_cancel_pkt(struct vsock_sock *vsk)
{ {
const struct vsock_transport *transport = vsk->transport;
if (!transport->cancel_pkt) if (!transport->cancel_pkt)
return -EOPNOTSUPP; return -EOPNOTSUPP;
...@@ -1113,6 +1123,7 @@ static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr, ...@@ -1113,6 +1123,7 @@ static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr,
int err; int err;
struct sock *sk; struct sock *sk;
struct vsock_sock *vsk; struct vsock_sock *vsk;
const struct vsock_transport *transport;
struct sockaddr_vm *remote_addr; struct sockaddr_vm *remote_addr;
long timeout; long timeout;
DEFINE_WAIT(wait); DEFINE_WAIT(wait);
...@@ -1120,6 +1131,7 @@ static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr, ...@@ -1120,6 +1131,7 @@ static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr,
err = 0; err = 0;
sk = sock->sk; sk = sock->sk;
vsk = vsock_sk(sk); vsk = vsock_sk(sk);
transport = vsk->transport;
lock_sock(sk); lock_sock(sk);
...@@ -1363,6 +1375,7 @@ static int vsock_stream_setsockopt(struct socket *sock, ...@@ -1363,6 +1375,7 @@ static int vsock_stream_setsockopt(struct socket *sock,
int err; int err;
struct sock *sk; struct sock *sk;
struct vsock_sock *vsk; struct vsock_sock *vsk;
const struct vsock_transport *transport;
u64 val; u64 val;
if (level != AF_VSOCK) if (level != AF_VSOCK)
...@@ -1383,6 +1396,7 @@ static int vsock_stream_setsockopt(struct socket *sock, ...@@ -1383,6 +1396,7 @@ static int vsock_stream_setsockopt(struct socket *sock,
err = 0; err = 0;
sk = sock->sk; sk = sock->sk;
vsk = vsock_sk(sk); vsk = vsock_sk(sk);
transport = vsk->transport;
lock_sock(sk); lock_sock(sk);
...@@ -1440,6 +1454,7 @@ static int vsock_stream_getsockopt(struct socket *sock, ...@@ -1440,6 +1454,7 @@ static int vsock_stream_getsockopt(struct socket *sock,
int len; int len;
struct sock *sk; struct sock *sk;
struct vsock_sock *vsk; struct vsock_sock *vsk;
const struct vsock_transport *transport;
u64 val; u64 val;
if (level != AF_VSOCK) if (level != AF_VSOCK)
...@@ -1463,6 +1478,7 @@ static int vsock_stream_getsockopt(struct socket *sock, ...@@ -1463,6 +1478,7 @@ static int vsock_stream_getsockopt(struct socket *sock,
err = 0; err = 0;
sk = sock->sk; sk = sock->sk;
vsk = vsock_sk(sk); vsk = vsock_sk(sk);
transport = vsk->transport;
switch (optname) { switch (optname) {
case SO_VM_SOCKETS_BUFFER_SIZE: case SO_VM_SOCKETS_BUFFER_SIZE:
...@@ -1507,6 +1523,7 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg, ...@@ -1507,6 +1523,7 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
{ {
struct sock *sk; struct sock *sk;
struct vsock_sock *vsk; struct vsock_sock *vsk;
const struct vsock_transport *transport;
ssize_t total_written; ssize_t total_written;
long timeout; long timeout;
int err; int err;
...@@ -1515,6 +1532,7 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg, ...@@ -1515,6 +1532,7 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
sk = sock->sk; sk = sock->sk;
vsk = vsock_sk(sk); vsk = vsock_sk(sk);
transport = vsk->transport;
total_written = 0; total_written = 0;
err = 0; err = 0;
...@@ -1646,6 +1664,7 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, ...@@ -1646,6 +1664,7 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
{ {
struct sock *sk; struct sock *sk;
struct vsock_sock *vsk; struct vsock_sock *vsk;
const struct vsock_transport *transport;
int err; int err;
size_t target; size_t target;
ssize_t copied; ssize_t copied;
...@@ -1656,6 +1675,7 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, ...@@ -1656,6 +1675,7 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
sk = sock->sk; sk = sock->sk;
vsk = vsock_sk(sk); vsk = vsock_sk(sk);
transport = vsk->transport;
err = 0; err = 0;
lock_sock(sk); lock_sock(sk);
...@@ -1870,7 +1890,7 @@ static long vsock_dev_do_ioctl(struct file *filp, ...@@ -1870,7 +1890,7 @@ static long vsock_dev_do_ioctl(struct file *filp,
switch (cmd) { switch (cmd) {
case IOCTL_VM_SOCKETS_GET_LOCAL_CID: case IOCTL_VM_SOCKETS_GET_LOCAL_CID:
if (put_user(transport->get_local_cid(), p) != 0) if (put_user(transport_single->get_local_cid(), p) != 0)
retval = -EFAULT; retval = -EFAULT;
break; break;
...@@ -1917,7 +1937,7 @@ int __vsock_core_init(const struct vsock_transport *t, struct module *owner) ...@@ -1917,7 +1937,7 @@ int __vsock_core_init(const struct vsock_transport *t, struct module *owner)
if (err) if (err)
return err; return err;
if (transport) { if (transport_single) {
err = -EBUSY; err = -EBUSY;
goto err_busy; goto err_busy;
} }
...@@ -1926,7 +1946,7 @@ int __vsock_core_init(const struct vsock_transport *t, struct module *owner) ...@@ -1926,7 +1946,7 @@ int __vsock_core_init(const struct vsock_transport *t, struct module *owner)
* unload while there are open sockets. * unload while there are open sockets.
*/ */
vsock_proto.owner = owner; vsock_proto.owner = owner;
transport = t; transport_single = t;
vsock_device.minor = MISC_DYNAMIC_MINOR; vsock_device.minor = MISC_DYNAMIC_MINOR;
err = misc_register(&vsock_device); err = misc_register(&vsock_device);
...@@ -1956,7 +1976,7 @@ int __vsock_core_init(const struct vsock_transport *t, struct module *owner) ...@@ -1956,7 +1976,7 @@ int __vsock_core_init(const struct vsock_transport *t, struct module *owner)
err_deregister_misc: err_deregister_misc:
misc_deregister(&vsock_device); misc_deregister(&vsock_device);
err_reset_transport: err_reset_transport:
transport = NULL; transport_single = NULL;
err_busy: err_busy:
mutex_unlock(&vsock_register_mutex); mutex_unlock(&vsock_register_mutex);
return err; return err;
...@@ -1973,7 +1993,7 @@ void vsock_core_exit(void) ...@@ -1973,7 +1993,7 @@ void vsock_core_exit(void)
/* We do not want the assignment below re-ordered. */ /* We do not want the assignment below re-ordered. */
mb(); mb();
transport = NULL; transport_single = NULL;
mutex_unlock(&vsock_register_mutex); mutex_unlock(&vsock_register_mutex);
} }
...@@ -1984,7 +2004,7 @@ const struct vsock_transport *vsock_core_get_transport(void) ...@@ -1984,7 +2004,7 @@ const struct vsock_transport *vsock_core_get_transport(void)
/* vsock_register_mutex not taken since only the transport uses this /* vsock_register_mutex not taken since only the transport uses this
* function and only while registered. * function and only while registered.
*/ */
return transport; return transport_single;
} }
EXPORT_SYMBOL_GPL(vsock_core_get_transport); EXPORT_SYMBOL_GPL(vsock_core_get_transport);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment