Commit 941e3e79 authored by Linus Torvalds's avatar Linus Torvalds

Merge tag 'for_linus' of git://git.kernel.org/pub/scm/linux/kernel/git/mst/vhost

Pull virtio fixes from Michael Tsirkin:
 "Fixes all over the place, most notably we are disabling
  IRQ hardening (again!)"

* tag 'for_linus' of git://git.kernel.org/pub/scm/linux/kernel/git/mst/vhost:
  virtio_ring: make vring_create_virtqueue_split prettier
  vhost-vdpa: call vhost_vdpa_cleanup during the release
  virtio_mmio: Restore guest page size on resume
  virtio_mmio: Add missing PM calls to freeze/restore
  caif_virtio: fix race between virtio_device_ready() and ndo_open()
  virtio-net: fix race between ndo_open() and virtio_device_ready()
  virtio: disable notification hardening by default
  virtio: Remove unnecessary variable assignments
  virtio_ring : keep used_wrap_counter in vq->last_used_idx
  vduse: Tie vduse mgmtdev and its device
  vdpa/mlx5: Initialize CVQ vringh only once
  vdpa/mlx5: Update Control VQ callback information
parents 23900951 c7cc29aa
...@@ -722,13 +722,21 @@ static int cfv_probe(struct virtio_device *vdev) ...@@ -722,13 +722,21 @@ static int cfv_probe(struct virtio_device *vdev)
/* Carrier is off until netdevice is opened */ /* Carrier is off until netdevice is opened */
netif_carrier_off(netdev); netif_carrier_off(netdev);
/* serialize netdev register + virtio_device_ready() with ndo_open() */
rtnl_lock();
/* register Netdev */ /* register Netdev */
err = register_netdev(netdev); err = register_netdevice(netdev);
if (err) { if (err) {
rtnl_unlock();
dev_err(&vdev->dev, "Unable to register netdev (%d)\n", err); dev_err(&vdev->dev, "Unable to register netdev (%d)\n", err);
goto err; goto err;
} }
virtio_device_ready(vdev);
rtnl_unlock();
debugfs_init(cfv); debugfs_init(cfv);
return 0; return 0;
......
...@@ -3642,14 +3642,20 @@ static int virtnet_probe(struct virtio_device *vdev) ...@@ -3642,14 +3642,20 @@ static int virtnet_probe(struct virtio_device *vdev)
if (vi->has_rss || vi->has_rss_hash_report) if (vi->has_rss || vi->has_rss_hash_report)
virtnet_init_default_rss(vi); virtnet_init_default_rss(vi);
err = register_netdev(dev); /* serialize netdev register + virtio_device_ready() with ndo_open() */
rtnl_lock();
err = register_netdevice(dev);
if (err) { if (err) {
pr_debug("virtio_net: registering device failed\n"); pr_debug("virtio_net: registering device failed\n");
rtnl_unlock();
goto free_failover; goto free_failover;
} }
virtio_device_ready(vdev); virtio_device_ready(vdev);
rtnl_unlock();
err = virtnet_cpu_notif_add(vi); err = virtnet_cpu_notif_add(vi);
if (err) { if (err) {
pr_debug("virtio_net: registering cpu notifier failed\n"); pr_debug("virtio_net: registering cpu notifier failed\n");
......
...@@ -1136,8 +1136,13 @@ static void virtio_ccw_int_handler(struct ccw_device *cdev, ...@@ -1136,8 +1136,13 @@ static void virtio_ccw_int_handler(struct ccw_device *cdev,
vcdev->err = -EIO; vcdev->err = -EIO;
} }
virtio_ccw_check_activity(vcdev, activity); virtio_ccw_check_activity(vcdev, activity);
/* Interrupts are disabled here */ #ifdef CONFIG_VIRTIO_HARDEN_NOTIFICATION
/*
* Paired with virtio_ccw_synchronize_cbs() and interrupts are
* disabled here.
*/
read_lock(&vcdev->irq_lock); read_lock(&vcdev->irq_lock);
#endif
for_each_set_bit(i, indicators(vcdev), for_each_set_bit(i, indicators(vcdev),
sizeof(*indicators(vcdev)) * BITS_PER_BYTE) { sizeof(*indicators(vcdev)) * BITS_PER_BYTE) {
/* The bit clear must happen before the vring kick. */ /* The bit clear must happen before the vring kick. */
...@@ -1146,7 +1151,9 @@ static void virtio_ccw_int_handler(struct ccw_device *cdev, ...@@ -1146,7 +1151,9 @@ static void virtio_ccw_int_handler(struct ccw_device *cdev,
vq = virtio_ccw_vq_by_ind(vcdev, i); vq = virtio_ccw_vq_by_ind(vcdev, i);
vring_interrupt(0, vq); vring_interrupt(0, vq);
} }
#ifdef CONFIG_VIRTIO_HARDEN_NOTIFICATION
read_unlock(&vcdev->irq_lock); read_unlock(&vcdev->irq_lock);
#endif
if (test_bit(0, indicators2(vcdev))) { if (test_bit(0, indicators2(vcdev))) {
virtio_config_changed(&vcdev->vdev); virtio_config_changed(&vcdev->vdev);
clear_bit(0, indicators2(vcdev)); clear_bit(0, indicators2(vcdev));
......
...@@ -1962,6 +1962,8 @@ static void mlx5_vdpa_set_vq_cb(struct vdpa_device *vdev, u16 idx, struct vdpa_c ...@@ -1962,6 +1962,8 @@ static void mlx5_vdpa_set_vq_cb(struct vdpa_device *vdev, u16 idx, struct vdpa_c
struct mlx5_vdpa_net *ndev = to_mlx5_vdpa_ndev(mvdev); struct mlx5_vdpa_net *ndev = to_mlx5_vdpa_ndev(mvdev);
ndev->event_cbs[idx] = *cb; ndev->event_cbs[idx] = *cb;
if (is_ctrl_vq_idx(mvdev, idx))
mvdev->cvq.event_cb = *cb;
} }
static void mlx5_cvq_notify(struct vringh *vring) static void mlx5_cvq_notify(struct vringh *vring)
...@@ -2174,7 +2176,6 @@ static int verify_driver_features(struct mlx5_vdpa_dev *mvdev, u64 features) ...@@ -2174,7 +2176,6 @@ static int verify_driver_features(struct mlx5_vdpa_dev *mvdev, u64 features)
static int setup_virtqueues(struct mlx5_vdpa_dev *mvdev) static int setup_virtqueues(struct mlx5_vdpa_dev *mvdev)
{ {
struct mlx5_vdpa_net *ndev = to_mlx5_vdpa_ndev(mvdev); struct mlx5_vdpa_net *ndev = to_mlx5_vdpa_ndev(mvdev);
struct mlx5_control_vq *cvq = &mvdev->cvq;
int err; int err;
int i; int i;
...@@ -2184,16 +2185,6 @@ static int setup_virtqueues(struct mlx5_vdpa_dev *mvdev) ...@@ -2184,16 +2185,6 @@ static int setup_virtqueues(struct mlx5_vdpa_dev *mvdev)
goto err_vq; goto err_vq;
} }
if (mvdev->actual_features & BIT_ULL(VIRTIO_NET_F_CTRL_VQ)) {
err = vringh_init_iotlb(&cvq->vring, mvdev->actual_features,
MLX5_CVQ_MAX_ENT, false,
(struct vring_desc *)(uintptr_t)cvq->desc_addr,
(struct vring_avail *)(uintptr_t)cvq->driver_addr,
(struct vring_used *)(uintptr_t)cvq->device_addr);
if (err)
goto err_vq;
}
return 0; return 0;
err_vq: err_vq:
...@@ -2466,6 +2457,21 @@ static void clear_vqs_ready(struct mlx5_vdpa_net *ndev) ...@@ -2466,6 +2457,21 @@ static void clear_vqs_ready(struct mlx5_vdpa_net *ndev)
ndev->mvdev.cvq.ready = false; ndev->mvdev.cvq.ready = false;
} }
static int setup_cvq_vring(struct mlx5_vdpa_dev *mvdev)
{
struct mlx5_control_vq *cvq = &mvdev->cvq;
int err = 0;
if (mvdev->actual_features & BIT_ULL(VIRTIO_NET_F_CTRL_VQ))
err = vringh_init_iotlb(&cvq->vring, mvdev->actual_features,
MLX5_CVQ_MAX_ENT, false,
(struct vring_desc *)(uintptr_t)cvq->desc_addr,
(struct vring_avail *)(uintptr_t)cvq->driver_addr,
(struct vring_used *)(uintptr_t)cvq->device_addr);
return err;
}
static void mlx5_vdpa_set_status(struct vdpa_device *vdev, u8 status) static void mlx5_vdpa_set_status(struct vdpa_device *vdev, u8 status)
{ {
struct mlx5_vdpa_dev *mvdev = to_mvdev(vdev); struct mlx5_vdpa_dev *mvdev = to_mvdev(vdev);
...@@ -2478,6 +2484,11 @@ static void mlx5_vdpa_set_status(struct vdpa_device *vdev, u8 status) ...@@ -2478,6 +2484,11 @@ static void mlx5_vdpa_set_status(struct vdpa_device *vdev, u8 status)
if ((status ^ ndev->mvdev.status) & VIRTIO_CONFIG_S_DRIVER_OK) { if ((status ^ ndev->mvdev.status) & VIRTIO_CONFIG_S_DRIVER_OK) {
if (status & VIRTIO_CONFIG_S_DRIVER_OK) { if (status & VIRTIO_CONFIG_S_DRIVER_OK) {
err = setup_cvq_vring(mvdev);
if (err) {
mlx5_vdpa_warn(mvdev, "failed to setup control VQ vring\n");
goto err_setup;
}
err = setup_driver(mvdev); err = setup_driver(mvdev);
if (err) { if (err) {
mlx5_vdpa_warn(mvdev, "failed to setup driver\n"); mlx5_vdpa_warn(mvdev, "failed to setup driver\n");
......
...@@ -1476,16 +1476,12 @@ static char *vduse_devnode(struct device *dev, umode_t *mode) ...@@ -1476,16 +1476,12 @@ static char *vduse_devnode(struct device *dev, umode_t *mode)
return kasprintf(GFP_KERNEL, "vduse/%s", dev_name(dev)); return kasprintf(GFP_KERNEL, "vduse/%s", dev_name(dev));
} }
static void vduse_mgmtdev_release(struct device *dev) struct vduse_mgmt_dev {
{ struct vdpa_mgmt_dev mgmt_dev;
} struct device dev;
static struct device vduse_mgmtdev = {
.init_name = "vduse",
.release = vduse_mgmtdev_release,
}; };
static struct vdpa_mgmt_dev mgmt_dev; static struct vduse_mgmt_dev *vduse_mgmt;
static int vduse_dev_init_vdpa(struct vduse_dev *dev, const char *name) static int vduse_dev_init_vdpa(struct vduse_dev *dev, const char *name)
{ {
...@@ -1510,7 +1506,7 @@ static int vduse_dev_init_vdpa(struct vduse_dev *dev, const char *name) ...@@ -1510,7 +1506,7 @@ static int vduse_dev_init_vdpa(struct vduse_dev *dev, const char *name)
} }
set_dma_ops(&vdev->vdpa.dev, &vduse_dev_dma_ops); set_dma_ops(&vdev->vdpa.dev, &vduse_dev_dma_ops);
vdev->vdpa.dma_dev = &vdev->vdpa.dev; vdev->vdpa.dma_dev = &vdev->vdpa.dev;
vdev->vdpa.mdev = &mgmt_dev; vdev->vdpa.mdev = &vduse_mgmt->mgmt_dev;
return 0; return 0;
} }
...@@ -1556,34 +1552,52 @@ static struct virtio_device_id id_table[] = { ...@@ -1556,34 +1552,52 @@ static struct virtio_device_id id_table[] = {
{ 0 }, { 0 },
}; };
static struct vdpa_mgmt_dev mgmt_dev = { static void vduse_mgmtdev_release(struct device *dev)
.device = &vduse_mgmtdev, {
.id_table = id_table, struct vduse_mgmt_dev *mgmt_dev;
.ops = &vdpa_dev_mgmtdev_ops,
}; mgmt_dev = container_of(dev, struct vduse_mgmt_dev, dev);
kfree(mgmt_dev);
}
static int vduse_mgmtdev_init(void) static int vduse_mgmtdev_init(void)
{ {
int ret; int ret;
ret = device_register(&vduse_mgmtdev); vduse_mgmt = kzalloc(sizeof(*vduse_mgmt), GFP_KERNEL);
if (ret) if (!vduse_mgmt)
return -ENOMEM;
ret = dev_set_name(&vduse_mgmt->dev, "vduse");
if (ret) {
kfree(vduse_mgmt);
return ret; return ret;
}
ret = vdpa_mgmtdev_register(&mgmt_dev); vduse_mgmt->dev.release = vduse_mgmtdev_release;
ret = device_register(&vduse_mgmt->dev);
if (ret) if (ret)
goto err; goto dev_reg_err;
return 0; vduse_mgmt->mgmt_dev.id_table = id_table;
err: vduse_mgmt->mgmt_dev.ops = &vdpa_dev_mgmtdev_ops;
device_unregister(&vduse_mgmtdev); vduse_mgmt->mgmt_dev.device = &vduse_mgmt->dev;
ret = vdpa_mgmtdev_register(&vduse_mgmt->mgmt_dev);
if (ret)
device_unregister(&vduse_mgmt->dev);
return ret;
dev_reg_err:
put_device(&vduse_mgmt->dev);
return ret; return ret;
} }
static void vduse_mgmtdev_exit(void) static void vduse_mgmtdev_exit(void)
{ {
vdpa_mgmtdev_unregister(&mgmt_dev); vdpa_mgmtdev_unregister(&vduse_mgmt->mgmt_dev);
device_unregister(&vduse_mgmtdev); device_unregister(&vduse_mgmt->dev);
} }
static int vduse_init(void) static int vduse_init(void)
......
...@@ -1209,7 +1209,7 @@ static int vhost_vdpa_release(struct inode *inode, struct file *filep) ...@@ -1209,7 +1209,7 @@ static int vhost_vdpa_release(struct inode *inode, struct file *filep)
vhost_dev_stop(&v->vdev); vhost_dev_stop(&v->vdev);
vhost_vdpa_free_domain(v); vhost_vdpa_free_domain(v);
vhost_vdpa_config_put(v); vhost_vdpa_config_put(v);
vhost_dev_cleanup(&v->vdev); vhost_vdpa_cleanup(v);
mutex_unlock(&d->mutex); mutex_unlock(&d->mutex);
atomic_dec(&v->opened); atomic_dec(&v->opened);
......
...@@ -29,6 +29,19 @@ menuconfig VIRTIO_MENU ...@@ -29,6 +29,19 @@ menuconfig VIRTIO_MENU
if VIRTIO_MENU if VIRTIO_MENU
config VIRTIO_HARDEN_NOTIFICATION
bool "Harden virtio notification"
help
Enable this to harden the device notifications and suppress
those that happen at a time where notifications are illegal.
Experimental: Note that several drivers still have bugs that
may cause crashes or hangs when correct handling of
notifications is enforced; depending on the subset of
drivers and devices you use, this may or may not work.
If unsure, say N.
config VIRTIO_PCI config VIRTIO_PCI
tristate "PCI driver for virtio devices" tristate "PCI driver for virtio devices"
depends on PCI depends on PCI
......
...@@ -219,6 +219,7 @@ static int virtio_features_ok(struct virtio_device *dev) ...@@ -219,6 +219,7 @@ static int virtio_features_ok(struct virtio_device *dev)
* */ * */
void virtio_reset_device(struct virtio_device *dev) void virtio_reset_device(struct virtio_device *dev)
{ {
#ifdef CONFIG_VIRTIO_HARDEN_NOTIFICATION
/* /*
* The below virtio_synchronize_cbs() guarantees that any * The below virtio_synchronize_cbs() guarantees that any
* interrupt for this line arriving after * interrupt for this line arriving after
...@@ -227,6 +228,7 @@ void virtio_reset_device(struct virtio_device *dev) ...@@ -227,6 +228,7 @@ void virtio_reset_device(struct virtio_device *dev)
*/ */
virtio_break_device(dev); virtio_break_device(dev);
virtio_synchronize_cbs(dev); virtio_synchronize_cbs(dev);
#endif
dev->config->reset(dev); dev->config->reset(dev);
} }
......
...@@ -62,6 +62,7 @@ ...@@ -62,6 +62,7 @@
#include <linux/list.h> #include <linux/list.h>
#include <linux/module.h> #include <linux/module.h>
#include <linux/platform_device.h> #include <linux/platform_device.h>
#include <linux/pm.h>
#include <linux/slab.h> #include <linux/slab.h>
#include <linux/spinlock.h> #include <linux/spinlock.h>
#include <linux/virtio.h> #include <linux/virtio.h>
...@@ -556,6 +557,28 @@ static const struct virtio_config_ops virtio_mmio_config_ops = { ...@@ -556,6 +557,28 @@ static const struct virtio_config_ops virtio_mmio_config_ops = {
.synchronize_cbs = vm_synchronize_cbs, .synchronize_cbs = vm_synchronize_cbs,
}; };
#ifdef CONFIG_PM_SLEEP
static int virtio_mmio_freeze(struct device *dev)
{
struct virtio_mmio_device *vm_dev = dev_get_drvdata(dev);
return virtio_device_freeze(&vm_dev->vdev);
}
static int virtio_mmio_restore(struct device *dev)
{
struct virtio_mmio_device *vm_dev = dev_get_drvdata(dev);
if (vm_dev->version == 1)
writel(PAGE_SIZE, vm_dev->base + VIRTIO_MMIO_GUEST_PAGE_SIZE);
return virtio_device_restore(&vm_dev->vdev);
}
static const struct dev_pm_ops virtio_mmio_pm_ops = {
SET_SYSTEM_SLEEP_PM_OPS(virtio_mmio_freeze, virtio_mmio_restore)
};
#endif
static void virtio_mmio_release_dev(struct device *_d) static void virtio_mmio_release_dev(struct device *_d)
{ {
...@@ -799,6 +822,9 @@ static struct platform_driver virtio_mmio_driver = { ...@@ -799,6 +822,9 @@ static struct platform_driver virtio_mmio_driver = {
.name = "virtio-mmio", .name = "virtio-mmio",
.of_match_table = virtio_mmio_match, .of_match_table = virtio_mmio_match,
.acpi_match_table = ACPI_PTR(virtio_mmio_acpi_match), .acpi_match_table = ACPI_PTR(virtio_mmio_acpi_match),
#ifdef CONFIG_PM_SLEEP
.pm = &virtio_mmio_pm_ops,
#endif
}, },
}; };
......
...@@ -220,8 +220,6 @@ int vp_modern_probe(struct virtio_pci_modern_device *mdev) ...@@ -220,8 +220,6 @@ int vp_modern_probe(struct virtio_pci_modern_device *mdev)
check_offsets(); check_offsets();
mdev->pci_dev = pci_dev;
/* We only own devices >= 0x1000 and <= 0x107f: leave the rest. */ /* We only own devices >= 0x1000 and <= 0x107f: leave the rest. */
if (pci_dev->device < 0x1000 || pci_dev->device > 0x107f) if (pci_dev->device < 0x1000 || pci_dev->device > 0x107f)
return -ENODEV; return -ENODEV;
......
...@@ -111,7 +111,12 @@ struct vring_virtqueue { ...@@ -111,7 +111,12 @@ struct vring_virtqueue {
/* Number we've added since last sync. */ /* Number we've added since last sync. */
unsigned int num_added; unsigned int num_added;
/* Last used index we've seen. */ /* Last used index we've seen.
* for split ring, it just contains last used index
* for packed ring:
* bits up to VRING_PACKED_EVENT_F_WRAP_CTR include the last used index.
* bits from VRING_PACKED_EVENT_F_WRAP_CTR include the used wrap counter.
*/
u16 last_used_idx; u16 last_used_idx;
/* Hint for event idx: already triggered no need to disable. */ /* Hint for event idx: already triggered no need to disable. */
...@@ -154,9 +159,6 @@ struct vring_virtqueue { ...@@ -154,9 +159,6 @@ struct vring_virtqueue {
/* Driver ring wrap counter. */ /* Driver ring wrap counter. */
bool avail_wrap_counter; bool avail_wrap_counter;
/* Device ring wrap counter. */
bool used_wrap_counter;
/* Avail used flags. */ /* Avail used flags. */
u16 avail_used_flags; u16 avail_used_flags;
...@@ -933,7 +935,7 @@ static struct virtqueue *vring_create_virtqueue_split( ...@@ -933,7 +935,7 @@ static struct virtqueue *vring_create_virtqueue_split(
for (; num && vring_size(num, vring_align) > PAGE_SIZE; num /= 2) { for (; num && vring_size(num, vring_align) > PAGE_SIZE; num /= 2) {
queue = vring_alloc_queue(vdev, vring_size(num, vring_align), queue = vring_alloc_queue(vdev, vring_size(num, vring_align),
&dma_addr, &dma_addr,
GFP_KERNEL|__GFP_NOWARN|__GFP_ZERO); GFP_KERNEL | __GFP_NOWARN | __GFP_ZERO);
if (queue) if (queue)
break; break;
if (!may_reduce_num) if (!may_reduce_num)
...@@ -973,6 +975,15 @@ static struct virtqueue *vring_create_virtqueue_split( ...@@ -973,6 +975,15 @@ static struct virtqueue *vring_create_virtqueue_split(
/* /*
* Packed ring specific functions - *_packed(). * Packed ring specific functions - *_packed().
*/ */
static inline bool packed_used_wrap_counter(u16 last_used_idx)
{
return !!(last_used_idx & (1 << VRING_PACKED_EVENT_F_WRAP_CTR));
}
static inline u16 packed_last_used(u16 last_used_idx)
{
return last_used_idx & ~(-(1 << VRING_PACKED_EVENT_F_WRAP_CTR));
}
static void vring_unmap_extra_packed(const struct vring_virtqueue *vq, static void vring_unmap_extra_packed(const struct vring_virtqueue *vq,
struct vring_desc_extra *extra) struct vring_desc_extra *extra)
...@@ -1406,8 +1417,14 @@ static inline bool is_used_desc_packed(const struct vring_virtqueue *vq, ...@@ -1406,8 +1417,14 @@ static inline bool is_used_desc_packed(const struct vring_virtqueue *vq,
static inline bool more_used_packed(const struct vring_virtqueue *vq) static inline bool more_used_packed(const struct vring_virtqueue *vq)
{ {
return is_used_desc_packed(vq, vq->last_used_idx, u16 last_used;
vq->packed.used_wrap_counter); u16 last_used_idx;
bool used_wrap_counter;
last_used_idx = READ_ONCE(vq->last_used_idx);
last_used = packed_last_used(last_used_idx);
used_wrap_counter = packed_used_wrap_counter(last_used_idx);
return is_used_desc_packed(vq, last_used, used_wrap_counter);
} }
static void *virtqueue_get_buf_ctx_packed(struct virtqueue *_vq, static void *virtqueue_get_buf_ctx_packed(struct virtqueue *_vq,
...@@ -1415,7 +1432,8 @@ static void *virtqueue_get_buf_ctx_packed(struct virtqueue *_vq, ...@@ -1415,7 +1432,8 @@ static void *virtqueue_get_buf_ctx_packed(struct virtqueue *_vq,
void **ctx) void **ctx)
{ {
struct vring_virtqueue *vq = to_vvq(_vq); struct vring_virtqueue *vq = to_vvq(_vq);
u16 last_used, id; u16 last_used, id, last_used_idx;
bool used_wrap_counter;
void *ret; void *ret;
START_USE(vq); START_USE(vq);
...@@ -1434,7 +1452,9 @@ static void *virtqueue_get_buf_ctx_packed(struct virtqueue *_vq, ...@@ -1434,7 +1452,9 @@ static void *virtqueue_get_buf_ctx_packed(struct virtqueue *_vq,
/* Only get used elements after they have been exposed by host. */ /* Only get used elements after they have been exposed by host. */
virtio_rmb(vq->weak_barriers); virtio_rmb(vq->weak_barriers);
last_used = vq->last_used_idx; last_used_idx = READ_ONCE(vq->last_used_idx);
used_wrap_counter = packed_used_wrap_counter(last_used_idx);
last_used = packed_last_used(last_used_idx);
id = le16_to_cpu(vq->packed.vring.desc[last_used].id); id = le16_to_cpu(vq->packed.vring.desc[last_used].id);
*len = le32_to_cpu(vq->packed.vring.desc[last_used].len); *len = le32_to_cpu(vq->packed.vring.desc[last_used].len);
...@@ -1451,12 +1471,15 @@ static void *virtqueue_get_buf_ctx_packed(struct virtqueue *_vq, ...@@ -1451,12 +1471,15 @@ static void *virtqueue_get_buf_ctx_packed(struct virtqueue *_vq,
ret = vq->packed.desc_state[id].data; ret = vq->packed.desc_state[id].data;
detach_buf_packed(vq, id, ctx); detach_buf_packed(vq, id, ctx);
vq->last_used_idx += vq->packed.desc_state[id].num; last_used += vq->packed.desc_state[id].num;
if (unlikely(vq->last_used_idx >= vq->packed.vring.num)) { if (unlikely(last_used >= vq->packed.vring.num)) {
vq->last_used_idx -= vq->packed.vring.num; last_used -= vq->packed.vring.num;
vq->packed.used_wrap_counter ^= 1; used_wrap_counter ^= 1;
} }
last_used = (last_used | (used_wrap_counter << VRING_PACKED_EVENT_F_WRAP_CTR));
WRITE_ONCE(vq->last_used_idx, last_used);
/* /*
* If we expect an interrupt for the next entry, tell host * If we expect an interrupt for the next entry, tell host
* by writing event index and flush out the write before * by writing event index and flush out the write before
...@@ -1465,9 +1488,7 @@ static void *virtqueue_get_buf_ctx_packed(struct virtqueue *_vq, ...@@ -1465,9 +1488,7 @@ static void *virtqueue_get_buf_ctx_packed(struct virtqueue *_vq,
if (vq->packed.event_flags_shadow == VRING_PACKED_EVENT_FLAG_DESC) if (vq->packed.event_flags_shadow == VRING_PACKED_EVENT_FLAG_DESC)
virtio_store_mb(vq->weak_barriers, virtio_store_mb(vq->weak_barriers,
&vq->packed.vring.driver->off_wrap, &vq->packed.vring.driver->off_wrap,
cpu_to_le16(vq->last_used_idx | cpu_to_le16(vq->last_used_idx));
(vq->packed.used_wrap_counter <<
VRING_PACKED_EVENT_F_WRAP_CTR)));
LAST_ADD_TIME_INVALID(vq); LAST_ADD_TIME_INVALID(vq);
...@@ -1499,9 +1520,7 @@ static unsigned int virtqueue_enable_cb_prepare_packed(struct virtqueue *_vq) ...@@ -1499,9 +1520,7 @@ static unsigned int virtqueue_enable_cb_prepare_packed(struct virtqueue *_vq)
if (vq->event) { if (vq->event) {
vq->packed.vring.driver->off_wrap = vq->packed.vring.driver->off_wrap =
cpu_to_le16(vq->last_used_idx | cpu_to_le16(vq->last_used_idx);
(vq->packed.used_wrap_counter <<
VRING_PACKED_EVENT_F_WRAP_CTR));
/* /*
* We need to update event offset and event wrap * We need to update event offset and event wrap
* counter first before updating event flags. * counter first before updating event flags.
...@@ -1518,8 +1537,7 @@ static unsigned int virtqueue_enable_cb_prepare_packed(struct virtqueue *_vq) ...@@ -1518,8 +1537,7 @@ static unsigned int virtqueue_enable_cb_prepare_packed(struct virtqueue *_vq)
} }
END_USE(vq); END_USE(vq);
return vq->last_used_idx | ((u16)vq->packed.used_wrap_counter << return vq->last_used_idx;
VRING_PACKED_EVENT_F_WRAP_CTR);
} }
static bool virtqueue_poll_packed(struct virtqueue *_vq, u16 off_wrap) static bool virtqueue_poll_packed(struct virtqueue *_vq, u16 off_wrap)
...@@ -1537,7 +1555,7 @@ static bool virtqueue_poll_packed(struct virtqueue *_vq, u16 off_wrap) ...@@ -1537,7 +1555,7 @@ static bool virtqueue_poll_packed(struct virtqueue *_vq, u16 off_wrap)
static bool virtqueue_enable_cb_delayed_packed(struct virtqueue *_vq) static bool virtqueue_enable_cb_delayed_packed(struct virtqueue *_vq)
{ {
struct vring_virtqueue *vq = to_vvq(_vq); struct vring_virtqueue *vq = to_vvq(_vq);
u16 used_idx, wrap_counter; u16 used_idx, wrap_counter, last_used_idx;
u16 bufs; u16 bufs;
START_USE(vq); START_USE(vq);
...@@ -1550,9 +1568,10 @@ static bool virtqueue_enable_cb_delayed_packed(struct virtqueue *_vq) ...@@ -1550,9 +1568,10 @@ static bool virtqueue_enable_cb_delayed_packed(struct virtqueue *_vq)
if (vq->event) { if (vq->event) {
/* TODO: tune this threshold */ /* TODO: tune this threshold */
bufs = (vq->packed.vring.num - vq->vq.num_free) * 3 / 4; bufs = (vq->packed.vring.num - vq->vq.num_free) * 3 / 4;
wrap_counter = vq->packed.used_wrap_counter; last_used_idx = READ_ONCE(vq->last_used_idx);
wrap_counter = packed_used_wrap_counter(last_used_idx);
used_idx = vq->last_used_idx + bufs; used_idx = packed_last_used(last_used_idx) + bufs;
if (used_idx >= vq->packed.vring.num) { if (used_idx >= vq->packed.vring.num) {
used_idx -= vq->packed.vring.num; used_idx -= vq->packed.vring.num;
wrap_counter ^= 1; wrap_counter ^= 1;
...@@ -1582,9 +1601,10 @@ static bool virtqueue_enable_cb_delayed_packed(struct virtqueue *_vq) ...@@ -1582,9 +1601,10 @@ static bool virtqueue_enable_cb_delayed_packed(struct virtqueue *_vq)
*/ */
virtio_mb(vq->weak_barriers); virtio_mb(vq->weak_barriers);
if (is_used_desc_packed(vq, last_used_idx = READ_ONCE(vq->last_used_idx);
vq->last_used_idx, wrap_counter = packed_used_wrap_counter(last_used_idx);
vq->packed.used_wrap_counter)) { used_idx = packed_last_used(last_used_idx);
if (is_used_desc_packed(vq, used_idx, wrap_counter)) {
END_USE(vq); END_USE(vq);
return false; return false;
} }
...@@ -1688,8 +1708,12 @@ static struct virtqueue *vring_create_virtqueue_packed( ...@@ -1688,8 +1708,12 @@ static struct virtqueue *vring_create_virtqueue_packed(
vq->we_own_ring = true; vq->we_own_ring = true;
vq->notify = notify; vq->notify = notify;
vq->weak_barriers = weak_barriers; vq->weak_barriers = weak_barriers;
#ifdef CONFIG_VIRTIO_HARDEN_NOTIFICATION
vq->broken = true; vq->broken = true;
vq->last_used_idx = 0; #else
vq->broken = false;
#endif
vq->last_used_idx = 0 | (1 << VRING_PACKED_EVENT_F_WRAP_CTR);
vq->event_triggered = false; vq->event_triggered = false;
vq->num_added = 0; vq->num_added = 0;
vq->packed_ring = true; vq->packed_ring = true;
...@@ -1720,7 +1744,6 @@ static struct virtqueue *vring_create_virtqueue_packed( ...@@ -1720,7 +1744,6 @@ static struct virtqueue *vring_create_virtqueue_packed(
vq->packed.next_avail_idx = 0; vq->packed.next_avail_idx = 0;
vq->packed.avail_wrap_counter = 1; vq->packed.avail_wrap_counter = 1;
vq->packed.used_wrap_counter = 1;
vq->packed.event_flags_shadow = 0; vq->packed.event_flags_shadow = 0;
vq->packed.avail_used_flags = 1 << VRING_PACKED_DESC_F_AVAIL; vq->packed.avail_used_flags = 1 << VRING_PACKED_DESC_F_AVAIL;
...@@ -2135,9 +2158,13 @@ irqreturn_t vring_interrupt(int irq, void *_vq) ...@@ -2135,9 +2158,13 @@ irqreturn_t vring_interrupt(int irq, void *_vq)
} }
if (unlikely(vq->broken)) { if (unlikely(vq->broken)) {
#ifdef CONFIG_VIRTIO_HARDEN_NOTIFICATION
dev_warn_once(&vq->vq.vdev->dev, dev_warn_once(&vq->vq.vdev->dev,
"virtio vring IRQ raised before DRIVER_OK"); "virtio vring IRQ raised before DRIVER_OK");
return IRQ_NONE; return IRQ_NONE;
#else
return IRQ_HANDLED;
#endif
} }
/* Just a hint for performance: so it's ok that this can be racy! */ /* Just a hint for performance: so it's ok that this can be racy! */
...@@ -2180,7 +2207,11 @@ struct virtqueue *__vring_new_virtqueue(unsigned int index, ...@@ -2180,7 +2207,11 @@ struct virtqueue *__vring_new_virtqueue(unsigned int index,
vq->we_own_ring = false; vq->we_own_ring = false;
vq->notify = notify; vq->notify = notify;
vq->weak_barriers = weak_barriers; vq->weak_barriers = weak_barriers;
#ifdef CONFIG_VIRTIO_HARDEN_NOTIFICATION
vq->broken = true; vq->broken = true;
#else
vq->broken = false;
#endif
vq->last_used_idx = 0; vq->last_used_idx = 0;
vq->event_triggered = false; vq->event_triggered = false;
vq->num_added = 0; vq->num_added = 0;
......
...@@ -257,6 +257,7 @@ void virtio_device_ready(struct virtio_device *dev) ...@@ -257,6 +257,7 @@ void virtio_device_ready(struct virtio_device *dev)
WARN_ON(status & VIRTIO_CONFIG_S_DRIVER_OK); WARN_ON(status & VIRTIO_CONFIG_S_DRIVER_OK);
#ifdef CONFIG_VIRTIO_HARDEN_NOTIFICATION
/* /*
* The virtio_synchronize_cbs() makes sure vring_interrupt() * The virtio_synchronize_cbs() makes sure vring_interrupt()
* will see the driver specific setup if it sees vq->broken * will see the driver specific setup if it sees vq->broken
...@@ -264,6 +265,7 @@ void virtio_device_ready(struct virtio_device *dev) ...@@ -264,6 +265,7 @@ void virtio_device_ready(struct virtio_device *dev)
*/ */
virtio_synchronize_cbs(dev); virtio_synchronize_cbs(dev);
__virtio_unbreak_device(dev); __virtio_unbreak_device(dev);
#endif
/* /*
* The transport should ensure the visibility of vq->broken * The transport should ensure the visibility of vq->broken
* before setting DRIVER_OK. See the comments for the transport * before setting DRIVER_OK. See the comments for the transport
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment