Commit e2234497 authored by William Bezuidenhout's avatar William Bezuidenhout Committed by Matt Holt

proxy: Add, remove, or replace upstream and downstream headers (closes #666) (PR #788)

* Overwrite proxy headers based on directive

Headers of the request sent by the proxy upstream can now be modified in
the following way:

Prefix header with `+`: Header will be added if it doesn't exist
otherwise, the values will be merge
Prefix header with `-': Header will be removed
No prefix: Header will be replaced with given value

* Add missing formating directive reported by go vet

* Overwrite up/down stream proxy headers

Add Up/DownStreamHeaders to UpstreamHost

Split `proxy_header` option in `proxy` directive into `header_upstream`
and `header_downstream`. By splitting into two, it makes it clear in
what direction the given headers must be applied.

`proxy_header` can still be used (to maintain backward compatability)
but its assumed to be `header_upstream`

Response headers received by the reverse proxy from the upstream host
are updated according the `header_downstream` rules.
The update occurs through a func given to the reverse proxy, which is
applied once a response is received.

Headers (for upstream and downstream) can now be modified in
the following way:

Prefix header with `+`: Header will be added if it doesn't exist
otherwise, the values will be merge
Prefix header with `-': Header will be removed
No prefix: Header will be replaced with given value

Updated branch with changes from master

* minor refactor to make intent clearer

* Make Up/Down stream headers naming consistent

* Fix error descriptions to be more clear

* Fix lint issue
parent 96425f0f
...@@ -43,7 +43,8 @@ type UpstreamHost struct { ...@@ -43,7 +43,8 @@ type UpstreamHost struct {
Fails int32 Fails int32
FailTimeout time.Duration FailTimeout time.Duration
Unhealthy bool Unhealthy bool
ExtraHeaders http.Header UpstreamHeaders http.Header
DownstreamHeaders http.Header
CheckDown UpstreamHostDownFunc CheckDown UpstreamHostDownFunc
WithoutPathPrefix string WithoutPathPrefix string
MaxConns int64 MaxConns int64
...@@ -99,26 +100,33 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { ...@@ -99,26 +100,33 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
} }
outreq.Host = host.Name outreq.Host = host.Name
if host.ExtraHeaders != nil { if host.UpstreamHeaders != nil {
extraHeaders := make(http.Header)
if replacer == nil { if replacer == nil {
rHost := r.Host rHost := r.Host
replacer = middleware.NewReplacer(r, nil, "") replacer = middleware.NewReplacer(r, nil, "")
outreq.Host = rHost outreq.Host = rHost
} }
for header, values := range host.ExtraHeaders { if v, ok := host.UpstreamHeaders["Host"]; ok {
for _, value := range values { r.Host = replacer.Replace(v[len(v)-1])
extraHeaders.Add(header, replacer.Replace(value))
if header == "Host" {
outreq.Host = replacer.Replace(value)
}
}
} }
for k, v := range extraHeaders { // Modify headers for request that will be sent to the upstream host
upHeaders := createHeadersByRules(host.UpstreamHeaders, r.Header, replacer)
for k, v := range upHeaders {
outreq.Header[k] = v outreq.Header[k] = v
} }
} }
var downHeaderUpdateFn respUpdateFn
if host.DownstreamHeaders != nil {
if replacer == nil {
rHost := r.Host
replacer = middleware.NewReplacer(r, nil, "")
outreq.Host = rHost
}
//Creates a function that is used to update headers the response received by the reverse proxy
downHeaderUpdateFn = createRespHeaderUpdateFn(host.DownstreamHeaders, replacer)
}
proxy := host.ReverseProxy proxy := host.ReverseProxy
if baseURL, err := url.Parse(host.Name); err == nil { if baseURL, err := url.Parse(host.Name); err == nil {
r.Host = baseURL.Host r.Host = baseURL.Host
...@@ -130,7 +138,7 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { ...@@ -130,7 +138,7 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
} }
atomic.AddInt64(&host.Conns, 1) atomic.AddInt64(&host.Conns, 1)
backendErr := proxy.ServeHTTP(w, outreq) backendErr := proxy.ServeHTTP(w, outreq, downHeaderUpdateFn)
atomic.AddInt64(&host.Conns, -1) atomic.AddInt64(&host.Conns, -1)
if backendErr == nil { if backendErr == nil {
return 0, nil return 0, nil
...@@ -182,3 +190,48 @@ func createUpstreamRequest(r *http.Request) *http.Request { ...@@ -182,3 +190,48 @@ func createUpstreamRequest(r *http.Request) *http.Request {
return outreq return outreq
} }
func createRespHeaderUpdateFn(rules http.Header, replacer middleware.Replacer) respUpdateFn {
return func(resp *http.Response) {
newHeaders := createHeadersByRules(rules, resp.Header, replacer)
for h, v := range newHeaders {
resp.Header[h] = v
}
}
}
func createHeadersByRules(rules http.Header, base http.Header, repl middleware.Replacer) http.Header {
newHeaders := make(http.Header)
for header, values := range rules {
if strings.HasPrefix(header, "+") {
header = strings.TrimLeft(header, "+")
add(newHeaders, header, base[header])
applyEach(values, repl.Replace)
add(newHeaders, header, values)
} else if strings.HasPrefix(header, "-") {
base.Del(strings.TrimLeft(header, "-"))
} else if _, ok := base[header]; ok {
applyEach(values, repl.Replace)
for _, v := range values {
newHeaders.Set(header, v)
}
} else {
applyEach(values, repl.Replace)
add(newHeaders, header, values)
add(newHeaders, header, base[header])
}
}
return newHeaders
}
func applyEach(values []string, mapFn func(string) string) {
for i, v := range values {
values[i] = mapFn(v)
}
}
func add(base http.Header, header string, values []string) {
for _, v := range values {
base.Add(header, v)
}
}
...@@ -348,6 +348,141 @@ func TestUnixSocketProxyPaths(t *testing.T) { ...@@ -348,6 +348,141 @@ func TestUnixSocketProxyPaths(t *testing.T) {
} }
} }
func TestUpstreamHeadersUpdate(t *testing.T) {
log.SetOutput(ioutil.Discard)
defer log.SetOutput(os.Stderr)
var actualHeaders http.Header
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Hello, client"))
actualHeaders = r.Header
}))
defer backend.Close()
upstream := newFakeUpstream(backend.URL, false)
upstream.host.UpstreamHeaders = http.Header{
"Connection": {"{>Connection}"},
"Upgrade": {"{>Upgrade}"},
"+Merge-Me": {"Merge-Value"},
"+Add-Me": {"Add-Value"},
"-Remove-Me": {""},
"Replace-Me": {"{hostname}"},
}
// set up proxy
p := &Proxy{
Upstreams: []Upstream{upstream},
}
// create request and response recorder
r, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
w := httptest.NewRecorder()
//add initial headers
r.Header.Add("Merge-Me", "Initial")
r.Header.Add("Remove-Me", "Remove-Value")
r.Header.Add("Replace-Me", "Replace-Value")
p.ServeHTTP(w, r)
replacer := middleware.NewReplacer(r, nil, "")
headerKey := "Merge-Me"
values, ok := actualHeaders[headerKey]
if !ok {
t.Errorf("Request sent to upstream backend does not contain expected %v header. Expected header to be added", headerKey)
} else if len(values) < 2 && (values[0] != "Initial" || values[1] != replacer.Replace("{hostname}")) {
t.Errorf("Values for proxy header `+Merge-Me` should be merged. Got %v", values)
}
headerKey = "Add-Me"
if _, ok := actualHeaders[headerKey]; !ok {
t.Errorf("Request sent to upstream backend does not contain expected %v header", headerKey)
}
headerKey = "Remove-Me"
if _, ok := actualHeaders[headerKey]; ok {
t.Errorf("Request sent to upstream backend should not contain %v header", headerKey)
}
headerKey = "Replace-Me"
headerValue := replacer.Replace("{hostname}")
value, ok := actualHeaders[headerKey]
if !ok {
t.Errorf("Request sent to upstream backend should not remove %v header", headerKey)
} else if len(value) > 0 && headerValue != value[0] {
t.Errorf("Request sent to upstream backend should replace value of %v header with %v. Instead value was %v", headerKey, headerValue, value)
}
}
func TestDownstreamHeadersUpdate(t *testing.T) {
log.SetOutput(ioutil.Discard)
defer log.SetOutput(os.Stderr)
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Merge-Me", "Initial")
w.Header().Add("Remove-Me", "Remove-Value")
w.Header().Add("Replace-Me", "Replace-Value")
w.Write([]byte("Hello, client"))
}))
defer backend.Close()
upstream := newFakeUpstream(backend.URL, false)
upstream.host.DownstreamHeaders = http.Header{
"+Merge-Me": {"Merge-Value"},
"+Add-Me": {"Add-Value"},
"-Remove-Me": {""},
"Replace-Me": {"{hostname}"},
}
// set up proxy
p := &Proxy{
Upstreams: []Upstream{upstream},
}
// create request and response recorder
r, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
w := httptest.NewRecorder()
p.ServeHTTP(w, r)
replacer := middleware.NewReplacer(r, nil, "")
actualHeaders := w.Header()
headerKey := "Merge-Me"
values, ok := actualHeaders[headerKey]
if !ok {
t.Errorf("Downstream response does not contain expected %v header. Expected header should be added", headerKey)
} else if len(values) < 2 && (values[0] != "Initial" || values[1] != replacer.Replace("{hostname}")) {
t.Errorf("Values for header `+Merge-Me` should be merged. Got %v", values)
}
headerKey = "Add-Me"
if _, ok := actualHeaders[headerKey]; !ok {
t.Errorf("Downstream response does not contain expected %v header", headerKey)
}
headerKey = "Remove-Me"
if _, ok := actualHeaders[headerKey]; ok {
t.Errorf("Downstream response should not contain %v header received from upstream", headerKey)
}
headerKey = "Replace-Me"
headerValue := replacer.Replace("{hostname}")
value, ok := actualHeaders[headerKey]
if !ok {
t.Errorf("Downstream response should contain %v header and not remove it", headerKey)
} else if len(value) > 0 && headerValue != value[0] {
t.Errorf("Downstream response should have header %v with value %v. Instead value was %v", headerKey, headerValue, value)
}
}
func newFakeUpstream(name string, insecure bool) *fakeUpstream { func newFakeUpstream(name string, insecure bool) *fakeUpstream {
uri, _ := url.Parse(name) uri, _ := url.Parse(name)
u := &fakeUpstream{ u := &fakeUpstream{
...@@ -410,7 +545,7 @@ func (u *fakeWsUpstream) Select() *UpstreamHost { ...@@ -410,7 +545,7 @@ func (u *fakeWsUpstream) Select() *UpstreamHost {
return &UpstreamHost{ return &UpstreamHost{
Name: u.name, Name: u.name,
ReverseProxy: NewSingleHostReverseProxy(uri, u.without), ReverseProxy: NewSingleHostReverseProxy(uri, u.without),
ExtraHeaders: http.Header{ UpstreamHeaders: http.Header{
"Connection": {"{>Connection}"}, "Connection": {"{>Connection}"},
"Upgrade": {"{>Upgrade}"}}, "Upgrade": {"{>Upgrade}"}},
} }
......
...@@ -154,7 +154,9 @@ var InsecureTransport http.RoundTripper = &http.Transport{ ...@@ -154,7 +154,9 @@ var InsecureTransport http.RoundTripper = &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
} }
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request) error { type respUpdateFn func(resp *http.Response)
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, respUpdateFn respUpdateFn) error {
transport := p.Transport transport := p.Transport
if transport == nil { if transport == nil {
transport = http.DefaultTransport transport = http.DefaultTransport
...@@ -169,6 +171,8 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request) e ...@@ -169,6 +171,8 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request) e
res, err := transport.RoundTrip(outreq) res, err := transport.RoundTrip(outreq)
if err != nil { if err != nil {
return err return err
} else if respUpdateFn != nil {
respUpdateFn(res)
} }
if res.StatusCode == http.StatusSwitchingProtocols && strings.ToLower(res.Header.Get("Upgrade")) == "websocket" { if res.StatusCode == http.StatusSwitchingProtocols && strings.ToLower(res.Header.Get("Upgrade")) == "websocket" {
......
...@@ -20,7 +20,8 @@ var ( ...@@ -20,7 +20,8 @@ var (
type staticUpstream struct { type staticUpstream struct {
from string from string
proxyHeaders http.Header upstreamHeaders http.Header
downstreamHeaders http.Header
Hosts HostPool Hosts HostPool
Policy Policy Policy Policy
insecureSkipVerify bool insecureSkipVerify bool
...@@ -42,13 +43,14 @@ func NewStaticUpstreams(c parse.Dispenser) ([]Upstream, error) { ...@@ -42,13 +43,14 @@ func NewStaticUpstreams(c parse.Dispenser) ([]Upstream, error) {
var upstreams []Upstream var upstreams []Upstream
for c.Next() { for c.Next() {
upstream := &staticUpstream{ upstream := &staticUpstream{
from: "", from: "",
proxyHeaders: make(http.Header), upstreamHeaders: make(http.Header),
Hosts: nil, downstreamHeaders: make(http.Header),
Policy: &Random{}, Hosts: nil,
FailTimeout: 10 * time.Second, Policy: &Random{},
MaxFails: 1, FailTimeout: 10 * time.Second,
MaxConns: 0, MaxFails: 1,
MaxConns: 0,
} }
if !c.Args(&upstream.from) { if !c.Args(&upstream.from) {
...@@ -97,12 +99,13 @@ func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) { ...@@ -97,12 +99,13 @@ func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) {
host = "http://" + host host = "http://" + host
} }
uh := &UpstreamHost{ uh := &UpstreamHost{
Name: host, Name: host,
Conns: 0, Conns: 0,
Fails: 0, Fails: 0,
FailTimeout: u.FailTimeout, FailTimeout: u.FailTimeout,
Unhealthy: false, Unhealthy: false,
ExtraHeaders: u.proxyHeaders, UpstreamHeaders: u.upstreamHeaders,
DownstreamHeaders: u.downstreamHeaders,
CheckDown: func(u *staticUpstream) UpstreamHostDownFunc { CheckDown: func(u *staticUpstream) UpstreamHostDownFunc {
return func(uh *UpstreamHost) bool { return func(uh *UpstreamHost) bool {
if uh.Unhealthy { if uh.Unhealthy {
...@@ -182,15 +185,23 @@ func parseBlock(c *parse.Dispenser, u *staticUpstream) error { ...@@ -182,15 +185,23 @@ func parseBlock(c *parse.Dispenser, u *staticUpstream) error {
} }
u.HealthCheck.Interval = dur u.HealthCheck.Interval = dur
} }
case "header_upstream":
fallthrough
case "proxy_header": case "proxy_header":
var header, value string var header, value string
if !c.Args(&header, &value) { if !c.Args(&header, &value) {
return c.ArgErr() return c.ArgErr()
} }
u.proxyHeaders.Add(header, value) u.upstreamHeaders.Add(header, value)
case "header_downstream":
var header, value string
if !c.Args(&header, &value) {
return c.ArgErr()
}
u.downstreamHeaders.Add(header, value)
case "websocket": case "websocket":
u.proxyHeaders.Add("Connection", "{>Connection}") u.upstreamHeaders.Add("Connection", "{>Connection}")
u.proxyHeaders.Add("Upgrade", "{>Upgrade}") u.upstreamHeaders.Add("Upgrade", "{>Upgrade}")
case "without": case "without":
if !c.NextArg() { if !c.NextArg() {
return c.ArgErr() return c.ArgErr()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment