From dd4c4d7eb649d4d0ae1b381a501f431c684ab27c Mon Sep 17 00:00:00 2001
From: Benny Ng <benny.tpng@gmail.com>
Date: Tue, 1 Nov 2016 12:34:39 +0800
Subject: [PATCH] proxy: record request Body for retry (fixes #1229)

---
 caddyhttp/proxy/body.go       | 40 ++++++++++++++++++++
 caddyhttp/proxy/body_test.go  | 69 +++++++++++++++++++++++++++++++++++
 caddyhttp/proxy/proxy.go      | 14 +++++++
 caddyhttp/proxy/proxy_test.go | 58 +++++++++++++++++++++++++++++
 4 files changed, 181 insertions(+)
 create mode 100644 caddyhttp/proxy/body.go
 create mode 100644 caddyhttp/proxy/body_test.go

diff --git a/caddyhttp/proxy/body.go b/caddyhttp/proxy/body.go
new file mode 100644
index 0000000..38d0016
--- /dev/null
+++ b/caddyhttp/proxy/body.go
@@ -0,0 +1,40 @@
+package proxy
+
+import (
+	"bytes"
+	"io"
+	"io/ioutil"
+)
+
+type bufferedBody struct {
+	*bytes.Reader
+}
+
+func (*bufferedBody) Close() error {
+	return nil
+}
+
+// rewind allows bufferedBody to be read again.
+func (b *bufferedBody) rewind() error {
+	if b == nil {
+		return nil
+	}
+	_, err := b.Seek(0, io.SeekStart)
+	return err
+}
+
+// newBufferedBody returns *bufferedBody to use in place of src. Closes src
+// and returns Read error on src. All content from src is buffered.
+func newBufferedBody(src io.ReadCloser) (*bufferedBody, error) {
+	if src == nil {
+		return nil, nil
+	}
+	b, err := ioutil.ReadAll(src)
+	src.Close()
+	if err != nil {
+		return nil, err
+	}
+	return &bufferedBody{
+		Reader: bytes.NewReader(b),
+	}, nil
+}
diff --git a/caddyhttp/proxy/body_test.go b/caddyhttp/proxy/body_test.go
new file mode 100644
index 0000000..5b72784
--- /dev/null
+++ b/caddyhttp/proxy/body_test.go
@@ -0,0 +1,69 @@
+package proxy
+
+import (
+	"bytes"
+	"io"
+	"io/ioutil"
+	"net/http"
+	"net/http/httptest"
+	"testing"
+)
+
+func TestBodyRetry(t *testing.T) {
+	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		io.Copy(w, r.Body)
+		r.Body.Close()
+	}))
+	defer ts.Close()
+
+	testcase := "test content"
+	req, err := http.NewRequest(http.MethodPost, ts.URL, bytes.NewBufferString(testcase))
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	body, err := newBufferedBody(req.Body)
+	if err != nil {
+		t.Fatal(err)
+	}
+	if body != nil {
+		req.Body = body
+	}
+
+	// simulate fail request
+	host := req.URL.Host
+	req.URL.Host = "example.com"
+	body.rewind()
+	_, _ = http.DefaultTransport.RoundTrip(req)
+
+	// retry request
+	req.URL.Host = host
+	body.rewind()
+	resp, err := http.DefaultTransport.RoundTrip(req)
+	if err != nil {
+		t.Fatal(err)
+	}
+	result, err := ioutil.ReadAll(resp.Body)
+	if err != nil {
+		t.Fatal(err)
+	}
+	resp.Body.Close()
+	if string(result) != testcase {
+		t.Fatalf("result = %s, want %s", result, testcase)
+	}
+
+	// try one more time for body reuse
+	body.rewind()
+	resp, err = http.DefaultTransport.RoundTrip(req)
+	if err != nil {
+		t.Fatal(err)
+	}
+	result, err = ioutil.ReadAll(resp.Body)
+	if err != nil {
+		t.Fatal(err)
+	}
+	resp.Body.Close()
+	if string(result) != testcase {
+		t.Fatalf("result = %s, want %s", result, testcase)
+	}
+}
diff --git a/caddyhttp/proxy/proxy.go b/caddyhttp/proxy/proxy.go
index 71c7476..11f2d5d 100644
--- a/caddyhttp/proxy/proxy.go
+++ b/caddyhttp/proxy/proxy.go
@@ -94,6 +94,15 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
 	// outreq is the request that makes a roundtrip to the backend
 	outreq := createUpstreamRequest(r)
 
+	// record and replace outreq body
+	body, err := newBufferedBody(outreq.Body)
+	if err != nil {
+		return http.StatusBadRequest, errors.New("failed to read downstream request body")
+	}
+	if body != nil {
+		outreq.Body = body
+	}
+
 	// The keepRetrying function will return true if we should
 	// loop and try to select another host, or false if we
 	// should break and stop retrying.
@@ -164,6 +173,11 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
 			downHeaderUpdateFn = createRespHeaderUpdateFn(host.DownstreamHeaders, replacer)
 		}
 
+		// rewind request body to its beginning
+		if err := body.rewind(); err != nil {
+			return http.StatusInternalServerError, errors.New("unable to rewind downstream request body")
+		}
+
 		// tell the proxy to serve the request
 		atomic.AddInt64(&host.Conns, 1)
 		backendErr = proxy.ServeHTTP(w, outreq, downHeaderUpdateFn)
diff --git a/caddyhttp/proxy/proxy_test.go b/caddyhttp/proxy/proxy_test.go
index af02a17..290cae9 100644
--- a/caddyhttp/proxy/proxy_test.go
+++ b/caddyhttp/proxy/proxy_test.go
@@ -20,6 +20,7 @@ import (
 	"testing"
 	"time"
 
+	"github.com/mholt/caddy/caddyfile"
 	"github.com/mholt/caddy/caddyhttp/httpserver"
 
 	"golang.org/x/net/websocket"
@@ -836,6 +837,63 @@ func TestProxyDirectorURL(t *testing.T) {
 	}
 }
 
+func TestReverseProxyRetry(t *testing.T) {
+	log.SetOutput(ioutil.Discard)
+	defer log.SetOutput(os.Stderr)
+
+	// set up proxy
+	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		io.Copy(w, r.Body)
+		r.Body.Close()
+	}))
+	defer backend.Close()
+
+	su, err := NewStaticUpstreams(caddyfile.NewDispenser("Testfile", strings.NewReader(`
+	proxy / localhost:65535 localhost:65534 `+backend.URL+` {
+		policy round_robin
+		fail_timeout 5s
+		max_fails 1
+		try_duration 5s
+		try_interval 250ms
+	}
+	`)))
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	p := &Proxy{
+		Next:      httpserver.EmptyNext, // prevents panic in some cases when test fails
+		Upstreams: su,
+	}
+
+	// middle is required to simulate closable downstream request body
+	middle := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		_, err = p.ServeHTTP(w, r)
+		if err != nil {
+			t.Error(err)
+		}
+	}))
+	defer middle.Close()
+
+	testcase := "test content"
+	r, err := http.NewRequest("POST", middle.URL, bytes.NewBufferString(testcase))
+	if err != nil {
+		t.Fatal(err)
+	}
+	resp, err := http.DefaultTransport.RoundTrip(r)
+	if err != nil {
+		t.Fatal(err)
+	}
+	b, err := ioutil.ReadAll(resp.Body)
+	resp.Body.Close()
+	if err != nil {
+		t.Fatal(err)
+	}
+	if string(b) != testcase {
+		t.Fatalf("string(b) = %s, want %s", string(b), testcase)
+	}
+}
+
 func newFakeUpstream(name string, insecure bool) *fakeUpstream {
 	uri, _ := url.Parse(name)
 	u := &fakeUpstream{
-- 
2.30.9