Commit a7722890 authored by huangjie.albert's avatar huangjie.albert Committed by Michael S. Tsirkin

virtio_ring : keep used_wrap_counter in vq->last_used_idx

the used_wrap_counter and the vq->last_used_idx may get
out of sync if they are separate assignment,and interrupt
might use an incorrect value to check for the used index.

for example:OOB access
ksoftirqd may consume the packet and it will call:
virtnet_poll
	-->virtnet_receive
		-->virtqueue_get_buf_ctx
			-->virtqueue_get_buf_ctx_packed
and in virtqueue_get_buf_ctx_packed:

vq->last_used_idx += vq->packed.desc_state[id].num;
if (unlikely(vq->last_used_idx >= vq->packed.vring.num)) {
         vq->last_used_idx -= vq->packed.vring.num;
         vq->packed.used_wrap_counter ^= 1;
}

if at the same time, there comes a vring interrupt,in vring_interrupt:
we will call:
vring_interrupt
	-->more_used
		-->more_used_packed
			-->is_used_desc_packed
in is_used_desc_packed, the last_used_idx maybe >= vq->packed.vring.num.
so this could case a memory out of bounds bug.

this patch is to keep the used_wrap_counter in vq->last_used_idx
so we can get the correct value to check for used index in interrupt.

v3->v4:
- use READ_ONCE/WRITE_ONCE to get/set vq->last_used_idx

v2->v3:
- add inline function to get used_wrap_counter and last_used
- when use vq->last_used_idx, only read once
  if vq->last_used_idx is read twice, the values can be inconsistent.
- use last_used_idx & ~(-(1 << VRING_PACKED_EVENT_F_WRAP_CTR))
  to get the all bits below VRING_PACKED_EVENT_F_WRAP_CTR

v1->v2:
- reuse the VRING_PACKED_EVENT_F_WRAP_CTR
- Remove parameter judgment in is_used_desc_packed,
because it can't be illegal
Signed-off-by: default avatarhuangjie.albert <huangjie.albert@bytedance.com>
Message-Id: <20220617020411.80367-1-huangjie.albert@bytedance.com>
Signed-off-by: default avatarMichael S. Tsirkin <mst@redhat.com>
parent 0e0348ac
...@@ -111,7 +111,12 @@ struct vring_virtqueue { ...@@ -111,7 +111,12 @@ struct vring_virtqueue {
/* Number we've added since last sync. */ /* Number we've added since last sync. */
unsigned int num_added; unsigned int num_added;
/* Last used index we've seen. */ /* Last used index we've seen.
* for split ring, it just contains last used index
* for packed ring:
* bits up to VRING_PACKED_EVENT_F_WRAP_CTR include the last used index.
* bits from VRING_PACKED_EVENT_F_WRAP_CTR include the used wrap counter.
*/
u16 last_used_idx; u16 last_used_idx;
/* Hint for event idx: already triggered no need to disable. */ /* Hint for event idx: already triggered no need to disable. */
...@@ -154,9 +159,6 @@ struct vring_virtqueue { ...@@ -154,9 +159,6 @@ struct vring_virtqueue {
/* Driver ring wrap counter. */ /* Driver ring wrap counter. */
bool avail_wrap_counter; bool avail_wrap_counter;
/* Device ring wrap counter. */
bool used_wrap_counter;
/* Avail used flags. */ /* Avail used flags. */
u16 avail_used_flags; u16 avail_used_flags;
...@@ -973,6 +975,15 @@ static struct virtqueue *vring_create_virtqueue_split( ...@@ -973,6 +975,15 @@ static struct virtqueue *vring_create_virtqueue_split(
/* /*
* Packed ring specific functions - *_packed(). * Packed ring specific functions - *_packed().
*/ */
static inline bool packed_used_wrap_counter(u16 last_used_idx)
{
return !!(last_used_idx & (1 << VRING_PACKED_EVENT_F_WRAP_CTR));
}
static inline u16 packed_last_used(u16 last_used_idx)
{
return last_used_idx & ~(-(1 << VRING_PACKED_EVENT_F_WRAP_CTR));
}
static void vring_unmap_extra_packed(const struct vring_virtqueue *vq, static void vring_unmap_extra_packed(const struct vring_virtqueue *vq,
struct vring_desc_extra *extra) struct vring_desc_extra *extra)
...@@ -1406,8 +1417,14 @@ static inline bool is_used_desc_packed(const struct vring_virtqueue *vq, ...@@ -1406,8 +1417,14 @@ static inline bool is_used_desc_packed(const struct vring_virtqueue *vq,
static inline bool more_used_packed(const struct vring_virtqueue *vq) static inline bool more_used_packed(const struct vring_virtqueue *vq)
{ {
return is_used_desc_packed(vq, vq->last_used_idx, u16 last_used;
vq->packed.used_wrap_counter); u16 last_used_idx;
bool used_wrap_counter;
last_used_idx = READ_ONCE(vq->last_used_idx);
last_used = packed_last_used(last_used_idx);
used_wrap_counter = packed_used_wrap_counter(last_used_idx);
return is_used_desc_packed(vq, last_used, used_wrap_counter);
} }
static void *virtqueue_get_buf_ctx_packed(struct virtqueue *_vq, static void *virtqueue_get_buf_ctx_packed(struct virtqueue *_vq,
...@@ -1415,7 +1432,8 @@ static void *virtqueue_get_buf_ctx_packed(struct virtqueue *_vq, ...@@ -1415,7 +1432,8 @@ static void *virtqueue_get_buf_ctx_packed(struct virtqueue *_vq,
void **ctx) void **ctx)
{ {
struct vring_virtqueue *vq = to_vvq(_vq); struct vring_virtqueue *vq = to_vvq(_vq);
u16 last_used, id; u16 last_used, id, last_used_idx;
bool used_wrap_counter;
void *ret; void *ret;
START_USE(vq); START_USE(vq);
...@@ -1434,7 +1452,9 @@ static void *virtqueue_get_buf_ctx_packed(struct virtqueue *_vq, ...@@ -1434,7 +1452,9 @@ static void *virtqueue_get_buf_ctx_packed(struct virtqueue *_vq,
/* Only get used elements after they have been exposed by host. */ /* Only get used elements after they have been exposed by host. */
virtio_rmb(vq->weak_barriers); virtio_rmb(vq->weak_barriers);
last_used = vq->last_used_idx; last_used_idx = READ_ONCE(vq->last_used_idx);
used_wrap_counter = packed_used_wrap_counter(last_used_idx);
last_used = packed_last_used(last_used_idx);
id = le16_to_cpu(vq->packed.vring.desc[last_used].id); id = le16_to_cpu(vq->packed.vring.desc[last_used].id);
*len = le32_to_cpu(vq->packed.vring.desc[last_used].len); *len = le32_to_cpu(vq->packed.vring.desc[last_used].len);
...@@ -1451,12 +1471,15 @@ static void *virtqueue_get_buf_ctx_packed(struct virtqueue *_vq, ...@@ -1451,12 +1471,15 @@ static void *virtqueue_get_buf_ctx_packed(struct virtqueue *_vq,
ret = vq->packed.desc_state[id].data; ret = vq->packed.desc_state[id].data;
detach_buf_packed(vq, id, ctx); detach_buf_packed(vq, id, ctx);
vq->last_used_idx += vq->packed.desc_state[id].num; last_used += vq->packed.desc_state[id].num;
if (unlikely(vq->last_used_idx >= vq->packed.vring.num)) { if (unlikely(last_used >= vq->packed.vring.num)) {
vq->last_used_idx -= vq->packed.vring.num; last_used -= vq->packed.vring.num;
vq->packed.used_wrap_counter ^= 1; used_wrap_counter ^= 1;
} }
last_used = (last_used | (used_wrap_counter << VRING_PACKED_EVENT_F_WRAP_CTR));
WRITE_ONCE(vq->last_used_idx, last_used);
/* /*
* If we expect an interrupt for the next entry, tell host * If we expect an interrupt for the next entry, tell host
* by writing event index and flush out the write before * by writing event index and flush out the write before
...@@ -1465,9 +1488,7 @@ static void *virtqueue_get_buf_ctx_packed(struct virtqueue *_vq, ...@@ -1465,9 +1488,7 @@ static void *virtqueue_get_buf_ctx_packed(struct virtqueue *_vq,
if (vq->packed.event_flags_shadow == VRING_PACKED_EVENT_FLAG_DESC) if (vq->packed.event_flags_shadow == VRING_PACKED_EVENT_FLAG_DESC)
virtio_store_mb(vq->weak_barriers, virtio_store_mb(vq->weak_barriers,
&vq->packed.vring.driver->off_wrap, &vq->packed.vring.driver->off_wrap,
cpu_to_le16(vq->last_used_idx | cpu_to_le16(vq->last_used_idx));
(vq->packed.used_wrap_counter <<
VRING_PACKED_EVENT_F_WRAP_CTR)));
LAST_ADD_TIME_INVALID(vq); LAST_ADD_TIME_INVALID(vq);
...@@ -1499,9 +1520,7 @@ static unsigned int virtqueue_enable_cb_prepare_packed(struct virtqueue *_vq) ...@@ -1499,9 +1520,7 @@ static unsigned int virtqueue_enable_cb_prepare_packed(struct virtqueue *_vq)
if (vq->event) { if (vq->event) {
vq->packed.vring.driver->off_wrap = vq->packed.vring.driver->off_wrap =
cpu_to_le16(vq->last_used_idx | cpu_to_le16(vq->last_used_idx);
(vq->packed.used_wrap_counter <<
VRING_PACKED_EVENT_F_WRAP_CTR));
/* /*
* We need to update event offset and event wrap * We need to update event offset and event wrap
* counter first before updating event flags. * counter first before updating event flags.
...@@ -1518,8 +1537,7 @@ static unsigned int virtqueue_enable_cb_prepare_packed(struct virtqueue *_vq) ...@@ -1518,8 +1537,7 @@ static unsigned int virtqueue_enable_cb_prepare_packed(struct virtqueue *_vq)
} }
END_USE(vq); END_USE(vq);
return vq->last_used_idx | ((u16)vq->packed.used_wrap_counter << return vq->last_used_idx;
VRING_PACKED_EVENT_F_WRAP_CTR);
} }
static bool virtqueue_poll_packed(struct virtqueue *_vq, u16 off_wrap) static bool virtqueue_poll_packed(struct virtqueue *_vq, u16 off_wrap)
...@@ -1537,7 +1555,7 @@ static bool virtqueue_poll_packed(struct virtqueue *_vq, u16 off_wrap) ...@@ -1537,7 +1555,7 @@ static bool virtqueue_poll_packed(struct virtqueue *_vq, u16 off_wrap)
static bool virtqueue_enable_cb_delayed_packed(struct virtqueue *_vq) static bool virtqueue_enable_cb_delayed_packed(struct virtqueue *_vq)
{ {
struct vring_virtqueue *vq = to_vvq(_vq); struct vring_virtqueue *vq = to_vvq(_vq);
u16 used_idx, wrap_counter; u16 used_idx, wrap_counter, last_used_idx;
u16 bufs; u16 bufs;
START_USE(vq); START_USE(vq);
...@@ -1550,9 +1568,10 @@ static bool virtqueue_enable_cb_delayed_packed(struct virtqueue *_vq) ...@@ -1550,9 +1568,10 @@ static bool virtqueue_enable_cb_delayed_packed(struct virtqueue *_vq)
if (vq->event) { if (vq->event) {
/* TODO: tune this threshold */ /* TODO: tune this threshold */
bufs = (vq->packed.vring.num - vq->vq.num_free) * 3 / 4; bufs = (vq->packed.vring.num - vq->vq.num_free) * 3 / 4;
wrap_counter = vq->packed.used_wrap_counter; last_used_idx = READ_ONCE(vq->last_used_idx);
wrap_counter = packed_used_wrap_counter(last_used_idx);
used_idx = vq->last_used_idx + bufs; used_idx = packed_last_used(last_used_idx) + bufs;
if (used_idx >= vq->packed.vring.num) { if (used_idx >= vq->packed.vring.num) {
used_idx -= vq->packed.vring.num; used_idx -= vq->packed.vring.num;
wrap_counter ^= 1; wrap_counter ^= 1;
...@@ -1582,9 +1601,10 @@ static bool virtqueue_enable_cb_delayed_packed(struct virtqueue *_vq) ...@@ -1582,9 +1601,10 @@ static bool virtqueue_enable_cb_delayed_packed(struct virtqueue *_vq)
*/ */
virtio_mb(vq->weak_barriers); virtio_mb(vq->weak_barriers);
if (is_used_desc_packed(vq, last_used_idx = READ_ONCE(vq->last_used_idx);
vq->last_used_idx, wrap_counter = packed_used_wrap_counter(last_used_idx);
vq->packed.used_wrap_counter)) { used_idx = packed_last_used(last_used_idx);
if (is_used_desc_packed(vq, used_idx, wrap_counter)) {
END_USE(vq); END_USE(vq);
return false; return false;
} }
...@@ -1689,7 +1709,7 @@ static struct virtqueue *vring_create_virtqueue_packed( ...@@ -1689,7 +1709,7 @@ static struct virtqueue *vring_create_virtqueue_packed(
vq->notify = notify; vq->notify = notify;
vq->weak_barriers = weak_barriers; vq->weak_barriers = weak_barriers;
vq->broken = true; vq->broken = true;
vq->last_used_idx = 0; vq->last_used_idx = 0 | (1 << VRING_PACKED_EVENT_F_WRAP_CTR);
vq->event_triggered = false; vq->event_triggered = false;
vq->num_added = 0; vq->num_added = 0;
vq->packed_ring = true; vq->packed_ring = true;
...@@ -1720,7 +1740,6 @@ static struct virtqueue *vring_create_virtqueue_packed( ...@@ -1720,7 +1740,6 @@ static struct virtqueue *vring_create_virtqueue_packed(
vq->packed.next_avail_idx = 0; vq->packed.next_avail_idx = 0;
vq->packed.avail_wrap_counter = 1; vq->packed.avail_wrap_counter = 1;
vq->packed.used_wrap_counter = 1;
vq->packed.event_flags_shadow = 0; vq->packed.event_flags_shadow = 0;
vq->packed.avail_used_flags = 1 << VRING_PACKED_DESC_F_AVAIL; vq->packed.avail_used_flags = 1 << VRING_PACKED_DESC_F_AVAIL;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment