Commit 6e5f6a86 authored by Linus Torvalds's avatar Linus Torvalds

Merge tag 'for_linus' of git://git.kernel.org/pub/scm/linux/kernel/git/mst/vhost

Pull virtio updates from Michael Tsirkin:
 "vhost,virtio and vdpa features, fixes, and cleanups:

   - mac vlan filter and stats support in mlx5 vdpa

   - irq hardening in virtio

   - performance improvements in virtio crypto

   - polling i/o support in virtio blk

   - ASID support in vhost

   - fixes, cleanups all over the place"

* tag 'for_linus' of git://git.kernel.org/pub/scm/linux/kernel/git/mst/vhost: (64 commits)
  vdpa: ifcvf: set pci driver data in probe
  vdpa/mlx5: Add RX MAC VLAN filter support
  vdpa/mlx5: Remove flow counter from steering
  vhost: rename vhost_work_dev_flush
  vhost-test: drop flush after vhost_dev_cleanup
  vhost-scsi: drop flush after vhost_dev_cleanup
  vhost_vsock: simplify vhost_vsock_flush()
  vhost_test: remove vhost_test_flush_vq()
  vhost_net: get rid of vhost_net_flush_vq() and extra flush calls
  vhost: flush dev once during vhost_dev_stop
  vhost: get rid of vhost_poll_flush() wrapper
  vhost-vdpa: return -EFAULT on copy_to_user() failure
  vdpasim: Off by one in vdpasim_set_group_asid()
  virtio: Directly use ida_alloc()/free()
  virtio: use WARN_ON() to warning illegal status value
  virtio: harden vring IRQ
  virtio: allow to unbreak virtqueue
  virtio-ccw: implement synchronize_cbs()
  virtio-mmio: implement synchronize_cbs()
  virtio-pci: implement synchronize_cbs()
  ...
parents 6f6ebb98 bd8bb9ae
...@@ -37,6 +37,10 @@ MODULE_PARM_DESC(num_request_queues, ...@@ -37,6 +37,10 @@ MODULE_PARM_DESC(num_request_queues,
"0 for no limit. " "0 for no limit. "
"Values > nr_cpu_ids truncated to nr_cpu_ids."); "Values > nr_cpu_ids truncated to nr_cpu_ids.");
static unsigned int poll_queues;
module_param(poll_queues, uint, 0644);
MODULE_PARM_DESC(poll_queues, "The number of dedicated virtqueues for polling I/O");
static int major; static int major;
static DEFINE_IDA(vd_index_ida); static DEFINE_IDA(vd_index_ida);
...@@ -74,6 +78,7 @@ struct virtio_blk { ...@@ -74,6 +78,7 @@ struct virtio_blk {
/* num of vqs */ /* num of vqs */
int num_vqs; int num_vqs;
int io_queues[HCTX_MAX_TYPES];
struct virtio_blk_vq *vqs; struct virtio_blk_vq *vqs;
}; };
...@@ -96,8 +101,7 @@ static inline blk_status_t virtblk_result(struct virtblk_req *vbr) ...@@ -96,8 +101,7 @@ static inline blk_status_t virtblk_result(struct virtblk_req *vbr)
} }
} }
static int virtblk_add_req(struct virtqueue *vq, struct virtblk_req *vbr, static int virtblk_add_req(struct virtqueue *vq, struct virtblk_req *vbr)
struct scatterlist *data_sg, bool have_data)
{ {
struct scatterlist hdr, status, *sgs[3]; struct scatterlist hdr, status, *sgs[3];
unsigned int num_out = 0, num_in = 0; unsigned int num_out = 0, num_in = 0;
...@@ -105,11 +109,11 @@ static int virtblk_add_req(struct virtqueue *vq, struct virtblk_req *vbr, ...@@ -105,11 +109,11 @@ static int virtblk_add_req(struct virtqueue *vq, struct virtblk_req *vbr,
sg_init_one(&hdr, &vbr->out_hdr, sizeof(vbr->out_hdr)); sg_init_one(&hdr, &vbr->out_hdr, sizeof(vbr->out_hdr));
sgs[num_out++] = &hdr; sgs[num_out++] = &hdr;
if (have_data) { if (vbr->sg_table.nents) {
if (vbr->out_hdr.type & cpu_to_virtio32(vq->vdev, VIRTIO_BLK_T_OUT)) if (vbr->out_hdr.type & cpu_to_virtio32(vq->vdev, VIRTIO_BLK_T_OUT))
sgs[num_out++] = data_sg; sgs[num_out++] = vbr->sg_table.sgl;
else else
sgs[num_out + num_in++] = data_sg; sgs[num_out + num_in++] = vbr->sg_table.sgl;
} }
sg_init_one(&status, &vbr->status, sizeof(vbr->status)); sg_init_one(&status, &vbr->status, sizeof(vbr->status));
...@@ -299,6 +303,28 @@ static void virtio_commit_rqs(struct blk_mq_hw_ctx *hctx) ...@@ -299,6 +303,28 @@ static void virtio_commit_rqs(struct blk_mq_hw_ctx *hctx)
virtqueue_notify(vq->vq); virtqueue_notify(vq->vq);
} }
static blk_status_t virtblk_prep_rq(struct blk_mq_hw_ctx *hctx,
struct virtio_blk *vblk,
struct request *req,
struct virtblk_req *vbr)
{
blk_status_t status;
status = virtblk_setup_cmd(vblk->vdev, req, vbr);
if (unlikely(status))
return status;
blk_mq_start_request(req);
vbr->sg_table.nents = virtblk_map_data(hctx, req, vbr);
if (unlikely(vbr->sg_table.nents < 0)) {
virtblk_cleanup_cmd(req);
return BLK_STS_RESOURCE;
}
return BLK_STS_OK;
}
static blk_status_t virtio_queue_rq(struct blk_mq_hw_ctx *hctx, static blk_status_t virtio_queue_rq(struct blk_mq_hw_ctx *hctx,
const struct blk_mq_queue_data *bd) const struct blk_mq_queue_data *bd)
{ {
...@@ -306,26 +332,17 @@ static blk_status_t virtio_queue_rq(struct blk_mq_hw_ctx *hctx, ...@@ -306,26 +332,17 @@ static blk_status_t virtio_queue_rq(struct blk_mq_hw_ctx *hctx,
struct request *req = bd->rq; struct request *req = bd->rq;
struct virtblk_req *vbr = blk_mq_rq_to_pdu(req); struct virtblk_req *vbr = blk_mq_rq_to_pdu(req);
unsigned long flags; unsigned long flags;
int num;
int qid = hctx->queue_num; int qid = hctx->queue_num;
bool notify = false; bool notify = false;
blk_status_t status; blk_status_t status;
int err; int err;
status = virtblk_setup_cmd(vblk->vdev, req, vbr); status = virtblk_prep_rq(hctx, vblk, req, vbr);
if (unlikely(status)) if (unlikely(status))
return status; return status;
blk_mq_start_request(req);
num = virtblk_map_data(hctx, req, vbr);
if (unlikely(num < 0)) {
virtblk_cleanup_cmd(req);
return BLK_STS_RESOURCE;
}
spin_lock_irqsave(&vblk->vqs[qid].lock, flags); spin_lock_irqsave(&vblk->vqs[qid].lock, flags);
err = virtblk_add_req(vblk->vqs[qid].vq, vbr, vbr->sg_table.sgl, num); err = virtblk_add_req(vblk->vqs[qid].vq, vbr);
if (err) { if (err) {
virtqueue_kick(vblk->vqs[qid].vq); virtqueue_kick(vblk->vqs[qid].vq);
/* Don't stop the queue if -ENOMEM: we may have failed to /* Don't stop the queue if -ENOMEM: we may have failed to
...@@ -355,6 +372,75 @@ static blk_status_t virtio_queue_rq(struct blk_mq_hw_ctx *hctx, ...@@ -355,6 +372,75 @@ static blk_status_t virtio_queue_rq(struct blk_mq_hw_ctx *hctx,
return BLK_STS_OK; return BLK_STS_OK;
} }
static bool virtblk_prep_rq_batch(struct request *req)
{
struct virtio_blk *vblk = req->mq_hctx->queue->queuedata;
struct virtblk_req *vbr = blk_mq_rq_to_pdu(req);
req->mq_hctx->tags->rqs[req->tag] = req;
return virtblk_prep_rq(req->mq_hctx, vblk, req, vbr) == BLK_STS_OK;
}
static bool virtblk_add_req_batch(struct virtio_blk_vq *vq,
struct request **rqlist,
struct request **requeue_list)
{
unsigned long flags;
int err;
bool kick;
spin_lock_irqsave(&vq->lock, flags);
while (!rq_list_empty(*rqlist)) {
struct request *req = rq_list_pop(rqlist);
struct virtblk_req *vbr = blk_mq_rq_to_pdu(req);
err = virtblk_add_req(vq->vq, vbr);
if (err) {
virtblk_unmap_data(req, vbr);
virtblk_cleanup_cmd(req);
rq_list_add(requeue_list, req);
}
}
kick = virtqueue_kick_prepare(vq->vq);
spin_unlock_irqrestore(&vq->lock, flags);
return kick;
}
static void virtio_queue_rqs(struct request **rqlist)
{
struct request *req, *next, *prev = NULL;
struct request *requeue_list = NULL;
rq_list_for_each_safe(rqlist, req, next) {
struct virtio_blk_vq *vq = req->mq_hctx->driver_data;
bool kick;
if (!virtblk_prep_rq_batch(req)) {
rq_list_move(rqlist, &requeue_list, req, prev);
req = prev;
if (!req)
continue;
}
if (!next || req->mq_hctx != next->mq_hctx) {
req->rq_next = NULL;
kick = virtblk_add_req_batch(vq, rqlist, &requeue_list);
if (kick)
virtqueue_notify(vq->vq);
*rqlist = next;
prev = NULL;
} else
prev = req;
}
*rqlist = requeue_list;
}
/* return id (s/n) string for *disk to *id_str /* return id (s/n) string for *disk to *id_str
*/ */
static int virtblk_get_id(struct gendisk *disk, char *id_str) static int virtblk_get_id(struct gendisk *disk, char *id_str)
...@@ -512,6 +598,7 @@ static int init_vq(struct virtio_blk *vblk) ...@@ -512,6 +598,7 @@ static int init_vq(struct virtio_blk *vblk)
const char **names; const char **names;
struct virtqueue **vqs; struct virtqueue **vqs;
unsigned short num_vqs; unsigned short num_vqs;
unsigned int num_poll_vqs;
struct virtio_device *vdev = vblk->vdev; struct virtio_device *vdev = vblk->vdev;
struct irq_affinity desc = { 0, }; struct irq_affinity desc = { 0, };
...@@ -520,6 +607,7 @@ static int init_vq(struct virtio_blk *vblk) ...@@ -520,6 +607,7 @@ static int init_vq(struct virtio_blk *vblk)
&num_vqs); &num_vqs);
if (err) if (err)
num_vqs = 1; num_vqs = 1;
if (!err && !num_vqs) { if (!err && !num_vqs) {
dev_err(&vdev->dev, "MQ advertised but zero queues reported\n"); dev_err(&vdev->dev, "MQ advertised but zero queues reported\n");
return -EINVAL; return -EINVAL;
...@@ -529,6 +617,17 @@ static int init_vq(struct virtio_blk *vblk) ...@@ -529,6 +617,17 @@ static int init_vq(struct virtio_blk *vblk)
min_not_zero(num_request_queues, nr_cpu_ids), min_not_zero(num_request_queues, nr_cpu_ids),
num_vqs); num_vqs);
num_poll_vqs = min_t(unsigned int, poll_queues, num_vqs - 1);
vblk->io_queues[HCTX_TYPE_DEFAULT] = num_vqs - num_poll_vqs;
vblk->io_queues[HCTX_TYPE_READ] = 0;
vblk->io_queues[HCTX_TYPE_POLL] = num_poll_vqs;
dev_info(&vdev->dev, "%d/%d/%d default/read/poll queues\n",
vblk->io_queues[HCTX_TYPE_DEFAULT],
vblk->io_queues[HCTX_TYPE_READ],
vblk->io_queues[HCTX_TYPE_POLL]);
vblk->vqs = kmalloc_array(num_vqs, sizeof(*vblk->vqs), GFP_KERNEL); vblk->vqs = kmalloc_array(num_vqs, sizeof(*vblk->vqs), GFP_KERNEL);
if (!vblk->vqs) if (!vblk->vqs)
return -ENOMEM; return -ENOMEM;
...@@ -541,12 +640,18 @@ static int init_vq(struct virtio_blk *vblk) ...@@ -541,12 +640,18 @@ static int init_vq(struct virtio_blk *vblk)
goto out; goto out;
} }
for (i = 0; i < num_vqs; i++) { for (i = 0; i < num_vqs - num_poll_vqs; i++) {
callbacks[i] = virtblk_done; callbacks[i] = virtblk_done;
snprintf(vblk->vqs[i].name, VQ_NAME_LEN, "req.%d", i); snprintf(vblk->vqs[i].name, VQ_NAME_LEN, "req.%d", i);
names[i] = vblk->vqs[i].name; names[i] = vblk->vqs[i].name;
} }
for (; i < num_vqs; i++) {
callbacks[i] = NULL;
snprintf(vblk->vqs[i].name, VQ_NAME_LEN, "req_poll.%d", i);
names[i] = vblk->vqs[i].name;
}
/* Discover virtqueues and write information to configuration. */ /* Discover virtqueues and write information to configuration. */
err = virtio_find_vqs(vdev, num_vqs, vqs, callbacks, names, &desc); err = virtio_find_vqs(vdev, num_vqs, vqs, callbacks, names, &desc);
if (err) if (err)
...@@ -692,16 +797,90 @@ static const struct attribute_group *virtblk_attr_groups[] = { ...@@ -692,16 +797,90 @@ static const struct attribute_group *virtblk_attr_groups[] = {
static int virtblk_map_queues(struct blk_mq_tag_set *set) static int virtblk_map_queues(struct blk_mq_tag_set *set)
{ {
struct virtio_blk *vblk = set->driver_data; struct virtio_blk *vblk = set->driver_data;
int i, qoff;
for (i = 0, qoff = 0; i < set->nr_maps; i++) {
struct blk_mq_queue_map *map = &set->map[i];
map->nr_queues = vblk->io_queues[i];
map->queue_offset = qoff;
qoff += map->nr_queues;
if (map->nr_queues == 0)
continue;
/*
* Regular queues have interrupts and hence CPU affinity is
* defined by the core virtio code, but polling queues have
* no interrupts so we let the block layer assign CPU affinity.
*/
if (i == HCTX_TYPE_POLL)
blk_mq_map_queues(&set->map[i]);
else
blk_mq_virtio_map_queues(&set->map[i], vblk->vdev, 0);
}
return 0;
}
static void virtblk_complete_batch(struct io_comp_batch *iob)
{
struct request *req;
rq_list_for_each(&iob->req_list, req) {
virtblk_unmap_data(req, blk_mq_rq_to_pdu(req));
virtblk_cleanup_cmd(req);
}
blk_mq_end_request_batch(iob);
}
static int virtblk_poll(struct blk_mq_hw_ctx *hctx, struct io_comp_batch *iob)
{
struct virtio_blk *vblk = hctx->queue->queuedata;
struct virtio_blk_vq *vq = hctx->driver_data;
struct virtblk_req *vbr;
unsigned long flags;
unsigned int len;
int found = 0;
return blk_mq_virtio_map_queues(&set->map[HCTX_TYPE_DEFAULT], spin_lock_irqsave(&vq->lock, flags);
vblk->vdev, 0);
while ((vbr = virtqueue_get_buf(vq->vq, &len)) != NULL) {
struct request *req = blk_mq_rq_from_pdu(vbr);
found++;
if (!blk_mq_add_to_batch(req, iob, vbr->status,
virtblk_complete_batch))
blk_mq_complete_request(req);
}
if (found)
blk_mq_start_stopped_hw_queues(vblk->disk->queue, true);
spin_unlock_irqrestore(&vq->lock, flags);
return found;
}
static int virtblk_init_hctx(struct blk_mq_hw_ctx *hctx, void *data,
unsigned int hctx_idx)
{
struct virtio_blk *vblk = data;
struct virtio_blk_vq *vq = &vblk->vqs[hctx_idx];
WARN_ON(vblk->tag_set.tags[hctx_idx] != hctx->tags);
hctx->driver_data = vq;
return 0;
} }
static const struct blk_mq_ops virtio_mq_ops = { static const struct blk_mq_ops virtio_mq_ops = {
.queue_rq = virtio_queue_rq, .queue_rq = virtio_queue_rq,
.queue_rqs = virtio_queue_rqs,
.commit_rqs = virtio_commit_rqs, .commit_rqs = virtio_commit_rqs,
.init_hctx = virtblk_init_hctx,
.complete = virtblk_request_done, .complete = virtblk_request_done,
.map_queues = virtblk_map_queues, .map_queues = virtblk_map_queues,
.poll = virtblk_poll,
}; };
static unsigned int virtblk_queue_depth; static unsigned int virtblk_queue_depth;
...@@ -778,6 +957,9 @@ static int virtblk_probe(struct virtio_device *vdev) ...@@ -778,6 +957,9 @@ static int virtblk_probe(struct virtio_device *vdev)
sizeof(struct scatterlist) * VIRTIO_BLK_INLINE_SG_CNT; sizeof(struct scatterlist) * VIRTIO_BLK_INLINE_SG_CNT;
vblk->tag_set.driver_data = vblk; vblk->tag_set.driver_data = vblk;
vblk->tag_set.nr_hw_queues = vblk->num_vqs; vblk->tag_set.nr_hw_queues = vblk->num_vqs;
vblk->tag_set.nr_maps = 1;
if (vblk->io_queues[HCTX_TYPE_POLL])
vblk->tag_set.nr_maps = 3;
err = blk_mq_alloc_tag_set(&vblk->tag_set); err = blk_mq_alloc_tag_set(&vblk->tag_set);
if (err) if (err)
......
...@@ -90,9 +90,12 @@ static void virtio_crypto_dataq_akcipher_callback(struct virtio_crypto_request * ...@@ -90,9 +90,12 @@ static void virtio_crypto_dataq_akcipher_callback(struct virtio_crypto_request *
} }
akcipher_req = vc_akcipher_req->akcipher_req; akcipher_req = vc_akcipher_req->akcipher_req;
if (vc_akcipher_req->opcode != VIRTIO_CRYPTO_AKCIPHER_VERIFY) if (vc_akcipher_req->opcode != VIRTIO_CRYPTO_AKCIPHER_VERIFY) {
/* actuall length maybe less than dst buffer */
akcipher_req->dst_len = len - sizeof(vc_req->status);
sg_copy_from_buffer(akcipher_req->dst, sg_nents(akcipher_req->dst), sg_copy_from_buffer(akcipher_req->dst, sg_nents(akcipher_req->dst),
vc_akcipher_req->dst_buf, akcipher_req->dst_len); vc_akcipher_req->dst_buf, akcipher_req->dst_len);
}
virtio_crypto_akcipher_finalize_req(vc_akcipher_req, akcipher_req, error); virtio_crypto_akcipher_finalize_req(vc_akcipher_req, akcipher_req, error);
} }
...@@ -103,54 +106,56 @@ static int virtio_crypto_alg_akcipher_init_session(struct virtio_crypto_akcipher ...@@ -103,54 +106,56 @@ static int virtio_crypto_alg_akcipher_init_session(struct virtio_crypto_akcipher
struct scatterlist outhdr_sg, key_sg, inhdr_sg, *sgs[3]; struct scatterlist outhdr_sg, key_sg, inhdr_sg, *sgs[3];
struct virtio_crypto *vcrypto = ctx->vcrypto; struct virtio_crypto *vcrypto = ctx->vcrypto;
uint8_t *pkey; uint8_t *pkey;
unsigned int inlen;
int err; int err;
unsigned int num_out = 0, num_in = 0; unsigned int num_out = 0, num_in = 0;
struct virtio_crypto_op_ctrl_req *ctrl;
struct virtio_crypto_session_input *input;
struct virtio_crypto_ctrl_request *vc_ctrl_req;
pkey = kmemdup(key, keylen, GFP_ATOMIC); pkey = kmemdup(key, keylen, GFP_ATOMIC);
if (!pkey) if (!pkey)
return -ENOMEM; return -ENOMEM;
spin_lock(&vcrypto->ctrl_lock); vc_ctrl_req = kzalloc(sizeof(*vc_ctrl_req), GFP_KERNEL);
memcpy(&vcrypto->ctrl.header, header, sizeof(vcrypto->ctrl.header)); if (!vc_ctrl_req) {
memcpy(&vcrypto->ctrl.u, para, sizeof(vcrypto->ctrl.u)); err = -ENOMEM;
vcrypto->input.status = cpu_to_le32(VIRTIO_CRYPTO_ERR); goto out;
}
sg_init_one(&outhdr_sg, &vcrypto->ctrl, sizeof(vcrypto->ctrl)); ctrl = &vc_ctrl_req->ctrl;
memcpy(&ctrl->header, header, sizeof(ctrl->header));
memcpy(&ctrl->u, para, sizeof(ctrl->u));
input = &vc_ctrl_req->input;
input->status = cpu_to_le32(VIRTIO_CRYPTO_ERR);
sg_init_one(&outhdr_sg, ctrl, sizeof(*ctrl));
sgs[num_out++] = &outhdr_sg; sgs[num_out++] = &outhdr_sg;
sg_init_one(&key_sg, pkey, keylen); sg_init_one(&key_sg, pkey, keylen);
sgs[num_out++] = &key_sg; sgs[num_out++] = &key_sg;
sg_init_one(&inhdr_sg, &vcrypto->input, sizeof(vcrypto->input)); sg_init_one(&inhdr_sg, input, sizeof(*input));
sgs[num_out + num_in++] = &inhdr_sg; sgs[num_out + num_in++] = &inhdr_sg;
err = virtqueue_add_sgs(vcrypto->ctrl_vq, sgs, num_out, num_in, vcrypto, GFP_ATOMIC); err = virtio_crypto_ctrl_vq_request(vcrypto, sgs, num_out, num_in, vc_ctrl_req);
if (err < 0) if (err < 0)
goto out; goto out;
virtqueue_kick(vcrypto->ctrl_vq); if (le32_to_cpu(input->status) != VIRTIO_CRYPTO_OK) {
while (!virtqueue_get_buf(vcrypto->ctrl_vq, &inlen) && pr_err("virtio_crypto: Create session failed status: %u\n",
!virtqueue_is_broken(vcrypto->ctrl_vq)) le32_to_cpu(input->status));
cpu_relax();
if (le32_to_cpu(vcrypto->input.status) != VIRTIO_CRYPTO_OK) {
err = -EINVAL; err = -EINVAL;
goto out; goto out;
} }
ctx->session_id = le64_to_cpu(vcrypto->input.session_id); ctx->session_id = le64_to_cpu(input->session_id);
ctx->session_valid = true; ctx->session_valid = true;
err = 0; err = 0;
out: out:
spin_unlock(&vcrypto->ctrl_lock); kfree(vc_ctrl_req);
kfree_sensitive(pkey); kfree_sensitive(pkey);
if (err < 0)
pr_err("virtio_crypto: Create session failed status: %u\n",
le32_to_cpu(vcrypto->input.status));
return err; return err;
} }
...@@ -159,37 +164,41 @@ static int virtio_crypto_alg_akcipher_close_session(struct virtio_crypto_akciphe ...@@ -159,37 +164,41 @@ static int virtio_crypto_alg_akcipher_close_session(struct virtio_crypto_akciphe
struct scatterlist outhdr_sg, inhdr_sg, *sgs[2]; struct scatterlist outhdr_sg, inhdr_sg, *sgs[2];
struct virtio_crypto_destroy_session_req *destroy_session; struct virtio_crypto_destroy_session_req *destroy_session;
struct virtio_crypto *vcrypto = ctx->vcrypto; struct virtio_crypto *vcrypto = ctx->vcrypto;
unsigned int num_out = 0, num_in = 0, inlen; unsigned int num_out = 0, num_in = 0;
int err; int err;
struct virtio_crypto_op_ctrl_req *ctrl;
struct virtio_crypto_inhdr *ctrl_status;
struct virtio_crypto_ctrl_request *vc_ctrl_req;
spin_lock(&vcrypto->ctrl_lock); if (!ctx->session_valid)
if (!ctx->session_valid) { return 0;
err = 0;
goto out; vc_ctrl_req = kzalloc(sizeof(*vc_ctrl_req), GFP_KERNEL);
} if (!vc_ctrl_req)
vcrypto->ctrl_status.status = VIRTIO_CRYPTO_ERR; return -ENOMEM;
vcrypto->ctrl.header.opcode = cpu_to_le32(VIRTIO_CRYPTO_AKCIPHER_DESTROY_SESSION);
vcrypto->ctrl.header.queue_id = 0; ctrl_status = &vc_ctrl_req->ctrl_status;
ctrl_status->status = VIRTIO_CRYPTO_ERR;
ctrl = &vc_ctrl_req->ctrl;
ctrl->header.opcode = cpu_to_le32(VIRTIO_CRYPTO_AKCIPHER_DESTROY_SESSION);
ctrl->header.queue_id = 0;
destroy_session = &vcrypto->ctrl.u.destroy_session; destroy_session = &ctrl->u.destroy_session;
destroy_session->session_id = cpu_to_le64(ctx->session_id); destroy_session->session_id = cpu_to_le64(ctx->session_id);
sg_init_one(&outhdr_sg, &vcrypto->ctrl, sizeof(vcrypto->ctrl)); sg_init_one(&outhdr_sg, ctrl, sizeof(*ctrl));
sgs[num_out++] = &outhdr_sg; sgs[num_out++] = &outhdr_sg;
sg_init_one(&inhdr_sg, &vcrypto->ctrl_status.status, sizeof(vcrypto->ctrl_status.status)); sg_init_one(&inhdr_sg, &ctrl_status->status, sizeof(ctrl_status->status));
sgs[num_out + num_in++] = &inhdr_sg; sgs[num_out + num_in++] = &inhdr_sg;
err = virtqueue_add_sgs(vcrypto->ctrl_vq, sgs, num_out, num_in, vcrypto, GFP_ATOMIC); err = virtio_crypto_ctrl_vq_request(vcrypto, sgs, num_out, num_in, vc_ctrl_req);
if (err < 0) if (err < 0)
goto out; goto out;
virtqueue_kick(vcrypto->ctrl_vq); if (ctrl_status->status != VIRTIO_CRYPTO_OK) {
while (!virtqueue_get_buf(vcrypto->ctrl_vq, &inlen) && pr_err("virtio_crypto: Close session failed status: %u, session_id: 0x%llx\n",
!virtqueue_is_broken(vcrypto->ctrl_vq)) ctrl_status->status, destroy_session->session_id);
cpu_relax();
if (vcrypto->ctrl_status.status != VIRTIO_CRYPTO_OK) {
err = -EINVAL; err = -EINVAL;
goto out; goto out;
} }
...@@ -198,11 +207,7 @@ static int virtio_crypto_alg_akcipher_close_session(struct virtio_crypto_akciphe ...@@ -198,11 +207,7 @@ static int virtio_crypto_alg_akcipher_close_session(struct virtio_crypto_akciphe
ctx->session_valid = false; ctx->session_valid = false;
out: out:
spin_unlock(&vcrypto->ctrl_lock); kfree(vc_ctrl_req);
if (err < 0) {
pr_err("virtio_crypto: Close session failed status: %u, session_id: 0x%llx\n",
vcrypto->ctrl_status.status, destroy_session->session_id);
}
return err; return err;
} }
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include <crypto/aead.h> #include <crypto/aead.h>
#include <crypto/aes.h> #include <crypto/aes.h>
#include <crypto/engine.h> #include <crypto/engine.h>
#include <uapi/linux/virtio_crypto.h>
/* Internal representation of a data virtqueue */ /* Internal representation of a data virtqueue */
...@@ -65,11 +66,6 @@ struct virtio_crypto { ...@@ -65,11 +66,6 @@ struct virtio_crypto {
/* Maximum size of per request */ /* Maximum size of per request */
u64 max_size; u64 max_size;
/* Control VQ buffers: protected by the ctrl_lock */
struct virtio_crypto_op_ctrl_req ctrl;
struct virtio_crypto_session_input input;
struct virtio_crypto_inhdr ctrl_status;
unsigned long status; unsigned long status;
atomic_t ref_count; atomic_t ref_count;
struct list_head list; struct list_head list;
...@@ -85,6 +81,18 @@ struct virtio_crypto_sym_session_info { ...@@ -85,6 +81,18 @@ struct virtio_crypto_sym_session_info {
__u64 session_id; __u64 session_id;
}; };
/*
* Note: there are padding fields in request, clear them to zero before
* sending to host to avoid to divulge any information.
* Ex, virtio_crypto_ctrl_request::ctrl::u::destroy_session::padding[48]
*/
struct virtio_crypto_ctrl_request {
struct virtio_crypto_op_ctrl_req ctrl;
struct virtio_crypto_session_input input;
struct virtio_crypto_inhdr ctrl_status;
struct completion compl;
};
struct virtio_crypto_request; struct virtio_crypto_request;
typedef void (*virtio_crypto_data_callback) typedef void (*virtio_crypto_data_callback)
(struct virtio_crypto_request *vc_req, int len); (struct virtio_crypto_request *vc_req, int len);
...@@ -134,5 +142,8 @@ int virtio_crypto_skcipher_algs_register(struct virtio_crypto *vcrypto); ...@@ -134,5 +142,8 @@ int virtio_crypto_skcipher_algs_register(struct virtio_crypto *vcrypto);
void virtio_crypto_skcipher_algs_unregister(struct virtio_crypto *vcrypto); void virtio_crypto_skcipher_algs_unregister(struct virtio_crypto *vcrypto);
int virtio_crypto_akcipher_algs_register(struct virtio_crypto *vcrypto); int virtio_crypto_akcipher_algs_register(struct virtio_crypto *vcrypto);
void virtio_crypto_akcipher_algs_unregister(struct virtio_crypto *vcrypto); void virtio_crypto_akcipher_algs_unregister(struct virtio_crypto *vcrypto);
int virtio_crypto_ctrl_vq_request(struct virtio_crypto *vcrypto, struct scatterlist *sgs[],
unsigned int out_sgs, unsigned int in_sgs,
struct virtio_crypto_ctrl_request *vc_ctrl_req);
#endif /* _VIRTIO_CRYPTO_COMMON_H */ #endif /* _VIRTIO_CRYPTO_COMMON_H */
...@@ -22,6 +22,56 @@ virtcrypto_clear_request(struct virtio_crypto_request *vc_req) ...@@ -22,6 +22,56 @@ virtcrypto_clear_request(struct virtio_crypto_request *vc_req)
} }
} }
static void virtio_crypto_ctrlq_callback(struct virtio_crypto_ctrl_request *vc_ctrl_req)
{
complete(&vc_ctrl_req->compl);
}
static void virtcrypto_ctrlq_callback(struct virtqueue *vq)
{
struct virtio_crypto *vcrypto = vq->vdev->priv;
struct virtio_crypto_ctrl_request *vc_ctrl_req;
unsigned long flags;
unsigned int len;
spin_lock_irqsave(&vcrypto->ctrl_lock, flags);
do {
virtqueue_disable_cb(vq);
while ((vc_ctrl_req = virtqueue_get_buf(vq, &len)) != NULL) {
spin_unlock_irqrestore(&vcrypto->ctrl_lock, flags);
virtio_crypto_ctrlq_callback(vc_ctrl_req);
spin_lock_irqsave(&vcrypto->ctrl_lock, flags);
}
if (unlikely(virtqueue_is_broken(vq)))
break;
} while (!virtqueue_enable_cb(vq));
spin_unlock_irqrestore(&vcrypto->ctrl_lock, flags);
}
int virtio_crypto_ctrl_vq_request(struct virtio_crypto *vcrypto, struct scatterlist *sgs[],
unsigned int out_sgs, unsigned int in_sgs,
struct virtio_crypto_ctrl_request *vc_ctrl_req)
{
int err;
unsigned long flags;
init_completion(&vc_ctrl_req->compl);
spin_lock_irqsave(&vcrypto->ctrl_lock, flags);
err = virtqueue_add_sgs(vcrypto->ctrl_vq, sgs, out_sgs, in_sgs, vc_ctrl_req, GFP_ATOMIC);
if (err < 0) {
spin_unlock_irqrestore(&vcrypto->ctrl_lock, flags);
return err;
}
virtqueue_kick(vcrypto->ctrl_vq);
spin_unlock_irqrestore(&vcrypto->ctrl_lock, flags);
wait_for_completion(&vc_ctrl_req->compl);
return 0;
}
static void virtcrypto_dataq_callback(struct virtqueue *vq) static void virtcrypto_dataq_callback(struct virtqueue *vq)
{ {
struct virtio_crypto *vcrypto = vq->vdev->priv; struct virtio_crypto *vcrypto = vq->vdev->priv;
...@@ -73,7 +123,7 @@ static int virtcrypto_find_vqs(struct virtio_crypto *vi) ...@@ -73,7 +123,7 @@ static int virtcrypto_find_vqs(struct virtio_crypto *vi)
goto err_names; goto err_names;
/* Parameters for control virtqueue */ /* Parameters for control virtqueue */
callbacks[total_vqs - 1] = NULL; callbacks[total_vqs - 1] = virtcrypto_ctrlq_callback;
names[total_vqs - 1] = "controlq"; names[total_vqs - 1] = "controlq";
/* Allocate/initialize parameters for data virtqueues */ /* Allocate/initialize parameters for data virtqueues */
...@@ -94,7 +144,8 @@ static int virtcrypto_find_vqs(struct virtio_crypto *vi) ...@@ -94,7 +144,8 @@ static int virtcrypto_find_vqs(struct virtio_crypto *vi)
spin_lock_init(&vi->data_vq[i].lock); spin_lock_init(&vi->data_vq[i].lock);
vi->data_vq[i].vq = vqs[i]; vi->data_vq[i].vq = vqs[i];
/* Initialize crypto engine */ /* Initialize crypto engine */
vi->data_vq[i].engine = crypto_engine_alloc_init(dev, 1); vi->data_vq[i].engine = crypto_engine_alloc_init_and_set(dev, true, NULL, true,
virtqueue_get_vring_size(vqs[i]));
if (!vi->data_vq[i].engine) { if (!vi->data_vq[i].engine) {
ret = -ENOMEM; ret = -ENOMEM;
goto err_engine; goto err_engine;
......
...@@ -118,11 +118,14 @@ static int virtio_crypto_alg_skcipher_init_session( ...@@ -118,11 +118,14 @@ static int virtio_crypto_alg_skcipher_init_session(
int encrypt) int encrypt)
{ {
struct scatterlist outhdr, key_sg, inhdr, *sgs[3]; struct scatterlist outhdr, key_sg, inhdr, *sgs[3];
unsigned int tmp;
struct virtio_crypto *vcrypto = ctx->vcrypto; struct virtio_crypto *vcrypto = ctx->vcrypto;
int op = encrypt ? VIRTIO_CRYPTO_OP_ENCRYPT : VIRTIO_CRYPTO_OP_DECRYPT; int op = encrypt ? VIRTIO_CRYPTO_OP_ENCRYPT : VIRTIO_CRYPTO_OP_DECRYPT;
int err; int err;
unsigned int num_out = 0, num_in = 0; unsigned int num_out = 0, num_in = 0;
struct virtio_crypto_op_ctrl_req *ctrl;
struct virtio_crypto_session_input *input;
struct virtio_crypto_sym_create_session_req *sym_create_session;
struct virtio_crypto_ctrl_request *vc_ctrl_req;
/* /*
* Avoid to do DMA from the stack, switch to using * Avoid to do DMA from the stack, switch to using
...@@ -133,26 +136,29 @@ static int virtio_crypto_alg_skcipher_init_session( ...@@ -133,26 +136,29 @@ static int virtio_crypto_alg_skcipher_init_session(
if (!cipher_key) if (!cipher_key)
return -ENOMEM; return -ENOMEM;
spin_lock(&vcrypto->ctrl_lock); vc_ctrl_req = kzalloc(sizeof(*vc_ctrl_req), GFP_KERNEL);
if (!vc_ctrl_req) {
err = -ENOMEM;
goto out;
}
/* Pad ctrl header */ /* Pad ctrl header */
vcrypto->ctrl.header.opcode = ctrl = &vc_ctrl_req->ctrl;
cpu_to_le32(VIRTIO_CRYPTO_CIPHER_CREATE_SESSION); ctrl->header.opcode = cpu_to_le32(VIRTIO_CRYPTO_CIPHER_CREATE_SESSION);
vcrypto->ctrl.header.algo = cpu_to_le32(alg); ctrl->header.algo = cpu_to_le32(alg);
/* Set the default dataqueue id to 0 */ /* Set the default dataqueue id to 0 */
vcrypto->ctrl.header.queue_id = 0; ctrl->header.queue_id = 0;
vcrypto->input.status = cpu_to_le32(VIRTIO_CRYPTO_ERR); input = &vc_ctrl_req->input;
input->status = cpu_to_le32(VIRTIO_CRYPTO_ERR);
/* Pad cipher's parameters */ /* Pad cipher's parameters */
vcrypto->ctrl.u.sym_create_session.op_type = sym_create_session = &ctrl->u.sym_create_session;
cpu_to_le32(VIRTIO_CRYPTO_SYM_OP_CIPHER); sym_create_session->op_type = cpu_to_le32(VIRTIO_CRYPTO_SYM_OP_CIPHER);
vcrypto->ctrl.u.sym_create_session.u.cipher.para.algo = sym_create_session->u.cipher.para.algo = ctrl->header.algo;
vcrypto->ctrl.header.algo; sym_create_session->u.cipher.para.keylen = cpu_to_le32(keylen);
vcrypto->ctrl.u.sym_create_session.u.cipher.para.keylen = sym_create_session->u.cipher.para.op = cpu_to_le32(op);
cpu_to_le32(keylen);
vcrypto->ctrl.u.sym_create_session.u.cipher.para.op = sg_init_one(&outhdr, ctrl, sizeof(*ctrl));
cpu_to_le32(op);
sg_init_one(&outhdr, &vcrypto->ctrl, sizeof(vcrypto->ctrl));
sgs[num_out++] = &outhdr; sgs[num_out++] = &outhdr;
/* Set key */ /* Set key */
...@@ -160,45 +166,30 @@ static int virtio_crypto_alg_skcipher_init_session( ...@@ -160,45 +166,30 @@ static int virtio_crypto_alg_skcipher_init_session(
sgs[num_out++] = &key_sg; sgs[num_out++] = &key_sg;
/* Return status and session id back */ /* Return status and session id back */
sg_init_one(&inhdr, &vcrypto->input, sizeof(vcrypto->input)); sg_init_one(&inhdr, input, sizeof(*input));
sgs[num_out + num_in++] = &inhdr; sgs[num_out + num_in++] = &inhdr;
err = virtqueue_add_sgs(vcrypto->ctrl_vq, sgs, num_out, err = virtio_crypto_ctrl_vq_request(vcrypto, sgs, num_out, num_in, vc_ctrl_req);
num_in, vcrypto, GFP_ATOMIC); if (err < 0)
if (err < 0) { goto out;
spin_unlock(&vcrypto->ctrl_lock);
kfree_sensitive(cipher_key);
return err;
}
virtqueue_kick(vcrypto->ctrl_vq);
/* if (le32_to_cpu(input->status) != VIRTIO_CRYPTO_OK) {
* Trapping into the hypervisor, so the request should be
* handled immediately.
*/
while (!virtqueue_get_buf(vcrypto->ctrl_vq, &tmp) &&
!virtqueue_is_broken(vcrypto->ctrl_vq))
cpu_relax();
if (le32_to_cpu(vcrypto->input.status) != VIRTIO_CRYPTO_OK) {
spin_unlock(&vcrypto->ctrl_lock);
pr_err("virtio_crypto: Create session failed status: %u\n", pr_err("virtio_crypto: Create session failed status: %u\n",
le32_to_cpu(vcrypto->input.status)); le32_to_cpu(input->status));
kfree_sensitive(cipher_key); err = -EINVAL;
return -EINVAL; goto out;
} }
if (encrypt) if (encrypt)
ctx->enc_sess_info.session_id = ctx->enc_sess_info.session_id = le64_to_cpu(input->session_id);
le64_to_cpu(vcrypto->input.session_id);
else else
ctx->dec_sess_info.session_id = ctx->dec_sess_info.session_id = le64_to_cpu(input->session_id);
le64_to_cpu(vcrypto->input.session_id);
spin_unlock(&vcrypto->ctrl_lock);
err = 0;
out:
kfree(vc_ctrl_req);
kfree_sensitive(cipher_key); kfree_sensitive(cipher_key);
return 0; return err;
} }
static int virtio_crypto_alg_skcipher_close_session( static int virtio_crypto_alg_skcipher_close_session(
...@@ -206,60 +197,55 @@ static int virtio_crypto_alg_skcipher_close_session( ...@@ -206,60 +197,55 @@ static int virtio_crypto_alg_skcipher_close_session(
int encrypt) int encrypt)
{ {
struct scatterlist outhdr, status_sg, *sgs[2]; struct scatterlist outhdr, status_sg, *sgs[2];
unsigned int tmp;
struct virtio_crypto_destroy_session_req *destroy_session; struct virtio_crypto_destroy_session_req *destroy_session;
struct virtio_crypto *vcrypto = ctx->vcrypto; struct virtio_crypto *vcrypto = ctx->vcrypto;
int err; int err;
unsigned int num_out = 0, num_in = 0; unsigned int num_out = 0, num_in = 0;
struct virtio_crypto_op_ctrl_req *ctrl;
struct virtio_crypto_inhdr *ctrl_status;
struct virtio_crypto_ctrl_request *vc_ctrl_req;
spin_lock(&vcrypto->ctrl_lock); vc_ctrl_req = kzalloc(sizeof(*vc_ctrl_req), GFP_KERNEL);
vcrypto->ctrl_status.status = VIRTIO_CRYPTO_ERR; if (!vc_ctrl_req)
return -ENOMEM;
ctrl_status = &vc_ctrl_req->ctrl_status;
ctrl_status->status = VIRTIO_CRYPTO_ERR;
/* Pad ctrl header */ /* Pad ctrl header */
vcrypto->ctrl.header.opcode = ctrl = &vc_ctrl_req->ctrl;
cpu_to_le32(VIRTIO_CRYPTO_CIPHER_DESTROY_SESSION); ctrl->header.opcode = cpu_to_le32(VIRTIO_CRYPTO_CIPHER_DESTROY_SESSION);
/* Set the default virtqueue id to 0 */ /* Set the default virtqueue id to 0 */
vcrypto->ctrl.header.queue_id = 0; ctrl->header.queue_id = 0;
destroy_session = &vcrypto->ctrl.u.destroy_session; destroy_session = &ctrl->u.destroy_session;
if (encrypt) if (encrypt)
destroy_session->session_id = destroy_session->session_id = cpu_to_le64(ctx->enc_sess_info.session_id);
cpu_to_le64(ctx->enc_sess_info.session_id);
else else
destroy_session->session_id = destroy_session->session_id = cpu_to_le64(ctx->dec_sess_info.session_id);
cpu_to_le64(ctx->dec_sess_info.session_id);
sg_init_one(&outhdr, &vcrypto->ctrl, sizeof(vcrypto->ctrl)); sg_init_one(&outhdr, ctrl, sizeof(*ctrl));
sgs[num_out++] = &outhdr; sgs[num_out++] = &outhdr;
/* Return status and session id back */ /* Return status and session id back */
sg_init_one(&status_sg, &vcrypto->ctrl_status.status, sg_init_one(&status_sg, &ctrl_status->status, sizeof(ctrl_status->status));
sizeof(vcrypto->ctrl_status.status));
sgs[num_out + num_in++] = &status_sg; sgs[num_out + num_in++] = &status_sg;
err = virtqueue_add_sgs(vcrypto->ctrl_vq, sgs, num_out, err = virtio_crypto_ctrl_vq_request(vcrypto, sgs, num_out, num_in, vc_ctrl_req);
num_in, vcrypto, GFP_ATOMIC); if (err < 0)
if (err < 0) { goto out;
spin_unlock(&vcrypto->ctrl_lock);
return err;
}
virtqueue_kick(vcrypto->ctrl_vq);
while (!virtqueue_get_buf(vcrypto->ctrl_vq, &tmp) &&
!virtqueue_is_broken(vcrypto->ctrl_vq))
cpu_relax();
if (vcrypto->ctrl_status.status != VIRTIO_CRYPTO_OK) { if (ctrl_status->status != VIRTIO_CRYPTO_OK) {
spin_unlock(&vcrypto->ctrl_lock);
pr_err("virtio_crypto: Close session failed status: %u, session_id: 0x%llx\n", pr_err("virtio_crypto: Close session failed status: %u, session_id: 0x%llx\n",
vcrypto->ctrl_status.status, ctrl_status->status, destroy_session->session_id);
destroy_session->session_id);
return -EINVAL; return -EINVAL;
} }
spin_unlock(&vcrypto->ctrl_lock);
return 0; err = 0;
out:
kfree(vc_ctrl_req);
return err;
} }
static int virtio_crypto_alg_skcipher_init_sessions( static int virtio_crypto_alg_skcipher_init_sessions(
......
...@@ -62,6 +62,7 @@ struct virtio_ccw_device { ...@@ -62,6 +62,7 @@ struct virtio_ccw_device {
unsigned int revision; /* Transport revision */ unsigned int revision; /* Transport revision */
wait_queue_head_t wait_q; wait_queue_head_t wait_q;
spinlock_t lock; spinlock_t lock;
rwlock_t irq_lock;
struct mutex io_lock; /* Serializes I/O requests */ struct mutex io_lock; /* Serializes I/O requests */
struct list_head virtqueues; struct list_head virtqueues;
bool is_thinint; bool is_thinint;
...@@ -970,6 +971,10 @@ static void virtio_ccw_set_status(struct virtio_device *vdev, u8 status) ...@@ -970,6 +971,10 @@ static void virtio_ccw_set_status(struct virtio_device *vdev, u8 status)
ccw->flags = 0; ccw->flags = 0;
ccw->count = sizeof(status); ccw->count = sizeof(status);
ccw->cda = (__u32)(unsigned long)&vcdev->dma_area->status; ccw->cda = (__u32)(unsigned long)&vcdev->dma_area->status;
/* We use ssch for setting the status which is a serializing
* instruction that guarantees the memory writes have
* completed before ssch.
*/
ret = ccw_io_helper(vcdev, ccw, VIRTIO_CCW_DOING_WRITE_STATUS); ret = ccw_io_helper(vcdev, ccw, VIRTIO_CCW_DOING_WRITE_STATUS);
/* Write failed? We assume status is unchanged. */ /* Write failed? We assume status is unchanged. */
if (ret) if (ret)
...@@ -984,6 +989,30 @@ static const char *virtio_ccw_bus_name(struct virtio_device *vdev) ...@@ -984,6 +989,30 @@ static const char *virtio_ccw_bus_name(struct virtio_device *vdev)
return dev_name(&vcdev->cdev->dev); return dev_name(&vcdev->cdev->dev);
} }
static void virtio_ccw_synchronize_cbs(struct virtio_device *vdev)
{
struct virtio_ccw_device *vcdev = to_vc_device(vdev);
struct airq_info *info = vcdev->airq_info;
if (info) {
/*
* This device uses adapter interrupts: synchronize with
* vring_interrupt() called by virtio_airq_handler()
* via the indicator area lock.
*/
write_lock_irq(&info->lock);
write_unlock_irq(&info->lock);
} else {
/* This device uses classic interrupts: synchronize
* with vring_interrupt() called by
* virtio_ccw_int_handler() via the per-device
* irq_lock
*/
write_lock_irq(&vcdev->irq_lock);
write_unlock_irq(&vcdev->irq_lock);
}
}
static const struct virtio_config_ops virtio_ccw_config_ops = { static const struct virtio_config_ops virtio_ccw_config_ops = {
.get_features = virtio_ccw_get_features, .get_features = virtio_ccw_get_features,
.finalize_features = virtio_ccw_finalize_features, .finalize_features = virtio_ccw_finalize_features,
...@@ -995,6 +1024,7 @@ static const struct virtio_config_ops virtio_ccw_config_ops = { ...@@ -995,6 +1024,7 @@ static const struct virtio_config_ops virtio_ccw_config_ops = {
.find_vqs = virtio_ccw_find_vqs, .find_vqs = virtio_ccw_find_vqs,
.del_vqs = virtio_ccw_del_vqs, .del_vqs = virtio_ccw_del_vqs,
.bus_name = virtio_ccw_bus_name, .bus_name = virtio_ccw_bus_name,
.synchronize_cbs = virtio_ccw_synchronize_cbs,
}; };
...@@ -1106,6 +1136,8 @@ static void virtio_ccw_int_handler(struct ccw_device *cdev, ...@@ -1106,6 +1136,8 @@ static void virtio_ccw_int_handler(struct ccw_device *cdev,
vcdev->err = -EIO; vcdev->err = -EIO;
} }
virtio_ccw_check_activity(vcdev, activity); virtio_ccw_check_activity(vcdev, activity);
/* Interrupts are disabled here */
read_lock(&vcdev->irq_lock);
for_each_set_bit(i, indicators(vcdev), for_each_set_bit(i, indicators(vcdev),
sizeof(*indicators(vcdev)) * BITS_PER_BYTE) { sizeof(*indicators(vcdev)) * BITS_PER_BYTE) {
/* The bit clear must happen before the vring kick. */ /* The bit clear must happen before the vring kick. */
...@@ -1114,6 +1146,7 @@ static void virtio_ccw_int_handler(struct ccw_device *cdev, ...@@ -1114,6 +1146,7 @@ static void virtio_ccw_int_handler(struct ccw_device *cdev,
vq = virtio_ccw_vq_by_ind(vcdev, i); vq = virtio_ccw_vq_by_ind(vcdev, i);
vring_interrupt(0, vq); vring_interrupt(0, vq);
} }
read_unlock(&vcdev->irq_lock);
if (test_bit(0, indicators2(vcdev))) { if (test_bit(0, indicators2(vcdev))) {
virtio_config_changed(&vcdev->vdev); virtio_config_changed(&vcdev->vdev);
clear_bit(0, indicators2(vcdev)); clear_bit(0, indicators2(vcdev));
...@@ -1284,6 +1317,7 @@ static int virtio_ccw_online(struct ccw_device *cdev) ...@@ -1284,6 +1317,7 @@ static int virtio_ccw_online(struct ccw_device *cdev)
init_waitqueue_head(&vcdev->wait_q); init_waitqueue_head(&vcdev->wait_q);
INIT_LIST_HEAD(&vcdev->virtqueues); INIT_LIST_HEAD(&vcdev->virtqueues);
spin_lock_init(&vcdev->lock); spin_lock_init(&vcdev->lock);
rwlock_init(&vcdev->irq_lock);
mutex_init(&vcdev->io_lock); mutex_init(&vcdev->io_lock);
spin_lock_irqsave(get_ccwdev_lock(cdev), flags); spin_lock_irqsave(get_ccwdev_lock(cdev), flags);
......
...@@ -470,7 +470,7 @@ static int eni_vdpa_probe(struct pci_dev *pdev, const struct pci_device_id *id) ...@@ -470,7 +470,7 @@ static int eni_vdpa_probe(struct pci_dev *pdev, const struct pci_device_id *id)
return ret; return ret;
eni_vdpa = vdpa_alloc_device(struct eni_vdpa, vdpa, eni_vdpa = vdpa_alloc_device(struct eni_vdpa, vdpa,
dev, &eni_vdpa_ops, NULL, false); dev, &eni_vdpa_ops, 1, 1, NULL, false);
if (IS_ERR(eni_vdpa)) { if (IS_ERR(eni_vdpa)) {
ENI_ERR(pdev, "failed to allocate vDPA structure\n"); ENI_ERR(pdev, "failed to allocate vDPA structure\n");
return PTR_ERR(eni_vdpa); return PTR_ERR(eni_vdpa);
......
...@@ -290,16 +290,16 @@ static int ifcvf_request_config_irq(struct ifcvf_adapter *adapter) ...@@ -290,16 +290,16 @@ static int ifcvf_request_config_irq(struct ifcvf_adapter *adapter)
struct ifcvf_hw *vf = &adapter->vf; struct ifcvf_hw *vf = &adapter->vf;
int config_vector, ret; int config_vector, ret;
if (vf->msix_vector_status == MSIX_VECTOR_DEV_SHARED)
return 0;
if (vf->msix_vector_status == MSIX_VECTOR_PER_VQ_AND_CONFIG) if (vf->msix_vector_status == MSIX_VECTOR_PER_VQ_AND_CONFIG)
/* vector 0 ~ vf->nr_vring for vqs, num vf->nr_vring vector for config interrupt */
config_vector = vf->nr_vring; config_vector = vf->nr_vring;
else if (vf->msix_vector_status == MSIX_VECTOR_SHARED_VQ_AND_CONFIG)
if (vf->msix_vector_status == MSIX_VECTOR_SHARED_VQ_AND_CONFIG)
/* vector 0 for vqs and 1 for config interrupt */ /* vector 0 for vqs and 1 for config interrupt */
config_vector = 1; config_vector = 1;
else if (vf->msix_vector_status == MSIX_VECTOR_DEV_SHARED)
/* re-use the vqs vector */
return 0;
else
return -EINVAL;
snprintf(vf->config_msix_name, 256, "ifcvf[%s]-config\n", snprintf(vf->config_msix_name, 256, "ifcvf[%s]-config\n",
pci_name(pdev)); pci_name(pdev));
...@@ -626,6 +626,11 @@ static size_t ifcvf_vdpa_get_config_size(struct vdpa_device *vdpa_dev) ...@@ -626,6 +626,11 @@ static size_t ifcvf_vdpa_get_config_size(struct vdpa_device *vdpa_dev)
return vf->config_size; return vf->config_size;
} }
static u32 ifcvf_vdpa_get_vq_group(struct vdpa_device *vdpa, u16 idx)
{
return 0;
}
static void ifcvf_vdpa_get_config(struct vdpa_device *vdpa_dev, static void ifcvf_vdpa_get_config(struct vdpa_device *vdpa_dev,
unsigned int offset, unsigned int offset,
void *buf, unsigned int len) void *buf, unsigned int len)
...@@ -704,6 +709,7 @@ static const struct vdpa_config_ops ifc_vdpa_ops = { ...@@ -704,6 +709,7 @@ static const struct vdpa_config_ops ifc_vdpa_ops = {
.get_device_id = ifcvf_vdpa_get_device_id, .get_device_id = ifcvf_vdpa_get_device_id,
.get_vendor_id = ifcvf_vdpa_get_vendor_id, .get_vendor_id = ifcvf_vdpa_get_vendor_id,
.get_vq_align = ifcvf_vdpa_get_vq_align, .get_vq_align = ifcvf_vdpa_get_vq_align,
.get_vq_group = ifcvf_vdpa_get_vq_group,
.get_config_size = ifcvf_vdpa_get_config_size, .get_config_size = ifcvf_vdpa_get_config_size,
.get_config = ifcvf_vdpa_get_config, .get_config = ifcvf_vdpa_get_config,
.set_config = ifcvf_vdpa_set_config, .set_config = ifcvf_vdpa_set_config,
...@@ -758,14 +764,13 @@ static int ifcvf_vdpa_dev_add(struct vdpa_mgmt_dev *mdev, const char *name, ...@@ -758,14 +764,13 @@ static int ifcvf_vdpa_dev_add(struct vdpa_mgmt_dev *mdev, const char *name,
pdev = ifcvf_mgmt_dev->pdev; pdev = ifcvf_mgmt_dev->pdev;
dev = &pdev->dev; dev = &pdev->dev;
adapter = vdpa_alloc_device(struct ifcvf_adapter, vdpa, adapter = vdpa_alloc_device(struct ifcvf_adapter, vdpa,
dev, &ifc_vdpa_ops, name, false); dev, &ifc_vdpa_ops, 1, 1, name, false);
if (IS_ERR(adapter)) { if (IS_ERR(adapter)) {
IFCVF_ERR(pdev, "Failed to allocate vDPA structure"); IFCVF_ERR(pdev, "Failed to allocate vDPA structure");
return PTR_ERR(adapter); return PTR_ERR(adapter);
} }
ifcvf_mgmt_dev->adapter = adapter; ifcvf_mgmt_dev->adapter = adapter;
pci_set_drvdata(pdev, ifcvf_mgmt_dev);
vf = &adapter->vf; vf = &adapter->vf;
vf->dev_type = get_dev_type(pdev); vf->dev_type = get_dev_type(pdev);
...@@ -880,6 +885,8 @@ static int ifcvf_probe(struct pci_dev *pdev, const struct pci_device_id *id) ...@@ -880,6 +885,8 @@ static int ifcvf_probe(struct pci_dev *pdev, const struct pci_device_id *id)
goto err; goto err;
} }
pci_set_drvdata(pdev, ifcvf_mgmt_dev);
return 0; return 0;
err: err:
......
...@@ -61,6 +61,8 @@ struct mlx5_control_vq { ...@@ -61,6 +61,8 @@ struct mlx5_control_vq {
struct vringh_kiov riov; struct vringh_kiov riov;
struct vringh_kiov wiov; struct vringh_kiov wiov;
unsigned short head; unsigned short head;
unsigned int received_desc;
unsigned int completed_desc;
}; };
struct mlx5_vdpa_wq_ent { struct mlx5_vdpa_wq_ent {
......
This diff is collapsed.
This diff is collapsed.
...@@ -96,11 +96,17 @@ static void vdpasim_do_reset(struct vdpasim *vdpasim) ...@@ -96,11 +96,17 @@ static void vdpasim_do_reset(struct vdpasim *vdpasim)
{ {
int i; int i;
for (i = 0; i < vdpasim->dev_attr.nvqs; i++) spin_lock(&vdpasim->iommu_lock);
for (i = 0; i < vdpasim->dev_attr.nvqs; i++) {
vdpasim_vq_reset(vdpasim, &vdpasim->vqs[i]); vdpasim_vq_reset(vdpasim, &vdpasim->vqs[i]);
vringh_set_iotlb(&vdpasim->vqs[i].vring, &vdpasim->iommu[0],
&vdpasim->iommu_lock);
}
for (i = 0; i < vdpasim->dev_attr.nas; i++)
vhost_iotlb_reset(&vdpasim->iommu[i]);
spin_lock(&vdpasim->iommu_lock);
vhost_iotlb_reset(vdpasim->iommu);
spin_unlock(&vdpasim->iommu_lock); spin_unlock(&vdpasim->iommu_lock);
vdpasim->features = 0; vdpasim->features = 0;
...@@ -145,7 +151,7 @@ static dma_addr_t vdpasim_map_range(struct vdpasim *vdpasim, phys_addr_t paddr, ...@@ -145,7 +151,7 @@ static dma_addr_t vdpasim_map_range(struct vdpasim *vdpasim, phys_addr_t paddr,
dma_addr = iova_dma_addr(&vdpasim->iova, iova); dma_addr = iova_dma_addr(&vdpasim->iova, iova);
spin_lock(&vdpasim->iommu_lock); spin_lock(&vdpasim->iommu_lock);
ret = vhost_iotlb_add_range(vdpasim->iommu, (u64)dma_addr, ret = vhost_iotlb_add_range(&vdpasim->iommu[0], (u64)dma_addr,
(u64)dma_addr + size - 1, (u64)paddr, perm); (u64)dma_addr + size - 1, (u64)paddr, perm);
spin_unlock(&vdpasim->iommu_lock); spin_unlock(&vdpasim->iommu_lock);
...@@ -161,7 +167,7 @@ static void vdpasim_unmap_range(struct vdpasim *vdpasim, dma_addr_t dma_addr, ...@@ -161,7 +167,7 @@ static void vdpasim_unmap_range(struct vdpasim *vdpasim, dma_addr_t dma_addr,
size_t size) size_t size)
{ {
spin_lock(&vdpasim->iommu_lock); spin_lock(&vdpasim->iommu_lock);
vhost_iotlb_del_range(vdpasim->iommu, (u64)dma_addr, vhost_iotlb_del_range(&vdpasim->iommu[0], (u64)dma_addr,
(u64)dma_addr + size - 1); (u64)dma_addr + size - 1);
spin_unlock(&vdpasim->iommu_lock); spin_unlock(&vdpasim->iommu_lock);
...@@ -251,6 +257,7 @@ struct vdpasim *vdpasim_create(struct vdpasim_dev_attr *dev_attr) ...@@ -251,6 +257,7 @@ struct vdpasim *vdpasim_create(struct vdpasim_dev_attr *dev_attr)
ops = &vdpasim_config_ops; ops = &vdpasim_config_ops;
vdpasim = vdpa_alloc_device(struct vdpasim, vdpa, NULL, ops, vdpasim = vdpa_alloc_device(struct vdpasim, vdpa, NULL, ops,
dev_attr->ngroups, dev_attr->nas,
dev_attr->name, false); dev_attr->name, false);
if (IS_ERR(vdpasim)) { if (IS_ERR(vdpasim)) {
ret = PTR_ERR(vdpasim); ret = PTR_ERR(vdpasim);
...@@ -278,16 +285,20 @@ struct vdpasim *vdpasim_create(struct vdpasim_dev_attr *dev_attr) ...@@ -278,16 +285,20 @@ struct vdpasim *vdpasim_create(struct vdpasim_dev_attr *dev_attr)
if (!vdpasim->vqs) if (!vdpasim->vqs)
goto err_iommu; goto err_iommu;
vdpasim->iommu = vhost_iotlb_alloc(max_iotlb_entries, 0); vdpasim->iommu = kmalloc_array(vdpasim->dev_attr.nas,
sizeof(*vdpasim->iommu), GFP_KERNEL);
if (!vdpasim->iommu) if (!vdpasim->iommu)
goto err_iommu; goto err_iommu;
for (i = 0; i < vdpasim->dev_attr.nas; i++)
vhost_iotlb_init(&vdpasim->iommu[i], 0, 0);
vdpasim->buffer = kvmalloc(dev_attr->buffer_size, GFP_KERNEL); vdpasim->buffer = kvmalloc(dev_attr->buffer_size, GFP_KERNEL);
if (!vdpasim->buffer) if (!vdpasim->buffer)
goto err_iommu; goto err_iommu;
for (i = 0; i < dev_attr->nvqs; i++) for (i = 0; i < dev_attr->nvqs; i++)
vringh_set_iotlb(&vdpasim->vqs[i].vring, vdpasim->iommu, vringh_set_iotlb(&vdpasim->vqs[i].vring, &vdpasim->iommu[0],
&vdpasim->iommu_lock); &vdpasim->iommu_lock);
ret = iova_cache_get(); ret = iova_cache_get();
...@@ -353,11 +364,14 @@ static void vdpasim_set_vq_ready(struct vdpa_device *vdpa, u16 idx, bool ready) ...@@ -353,11 +364,14 @@ static void vdpasim_set_vq_ready(struct vdpa_device *vdpa, u16 idx, bool ready)
{ {
struct vdpasim *vdpasim = vdpa_to_sim(vdpa); struct vdpasim *vdpasim = vdpa_to_sim(vdpa);
struct vdpasim_virtqueue *vq = &vdpasim->vqs[idx]; struct vdpasim_virtqueue *vq = &vdpasim->vqs[idx];
bool old_ready;
spin_lock(&vdpasim->lock); spin_lock(&vdpasim->lock);
old_ready = vq->ready;
vq->ready = ready; vq->ready = ready;
if (vq->ready) if (vq->ready && !old_ready) {
vdpasim_queue_ready(vdpasim, idx); vdpasim_queue_ready(vdpasim, idx);
}
spin_unlock(&vdpasim->lock); spin_unlock(&vdpasim->lock);
} }
...@@ -399,6 +413,15 @@ static u32 vdpasim_get_vq_align(struct vdpa_device *vdpa) ...@@ -399,6 +413,15 @@ static u32 vdpasim_get_vq_align(struct vdpa_device *vdpa)
return VDPASIM_QUEUE_ALIGN; return VDPASIM_QUEUE_ALIGN;
} }
static u32 vdpasim_get_vq_group(struct vdpa_device *vdpa, u16 idx)
{
/* RX and TX belongs to group 0, CVQ belongs to group 1 */
if (idx == 2)
return 1;
else
return 0;
}
static u64 vdpasim_get_device_features(struct vdpa_device *vdpa) static u64 vdpasim_get_device_features(struct vdpa_device *vdpa)
{ {
struct vdpasim *vdpasim = vdpa_to_sim(vdpa); struct vdpasim *vdpasim = vdpa_to_sim(vdpa);
...@@ -534,20 +557,53 @@ static struct vdpa_iova_range vdpasim_get_iova_range(struct vdpa_device *vdpa) ...@@ -534,20 +557,53 @@ static struct vdpa_iova_range vdpasim_get_iova_range(struct vdpa_device *vdpa)
return range; return range;
} }
static int vdpasim_set_map(struct vdpa_device *vdpa, static int vdpasim_set_group_asid(struct vdpa_device *vdpa, unsigned int group,
unsigned int asid)
{
struct vdpasim *vdpasim = vdpa_to_sim(vdpa);
struct vhost_iotlb *iommu;
int i;
if (group > vdpasim->dev_attr.ngroups)
return -EINVAL;
if (asid >= vdpasim->dev_attr.nas)
return -EINVAL;
iommu = &vdpasim->iommu[asid];
spin_lock(&vdpasim->lock);
for (i = 0; i < vdpasim->dev_attr.nvqs; i++)
if (vdpasim_get_vq_group(vdpa, i) == group)
vringh_set_iotlb(&vdpasim->vqs[i].vring, iommu,
&vdpasim->iommu_lock);
spin_unlock(&vdpasim->lock);
return 0;
}
static int vdpasim_set_map(struct vdpa_device *vdpa, unsigned int asid,
struct vhost_iotlb *iotlb) struct vhost_iotlb *iotlb)
{ {
struct vdpasim *vdpasim = vdpa_to_sim(vdpa); struct vdpasim *vdpasim = vdpa_to_sim(vdpa);
struct vhost_iotlb_map *map; struct vhost_iotlb_map *map;
struct vhost_iotlb *iommu;
u64 start = 0ULL, last = 0ULL - 1; u64 start = 0ULL, last = 0ULL - 1;
int ret; int ret;
if (asid >= vdpasim->dev_attr.nas)
return -EINVAL;
spin_lock(&vdpasim->iommu_lock); spin_lock(&vdpasim->iommu_lock);
vhost_iotlb_reset(vdpasim->iommu);
iommu = &vdpasim->iommu[asid];
vhost_iotlb_reset(iommu);
for (map = vhost_iotlb_itree_first(iotlb, start, last); map; for (map = vhost_iotlb_itree_first(iotlb, start, last); map;
map = vhost_iotlb_itree_next(map, start, last)) { map = vhost_iotlb_itree_next(map, start, last)) {
ret = vhost_iotlb_add_range(vdpasim->iommu, map->start, ret = vhost_iotlb_add_range(iommu, map->start,
map->last, map->addr, map->perm); map->last, map->addr, map->perm);
if (ret) if (ret)
goto err; goto err;
...@@ -556,31 +612,39 @@ static int vdpasim_set_map(struct vdpa_device *vdpa, ...@@ -556,31 +612,39 @@ static int vdpasim_set_map(struct vdpa_device *vdpa,
return 0; return 0;
err: err:
vhost_iotlb_reset(vdpasim->iommu); vhost_iotlb_reset(iommu);
spin_unlock(&vdpasim->iommu_lock); spin_unlock(&vdpasim->iommu_lock);
return ret; return ret;
} }
static int vdpasim_dma_map(struct vdpa_device *vdpa, u64 iova, u64 size, static int vdpasim_dma_map(struct vdpa_device *vdpa, unsigned int asid,
u64 iova, u64 size,
u64 pa, u32 perm, void *opaque) u64 pa, u32 perm, void *opaque)
{ {
struct vdpasim *vdpasim = vdpa_to_sim(vdpa); struct vdpasim *vdpasim = vdpa_to_sim(vdpa);
int ret; int ret;
if (asid >= vdpasim->dev_attr.nas)
return -EINVAL;
spin_lock(&vdpasim->iommu_lock); spin_lock(&vdpasim->iommu_lock);
ret = vhost_iotlb_add_range_ctx(vdpasim->iommu, iova, iova + size - 1, ret = vhost_iotlb_add_range_ctx(&vdpasim->iommu[asid], iova,
pa, perm, opaque); iova + size - 1, pa, perm, opaque);
spin_unlock(&vdpasim->iommu_lock); spin_unlock(&vdpasim->iommu_lock);
return ret; return ret;
} }
static int vdpasim_dma_unmap(struct vdpa_device *vdpa, u64 iova, u64 size) static int vdpasim_dma_unmap(struct vdpa_device *vdpa, unsigned int asid,
u64 iova, u64 size)
{ {
struct vdpasim *vdpasim = vdpa_to_sim(vdpa); struct vdpasim *vdpasim = vdpa_to_sim(vdpa);
if (asid >= vdpasim->dev_attr.nas)
return -EINVAL;
spin_lock(&vdpasim->iommu_lock); spin_lock(&vdpasim->iommu_lock);
vhost_iotlb_del_range(vdpasim->iommu, iova, iova + size - 1); vhost_iotlb_del_range(&vdpasim->iommu[asid], iova, iova + size - 1);
spin_unlock(&vdpasim->iommu_lock); spin_unlock(&vdpasim->iommu_lock);
return 0; return 0;
...@@ -604,8 +668,7 @@ static void vdpasim_free(struct vdpa_device *vdpa) ...@@ -604,8 +668,7 @@ static void vdpasim_free(struct vdpa_device *vdpa)
} }
kvfree(vdpasim->buffer); kvfree(vdpasim->buffer);
if (vdpasim->iommu) vhost_iotlb_free(vdpasim->iommu);
vhost_iotlb_free(vdpasim->iommu);
kfree(vdpasim->vqs); kfree(vdpasim->vqs);
kfree(vdpasim->config); kfree(vdpasim->config);
} }
...@@ -620,6 +683,7 @@ static const struct vdpa_config_ops vdpasim_config_ops = { ...@@ -620,6 +683,7 @@ static const struct vdpa_config_ops vdpasim_config_ops = {
.set_vq_state = vdpasim_set_vq_state, .set_vq_state = vdpasim_set_vq_state,
.get_vq_state = vdpasim_get_vq_state, .get_vq_state = vdpasim_get_vq_state,
.get_vq_align = vdpasim_get_vq_align, .get_vq_align = vdpasim_get_vq_align,
.get_vq_group = vdpasim_get_vq_group,
.get_device_features = vdpasim_get_device_features, .get_device_features = vdpasim_get_device_features,
.set_driver_features = vdpasim_set_driver_features, .set_driver_features = vdpasim_set_driver_features,
.get_driver_features = vdpasim_get_driver_features, .get_driver_features = vdpasim_get_driver_features,
...@@ -635,6 +699,7 @@ static const struct vdpa_config_ops vdpasim_config_ops = { ...@@ -635,6 +699,7 @@ static const struct vdpa_config_ops vdpasim_config_ops = {
.set_config = vdpasim_set_config, .set_config = vdpasim_set_config,
.get_generation = vdpasim_get_generation, .get_generation = vdpasim_get_generation,
.get_iova_range = vdpasim_get_iova_range, .get_iova_range = vdpasim_get_iova_range,
.set_group_asid = vdpasim_set_group_asid,
.dma_map = vdpasim_dma_map, .dma_map = vdpasim_dma_map,
.dma_unmap = vdpasim_dma_unmap, .dma_unmap = vdpasim_dma_unmap,
.free = vdpasim_free, .free = vdpasim_free,
...@@ -650,6 +715,7 @@ static const struct vdpa_config_ops vdpasim_batch_config_ops = { ...@@ -650,6 +715,7 @@ static const struct vdpa_config_ops vdpasim_batch_config_ops = {
.set_vq_state = vdpasim_set_vq_state, .set_vq_state = vdpasim_set_vq_state,
.get_vq_state = vdpasim_get_vq_state, .get_vq_state = vdpasim_get_vq_state,
.get_vq_align = vdpasim_get_vq_align, .get_vq_align = vdpasim_get_vq_align,
.get_vq_group = vdpasim_get_vq_group,
.get_device_features = vdpasim_get_device_features, .get_device_features = vdpasim_get_device_features,
.set_driver_features = vdpasim_set_driver_features, .set_driver_features = vdpasim_set_driver_features,
.get_driver_features = vdpasim_get_driver_features, .get_driver_features = vdpasim_get_driver_features,
...@@ -665,6 +731,7 @@ static const struct vdpa_config_ops vdpasim_batch_config_ops = { ...@@ -665,6 +731,7 @@ static const struct vdpa_config_ops vdpasim_batch_config_ops = {
.set_config = vdpasim_set_config, .set_config = vdpasim_set_config,
.get_generation = vdpasim_get_generation, .get_generation = vdpasim_get_generation,
.get_iova_range = vdpasim_get_iova_range, .get_iova_range = vdpasim_get_iova_range,
.set_group_asid = vdpasim_set_group_asid,
.set_map = vdpasim_set_map, .set_map = vdpasim_set_map,
.free = vdpasim_free, .free = vdpasim_free,
}; };
......
...@@ -41,6 +41,8 @@ struct vdpasim_dev_attr { ...@@ -41,6 +41,8 @@ struct vdpasim_dev_attr {
size_t buffer_size; size_t buffer_size;
int nvqs; int nvqs;
u32 id; u32 id;
u32 ngroups;
u32 nas;
work_func_t work_fn; work_func_t work_fn;
void (*get_config)(struct vdpasim *vdpasim, void *config); void (*get_config)(struct vdpasim *vdpasim, void *config);
...@@ -63,6 +65,7 @@ struct vdpasim { ...@@ -63,6 +65,7 @@ struct vdpasim {
u32 status; u32 status;
u32 generation; u32 generation;
u64 features; u64 features;
u32 groups;
/* spinlock to synchronize iommu table */ /* spinlock to synchronize iommu table */
spinlock_t iommu_lock; spinlock_t iommu_lock;
}; };
......
...@@ -26,9 +26,122 @@ ...@@ -26,9 +26,122 @@
#define DRV_LICENSE "GPL v2" #define DRV_LICENSE "GPL v2"
#define VDPASIM_NET_FEATURES (VDPASIM_FEATURES | \ #define VDPASIM_NET_FEATURES (VDPASIM_FEATURES | \
(1ULL << VIRTIO_NET_F_MAC)) (1ULL << VIRTIO_NET_F_MAC) | \
(1ULL << VIRTIO_NET_F_MTU) | \
(1ULL << VIRTIO_NET_F_CTRL_VQ) | \
(1ULL << VIRTIO_NET_F_CTRL_MAC_ADDR))
#define VDPASIM_NET_VQ_NUM 2 /* 3 virtqueues, 2 address spaces, 2 virtqueue groups */
#define VDPASIM_NET_VQ_NUM 3
#define VDPASIM_NET_AS_NUM 2
#define VDPASIM_NET_GROUP_NUM 2
static void vdpasim_net_complete(struct vdpasim_virtqueue *vq, size_t len)
{
/* Make sure data is wrote before advancing index */
smp_wmb();
vringh_complete_iotlb(&vq->vring, vq->head, len);
/* Make sure used is visible before rasing the interrupt. */
smp_wmb();
local_bh_disable();
if (vringh_need_notify_iotlb(&vq->vring) > 0)
vringh_notify(&vq->vring);
local_bh_enable();
}
static bool receive_filter(struct vdpasim *vdpasim, size_t len)
{
bool modern = vdpasim->features & (1ULL << VIRTIO_F_VERSION_1);
size_t hdr_len = modern ? sizeof(struct virtio_net_hdr_v1) :
sizeof(struct virtio_net_hdr);
struct virtio_net_config *vio_config = vdpasim->config;
if (len < ETH_ALEN + hdr_len)
return false;
if (!strncmp(vdpasim->buffer + hdr_len, vio_config->mac, ETH_ALEN))
return true;
return false;
}
static virtio_net_ctrl_ack vdpasim_handle_ctrl_mac(struct vdpasim *vdpasim,
u8 cmd)
{
struct virtio_net_config *vio_config = vdpasim->config;
struct vdpasim_virtqueue *cvq = &vdpasim->vqs[2];
virtio_net_ctrl_ack status = VIRTIO_NET_ERR;
size_t read;
switch (cmd) {
case VIRTIO_NET_CTRL_MAC_ADDR_SET:
read = vringh_iov_pull_iotlb(&cvq->vring, &cvq->in_iov,
vio_config->mac, ETH_ALEN);
if (read == ETH_ALEN)
status = VIRTIO_NET_OK;
break;
default:
break;
}
return status;
}
static void vdpasim_handle_cvq(struct vdpasim *vdpasim)
{
struct vdpasim_virtqueue *cvq = &vdpasim->vqs[2];
virtio_net_ctrl_ack status = VIRTIO_NET_ERR;
struct virtio_net_ctrl_hdr ctrl;
size_t read, write;
int err;
if (!(vdpasim->features & (1ULL << VIRTIO_NET_F_CTRL_VQ)))
return;
if (!cvq->ready)
return;
while (true) {
err = vringh_getdesc_iotlb(&cvq->vring, &cvq->in_iov,
&cvq->out_iov,
&cvq->head, GFP_ATOMIC);
if (err <= 0)
break;
read = vringh_iov_pull_iotlb(&cvq->vring, &cvq->in_iov, &ctrl,
sizeof(ctrl));
if (read != sizeof(ctrl))
break;
switch (ctrl.class) {
case VIRTIO_NET_CTRL_MAC:
status = vdpasim_handle_ctrl_mac(vdpasim, ctrl.cmd);
break;
default:
break;
}
/* Make sure data is wrote before advancing index */
smp_wmb();
write = vringh_iov_push_iotlb(&cvq->vring, &cvq->out_iov,
&status, sizeof(status));
vringh_complete_iotlb(&cvq->vring, cvq->head, write);
vringh_kiov_cleanup(&cvq->in_iov);
vringh_kiov_cleanup(&cvq->out_iov);
/* Make sure used is visible before rasing the interrupt. */
smp_wmb();
local_bh_disable();
if (cvq->cb)
cvq->cb(cvq->private);
local_bh_enable();
}
}
static void vdpasim_net_work(struct work_struct *work) static void vdpasim_net_work(struct work_struct *work)
{ {
...@@ -36,7 +149,6 @@ static void vdpasim_net_work(struct work_struct *work) ...@@ -36,7 +149,6 @@ static void vdpasim_net_work(struct work_struct *work)
struct vdpasim_virtqueue *txq = &vdpasim->vqs[1]; struct vdpasim_virtqueue *txq = &vdpasim->vqs[1];
struct vdpasim_virtqueue *rxq = &vdpasim->vqs[0]; struct vdpasim_virtqueue *rxq = &vdpasim->vqs[0];
ssize_t read, write; ssize_t read, write;
size_t total_write;
int pkts = 0; int pkts = 0;
int err; int err;
...@@ -45,53 +157,40 @@ static void vdpasim_net_work(struct work_struct *work) ...@@ -45,53 +157,40 @@ static void vdpasim_net_work(struct work_struct *work)
if (!(vdpasim->status & VIRTIO_CONFIG_S_DRIVER_OK)) if (!(vdpasim->status & VIRTIO_CONFIG_S_DRIVER_OK))
goto out; goto out;
vdpasim_handle_cvq(vdpasim);
if (!txq->ready || !rxq->ready) if (!txq->ready || !rxq->ready)
goto out; goto out;
while (true) { while (true) {
total_write = 0;
err = vringh_getdesc_iotlb(&txq->vring, &txq->out_iov, NULL, err = vringh_getdesc_iotlb(&txq->vring, &txq->out_iov, NULL,
&txq->head, GFP_ATOMIC); &txq->head, GFP_ATOMIC);
if (err <= 0) if (err <= 0)
break; break;
read = vringh_iov_pull_iotlb(&txq->vring, &txq->out_iov,
vdpasim->buffer,
PAGE_SIZE);
if (!receive_filter(vdpasim, read)) {
vdpasim_net_complete(txq, 0);
continue;
}
err = vringh_getdesc_iotlb(&rxq->vring, NULL, &rxq->in_iov, err = vringh_getdesc_iotlb(&rxq->vring, NULL, &rxq->in_iov,
&rxq->head, GFP_ATOMIC); &rxq->head, GFP_ATOMIC);
if (err <= 0) { if (err <= 0) {
vringh_complete_iotlb(&txq->vring, txq->head, 0); vdpasim_net_complete(txq, 0);
break; break;
} }
while (true) { write = vringh_iov_push_iotlb(&rxq->vring, &rxq->in_iov,
read = vringh_iov_pull_iotlb(&txq->vring, &txq->out_iov, vdpasim->buffer, read);
vdpasim->buffer, if (write <= 0)
PAGE_SIZE); break;
if (read <= 0)
break;
write = vringh_iov_push_iotlb(&rxq->vring, &rxq->in_iov,
vdpasim->buffer, read);
if (write <= 0)
break;
total_write += write;
}
/* Make sure data is wrote before advancing index */
smp_wmb();
vringh_complete_iotlb(&txq->vring, txq->head, 0);
vringh_complete_iotlb(&rxq->vring, rxq->head, total_write);
/* Make sure used is visible before rasing the interrupt. */
smp_wmb();
local_bh_disable(); vdpasim_net_complete(txq, 0);
if (vringh_need_notify_iotlb(&txq->vring) > 0) vdpasim_net_complete(rxq, write);
vringh_notify(&txq->vring);
if (vringh_need_notify_iotlb(&rxq->vring) > 0)
vringh_notify(&rxq->vring);
local_bh_enable();
if (++pkts > 4) { if (++pkts > 4) {
schedule_work(&vdpasim->work); schedule_work(&vdpasim->work);
...@@ -145,6 +244,8 @@ static int vdpasim_net_dev_add(struct vdpa_mgmt_dev *mdev, const char *name, ...@@ -145,6 +244,8 @@ static int vdpasim_net_dev_add(struct vdpa_mgmt_dev *mdev, const char *name,
dev_attr.id = VIRTIO_ID_NET; dev_attr.id = VIRTIO_ID_NET;
dev_attr.supported_features = VDPASIM_NET_FEATURES; dev_attr.supported_features = VDPASIM_NET_FEATURES;
dev_attr.nvqs = VDPASIM_NET_VQ_NUM; dev_attr.nvqs = VDPASIM_NET_VQ_NUM;
dev_attr.ngroups = VDPASIM_NET_GROUP_NUM;
dev_attr.nas = VDPASIM_NET_AS_NUM;
dev_attr.config_size = sizeof(struct virtio_net_config); dev_attr.config_size = sizeof(struct virtio_net_config);
dev_attr.get_config = vdpasim_net_get_config; dev_attr.get_config = vdpasim_net_get_config;
dev_attr.work_fn = vdpasim_net_work; dev_attr.work_fn = vdpasim_net_work;
......
...@@ -693,6 +693,7 @@ static u32 vduse_vdpa_get_generation(struct vdpa_device *vdpa) ...@@ -693,6 +693,7 @@ static u32 vduse_vdpa_get_generation(struct vdpa_device *vdpa)
} }
static int vduse_vdpa_set_map(struct vdpa_device *vdpa, static int vduse_vdpa_set_map(struct vdpa_device *vdpa,
unsigned int asid,
struct vhost_iotlb *iotlb) struct vhost_iotlb *iotlb)
{ {
struct vduse_dev *dev = vdpa_to_vduse(vdpa); struct vduse_dev *dev = vdpa_to_vduse(vdpa);
...@@ -1495,7 +1496,7 @@ static int vduse_dev_init_vdpa(struct vduse_dev *dev, const char *name) ...@@ -1495,7 +1496,7 @@ static int vduse_dev_init_vdpa(struct vduse_dev *dev, const char *name)
return -EEXIST; return -EEXIST;
vdev = vdpa_alloc_device(struct vduse_vdpa, vdpa, dev->dev, vdev = vdpa_alloc_device(struct vduse_vdpa, vdpa, dev->dev,
&vduse_vdpa_config_ops, name, true); &vduse_vdpa_config_ops, 1, 1, name, true);
if (IS_ERR(vdev)) if (IS_ERR(vdev))
return PTR_ERR(vdev); return PTR_ERR(vdev);
......
...@@ -32,7 +32,7 @@ struct vp_vring { ...@@ -32,7 +32,7 @@ struct vp_vring {
struct vp_vdpa { struct vp_vdpa {
struct vdpa_device vdpa; struct vdpa_device vdpa;
struct virtio_pci_modern_device mdev; struct virtio_pci_modern_device *mdev;
struct vp_vring *vring; struct vp_vring *vring;
struct vdpa_callback config_cb; struct vdpa_callback config_cb;
char msix_name[VP_VDPA_NAME_SIZE]; char msix_name[VP_VDPA_NAME_SIZE];
...@@ -41,6 +41,12 @@ struct vp_vdpa { ...@@ -41,6 +41,12 @@ struct vp_vdpa {
int vectors; int vectors;
}; };
struct vp_vdpa_mgmtdev {
struct vdpa_mgmt_dev mgtdev;
struct virtio_pci_modern_device *mdev;
struct vp_vdpa *vp_vdpa;
};
static struct vp_vdpa *vdpa_to_vp(struct vdpa_device *vdpa) static struct vp_vdpa *vdpa_to_vp(struct vdpa_device *vdpa)
{ {
return container_of(vdpa, struct vp_vdpa, vdpa); return container_of(vdpa, struct vp_vdpa, vdpa);
...@@ -50,7 +56,12 @@ static struct virtio_pci_modern_device *vdpa_to_mdev(struct vdpa_device *vdpa) ...@@ -50,7 +56,12 @@ static struct virtio_pci_modern_device *vdpa_to_mdev(struct vdpa_device *vdpa)
{ {
struct vp_vdpa *vp_vdpa = vdpa_to_vp(vdpa); struct vp_vdpa *vp_vdpa = vdpa_to_vp(vdpa);
return &vp_vdpa->mdev; return vp_vdpa->mdev;
}
static struct virtio_pci_modern_device *vp_vdpa_to_mdev(struct vp_vdpa *vp_vdpa)
{
return vp_vdpa->mdev;
} }
static u64 vp_vdpa_get_device_features(struct vdpa_device *vdpa) static u64 vp_vdpa_get_device_features(struct vdpa_device *vdpa)
...@@ -96,7 +107,7 @@ static int vp_vdpa_get_vq_irq(struct vdpa_device *vdpa, u16 idx) ...@@ -96,7 +107,7 @@ static int vp_vdpa_get_vq_irq(struct vdpa_device *vdpa, u16 idx)
static void vp_vdpa_free_irq(struct vp_vdpa *vp_vdpa) static void vp_vdpa_free_irq(struct vp_vdpa *vp_vdpa)
{ {
struct virtio_pci_modern_device *mdev = &vp_vdpa->mdev; struct virtio_pci_modern_device *mdev = vp_vdpa_to_mdev(vp_vdpa);
struct pci_dev *pdev = mdev->pci_dev; struct pci_dev *pdev = mdev->pci_dev;
int i; int i;
...@@ -143,7 +154,7 @@ static irqreturn_t vp_vdpa_config_handler(int irq, void *arg) ...@@ -143,7 +154,7 @@ static irqreturn_t vp_vdpa_config_handler(int irq, void *arg)
static int vp_vdpa_request_irq(struct vp_vdpa *vp_vdpa) static int vp_vdpa_request_irq(struct vp_vdpa *vp_vdpa)
{ {
struct virtio_pci_modern_device *mdev = &vp_vdpa->mdev; struct virtio_pci_modern_device *mdev = vp_vdpa_to_mdev(vp_vdpa);
struct pci_dev *pdev = mdev->pci_dev; struct pci_dev *pdev = mdev->pci_dev;
int i, ret, irq; int i, ret, irq;
int queues = vp_vdpa->queues; int queues = vp_vdpa->queues;
...@@ -198,7 +209,7 @@ static int vp_vdpa_request_irq(struct vp_vdpa *vp_vdpa) ...@@ -198,7 +209,7 @@ static int vp_vdpa_request_irq(struct vp_vdpa *vp_vdpa)
static void vp_vdpa_set_status(struct vdpa_device *vdpa, u8 status) static void vp_vdpa_set_status(struct vdpa_device *vdpa, u8 status)
{ {
struct vp_vdpa *vp_vdpa = vdpa_to_vp(vdpa); struct vp_vdpa *vp_vdpa = vdpa_to_vp(vdpa);
struct virtio_pci_modern_device *mdev = &vp_vdpa->mdev; struct virtio_pci_modern_device *mdev = vp_vdpa_to_mdev(vp_vdpa);
u8 s = vp_vdpa_get_status(vdpa); u8 s = vp_vdpa_get_status(vdpa);
if (status & VIRTIO_CONFIG_S_DRIVER_OK && if (status & VIRTIO_CONFIG_S_DRIVER_OK &&
...@@ -212,7 +223,7 @@ static void vp_vdpa_set_status(struct vdpa_device *vdpa, u8 status) ...@@ -212,7 +223,7 @@ static void vp_vdpa_set_status(struct vdpa_device *vdpa, u8 status)
static int vp_vdpa_reset(struct vdpa_device *vdpa) static int vp_vdpa_reset(struct vdpa_device *vdpa)
{ {
struct vp_vdpa *vp_vdpa = vdpa_to_vp(vdpa); struct vp_vdpa *vp_vdpa = vdpa_to_vp(vdpa);
struct virtio_pci_modern_device *mdev = &vp_vdpa->mdev; struct virtio_pci_modern_device *mdev = vp_vdpa_to_mdev(vp_vdpa);
u8 s = vp_vdpa_get_status(vdpa); u8 s = vp_vdpa_get_status(vdpa);
vp_modern_set_status(mdev, 0); vp_modern_set_status(mdev, 0);
...@@ -372,7 +383,7 @@ static void vp_vdpa_get_config(struct vdpa_device *vdpa, ...@@ -372,7 +383,7 @@ static void vp_vdpa_get_config(struct vdpa_device *vdpa,
void *buf, unsigned int len) void *buf, unsigned int len)
{ {
struct vp_vdpa *vp_vdpa = vdpa_to_vp(vdpa); struct vp_vdpa *vp_vdpa = vdpa_to_vp(vdpa);
struct virtio_pci_modern_device *mdev = &vp_vdpa->mdev; struct virtio_pci_modern_device *mdev = vp_vdpa_to_mdev(vp_vdpa);
u8 old, new; u8 old, new;
u8 *p; u8 *p;
int i; int i;
...@@ -392,7 +403,7 @@ static void vp_vdpa_set_config(struct vdpa_device *vdpa, ...@@ -392,7 +403,7 @@ static void vp_vdpa_set_config(struct vdpa_device *vdpa,
unsigned int len) unsigned int len)
{ {
struct vp_vdpa *vp_vdpa = vdpa_to_vp(vdpa); struct vp_vdpa *vp_vdpa = vdpa_to_vp(vdpa);
struct virtio_pci_modern_device *mdev = &vp_vdpa->mdev; struct virtio_pci_modern_device *mdev = vp_vdpa_to_mdev(vp_vdpa);
const u8 *p = buf; const u8 *p = buf;
int i; int i;
...@@ -412,7 +423,7 @@ static struct vdpa_notification_area ...@@ -412,7 +423,7 @@ static struct vdpa_notification_area
vp_vdpa_get_vq_notification(struct vdpa_device *vdpa, u16 qid) vp_vdpa_get_vq_notification(struct vdpa_device *vdpa, u16 qid)
{ {
struct vp_vdpa *vp_vdpa = vdpa_to_vp(vdpa); struct vp_vdpa *vp_vdpa = vdpa_to_vp(vdpa);
struct virtio_pci_modern_device *mdev = &vp_vdpa->mdev; struct virtio_pci_modern_device *mdev = vp_vdpa_to_mdev(vp_vdpa);
struct vdpa_notification_area notify; struct vdpa_notification_area notify;
notify.addr = vp_vdpa->vring[qid].notify_pa; notify.addr = vp_vdpa->vring[qid].notify_pa;
...@@ -454,38 +465,31 @@ static void vp_vdpa_free_irq_vectors(void *data) ...@@ -454,38 +465,31 @@ static void vp_vdpa_free_irq_vectors(void *data)
pci_free_irq_vectors(data); pci_free_irq_vectors(data);
} }
static int vp_vdpa_probe(struct pci_dev *pdev, const struct pci_device_id *id) static int vp_vdpa_dev_add(struct vdpa_mgmt_dev *v_mdev, const char *name,
const struct vdpa_dev_set_config *add_config)
{ {
struct virtio_pci_modern_device *mdev; struct vp_vdpa_mgmtdev *vp_vdpa_mgtdev =
container_of(v_mdev, struct vp_vdpa_mgmtdev, mgtdev);
struct virtio_pci_modern_device *mdev = vp_vdpa_mgtdev->mdev;
struct pci_dev *pdev = mdev->pci_dev;
struct device *dev = &pdev->dev; struct device *dev = &pdev->dev;
struct vp_vdpa *vp_vdpa; struct vp_vdpa *vp_vdpa = NULL;
int ret, i; int ret, i;
ret = pcim_enable_device(pdev);
if (ret)
return ret;
vp_vdpa = vdpa_alloc_device(struct vp_vdpa, vdpa, vp_vdpa = vdpa_alloc_device(struct vp_vdpa, vdpa,
dev, &vp_vdpa_ops, NULL, false); dev, &vp_vdpa_ops, 1, 1, name, false);
if (IS_ERR(vp_vdpa)) { if (IS_ERR(vp_vdpa)) {
dev_err(dev, "vp_vdpa: Failed to allocate vDPA structure\n"); dev_err(dev, "vp_vdpa: Failed to allocate vDPA structure\n");
return PTR_ERR(vp_vdpa); return PTR_ERR(vp_vdpa);
} }
mdev = &vp_vdpa->mdev; vp_vdpa_mgtdev->vp_vdpa = vp_vdpa;
mdev->pci_dev = pdev;
ret = vp_modern_probe(mdev);
if (ret) {
dev_err(&pdev->dev, "Failed to probe modern PCI device\n");
goto err;
}
pci_set_master(pdev);
pci_set_drvdata(pdev, vp_vdpa);
vp_vdpa->vdpa.dma_dev = &pdev->dev; vp_vdpa->vdpa.dma_dev = &pdev->dev;
vp_vdpa->queues = vp_modern_get_num_queues(mdev); vp_vdpa->queues = vp_modern_get_num_queues(mdev);
vp_vdpa->mdev = mdev;
ret = devm_add_action_or_reset(dev, vp_vdpa_free_irq_vectors, pdev); ret = devm_add_action_or_reset(dev, vp_vdpa_free_irq_vectors, pdev);
if (ret) { if (ret) {
...@@ -516,7 +520,8 @@ static int vp_vdpa_probe(struct pci_dev *pdev, const struct pci_device_id *id) ...@@ -516,7 +520,8 @@ static int vp_vdpa_probe(struct pci_dev *pdev, const struct pci_device_id *id)
} }
vp_vdpa->config_irq = VIRTIO_MSI_NO_VECTOR; vp_vdpa->config_irq = VIRTIO_MSI_NO_VECTOR;
ret = vdpa_register_device(&vp_vdpa->vdpa, vp_vdpa->queues); vp_vdpa->vdpa.mdev = &vp_vdpa_mgtdev->mgtdev;
ret = _vdpa_register_device(&vp_vdpa->vdpa, vp_vdpa->queues);
if (ret) { if (ret) {
dev_err(&pdev->dev, "Failed to register to vdpa bus\n"); dev_err(&pdev->dev, "Failed to register to vdpa bus\n");
goto err; goto err;
...@@ -529,12 +534,104 @@ static int vp_vdpa_probe(struct pci_dev *pdev, const struct pci_device_id *id) ...@@ -529,12 +534,104 @@ static int vp_vdpa_probe(struct pci_dev *pdev, const struct pci_device_id *id)
return ret; return ret;
} }
static void vp_vdpa_dev_del(struct vdpa_mgmt_dev *v_mdev,
struct vdpa_device *dev)
{
struct vp_vdpa_mgmtdev *vp_vdpa_mgtdev =
container_of(v_mdev, struct vp_vdpa_mgmtdev, mgtdev);
struct vp_vdpa *vp_vdpa = vp_vdpa_mgtdev->vp_vdpa;
_vdpa_unregister_device(&vp_vdpa->vdpa);
vp_vdpa_mgtdev->vp_vdpa = NULL;
}
static const struct vdpa_mgmtdev_ops vp_vdpa_mdev_ops = {
.dev_add = vp_vdpa_dev_add,
.dev_del = vp_vdpa_dev_del,
};
static int vp_vdpa_probe(struct pci_dev *pdev, const struct pci_device_id *id)
{
struct vp_vdpa_mgmtdev *vp_vdpa_mgtdev = NULL;
struct vdpa_mgmt_dev *mgtdev;
struct device *dev = &pdev->dev;
struct virtio_pci_modern_device *mdev = NULL;
struct virtio_device_id *mdev_id = NULL;
int err;
vp_vdpa_mgtdev = kzalloc(sizeof(*vp_vdpa_mgtdev), GFP_KERNEL);
if (!vp_vdpa_mgtdev)
return -ENOMEM;
mgtdev = &vp_vdpa_mgtdev->mgtdev;
mgtdev->ops = &vp_vdpa_mdev_ops;
mgtdev->device = dev;
mdev = kzalloc(sizeof(struct virtio_pci_modern_device), GFP_KERNEL);
if (!mdev) {
err = -ENOMEM;
goto mdev_err;
}
mdev_id = kzalloc(sizeof(struct virtio_device_id), GFP_KERNEL);
if (!mdev_id) {
err = -ENOMEM;
goto mdev_id_err;
}
vp_vdpa_mgtdev->mdev = mdev;
mdev->pci_dev = pdev;
err = pcim_enable_device(pdev);
if (err) {
goto probe_err;
}
err = vp_modern_probe(mdev);
if (err) {
dev_err(&pdev->dev, "Failed to probe modern PCI device\n");
goto probe_err;
}
mdev_id->device = mdev->id.device;
mdev_id->vendor = mdev->id.vendor;
mgtdev->id_table = mdev_id;
mgtdev->max_supported_vqs = vp_modern_get_num_queues(mdev);
mgtdev->supported_features = vp_modern_get_features(mdev);
pci_set_master(pdev);
pci_set_drvdata(pdev, vp_vdpa_mgtdev);
err = vdpa_mgmtdev_register(mgtdev);
if (err) {
dev_err(&pdev->dev, "Failed to register vdpa mgmtdev device\n");
goto register_err;
}
return 0;
register_err:
vp_modern_remove(vp_vdpa_mgtdev->mdev);
probe_err:
kfree(mdev_id);
mdev_id_err:
kfree(mdev);
mdev_err:
kfree(vp_vdpa_mgtdev);
return err;
}
static void vp_vdpa_remove(struct pci_dev *pdev) static void vp_vdpa_remove(struct pci_dev *pdev)
{ {
struct vp_vdpa *vp_vdpa = pci_get_drvdata(pdev); struct vp_vdpa_mgmtdev *vp_vdpa_mgtdev = pci_get_drvdata(pdev);
struct virtio_pci_modern_device *mdev = NULL;
vp_modern_remove(&vp_vdpa->mdev); mdev = vp_vdpa_mgtdev->mdev;
vdpa_unregister_device(&vp_vdpa->vdpa); vp_modern_remove(mdev);
vdpa_mgmtdev_unregister(&vp_vdpa_mgtdev->mgtdev);
kfree(&vp_vdpa_mgtdev->mgtdev.id_table);
kfree(mdev);
kfree(vp_vdpa_mgtdev);
} }
static struct pci_driver vp_vdpa_driver = { static struct pci_driver vp_vdpa_driver = {
......
...@@ -125,6 +125,23 @@ void vhost_iotlb_del_range(struct vhost_iotlb *iotlb, u64 start, u64 last) ...@@ -125,6 +125,23 @@ void vhost_iotlb_del_range(struct vhost_iotlb *iotlb, u64 start, u64 last)
} }
EXPORT_SYMBOL_GPL(vhost_iotlb_del_range); EXPORT_SYMBOL_GPL(vhost_iotlb_del_range);
/**
* vhost_iotlb_init - initialize a vhost IOTLB
* @iotlb: the IOTLB that needs to be initialized
* @limit: maximum number of IOTLB entries
* @flags: VHOST_IOTLB_FLAG_XXX
*/
void vhost_iotlb_init(struct vhost_iotlb *iotlb, unsigned int limit,
unsigned int flags)
{
iotlb->root = RB_ROOT_CACHED;
iotlb->limit = limit;
iotlb->nmaps = 0;
iotlb->flags = flags;
INIT_LIST_HEAD(&iotlb->list);
}
EXPORT_SYMBOL_GPL(vhost_iotlb_init);
/** /**
* vhost_iotlb_alloc - add a new vhost IOTLB * vhost_iotlb_alloc - add a new vhost IOTLB
* @limit: maximum number of IOTLB entries * @limit: maximum number of IOTLB entries
...@@ -139,11 +156,7 @@ struct vhost_iotlb *vhost_iotlb_alloc(unsigned int limit, unsigned int flags) ...@@ -139,11 +156,7 @@ struct vhost_iotlb *vhost_iotlb_alloc(unsigned int limit, unsigned int flags)
if (!iotlb) if (!iotlb)
return NULL; return NULL;
iotlb->root = RB_ROOT_CACHED; vhost_iotlb_init(iotlb, limit, flags);
iotlb->limit = limit;
iotlb->nmaps = 0;
iotlb->flags = flags;
INIT_LIST_HEAD(&iotlb->list);
return iotlb; return iotlb;
} }
......
...@@ -1374,16 +1374,9 @@ static void vhost_net_stop(struct vhost_net *n, struct socket **tx_sock, ...@@ -1374,16 +1374,9 @@ static void vhost_net_stop(struct vhost_net *n, struct socket **tx_sock,
*rx_sock = vhost_net_stop_vq(n, &n->vqs[VHOST_NET_VQ_RX].vq); *rx_sock = vhost_net_stop_vq(n, &n->vqs[VHOST_NET_VQ_RX].vq);
} }
static void vhost_net_flush_vq(struct vhost_net *n, int index)
{
vhost_poll_flush(n->poll + index);
vhost_poll_flush(&n->vqs[index].vq.poll);
}
static void vhost_net_flush(struct vhost_net *n) static void vhost_net_flush(struct vhost_net *n)
{ {
vhost_net_flush_vq(n, VHOST_NET_VQ_TX); vhost_dev_flush(&n->dev);
vhost_net_flush_vq(n, VHOST_NET_VQ_RX);
if (n->vqs[VHOST_NET_VQ_TX].ubufs) { if (n->vqs[VHOST_NET_VQ_TX].ubufs) {
mutex_lock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex); mutex_lock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex);
n->tx_flush = true; n->tx_flush = true;
...@@ -1572,7 +1565,7 @@ static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd) ...@@ -1572,7 +1565,7 @@ static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd)
} }
if (oldsock) { if (oldsock) {
vhost_net_flush_vq(n, index); vhost_dev_flush(&n->dev);
sockfd_put(oldsock); sockfd_put(oldsock);
} }
......
...@@ -1436,7 +1436,7 @@ static void vhost_scsi_flush(struct vhost_scsi *vs) ...@@ -1436,7 +1436,7 @@ static void vhost_scsi_flush(struct vhost_scsi *vs)
kref_put(&old_inflight[i]->kref, vhost_scsi_done_inflight); kref_put(&old_inflight[i]->kref, vhost_scsi_done_inflight);
/* Flush both the vhost poll and vhost work */ /* Flush both the vhost poll and vhost work */
vhost_work_dev_flush(&vs->dev); vhost_dev_flush(&vs->dev);
/* Wait for all reqs issued before the flush to be finished */ /* Wait for all reqs issued before the flush to be finished */
for (i = 0; i < VHOST_SCSI_MAX_VQ; i++) for (i = 0; i < VHOST_SCSI_MAX_VQ; i++)
...@@ -1827,8 +1827,6 @@ static int vhost_scsi_release(struct inode *inode, struct file *f) ...@@ -1827,8 +1827,6 @@ static int vhost_scsi_release(struct inode *inode, struct file *f)
vhost_scsi_clear_endpoint(vs, &t); vhost_scsi_clear_endpoint(vs, &t);
vhost_dev_stop(&vs->dev); vhost_dev_stop(&vs->dev);
vhost_dev_cleanup(&vs->dev); vhost_dev_cleanup(&vs->dev);
/* Jobs can re-queue themselves in evt kick handler. Do extra flush. */
vhost_scsi_flush(vs);
kfree(vs->dev.vqs); kfree(vs->dev.vqs);
kvfree(vs); kvfree(vs);
return 0; return 0;
......
...@@ -144,14 +144,9 @@ static void vhost_test_stop(struct vhost_test *n, void **privatep) ...@@ -144,14 +144,9 @@ static void vhost_test_stop(struct vhost_test *n, void **privatep)
*privatep = vhost_test_stop_vq(n, n->vqs + VHOST_TEST_VQ); *privatep = vhost_test_stop_vq(n, n->vqs + VHOST_TEST_VQ);
} }
static void vhost_test_flush_vq(struct vhost_test *n, int index)
{
vhost_poll_flush(&n->vqs[index].poll);
}
static void vhost_test_flush(struct vhost_test *n) static void vhost_test_flush(struct vhost_test *n)
{ {
vhost_test_flush_vq(n, VHOST_TEST_VQ); vhost_dev_flush(&n->dev);
} }
static int vhost_test_release(struct inode *inode, struct file *f) static int vhost_test_release(struct inode *inode, struct file *f)
...@@ -163,9 +158,6 @@ static int vhost_test_release(struct inode *inode, struct file *f) ...@@ -163,9 +158,6 @@ static int vhost_test_release(struct inode *inode, struct file *f)
vhost_test_flush(n); vhost_test_flush(n);
vhost_dev_stop(&n->dev); vhost_dev_stop(&n->dev);
vhost_dev_cleanup(&n->dev); vhost_dev_cleanup(&n->dev);
/* We do an extra flush before freeing memory,
* since jobs can re-queue themselves. */
vhost_test_flush(n);
kfree(n->dev.vqs); kfree(n->dev.vqs);
kfree(n); kfree(n);
return 0; return 0;
...@@ -210,7 +202,7 @@ static long vhost_test_run(struct vhost_test *n, int test) ...@@ -210,7 +202,7 @@ static long vhost_test_run(struct vhost_test *n, int test)
goto err; goto err;
if (oldpriv) { if (oldpriv) {
vhost_test_flush_vq(n, index); vhost_test_flush(n);
} }
} }
...@@ -303,7 +295,7 @@ static long vhost_test_set_backend(struct vhost_test *n, unsigned index, int fd) ...@@ -303,7 +295,7 @@ static long vhost_test_set_backend(struct vhost_test *n, unsigned index, int fd)
mutex_unlock(&vq->mutex); mutex_unlock(&vq->mutex);
if (enable) { if (enable) {
vhost_test_flush_vq(n, index); vhost_test_flush(n);
} }
mutex_unlock(&n->dev.mutex); mutex_unlock(&n->dev.mutex);
......
This diff is collapsed.
...@@ -231,7 +231,7 @@ void vhost_poll_stop(struct vhost_poll *poll) ...@@ -231,7 +231,7 @@ void vhost_poll_stop(struct vhost_poll *poll)
} }
EXPORT_SYMBOL_GPL(vhost_poll_stop); EXPORT_SYMBOL_GPL(vhost_poll_stop);
void vhost_work_dev_flush(struct vhost_dev *dev) void vhost_dev_flush(struct vhost_dev *dev)
{ {
struct vhost_flush_struct flush; struct vhost_flush_struct flush;
...@@ -243,15 +243,7 @@ void vhost_work_dev_flush(struct vhost_dev *dev) ...@@ -243,15 +243,7 @@ void vhost_work_dev_flush(struct vhost_dev *dev)
wait_for_completion(&flush.wait_event); wait_for_completion(&flush.wait_event);
} }
} }
EXPORT_SYMBOL_GPL(vhost_work_dev_flush); EXPORT_SYMBOL_GPL(vhost_dev_flush);
/* Flush any work that has been scheduled. When calling this, don't hold any
* locks that are also used by the callback. */
void vhost_poll_flush(struct vhost_poll *poll)
{
vhost_work_dev_flush(poll->dev);
}
EXPORT_SYMBOL_GPL(vhost_poll_flush);
void vhost_work_queue(struct vhost_dev *dev, struct vhost_work *work) void vhost_work_queue(struct vhost_dev *dev, struct vhost_work *work)
{ {
...@@ -468,7 +460,7 @@ void vhost_dev_init(struct vhost_dev *dev, ...@@ -468,7 +460,7 @@ void vhost_dev_init(struct vhost_dev *dev,
struct vhost_virtqueue **vqs, int nvqs, struct vhost_virtqueue **vqs, int nvqs,
int iov_limit, int weight, int byte_weight, int iov_limit, int weight, int byte_weight,
bool use_worker, bool use_worker,
int (*msg_handler)(struct vhost_dev *dev, int (*msg_handler)(struct vhost_dev *dev, u32 asid,
struct vhost_iotlb_msg *msg)) struct vhost_iotlb_msg *msg))
{ {
struct vhost_virtqueue *vq; struct vhost_virtqueue *vq;
...@@ -538,7 +530,7 @@ static int vhost_attach_cgroups(struct vhost_dev *dev) ...@@ -538,7 +530,7 @@ static int vhost_attach_cgroups(struct vhost_dev *dev)
attach.owner = current; attach.owner = current;
vhost_work_init(&attach.work, vhost_attach_cgroups_work); vhost_work_init(&attach.work, vhost_attach_cgroups_work);
vhost_work_queue(dev, &attach.work); vhost_work_queue(dev, &attach.work);
vhost_work_dev_flush(dev); vhost_dev_flush(dev);
return attach.ret; return attach.ret;
} }
...@@ -661,11 +653,11 @@ void vhost_dev_stop(struct vhost_dev *dev) ...@@ -661,11 +653,11 @@ void vhost_dev_stop(struct vhost_dev *dev)
int i; int i;
for (i = 0; i < dev->nvqs; ++i) { for (i = 0; i < dev->nvqs; ++i) {
if (dev->vqs[i]->kick && dev->vqs[i]->handle_kick) { if (dev->vqs[i]->kick && dev->vqs[i]->handle_kick)
vhost_poll_stop(&dev->vqs[i]->poll); vhost_poll_stop(&dev->vqs[i]->poll);
vhost_poll_flush(&dev->vqs[i]->poll);
}
} }
vhost_dev_flush(dev);
} }
EXPORT_SYMBOL_GPL(vhost_dev_stop); EXPORT_SYMBOL_GPL(vhost_dev_stop);
...@@ -1090,11 +1082,14 @@ static bool umem_access_ok(u64 uaddr, u64 size, int access) ...@@ -1090,11 +1082,14 @@ static bool umem_access_ok(u64 uaddr, u64 size, int access)
return true; return true;
} }
static int vhost_process_iotlb_msg(struct vhost_dev *dev, static int vhost_process_iotlb_msg(struct vhost_dev *dev, u32 asid,
struct vhost_iotlb_msg *msg) struct vhost_iotlb_msg *msg)
{ {
int ret = 0; int ret = 0;
if (asid != 0)
return -EINVAL;
mutex_lock(&dev->mutex); mutex_lock(&dev->mutex);
vhost_dev_lock_vqs(dev); vhost_dev_lock_vqs(dev);
switch (msg->type) { switch (msg->type) {
...@@ -1141,6 +1136,7 @@ ssize_t vhost_chr_write_iter(struct vhost_dev *dev, ...@@ -1141,6 +1136,7 @@ ssize_t vhost_chr_write_iter(struct vhost_dev *dev,
struct vhost_iotlb_msg msg; struct vhost_iotlb_msg msg;
size_t offset; size_t offset;
int type, ret; int type, ret;
u32 asid = 0;
ret = copy_from_iter(&type, sizeof(type), from); ret = copy_from_iter(&type, sizeof(type), from);
if (ret != sizeof(type)) { if (ret != sizeof(type)) {
...@@ -1156,7 +1152,16 @@ ssize_t vhost_chr_write_iter(struct vhost_dev *dev, ...@@ -1156,7 +1152,16 @@ ssize_t vhost_chr_write_iter(struct vhost_dev *dev,
offset = offsetof(struct vhost_msg, iotlb) - sizeof(int); offset = offsetof(struct vhost_msg, iotlb) - sizeof(int);
break; break;
case VHOST_IOTLB_MSG_V2: case VHOST_IOTLB_MSG_V2:
offset = sizeof(__u32); if (vhost_backend_has_feature(dev->vqs[0],
VHOST_BACKEND_F_IOTLB_ASID)) {
ret = copy_from_iter(&asid, sizeof(asid), from);
if (ret != sizeof(asid)) {
ret = -EINVAL;
goto done;
}
offset = 0;
} else
offset = sizeof(__u32);
break; break;
default: default:
ret = -EINVAL; ret = -EINVAL;
...@@ -1178,9 +1183,9 @@ ssize_t vhost_chr_write_iter(struct vhost_dev *dev, ...@@ -1178,9 +1183,9 @@ ssize_t vhost_chr_write_iter(struct vhost_dev *dev,
} }
if (dev->msg_handler) if (dev->msg_handler)
ret = dev->msg_handler(dev, &msg); ret = dev->msg_handler(dev, asid, &msg);
else else
ret = vhost_process_iotlb_msg(dev, &msg); ret = vhost_process_iotlb_msg(dev, asid, &msg);
if (ret) { if (ret) {
ret = -EFAULT; ret = -EFAULT;
goto done; goto done;
...@@ -1719,7 +1724,7 @@ long vhost_vring_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *arg ...@@ -1719,7 +1724,7 @@ long vhost_vring_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *arg
mutex_unlock(&vq->mutex); mutex_unlock(&vq->mutex);
if (pollstop && vq->handle_kick) if (pollstop && vq->handle_kick)
vhost_poll_flush(&vq->poll); vhost_dev_flush(vq->poll.dev);
return r; return r;
} }
EXPORT_SYMBOL_GPL(vhost_vring_ioctl); EXPORT_SYMBOL_GPL(vhost_vring_ioctl);
......
...@@ -44,9 +44,8 @@ void vhost_poll_init(struct vhost_poll *poll, vhost_work_fn_t fn, ...@@ -44,9 +44,8 @@ void vhost_poll_init(struct vhost_poll *poll, vhost_work_fn_t fn,
__poll_t mask, struct vhost_dev *dev); __poll_t mask, struct vhost_dev *dev);
int vhost_poll_start(struct vhost_poll *poll, struct file *file); int vhost_poll_start(struct vhost_poll *poll, struct file *file);
void vhost_poll_stop(struct vhost_poll *poll); void vhost_poll_stop(struct vhost_poll *poll);
void vhost_poll_flush(struct vhost_poll *poll);
void vhost_poll_queue(struct vhost_poll *poll); void vhost_poll_queue(struct vhost_poll *poll);
void vhost_work_dev_flush(struct vhost_dev *dev); void vhost_dev_flush(struct vhost_dev *dev);
struct vhost_log { struct vhost_log {
u64 addr; u64 addr;
...@@ -161,7 +160,7 @@ struct vhost_dev { ...@@ -161,7 +160,7 @@ struct vhost_dev {
int byte_weight; int byte_weight;
u64 kcov_handle; u64 kcov_handle;
bool use_worker; bool use_worker;
int (*msg_handler)(struct vhost_dev *dev, int (*msg_handler)(struct vhost_dev *dev, u32 asid,
struct vhost_iotlb_msg *msg); struct vhost_iotlb_msg *msg);
}; };
...@@ -169,7 +168,7 @@ bool vhost_exceeds_weight(struct vhost_virtqueue *vq, int pkts, int total_len); ...@@ -169,7 +168,7 @@ bool vhost_exceeds_weight(struct vhost_virtqueue *vq, int pkts, int total_len);
void vhost_dev_init(struct vhost_dev *, struct vhost_virtqueue **vqs, void vhost_dev_init(struct vhost_dev *, struct vhost_virtqueue **vqs,
int nvqs, int iov_limit, int weight, int byte_weight, int nvqs, int iov_limit, int weight, int byte_weight,
bool use_worker, bool use_worker,
int (*msg_handler)(struct vhost_dev *dev, int (*msg_handler)(struct vhost_dev *dev, u32 asid,
struct vhost_iotlb_msg *msg)); struct vhost_iotlb_msg *msg));
long vhost_dev_set_owner(struct vhost_dev *dev); long vhost_dev_set_owner(struct vhost_dev *dev);
bool vhost_dev_has_owner(struct vhost_dev *dev); bool vhost_dev_has_owner(struct vhost_dev *dev);
......
...@@ -705,12 +705,7 @@ static int vhost_vsock_dev_open(struct inode *inode, struct file *file) ...@@ -705,12 +705,7 @@ static int vhost_vsock_dev_open(struct inode *inode, struct file *file)
static void vhost_vsock_flush(struct vhost_vsock *vsock) static void vhost_vsock_flush(struct vhost_vsock *vsock)
{ {
int i; vhost_dev_flush(&vsock->dev);
for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++)
if (vsock->vqs[i].handle_kick)
vhost_poll_flush(&vsock->vqs[i].poll);
vhost_work_dev_flush(&vsock->dev);
} }
static void vhost_vsock_reset_orphans(struct sock *sk) static void vhost_vsock_reset_orphans(struct sock *sk)
......
...@@ -169,7 +169,7 @@ EXPORT_SYMBOL_GPL(virtio_add_status); ...@@ -169,7 +169,7 @@ EXPORT_SYMBOL_GPL(virtio_add_status);
/* Do some validation, then set FEATURES_OK */ /* Do some validation, then set FEATURES_OK */
static int virtio_features_ok(struct virtio_device *dev) static int virtio_features_ok(struct virtio_device *dev)
{ {
unsigned status; unsigned int status;
int ret; int ret;
might_sleep(); might_sleep();
...@@ -220,6 +220,15 @@ static int virtio_features_ok(struct virtio_device *dev) ...@@ -220,6 +220,15 @@ static int virtio_features_ok(struct virtio_device *dev)
* */ * */
void virtio_reset_device(struct virtio_device *dev) void virtio_reset_device(struct virtio_device *dev)
{ {
/*
* The below virtio_synchronize_cbs() guarantees that any
* interrupt for this line arriving after
* virtio_synchronize_vqs() has completed is guaranteed to see
* vq->broken as true.
*/
virtio_break_device(dev);
virtio_synchronize_cbs(dev);
dev->config->reset(dev); dev->config->reset(dev);
} }
EXPORT_SYMBOL_GPL(virtio_reset_device); EXPORT_SYMBOL_GPL(virtio_reset_device);
...@@ -413,7 +422,7 @@ int register_virtio_device(struct virtio_device *dev) ...@@ -413,7 +422,7 @@ int register_virtio_device(struct virtio_device *dev)
device_initialize(&dev->dev); device_initialize(&dev->dev);
/* Assign a unique device index and hence name. */ /* Assign a unique device index and hence name. */
err = ida_simple_get(&virtio_index_ida, 0, 0, GFP_KERNEL); err = ida_alloc(&virtio_index_ida, GFP_KERNEL);
if (err < 0) if (err < 0)
goto out; goto out;
...@@ -428,16 +437,16 @@ int register_virtio_device(struct virtio_device *dev) ...@@ -428,16 +437,16 @@ int register_virtio_device(struct virtio_device *dev)
dev->config_enabled = false; dev->config_enabled = false;
dev->config_change_pending = false; dev->config_change_pending = false;
INIT_LIST_HEAD(&dev->vqs);
spin_lock_init(&dev->vqs_list_lock);
/* We always start by resetting the device, in case a previous /* We always start by resetting the device, in case a previous
* driver messed it up. This also tests that code path a little. */ * driver messed it up. This also tests that code path a little. */
dev->config->reset(dev); virtio_reset_device(dev);
/* Acknowledge that we've seen the device. */ /* Acknowledge that we've seen the device. */
virtio_add_status(dev, VIRTIO_CONFIG_S_ACKNOWLEDGE); virtio_add_status(dev, VIRTIO_CONFIG_S_ACKNOWLEDGE);
INIT_LIST_HEAD(&dev->vqs);
spin_lock_init(&dev->vqs_list_lock);
/* /*
* device_add() causes the bus infrastructure to look for a matching * device_add() causes the bus infrastructure to look for a matching
* driver. * driver.
...@@ -451,7 +460,7 @@ int register_virtio_device(struct virtio_device *dev) ...@@ -451,7 +460,7 @@ int register_virtio_device(struct virtio_device *dev)
out_of_node_put: out_of_node_put:
of_node_put(dev->dev.of_node); of_node_put(dev->dev.of_node);
out_ida_remove: out_ida_remove:
ida_simple_remove(&virtio_index_ida, dev->index); ida_free(&virtio_index_ida, dev->index);
out: out:
virtio_add_status(dev, VIRTIO_CONFIG_S_FAILED); virtio_add_status(dev, VIRTIO_CONFIG_S_FAILED);
return err; return err;
...@@ -469,7 +478,7 @@ void unregister_virtio_device(struct virtio_device *dev) ...@@ -469,7 +478,7 @@ void unregister_virtio_device(struct virtio_device *dev)
int index = dev->index; /* save for after device release */ int index = dev->index; /* save for after device release */
device_unregister(&dev->dev); device_unregister(&dev->dev);
ida_simple_remove(&virtio_index_ida, index); ida_free(&virtio_index_ida, index);
} }
EXPORT_SYMBOL_GPL(unregister_virtio_device); EXPORT_SYMBOL_GPL(unregister_virtio_device);
...@@ -496,7 +505,7 @@ int virtio_device_restore(struct virtio_device *dev) ...@@ -496,7 +505,7 @@ int virtio_device_restore(struct virtio_device *dev)
/* We always start by resetting the device, in case a previous /* We always start by resetting the device, in case a previous
* driver messed it up. */ * driver messed it up. */
dev->config->reset(dev); virtio_reset_device(dev);
/* Acknowledge that we've seen the device. */ /* Acknowledge that we've seen the device. */
virtio_add_status(dev, VIRTIO_CONFIG_S_ACKNOWLEDGE); virtio_add_status(dev, VIRTIO_CONFIG_S_ACKNOWLEDGE);
...@@ -526,8 +535,9 @@ int virtio_device_restore(struct virtio_device *dev) ...@@ -526,8 +535,9 @@ int virtio_device_restore(struct virtio_device *dev)
goto err; goto err;
} }
/* Finally, tell the device we're all set */ /* If restore didn't do it, mark device DRIVER_OK ourselves. */
virtio_add_status(dev, VIRTIO_CONFIG_S_DRIVER_OK); if (!(dev->config->get_status(dev) & VIRTIO_CONFIG_S_DRIVER_OK))
virtio_device_ready(dev);
virtio_config_enable(dev); virtio_config_enable(dev);
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
* multiple balloon pages. All memory counters in this driver are in balloon * multiple balloon pages. All memory counters in this driver are in balloon
* page units. * page units.
*/ */
#define VIRTIO_BALLOON_PAGES_PER_PAGE (unsigned)(PAGE_SIZE >> VIRTIO_BALLOON_PFN_SHIFT) #define VIRTIO_BALLOON_PAGES_PER_PAGE (unsigned int)(PAGE_SIZE >> VIRTIO_BALLOON_PFN_SHIFT)
#define VIRTIO_BALLOON_ARRAY_PFNS_MAX 256 #define VIRTIO_BALLOON_ARRAY_PFNS_MAX 256
/* Maximum number of (4k) pages to deflate on OOM notifications. */ /* Maximum number of (4k) pages to deflate on OOM notifications. */
#define VIRTIO_BALLOON_OOM_NR_PAGES 256 #define VIRTIO_BALLOON_OOM_NR_PAGES 256
...@@ -208,10 +208,10 @@ static void set_page_pfns(struct virtio_balloon *vb, ...@@ -208,10 +208,10 @@ static void set_page_pfns(struct virtio_balloon *vb,
page_to_balloon_pfn(page) + i); page_to_balloon_pfn(page) + i);
} }
static unsigned fill_balloon(struct virtio_balloon *vb, size_t num) static unsigned int fill_balloon(struct virtio_balloon *vb, size_t num)
{ {
unsigned num_allocated_pages; unsigned int num_allocated_pages;
unsigned num_pfns; unsigned int num_pfns;
struct page *page; struct page *page;
LIST_HEAD(pages); LIST_HEAD(pages);
...@@ -272,9 +272,9 @@ static void release_pages_balloon(struct virtio_balloon *vb, ...@@ -272,9 +272,9 @@ static void release_pages_balloon(struct virtio_balloon *vb,
} }
} }
static unsigned leak_balloon(struct virtio_balloon *vb, size_t num) static unsigned int leak_balloon(struct virtio_balloon *vb, size_t num)
{ {
unsigned num_freed_pages; unsigned int num_freed_pages;
struct page *page; struct page *page;
struct balloon_dev_info *vb_dev_info = &vb->vb_dev_info; struct balloon_dev_info *vb_dev_info = &vb->vb_dev_info;
LIST_HEAD(pages); LIST_HEAD(pages);
......
...@@ -144,8 +144,8 @@ static int vm_finalize_features(struct virtio_device *vdev) ...@@ -144,8 +144,8 @@ static int vm_finalize_features(struct virtio_device *vdev)
return 0; return 0;
} }
static void vm_get(struct virtio_device *vdev, unsigned offset, static void vm_get(struct virtio_device *vdev, unsigned int offset,
void *buf, unsigned len) void *buf, unsigned int len)
{ {
struct virtio_mmio_device *vm_dev = to_virtio_mmio_device(vdev); struct virtio_mmio_device *vm_dev = to_virtio_mmio_device(vdev);
void __iomem *base = vm_dev->base + VIRTIO_MMIO_CONFIG; void __iomem *base = vm_dev->base + VIRTIO_MMIO_CONFIG;
...@@ -186,8 +186,8 @@ static void vm_get(struct virtio_device *vdev, unsigned offset, ...@@ -186,8 +186,8 @@ static void vm_get(struct virtio_device *vdev, unsigned offset,
} }
} }
static void vm_set(struct virtio_device *vdev, unsigned offset, static void vm_set(struct virtio_device *vdev, unsigned int offset,
const void *buf, unsigned len) const void *buf, unsigned int len)
{ {
struct virtio_mmio_device *vm_dev = to_virtio_mmio_device(vdev); struct virtio_mmio_device *vm_dev = to_virtio_mmio_device(vdev);
void __iomem *base = vm_dev->base + VIRTIO_MMIO_CONFIG; void __iomem *base = vm_dev->base + VIRTIO_MMIO_CONFIG;
...@@ -253,6 +253,11 @@ static void vm_set_status(struct virtio_device *vdev, u8 status) ...@@ -253,6 +253,11 @@ static void vm_set_status(struct virtio_device *vdev, u8 status)
/* We should never be setting status to 0. */ /* We should never be setting status to 0. */
BUG_ON(status == 0); BUG_ON(status == 0);
/*
* Per memory-barriers.txt, wmb() is not needed to guarantee
* that the the cache coherent memory writes have completed
* before writing to the MMIO region.
*/
writel(status, vm_dev->base + VIRTIO_MMIO_STATUS); writel(status, vm_dev->base + VIRTIO_MMIO_STATUS);
} }
...@@ -345,7 +350,14 @@ static void vm_del_vqs(struct virtio_device *vdev) ...@@ -345,7 +350,14 @@ static void vm_del_vqs(struct virtio_device *vdev)
free_irq(platform_get_irq(vm_dev->pdev, 0), vm_dev); free_irq(platform_get_irq(vm_dev->pdev, 0), vm_dev);
} }
static struct virtqueue *vm_setup_vq(struct virtio_device *vdev, unsigned index, static void vm_synchronize_cbs(struct virtio_device *vdev)
{
struct virtio_mmio_device *vm_dev = to_virtio_mmio_device(vdev);
synchronize_irq(platform_get_irq(vm_dev->pdev, 0));
}
static struct virtqueue *vm_setup_vq(struct virtio_device *vdev, unsigned int index,
void (*callback)(struct virtqueue *vq), void (*callback)(struct virtqueue *vq),
const char *name, bool ctx) const char *name, bool ctx)
{ {
...@@ -455,7 +467,7 @@ static struct virtqueue *vm_setup_vq(struct virtio_device *vdev, unsigned index, ...@@ -455,7 +467,7 @@ static struct virtqueue *vm_setup_vq(struct virtio_device *vdev, unsigned index,
return ERR_PTR(err); return ERR_PTR(err);
} }
static int vm_find_vqs(struct virtio_device *vdev, unsigned nvqs, static int vm_find_vqs(struct virtio_device *vdev, unsigned int nvqs,
struct virtqueue *vqs[], struct virtqueue *vqs[],
vq_callback_t *callbacks[], vq_callback_t *callbacks[],
const char * const names[], const char * const names[],
...@@ -541,6 +553,7 @@ static const struct virtio_config_ops virtio_mmio_config_ops = { ...@@ -541,6 +553,7 @@ static const struct virtio_config_ops virtio_mmio_config_ops = {
.finalize_features = vm_finalize_features, .finalize_features = vm_finalize_features,
.bus_name = vm_bus_name, .bus_name = vm_bus_name,
.get_shm_region = vm_get_shm_region, .get_shm_region = vm_get_shm_region,
.synchronize_cbs = vm_synchronize_cbs,
}; };
...@@ -657,7 +670,7 @@ static int vm_cmdline_set(const char *device, ...@@ -657,7 +670,7 @@ static int vm_cmdline_set(const char *device,
int err; int err;
struct resource resources[2] = {}; struct resource resources[2] = {};
char *str; char *str;
long long int base, size; long long base, size;
unsigned int irq; unsigned int irq;
int processed, consumed = 0; int processed, consumed = 0;
struct platform_device *pdev; struct platform_device *pdev;
......
...@@ -104,8 +104,8 @@ static int vp_request_msix_vectors(struct virtio_device *vdev, int nvectors, ...@@ -104,8 +104,8 @@ static int vp_request_msix_vectors(struct virtio_device *vdev, int nvectors,
{ {
struct virtio_pci_device *vp_dev = to_vp_device(vdev); struct virtio_pci_device *vp_dev = to_vp_device(vdev);
const char *name = dev_name(&vp_dev->vdev.dev); const char *name = dev_name(&vp_dev->vdev.dev);
unsigned flags = PCI_IRQ_MSIX; unsigned int flags = PCI_IRQ_MSIX;
unsigned i, v; unsigned int i, v;
int err = -ENOMEM; int err = -ENOMEM;
vp_dev->msix_vectors = nvectors; vp_dev->msix_vectors = nvectors;
...@@ -171,7 +171,7 @@ static int vp_request_msix_vectors(struct virtio_device *vdev, int nvectors, ...@@ -171,7 +171,7 @@ static int vp_request_msix_vectors(struct virtio_device *vdev, int nvectors,
return err; return err;
} }
static struct virtqueue *vp_setup_vq(struct virtio_device *vdev, unsigned index, static struct virtqueue *vp_setup_vq(struct virtio_device *vdev, unsigned int index,
void (*callback)(struct virtqueue *vq), void (*callback)(struct virtqueue *vq),
const char *name, const char *name,
bool ctx, bool ctx,
...@@ -254,8 +254,7 @@ void vp_del_vqs(struct virtio_device *vdev) ...@@ -254,8 +254,7 @@ void vp_del_vqs(struct virtio_device *vdev)
if (vp_dev->msix_affinity_masks) { if (vp_dev->msix_affinity_masks) {
for (i = 0; i < vp_dev->msix_vectors; i++) for (i = 0; i < vp_dev->msix_vectors; i++)
if (vp_dev->msix_affinity_masks[i]) free_cpumask_var(vp_dev->msix_affinity_masks[i]);
free_cpumask_var(vp_dev->msix_affinity_masks[i]);
} }
if (vp_dev->msix_enabled) { if (vp_dev->msix_enabled) {
...@@ -276,7 +275,7 @@ void vp_del_vqs(struct virtio_device *vdev) ...@@ -276,7 +275,7 @@ void vp_del_vqs(struct virtio_device *vdev)
vp_dev->vqs = NULL; vp_dev->vqs = NULL;
} }
static int vp_find_vqs_msix(struct virtio_device *vdev, unsigned nvqs, static int vp_find_vqs_msix(struct virtio_device *vdev, unsigned int nvqs,
struct virtqueue *vqs[], vq_callback_t *callbacks[], struct virtqueue *vqs[], vq_callback_t *callbacks[],
const char * const names[], bool per_vq_vectors, const char * const names[], bool per_vq_vectors,
const bool *ctx, const bool *ctx,
...@@ -350,7 +349,7 @@ static int vp_find_vqs_msix(struct virtio_device *vdev, unsigned nvqs, ...@@ -350,7 +349,7 @@ static int vp_find_vqs_msix(struct virtio_device *vdev, unsigned nvqs,
return err; return err;
} }
static int vp_find_vqs_intx(struct virtio_device *vdev, unsigned nvqs, static int vp_find_vqs_intx(struct virtio_device *vdev, unsigned int nvqs,
struct virtqueue *vqs[], vq_callback_t *callbacks[], struct virtqueue *vqs[], vq_callback_t *callbacks[],
const char * const names[], const bool *ctx) const char * const names[], const bool *ctx)
{ {
...@@ -389,7 +388,7 @@ static int vp_find_vqs_intx(struct virtio_device *vdev, unsigned nvqs, ...@@ -389,7 +388,7 @@ static int vp_find_vqs_intx(struct virtio_device *vdev, unsigned nvqs,
} }
/* the config->find_vqs() implementation */ /* the config->find_vqs() implementation */
int vp_find_vqs(struct virtio_device *vdev, unsigned nvqs, int vp_find_vqs(struct virtio_device *vdev, unsigned int nvqs,
struct virtqueue *vqs[], vq_callback_t *callbacks[], struct virtqueue *vqs[], vq_callback_t *callbacks[],
const char * const names[], const bool *ctx, const char * const names[], const bool *ctx,
struct irq_affinity *desc) struct irq_affinity *desc)
......
...@@ -38,7 +38,7 @@ struct virtio_pci_vq_info { ...@@ -38,7 +38,7 @@ struct virtio_pci_vq_info {
struct list_head node; struct list_head node;
/* MSI-X vector (or none) */ /* MSI-X vector (or none) */
unsigned msix_vector; unsigned int msix_vector;
}; };
/* Our device structure */ /* Our device structure */
...@@ -68,16 +68,16 @@ struct virtio_pci_device { ...@@ -68,16 +68,16 @@ struct virtio_pci_device {
* and I'm too lazy to allocate each name separately. */ * and I'm too lazy to allocate each name separately. */
char (*msix_names)[256]; char (*msix_names)[256];
/* Number of available vectors */ /* Number of available vectors */
unsigned msix_vectors; unsigned int msix_vectors;
/* Vectors allocated, excluding per-vq vectors if any */ /* Vectors allocated, excluding per-vq vectors if any */
unsigned msix_used_vectors; unsigned int msix_used_vectors;
/* Whether we have vector per vq */ /* Whether we have vector per vq */
bool per_vq_vectors; bool per_vq_vectors;
struct virtqueue *(*setup_vq)(struct virtio_pci_device *vp_dev, struct virtqueue *(*setup_vq)(struct virtio_pci_device *vp_dev,
struct virtio_pci_vq_info *info, struct virtio_pci_vq_info *info,
unsigned idx, unsigned int idx,
void (*callback)(struct virtqueue *vq), void (*callback)(struct virtqueue *vq),
const char *name, const char *name,
bool ctx, bool ctx,
...@@ -108,7 +108,7 @@ bool vp_notify(struct virtqueue *vq); ...@@ -108,7 +108,7 @@ bool vp_notify(struct virtqueue *vq);
/* the config->del_vqs() implementation */ /* the config->del_vqs() implementation */
void vp_del_vqs(struct virtio_device *vdev); void vp_del_vqs(struct virtio_device *vdev);
/* the config->find_vqs() implementation */ /* the config->find_vqs() implementation */
int vp_find_vqs(struct virtio_device *vdev, unsigned nvqs, int vp_find_vqs(struct virtio_device *vdev, unsigned int nvqs,
struct virtqueue *vqs[], vq_callback_t *callbacks[], struct virtqueue *vqs[], vq_callback_t *callbacks[],
const char * const names[], const bool *ctx, const char * const names[], const bool *ctx,
struct irq_affinity *desc); struct irq_affinity *desc);
......
...@@ -45,8 +45,8 @@ static int vp_finalize_features(struct virtio_device *vdev) ...@@ -45,8 +45,8 @@ static int vp_finalize_features(struct virtio_device *vdev)
} }
/* virtio config->get() implementation */ /* virtio config->get() implementation */
static void vp_get(struct virtio_device *vdev, unsigned offset, static void vp_get(struct virtio_device *vdev, unsigned int offset,
void *buf, unsigned len) void *buf, unsigned int len)
{ {
struct virtio_pci_device *vp_dev = to_vp_device(vdev); struct virtio_pci_device *vp_dev = to_vp_device(vdev);
void __iomem *ioaddr = vp_dev->ldev.ioaddr + void __iomem *ioaddr = vp_dev->ldev.ioaddr +
...@@ -61,8 +61,8 @@ static void vp_get(struct virtio_device *vdev, unsigned offset, ...@@ -61,8 +61,8 @@ static void vp_get(struct virtio_device *vdev, unsigned offset,
/* the config->set() implementation. it's symmetric to the config->get() /* the config->set() implementation. it's symmetric to the config->get()
* implementation */ * implementation */
static void vp_set(struct virtio_device *vdev, unsigned offset, static void vp_set(struct virtio_device *vdev, unsigned int offset,
const void *buf, unsigned len) const void *buf, unsigned int len)
{ {
struct virtio_pci_device *vp_dev = to_vp_device(vdev); struct virtio_pci_device *vp_dev = to_vp_device(vdev);
void __iomem *ioaddr = vp_dev->ldev.ioaddr + void __iomem *ioaddr = vp_dev->ldev.ioaddr +
...@@ -109,7 +109,7 @@ static u16 vp_config_vector(struct virtio_pci_device *vp_dev, u16 vector) ...@@ -109,7 +109,7 @@ static u16 vp_config_vector(struct virtio_pci_device *vp_dev, u16 vector)
static struct virtqueue *setup_vq(struct virtio_pci_device *vp_dev, static struct virtqueue *setup_vq(struct virtio_pci_device *vp_dev,
struct virtio_pci_vq_info *info, struct virtio_pci_vq_info *info,
unsigned index, unsigned int index,
void (*callback)(struct virtqueue *vq), void (*callback)(struct virtqueue *vq),
const char *name, const char *name,
bool ctx, bool ctx,
...@@ -192,6 +192,7 @@ static const struct virtio_config_ops virtio_pci_config_ops = { ...@@ -192,6 +192,7 @@ static const struct virtio_config_ops virtio_pci_config_ops = {
.reset = vp_reset, .reset = vp_reset,
.find_vqs = vp_find_vqs, .find_vqs = vp_find_vqs,
.del_vqs = vp_del_vqs, .del_vqs = vp_del_vqs,
.synchronize_cbs = vp_synchronize_vectors,
.get_features = vp_get_features, .get_features = vp_get_features,
.finalize_features = vp_finalize_features, .finalize_features = vp_finalize_features,
.bus_name = vp_bus_name, .bus_name = vp_bus_name,
......
...@@ -60,8 +60,8 @@ static int vp_finalize_features(struct virtio_device *vdev) ...@@ -60,8 +60,8 @@ static int vp_finalize_features(struct virtio_device *vdev)
} }
/* virtio config->get() implementation */ /* virtio config->get() implementation */
static void vp_get(struct virtio_device *vdev, unsigned offset, static void vp_get(struct virtio_device *vdev, unsigned int offset,
void *buf, unsigned len) void *buf, unsigned int len)
{ {
struct virtio_pci_device *vp_dev = to_vp_device(vdev); struct virtio_pci_device *vp_dev = to_vp_device(vdev);
struct virtio_pci_modern_device *mdev = &vp_dev->mdev; struct virtio_pci_modern_device *mdev = &vp_dev->mdev;
...@@ -98,8 +98,8 @@ static void vp_get(struct virtio_device *vdev, unsigned offset, ...@@ -98,8 +98,8 @@ static void vp_get(struct virtio_device *vdev, unsigned offset,
/* the config->set() implementation. it's symmetric to the config->get() /* the config->set() implementation. it's symmetric to the config->get()
* implementation */ * implementation */
static void vp_set(struct virtio_device *vdev, unsigned offset, static void vp_set(struct virtio_device *vdev, unsigned int offset,
const void *buf, unsigned len) const void *buf, unsigned int len)
{ {
struct virtio_pci_device *vp_dev = to_vp_device(vdev); struct virtio_pci_device *vp_dev = to_vp_device(vdev);
struct virtio_pci_modern_device *mdev = &vp_dev->mdev; struct virtio_pci_modern_device *mdev = &vp_dev->mdev;
...@@ -183,7 +183,7 @@ static u16 vp_config_vector(struct virtio_pci_device *vp_dev, u16 vector) ...@@ -183,7 +183,7 @@ static u16 vp_config_vector(struct virtio_pci_device *vp_dev, u16 vector)
static struct virtqueue *setup_vq(struct virtio_pci_device *vp_dev, static struct virtqueue *setup_vq(struct virtio_pci_device *vp_dev,
struct virtio_pci_vq_info *info, struct virtio_pci_vq_info *info,
unsigned index, unsigned int index,
void (*callback)(struct virtqueue *vq), void (*callback)(struct virtqueue *vq),
const char *name, const char *name,
bool ctx, bool ctx,
...@@ -248,7 +248,7 @@ static struct virtqueue *setup_vq(struct virtio_pci_device *vp_dev, ...@@ -248,7 +248,7 @@ static struct virtqueue *setup_vq(struct virtio_pci_device *vp_dev,
return ERR_PTR(err); return ERR_PTR(err);
} }
static int vp_modern_find_vqs(struct virtio_device *vdev, unsigned nvqs, static int vp_modern_find_vqs(struct virtio_device *vdev, unsigned int nvqs,
struct virtqueue *vqs[], struct virtqueue *vqs[],
vq_callback_t *callbacks[], vq_callback_t *callbacks[],
const char * const names[], const bool *ctx, const char * const names[], const bool *ctx,
...@@ -394,6 +394,7 @@ static const struct virtio_config_ops virtio_pci_config_nodev_ops = { ...@@ -394,6 +394,7 @@ static const struct virtio_config_ops virtio_pci_config_nodev_ops = {
.reset = vp_reset, .reset = vp_reset,
.find_vqs = vp_modern_find_vqs, .find_vqs = vp_modern_find_vqs,
.del_vqs = vp_del_vqs, .del_vqs = vp_del_vqs,
.synchronize_cbs = vp_synchronize_vectors,
.get_features = vp_get_features, .get_features = vp_get_features,
.finalize_features = vp_finalize_features, .finalize_features = vp_finalize_features,
.bus_name = vp_bus_name, .bus_name = vp_bus_name,
...@@ -411,6 +412,7 @@ static const struct virtio_config_ops virtio_pci_config_ops = { ...@@ -411,6 +412,7 @@ static const struct virtio_config_ops virtio_pci_config_ops = {
.reset = vp_reset, .reset = vp_reset,
.find_vqs = vp_modern_find_vqs, .find_vqs = vp_modern_find_vqs,
.del_vqs = vp_del_vqs, .del_vqs = vp_del_vqs,
.synchronize_cbs = vp_synchronize_vectors,
.get_features = vp_get_features, .get_features = vp_get_features,
.finalize_features = vp_finalize_features, .finalize_features = vp_finalize_features,
.bus_name = vp_bus_name, .bus_name = vp_bus_name,
......
...@@ -347,6 +347,7 @@ int vp_modern_probe(struct virtio_pci_modern_device *mdev) ...@@ -347,6 +347,7 @@ int vp_modern_probe(struct virtio_pci_modern_device *mdev)
err_map_isr: err_map_isr:
pci_iounmap(pci_dev, mdev->common); pci_iounmap(pci_dev, mdev->common);
err_map_common: err_map_common:
pci_release_selected_regions(pci_dev, mdev->modern_bars);
return err; return err;
} }
EXPORT_SYMBOL_GPL(vp_modern_probe); EXPORT_SYMBOL_GPL(vp_modern_probe);
...@@ -466,6 +467,11 @@ void vp_modern_set_status(struct virtio_pci_modern_device *mdev, ...@@ -466,6 +467,11 @@ void vp_modern_set_status(struct virtio_pci_modern_device *mdev,
{ {
struct virtio_pci_common_cfg __iomem *cfg = mdev->common; struct virtio_pci_common_cfg __iomem *cfg = mdev->common;
/*
* Per memory-barriers.txt, wmb() is not needed to guarantee
* that the the cache coherent memory writes have completed
* before writing to the MMIO region.
*/
vp_iowrite8(status, &cfg->device_status); vp_iowrite8(status, &cfg->device_status);
} }
EXPORT_SYMBOL_GPL(vp_modern_set_status); EXPORT_SYMBOL_GPL(vp_modern_set_status);
......
...@@ -205,11 +205,9 @@ struct vring_virtqueue { ...@@ -205,11 +205,9 @@ struct vring_virtqueue {
#define to_vvq(_vq) container_of(_vq, struct vring_virtqueue, vq) #define to_vvq(_vq) container_of(_vq, struct vring_virtqueue, vq)
static inline bool virtqueue_use_indirect(struct virtqueue *_vq, static inline bool virtqueue_use_indirect(struct vring_virtqueue *vq,
unsigned int total_sg) unsigned int total_sg)
{ {
struct vring_virtqueue *vq = to_vvq(_vq);
/* /*
* If the host supports indirect descriptor tables, and we have multiple * If the host supports indirect descriptor tables, and we have multiple
* buffers, then go indirect. FIXME: tune this threshold * buffers, then go indirect. FIXME: tune this threshold
...@@ -499,7 +497,7 @@ static inline int virtqueue_add_split(struct virtqueue *_vq, ...@@ -499,7 +497,7 @@ static inline int virtqueue_add_split(struct virtqueue *_vq,
head = vq->free_head; head = vq->free_head;
if (virtqueue_use_indirect(_vq, total_sg)) if (virtqueue_use_indirect(vq, total_sg))
desc = alloc_indirect_split(_vq, total_sg, gfp); desc = alloc_indirect_split(_vq, total_sg, gfp);
else { else {
desc = NULL; desc = NULL;
...@@ -519,7 +517,7 @@ static inline int virtqueue_add_split(struct virtqueue *_vq, ...@@ -519,7 +517,7 @@ static inline int virtqueue_add_split(struct virtqueue *_vq,
descs_used = total_sg; descs_used = total_sg;
} }
if (vq->vq.num_free < descs_used) { if (unlikely(vq->vq.num_free < descs_used)) {
pr_debug("Can't add buf len %i - avail = %i\n", pr_debug("Can't add buf len %i - avail = %i\n",
descs_used, vq->vq.num_free); descs_used, vq->vq.num_free);
/* FIXME: for historical reasons, we force a notify here if /* FIXME: for historical reasons, we force a notify here if
...@@ -811,7 +809,7 @@ static void virtqueue_disable_cb_split(struct virtqueue *_vq) ...@@ -811,7 +809,7 @@ static void virtqueue_disable_cb_split(struct virtqueue *_vq)
} }
} }
static unsigned virtqueue_enable_cb_prepare_split(struct virtqueue *_vq) static unsigned int virtqueue_enable_cb_prepare_split(struct virtqueue *_vq)
{ {
struct vring_virtqueue *vq = to_vvq(_vq); struct vring_virtqueue *vq = to_vvq(_vq);
u16 last_used_idx; u16 last_used_idx;
...@@ -836,7 +834,7 @@ static unsigned virtqueue_enable_cb_prepare_split(struct virtqueue *_vq) ...@@ -836,7 +834,7 @@ static unsigned virtqueue_enable_cb_prepare_split(struct virtqueue *_vq)
return last_used_idx; return last_used_idx;
} }
static bool virtqueue_poll_split(struct virtqueue *_vq, unsigned last_used_idx) static bool virtqueue_poll_split(struct virtqueue *_vq, unsigned int last_used_idx)
{ {
struct vring_virtqueue *vq = to_vvq(_vq); struct vring_virtqueue *vq = to_vvq(_vq);
...@@ -1178,7 +1176,7 @@ static inline int virtqueue_add_packed(struct virtqueue *_vq, ...@@ -1178,7 +1176,7 @@ static inline int virtqueue_add_packed(struct virtqueue *_vq,
BUG_ON(total_sg == 0); BUG_ON(total_sg == 0);
if (virtqueue_use_indirect(_vq, total_sg)) { if (virtqueue_use_indirect(vq, total_sg)) {
err = virtqueue_add_indirect_packed(vq, sgs, total_sg, out_sgs, err = virtqueue_add_indirect_packed(vq, sgs, total_sg, out_sgs,
in_sgs, data, gfp); in_sgs, data, gfp);
if (err != -ENOMEM) { if (err != -ENOMEM) {
...@@ -1488,7 +1486,7 @@ static void virtqueue_disable_cb_packed(struct virtqueue *_vq) ...@@ -1488,7 +1486,7 @@ static void virtqueue_disable_cb_packed(struct virtqueue *_vq)
} }
} }
static unsigned virtqueue_enable_cb_prepare_packed(struct virtqueue *_vq) static unsigned int virtqueue_enable_cb_prepare_packed(struct virtqueue *_vq)
{ {
struct vring_virtqueue *vq = to_vvq(_vq); struct vring_virtqueue *vq = to_vvq(_vq);
...@@ -1690,7 +1688,7 @@ static struct virtqueue *vring_create_virtqueue_packed( ...@@ -1690,7 +1688,7 @@ static struct virtqueue *vring_create_virtqueue_packed(
vq->we_own_ring = true; vq->we_own_ring = true;
vq->notify = notify; vq->notify = notify;
vq->weak_barriers = weak_barriers; vq->weak_barriers = weak_barriers;
vq->broken = false; vq->broken = true;
vq->last_used_idx = 0; vq->last_used_idx = 0;
vq->event_triggered = false; vq->event_triggered = false;
vq->num_added = 0; vq->num_added = 0;
...@@ -2027,7 +2025,7 @@ EXPORT_SYMBOL_GPL(virtqueue_disable_cb); ...@@ -2027,7 +2025,7 @@ EXPORT_SYMBOL_GPL(virtqueue_disable_cb);
* Caller must ensure we don't call this with other virtqueue * Caller must ensure we don't call this with other virtqueue
* operations at the same time (except where noted). * operations at the same time (except where noted).
*/ */
unsigned virtqueue_enable_cb_prepare(struct virtqueue *_vq) unsigned int virtqueue_enable_cb_prepare(struct virtqueue *_vq)
{ {
struct vring_virtqueue *vq = to_vvq(_vq); struct vring_virtqueue *vq = to_vvq(_vq);
...@@ -2048,7 +2046,7 @@ EXPORT_SYMBOL_GPL(virtqueue_enable_cb_prepare); ...@@ -2048,7 +2046,7 @@ EXPORT_SYMBOL_GPL(virtqueue_enable_cb_prepare);
* *
* This does not need to be serialized. * This does not need to be serialized.
*/ */
bool virtqueue_poll(struct virtqueue *_vq, unsigned last_used_idx) bool virtqueue_poll(struct virtqueue *_vq, unsigned int last_used_idx)
{ {
struct vring_virtqueue *vq = to_vvq(_vq); struct vring_virtqueue *vq = to_vvq(_vq);
...@@ -2074,7 +2072,7 @@ EXPORT_SYMBOL_GPL(virtqueue_poll); ...@@ -2074,7 +2072,7 @@ EXPORT_SYMBOL_GPL(virtqueue_poll);
*/ */
bool virtqueue_enable_cb(struct virtqueue *_vq) bool virtqueue_enable_cb(struct virtqueue *_vq)
{ {
unsigned last_used_idx = virtqueue_enable_cb_prepare(_vq); unsigned int last_used_idx = virtqueue_enable_cb_prepare(_vq);
return !virtqueue_poll(_vq, last_used_idx); return !virtqueue_poll(_vq, last_used_idx);
} }
...@@ -2136,8 +2134,11 @@ irqreturn_t vring_interrupt(int irq, void *_vq) ...@@ -2136,8 +2134,11 @@ irqreturn_t vring_interrupt(int irq, void *_vq)
return IRQ_NONE; return IRQ_NONE;
} }
if (unlikely(vq->broken)) if (unlikely(vq->broken)) {
return IRQ_HANDLED; dev_warn_once(&vq->vq.vdev->dev,
"virtio vring IRQ raised before DRIVER_OK");
return IRQ_NONE;
}
/* Just a hint for performance: so it's ok that this can be racy! */ /* Just a hint for performance: so it's ok that this can be racy! */
if (vq->event) if (vq->event)
...@@ -2179,7 +2180,7 @@ struct virtqueue *__vring_new_virtqueue(unsigned int index, ...@@ -2179,7 +2180,7 @@ struct virtqueue *__vring_new_virtqueue(unsigned int index,
vq->we_own_ring = false; vq->we_own_ring = false;
vq->notify = notify; vq->notify = notify;
vq->weak_barriers = weak_barriers; vq->weak_barriers = weak_barriers;
vq->broken = false; vq->broken = true;
vq->last_used_idx = 0; vq->last_used_idx = 0;
vq->event_triggered = false; vq->event_triggered = false;
vq->num_added = 0; vq->num_added = 0;
...@@ -2397,6 +2398,28 @@ void virtio_break_device(struct virtio_device *dev) ...@@ -2397,6 +2398,28 @@ void virtio_break_device(struct virtio_device *dev)
} }
EXPORT_SYMBOL_GPL(virtio_break_device); EXPORT_SYMBOL_GPL(virtio_break_device);
/*
* This should allow the device to be used by the driver. You may
* need to grab appropriate locks to flush the write to
* vq->broken. This should only be used in some specific case e.g
* (probing and restoring). This function should only be called by the
* core, not directly by the driver.
*/
void __virtio_unbreak_device(struct virtio_device *dev)
{
struct virtqueue *_vq;
spin_lock(&dev->vqs_list_lock);
list_for_each_entry(_vq, &dev->vqs, list) {
struct vring_virtqueue *vq = to_vvq(_vq);
/* Pairs with READ_ONCE() in virtqueue_is_broken(). */
WRITE_ONCE(vq->broken, false);
}
spin_unlock(&dev->vqs_list_lock);
}
EXPORT_SYMBOL_GPL(__virtio_unbreak_device);
dma_addr_t virtqueue_get_desc_addr(struct virtqueue *_vq) dma_addr_t virtqueue_get_desc_addr(struct virtqueue *_vq)
{ {
struct vring_virtqueue *vq = to_vvq(_vq); struct vring_virtqueue *vq = to_vvq(_vq);
......
...@@ -53,16 +53,16 @@ static struct vdpa_device *vd_get_vdpa(struct virtio_device *vdev) ...@@ -53,16 +53,16 @@ static struct vdpa_device *vd_get_vdpa(struct virtio_device *vdev)
return to_virtio_vdpa_device(vdev)->vdpa; return to_virtio_vdpa_device(vdev)->vdpa;
} }
static void virtio_vdpa_get(struct virtio_device *vdev, unsigned offset, static void virtio_vdpa_get(struct virtio_device *vdev, unsigned int offset,
void *buf, unsigned len) void *buf, unsigned int len)
{ {
struct vdpa_device *vdpa = vd_get_vdpa(vdev); struct vdpa_device *vdpa = vd_get_vdpa(vdev);
vdpa_get_config(vdpa, offset, buf, len); vdpa_get_config(vdpa, offset, buf, len);
} }
static void virtio_vdpa_set(struct virtio_device *vdev, unsigned offset, static void virtio_vdpa_set(struct virtio_device *vdev, unsigned int offset,
const void *buf, unsigned len) const void *buf, unsigned int len)
{ {
struct vdpa_device *vdpa = vd_get_vdpa(vdev); struct vdpa_device *vdpa = vd_get_vdpa(vdev);
...@@ -184,7 +184,7 @@ virtio_vdpa_setup_vq(struct virtio_device *vdev, unsigned int index, ...@@ -184,7 +184,7 @@ virtio_vdpa_setup_vq(struct virtio_device *vdev, unsigned int index,
} }
/* Setup virtqueue callback */ /* Setup virtqueue callback */
cb.callback = virtio_vdpa_virtqueue_cb; cb.callback = callback ? virtio_vdpa_virtqueue_cb : NULL;
cb.private = info; cb.private = info;
ops->set_vq_cb(vdpa, index, &cb); ops->set_vq_cb(vdpa, index, &cb);
ops->set_vq_num(vdpa, index, virtqueue_get_vring_size(vq)); ops->set_vq_num(vdpa, index, virtqueue_get_vring_size(vq));
...@@ -263,7 +263,7 @@ static void virtio_vdpa_del_vqs(struct virtio_device *vdev) ...@@ -263,7 +263,7 @@ static void virtio_vdpa_del_vqs(struct virtio_device *vdev)
virtio_vdpa_del_vq(vq); virtio_vdpa_del_vq(vq);
} }
static int virtio_vdpa_find_vqs(struct virtio_device *vdev, unsigned nvqs, static int virtio_vdpa_find_vqs(struct virtio_device *vdev, unsigned int nvqs,
struct virtqueue *vqs[], struct virtqueue *vqs[],
vq_callback_t *callbacks[], vq_callback_t *callbacks[],
const char * const names[], const char * const names[],
......
...@@ -87,6 +87,7 @@ enum { ...@@ -87,6 +87,7 @@ enum {
enum { enum {
MLX5_OBJ_TYPE_GENEVE_TLV_OPT = 0x000b, MLX5_OBJ_TYPE_GENEVE_TLV_OPT = 0x000b,
MLX5_OBJ_TYPE_VIRTIO_NET_Q = 0x000d, MLX5_OBJ_TYPE_VIRTIO_NET_Q = 0x000d,
MLX5_OBJ_TYPE_VIRTIO_Q_COUNTERS = 0x001c,
MLX5_OBJ_TYPE_MATCH_DEFINER = 0x0018, MLX5_OBJ_TYPE_MATCH_DEFINER = 0x0018,
MLX5_OBJ_TYPE_MKEY = 0xff01, MLX5_OBJ_TYPE_MKEY = 0xff01,
MLX5_OBJ_TYPE_QP = 0xff02, MLX5_OBJ_TYPE_QP = 0xff02,
......
...@@ -165,4 +165,43 @@ struct mlx5_ifc_modify_virtio_net_q_out_bits { ...@@ -165,4 +165,43 @@ struct mlx5_ifc_modify_virtio_net_q_out_bits {
struct mlx5_ifc_general_obj_out_cmd_hdr_bits general_obj_out_cmd_hdr; struct mlx5_ifc_general_obj_out_cmd_hdr_bits general_obj_out_cmd_hdr;
}; };
struct mlx5_ifc_virtio_q_counters_bits {
u8 modify_field_select[0x40];
u8 reserved_at_40[0x40];
u8 received_desc[0x40];
u8 completed_desc[0x40];
u8 error_cqes[0x20];
u8 bad_desc_errors[0x20];
u8 exceed_max_chain[0x20];
u8 invalid_buffer[0x20];
u8 reserved_at_180[0x280];
};
struct mlx5_ifc_create_virtio_q_counters_in_bits {
struct mlx5_ifc_general_obj_in_cmd_hdr_bits hdr;
struct mlx5_ifc_virtio_q_counters_bits virtio_q_counters;
};
struct mlx5_ifc_create_virtio_q_counters_out_bits {
struct mlx5_ifc_general_obj_in_cmd_hdr_bits hdr;
struct mlx5_ifc_virtio_q_counters_bits virtio_q_counters;
};
struct mlx5_ifc_destroy_virtio_q_counters_in_bits {
struct mlx5_ifc_general_obj_in_cmd_hdr_bits hdr;
};
struct mlx5_ifc_destroy_virtio_q_counters_out_bits {
struct mlx5_ifc_general_obj_out_cmd_hdr_bits hdr;
};
struct mlx5_ifc_query_virtio_q_counters_in_bits {
struct mlx5_ifc_general_obj_in_cmd_hdr_bits hdr;
};
struct mlx5_ifc_query_virtio_q_counters_out_bits {
struct mlx5_ifc_general_obj_in_cmd_hdr_bits hdr;
struct mlx5_ifc_virtio_q_counters_bits counters;
};
#endif /* __MLX5_IFC_VDPA_H_ */ #endif /* __MLX5_IFC_VDPA_H_ */
...@@ -66,9 +66,11 @@ struct vdpa_mgmt_dev; ...@@ -66,9 +66,11 @@ struct vdpa_mgmt_dev;
* @dma_dev: the actual device that is performing DMA * @dma_dev: the actual device that is performing DMA
* @driver_override: driver name to force a match * @driver_override: driver name to force a match
* @config: the configuration ops for this device. * @config: the configuration ops for this device.
* @cf_mutex: Protects get and set access to configuration layout. * @cf_lock: Protects get and set access to configuration layout.
* @index: device index * @index: device index
* @features_valid: were features initialized? for legacy guests * @features_valid: were features initialized? for legacy guests
* @ngroups: the number of virtqueue groups
* @nas: the number of address spaces
* @use_va: indicate whether virtual address must be used by this device * @use_va: indicate whether virtual address must be used by this device
* @nvqs: maximum number of supported virtqueues * @nvqs: maximum number of supported virtqueues
* @mdev: management device pointer; caller must setup when registering device as part * @mdev: management device pointer; caller must setup when registering device as part
...@@ -79,12 +81,14 @@ struct vdpa_device { ...@@ -79,12 +81,14 @@ struct vdpa_device {
struct device *dma_dev; struct device *dma_dev;
const char *driver_override; const char *driver_override;
const struct vdpa_config_ops *config; const struct vdpa_config_ops *config;
struct mutex cf_mutex; /* Protects get/set config */ struct rw_semaphore cf_lock; /* Protects get/set config */
unsigned int index; unsigned int index;
bool features_valid; bool features_valid;
bool use_va; bool use_va;
u32 nvqs; u32 nvqs;
struct vdpa_mgmt_dev *mdev; struct vdpa_mgmt_dev *mdev;
unsigned int ngroups;
unsigned int nas;
}; };
/** /**
...@@ -172,6 +176,10 @@ struct vdpa_map_file { ...@@ -172,6 +176,10 @@ struct vdpa_map_file {
* for the device * for the device
* @vdev: vdpa device * @vdev: vdpa device
* Returns virtqueue algin requirement * Returns virtqueue algin requirement
* @get_vq_group: Get the group id for a specific virtqueue
* @vdev: vdpa device
* @idx: virtqueue index
* Returns u32: group id for this virtqueue
* @get_device_features: Get virtio features supported by the device * @get_device_features: Get virtio features supported by the device
* @vdev: vdpa device * @vdev: vdpa device
* Returns the virtio features support by the * Returns the virtio features support by the
...@@ -232,10 +240,17 @@ struct vdpa_map_file { ...@@ -232,10 +240,17 @@ struct vdpa_map_file {
* @vdev: vdpa device * @vdev: vdpa device
* Returns the iova range supported by * Returns the iova range supported by
* the device. * the device.
* @set_group_asid: Set address space identifier for a
* virtqueue group
* @vdev: vdpa device
* @group: virtqueue group
* @asid: address space id for this group
* Returns integer: success (0) or error (< 0)
* @set_map: Set device memory mapping (optional) * @set_map: Set device memory mapping (optional)
* Needed for device that using device * Needed for device that using device
* specific DMA translation (on-chip IOMMU) * specific DMA translation (on-chip IOMMU)
* @vdev: vdpa device * @vdev: vdpa device
* @asid: address space identifier
* @iotlb: vhost memory mapping to be * @iotlb: vhost memory mapping to be
* used by the vDPA * used by the vDPA
* Returns integer: success (0) or error (< 0) * Returns integer: success (0) or error (< 0)
...@@ -244,6 +259,7 @@ struct vdpa_map_file { ...@@ -244,6 +259,7 @@ struct vdpa_map_file {
* specific DMA translation (on-chip IOMMU) * specific DMA translation (on-chip IOMMU)
* and preferring incremental map. * and preferring incremental map.
* @vdev: vdpa device * @vdev: vdpa device
* @asid: address space identifier
* @iova: iova to be mapped * @iova: iova to be mapped
* @size: size of the area * @size: size of the area
* @pa: physical address for the map * @pa: physical address for the map
...@@ -255,6 +271,7 @@ struct vdpa_map_file { ...@@ -255,6 +271,7 @@ struct vdpa_map_file {
* specific DMA translation (on-chip IOMMU) * specific DMA translation (on-chip IOMMU)
* and preferring incremental unmap. * and preferring incremental unmap.
* @vdev: vdpa device * @vdev: vdpa device
* @asid: address space identifier
* @iova: iova to be unmapped * @iova: iova to be unmapped
* @size: size of the area * @size: size of the area
* Returns integer: success (0) or error (< 0) * Returns integer: success (0) or error (< 0)
...@@ -276,6 +293,9 @@ struct vdpa_config_ops { ...@@ -276,6 +293,9 @@ struct vdpa_config_ops {
const struct vdpa_vq_state *state); const struct vdpa_vq_state *state);
int (*get_vq_state)(struct vdpa_device *vdev, u16 idx, int (*get_vq_state)(struct vdpa_device *vdev, u16 idx,
struct vdpa_vq_state *state); struct vdpa_vq_state *state);
int (*get_vendor_vq_stats)(struct vdpa_device *vdev, u16 idx,
struct sk_buff *msg,
struct netlink_ext_ack *extack);
struct vdpa_notification_area struct vdpa_notification_area
(*get_vq_notification)(struct vdpa_device *vdev, u16 idx); (*get_vq_notification)(struct vdpa_device *vdev, u16 idx);
/* vq irq is not expected to be changed once DRIVER_OK is set */ /* vq irq is not expected to be changed once DRIVER_OK is set */
...@@ -283,6 +303,7 @@ struct vdpa_config_ops { ...@@ -283,6 +303,7 @@ struct vdpa_config_ops {
/* Device ops */ /* Device ops */
u32 (*get_vq_align)(struct vdpa_device *vdev); u32 (*get_vq_align)(struct vdpa_device *vdev);
u32 (*get_vq_group)(struct vdpa_device *vdev, u16 idx);
u64 (*get_device_features)(struct vdpa_device *vdev); u64 (*get_device_features)(struct vdpa_device *vdev);
int (*set_driver_features)(struct vdpa_device *vdev, u64 features); int (*set_driver_features)(struct vdpa_device *vdev, u64 features);
u64 (*get_driver_features)(struct vdpa_device *vdev); u64 (*get_driver_features)(struct vdpa_device *vdev);
...@@ -304,10 +325,14 @@ struct vdpa_config_ops { ...@@ -304,10 +325,14 @@ struct vdpa_config_ops {
struct vdpa_iova_range (*get_iova_range)(struct vdpa_device *vdev); struct vdpa_iova_range (*get_iova_range)(struct vdpa_device *vdev);
/* DMA ops */ /* DMA ops */
int (*set_map)(struct vdpa_device *vdev, struct vhost_iotlb *iotlb); int (*set_map)(struct vdpa_device *vdev, unsigned int asid,
int (*dma_map)(struct vdpa_device *vdev, u64 iova, u64 size, struct vhost_iotlb *iotlb);
u64 pa, u32 perm, void *opaque); int (*dma_map)(struct vdpa_device *vdev, unsigned int asid,
int (*dma_unmap)(struct vdpa_device *vdev, u64 iova, u64 size); u64 iova, u64 size, u64 pa, u32 perm, void *opaque);
int (*dma_unmap)(struct vdpa_device *vdev, unsigned int asid,
u64 iova, u64 size);
int (*set_group_asid)(struct vdpa_device *vdev, unsigned int group,
unsigned int asid);
/* Free device resources */ /* Free device resources */
void (*free)(struct vdpa_device *vdev); void (*free)(struct vdpa_device *vdev);
...@@ -315,6 +340,7 @@ struct vdpa_config_ops { ...@@ -315,6 +340,7 @@ struct vdpa_config_ops {
struct vdpa_device *__vdpa_alloc_device(struct device *parent, struct vdpa_device *__vdpa_alloc_device(struct device *parent,
const struct vdpa_config_ops *config, const struct vdpa_config_ops *config,
unsigned int ngroups, unsigned int nas,
size_t size, const char *name, size_t size, const char *name,
bool use_va); bool use_va);
...@@ -325,17 +351,20 @@ struct vdpa_device *__vdpa_alloc_device(struct device *parent, ...@@ -325,17 +351,20 @@ struct vdpa_device *__vdpa_alloc_device(struct device *parent,
* @member: the name of struct vdpa_device within the @dev_struct * @member: the name of struct vdpa_device within the @dev_struct
* @parent: the parent device * @parent: the parent device
* @config: the bus operations that is supported by this device * @config: the bus operations that is supported by this device
* @ngroups: the number of virtqueue groups supported by this device
* @nas: the number of address spaces
* @name: name of the vdpa device * @name: name of the vdpa device
* @use_va: indicate whether virtual address must be used by this device * @use_va: indicate whether virtual address must be used by this device
* *
* Return allocated data structure or ERR_PTR upon error * Return allocated data structure or ERR_PTR upon error
*/ */
#define vdpa_alloc_device(dev_struct, member, parent, config, name, use_va) \ #define vdpa_alloc_device(dev_struct, member, parent, config, ngroups, nas, \
container_of(__vdpa_alloc_device( \ name, use_va) \
parent, config, \ container_of((__vdpa_alloc_device( \
sizeof(dev_struct) + \ parent, config, ngroups, nas, \
(sizeof(dev_struct) + \
BUILD_BUG_ON_ZERO(offsetof( \ BUILD_BUG_ON_ZERO(offsetof( \
dev_struct, member)), name, use_va), \ dev_struct, member))), name, use_va)), \
dev_struct, member) dev_struct, member)
int vdpa_register_device(struct vdpa_device *vdev, u32 nvqs); int vdpa_register_device(struct vdpa_device *vdev, u32 nvqs);
...@@ -395,10 +424,10 @@ static inline int vdpa_reset(struct vdpa_device *vdev) ...@@ -395,10 +424,10 @@ static inline int vdpa_reset(struct vdpa_device *vdev)
const struct vdpa_config_ops *ops = vdev->config; const struct vdpa_config_ops *ops = vdev->config;
int ret; int ret;
mutex_lock(&vdev->cf_mutex); down_write(&vdev->cf_lock);
vdev->features_valid = false; vdev->features_valid = false;
ret = ops->reset(vdev); ret = ops->reset(vdev);
mutex_unlock(&vdev->cf_mutex); up_write(&vdev->cf_lock);
return ret; return ret;
} }
...@@ -417,9 +446,9 @@ static inline int vdpa_set_features(struct vdpa_device *vdev, u64 features) ...@@ -417,9 +446,9 @@ static inline int vdpa_set_features(struct vdpa_device *vdev, u64 features)
{ {
int ret; int ret;
mutex_lock(&vdev->cf_mutex); down_write(&vdev->cf_lock);
ret = vdpa_set_features_unlocked(vdev, features); ret = vdpa_set_features_unlocked(vdev, features);
mutex_unlock(&vdev->cf_mutex); up_write(&vdev->cf_lock);
return ret; return ret;
} }
...@@ -463,7 +492,7 @@ struct vdpa_mgmtdev_ops { ...@@ -463,7 +492,7 @@ struct vdpa_mgmtdev_ops {
struct vdpa_mgmt_dev { struct vdpa_mgmt_dev {
struct device *device; struct device *device;
const struct vdpa_mgmtdev_ops *ops; const struct vdpa_mgmtdev_ops *ops;
const struct virtio_device_id *id_table; struct virtio_device_id *id_table;
u64 config_attr_mask; u64 config_attr_mask;
struct list_head list; struct list_head list;
u64 supported_features; u64 supported_features;
......
...@@ -36,6 +36,8 @@ int vhost_iotlb_add_range(struct vhost_iotlb *iotlb, u64 start, u64 last, ...@@ -36,6 +36,8 @@ int vhost_iotlb_add_range(struct vhost_iotlb *iotlb, u64 start, u64 last,
u64 addr, unsigned int perm); u64 addr, unsigned int perm);
void vhost_iotlb_del_range(struct vhost_iotlb *iotlb, u64 start, u64 last); void vhost_iotlb_del_range(struct vhost_iotlb *iotlb, u64 start, u64 last);
void vhost_iotlb_init(struct vhost_iotlb *iotlb, unsigned int limit,
unsigned int flags);
struct vhost_iotlb *vhost_iotlb_alloc(unsigned int limit, unsigned int flags); struct vhost_iotlb *vhost_iotlb_alloc(unsigned int limit, unsigned int flags);
void vhost_iotlb_free(struct vhost_iotlb *iotlb); void vhost_iotlb_free(struct vhost_iotlb *iotlb);
void vhost_iotlb_reset(struct vhost_iotlb *iotlb); void vhost_iotlb_reset(struct vhost_iotlb *iotlb);
......
...@@ -131,6 +131,7 @@ void unregister_virtio_device(struct virtio_device *dev); ...@@ -131,6 +131,7 @@ void unregister_virtio_device(struct virtio_device *dev);
bool is_virtio_device(struct device *dev); bool is_virtio_device(struct device *dev);
void virtio_break_device(struct virtio_device *dev); void virtio_break_device(struct virtio_device *dev);
void __virtio_unbreak_device(struct virtio_device *dev);
void virtio_config_changed(struct virtio_device *dev); void virtio_config_changed(struct virtio_device *dev);
#ifdef CONFIG_PM_SLEEP #ifdef CONFIG_PM_SLEEP
......
...@@ -57,6 +57,11 @@ struct virtio_shm_region { ...@@ -57,6 +57,11 @@ struct virtio_shm_region {
* include a NULL entry for vqs unused by driver * include a NULL entry for vqs unused by driver
* Returns 0 on success or error status * Returns 0 on success or error status
* @del_vqs: free virtqueues found by find_vqs(). * @del_vqs: free virtqueues found by find_vqs().
* @synchronize_cbs: synchronize with the virtqueue callbacks (optional)
* The function guarantees that all memory operations on the
* queue before it are visible to the vring_interrupt() that is
* called after it.
* vdev: the virtio_device
* @get_features: get the array of feature bits for this device. * @get_features: get the array of feature bits for this device.
* vdev: the virtio_device * vdev: the virtio_device
* Returns the first 64 feature bits (all we currently need). * Returns the first 64 feature bits (all we currently need).
...@@ -89,6 +94,7 @@ struct virtio_config_ops { ...@@ -89,6 +94,7 @@ struct virtio_config_ops {
const char * const names[], const bool *ctx, const char * const names[], const bool *ctx,
struct irq_affinity *desc); struct irq_affinity *desc);
void (*del_vqs)(struct virtio_device *); void (*del_vqs)(struct virtio_device *);
void (*synchronize_cbs)(struct virtio_device *);
u64 (*get_features)(struct virtio_device *vdev); u64 (*get_features)(struct virtio_device *vdev);
int (*finalize_features)(struct virtio_device *vdev); int (*finalize_features)(struct virtio_device *vdev);
const char *(*bus_name)(struct virtio_device *vdev); const char *(*bus_name)(struct virtio_device *vdev);
...@@ -217,6 +223,25 @@ int virtio_find_vqs_ctx(struct virtio_device *vdev, unsigned nvqs, ...@@ -217,6 +223,25 @@ int virtio_find_vqs_ctx(struct virtio_device *vdev, unsigned nvqs,
desc); desc);
} }
/**
* virtio_synchronize_cbs - synchronize with virtqueue callbacks
* @vdev: the device
*/
static inline
void virtio_synchronize_cbs(struct virtio_device *dev)
{
if (dev->config->synchronize_cbs) {
dev->config->synchronize_cbs(dev);
} else {
/*
* A best effort fallback to synchronize with
* interrupts, preemption and softirq disabled
* regions. See comment above synchronize_rcu().
*/
synchronize_rcu();
}
}
/** /**
* virtio_device_ready - enable vq use in probe function * virtio_device_ready - enable vq use in probe function
* @vdev: the device * @vdev: the device
...@@ -230,7 +255,27 @@ void virtio_device_ready(struct virtio_device *dev) ...@@ -230,7 +255,27 @@ void virtio_device_ready(struct virtio_device *dev)
{ {
unsigned status = dev->config->get_status(dev); unsigned status = dev->config->get_status(dev);
BUG_ON(status & VIRTIO_CONFIG_S_DRIVER_OK); WARN_ON(status & VIRTIO_CONFIG_S_DRIVER_OK);
/*
* The virtio_synchronize_cbs() makes sure vring_interrupt()
* will see the driver specific setup if it sees vq->broken
* as false (even if the notifications come before DRIVER_OK).
*/
virtio_synchronize_cbs(dev);
__virtio_unbreak_device(dev);
/*
* The transport should ensure the visibility of vq->broken
* before setting DRIVER_OK. See the comments for the transport
* specific set_status() method.
*
* A well behaved device will only notify a virtqueue after
* DRIVER_OK, this means the device should "see" the coherenct
* memory write that set vq->broken as false which is done by
* the driver when it sees DRIVER_OK, then the following
* driver's vring_interrupt() will see vq->broken as false so
* we won't lose any notification.
*/
dev->config->set_status(dev, status | VIRTIO_CONFIG_S_DRIVER_OK); dev->config->set_status(dev, status | VIRTIO_CONFIG_S_DRIVER_OK);
} }
......
...@@ -18,6 +18,7 @@ enum vdpa_command { ...@@ -18,6 +18,7 @@ enum vdpa_command {
VDPA_CMD_DEV_DEL, VDPA_CMD_DEV_DEL,
VDPA_CMD_DEV_GET, /* can dump */ VDPA_CMD_DEV_GET, /* can dump */
VDPA_CMD_DEV_CONFIG_GET, /* can dump */ VDPA_CMD_DEV_CONFIG_GET, /* can dump */
VDPA_CMD_DEV_VSTATS_GET,
}; };
enum vdpa_attr { enum vdpa_attr {
...@@ -46,6 +47,11 @@ enum vdpa_attr { ...@@ -46,6 +47,11 @@ enum vdpa_attr {
VDPA_ATTR_DEV_NEGOTIATED_FEATURES, /* u64 */ VDPA_ATTR_DEV_NEGOTIATED_FEATURES, /* u64 */
VDPA_ATTR_DEV_MGMTDEV_MAX_VQS, /* u32 */ VDPA_ATTR_DEV_MGMTDEV_MAX_VQS, /* u32 */
VDPA_ATTR_DEV_SUPPORTED_FEATURES, /* u64 */ VDPA_ATTR_DEV_SUPPORTED_FEATURES, /* u64 */
VDPA_ATTR_DEV_QUEUE_INDEX, /* u32 */
VDPA_ATTR_DEV_VENDOR_ATTR_NAME, /* string */
VDPA_ATTR_DEV_VENDOR_ATTR_VALUE, /* u64 */
/* new attributes must be added above here */ /* new attributes must be added above here */
VDPA_ATTR_MAX, VDPA_ATTR_MAX,
}; };
......
...@@ -89,11 +89,6 @@ ...@@ -89,11 +89,6 @@
/* Set or get vhost backend capability */ /* Set or get vhost backend capability */
/* Use message type V2 */
#define VHOST_BACKEND_F_IOTLB_MSG_V2 0x1
/* IOTLB can accept batching hints */
#define VHOST_BACKEND_F_IOTLB_BATCH 0x2
#define VHOST_SET_BACKEND_FEATURES _IOW(VHOST_VIRTIO, 0x25, __u64) #define VHOST_SET_BACKEND_FEATURES _IOW(VHOST_VIRTIO, 0x25, __u64)
#define VHOST_GET_BACKEND_FEATURES _IOR(VHOST_VIRTIO, 0x26, __u64) #define VHOST_GET_BACKEND_FEATURES _IOR(VHOST_VIRTIO, 0x26, __u64)
...@@ -150,11 +145,30 @@ ...@@ -150,11 +145,30 @@
/* Get the valid iova range */ /* Get the valid iova range */
#define VHOST_VDPA_GET_IOVA_RANGE _IOR(VHOST_VIRTIO, 0x78, \ #define VHOST_VDPA_GET_IOVA_RANGE _IOR(VHOST_VIRTIO, 0x78, \
struct vhost_vdpa_iova_range) struct vhost_vdpa_iova_range)
/* Get the config size */ /* Get the config size */
#define VHOST_VDPA_GET_CONFIG_SIZE _IOR(VHOST_VIRTIO, 0x79, __u32) #define VHOST_VDPA_GET_CONFIG_SIZE _IOR(VHOST_VIRTIO, 0x79, __u32)
/* Get the count of all virtqueues */ /* Get the count of all virtqueues */
#define VHOST_VDPA_GET_VQS_COUNT _IOR(VHOST_VIRTIO, 0x80, __u32) #define VHOST_VDPA_GET_VQS_COUNT _IOR(VHOST_VIRTIO, 0x80, __u32)
/* Get the number of virtqueue groups. */
#define VHOST_VDPA_GET_GROUP_NUM _IOR(VHOST_VIRTIO, 0x81, __u32)
/* Get the number of address spaces. */
#define VHOST_VDPA_GET_AS_NUM _IOR(VHOST_VIRTIO, 0x7A, unsigned int)
/* Get the group for a virtqueue: read index, write group in num,
* The virtqueue index is stored in the index field of
* vhost_vring_state. The group for this specific virtqueue is
* returned via num field of vhost_vring_state.
*/
#define VHOST_VDPA_GET_VRING_GROUP _IOWR(VHOST_VIRTIO, 0x7B, \
struct vhost_vring_state)
/* Set the ASID for a virtqueue group. The group index is stored in
* the index field of vhost_vring_state, the ASID associated with this
* group is stored at num field of vhost_vring_state.
*/
#define VHOST_VDPA_SET_GROUP_ASID _IOW(VHOST_VIRTIO, 0x7C, \
struct vhost_vring_state)
#endif #endif
...@@ -87,7 +87,7 @@ struct vhost_msg { ...@@ -87,7 +87,7 @@ struct vhost_msg {
struct vhost_msg_v2 { struct vhost_msg_v2 {
__u32 type; __u32 type;
__u32 reserved; __u32 asid;
union { union {
struct vhost_iotlb_msg iotlb; struct vhost_iotlb_msg iotlb;
__u8 padding[64]; __u8 padding[64];
...@@ -153,4 +153,13 @@ struct vhost_vdpa_iova_range { ...@@ -153,4 +153,13 @@ struct vhost_vdpa_iova_range {
/* vhost-net should add virtio_net_hdr for RX, and strip for TX packets. */ /* vhost-net should add virtio_net_hdr for RX, and strip for TX packets. */
#define VHOST_NET_F_VIRTIO_NET_HDR 27 #define VHOST_NET_F_VIRTIO_NET_HDR 27
/* Use message type V2 */
#define VHOST_BACKEND_F_IOTLB_MSG_V2 0x1
/* IOTLB can accept batching hints */
#define VHOST_BACKEND_F_IOTLB_BATCH 0x2
/* IOTLB can accept address space identifier through V2 type of IOTLB
* message
*/
#define VHOST_BACKEND_F_IOTLB_ASID 0x3
#endif #endif
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment