Commit 5aa3bd9b authored by David S. Miller's avatar David S. Miller

Merge branch 'virtio-vsock-seqpacket'

Arseny Krasnov says:

====================
virtio/vsock: introduce SOCK_SEQPACKET support

This patchset implements support of SOCK_SEQPACKET for virtio
transport.
	As SOCK_SEQPACKET guarantees to save record boundaries, so to
do it, new bit for field 'flags' was added: SEQ_EOR. This bit is
set to 1 in last RW packet of message.
	Now as  packets of one socket are not reordered neither on vsock
nor on vhost transport layers, such bit allows to restore original
message on receiver's side. If user's buffer is smaller than message
length, when all out of size data is dropped.
	Maximum length of datagram is limited by 'peer_buf_alloc' value.
	Implementation also supports 'MSG_TRUNC' flags.
	Tests also implemented.

	Thanks to stsp2@yandex.ru for encouragements and initial design
recommendations.
====================
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents 57806b28 184039ee
...@@ -31,7 +31,8 @@ ...@@ -31,7 +31,8 @@
enum { enum {
VHOST_VSOCK_FEATURES = VHOST_FEATURES | VHOST_VSOCK_FEATURES = VHOST_FEATURES |
(1ULL << VIRTIO_F_ACCESS_PLATFORM) (1ULL << VIRTIO_F_ACCESS_PLATFORM) |
(1ULL << VIRTIO_VSOCK_F_SEQPACKET)
}; };
enum { enum {
...@@ -56,6 +57,7 @@ struct vhost_vsock { ...@@ -56,6 +57,7 @@ struct vhost_vsock {
atomic_t queued_replies; atomic_t queued_replies;
u32 guest_cid; u32 guest_cid;
bool seqpacket_allow;
}; };
static u32 vhost_transport_get_local_cid(void) static u32 vhost_transport_get_local_cid(void)
...@@ -112,6 +114,7 @@ vhost_transport_do_send_pkt(struct vhost_vsock *vsock, ...@@ -112,6 +114,7 @@ vhost_transport_do_send_pkt(struct vhost_vsock *vsock,
size_t nbytes; size_t nbytes;
size_t iov_len, payload_len; size_t iov_len, payload_len;
int head; int head;
bool restore_flag = false;
spin_lock_bh(&vsock->send_pkt_list_lock); spin_lock_bh(&vsock->send_pkt_list_lock);
if (list_empty(&vsock->send_pkt_list)) { if (list_empty(&vsock->send_pkt_list)) {
...@@ -168,9 +171,26 @@ vhost_transport_do_send_pkt(struct vhost_vsock *vsock, ...@@ -168,9 +171,26 @@ vhost_transport_do_send_pkt(struct vhost_vsock *vsock,
/* If the packet is greater than the space available in the /* If the packet is greater than the space available in the
* buffer, we split it using multiple buffers. * buffer, we split it using multiple buffers.
*/ */
if (payload_len > iov_len - sizeof(pkt->hdr)) if (payload_len > iov_len - sizeof(pkt->hdr)) {
payload_len = iov_len - sizeof(pkt->hdr); payload_len = iov_len - sizeof(pkt->hdr);
/* As we are copying pieces of large packet's buffer to
* small rx buffers, headers of packets in rx queue are
* created dynamically and are initialized with header
* of current packet(except length). But in case of
* SOCK_SEQPACKET, we also must clear record delimeter
* bit(VIRTIO_VSOCK_SEQ_EOR). Otherwise, instead of one
* packet with delimeter(which marks end of record),
* there will be sequence of packets with delimeter
* bit set. After initialized header will be copied to
* rx buffer, this bit will be restored.
*/
if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SEQ_EOR) {
pkt->hdr.flags &= ~cpu_to_le32(VIRTIO_VSOCK_SEQ_EOR);
restore_flag = true;
}
}
/* Set the correct length in the header */ /* Set the correct length in the header */
pkt->hdr.len = cpu_to_le32(payload_len); pkt->hdr.len = cpu_to_le32(payload_len);
...@@ -204,6 +224,9 @@ vhost_transport_do_send_pkt(struct vhost_vsock *vsock, ...@@ -204,6 +224,9 @@ vhost_transport_do_send_pkt(struct vhost_vsock *vsock,
* to send it with the next available buffer. * to send it with the next available buffer.
*/ */
if (pkt->off < pkt->len) { if (pkt->off < pkt->len) {
if (restore_flag)
pkt->hdr.flags |= cpu_to_le32(VIRTIO_VSOCK_SEQ_EOR);
/* We are queueing the same virtio_vsock_pkt to handle /* We are queueing the same virtio_vsock_pkt to handle
* the remaining bytes, and we want to deliver it * the remaining bytes, and we want to deliver it
* to monitoring devices in the next iteration. * to monitoring devices in the next iteration.
...@@ -354,7 +377,6 @@ vhost_vsock_alloc_pkt(struct vhost_virtqueue *vq, ...@@ -354,7 +377,6 @@ vhost_vsock_alloc_pkt(struct vhost_virtqueue *vq,
return NULL; return NULL;
} }
if (le16_to_cpu(pkt->hdr.type) == VIRTIO_VSOCK_TYPE_STREAM)
pkt->len = le32_to_cpu(pkt->hdr.len); pkt->len = le32_to_cpu(pkt->hdr.len);
/* No payload */ /* No payload */
...@@ -398,6 +420,8 @@ static bool vhost_vsock_more_replies(struct vhost_vsock *vsock) ...@@ -398,6 +420,8 @@ static bool vhost_vsock_more_replies(struct vhost_vsock *vsock)
return val < vq->num; return val < vq->num;
} }
static bool vhost_transport_seqpacket_allow(u32 remote_cid);
static struct virtio_transport vhost_transport = { static struct virtio_transport vhost_transport = {
.transport = { .transport = {
.module = THIS_MODULE, .module = THIS_MODULE,
...@@ -424,6 +448,11 @@ static struct virtio_transport vhost_transport = { ...@@ -424,6 +448,11 @@ static struct virtio_transport vhost_transport = {
.stream_is_active = virtio_transport_stream_is_active, .stream_is_active = virtio_transport_stream_is_active,
.stream_allow = virtio_transport_stream_allow, .stream_allow = virtio_transport_stream_allow,
.seqpacket_dequeue = virtio_transport_seqpacket_dequeue,
.seqpacket_enqueue = virtio_transport_seqpacket_enqueue,
.seqpacket_allow = vhost_transport_seqpacket_allow,
.seqpacket_has_data = virtio_transport_seqpacket_has_data,
.notify_poll_in = virtio_transport_notify_poll_in, .notify_poll_in = virtio_transport_notify_poll_in,
.notify_poll_out = virtio_transport_notify_poll_out, .notify_poll_out = virtio_transport_notify_poll_out,
.notify_recv_init = virtio_transport_notify_recv_init, .notify_recv_init = virtio_transport_notify_recv_init,
...@@ -441,6 +470,22 @@ static struct virtio_transport vhost_transport = { ...@@ -441,6 +470,22 @@ static struct virtio_transport vhost_transport = {
.send_pkt = vhost_transport_send_pkt, .send_pkt = vhost_transport_send_pkt,
}; };
static bool vhost_transport_seqpacket_allow(u32 remote_cid)
{
struct vhost_vsock *vsock;
bool seqpacket_allow = false;
rcu_read_lock();
vsock = vhost_vsock_get(remote_cid);
if (vsock)
seqpacket_allow = vsock->seqpacket_allow;
rcu_read_unlock();
return seqpacket_allow;
}
static void vhost_vsock_handle_tx_kick(struct vhost_work *work) static void vhost_vsock_handle_tx_kick(struct vhost_work *work)
{ {
struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue, struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
...@@ -785,6 +830,9 @@ static int vhost_vsock_set_features(struct vhost_vsock *vsock, u64 features) ...@@ -785,6 +830,9 @@ static int vhost_vsock_set_features(struct vhost_vsock *vsock, u64 features)
goto err; goto err;
} }
if (features & (1ULL << VIRTIO_VSOCK_F_SEQPACKET))
vsock->seqpacket_allow = true;
for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) { for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
vq = &vsock->vqs[i]; vq = &vsock->vqs[i];
mutex_lock(&vq->mutex); mutex_lock(&vq->mutex);
......
...@@ -36,6 +36,7 @@ struct virtio_vsock_sock { ...@@ -36,6 +36,7 @@ struct virtio_vsock_sock {
u32 rx_bytes; u32 rx_bytes;
u32 buf_alloc; u32 buf_alloc;
struct list_head rx_queue; struct list_head rx_queue;
u32 msg_count;
}; };
struct virtio_vsock_pkt { struct virtio_vsock_pkt {
...@@ -80,8 +81,17 @@ virtio_transport_dgram_dequeue(struct vsock_sock *vsk, ...@@ -80,8 +81,17 @@ virtio_transport_dgram_dequeue(struct vsock_sock *vsk,
struct msghdr *msg, struct msghdr *msg,
size_t len, int flags); size_t len, int flags);
int
virtio_transport_seqpacket_enqueue(struct vsock_sock *vsk,
struct msghdr *msg,
size_t len);
ssize_t
virtio_transport_seqpacket_dequeue(struct vsock_sock *vsk,
struct msghdr *msg,
int flags);
s64 virtio_transport_stream_has_data(struct vsock_sock *vsk); s64 virtio_transport_stream_has_data(struct vsock_sock *vsk);
s64 virtio_transport_stream_has_space(struct vsock_sock *vsk); s64 virtio_transport_stream_has_space(struct vsock_sock *vsk);
u32 virtio_transport_seqpacket_has_data(struct vsock_sock *vsk);
int virtio_transport_do_socket_init(struct vsock_sock *vsk, int virtio_transport_do_socket_init(struct vsock_sock *vsk,
struct vsock_sock *psk); struct vsock_sock *psk);
......
...@@ -135,6 +135,14 @@ struct vsock_transport { ...@@ -135,6 +135,14 @@ struct vsock_transport {
bool (*stream_is_active)(struct vsock_sock *); bool (*stream_is_active)(struct vsock_sock *);
bool (*stream_allow)(u32 cid, u32 port); bool (*stream_allow)(u32 cid, u32 port);
/* SEQ_PACKET. */
ssize_t (*seqpacket_dequeue)(struct vsock_sock *vsk, struct msghdr *msg,
int flags);
int (*seqpacket_enqueue)(struct vsock_sock *vsk, struct msghdr *msg,
size_t len);
bool (*seqpacket_allow)(u32 remote_cid);
u32 (*seqpacket_has_data)(struct vsock_sock *vsk);
/* Notification. */ /* Notification. */
int (*notify_poll_in)(struct vsock_sock *, size_t, bool *); int (*notify_poll_in)(struct vsock_sock *, size_t, bool *);
int (*notify_poll_out)(struct vsock_sock *, size_t, bool *); int (*notify_poll_out)(struct vsock_sock *, size_t, bool *);
......
...@@ -9,9 +9,12 @@ ...@@ -9,9 +9,12 @@
#include <linux/tracepoint.h> #include <linux/tracepoint.h>
TRACE_DEFINE_ENUM(VIRTIO_VSOCK_TYPE_STREAM); TRACE_DEFINE_ENUM(VIRTIO_VSOCK_TYPE_STREAM);
TRACE_DEFINE_ENUM(VIRTIO_VSOCK_TYPE_SEQPACKET);
#define show_type(val) \ #define show_type(val) \
__print_symbolic(val, { VIRTIO_VSOCK_TYPE_STREAM, "STREAM" }) __print_symbolic(val, \
{ VIRTIO_VSOCK_TYPE_STREAM, "STREAM" }, \
{ VIRTIO_VSOCK_TYPE_SEQPACKET, "SEQPACKET" })
TRACE_DEFINE_ENUM(VIRTIO_VSOCK_OP_INVALID); TRACE_DEFINE_ENUM(VIRTIO_VSOCK_OP_INVALID);
TRACE_DEFINE_ENUM(VIRTIO_VSOCK_OP_REQUEST); TRACE_DEFINE_ENUM(VIRTIO_VSOCK_OP_REQUEST);
......
...@@ -38,6 +38,9 @@ ...@@ -38,6 +38,9 @@
#include <linux/virtio_ids.h> #include <linux/virtio_ids.h>
#include <linux/virtio_config.h> #include <linux/virtio_config.h>
/* The feature bitmap for virtio vsock */
#define VIRTIO_VSOCK_F_SEQPACKET 1 /* SOCK_SEQPACKET supported */
struct virtio_vsock_config { struct virtio_vsock_config {
__le64 guest_cid; __le64 guest_cid;
} __attribute__((packed)); } __attribute__((packed));
...@@ -65,6 +68,7 @@ struct virtio_vsock_hdr { ...@@ -65,6 +68,7 @@ struct virtio_vsock_hdr {
enum virtio_vsock_type { enum virtio_vsock_type {
VIRTIO_VSOCK_TYPE_STREAM = 1, VIRTIO_VSOCK_TYPE_STREAM = 1,
VIRTIO_VSOCK_TYPE_SEQPACKET = 2,
}; };
enum virtio_vsock_op { enum virtio_vsock_op {
...@@ -91,4 +95,9 @@ enum virtio_vsock_shutdown { ...@@ -91,4 +95,9 @@ enum virtio_vsock_shutdown {
VIRTIO_VSOCK_SHUTDOWN_SEND = 2, VIRTIO_VSOCK_SHUTDOWN_SEND = 2,
}; };
/* VIRTIO_VSOCK_OP_RW flags values */
enum virtio_vsock_rw {
VIRTIO_VSOCK_SEQ_EOR = 1,
};
#endif /* _UAPI_LINUX_VIRTIO_VSOCK_H */ #endif /* _UAPI_LINUX_VIRTIO_VSOCK_H */
This diff is collapsed.
...@@ -62,6 +62,7 @@ struct virtio_vsock { ...@@ -62,6 +62,7 @@ struct virtio_vsock {
struct virtio_vsock_event event_list[8]; struct virtio_vsock_event event_list[8];
u32 guest_cid; u32 guest_cid;
bool seqpacket_allow;
}; };
static u32 virtio_transport_get_local_cid(void) static u32 virtio_transport_get_local_cid(void)
...@@ -443,6 +444,8 @@ static void virtio_vsock_rx_done(struct virtqueue *vq) ...@@ -443,6 +444,8 @@ static void virtio_vsock_rx_done(struct virtqueue *vq)
queue_work(virtio_vsock_workqueue, &vsock->rx_work); queue_work(virtio_vsock_workqueue, &vsock->rx_work);
} }
static bool virtio_transport_seqpacket_allow(u32 remote_cid);
static struct virtio_transport virtio_transport = { static struct virtio_transport virtio_transport = {
.transport = { .transport = {
.module = THIS_MODULE, .module = THIS_MODULE,
...@@ -469,6 +472,11 @@ static struct virtio_transport virtio_transport = { ...@@ -469,6 +472,11 @@ static struct virtio_transport virtio_transport = {
.stream_is_active = virtio_transport_stream_is_active, .stream_is_active = virtio_transport_stream_is_active,
.stream_allow = virtio_transport_stream_allow, .stream_allow = virtio_transport_stream_allow,
.seqpacket_dequeue = virtio_transport_seqpacket_dequeue,
.seqpacket_enqueue = virtio_transport_seqpacket_enqueue,
.seqpacket_allow = virtio_transport_seqpacket_allow,
.seqpacket_has_data = virtio_transport_seqpacket_has_data,
.notify_poll_in = virtio_transport_notify_poll_in, .notify_poll_in = virtio_transport_notify_poll_in,
.notify_poll_out = virtio_transport_notify_poll_out, .notify_poll_out = virtio_transport_notify_poll_out,
.notify_recv_init = virtio_transport_notify_recv_init, .notify_recv_init = virtio_transport_notify_recv_init,
...@@ -485,6 +493,19 @@ static struct virtio_transport virtio_transport = { ...@@ -485,6 +493,19 @@ static struct virtio_transport virtio_transport = {
.send_pkt = virtio_transport_send_pkt, .send_pkt = virtio_transport_send_pkt,
}; };
static bool virtio_transport_seqpacket_allow(u32 remote_cid)
{
struct virtio_vsock *vsock;
bool seqpacket_allow;
rcu_read_lock();
vsock = rcu_dereference(the_virtio_vsock);
seqpacket_allow = vsock->seqpacket_allow;
rcu_read_unlock();
return seqpacket_allow;
}
static void virtio_transport_rx_work(struct work_struct *work) static void virtio_transport_rx_work(struct work_struct *work)
{ {
struct virtio_vsock *vsock = struct virtio_vsock *vsock =
...@@ -608,10 +629,14 @@ static int virtio_vsock_probe(struct virtio_device *vdev) ...@@ -608,10 +629,14 @@ static int virtio_vsock_probe(struct virtio_device *vdev)
vsock->event_run = true; vsock->event_run = true;
mutex_unlock(&vsock->event_lock); mutex_unlock(&vsock->event_lock);
if (virtio_has_feature(vdev, VIRTIO_VSOCK_F_SEQPACKET))
vsock->seqpacket_allow = true;
vdev->priv = vsock; vdev->priv = vsock;
rcu_assign_pointer(the_virtio_vsock, vsock); rcu_assign_pointer(the_virtio_vsock, vsock);
mutex_unlock(&the_virtio_vsock_mutex); mutex_unlock(&the_virtio_vsock_mutex);
return 0; return 0;
out: out:
...@@ -695,6 +720,7 @@ static struct virtio_device_id id_table[] = { ...@@ -695,6 +720,7 @@ static struct virtio_device_id id_table[] = {
}; };
static unsigned int features[] = { static unsigned int features[] = {
VIRTIO_VSOCK_F_SEQPACKET
}; };
static struct virtio_driver virtio_vsock_driver = { static struct virtio_driver virtio_vsock_driver = {
......
...@@ -74,6 +74,10 @@ virtio_transport_alloc_pkt(struct virtio_vsock_pkt_info *info, ...@@ -74,6 +74,10 @@ virtio_transport_alloc_pkt(struct virtio_vsock_pkt_info *info,
err = memcpy_from_msg(pkt->buf, info->msg, len); err = memcpy_from_msg(pkt->buf, info->msg, len);
if (err) if (err)
goto out; goto out;
if (msg_data_left(info->msg) == 0 &&
info->type == VIRTIO_VSOCK_TYPE_SEQPACKET)
pkt->hdr.flags |= cpu_to_le32(VIRTIO_VSOCK_SEQ_EOR);
} }
trace_virtio_transport_alloc_pkt(src_cid, src_port, trace_virtio_transport_alloc_pkt(src_cid, src_port,
...@@ -165,6 +169,14 @@ void virtio_transport_deliver_tap_pkt(struct virtio_vsock_pkt *pkt) ...@@ -165,6 +169,14 @@ void virtio_transport_deliver_tap_pkt(struct virtio_vsock_pkt *pkt)
} }
EXPORT_SYMBOL_GPL(virtio_transport_deliver_tap_pkt); EXPORT_SYMBOL_GPL(virtio_transport_deliver_tap_pkt);
static u16 virtio_transport_get_type(struct sock *sk)
{
if (sk->sk_type == SOCK_STREAM)
return VIRTIO_VSOCK_TYPE_STREAM;
else
return VIRTIO_VSOCK_TYPE_SEQPACKET;
}
/* This function can only be used on connecting/connected sockets, /* This function can only be used on connecting/connected sockets,
* since a socket assigned to a transport is required. * since a socket assigned to a transport is required.
* *
...@@ -179,6 +191,8 @@ static int virtio_transport_send_pkt_info(struct vsock_sock *vsk, ...@@ -179,6 +191,8 @@ static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
struct virtio_vsock_pkt *pkt; struct virtio_vsock_pkt *pkt;
u32 pkt_len = info->pkt_len; u32 pkt_len = info->pkt_len;
info->type = virtio_transport_get_type(sk_vsock(vsk));
t_ops = virtio_transport_get_ops(vsk); t_ops = virtio_transport_get_ops(vsk);
if (unlikely(!t_ops)) if (unlikely(!t_ops))
return -EFAULT; return -EFAULT;
...@@ -269,13 +283,10 @@ void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit) ...@@ -269,13 +283,10 @@ void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit)
} }
EXPORT_SYMBOL_GPL(virtio_transport_put_credit); EXPORT_SYMBOL_GPL(virtio_transport_put_credit);
static int virtio_transport_send_credit_update(struct vsock_sock *vsk, static int virtio_transport_send_credit_update(struct vsock_sock *vsk)
int type,
struct virtio_vsock_hdr *hdr)
{ {
struct virtio_vsock_pkt_info info = { struct virtio_vsock_pkt_info info = {
.op = VIRTIO_VSOCK_OP_CREDIT_UPDATE, .op = VIRTIO_VSOCK_OP_CREDIT_UPDATE,
.type = type,
.vsk = vsk, .vsk = vsk,
}; };
...@@ -383,11 +394,8 @@ virtio_transport_stream_do_dequeue(struct vsock_sock *vsk, ...@@ -383,11 +394,8 @@ virtio_transport_stream_do_dequeue(struct vsock_sock *vsk,
* messages, we set the limit to a high value. TODO: experiment * messages, we set the limit to a high value. TODO: experiment
* with different values. * with different values.
*/ */
if (free_space < VIRTIO_VSOCK_MAX_PKT_BUF_SIZE) { if (free_space < VIRTIO_VSOCK_MAX_PKT_BUF_SIZE)
virtio_transport_send_credit_update(vsk, virtio_transport_send_credit_update(vsk);
VIRTIO_VSOCK_TYPE_STREAM,
NULL);
}
return total; return total;
...@@ -397,6 +405,78 @@ virtio_transport_stream_do_dequeue(struct vsock_sock *vsk, ...@@ -397,6 +405,78 @@ virtio_transport_stream_do_dequeue(struct vsock_sock *vsk,
return err; return err;
} }
static int virtio_transport_seqpacket_do_dequeue(struct vsock_sock *vsk,
struct msghdr *msg,
int flags)
{
struct virtio_vsock_sock *vvs = vsk->trans;
struct virtio_vsock_pkt *pkt;
int dequeued_len = 0;
size_t user_buf_len = msg_data_left(msg);
bool copy_failed = false;
bool msg_ready = false;
spin_lock_bh(&vvs->rx_lock);
if (vvs->msg_count == 0) {
spin_unlock_bh(&vvs->rx_lock);
return 0;
}
while (!msg_ready) {
pkt = list_first_entry(&vvs->rx_queue, struct virtio_vsock_pkt, list);
if (!copy_failed) {
size_t pkt_len;
size_t bytes_to_copy;
pkt_len = (size_t)le32_to_cpu(pkt->hdr.len);
bytes_to_copy = min(user_buf_len, pkt_len);
if (bytes_to_copy) {
int err;
/* sk_lock is held by caller so no one else can dequeue.
* Unlock rx_lock since memcpy_to_msg() may sleep.
*/
spin_unlock_bh(&vvs->rx_lock);
err = memcpy_to_msg(msg, pkt->buf, bytes_to_copy);
if (err) {
/* Copy of message failed, set flag to skip
* copy path for rest of fragments. Rest of
* fragments will be freed without copy.
*/
copy_failed = true;
dequeued_len = err;
} else {
user_buf_len -= bytes_to_copy;
}
spin_lock_bh(&vvs->rx_lock);
}
if (dequeued_len >= 0)
dequeued_len += pkt_len;
}
if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SEQ_EOR) {
msg_ready = true;
vvs->msg_count--;
}
virtio_transport_dec_rx_pkt(vvs, pkt);
list_del(&pkt->list);
virtio_transport_free_pkt(pkt);
}
spin_unlock_bh(&vvs->rx_lock);
virtio_transport_send_credit_update(vsk);
return dequeued_len;
}
ssize_t ssize_t
virtio_transport_stream_dequeue(struct vsock_sock *vsk, virtio_transport_stream_dequeue(struct vsock_sock *vsk,
struct msghdr *msg, struct msghdr *msg,
...@@ -409,6 +489,38 @@ virtio_transport_stream_dequeue(struct vsock_sock *vsk, ...@@ -409,6 +489,38 @@ virtio_transport_stream_dequeue(struct vsock_sock *vsk,
} }
EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue); EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue);
ssize_t
virtio_transport_seqpacket_dequeue(struct vsock_sock *vsk,
struct msghdr *msg,
int flags)
{
if (flags & MSG_PEEK)
return -EOPNOTSUPP;
return virtio_transport_seqpacket_do_dequeue(vsk, msg, flags);
}
EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_dequeue);
int
virtio_transport_seqpacket_enqueue(struct vsock_sock *vsk,
struct msghdr *msg,
size_t len)
{
struct virtio_vsock_sock *vvs = vsk->trans;
spin_lock_bh(&vvs->tx_lock);
if (len > vvs->peer_buf_alloc) {
spin_unlock_bh(&vvs->tx_lock);
return -EMSGSIZE;
}
spin_unlock_bh(&vvs->tx_lock);
return virtio_transport_stream_enqueue(vsk, msg, len);
}
EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_enqueue);
int int
virtio_transport_dgram_dequeue(struct vsock_sock *vsk, virtio_transport_dgram_dequeue(struct vsock_sock *vsk,
struct msghdr *msg, struct msghdr *msg,
...@@ -431,6 +543,19 @@ s64 virtio_transport_stream_has_data(struct vsock_sock *vsk) ...@@ -431,6 +543,19 @@ s64 virtio_transport_stream_has_data(struct vsock_sock *vsk)
} }
EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data); EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data);
u32 virtio_transport_seqpacket_has_data(struct vsock_sock *vsk)
{
struct virtio_vsock_sock *vvs = vsk->trans;
u32 msg_count;
spin_lock_bh(&vvs->rx_lock);
msg_count = vvs->msg_count;
spin_unlock_bh(&vvs->rx_lock);
return msg_count;
}
EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_has_data);
static s64 virtio_transport_has_space(struct vsock_sock *vsk) static s64 virtio_transport_has_space(struct vsock_sock *vsk)
{ {
struct virtio_vsock_sock *vvs = vsk->trans; struct virtio_vsock_sock *vvs = vsk->trans;
...@@ -496,8 +621,7 @@ void virtio_transport_notify_buffer_size(struct vsock_sock *vsk, u64 *val) ...@@ -496,8 +621,7 @@ void virtio_transport_notify_buffer_size(struct vsock_sock *vsk, u64 *val)
vvs->buf_alloc = *val; vvs->buf_alloc = *val;
virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM, virtio_transport_send_credit_update(vsk);
NULL);
} }
EXPORT_SYMBOL_GPL(virtio_transport_notify_buffer_size); EXPORT_SYMBOL_GPL(virtio_transport_notify_buffer_size);
...@@ -624,7 +748,6 @@ int virtio_transport_connect(struct vsock_sock *vsk) ...@@ -624,7 +748,6 @@ int virtio_transport_connect(struct vsock_sock *vsk)
{ {
struct virtio_vsock_pkt_info info = { struct virtio_vsock_pkt_info info = {
.op = VIRTIO_VSOCK_OP_REQUEST, .op = VIRTIO_VSOCK_OP_REQUEST,
.type = VIRTIO_VSOCK_TYPE_STREAM,
.vsk = vsk, .vsk = vsk,
}; };
...@@ -636,7 +759,6 @@ int virtio_transport_shutdown(struct vsock_sock *vsk, int mode) ...@@ -636,7 +759,6 @@ int virtio_transport_shutdown(struct vsock_sock *vsk, int mode)
{ {
struct virtio_vsock_pkt_info info = { struct virtio_vsock_pkt_info info = {
.op = VIRTIO_VSOCK_OP_SHUTDOWN, .op = VIRTIO_VSOCK_OP_SHUTDOWN,
.type = VIRTIO_VSOCK_TYPE_STREAM,
.flags = (mode & RCV_SHUTDOWN ? .flags = (mode & RCV_SHUTDOWN ?
VIRTIO_VSOCK_SHUTDOWN_RCV : 0) | VIRTIO_VSOCK_SHUTDOWN_RCV : 0) |
(mode & SEND_SHUTDOWN ? (mode & SEND_SHUTDOWN ?
...@@ -665,7 +787,6 @@ virtio_transport_stream_enqueue(struct vsock_sock *vsk, ...@@ -665,7 +787,6 @@ virtio_transport_stream_enqueue(struct vsock_sock *vsk,
{ {
struct virtio_vsock_pkt_info info = { struct virtio_vsock_pkt_info info = {
.op = VIRTIO_VSOCK_OP_RW, .op = VIRTIO_VSOCK_OP_RW,
.type = VIRTIO_VSOCK_TYPE_STREAM,
.msg = msg, .msg = msg,
.pkt_len = len, .pkt_len = len,
.vsk = vsk, .vsk = vsk,
...@@ -688,7 +809,6 @@ static int virtio_transport_reset(struct vsock_sock *vsk, ...@@ -688,7 +809,6 @@ static int virtio_transport_reset(struct vsock_sock *vsk,
{ {
struct virtio_vsock_pkt_info info = { struct virtio_vsock_pkt_info info = {
.op = VIRTIO_VSOCK_OP_RST, .op = VIRTIO_VSOCK_OP_RST,
.type = VIRTIO_VSOCK_TYPE_STREAM,
.reply = !!pkt, .reply = !!pkt,
.vsk = vsk, .vsk = vsk,
}; };
...@@ -848,7 +968,7 @@ void virtio_transport_release(struct vsock_sock *vsk) ...@@ -848,7 +968,7 @@ void virtio_transport_release(struct vsock_sock *vsk)
struct sock *sk = &vsk->sk; struct sock *sk = &vsk->sk;
bool remove_sock = true; bool remove_sock = true;
if (sk->sk_type == SOCK_STREAM) if (sk->sk_type == SOCK_STREAM || sk->sk_type == SOCK_SEQPACKET)
remove_sock = virtio_transport_close(vsk); remove_sock = virtio_transport_close(vsk);
if (remove_sock) { if (remove_sock) {
...@@ -912,6 +1032,9 @@ virtio_transport_recv_enqueue(struct vsock_sock *vsk, ...@@ -912,6 +1032,9 @@ virtio_transport_recv_enqueue(struct vsock_sock *vsk,
goto out; goto out;
} }
if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SEQ_EOR)
vvs->msg_count++;
/* Try to copy small packets into the buffer of last packet queued, /* Try to copy small packets into the buffer of last packet queued,
* to avoid wasting memory queueing the entire buffer with a small * to avoid wasting memory queueing the entire buffer with a small
* payload. * payload.
...@@ -923,13 +1046,18 @@ virtio_transport_recv_enqueue(struct vsock_sock *vsk, ...@@ -923,13 +1046,18 @@ virtio_transport_recv_enqueue(struct vsock_sock *vsk,
struct virtio_vsock_pkt, list); struct virtio_vsock_pkt, list);
/* If there is space in the last packet queued, we copy the /* If there is space in the last packet queued, we copy the
* new packet in its buffer. * new packet in its buffer. We avoid this if the last packet
* queued has VIRTIO_VSOCK_SEQ_EOR set, because this is
* delimiter of SEQPACKET record, so 'pkt' is the first packet
* of a new record.
*/ */
if (pkt->len <= last_pkt->buf_len - last_pkt->len) { if ((pkt->len <= last_pkt->buf_len - last_pkt->len) &&
!(le32_to_cpu(last_pkt->hdr.flags) & VIRTIO_VSOCK_SEQ_EOR)) {
memcpy(last_pkt->buf + last_pkt->len, pkt->buf, memcpy(last_pkt->buf + last_pkt->len, pkt->buf,
pkt->len); pkt->len);
last_pkt->len += pkt->len; last_pkt->len += pkt->len;
free_pkt = true; free_pkt = true;
last_pkt->hdr.flags |= pkt->hdr.flags;
goto out; goto out;
} }
} }
...@@ -1000,7 +1128,6 @@ virtio_transport_send_response(struct vsock_sock *vsk, ...@@ -1000,7 +1128,6 @@ virtio_transport_send_response(struct vsock_sock *vsk,
{ {
struct virtio_vsock_pkt_info info = { struct virtio_vsock_pkt_info info = {
.op = VIRTIO_VSOCK_OP_RESPONSE, .op = VIRTIO_VSOCK_OP_RESPONSE,
.type = VIRTIO_VSOCK_TYPE_STREAM,
.remote_cid = le64_to_cpu(pkt->hdr.src_cid), .remote_cid = le64_to_cpu(pkt->hdr.src_cid),
.remote_port = le32_to_cpu(pkt->hdr.src_port), .remote_port = le32_to_cpu(pkt->hdr.src_port),
.reply = true, .reply = true,
...@@ -1096,6 +1223,12 @@ virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt, ...@@ -1096,6 +1223,12 @@ virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt,
return 0; return 0;
} }
static bool virtio_transport_valid_type(u16 type)
{
return (type == VIRTIO_VSOCK_TYPE_STREAM) ||
(type == VIRTIO_VSOCK_TYPE_SEQPACKET);
}
/* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex /* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex
* lock. * lock.
*/ */
...@@ -1121,7 +1254,7 @@ void virtio_transport_recv_pkt(struct virtio_transport *t, ...@@ -1121,7 +1254,7 @@ void virtio_transport_recv_pkt(struct virtio_transport *t,
le32_to_cpu(pkt->hdr.buf_alloc), le32_to_cpu(pkt->hdr.buf_alloc),
le32_to_cpu(pkt->hdr.fwd_cnt)); le32_to_cpu(pkt->hdr.fwd_cnt));
if (le16_to_cpu(pkt->hdr.type) != VIRTIO_VSOCK_TYPE_STREAM) { if (!virtio_transport_valid_type(le16_to_cpu(pkt->hdr.type))) {
(void)virtio_transport_reset_no_sock(t, pkt); (void)virtio_transport_reset_no_sock(t, pkt);
goto free_pkt; goto free_pkt;
} }
...@@ -1138,6 +1271,12 @@ void virtio_transport_recv_pkt(struct virtio_transport *t, ...@@ -1138,6 +1271,12 @@ void virtio_transport_recv_pkt(struct virtio_transport *t,
} }
} }
if (virtio_transport_get_type(sk) != le16_to_cpu(pkt->hdr.type)) {
(void)virtio_transport_reset_no_sock(t, pkt);
sock_put(sk);
goto free_pkt;
}
vsk = vsock_sk(sk); vsk = vsock_sk(sk);
lock_sock(sk); lock_sock(sk);
......
...@@ -63,6 +63,8 @@ static int vsock_loopback_cancel_pkt(struct vsock_sock *vsk) ...@@ -63,6 +63,8 @@ static int vsock_loopback_cancel_pkt(struct vsock_sock *vsk)
return 0; return 0;
} }
static bool vsock_loopback_seqpacket_allow(u32 remote_cid);
static struct virtio_transport loopback_transport = { static struct virtio_transport loopback_transport = {
.transport = { .transport = {
.module = THIS_MODULE, .module = THIS_MODULE,
...@@ -89,6 +91,11 @@ static struct virtio_transport loopback_transport = { ...@@ -89,6 +91,11 @@ static struct virtio_transport loopback_transport = {
.stream_is_active = virtio_transport_stream_is_active, .stream_is_active = virtio_transport_stream_is_active,
.stream_allow = virtio_transport_stream_allow, .stream_allow = virtio_transport_stream_allow,
.seqpacket_dequeue = virtio_transport_seqpacket_dequeue,
.seqpacket_enqueue = virtio_transport_seqpacket_enqueue,
.seqpacket_allow = vsock_loopback_seqpacket_allow,
.seqpacket_has_data = virtio_transport_seqpacket_has_data,
.notify_poll_in = virtio_transport_notify_poll_in, .notify_poll_in = virtio_transport_notify_poll_in,
.notify_poll_out = virtio_transport_notify_poll_out, .notify_poll_out = virtio_transport_notify_poll_out,
.notify_recv_init = virtio_transport_notify_recv_init, .notify_recv_init = virtio_transport_notify_recv_init,
...@@ -105,6 +112,11 @@ static struct virtio_transport loopback_transport = { ...@@ -105,6 +112,11 @@ static struct virtio_transport loopback_transport = {
.send_pkt = vsock_loopback_send_pkt, .send_pkt = vsock_loopback_send_pkt,
}; };
static bool vsock_loopback_seqpacket_allow(u32 remote_cid)
{
return true;
}
static void vsock_loopback_work(struct work_struct *work) static void vsock_loopback_work(struct work_struct *work)
{ {
struct vsock_loopback *vsock = struct vsock_loopback *vsock =
......
...@@ -84,7 +84,7 @@ void vsock_wait_remote_close(int fd) ...@@ -84,7 +84,7 @@ void vsock_wait_remote_close(int fd)
} }
/* Connect to <cid, port> and return the file descriptor. */ /* Connect to <cid, port> and return the file descriptor. */
int vsock_stream_connect(unsigned int cid, unsigned int port) static int vsock_connect(unsigned int cid, unsigned int port, int type)
{ {
union { union {
struct sockaddr sa; struct sockaddr sa;
...@@ -101,7 +101,7 @@ int vsock_stream_connect(unsigned int cid, unsigned int port) ...@@ -101,7 +101,7 @@ int vsock_stream_connect(unsigned int cid, unsigned int port)
control_expectln("LISTENING"); control_expectln("LISTENING");
fd = socket(AF_VSOCK, SOCK_STREAM, 0); fd = socket(AF_VSOCK, type, 0);
timeout_begin(TIMEOUT); timeout_begin(TIMEOUT);
do { do {
...@@ -120,11 +120,21 @@ int vsock_stream_connect(unsigned int cid, unsigned int port) ...@@ -120,11 +120,21 @@ int vsock_stream_connect(unsigned int cid, unsigned int port)
return fd; return fd;
} }
int vsock_stream_connect(unsigned int cid, unsigned int port)
{
return vsock_connect(cid, port, SOCK_STREAM);
}
int vsock_seqpacket_connect(unsigned int cid, unsigned int port)
{
return vsock_connect(cid, port, SOCK_SEQPACKET);
}
/* Listen on <cid, port> and return the first incoming connection. The remote /* Listen on <cid, port> and return the first incoming connection. The remote
* address is stored to clientaddrp. clientaddrp may be NULL. * address is stored to clientaddrp. clientaddrp may be NULL.
*/ */
int vsock_stream_accept(unsigned int cid, unsigned int port, static int vsock_accept(unsigned int cid, unsigned int port,
struct sockaddr_vm *clientaddrp) struct sockaddr_vm *clientaddrp, int type)
{ {
union { union {
struct sockaddr sa; struct sockaddr sa;
...@@ -145,7 +155,7 @@ int vsock_stream_accept(unsigned int cid, unsigned int port, ...@@ -145,7 +155,7 @@ int vsock_stream_accept(unsigned int cid, unsigned int port,
int client_fd; int client_fd;
int old_errno; int old_errno;
fd = socket(AF_VSOCK, SOCK_STREAM, 0); fd = socket(AF_VSOCK, type, 0);
if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) { if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
perror("bind"); perror("bind");
...@@ -189,6 +199,18 @@ int vsock_stream_accept(unsigned int cid, unsigned int port, ...@@ -189,6 +199,18 @@ int vsock_stream_accept(unsigned int cid, unsigned int port,
return client_fd; return client_fd;
} }
int vsock_stream_accept(unsigned int cid, unsigned int port,
struct sockaddr_vm *clientaddrp)
{
return vsock_accept(cid, port, clientaddrp, SOCK_STREAM);
}
int vsock_seqpacket_accept(unsigned int cid, unsigned int port,
struct sockaddr_vm *clientaddrp)
{
return vsock_accept(cid, port, clientaddrp, SOCK_SEQPACKET);
}
/* Transmit one byte and check the return value. /* Transmit one byte and check the return value.
* *
* expected_ret: * expected_ret:
......
...@@ -36,8 +36,11 @@ struct test_case { ...@@ -36,8 +36,11 @@ struct test_case {
void init_signals(void); void init_signals(void);
unsigned int parse_cid(const char *str); unsigned int parse_cid(const char *str);
int vsock_stream_connect(unsigned int cid, unsigned int port); int vsock_stream_connect(unsigned int cid, unsigned int port);
int vsock_seqpacket_connect(unsigned int cid, unsigned int port);
int vsock_stream_accept(unsigned int cid, unsigned int port, int vsock_stream_accept(unsigned int cid, unsigned int port,
struct sockaddr_vm *clientaddrp); struct sockaddr_vm *clientaddrp);
int vsock_seqpacket_accept(unsigned int cid, unsigned int port,
struct sockaddr_vm *clientaddrp);
void vsock_wait_remote_close(int fd); void vsock_wait_remote_close(int fd);
void send_byte(int fd, int expected_ret, int flags); void send_byte(int fd, int expected_ret, int flags);
void recv_byte(int fd, int expected_ret, int flags); void recv_byte(int fd, int expected_ret, int flags);
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#include <errno.h> #include <errno.h>
#include <unistd.h> #include <unistd.h>
#include <linux/kernel.h> #include <linux/kernel.h>
#include <sys/types.h>
#include <sys/socket.h>
#include "timeout.h" #include "timeout.h"
#include "control.h" #include "control.h"
...@@ -279,6 +281,110 @@ static void test_stream_msg_peek_server(const struct test_opts *opts) ...@@ -279,6 +281,110 @@ static void test_stream_msg_peek_server(const struct test_opts *opts)
close(fd); close(fd);
} }
#define MESSAGES_CNT 7
static void test_seqpacket_msg_bounds_client(const struct test_opts *opts)
{
int fd;
fd = vsock_seqpacket_connect(opts->peer_cid, 1234);
if (fd < 0) {
perror("connect");
exit(EXIT_FAILURE);
}
/* Send several messages, one with MSG_EOR flag */
for (int i = 0; i < MESSAGES_CNT; i++)
send_byte(fd, 1, 0);
control_writeln("SENDDONE");
close(fd);
}
static void test_seqpacket_msg_bounds_server(const struct test_opts *opts)
{
int fd;
char buf[16];
struct msghdr msg = {0};
struct iovec iov = {0};
fd = vsock_seqpacket_accept(VMADDR_CID_ANY, 1234, NULL);
if (fd < 0) {
perror("accept");
exit(EXIT_FAILURE);
}
control_expectln("SENDDONE");
iov.iov_base = buf;
iov.iov_len = sizeof(buf);
msg.msg_iov = &iov;
msg.msg_iovlen = 1;
for (int i = 0; i < MESSAGES_CNT; i++) {
if (recvmsg(fd, &msg, 0) != 1) {
perror("message bound violated");
exit(EXIT_FAILURE);
}
}
close(fd);
}
#define MESSAGE_TRUNC_SZ 32
static void test_seqpacket_msg_trunc_client(const struct test_opts *opts)
{
int fd;
char buf[MESSAGE_TRUNC_SZ];
fd = vsock_seqpacket_connect(opts->peer_cid, 1234);
if (fd < 0) {
perror("connect");
exit(EXIT_FAILURE);
}
if (send(fd, buf, sizeof(buf), 0) != sizeof(buf)) {
perror("send failed");
exit(EXIT_FAILURE);
}
control_writeln("SENDDONE");
close(fd);
}
static void test_seqpacket_msg_trunc_server(const struct test_opts *opts)
{
int fd;
char buf[MESSAGE_TRUNC_SZ / 2];
struct msghdr msg = {0};
struct iovec iov = {0};
fd = vsock_seqpacket_accept(VMADDR_CID_ANY, 1234, NULL);
if (fd < 0) {
perror("accept");
exit(EXIT_FAILURE);
}
control_expectln("SENDDONE");
iov.iov_base = buf;
iov.iov_len = sizeof(buf);
msg.msg_iov = &iov;
msg.msg_iovlen = 1;
ssize_t ret = recvmsg(fd, &msg, MSG_TRUNC);
if (ret != MESSAGE_TRUNC_SZ) {
printf("%zi\n", ret);
perror("MSG_TRUNC doesn't work");
exit(EXIT_FAILURE);
}
if (!(msg.msg_flags & MSG_TRUNC)) {
fprintf(stderr, "MSG_TRUNC expected\n");
exit(EXIT_FAILURE);
}
close(fd);
}
static struct test_case test_cases[] = { static struct test_case test_cases[] = {
{ {
.name = "SOCK_STREAM connection reset", .name = "SOCK_STREAM connection reset",
...@@ -309,6 +415,16 @@ static struct test_case test_cases[] = { ...@@ -309,6 +415,16 @@ static struct test_case test_cases[] = {
.run_client = test_stream_msg_peek_client, .run_client = test_stream_msg_peek_client,
.run_server = test_stream_msg_peek_server, .run_server = test_stream_msg_peek_server,
}, },
{
.name = "SOCK_SEQPACKET msg bounds",
.run_client = test_seqpacket_msg_bounds_client,
.run_server = test_seqpacket_msg_bounds_server,
},
{
.name = "SOCK_SEQPACKET MSG_TRUNC flag",
.run_client = test_seqpacket_msg_trunc_client,
.run_server = test_seqpacket_msg_trunc_server,
},
{}, {},
}; };
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment