diff --git a/arch/x86/include/asm/kvm_host.h b/arch/x86/include/asm/kvm_host.h
index f8ab0d760231ee34770e1e9afd9de58253762fee..3c9ea26c7aea73daa1cfc43392a8d90121ca6431 100644
--- a/arch/x86/include/asm/kvm_host.h
+++ b/arch/x86/include/asm/kvm_host.h
@@ -461,7 +461,6 @@ struct kvm_arch {
 	unsigned int n_requested_mmu_pages;
 	unsigned int n_max_mmu_pages;
 	unsigned int indirect_shadow_pages;
-	atomic_t invlpg_counter;
 	struct hlist_head mmu_page_hash[KVM_NUM_MMU_PAGES];
 	/*
 	 * Hash table of struct kvm_mmu_page.
@@ -757,8 +756,7 @@ int fx_init(struct kvm_vcpu *vcpu);
 
 void kvm_mmu_flush_tlb(struct kvm_vcpu *vcpu);
 void kvm_mmu_pte_write(struct kvm_vcpu *vcpu, gpa_t gpa,
-		       const u8 *new, int bytes,
-		       bool guest_initiated);
+		       const u8 *new, int bytes);
 int kvm_mmu_unprotect_page(struct kvm *kvm, gfn_t gfn);
 int kvm_mmu_unprotect_page_virt(struct kvm_vcpu *vcpu, gva_t gva);
 void __kvm_mmu_free_some_pages(struct kvm_vcpu *vcpu);
diff --git a/arch/x86/kvm/mmu.c b/arch/x86/kvm/mmu.c
index d15f908649e72068b3ea9be9002c15fb1bcdfc27..c01137f10c6b5d0f5892b68eb332d0fe9fb9c062 100644
--- a/arch/x86/kvm/mmu.c
+++ b/arch/x86/kvm/mmu.c
@@ -3531,8 +3531,7 @@ static bool last_updated_pte_accessed(struct kvm_vcpu *vcpu)
 }
 
 void kvm_mmu_pte_write(struct kvm_vcpu *vcpu, gpa_t gpa,
-		       const u8 *new, int bytes,
-		       bool guest_initiated)
+		       const u8 *new, int bytes)
 {
 	gfn_t gfn = gpa >> PAGE_SHIFT;
 	union kvm_mmu_page_role mask = { .word = 0 };
@@ -3541,7 +3540,7 @@ void kvm_mmu_pte_write(struct kvm_vcpu *vcpu, gpa_t gpa,
 	LIST_HEAD(invalid_list);
 	u64 entry, gentry, *spte;
 	unsigned pte_size, page_offset, misaligned, quadrant, offset;
-	int level, npte, invlpg_counter, r, flooded = 0;
+	int level, npte, r, flooded = 0;
 	bool remote_flush, local_flush, zap_page;
 
 	/*
@@ -3556,19 +3555,16 @@ void kvm_mmu_pte_write(struct kvm_vcpu *vcpu, gpa_t gpa,
 
 	pgprintk("%s: gpa %llx bytes %d\n", __func__, gpa, bytes);
 
-	invlpg_counter = atomic_read(&vcpu->kvm->arch.invlpg_counter);
-
 	/*
 	 * Assume that the pte write on a page table of the same type
 	 * as the current vcpu paging mode since we update the sptes only
 	 * when they have the same mode.
 	 */
-	if ((is_pae(vcpu) && bytes == 4) || !new) {
+	if (is_pae(vcpu) && bytes == 4) {
 		/* Handle a 32-bit guest writing two halves of a 64-bit gpte */
-		if (is_pae(vcpu)) {
-			gpa &= ~(gpa_t)7;
-			bytes = 8;
-		}
+		gpa &= ~(gpa_t)7;
+		bytes = 8;
+
 		r = kvm_read_guest(vcpu->kvm, gpa, &gentry, min(bytes, 8));
 		if (r)
 			gentry = 0;
@@ -3594,22 +3590,18 @@ void kvm_mmu_pte_write(struct kvm_vcpu *vcpu, gpa_t gpa,
 	 */
 	mmu_topup_memory_caches(vcpu);
 	spin_lock(&vcpu->kvm->mmu_lock);
-	if (atomic_read(&vcpu->kvm->arch.invlpg_counter) != invlpg_counter)
-		gentry = 0;
 	kvm_mmu_free_some_pages(vcpu);
 	++vcpu->kvm->stat.mmu_pte_write;
 	trace_kvm_mmu_audit(vcpu, AUDIT_PRE_PTE_WRITE);
-	if (guest_initiated) {
-		if (gfn == vcpu->arch.last_pt_write_gfn
-		    && !last_updated_pte_accessed(vcpu)) {
-			++vcpu->arch.last_pt_write_count;
-			if (vcpu->arch.last_pt_write_count >= 3)
-				flooded = 1;
-		} else {
-			vcpu->arch.last_pt_write_gfn = gfn;
-			vcpu->arch.last_pt_write_count = 1;
-			vcpu->arch.last_pte_updated = NULL;
-		}
+	if (gfn == vcpu->arch.last_pt_write_gfn
+	    && !last_updated_pte_accessed(vcpu)) {
+		++vcpu->arch.last_pt_write_count;
+		if (vcpu->arch.last_pt_write_count >= 3)
+			flooded = 1;
+	} else {
+		vcpu->arch.last_pt_write_gfn = gfn;
+		vcpu->arch.last_pt_write_count = 1;
+		vcpu->arch.last_pte_updated = NULL;
 	}
 
 	mask.cr0_wp = mask.cr4_pae = mask.nxe = 1;
diff --git a/arch/x86/kvm/paging_tmpl.h b/arch/x86/kvm/paging_tmpl.h
index d8d3906649da2a0ba6d46674aad0661bff1bc03f..9efb860357741d84815e8a2e2d3bc92837912af0 100644
--- a/arch/x86/kvm/paging_tmpl.h
+++ b/arch/x86/kvm/paging_tmpl.h
@@ -672,20 +672,27 @@ static void FNAME(invlpg)(struct kvm_vcpu *vcpu, gva_t gva)
 {
 	struct kvm_shadow_walk_iterator iterator;
 	struct kvm_mmu_page *sp;
-	gpa_t pte_gpa = -1;
 	int level;
 	u64 *sptep;
 
 	vcpu_clear_mmio_info(vcpu, gva);
 
-	spin_lock(&vcpu->kvm->mmu_lock);
+	/*
+	 * No need to check return value here, rmap_can_add() can
+	 * help us to skip pte prefetch later.
+	 */
+	mmu_topup_memory_caches(vcpu);
 
+	spin_lock(&vcpu->kvm->mmu_lock);
 	for_each_shadow_entry(vcpu, gva, iterator) {
 		level = iterator.level;
 		sptep = iterator.sptep;
 
 		sp = page_header(__pa(sptep));
 		if (is_last_spte(*sptep, level)) {
+			pt_element_t gpte;
+			gpa_t pte_gpa;
+
 			if (!sp->unsync)
 				break;
 
@@ -694,22 +701,21 @@ static void FNAME(invlpg)(struct kvm_vcpu *vcpu, gva_t gva)
 
 			if (mmu_page_zap_pte(vcpu->kvm, sp, sptep))
 				kvm_flush_remote_tlbs(vcpu->kvm);
+
+			if (!rmap_can_add(vcpu))
+				break;
+
+			if (kvm_read_guest_atomic(vcpu->kvm, pte_gpa, &gpte,
+						  sizeof(pt_element_t)))
+				break;
+
+			FNAME(update_pte)(vcpu, sp, sptep, &gpte);
 		}
 
 		if (!is_shadow_present_pte(*sptep) || !sp->unsync_children)
 			break;
 	}
-
-	atomic_inc(&vcpu->kvm->arch.invlpg_counter);
-
 	spin_unlock(&vcpu->kvm->mmu_lock);
-
-	if (pte_gpa == -1)
-		return;
-
-	if (mmu_topup_memory_caches(vcpu))
-		return;
-	kvm_mmu_pte_write(vcpu, pte_gpa, NULL, sizeof(pt_element_t), 0);
 }
 
 static gpa_t FNAME(gva_to_gpa)(struct kvm_vcpu *vcpu, gva_t vaddr, u32 access,
diff --git a/arch/x86/kvm/x86.c b/arch/x86/kvm/x86.c
index a2154487917d5d9a9f8f6adacd043b7262ad625f..9c980ce26e6171a9b3a96fb68ab0503d24285929 100644
--- a/arch/x86/kvm/x86.c
+++ b/arch/x86/kvm/x86.c
@@ -4087,7 +4087,7 @@ int emulator_write_phys(struct kvm_vcpu *vcpu, gpa_t gpa,
 	ret = kvm_write_guest(vcpu->kvm, gpa, val, bytes);
 	if (ret < 0)
 		return 0;
-	kvm_mmu_pte_write(vcpu, gpa, val, bytes, 1);
+	kvm_mmu_pte_write(vcpu, gpa, val, bytes);
 	return 1;
 }
 
@@ -4324,7 +4324,7 @@ static int emulator_cmpxchg_emulated(struct x86_emulate_ctxt *ctxt,
 	if (!exchanged)
 		return X86EMUL_CMPXCHG_FAILED;
 
-	kvm_mmu_pte_write(vcpu, gpa, new, bytes, 1);
+	kvm_mmu_pte_write(vcpu, gpa, new, bytes);
 
 	return X86EMUL_CONTINUE;