Commit 3f4957eb authored by David S. Miller's avatar David S. Miller

Merge branch 'vsock-virtio-fixes'

Stefano Garzarella says:

====================
vsock/virtio: several fixes in the .probe() and .remove()

During the review of "[PATCH] vsock/virtio: Initialize core virtio vsock
before registering the driver", Stefan pointed out some possible issues
in the .probe() and .remove() callbacks of the virtio-vsock driver.

This series tries to solve these issues:
- Patch 1 adds RCU critical sections to avoid use-after-free of
  'the_virtio_vsock' pointer.
- Patch 2 stops workers before to call vdev->config->reset(vdev) to
  be sure that no one is accessing the device.
- Patch 3 moves the works flush at the end of the .remove() to avoid
  use-after-free of 'vsock' object.

v3:
- Patch 1: use rcu_dereference_protected() to get the_virtio_vosck value in
           the virtio_vsock_probe() [Jason]

v2: https://patchwork.kernel.org/cover/11022343/

v1: https://patchwork.kernel.org/cover/10964733/

Before this series the guest crashes in a few second. After this series the
test runs (~12h) without issues.
Tested on an SMP guest (-smp 4 -monitor tcp:127.0.0.1:1234,server,nowait)
with these scripts to stress the .probe()/.remove() path:

- guest
  while true; do
      cat /dev/urandom | nc-vsock -l 4321 > /dev/null &
      cat /dev/urandom | nc-vsock -l 5321 > /dev/null &
      cat /dev/urandom | nc-vsock -l 6321 > /dev/null &
      cat /dev/urandom | nc-vsock -l 7321 > /dev/null &
      wait
  done

- host
  while true; do
      cat /dev/urandom | nc-vsock 3 4321 > /dev/null &
      cat /dev/urandom | nc-vsock 3 5321 > /dev/null &
      cat /dev/urandom | nc-vsock 3 6321 > /dev/null &
      cat /dev/urandom | nc-vsock 3 7321 > /dev/null &
      sleep 2
      echo "device_del v1" | nc 127.0.0.1 1234
      sleep 1
      echo "device_add vhost-vsock-pci,id=v1,guest-cid=3" | nc 127.0.0.1 1234
      sleep 1
  done
====================
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents 1a2d405c e226121f
...@@ -38,6 +38,7 @@ struct virtio_vsock { ...@@ -38,6 +38,7 @@ struct virtio_vsock {
* must be accessed with tx_lock held. * must be accessed with tx_lock held.
*/ */
struct mutex tx_lock; struct mutex tx_lock;
bool tx_run;
struct work_struct send_pkt_work; struct work_struct send_pkt_work;
spinlock_t send_pkt_list_lock; spinlock_t send_pkt_list_lock;
...@@ -53,6 +54,7 @@ struct virtio_vsock { ...@@ -53,6 +54,7 @@ struct virtio_vsock {
* must be accessed with rx_lock held. * must be accessed with rx_lock held.
*/ */
struct mutex rx_lock; struct mutex rx_lock;
bool rx_run;
int rx_buf_nr; int rx_buf_nr;
int rx_buf_max_nr; int rx_buf_max_nr;
...@@ -60,24 +62,28 @@ struct virtio_vsock { ...@@ -60,24 +62,28 @@ struct virtio_vsock {
* vqs[VSOCK_VQ_EVENT] must be accessed with event_lock held. * vqs[VSOCK_VQ_EVENT] must be accessed with event_lock held.
*/ */
struct mutex event_lock; struct mutex event_lock;
bool event_run;
struct virtio_vsock_event event_list[8]; struct virtio_vsock_event event_list[8];
u32 guest_cid; u32 guest_cid;
}; };
static struct virtio_vsock *virtio_vsock_get(void)
{
return the_virtio_vsock;
}
static u32 virtio_transport_get_local_cid(void) static u32 virtio_transport_get_local_cid(void)
{ {
struct virtio_vsock *vsock = virtio_vsock_get(); struct virtio_vsock *vsock;
u32 ret;
if (!vsock) rcu_read_lock();
return VMADDR_CID_ANY; vsock = rcu_dereference(the_virtio_vsock);
if (!vsock) {
ret = VMADDR_CID_ANY;
goto out_rcu;
}
return vsock->guest_cid; ret = vsock->guest_cid;
out_rcu:
rcu_read_unlock();
return ret;
} }
static void virtio_transport_loopback_work(struct work_struct *work) static void virtio_transport_loopback_work(struct work_struct *work)
...@@ -91,6 +97,10 @@ static void virtio_transport_loopback_work(struct work_struct *work) ...@@ -91,6 +97,10 @@ static void virtio_transport_loopback_work(struct work_struct *work)
spin_unlock_bh(&vsock->loopback_list_lock); spin_unlock_bh(&vsock->loopback_list_lock);
mutex_lock(&vsock->rx_lock); mutex_lock(&vsock->rx_lock);
if (!vsock->rx_run)
goto out;
while (!list_empty(&pkts)) { while (!list_empty(&pkts)) {
struct virtio_vsock_pkt *pkt; struct virtio_vsock_pkt *pkt;
...@@ -99,6 +109,7 @@ static void virtio_transport_loopback_work(struct work_struct *work) ...@@ -99,6 +109,7 @@ static void virtio_transport_loopback_work(struct work_struct *work)
virtio_transport_recv_pkt(pkt); virtio_transport_recv_pkt(pkt);
} }
out:
mutex_unlock(&vsock->rx_lock); mutex_unlock(&vsock->rx_lock);
} }
...@@ -127,6 +138,9 @@ virtio_transport_send_pkt_work(struct work_struct *work) ...@@ -127,6 +138,9 @@ virtio_transport_send_pkt_work(struct work_struct *work)
mutex_lock(&vsock->tx_lock); mutex_lock(&vsock->tx_lock);
if (!vsock->tx_run)
goto out;
vq = vsock->vqs[VSOCK_VQ_TX]; vq = vsock->vqs[VSOCK_VQ_TX];
for (;;) { for (;;) {
...@@ -185,6 +199,7 @@ virtio_transport_send_pkt_work(struct work_struct *work) ...@@ -185,6 +199,7 @@ virtio_transport_send_pkt_work(struct work_struct *work)
if (added) if (added)
virtqueue_kick(vq); virtqueue_kick(vq);
out:
mutex_unlock(&vsock->tx_lock); mutex_unlock(&vsock->tx_lock);
if (restart_rx) if (restart_rx)
...@@ -197,14 +212,18 @@ virtio_transport_send_pkt(struct virtio_vsock_pkt *pkt) ...@@ -197,14 +212,18 @@ virtio_transport_send_pkt(struct virtio_vsock_pkt *pkt)
struct virtio_vsock *vsock; struct virtio_vsock *vsock;
int len = pkt->len; int len = pkt->len;
vsock = virtio_vsock_get(); rcu_read_lock();
vsock = rcu_dereference(the_virtio_vsock);
if (!vsock) { if (!vsock) {
virtio_transport_free_pkt(pkt); virtio_transport_free_pkt(pkt);
return -ENODEV; len = -ENODEV;
goto out_rcu;
} }
if (le64_to_cpu(pkt->hdr.dst_cid) == vsock->guest_cid) if (le64_to_cpu(pkt->hdr.dst_cid) == vsock->guest_cid) {
return virtio_transport_send_pkt_loopback(vsock, pkt); len = virtio_transport_send_pkt_loopback(vsock, pkt);
goto out_rcu;
}
if (pkt->reply) if (pkt->reply)
atomic_inc(&vsock->queued_replies); atomic_inc(&vsock->queued_replies);
...@@ -214,6 +233,9 @@ virtio_transport_send_pkt(struct virtio_vsock_pkt *pkt) ...@@ -214,6 +233,9 @@ virtio_transport_send_pkt(struct virtio_vsock_pkt *pkt)
spin_unlock_bh(&vsock->send_pkt_list_lock); spin_unlock_bh(&vsock->send_pkt_list_lock);
queue_work(virtio_vsock_workqueue, &vsock->send_pkt_work); queue_work(virtio_vsock_workqueue, &vsock->send_pkt_work);
out_rcu:
rcu_read_unlock();
return len; return len;
} }
...@@ -222,12 +244,14 @@ virtio_transport_cancel_pkt(struct vsock_sock *vsk) ...@@ -222,12 +244,14 @@ virtio_transport_cancel_pkt(struct vsock_sock *vsk)
{ {
struct virtio_vsock *vsock; struct virtio_vsock *vsock;
struct virtio_vsock_pkt *pkt, *n; struct virtio_vsock_pkt *pkt, *n;
int cnt = 0; int cnt = 0, ret;
LIST_HEAD(freeme); LIST_HEAD(freeme);
vsock = virtio_vsock_get(); rcu_read_lock();
vsock = rcu_dereference(the_virtio_vsock);
if (!vsock) { if (!vsock) {
return -ENODEV; ret = -ENODEV;
goto out_rcu;
} }
spin_lock_bh(&vsock->send_pkt_list_lock); spin_lock_bh(&vsock->send_pkt_list_lock);
...@@ -255,7 +279,11 @@ virtio_transport_cancel_pkt(struct vsock_sock *vsk) ...@@ -255,7 +279,11 @@ virtio_transport_cancel_pkt(struct vsock_sock *vsk)
queue_work(virtio_vsock_workqueue, &vsock->rx_work); queue_work(virtio_vsock_workqueue, &vsock->rx_work);
} }
return 0; ret = 0;
out_rcu:
rcu_read_unlock();
return ret;
} }
static void virtio_vsock_rx_fill(struct virtio_vsock *vsock) static void virtio_vsock_rx_fill(struct virtio_vsock *vsock)
...@@ -307,6 +335,10 @@ static void virtio_transport_tx_work(struct work_struct *work) ...@@ -307,6 +335,10 @@ static void virtio_transport_tx_work(struct work_struct *work)
vq = vsock->vqs[VSOCK_VQ_TX]; vq = vsock->vqs[VSOCK_VQ_TX];
mutex_lock(&vsock->tx_lock); mutex_lock(&vsock->tx_lock);
if (!vsock->tx_run)
goto out;
do { do {
struct virtio_vsock_pkt *pkt; struct virtio_vsock_pkt *pkt;
unsigned int len; unsigned int len;
...@@ -317,6 +349,8 @@ static void virtio_transport_tx_work(struct work_struct *work) ...@@ -317,6 +349,8 @@ static void virtio_transport_tx_work(struct work_struct *work)
added = true; added = true;
} }
} while (!virtqueue_enable_cb(vq)); } while (!virtqueue_enable_cb(vq));
out:
mutex_unlock(&vsock->tx_lock); mutex_unlock(&vsock->tx_lock);
if (added) if (added)
...@@ -345,6 +379,9 @@ static void virtio_transport_rx_work(struct work_struct *work) ...@@ -345,6 +379,9 @@ static void virtio_transport_rx_work(struct work_struct *work)
mutex_lock(&vsock->rx_lock); mutex_lock(&vsock->rx_lock);
if (!vsock->rx_run)
goto out;
do { do {
virtqueue_disable_cb(vq); virtqueue_disable_cb(vq);
for (;;) { for (;;) {
...@@ -454,6 +491,9 @@ static void virtio_transport_event_work(struct work_struct *work) ...@@ -454,6 +491,9 @@ static void virtio_transport_event_work(struct work_struct *work)
mutex_lock(&vsock->event_lock); mutex_lock(&vsock->event_lock);
if (!vsock->event_run)
goto out;
do { do {
struct virtio_vsock_event *event; struct virtio_vsock_event *event;
unsigned int len; unsigned int len;
...@@ -468,7 +508,7 @@ static void virtio_transport_event_work(struct work_struct *work) ...@@ -468,7 +508,7 @@ static void virtio_transport_event_work(struct work_struct *work)
} while (!virtqueue_enable_cb(vq)); } while (!virtqueue_enable_cb(vq));
virtqueue_kick(vsock->vqs[VSOCK_VQ_EVENT]); virtqueue_kick(vsock->vqs[VSOCK_VQ_EVENT]);
out:
mutex_unlock(&vsock->event_lock); mutex_unlock(&vsock->event_lock);
} }
...@@ -565,7 +605,8 @@ static int virtio_vsock_probe(struct virtio_device *vdev) ...@@ -565,7 +605,8 @@ static int virtio_vsock_probe(struct virtio_device *vdev)
return ret; return ret;
/* Only one virtio-vsock device per guest is supported */ /* Only one virtio-vsock device per guest is supported */
if (the_virtio_vsock) { if (rcu_dereference_protected(the_virtio_vsock,
lockdep_is_held(&the_virtio_vsock_mutex))) {
ret = -EBUSY; ret = -EBUSY;
goto out; goto out;
} }
...@@ -590,8 +631,6 @@ static int virtio_vsock_probe(struct virtio_device *vdev) ...@@ -590,8 +631,6 @@ static int virtio_vsock_probe(struct virtio_device *vdev)
vsock->rx_buf_max_nr = 0; vsock->rx_buf_max_nr = 0;
atomic_set(&vsock->queued_replies, 0); atomic_set(&vsock->queued_replies, 0);
vdev->priv = vsock;
the_virtio_vsock = vsock;
mutex_init(&vsock->tx_lock); mutex_init(&vsock->tx_lock);
mutex_init(&vsock->rx_lock); mutex_init(&vsock->rx_lock);
mutex_init(&vsock->event_lock); mutex_init(&vsock->event_lock);
...@@ -605,14 +644,23 @@ static int virtio_vsock_probe(struct virtio_device *vdev) ...@@ -605,14 +644,23 @@ static int virtio_vsock_probe(struct virtio_device *vdev)
INIT_WORK(&vsock->send_pkt_work, virtio_transport_send_pkt_work); INIT_WORK(&vsock->send_pkt_work, virtio_transport_send_pkt_work);
INIT_WORK(&vsock->loopback_work, virtio_transport_loopback_work); INIT_WORK(&vsock->loopback_work, virtio_transport_loopback_work);
mutex_lock(&vsock->tx_lock);
vsock->tx_run = true;
mutex_unlock(&vsock->tx_lock);
mutex_lock(&vsock->rx_lock); mutex_lock(&vsock->rx_lock);
virtio_vsock_rx_fill(vsock); virtio_vsock_rx_fill(vsock);
vsock->rx_run = true;
mutex_unlock(&vsock->rx_lock); mutex_unlock(&vsock->rx_lock);
mutex_lock(&vsock->event_lock); mutex_lock(&vsock->event_lock);
virtio_vsock_event_fill(vsock); virtio_vsock_event_fill(vsock);
vsock->event_run = true;
mutex_unlock(&vsock->event_lock); mutex_unlock(&vsock->event_lock);
vdev->priv = vsock;
rcu_assign_pointer(the_virtio_vsock, vsock);
mutex_unlock(&the_virtio_vsock_mutex); mutex_unlock(&the_virtio_vsock_mutex);
return 0; return 0;
...@@ -627,15 +675,33 @@ static void virtio_vsock_remove(struct virtio_device *vdev) ...@@ -627,15 +675,33 @@ static void virtio_vsock_remove(struct virtio_device *vdev)
struct virtio_vsock *vsock = vdev->priv; struct virtio_vsock *vsock = vdev->priv;
struct virtio_vsock_pkt *pkt; struct virtio_vsock_pkt *pkt;
flush_work(&vsock->loopback_work); mutex_lock(&the_virtio_vsock_mutex);
flush_work(&vsock->rx_work);
flush_work(&vsock->tx_work); vdev->priv = NULL;
flush_work(&vsock->event_work); rcu_assign_pointer(the_virtio_vsock, NULL);
flush_work(&vsock->send_pkt_work); synchronize_rcu();
/* Reset all connected sockets when the device disappear */ /* Reset all connected sockets when the device disappear */
vsock_for_each_connected_socket(virtio_vsock_reset_sock); vsock_for_each_connected_socket(virtio_vsock_reset_sock);
/* Stop all work handlers to make sure no one is accessing the device,
* so we can safely call vdev->config->reset().
*/
mutex_lock(&vsock->rx_lock);
vsock->rx_run = false;
mutex_unlock(&vsock->rx_lock);
mutex_lock(&vsock->tx_lock);
vsock->tx_run = false;
mutex_unlock(&vsock->tx_lock);
mutex_lock(&vsock->event_lock);
vsock->event_run = false;
mutex_unlock(&vsock->event_lock);
/* Flush all device writes and interrupts, device will not use any
* more buffers.
*/
vdev->config->reset(vdev); vdev->config->reset(vdev);
mutex_lock(&vsock->rx_lock); mutex_lock(&vsock->rx_lock);
...@@ -666,12 +732,20 @@ static void virtio_vsock_remove(struct virtio_device *vdev) ...@@ -666,12 +732,20 @@ static void virtio_vsock_remove(struct virtio_device *vdev)
} }
spin_unlock_bh(&vsock->loopback_list_lock); spin_unlock_bh(&vsock->loopback_list_lock);
mutex_lock(&the_virtio_vsock_mutex); /* Delete virtqueues and flush outstanding callbacks if any */
the_virtio_vsock = NULL;
mutex_unlock(&the_virtio_vsock_mutex);
vdev->config->del_vqs(vdev); vdev->config->del_vqs(vdev);
/* Other works can be queued before 'config->del_vqs()', so we flush
* all works before to free the vsock object to avoid use after free.
*/
flush_work(&vsock->loopback_work);
flush_work(&vsock->rx_work);
flush_work(&vsock->tx_work);
flush_work(&vsock->event_work);
flush_work(&vsock->send_pkt_work);
mutex_unlock(&the_virtio_vsock_mutex);
kfree(vsock); kfree(vsock);
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment