diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c
index f0fd52cdfadc05b6fcd57b9efff0eddee4173d2a..70ac60437d174411f891f20f8b477b20b74e8e8e 100644
--- a/drivers/vhost/net.c
+++ b/drivers/vhost/net.c
@@ -703,6 +703,10 @@ static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd)
 		vhost_net_disable_vq(n, vq);
 		rcu_assign_pointer(vq->private_data, sock);
 		vhost_net_enable_vq(n, vq);
+
+		r = vhost_init_used(vq);
+		if (r)
+			goto err_vq;
 	}
 
 	mutex_unlock(&vq->mutex);
diff --git a/drivers/vhost/test.c b/drivers/vhost/test.c
index 734e1d74ad805a1547ed867dd8363f09bcd9e2f8..fc9a1d75281f33d57be305b3c724195c578334de 100644
--- a/drivers/vhost/test.c
+++ b/drivers/vhost/test.c
@@ -195,8 +195,13 @@ static long vhost_test_run(struct vhost_test *n, int test)
 						    lockdep_is_held(&vq->mutex));
 		rcu_assign_pointer(vq->private_data, priv);
 
+		r = vhost_init_used(&n->vqs[index]);
+
 		mutex_unlock(&vq->mutex);
 
+		if (r)
+			goto err;
+
 		if (oldpriv) {
 			vhost_test_flush_vq(n, index);
 		}
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index 5ef2f62becf4b94c69082b4dc1b1a2f6eacd84ac..9a108038fe527393f6ac2b04a2104916b1518290 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -629,15 +629,17 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
 	return 0;
 }
 
-static int init_used(struct vhost_virtqueue *vq,
-		     struct vring_used __user *used)
+int vhost_init_used(struct vhost_virtqueue *vq)
 {
-	int r = put_user(vq->used_flags, &used->flags);
+	int r;
+	if (!vq->private_data)
+		return 0;
 
+	r = put_user(vq->used_flags, &vq->used->flags);
 	if (r)
 		return r;
 	vq->signalled_used_valid = false;
-	return get_user(vq->last_used_idx, &used->idx);
+	return get_user(vq->last_used_idx, &vq->used->idx);
 }
 
 static long vhost_set_vring(struct vhost_dev *d, int ioctl, void __user *argp)
@@ -752,10 +754,6 @@ static long vhost_set_vring(struct vhost_dev *d, int ioctl, void __user *argp)
 			}
 		}
 
-		r = init_used(vq, (struct vring_used __user *)(unsigned long)
-			      a.used_user_addr);
-		if (r)
-			break;
 		vq->log_used = !!(a.flags & (0x1 << VHOST_VRING_F_LOG));
 		vq->desc = (void __user *)(unsigned long)a.desc_user_addr;
 		vq->avail = (void __user *)(unsigned long)a.avail_user_addr;
diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h
index 1544b782529b5803730ac6e447ac698caa7c7239..14c9abf0d80025fd863460c1bd935d5b49f88c26 100644
--- a/drivers/vhost/vhost.h
+++ b/drivers/vhost/vhost.h
@@ -174,6 +174,7 @@ int vhost_get_vq_desc(struct vhost_dev *, struct vhost_virtqueue *,
 		      struct vhost_log *log, unsigned int *log_num);
 void vhost_discard_vq_desc(struct vhost_virtqueue *, int n);
 
+int vhost_init_used(struct vhost_virtqueue *);
 int vhost_add_used(struct vhost_virtqueue *, unsigned int head, int len);
 int vhost_add_used_n(struct vhost_virtqueue *, struct vring_used_elem *heads,
 		     unsigned count);