Commit 12fd3499 authored by Tw's avatar Tw Committed by GitHub

Merge pull request #1232 from mholt/fix-1229

proxy: record request Body for retry (fixes #1229)
parents 0cdaaba4 dd4c4d7e
package proxy
import (
"bytes"
"io"
"io/ioutil"
)
type bufferedBody struct {
*bytes.Reader
}
func (*bufferedBody) Close() error {
return nil
}
// rewind allows bufferedBody to be read again.
func (b *bufferedBody) rewind() error {
if b == nil {
return nil
}
_, err := b.Seek(0, io.SeekStart)
return err
}
// newBufferedBody returns *bufferedBody to use in place of src. Closes src
// and returns Read error on src. All content from src is buffered.
func newBufferedBody(src io.ReadCloser) (*bufferedBody, error) {
if src == nil {
return nil, nil
}
b, err := ioutil.ReadAll(src)
src.Close()
if err != nil {
return nil, err
}
return &bufferedBody{
Reader: bytes.NewReader(b),
}, nil
}
package proxy
import (
"bytes"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
)
func TestBodyRetry(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
io.Copy(w, r.Body)
r.Body.Close()
}))
defer ts.Close()
testcase := "test content"
req, err := http.NewRequest(http.MethodPost, ts.URL, bytes.NewBufferString(testcase))
if err != nil {
t.Fatal(err)
}
body, err := newBufferedBody(req.Body)
if err != nil {
t.Fatal(err)
}
if body != nil {
req.Body = body
}
// simulate fail request
host := req.URL.Host
req.URL.Host = "example.com"
body.rewind()
_, _ = http.DefaultTransport.RoundTrip(req)
// retry request
req.URL.Host = host
body.rewind()
resp, err := http.DefaultTransport.RoundTrip(req)
if err != nil {
t.Fatal(err)
}
result, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
resp.Body.Close()
if string(result) != testcase {
t.Fatalf("result = %s, want %s", result, testcase)
}
// try one more time for body reuse
body.rewind()
resp, err = http.DefaultTransport.RoundTrip(req)
if err != nil {
t.Fatal(err)
}
result, err = ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
resp.Body.Close()
if string(result) != testcase {
t.Fatalf("result = %s, want %s", result, testcase)
}
}
...@@ -94,6 +94,15 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { ...@@ -94,6 +94,15 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
// outreq is the request that makes a roundtrip to the backend // outreq is the request that makes a roundtrip to the backend
outreq := createUpstreamRequest(r) outreq := createUpstreamRequest(r)
// record and replace outreq body
body, err := newBufferedBody(outreq.Body)
if err != nil {
return http.StatusBadRequest, errors.New("failed to read downstream request body")
}
if body != nil {
outreq.Body = body
}
// The keepRetrying function will return true if we should // The keepRetrying function will return true if we should
// loop and try to select another host, or false if we // loop and try to select another host, or false if we
// should break and stop retrying. // should break and stop retrying.
...@@ -164,6 +173,11 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { ...@@ -164,6 +173,11 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
downHeaderUpdateFn = createRespHeaderUpdateFn(host.DownstreamHeaders, replacer) downHeaderUpdateFn = createRespHeaderUpdateFn(host.DownstreamHeaders, replacer)
} }
// rewind request body to its beginning
if err := body.rewind(); err != nil {
return http.StatusInternalServerError, errors.New("unable to rewind downstream request body")
}
// tell the proxy to serve the request // tell the proxy to serve the request
atomic.AddInt64(&host.Conns, 1) atomic.AddInt64(&host.Conns, 1)
backendErr = proxy.ServeHTTP(w, outreq, downHeaderUpdateFn) backendErr = proxy.ServeHTTP(w, outreq, downHeaderUpdateFn)
......
...@@ -20,6 +20,7 @@ import ( ...@@ -20,6 +20,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/mholt/caddy/caddyfile"
"github.com/mholt/caddy/caddyhttp/httpserver" "github.com/mholt/caddy/caddyhttp/httpserver"
"golang.org/x/net/websocket" "golang.org/x/net/websocket"
...@@ -836,6 +837,63 @@ func TestProxyDirectorURL(t *testing.T) { ...@@ -836,6 +837,63 @@ func TestProxyDirectorURL(t *testing.T) {
} }
} }
func TestReverseProxyRetry(t *testing.T) {
log.SetOutput(ioutil.Discard)
defer log.SetOutput(os.Stderr)
// set up proxy
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
io.Copy(w, r.Body)
r.Body.Close()
}))
defer backend.Close()
su, err := NewStaticUpstreams(caddyfile.NewDispenser("Testfile", strings.NewReader(`
proxy / localhost:65535 localhost:65534 `+backend.URL+` {
policy round_robin
fail_timeout 5s
max_fails 1
try_duration 5s
try_interval 250ms
}
`)))
if err != nil {
t.Fatal(err)
}
p := &Proxy{
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
Upstreams: su,
}
// middle is required to simulate closable downstream request body
middle := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err = p.ServeHTTP(w, r)
if err != nil {
t.Error(err)
}
}))
defer middle.Close()
testcase := "test content"
r, err := http.NewRequest("POST", middle.URL, bytes.NewBufferString(testcase))
if err != nil {
t.Fatal(err)
}
resp, err := http.DefaultTransport.RoundTrip(r)
if err != nil {
t.Fatal(err)
}
b, err := ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
t.Fatal(err)
}
if string(b) != testcase {
t.Fatalf("string(b) = %s, want %s", string(b), testcase)
}
}
func newFakeUpstream(name string, insecure bool) *fakeUpstream { func newFakeUpstream(name string, insecure bool) *fakeUpstream {
uri, _ := url.Parse(name) uri, _ := url.Parse(name)
u := &fakeUpstream{ u := &fakeUpstream{
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment