summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--drivers/vfio/vfio_iommu_type1.c84
1 files changed, 72 insertions, 12 deletions
diff --git a/drivers/vfio/vfio_iommu_type1.c b/drivers/vfio/vfio_iommu_type1.c
index f8d68fe77b41..7829b5e268c2 100644
--- a/drivers/vfio/vfio_iommu_type1.c
+++ b/drivers/vfio/vfio_iommu_type1.c
@@ -37,6 +37,7 @@
#include <linux/vfio.h>
#include <linux/workqueue.h>
#include <linux/notifier.h>
+#include <linux/mm_inline.h>
#include "vfio.h"
#define DRIVER_VERSION "0.2"
@@ -318,7 +319,13 @@ static void vfio_dma_bitmap_free_all(struct vfio_iommu *iommu)
/*
* Helper Functions for host iova-pfn list
*/
-static struct vfio_pfn *vfio_find_vpfn(struct vfio_dma *dma, dma_addr_t iova)
+
+/*
+ * Find the highest vfio_pfn that overlapping the range
+ * [iova_start, iova_end) in rb tree.
+ */
+static struct vfio_pfn *vfio_find_vpfn_range(struct vfio_dma *dma,
+ dma_addr_t iova_start, dma_addr_t iova_end)
{
struct vfio_pfn *vpfn;
struct rb_node *node = dma->pfn_list.rb_node;
@@ -326,9 +333,9 @@ static struct vfio_pfn *vfio_find_vpfn(struct vfio_dma *dma, dma_addr_t iova)
while (node) {
vpfn = rb_entry(node, struct vfio_pfn, node);
- if (iova < vpfn->iova)
+ if (iova_end <= vpfn->iova)
node = node->rb_left;
- else if (iova > vpfn->iova)
+ else if (iova_start > vpfn->iova)
node = node->rb_right;
else
return vpfn;
@@ -336,6 +343,11 @@ static struct vfio_pfn *vfio_find_vpfn(struct vfio_dma *dma, dma_addr_t iova)
return NULL;
}
+static inline struct vfio_pfn *vfio_find_vpfn(struct vfio_dma *dma, dma_addr_t iova)
+{
+ return vfio_find_vpfn_range(dma, iova, iova + 1);
+}
+
static void vfio_link_pfn(struct vfio_dma *dma,
struct vfio_pfn *new)
{
@@ -614,6 +626,39 @@ done:
return ret;
}
+
+static long vpfn_pages(struct vfio_dma *dma,
+ dma_addr_t iova_start, long nr_pages)
+{
+ dma_addr_t iova_end = iova_start + (nr_pages << PAGE_SHIFT);
+ struct vfio_pfn *top = vfio_find_vpfn_range(dma, iova_start, iova_end);
+ long ret = 1;
+ struct vfio_pfn *vpfn;
+ struct rb_node *prev;
+ struct rb_node *next;
+
+ if (likely(!top))
+ return 0;
+
+ prev = next = &top->node;
+
+ while ((prev = rb_prev(prev))) {
+ vpfn = rb_entry(prev, struct vfio_pfn, node);
+ if (vpfn->iova < iova_start)
+ break;
+ ret++;
+ }
+
+ while ((next = rb_next(next))) {
+ vpfn = rb_entry(next, struct vfio_pfn, node);
+ if (vpfn->iova >= iova_end)
+ break;
+ ret++;
+ }
+
+ return ret;
+}
+
/*
* Attempt to pin pages. We really don't want to track all the pfns and
* the iommu can only map chunks of consecutive pfns anyway, so get the
@@ -687,32 +732,47 @@ static long vfio_pin_pages_remote(struct vfio_dma *dma, unsigned long vaddr,
* and rsvd here, and therefore continues to use the batch.
*/
while (true) {
+ long nr_pages, acct_pages = 0;
+
if (pfn != *pfn_base + pinned ||
rsvd != is_invalid_reserved_pfn(pfn))
goto out;
/*
+ * Using GUP with the FOLL_LONGTERM in
+ * vaddr_get_pfns() will not return invalid
+ * or reserved pages.
+ */
+ nr_pages = num_pages_contiguous(
+ &batch->pages[batch->offset],
+ batch->size);
+ if (!rsvd) {
+ acct_pages = nr_pages;
+ acct_pages -= vpfn_pages(dma, iova, nr_pages);
+ }
+
+ /*
* Reserved pages aren't counted against the user,
* externally pinned pages are already counted against
* the user.
*/
- if (!rsvd && !vfio_find_vpfn(dma, iova)) {
+ if (acct_pages) {
if (!dma->lock_cap &&
- mm->locked_vm + lock_acct + 1 > limit) {
+ mm->locked_vm + lock_acct + acct_pages > limit) {
pr_warn("%s: RLIMIT_MEMLOCK (%ld) exceeded\n",
__func__, limit << PAGE_SHIFT);
ret = -ENOMEM;
goto unpin_out;
}
- lock_acct++;
+ lock_acct += acct_pages;
}
- pinned++;
- npage--;
- vaddr += PAGE_SIZE;
- iova += PAGE_SIZE;
- batch->offset++;
- batch->size--;
+ pinned += nr_pages;
+ npage -= nr_pages;
+ vaddr += PAGE_SIZE * nr_pages;
+ iova += PAGE_SIZE * nr_pages;
+ batch->offset += nr_pages;
+ batch->size -= nr_pages;
if (!batch->size)
break;