Commit 6490ff62 authored by Matthew Holt's avatar Matthew Holt

Adjust proxy headers properly (fixes #916)

parent 57710e8b
......@@ -84,7 +84,7 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
}
// this replacer is used to fill in header field values
var replacer httpserver.Replacer
replacer := httpserver.NewReplacer(r, nil, "")
// outreq is the request that makes a roundtrip to the backend
outreq := createUpstreamRequest(r)
......@@ -119,16 +119,10 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
// set headers for request going upstream
if host.UpstreamHeaders != nil {
if replacer == nil {
replacer = httpserver.NewReplacer(r, nil, "")
}
if v, ok := host.UpstreamHeaders["Host"]; ok {
outreq.Host = replacer.Replace(v[len(v)-1])
}
// modify headers for request that will be sent to the upstream host
upHeaders := createHeadersByRules(host.UpstreamHeaders, r.Header, replacer)
for k, v := range upHeaders {
outreq.Header[k] = v
mutateHeadersByRules(outreq.Header, host.UpstreamHeaders, replacer)
if hostHeaders, ok := outreq.Header["Host"]; ok && len(hostHeaders) > 0 {
outreq.Host = hostHeaders[len(hostHeaders)-1]
}
}
......@@ -136,9 +130,6 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
// headers coming back downstream
var downHeaderUpdateFn respUpdateFn
if host.DownstreamHeaders != nil {
if replacer == nil {
replacer = httpserver.NewReplacer(r, nil, "")
}
downHeaderUpdateFn = createRespHeaderUpdateFn(host.DownstreamHeaders, replacer)
}
......@@ -185,6 +176,8 @@ func (p Proxy) match(r *http.Request) Upstream {
// createUpstremRequest shallow-copies r into a new request
// that can be sent upstream.
//
// Derived from reverseproxy.go in the standard Go httputil package.
func createUpstreamRequest(r *http.Request) *http.Request {
outreq := new(http.Request)
*outreq = *r // includes shallow copies of maps, but okay
......@@ -199,10 +192,14 @@ func createUpstreamRequest(r *http.Request) *http.Request {
// connection, regardless of what the client sent to us. This
// is modifying the same underlying map from r (shallow
// copied above) so we only copy it if necessary.
var copiedHeaders bool
for _, h := range hopHeaders {
if outreq.Header.Get(h) != "" {
if !copiedHeaders {
outreq.Header = make(http.Header)
copyHeader(outreq.Header, r.Header)
copiedHeaders = true
}
outreq.Header.Del(h)
}
}
......@@ -222,45 +219,20 @@ func createUpstreamRequest(r *http.Request) *http.Request {
func createRespHeaderUpdateFn(rules http.Header, replacer httpserver.Replacer) respUpdateFn {
return func(resp *http.Response) {
newHeaders := createHeadersByRules(rules, resp.Header, replacer)
for h, v := range newHeaders {
resp.Header[h] = v
}
mutateHeadersByRules(resp.Header, rules, replacer)
}
}
func createHeadersByRules(rules http.Header, base http.Header, repl httpserver.Replacer) http.Header {
newHeaders := make(http.Header)
for header, values := range rules {
if strings.HasPrefix(header, "+") {
header = strings.TrimLeft(header, "+")
add(newHeaders, header, base[header])
applyEach(values, repl.Replace)
add(newHeaders, header, values)
} else if strings.HasPrefix(header, "-") {
base.Del(strings.TrimLeft(header, "-"))
} else if _, ok := base[header]; ok {
applyEach(values, repl.Replace)
for _, v := range values {
newHeaders.Set(header, v)
}
} else {
applyEach(values, repl.Replace)
add(newHeaders, header, values)
add(newHeaders, header, base[header])
}
func mutateHeadersByRules(headers, rules http.Header, repl httpserver.Replacer) {
for ruleField, ruleValues := range rules {
if strings.HasPrefix(ruleField, "+") {
for _, ruleValue := range ruleValues {
headers.Add(strings.TrimPrefix(ruleField, "+"), repl.Replace(ruleValue))
}
return newHeaders
}
func applyEach(values []string, mapFn func(string) string) {
for i, v := range values {
values[i] = mapFn(v)
} else if strings.HasPrefix(ruleField, "-") {
headers.Del(strings.TrimPrefix(ruleField, "-"))
} else if len(ruleValues) > 0 {
headers.Set(ruleField, repl.Replace(ruleValues[len(ruleValues)-1]))
}
}
func add(base http.Header, header string, values []string) {
for _, v := range values {
base.Add(header, v)
}
}
......@@ -177,10 +177,11 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, r
res, err := transport.RoundTrip(outreq)
if err != nil {
return err
} else if respUpdateFn != nil {
respUpdateFn(res)
}
if respUpdateFn != nil {
respUpdateFn(res)
}
if res.StatusCode == http.StatusSwitchingProtocols && strings.ToLower(res.Header.Get("Upgrade")) == "websocket" {
res.Body.Close()
hj, ok := rw.(http.Hijacker)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment