Commit 5f860d3a authored by Leonard Hecker's avatar Leonard Hecker Committed by Matt Holt

proxy: Fixed #1502: Proxying of unannounced trailers (#1588)

parent 6bb84ba1
...@@ -44,32 +44,62 @@ func TestReverseProxy(t *testing.T) { ...@@ -44,32 +44,62 @@ func TestReverseProxy(t *testing.T) {
log.SetOutput(ioutil.Discard) log.SetOutput(ioutil.Discard)
defer log.SetOutput(os.Stderr) defer log.SetOutput(os.Stderr)
verifyHeaders := func(headers http.Header, trailers http.Header) { testHeaderValue := []string{"header-value"}
if headers.Get("X-Header") != "header-value" { testHeaders := http.Header{
t.Error("Expected header 'X-Header' to be proxied properly") "X-Header-1": testHeaderValue,
"X-Header-2": testHeaderValue,
"X-Header-3": testHeaderValue,
}
testTrailerValue := []string{"trailer-value"}
testTrailers := http.Header{
"X-Trailer-1": testTrailerValue,
"X-Trailer-2": testTrailerValue,
"X-Trailer-3": testTrailerValue,
}
verifyHeaderValues := func(actual http.Header, expected http.Header) bool {
if actual == nil {
t.Error("Expected headers")
return true
}
for k := range expected {
if expected.Get(k) != actual.Get(k) {
t.Errorf("Expected header '%s' to be proxied properly", k)
return true
}
} }
if trailers == nil { return false
t.Error("Expected to receive trailers")
} }
if trailers.Get("X-Trailer") != "trailer-value" { verifyHeadersTrailers := func(headers http.Header, trailers http.Header) {
t.Error("Expected header 'X-Trailer' to be proxied properly") if verifyHeaderValues(headers, testHeaders) || verifyHeaderValues(trailers, testTrailers) {
t.FailNow()
} }
} }
var requestReceived bool requestReceived := false
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// read the body (even if it's empty) to make Go parse trailers // read the body (even if it's empty) to make Go parse trailers
io.Copy(ioutil.Discard, r.Body) io.Copy(ioutil.Discard, r.Body)
verifyHeaders(r.Header, r.Trailer)
verifyHeadersTrailers(r.Header, r.Trailer)
requestReceived = true requestReceived = true
w.Header().Set("Trailer", "X-Trailer") // Set headers.
w.Header().Set("X-Header", "header-value") copyHeader(w.Header(), testHeaders)
// Only announce one of the trailers to test wether
// unannounced trailers are proxied correctly.
for k := range testTrailers {
w.Header().Set("Trailer", k)
break
}
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write([]byte("Hello, client")) w.Write([]byte("Hello, client"))
w.Header().Set("X-Trailer", "trailer-value")
// Set trailers.
shallowCopyTrailers(w.Header(), testTrailers, true)
})) }))
defer backend.Close() defer backend.Close()
...@@ -79,24 +109,37 @@ func TestReverseProxy(t *testing.T) { ...@@ -79,24 +109,37 @@ func TestReverseProxy(t *testing.T) {
Upstreams: []Upstream{newFakeUpstream(backend.URL, false)}, Upstreams: []Upstream{newFakeUpstream(backend.URL, false)},
} }
// create request and response recorder // Create the fake request body.
r := httptest.NewRequest("GET", "/", strings.NewReader("test")) // This will copy "trailersToSet" to r.Trailer right before it is closed and
w := httptest.NewRecorder() // thus test for us wether unannounced client trailers are proxied correctly.
body := &trailerTestStringReader{
Reader: *strings.NewReader("test"),
trailersToSet: testTrailers,
}
// Create the fake request with the above body.
r := httptest.NewRequest("GET", "/", body)
r.Trailer = make(http.Header)
body.request = r
copyHeader(r.Header, testHeaders)
r.ContentLength = -1 // force chunked encoding (required for trailers) // Only announce one of the trailers to test wether
r.Header.Set("X-Header", "header-value") // unannounced trailers are proxied correctly.
r.Trailer = map[string][]string{ for k, v := range testTrailers {
"X-Trailer": {"trailer-value"}, r.Trailer[k] = v
break
} }
w := httptest.NewRecorder()
p.ServeHTTP(w, r) p.ServeHTTP(w, r)
res := w.Result()
if !requestReceived { if !requestReceived {
t.Error("Expected backend to receive request, but it didn't") t.Error("Expected backend to receive request, but it didn't")
} }
res := w.Result() verifyHeadersTrailers(res.Header, res.Trailer)
verifyHeaders(res.Header, res.Trailer)
// Make sure {upstream} placeholder is set // Make sure {upstream} placeholder is set
r.Body = ioutil.NopCloser(strings.NewReader("test")) r.Body = ioutil.NopCloser(strings.NewReader("test"))
...@@ -112,6 +155,21 @@ func TestReverseProxy(t *testing.T) { ...@@ -112,6 +155,21 @@ func TestReverseProxy(t *testing.T) {
} }
} }
// trailerTestStringReader is used to test unannounced trailers coming
// from a client which should properly be proxied to the upstream.
type trailerTestStringReader struct {
strings.Reader
request *http.Request
trailersToSet http.Header
}
var _ io.ReadCloser = &trailerTestStringReader{}
func (r *trailerTestStringReader) Close() error {
copyHeader(r.request.Trailer, r.trailersToSet)
return nil
}
func TestReverseProxyInsecureSkipVerify(t *testing.T) { func TestReverseProxyInsecureSkipVerify(t *testing.T) {
log.SetOutput(ioutil.Discard) log.SetOutput(ioutil.Discard)
defer log.SetOutput(os.Stderr) defer log.SetOutput(os.Stderr)
......
...@@ -318,30 +318,61 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, ...@@ -318,30 +318,61 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request,
} }
pooledIoCopy(backendConn, conn) pooledIoCopy(backendConn, conn)
} else { } else {
// NOTE:
// Closing the Body involves acquiring a mutex, which is a
// unnecessarily heavy operation, considering that this defer will
// pretty much never be executed with the Body still unclosed.
bodyOpen := true
closeBody := func() {
if bodyOpen {
res.Body.Close()
bodyOpen = false
}
}
defer closeBody()
// Copy all headers over.
// res.Header does not include the "Trailer" header,
// which means we will have to do that manually below.
copyHeader(rw.Header(), res.Header) copyHeader(rw.Header(), res.Header)
// The "Trailer" header isn't included in the Transport's response, // The "Trailer" header isn't included in res' Header map, which
// at least for *http.Transport. Build it up from Trailer. // is why we have to build one ourselves from res.Trailer.
if len(res.Trailer) > 0 { //
trailerKeys := make([]string, 0, len(res.Trailer)) // But res.Trailer does not necessarily contain all trailer keys at this
// point yet. The HTTP spec allows one to send "unannounced trailers"
// after a request and certain systems like gRPC make use of that.
announcedTrailerKeyCount := len(res.Trailer)
if announcedTrailerKeyCount > 0 {
vv := make([]string, 0, announcedTrailerKeyCount)
for k := range res.Trailer { for k := range res.Trailer {
trailerKeys = append(trailerKeys, k) vv = append(vv, k)
} }
rw.Header().Add("Trailer", strings.Join(trailerKeys, ", ")) rw.Header()["Trailer"] = vv
} }
// Now copy over the status code as well as the response body.
rw.WriteHeader(res.StatusCode) rw.WriteHeader(res.StatusCode)
if len(res.Trailer) > 0 { if announcedTrailerKeyCount > 0 {
// Force chunking if we saw a response trailer. // Force chunking if we saw a response trailer.
// This prevents net/http from calculating the length for short // This prevents net/http from calculating the length
// bodies and adding a Content-Length. // for short bodies and adding a Content-Length.
if fl, ok := rw.(http.Flusher); ok { if fl, ok := rw.(http.Flusher); ok {
fl.Flush() fl.Flush()
} }
} }
rp.copyResponse(rw, res.Body) rp.copyResponse(rw, res.Body)
res.Body.Close() // close now, instead of defer, to populate res.Trailer
copyHeader(rw.Header(), res.Trailer) // Now close the body to fully populate res.Trailer.
closeBody()
// Since Go does not remove keys from res.Trailer we
// can safely do a length comparison to check wether
// we received further, unannounced trailers.
//
// Most of the time forceSetTrailers should be false.
forceSetTrailers := len(res.Trailer) != announcedTrailerKeyCount
shallowCopyTrailers(rw.Header(), res.Trailer, forceSetTrailers)
} }
return nil return nil
...@@ -391,6 +422,22 @@ func copyHeader(dst, src http.Header) { ...@@ -391,6 +422,22 @@ func copyHeader(dst, src http.Header) {
} }
} }
// shallowCopyTrailers copies all headers from srcTrailer to dstHeader.
//
// If forceSetTrailers is set to true, the http.TrailerPrefix will be added to
// all srcTrailer key names. Otherwise the Go stdlib will ignore all keys
// which weren't listed in the Trailer map before submitting the Response.
//
// WARNING: Only a shallow copy will be created!
func shallowCopyTrailers(dstHeader, srcTrailer http.Header, forceSetTrailers bool) {
for k, vv := range srcTrailer {
if forceSetTrailers {
k = http.TrailerPrefix + k
}
dstHeader[k] = vv
}
}
// Hop-by-hop headers. These are removed when sent to the backend. // Hop-by-hop headers. These are removed when sent to the backend.
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html // http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
var hopHeaders = []string{ var hopHeaders = []string{
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment