Commit f340208f authored by Linus Torvalds's avatar Linus Torvalds

Merge tag 'for_linus' of git://git.kernel.org/pub/scm/linux/kernel/git/mst/vhost

Pull virtio fixes from Michael Tsirkin:
 "Several fixes, some of them for CVEs"

* tag 'for_linus' of git://git.kernel.org/pub/scm/linux/kernel/git/mst/vhost:
  vhost: scsi: add weight support
  vhost: vsock: add weight support
  vhost_net: fix possible infinite loop
  vhost: introduce vhost_exceeds_weight()
  virtio: Fix indentation of VIRTIO_MMIO
  virtio: add unlikely() to WARN_ON_ONCE()
parents f2c7c76c c1ea02f1
...@@ -604,12 +604,6 @@ static size_t init_iov_iter(struct vhost_virtqueue *vq, struct iov_iter *iter, ...@@ -604,12 +604,6 @@ static size_t init_iov_iter(struct vhost_virtqueue *vq, struct iov_iter *iter,
return iov_iter_count(iter); return iov_iter_count(iter);
} }
static bool vhost_exceeds_weight(int pkts, int total_len)
{
return total_len >= VHOST_NET_WEIGHT ||
pkts >= VHOST_NET_PKT_WEIGHT;
}
static int get_tx_bufs(struct vhost_net *net, static int get_tx_bufs(struct vhost_net *net,
struct vhost_net_virtqueue *nvq, struct vhost_net_virtqueue *nvq,
struct msghdr *msg, struct msghdr *msg,
...@@ -779,7 +773,7 @@ static void handle_tx_copy(struct vhost_net *net, struct socket *sock) ...@@ -779,7 +773,7 @@ static void handle_tx_copy(struct vhost_net *net, struct socket *sock)
int sent_pkts = 0; int sent_pkts = 0;
bool sock_can_batch = (sock->sk->sk_sndbuf == INT_MAX); bool sock_can_batch = (sock->sk->sk_sndbuf == INT_MAX);
for (;;) { do {
bool busyloop_intr = false; bool busyloop_intr = false;
if (nvq->done_idx == VHOST_NET_BATCH) if (nvq->done_idx == VHOST_NET_BATCH)
...@@ -845,11 +839,7 @@ static void handle_tx_copy(struct vhost_net *net, struct socket *sock) ...@@ -845,11 +839,7 @@ static void handle_tx_copy(struct vhost_net *net, struct socket *sock)
vq->heads[nvq->done_idx].id = cpu_to_vhost32(vq, head); vq->heads[nvq->done_idx].id = cpu_to_vhost32(vq, head);
vq->heads[nvq->done_idx].len = 0; vq->heads[nvq->done_idx].len = 0;
++nvq->done_idx; ++nvq->done_idx;
if (vhost_exceeds_weight(++sent_pkts, total_len)) { } while (likely(!vhost_exceeds_weight(vq, ++sent_pkts, total_len)));
vhost_poll_queue(&vq->poll);
break;
}
}
vhost_tx_batch(net, nvq, sock, &msg); vhost_tx_batch(net, nvq, sock, &msg);
} }
...@@ -874,7 +864,7 @@ static void handle_tx_zerocopy(struct vhost_net *net, struct socket *sock) ...@@ -874,7 +864,7 @@ static void handle_tx_zerocopy(struct vhost_net *net, struct socket *sock)
bool zcopy_used; bool zcopy_used;
int sent_pkts = 0; int sent_pkts = 0;
for (;;) { do {
bool busyloop_intr; bool busyloop_intr;
/* Release DMAs done buffers first */ /* Release DMAs done buffers first */
...@@ -951,11 +941,7 @@ static void handle_tx_zerocopy(struct vhost_net *net, struct socket *sock) ...@@ -951,11 +941,7 @@ static void handle_tx_zerocopy(struct vhost_net *net, struct socket *sock)
else else
vhost_zerocopy_signal_used(net, vq); vhost_zerocopy_signal_used(net, vq);
vhost_net_tx_packet(net); vhost_net_tx_packet(net);
if (unlikely(vhost_exceeds_weight(++sent_pkts, total_len))) { } while (likely(!vhost_exceeds_weight(vq, ++sent_pkts, total_len)));
vhost_poll_queue(&vq->poll);
break;
}
}
} }
/* Expects to be always run from workqueue - which acts as /* Expects to be always run from workqueue - which acts as
...@@ -1153,8 +1139,11 @@ static void handle_rx(struct vhost_net *net) ...@@ -1153,8 +1139,11 @@ static void handle_rx(struct vhost_net *net)
vq->log : NULL; vq->log : NULL;
mergeable = vhost_has_feature(vq, VIRTIO_NET_F_MRG_RXBUF); mergeable = vhost_has_feature(vq, VIRTIO_NET_F_MRG_RXBUF);
while ((sock_len = vhost_net_rx_peek_head_len(net, sock->sk, do {
&busyloop_intr))) { sock_len = vhost_net_rx_peek_head_len(net, sock->sk,
&busyloop_intr);
if (!sock_len)
break;
sock_len += sock_hlen; sock_len += sock_hlen;
vhost_len = sock_len + vhost_hlen; vhost_len = sock_len + vhost_hlen;
headcount = get_rx_bufs(vq, vq->heads + nvq->done_idx, headcount = get_rx_bufs(vq, vq->heads + nvq->done_idx,
...@@ -1239,14 +1228,11 @@ static void handle_rx(struct vhost_net *net) ...@@ -1239,14 +1228,11 @@ static void handle_rx(struct vhost_net *net)
vhost_log_write(vq, vq_log, log, vhost_len, vhost_log_write(vq, vq_log, log, vhost_len,
vq->iov, in); vq->iov, in);
total_len += vhost_len; total_len += vhost_len;
if (unlikely(vhost_exceeds_weight(++recv_pkts, total_len))) { } while (likely(!vhost_exceeds_weight(vq, ++recv_pkts, total_len)));
vhost_poll_queue(&vq->poll);
goto out;
}
}
if (unlikely(busyloop_intr)) if (unlikely(busyloop_intr))
vhost_poll_queue(&vq->poll); vhost_poll_queue(&vq->poll);
else else if (!sock_len)
vhost_net_enable_vq(net, vq); vhost_net_enable_vq(net, vq);
out: out:
vhost_net_signal_used(nvq); vhost_net_signal_used(nvq);
...@@ -1338,7 +1324,8 @@ static int vhost_net_open(struct inode *inode, struct file *f) ...@@ -1338,7 +1324,8 @@ static int vhost_net_open(struct inode *inode, struct file *f)
vhost_net_buf_init(&n->vqs[i].rxq); vhost_net_buf_init(&n->vqs[i].rxq);
} }
vhost_dev_init(dev, vqs, VHOST_NET_VQ_MAX, vhost_dev_init(dev, vqs, VHOST_NET_VQ_MAX,
UIO_MAXIOV + VHOST_NET_BATCH); UIO_MAXIOV + VHOST_NET_BATCH,
VHOST_NET_PKT_WEIGHT, VHOST_NET_WEIGHT);
vhost_poll_init(n->poll + VHOST_NET_VQ_TX, handle_tx_net, EPOLLOUT, dev); vhost_poll_init(n->poll + VHOST_NET_VQ_TX, handle_tx_net, EPOLLOUT, dev);
vhost_poll_init(n->poll + VHOST_NET_VQ_RX, handle_rx_net, EPOLLIN, dev); vhost_poll_init(n->poll + VHOST_NET_VQ_RX, handle_rx_net, EPOLLIN, dev);
......
...@@ -57,6 +57,12 @@ ...@@ -57,6 +57,12 @@
#define VHOST_SCSI_PREALLOC_UPAGES 2048 #define VHOST_SCSI_PREALLOC_UPAGES 2048
#define VHOST_SCSI_PREALLOC_PROT_SGLS 2048 #define VHOST_SCSI_PREALLOC_PROT_SGLS 2048
/* Max number of requests before requeueing the job.
* Using this limit prevents one virtqueue from starving others with
* request.
*/
#define VHOST_SCSI_WEIGHT 256
struct vhost_scsi_inflight { struct vhost_scsi_inflight {
/* Wait for the flush operation to finish */ /* Wait for the flush operation to finish */
struct completion comp; struct completion comp;
...@@ -912,7 +918,7 @@ vhost_scsi_handle_vq(struct vhost_scsi *vs, struct vhost_virtqueue *vq) ...@@ -912,7 +918,7 @@ vhost_scsi_handle_vq(struct vhost_scsi *vs, struct vhost_virtqueue *vq)
struct iov_iter in_iter, prot_iter, data_iter; struct iov_iter in_iter, prot_iter, data_iter;
u64 tag; u64 tag;
u32 exp_data_len, data_direction; u32 exp_data_len, data_direction;
int ret, prot_bytes; int ret, prot_bytes, c = 0;
u16 lun; u16 lun;
u8 task_attr; u8 task_attr;
bool t10_pi = vhost_has_feature(vq, VIRTIO_SCSI_F_T10_PI); bool t10_pi = vhost_has_feature(vq, VIRTIO_SCSI_F_T10_PI);
...@@ -932,7 +938,7 @@ vhost_scsi_handle_vq(struct vhost_scsi *vs, struct vhost_virtqueue *vq) ...@@ -932,7 +938,7 @@ vhost_scsi_handle_vq(struct vhost_scsi *vs, struct vhost_virtqueue *vq)
vhost_disable_notify(&vs->dev, vq); vhost_disable_notify(&vs->dev, vq);
for (;;) { do {
ret = vhost_scsi_get_desc(vs, vq, &vc); ret = vhost_scsi_get_desc(vs, vq, &vc);
if (ret) if (ret)
goto err; goto err;
...@@ -1112,7 +1118,7 @@ vhost_scsi_handle_vq(struct vhost_scsi *vs, struct vhost_virtqueue *vq) ...@@ -1112,7 +1118,7 @@ vhost_scsi_handle_vq(struct vhost_scsi *vs, struct vhost_virtqueue *vq)
break; break;
else if (ret == -EIO) else if (ret == -EIO)
vhost_scsi_send_bad_target(vs, vq, vc.head, vc.out); vhost_scsi_send_bad_target(vs, vq, vc.head, vc.out);
} } while (likely(!vhost_exceeds_weight(vq, ++c, 0)));
out: out:
mutex_unlock(&vq->mutex); mutex_unlock(&vq->mutex);
} }
...@@ -1171,7 +1177,7 @@ vhost_scsi_ctl_handle_vq(struct vhost_scsi *vs, struct vhost_virtqueue *vq) ...@@ -1171,7 +1177,7 @@ vhost_scsi_ctl_handle_vq(struct vhost_scsi *vs, struct vhost_virtqueue *vq)
} v_req; } v_req;
struct vhost_scsi_ctx vc; struct vhost_scsi_ctx vc;
size_t typ_size; size_t typ_size;
int ret; int ret, c = 0;
mutex_lock(&vq->mutex); mutex_lock(&vq->mutex);
/* /*
...@@ -1185,7 +1191,7 @@ vhost_scsi_ctl_handle_vq(struct vhost_scsi *vs, struct vhost_virtqueue *vq) ...@@ -1185,7 +1191,7 @@ vhost_scsi_ctl_handle_vq(struct vhost_scsi *vs, struct vhost_virtqueue *vq)
vhost_disable_notify(&vs->dev, vq); vhost_disable_notify(&vs->dev, vq);
for (;;) { do {
ret = vhost_scsi_get_desc(vs, vq, &vc); ret = vhost_scsi_get_desc(vs, vq, &vc);
if (ret) if (ret)
goto err; goto err;
...@@ -1264,7 +1270,7 @@ vhost_scsi_ctl_handle_vq(struct vhost_scsi *vs, struct vhost_virtqueue *vq) ...@@ -1264,7 +1270,7 @@ vhost_scsi_ctl_handle_vq(struct vhost_scsi *vs, struct vhost_virtqueue *vq)
break; break;
else if (ret == -EIO) else if (ret == -EIO)
vhost_scsi_send_bad_target(vs, vq, vc.head, vc.out); vhost_scsi_send_bad_target(vs, vq, vc.head, vc.out);
} } while (likely(!vhost_exceeds_weight(vq, ++c, 0)));
out: out:
mutex_unlock(&vq->mutex); mutex_unlock(&vq->mutex);
} }
...@@ -1621,7 +1627,8 @@ static int vhost_scsi_open(struct inode *inode, struct file *f) ...@@ -1621,7 +1627,8 @@ static int vhost_scsi_open(struct inode *inode, struct file *f)
vqs[i] = &vs->vqs[i].vq; vqs[i] = &vs->vqs[i].vq;
vs->vqs[i].vq.handle_kick = vhost_scsi_handle_kick; vs->vqs[i].vq.handle_kick = vhost_scsi_handle_kick;
} }
vhost_dev_init(&vs->dev, vqs, VHOST_SCSI_MAX_VQ, UIO_MAXIOV); vhost_dev_init(&vs->dev, vqs, VHOST_SCSI_MAX_VQ, UIO_MAXIOV,
VHOST_SCSI_WEIGHT, 0);
vhost_scsi_init_inflight(vs, NULL); vhost_scsi_init_inflight(vs, NULL);
......
...@@ -413,8 +413,24 @@ static void vhost_dev_free_iovecs(struct vhost_dev *dev) ...@@ -413,8 +413,24 @@ static void vhost_dev_free_iovecs(struct vhost_dev *dev)
vhost_vq_free_iovecs(dev->vqs[i]); vhost_vq_free_iovecs(dev->vqs[i]);
} }
bool vhost_exceeds_weight(struct vhost_virtqueue *vq,
int pkts, int total_len)
{
struct vhost_dev *dev = vq->dev;
if ((dev->byte_weight && total_len >= dev->byte_weight) ||
pkts >= dev->weight) {
vhost_poll_queue(&vq->poll);
return true;
}
return false;
}
EXPORT_SYMBOL_GPL(vhost_exceeds_weight);
void vhost_dev_init(struct vhost_dev *dev, void vhost_dev_init(struct vhost_dev *dev,
struct vhost_virtqueue **vqs, int nvqs, int iov_limit) struct vhost_virtqueue **vqs, int nvqs,
int iov_limit, int weight, int byte_weight)
{ {
struct vhost_virtqueue *vq; struct vhost_virtqueue *vq;
int i; int i;
...@@ -428,6 +444,8 @@ void vhost_dev_init(struct vhost_dev *dev, ...@@ -428,6 +444,8 @@ void vhost_dev_init(struct vhost_dev *dev,
dev->mm = NULL; dev->mm = NULL;
dev->worker = NULL; dev->worker = NULL;
dev->iov_limit = iov_limit; dev->iov_limit = iov_limit;
dev->weight = weight;
dev->byte_weight = byte_weight;
init_llist_head(&dev->work_list); init_llist_head(&dev->work_list);
init_waitqueue_head(&dev->wait); init_waitqueue_head(&dev->wait);
INIT_LIST_HEAD(&dev->read_list); INIT_LIST_HEAD(&dev->read_list);
......
...@@ -171,10 +171,13 @@ struct vhost_dev { ...@@ -171,10 +171,13 @@ struct vhost_dev {
struct list_head pending_list; struct list_head pending_list;
wait_queue_head_t wait; wait_queue_head_t wait;
int iov_limit; int iov_limit;
int weight;
int byte_weight;
}; };
bool vhost_exceeds_weight(struct vhost_virtqueue *vq, int pkts, int total_len);
void vhost_dev_init(struct vhost_dev *, struct vhost_virtqueue **vqs, void vhost_dev_init(struct vhost_dev *, struct vhost_virtqueue **vqs,
int nvqs, int iov_limit); int nvqs, int iov_limit, int weight, int byte_weight);
long vhost_dev_set_owner(struct vhost_dev *dev); long vhost_dev_set_owner(struct vhost_dev *dev);
bool vhost_dev_has_owner(struct vhost_dev *dev); bool vhost_dev_has_owner(struct vhost_dev *dev);
long vhost_dev_check_owner(struct vhost_dev *); long vhost_dev_check_owner(struct vhost_dev *);
......
...@@ -21,6 +21,14 @@ ...@@ -21,6 +21,14 @@
#include "vhost.h" #include "vhost.h"
#define VHOST_VSOCK_DEFAULT_HOST_CID 2 #define VHOST_VSOCK_DEFAULT_HOST_CID 2
/* Max number of bytes transferred before requeueing the job.
* Using this limit prevents one virtqueue from starving others. */
#define VHOST_VSOCK_WEIGHT 0x80000
/* Max number of packets transferred before requeueing the job.
* Using this limit prevents one virtqueue from starving others with
* small pkts.
*/
#define VHOST_VSOCK_PKT_WEIGHT 256
enum { enum {
VHOST_VSOCK_FEATURES = VHOST_FEATURES, VHOST_VSOCK_FEATURES = VHOST_FEATURES,
...@@ -78,6 +86,7 @@ vhost_transport_do_send_pkt(struct vhost_vsock *vsock, ...@@ -78,6 +86,7 @@ vhost_transport_do_send_pkt(struct vhost_vsock *vsock,
struct vhost_virtqueue *vq) struct vhost_virtqueue *vq)
{ {
struct vhost_virtqueue *tx_vq = &vsock->vqs[VSOCK_VQ_TX]; struct vhost_virtqueue *tx_vq = &vsock->vqs[VSOCK_VQ_TX];
int pkts = 0, total_len = 0;
bool added = false; bool added = false;
bool restart_tx = false; bool restart_tx = false;
...@@ -89,7 +98,7 @@ vhost_transport_do_send_pkt(struct vhost_vsock *vsock, ...@@ -89,7 +98,7 @@ vhost_transport_do_send_pkt(struct vhost_vsock *vsock,
/* Avoid further vmexits, we're already processing the virtqueue */ /* Avoid further vmexits, we're already processing the virtqueue */
vhost_disable_notify(&vsock->dev, vq); vhost_disable_notify(&vsock->dev, vq);
for (;;) { do {
struct virtio_vsock_pkt *pkt; struct virtio_vsock_pkt *pkt;
struct iov_iter iov_iter; struct iov_iter iov_iter;
unsigned out, in; unsigned out, in;
...@@ -174,8 +183,9 @@ vhost_transport_do_send_pkt(struct vhost_vsock *vsock, ...@@ -174,8 +183,9 @@ vhost_transport_do_send_pkt(struct vhost_vsock *vsock,
*/ */
virtio_transport_deliver_tap_pkt(pkt); virtio_transport_deliver_tap_pkt(pkt);
total_len += pkt->len;
virtio_transport_free_pkt(pkt); virtio_transport_free_pkt(pkt);
} } while(likely(!vhost_exceeds_weight(vq, ++pkts, total_len)));
if (added) if (added)
vhost_signal(&vsock->dev, vq); vhost_signal(&vsock->dev, vq);
...@@ -350,7 +360,7 @@ static void vhost_vsock_handle_tx_kick(struct vhost_work *work) ...@@ -350,7 +360,7 @@ static void vhost_vsock_handle_tx_kick(struct vhost_work *work)
struct vhost_vsock *vsock = container_of(vq->dev, struct vhost_vsock, struct vhost_vsock *vsock = container_of(vq->dev, struct vhost_vsock,
dev); dev);
struct virtio_vsock_pkt *pkt; struct virtio_vsock_pkt *pkt;
int head; int head, pkts = 0, total_len = 0;
unsigned int out, in; unsigned int out, in;
bool added = false; bool added = false;
...@@ -360,7 +370,7 @@ static void vhost_vsock_handle_tx_kick(struct vhost_work *work) ...@@ -360,7 +370,7 @@ static void vhost_vsock_handle_tx_kick(struct vhost_work *work)
goto out; goto out;
vhost_disable_notify(&vsock->dev, vq); vhost_disable_notify(&vsock->dev, vq);
for (;;) { do {
u32 len; u32 len;
if (!vhost_vsock_more_replies(vsock)) { if (!vhost_vsock_more_replies(vsock)) {
...@@ -401,9 +411,11 @@ static void vhost_vsock_handle_tx_kick(struct vhost_work *work) ...@@ -401,9 +411,11 @@ static void vhost_vsock_handle_tx_kick(struct vhost_work *work)
else else
virtio_transport_free_pkt(pkt); virtio_transport_free_pkt(pkt);
vhost_add_used(vq, head, sizeof(pkt->hdr) + len); len += sizeof(pkt->hdr);
vhost_add_used(vq, head, len);
total_len += len;
added = true; added = true;
} } while(likely(!vhost_exceeds_weight(vq, ++pkts, total_len)));
no_more_replies: no_more_replies:
if (added) if (added)
...@@ -531,7 +543,9 @@ static int vhost_vsock_dev_open(struct inode *inode, struct file *file) ...@@ -531,7 +543,9 @@ static int vhost_vsock_dev_open(struct inode *inode, struct file *file)
vsock->vqs[VSOCK_VQ_TX].handle_kick = vhost_vsock_handle_tx_kick; vsock->vqs[VSOCK_VQ_TX].handle_kick = vhost_vsock_handle_tx_kick;
vsock->vqs[VSOCK_VQ_RX].handle_kick = vhost_vsock_handle_rx_kick; vsock->vqs[VSOCK_VQ_RX].handle_kick = vhost_vsock_handle_rx_kick;
vhost_dev_init(&vsock->dev, vqs, ARRAY_SIZE(vsock->vqs), UIO_MAXIOV); vhost_dev_init(&vsock->dev, vqs, ARRAY_SIZE(vsock->vqs),
UIO_MAXIOV, VHOST_VSOCK_PKT_WEIGHT,
VHOST_VSOCK_WEIGHT);
file->private_data = vsock; file->private_data = vsock;
spin_lock_init(&vsock->send_pkt_list_lock); spin_lock_init(&vsock->send_pkt_list_lock);
......
...@@ -63,12 +63,12 @@ config VIRTIO_INPUT ...@@ -63,12 +63,12 @@ config VIRTIO_INPUT
If unsure, say M. If unsure, say M.
config VIRTIO_MMIO config VIRTIO_MMIO
tristate "Platform bus driver for memory mapped virtio devices" tristate "Platform bus driver for memory mapped virtio devices"
depends on HAS_IOMEM && HAS_DMA depends on HAS_IOMEM && HAS_DMA
select VIRTIO select VIRTIO
---help--- ---help---
This drivers provides support for memory mapped virtio This drivers provides support for memory mapped virtio
platform device driver. platform device driver.
If unsure, say N. If unsure, say N.
......
...@@ -127,7 +127,7 @@ static inline void free_page(unsigned long addr) ...@@ -127,7 +127,7 @@ static inline void free_page(unsigned long addr)
#define dev_err(dev, format, ...) fprintf (stderr, format, ## __VA_ARGS__) #define dev_err(dev, format, ...) fprintf (stderr, format, ## __VA_ARGS__)
#define dev_warn(dev, format, ...) fprintf (stderr, format, ## __VA_ARGS__) #define dev_warn(dev, format, ...) fprintf (stderr, format, ## __VA_ARGS__)
#define WARN_ON_ONCE(cond) ((cond) ? fprintf (stderr, "WARNING\n") : 0) #define WARN_ON_ONCE(cond) (unlikely(cond) ? fprintf (stderr, "WARNING\n") : 0)
#define min(x, y) ({ \ #define min(x, y) ({ \
typeof(x) _min1 = (x); \ typeof(x) _min1 = (x); \
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment