Commit 5aa3bd9b authored by David S. Miller's avatar David S. Miller

Merge branch 'virtio-vsock-seqpacket'

Arseny Krasnov says:

====================
virtio/vsock: introduce SOCK_SEQPACKET support

This patchset implements support of SOCK_SEQPACKET for virtio
transport.
	As SOCK_SEQPACKET guarantees to save record boundaries, so to
do it, new bit for field 'flags' was added: SEQ_EOR. This bit is
set to 1 in last RW packet of message.
	Now as  packets of one socket are not reordered neither on vsock
nor on vhost transport layers, such bit allows to restore original
message on receiver's side. If user's buffer is smaller than message
length, when all out of size data is dropped.
	Maximum length of datagram is limited by 'peer_buf_alloc' value.
	Implementation also supports 'MSG_TRUNC' flags.
	Tests also implemented.

	Thanks to stsp2@yandex.ru for encouragements and initial design
recommendations.
====================
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents 57806b28 184039ee
......@@ -31,7 +31,8 @@
enum {
VHOST_VSOCK_FEATURES = VHOST_FEATURES |
(1ULL << VIRTIO_F_ACCESS_PLATFORM)
(1ULL << VIRTIO_F_ACCESS_PLATFORM) |
(1ULL << VIRTIO_VSOCK_F_SEQPACKET)
};
enum {
......@@ -56,6 +57,7 @@ struct vhost_vsock {
atomic_t queued_replies;
u32 guest_cid;
bool seqpacket_allow;
};
static u32 vhost_transport_get_local_cid(void)
......@@ -112,6 +114,7 @@ vhost_transport_do_send_pkt(struct vhost_vsock *vsock,
size_t nbytes;
size_t iov_len, payload_len;
int head;
bool restore_flag = false;
spin_lock_bh(&vsock->send_pkt_list_lock);
if (list_empty(&vsock->send_pkt_list)) {
......@@ -168,9 +171,26 @@ vhost_transport_do_send_pkt(struct vhost_vsock *vsock,
/* If the packet is greater than the space available in the
* buffer, we split it using multiple buffers.
*/
if (payload_len > iov_len - sizeof(pkt->hdr))
if (payload_len > iov_len - sizeof(pkt->hdr)) {
payload_len = iov_len - sizeof(pkt->hdr);
/* As we are copying pieces of large packet's buffer to
* small rx buffers, headers of packets in rx queue are
* created dynamically and are initialized with header
* of current packet(except length). But in case of
* SOCK_SEQPACKET, we also must clear record delimeter
* bit(VIRTIO_VSOCK_SEQ_EOR). Otherwise, instead of one
* packet with delimeter(which marks end of record),
* there will be sequence of packets with delimeter
* bit set. After initialized header will be copied to
* rx buffer, this bit will be restored.
*/
if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SEQ_EOR) {
pkt->hdr.flags &= ~cpu_to_le32(VIRTIO_VSOCK_SEQ_EOR);
restore_flag = true;
}
}
/* Set the correct length in the header */
pkt->hdr.len = cpu_to_le32(payload_len);
......@@ -204,6 +224,9 @@ vhost_transport_do_send_pkt(struct vhost_vsock *vsock,
* to send it with the next available buffer.
*/
if (pkt->off < pkt->len) {
if (restore_flag)
pkt->hdr.flags |= cpu_to_le32(VIRTIO_VSOCK_SEQ_EOR);
/* We are queueing the same virtio_vsock_pkt to handle
* the remaining bytes, and we want to deliver it
* to monitoring devices in the next iteration.
......@@ -354,8 +377,7 @@ vhost_vsock_alloc_pkt(struct vhost_virtqueue *vq,
return NULL;
}
if (le16_to_cpu(pkt->hdr.type) == VIRTIO_VSOCK_TYPE_STREAM)
pkt->len = le32_to_cpu(pkt->hdr.len);
pkt->len = le32_to_cpu(pkt->hdr.len);
/* No payload */
if (!pkt->len)
......@@ -398,6 +420,8 @@ static bool vhost_vsock_more_replies(struct vhost_vsock *vsock)
return val < vq->num;
}
static bool vhost_transport_seqpacket_allow(u32 remote_cid);
static struct virtio_transport vhost_transport = {
.transport = {
.module = THIS_MODULE,
......@@ -424,6 +448,11 @@ static struct virtio_transport vhost_transport = {
.stream_is_active = virtio_transport_stream_is_active,
.stream_allow = virtio_transport_stream_allow,
.seqpacket_dequeue = virtio_transport_seqpacket_dequeue,
.seqpacket_enqueue = virtio_transport_seqpacket_enqueue,
.seqpacket_allow = vhost_transport_seqpacket_allow,
.seqpacket_has_data = virtio_transport_seqpacket_has_data,
.notify_poll_in = virtio_transport_notify_poll_in,
.notify_poll_out = virtio_transport_notify_poll_out,
.notify_recv_init = virtio_transport_notify_recv_init,
......@@ -441,6 +470,22 @@ static struct virtio_transport vhost_transport = {
.send_pkt = vhost_transport_send_pkt,
};
static bool vhost_transport_seqpacket_allow(u32 remote_cid)
{
struct vhost_vsock *vsock;
bool seqpacket_allow = false;
rcu_read_lock();
vsock = vhost_vsock_get(remote_cid);
if (vsock)
seqpacket_allow = vsock->seqpacket_allow;
rcu_read_unlock();
return seqpacket_allow;
}
static void vhost_vsock_handle_tx_kick(struct vhost_work *work)
{
struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
......@@ -785,6 +830,9 @@ static int vhost_vsock_set_features(struct vhost_vsock *vsock, u64 features)
goto err;
}
if (features & (1ULL << VIRTIO_VSOCK_F_SEQPACKET))
vsock->seqpacket_allow = true;
for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
vq = &vsock->vqs[i];
mutex_lock(&vq->mutex);
......
......@@ -36,6 +36,7 @@ struct virtio_vsock_sock {
u32 rx_bytes;
u32 buf_alloc;
struct list_head rx_queue;
u32 msg_count;
};
struct virtio_vsock_pkt {
......@@ -80,8 +81,17 @@ virtio_transport_dgram_dequeue(struct vsock_sock *vsk,
struct msghdr *msg,
size_t len, int flags);
int
virtio_transport_seqpacket_enqueue(struct vsock_sock *vsk,
struct msghdr *msg,
size_t len);
ssize_t
virtio_transport_seqpacket_dequeue(struct vsock_sock *vsk,
struct msghdr *msg,
int flags);
s64 virtio_transport_stream_has_data(struct vsock_sock *vsk);
s64 virtio_transport_stream_has_space(struct vsock_sock *vsk);
u32 virtio_transport_seqpacket_has_data(struct vsock_sock *vsk);
int virtio_transport_do_socket_init(struct vsock_sock *vsk,
struct vsock_sock *psk);
......
......@@ -135,6 +135,14 @@ struct vsock_transport {
bool (*stream_is_active)(struct vsock_sock *);
bool (*stream_allow)(u32 cid, u32 port);
/* SEQ_PACKET. */
ssize_t (*seqpacket_dequeue)(struct vsock_sock *vsk, struct msghdr *msg,
int flags);
int (*seqpacket_enqueue)(struct vsock_sock *vsk, struct msghdr *msg,
size_t len);
bool (*seqpacket_allow)(u32 remote_cid);
u32 (*seqpacket_has_data)(struct vsock_sock *vsk);
/* Notification. */
int (*notify_poll_in)(struct vsock_sock *, size_t, bool *);
int (*notify_poll_out)(struct vsock_sock *, size_t, bool *);
......
......@@ -9,9 +9,12 @@
#include <linux/tracepoint.h>
TRACE_DEFINE_ENUM(VIRTIO_VSOCK_TYPE_STREAM);
TRACE_DEFINE_ENUM(VIRTIO_VSOCK_TYPE_SEQPACKET);
#define show_type(val) \
__print_symbolic(val, { VIRTIO_VSOCK_TYPE_STREAM, "STREAM" })
__print_symbolic(val, \
{ VIRTIO_VSOCK_TYPE_STREAM, "STREAM" }, \
{ VIRTIO_VSOCK_TYPE_SEQPACKET, "SEQPACKET" })
TRACE_DEFINE_ENUM(VIRTIO_VSOCK_OP_INVALID);
TRACE_DEFINE_ENUM(VIRTIO_VSOCK_OP_REQUEST);
......
......@@ -38,6 +38,9 @@
#include <linux/virtio_ids.h>
#include <linux/virtio_config.h>
/* The feature bitmap for virtio vsock */
#define VIRTIO_VSOCK_F_SEQPACKET 1 /* SOCK_SEQPACKET supported */
struct virtio_vsock_config {
__le64 guest_cid;
} __attribute__((packed));
......@@ -65,6 +68,7 @@ struct virtio_vsock_hdr {
enum virtio_vsock_type {
VIRTIO_VSOCK_TYPE_STREAM = 1,
VIRTIO_VSOCK_TYPE_SEQPACKET = 2,
};
enum virtio_vsock_op {
......@@ -91,4 +95,9 @@ enum virtio_vsock_shutdown {
VIRTIO_VSOCK_SHUTDOWN_SEND = 2,
};
/* VIRTIO_VSOCK_OP_RW flags values */
enum virtio_vsock_rw {
VIRTIO_VSOCK_SEQ_EOR = 1,
};
#endif /* _UAPI_LINUX_VIRTIO_VSOCK_H */
This diff is collapsed.
......@@ -62,6 +62,7 @@ struct virtio_vsock {
struct virtio_vsock_event event_list[8];
u32 guest_cid;
bool seqpacket_allow;
};
static u32 virtio_transport_get_local_cid(void)
......@@ -443,6 +444,8 @@ static void virtio_vsock_rx_done(struct virtqueue *vq)
queue_work(virtio_vsock_workqueue, &vsock->rx_work);
}
static bool virtio_transport_seqpacket_allow(u32 remote_cid);
static struct virtio_transport virtio_transport = {
.transport = {
.module = THIS_MODULE,
......@@ -469,6 +472,11 @@ static struct virtio_transport virtio_transport = {
.stream_is_active = virtio_transport_stream_is_active,
.stream_allow = virtio_transport_stream_allow,
.seqpacket_dequeue = virtio_transport_seqpacket_dequeue,
.seqpacket_enqueue = virtio_transport_seqpacket_enqueue,
.seqpacket_allow = virtio_transport_seqpacket_allow,
.seqpacket_has_data = virtio_transport_seqpacket_has_data,
.notify_poll_in = virtio_transport_notify_poll_in,
.notify_poll_out = virtio_transport_notify_poll_out,
.notify_recv_init = virtio_transport_notify_recv_init,
......@@ -485,6 +493,19 @@ static struct virtio_transport virtio_transport = {
.send_pkt = virtio_transport_send_pkt,
};
static bool virtio_transport_seqpacket_allow(u32 remote_cid)
{
struct virtio_vsock *vsock;
bool seqpacket_allow;
rcu_read_lock();
vsock = rcu_dereference(the_virtio_vsock);
seqpacket_allow = vsock->seqpacket_allow;
rcu_read_unlock();
return seqpacket_allow;
}
static void virtio_transport_rx_work(struct work_struct *work)
{
struct virtio_vsock *vsock =
......@@ -608,10 +629,14 @@ static int virtio_vsock_probe(struct virtio_device *vdev)
vsock->event_run = true;
mutex_unlock(&vsock->event_lock);
if (virtio_has_feature(vdev, VIRTIO_VSOCK_F_SEQPACKET))
vsock->seqpacket_allow = true;
vdev->priv = vsock;
rcu_assign_pointer(the_virtio_vsock, vsock);
mutex_unlock(&the_virtio_vsock_mutex);
return 0;
out:
......@@ -695,6 +720,7 @@ static struct virtio_device_id id_table[] = {
};
static unsigned int features[] = {
VIRTIO_VSOCK_F_SEQPACKET
};
static struct virtio_driver virtio_vsock_driver = {
......
......@@ -74,6 +74,10 @@ virtio_transport_alloc_pkt(struct virtio_vsock_pkt_info *info,
err = memcpy_from_msg(pkt->buf, info->msg, len);
if (err)
goto out;
if (msg_data_left(info->msg) == 0 &&
info->type == VIRTIO_VSOCK_TYPE_SEQPACKET)
pkt->hdr.flags |= cpu_to_le32(VIRTIO_VSOCK_SEQ_EOR);
}
trace_virtio_transport_alloc_pkt(src_cid, src_port,
......@@ -165,6 +169,14 @@ void virtio_transport_deliver_tap_pkt(struct virtio_vsock_pkt *pkt)
}
EXPORT_SYMBOL_GPL(virtio_transport_deliver_tap_pkt);
static u16 virtio_transport_get_type(struct sock *sk)
{
if (sk->sk_type == SOCK_STREAM)
return VIRTIO_VSOCK_TYPE_STREAM;
else
return VIRTIO_VSOCK_TYPE_SEQPACKET;
}
/* This function can only be used on connecting/connected sockets,
* since a socket assigned to a transport is required.
*
......@@ -179,6 +191,8 @@ static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
struct virtio_vsock_pkt *pkt;
u32 pkt_len = info->pkt_len;
info->type = virtio_transport_get_type(sk_vsock(vsk));
t_ops = virtio_transport_get_ops(vsk);
if (unlikely(!t_ops))
return -EFAULT;
......@@ -269,13 +283,10 @@ void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit)
}
EXPORT_SYMBOL_GPL(virtio_transport_put_credit);
static int virtio_transport_send_credit_update(struct vsock_sock *vsk,
int type,
struct virtio_vsock_hdr *hdr)
static int virtio_transport_send_credit_update(struct vsock_sock *vsk)
{
struct virtio_vsock_pkt_info info = {
.op = VIRTIO_VSOCK_OP_CREDIT_UPDATE,
.type = type,
.vsk = vsk,
};
......@@ -383,11 +394,8 @@ virtio_transport_stream_do_dequeue(struct vsock_sock *vsk,
* messages, we set the limit to a high value. TODO: experiment
* with different values.
*/
if (free_space < VIRTIO_VSOCK_MAX_PKT_BUF_SIZE) {
virtio_transport_send_credit_update(vsk,
VIRTIO_VSOCK_TYPE_STREAM,
NULL);
}
if (free_space < VIRTIO_VSOCK_MAX_PKT_BUF_SIZE)
virtio_transport_send_credit_update(vsk);
return total;
......@@ -397,6 +405,78 @@ virtio_transport_stream_do_dequeue(struct vsock_sock *vsk,
return err;
}
static int virtio_transport_seqpacket_do_dequeue(struct vsock_sock *vsk,
struct msghdr *msg,
int flags)
{
struct virtio_vsock_sock *vvs = vsk->trans;
struct virtio_vsock_pkt *pkt;
int dequeued_len = 0;
size_t user_buf_len = msg_data_left(msg);
bool copy_failed = false;
bool msg_ready = false;
spin_lock_bh(&vvs->rx_lock);
if (vvs->msg_count == 0) {
spin_unlock_bh(&vvs->rx_lock);
return 0;
}
while (!msg_ready) {
pkt = list_first_entry(&vvs->rx_queue, struct virtio_vsock_pkt, list);
if (!copy_failed) {
size_t pkt_len;
size_t bytes_to_copy;
pkt_len = (size_t)le32_to_cpu(pkt->hdr.len);
bytes_to_copy = min(user_buf_len, pkt_len);
if (bytes_to_copy) {
int err;
/* sk_lock is held by caller so no one else can dequeue.
* Unlock rx_lock since memcpy_to_msg() may sleep.
*/
spin_unlock_bh(&vvs->rx_lock);
err = memcpy_to_msg(msg, pkt->buf, bytes_to_copy);
if (err) {
/* Copy of message failed, set flag to skip
* copy path for rest of fragments. Rest of
* fragments will be freed without copy.
*/
copy_failed = true;
dequeued_len = err;
} else {
user_buf_len -= bytes_to_copy;
}
spin_lock_bh(&vvs->rx_lock);
}
if (dequeued_len >= 0)
dequeued_len += pkt_len;
}
if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SEQ_EOR) {
msg_ready = true;
vvs->msg_count--;
}
virtio_transport_dec_rx_pkt(vvs, pkt);
list_del(&pkt->list);
virtio_transport_free_pkt(pkt);
}
spin_unlock_bh(&vvs->rx_lock);
virtio_transport_send_credit_update(vsk);
return dequeued_len;
}
ssize_t
virtio_transport_stream_dequeue(struct vsock_sock *vsk,
struct msghdr *msg,
......@@ -409,6 +489,38 @@ virtio_transport_stream_dequeue(struct vsock_sock *vsk,
}
EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue);
ssize_t
virtio_transport_seqpacket_dequeue(struct vsock_sock *vsk,
struct msghdr *msg,
int flags)
{
if (flags & MSG_PEEK)
return -EOPNOTSUPP;
return virtio_transport_seqpacket_do_dequeue(vsk, msg, flags);
}
EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_dequeue);
int
virtio_transport_seqpacket_enqueue(struct vsock_sock *vsk,
struct msghdr *msg,
size_t len)
{
struct virtio_vsock_sock *vvs = vsk->trans;
spin_lock_bh(&vvs->tx_lock);
if (len > vvs->peer_buf_alloc) {
spin_unlock_bh(&vvs->tx_lock);
return -EMSGSIZE;
}
spin_unlock_bh(&vvs->tx_lock);
return virtio_transport_stream_enqueue(vsk, msg, len);
}
EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_enqueue);
int
virtio_transport_dgram_dequeue(struct vsock_sock *vsk,
struct msghdr *msg,
......@@ -431,6 +543,19 @@ s64 virtio_transport_stream_has_data(struct vsock_sock *vsk)
}
EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data);
u32 virtio_transport_seqpacket_has_data(struct vsock_sock *vsk)
{
struct virtio_vsock_sock *vvs = vsk->trans;
u32 msg_count;
spin_lock_bh(&vvs->rx_lock);
msg_count = vvs->msg_count;
spin_unlock_bh(&vvs->rx_lock);
return msg_count;
}
EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_has_data);
static s64 virtio_transport_has_space(struct vsock_sock *vsk)
{
struct virtio_vsock_sock *vvs = vsk->trans;
......@@ -496,8 +621,7 @@ void virtio_transport_notify_buffer_size(struct vsock_sock *vsk, u64 *val)
vvs->buf_alloc = *val;
virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM,
NULL);
virtio_transport_send_credit_update(vsk);
}
EXPORT_SYMBOL_GPL(virtio_transport_notify_buffer_size);
......@@ -624,7 +748,6 @@ int virtio_transport_connect(struct vsock_sock *vsk)
{
struct virtio_vsock_pkt_info info = {
.op = VIRTIO_VSOCK_OP_REQUEST,
.type = VIRTIO_VSOCK_TYPE_STREAM,
.vsk = vsk,
};
......@@ -636,7 +759,6 @@ int virtio_transport_shutdown(struct vsock_sock *vsk, int mode)
{
struct virtio_vsock_pkt_info info = {
.op = VIRTIO_VSOCK_OP_SHUTDOWN,
.type = VIRTIO_VSOCK_TYPE_STREAM,
.flags = (mode & RCV_SHUTDOWN ?
VIRTIO_VSOCK_SHUTDOWN_RCV : 0) |
(mode & SEND_SHUTDOWN ?
......@@ -665,7 +787,6 @@ virtio_transport_stream_enqueue(struct vsock_sock *vsk,
{
struct virtio_vsock_pkt_info info = {
.op = VIRTIO_VSOCK_OP_RW,
.type = VIRTIO_VSOCK_TYPE_STREAM,
.msg = msg,
.pkt_len = len,
.vsk = vsk,
......@@ -688,7 +809,6 @@ static int virtio_transport_reset(struct vsock_sock *vsk,
{
struct virtio_vsock_pkt_info info = {
.op = VIRTIO_VSOCK_OP_RST,
.type = VIRTIO_VSOCK_TYPE_STREAM,
.reply = !!pkt,
.vsk = vsk,
};
......@@ -848,7 +968,7 @@ void virtio_transport_release(struct vsock_sock *vsk)
struct sock *sk = &vsk->sk;
bool remove_sock = true;
if (sk->sk_type == SOCK_STREAM)
if (sk->sk_type == SOCK_STREAM || sk->sk_type == SOCK_SEQPACKET)
remove_sock = virtio_transport_close(vsk);
if (remove_sock) {
......@@ -912,6 +1032,9 @@ virtio_transport_recv_enqueue(struct vsock_sock *vsk,
goto out;
}
if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SEQ_EOR)
vvs->msg_count++;
/* Try to copy small packets into the buffer of last packet queued,
* to avoid wasting memory queueing the entire buffer with a small
* payload.
......@@ -923,13 +1046,18 @@ virtio_transport_recv_enqueue(struct vsock_sock *vsk,
struct virtio_vsock_pkt, list);
/* If there is space in the last packet queued, we copy the
* new packet in its buffer.
* new packet in its buffer. We avoid this if the last packet
* queued has VIRTIO_VSOCK_SEQ_EOR set, because this is
* delimiter of SEQPACKET record, so 'pkt' is the first packet
* of a new record.
*/
if (pkt->len <= last_pkt->buf_len - last_pkt->len) {
if ((pkt->len <= last_pkt->buf_len - last_pkt->len) &&
!(le32_to_cpu(last_pkt->hdr.flags) & VIRTIO_VSOCK_SEQ_EOR)) {
memcpy(last_pkt->buf + last_pkt->len, pkt->buf,
pkt->len);
last_pkt->len += pkt->len;
free_pkt = true;
last_pkt->hdr.flags |= pkt->hdr.flags;
goto out;
}
}
......@@ -1000,7 +1128,6 @@ virtio_transport_send_response(struct vsock_sock *vsk,
{
struct virtio_vsock_pkt_info info = {
.op = VIRTIO_VSOCK_OP_RESPONSE,
.type = VIRTIO_VSOCK_TYPE_STREAM,
.remote_cid = le64_to_cpu(pkt->hdr.src_cid),
.remote_port = le32_to_cpu(pkt->hdr.src_port),
.reply = true,
......@@ -1096,6 +1223,12 @@ virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt,
return 0;
}
static bool virtio_transport_valid_type(u16 type)
{
return (type == VIRTIO_VSOCK_TYPE_STREAM) ||
(type == VIRTIO_VSOCK_TYPE_SEQPACKET);
}
/* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex
* lock.
*/
......@@ -1121,7 +1254,7 @@ void virtio_transport_recv_pkt(struct virtio_transport *t,
le32_to_cpu(pkt->hdr.buf_alloc),
le32_to_cpu(pkt->hdr.fwd_cnt));
if (le16_to_cpu(pkt->hdr.type) != VIRTIO_VSOCK_TYPE_STREAM) {
if (!virtio_transport_valid_type(le16_to_cpu(pkt->hdr.type))) {
(void)virtio_transport_reset_no_sock(t, pkt);
goto free_pkt;
}
......@@ -1138,6 +1271,12 @@ void virtio_transport_recv_pkt(struct virtio_transport *t,
}
}
if (virtio_transport_get_type(sk) != le16_to_cpu(pkt->hdr.type)) {
(void)virtio_transport_reset_no_sock(t, pkt);
sock_put(sk);
goto free_pkt;
}
vsk = vsock_sk(sk);
lock_sock(sk);
......
......@@ -63,6 +63,8 @@ static int vsock_loopback_cancel_pkt(struct vsock_sock *vsk)
return 0;
}
static bool vsock_loopback_seqpacket_allow(u32 remote_cid);
static struct virtio_transport loopback_transport = {
.transport = {
.module = THIS_MODULE,
......@@ -89,6 +91,11 @@ static struct virtio_transport loopback_transport = {
.stream_is_active = virtio_transport_stream_is_active,
.stream_allow = virtio_transport_stream_allow,
.seqpacket_dequeue = virtio_transport_seqpacket_dequeue,
.seqpacket_enqueue = virtio_transport_seqpacket_enqueue,
.seqpacket_allow = vsock_loopback_seqpacket_allow,
.seqpacket_has_data = virtio_transport_seqpacket_has_data,
.notify_poll_in = virtio_transport_notify_poll_in,
.notify_poll_out = virtio_transport_notify_poll_out,
.notify_recv_init = virtio_transport_notify_recv_init,
......@@ -105,6 +112,11 @@ static struct virtio_transport loopback_transport = {
.send_pkt = vsock_loopback_send_pkt,
};
static bool vsock_loopback_seqpacket_allow(u32 remote_cid)
{
return true;
}
static void vsock_loopback_work(struct work_struct *work)
{
struct vsock_loopback *vsock =
......
......@@ -84,7 +84,7 @@ void vsock_wait_remote_close(int fd)
}
/* Connect to <cid, port> and return the file descriptor. */
int vsock_stream_connect(unsigned int cid, unsigned int port)
static int vsock_connect(unsigned int cid, unsigned int port, int type)
{
union {
struct sockaddr sa;
......@@ -101,7 +101,7 @@ int vsock_stream_connect(unsigned int cid, unsigned int port)
control_expectln("LISTENING");
fd = socket(AF_VSOCK, SOCK_STREAM, 0);
fd = socket(AF_VSOCK, type, 0);
timeout_begin(TIMEOUT);
do {
......@@ -120,11 +120,21 @@ int vsock_stream_connect(unsigned int cid, unsigned int port)
return fd;
}
int vsock_stream_connect(unsigned int cid, unsigned int port)
{
return vsock_connect(cid, port, SOCK_STREAM);
}
int vsock_seqpacket_connect(unsigned int cid, unsigned int port)
{
return vsock_connect(cid, port, SOCK_SEQPACKET);
}
/* Listen on <cid, port> and return the first incoming connection. The remote
* address is stored to clientaddrp. clientaddrp may be NULL.
*/
int vsock_stream_accept(unsigned int cid, unsigned int port,
struct sockaddr_vm *clientaddrp)
static int vsock_accept(unsigned int cid, unsigned int port,
struct sockaddr_vm *clientaddrp, int type)
{
union {
struct sockaddr sa;
......@@ -145,7 +155,7 @@ int vsock_stream_accept(unsigned int cid, unsigned int port,
int client_fd;
int old_errno;
fd = socket(AF_VSOCK, SOCK_STREAM, 0);
fd = socket(AF_VSOCK, type, 0);
if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
perror("bind");
......@@ -189,6 +199,18 @@ int vsock_stream_accept(unsigned int cid, unsigned int port,
return client_fd;
}
int vsock_stream_accept(unsigned int cid, unsigned int port,
struct sockaddr_vm *clientaddrp)
{
return vsock_accept(cid, port, clientaddrp, SOCK_STREAM);
}
int vsock_seqpacket_accept(unsigned int cid, unsigned int port,
struct sockaddr_vm *clientaddrp)
{
return vsock_accept(cid, port, clientaddrp, SOCK_SEQPACKET);
}
/* Transmit one byte and check the return value.
*
* expected_ret:
......
......@@ -36,8 +36,11 @@ struct test_case {
void init_signals(void);
unsigned int parse_cid(const char *str);
int vsock_stream_connect(unsigned int cid, unsigned int port);
int vsock_seqpacket_connect(unsigned int cid, unsigned int port);
int vsock_stream_accept(unsigned int cid, unsigned int port,
struct sockaddr_vm *clientaddrp);
int vsock_seqpacket_accept(unsigned int cid, unsigned int port,
struct sockaddr_vm *clientaddrp);
void vsock_wait_remote_close(int fd);
void send_byte(int fd, int expected_ret, int flags);
void recv_byte(int fd, int expected_ret, int flags);
......
......@@ -14,6 +14,8 @@
#include <errno.h>
#include <unistd.h>
#include <linux/kernel.h>
#include <sys/types.h>
#include <sys/socket.h>
#include "timeout.h"
#include "control.h"
......@@ -279,6 +281,110 @@ static void test_stream_msg_peek_server(const struct test_opts *opts)
close(fd);
}
#define MESSAGES_CNT 7
static void test_seqpacket_msg_bounds_client(const struct test_opts *opts)
{
int fd;
fd = vsock_seqpacket_connect(opts->peer_cid, 1234);
if (fd < 0) {
perror("connect");
exit(EXIT_FAILURE);
}
/* Send several messages, one with MSG_EOR flag */
for (int i = 0; i < MESSAGES_CNT; i++)
send_byte(fd, 1, 0);
control_writeln("SENDDONE");
close(fd);
}
static void test_seqpacket_msg_bounds_server(const struct test_opts *opts)
{
int fd;
char buf[16];
struct msghdr msg = {0};
struct iovec iov = {0};
fd = vsock_seqpacket_accept(VMADDR_CID_ANY, 1234, NULL);
if (fd < 0) {
perror("accept");
exit(EXIT_FAILURE);
}
control_expectln("SENDDONE");
iov.iov_base = buf;
iov.iov_len = sizeof(buf);
msg.msg_iov = &iov;
msg.msg_iovlen = 1;
for (int i = 0; i < MESSAGES_CNT; i++) {
if (recvmsg(fd, &msg, 0) != 1) {
perror("message bound violated");
exit(EXIT_FAILURE);
}
}
close(fd);
}
#define MESSAGE_TRUNC_SZ 32
static void test_seqpacket_msg_trunc_client(const struct test_opts *opts)
{
int fd;
char buf[MESSAGE_TRUNC_SZ];
fd = vsock_seqpacket_connect(opts->peer_cid, 1234);
if (fd < 0) {
perror("connect");
exit(EXIT_FAILURE);
}
if (send(fd, buf, sizeof(buf), 0) != sizeof(buf)) {
perror("send failed");
exit(EXIT_FAILURE);
}
control_writeln("SENDDONE");
close(fd);
}
static void test_seqpacket_msg_trunc_server(const struct test_opts *opts)
{
int fd;
char buf[MESSAGE_TRUNC_SZ / 2];
struct msghdr msg = {0};
struct iovec iov = {0};
fd = vsock_seqpacket_accept(VMADDR_CID_ANY, 1234, NULL);
if (fd < 0) {
perror("accept");
exit(EXIT_FAILURE);
}
control_expectln("SENDDONE");
iov.iov_base = buf;
iov.iov_len = sizeof(buf);
msg.msg_iov = &iov;
msg.msg_iovlen = 1;
ssize_t ret = recvmsg(fd, &msg, MSG_TRUNC);
if (ret != MESSAGE_TRUNC_SZ) {
printf("%zi\n", ret);
perror("MSG_TRUNC doesn't work");
exit(EXIT_FAILURE);
}
if (!(msg.msg_flags & MSG_TRUNC)) {
fprintf(stderr, "MSG_TRUNC expected\n");
exit(EXIT_FAILURE);
}
close(fd);
}
static struct test_case test_cases[] = {
{
.name = "SOCK_STREAM connection reset",
......@@ -309,6 +415,16 @@ static struct test_case test_cases[] = {
.run_client = test_stream_msg_peek_client,
.run_server = test_stream_msg_peek_server,
},
{
.name = "SOCK_SEQPACKET msg bounds",
.run_client = test_seqpacket_msg_bounds_client,
.run_server = test_seqpacket_msg_bounds_server,
},
{
.name = "SOCK_SEQPACKET MSG_TRUNC flag",
.run_client = test_seqpacket_msg_trunc_client,
.run_server = test_seqpacket_msg_trunc_server,
},
{},
};
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment