Commit 47b78714 authored by comp500's avatar comp500 Committed by Matt Holt

proxy: Change headers using regex (#2144)

* Add upstream header replacements (TODO: tests, docs)

* Add tests, fix a few bugs

* Add more tests and comments

* Refactor header_upstream to use a fallthrough; return regex errors
parent fda7350a
...@@ -94,6 +94,8 @@ type UpstreamHost struct { ...@@ -94,6 +94,8 @@ type UpstreamHost struct {
// is healthy and any non-zero value indicates unhealthy. // is healthy and any non-zero value indicates unhealthy.
Unhealthy int32 Unhealthy int32
HealthCheckResult atomic.Value HealthCheckResult atomic.Value
UpstreamHeaderReplacements headerReplacements
DownstreamHeaderReplacements headerReplacements
} }
// Down checks whether the upstream host is down or not. // Down checks whether the upstream host is down or not.
...@@ -220,7 +222,7 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { ...@@ -220,7 +222,7 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
// set headers for request going upstream // set headers for request going upstream
if host.UpstreamHeaders != nil { if host.UpstreamHeaders != nil {
// modify headers for request that will be sent to the upstream host // modify headers for request that will be sent to the upstream host
mutateHeadersByRules(outreq.Header, host.UpstreamHeaders, replacer) mutateHeadersByRules(outreq.Header, host.UpstreamHeaders, replacer, host.UpstreamHeaderReplacements)
if hostHeaders, ok := outreq.Header["Host"]; ok && len(hostHeaders) > 0 { if hostHeaders, ok := outreq.Header["Host"]; ok && len(hostHeaders) > 0 {
outreq.Host = hostHeaders[len(hostHeaders)-1] outreq.Host = hostHeaders[len(hostHeaders)-1]
} }
...@@ -230,7 +232,7 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { ...@@ -230,7 +232,7 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
// headers coming back downstream // headers coming back downstream
var downHeaderUpdateFn respUpdateFn var downHeaderUpdateFn respUpdateFn
if host.DownstreamHeaders != nil { if host.DownstreamHeaders != nil {
downHeaderUpdateFn = createRespHeaderUpdateFn(host.DownstreamHeaders, replacer) downHeaderUpdateFn = createRespHeaderUpdateFn(host.DownstreamHeaders, replacer, host.DownstreamHeaderReplacements)
} }
// Before we retry the request we have to make sure // Before we retry the request we have to make sure
...@@ -376,13 +378,13 @@ func createUpstreamRequest(rw http.ResponseWriter, r *http.Request) (*http.Reque ...@@ -376,13 +378,13 @@ func createUpstreamRequest(rw http.ResponseWriter, r *http.Request) (*http.Reque
return outreq, cancel return outreq, cancel
} }
func createRespHeaderUpdateFn(rules http.Header, replacer httpserver.Replacer) respUpdateFn { func createRespHeaderUpdateFn(rules http.Header, replacer httpserver.Replacer, replacements headerReplacements) respUpdateFn {
return func(resp *http.Response) { return func(resp *http.Response) {
mutateHeadersByRules(resp.Header, rules, replacer) mutateHeadersByRules(resp.Header, rules, replacer, replacements)
} }
} }
func mutateHeadersByRules(headers, rules http.Header, repl httpserver.Replacer) { func mutateHeadersByRules(headers, rules http.Header, repl httpserver.Replacer, replacements headerReplacements) {
for ruleField, ruleValues := range rules { for ruleField, ruleValues := range rules {
if strings.HasPrefix(ruleField, "+") { if strings.HasPrefix(ruleField, "+") {
for _, ruleValue := range ruleValues { for _, ruleValue := range ruleValues {
...@@ -400,6 +402,19 @@ func mutateHeadersByRules(headers, rules http.Header, repl httpserver.Replacer) ...@@ -400,6 +402,19 @@ func mutateHeadersByRules(headers, rules http.Header, repl httpserver.Replacer)
} }
} }
} }
for ruleField, ruleValues := range replacements {
for _, ruleValue := range ruleValues {
// Replace variables in replacement string
replacement := repl.Replace(ruleValue.to)
original := headers.Get(ruleField)
if len(replacement) > 0 && len(original) > 0 {
// Replace matches in original string with replacement string
replaced := ruleValue.regexp.ReplaceAllString(original, replacement)
headers.Set(ruleField, replaced)
}
}
}
} }
const CustomStatusContextCancelled = 499 const CustomStatusContextCancelled = 499
...@@ -31,6 +31,7 @@ import ( ...@@ -31,6 +31,7 @@ import (
"path" "path"
"path/filepath" "path/filepath"
"reflect" "reflect"
"regexp"
"runtime" "runtime"
"strings" "strings"
"sync" "sync"
...@@ -724,6 +725,14 @@ func TestUpstreamHeadersUpdate(t *testing.T) { ...@@ -724,6 +725,14 @@ func TestUpstreamHeadersUpdate(t *testing.T) {
"Clear-Me": {""}, "Clear-Me": {""},
"Host": {"{>Host}"}, "Host": {"{>Host}"},
} }
regex1, _ := regexp.Compile("was originally")
regex2, _ := regexp.Compile("this")
regex3, _ := regexp.Compile("bad")
upstream.host.UpstreamHeaderReplacements = headerReplacements{
"Regex-Me": {headerReplacement{regex1, "am now"}, headerReplacement{regex2, "that"}},
"Regexreplace-Me": {headerReplacement{regex3, "{hostname}"}},
}
// set up proxy // set up proxy
p := &Proxy{ p := &Proxy{
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
...@@ -740,6 +749,8 @@ func TestUpstreamHeadersUpdate(t *testing.T) { ...@@ -740,6 +749,8 @@ func TestUpstreamHeadersUpdate(t *testing.T) {
r.Header.Add("Remove-Me", "Remove-Value") r.Header.Add("Remove-Me", "Remove-Value")
r.Header.Add("Replace-Me", "Replace-Value") r.Header.Add("Replace-Me", "Replace-Value")
r.Header.Add("Host", expectHost) r.Header.Add("Host", expectHost)
r.Header.Add("Regex-Me", "I was originally this")
r.Header.Add("Regexreplace-Me", "The host is bad")
p.ServeHTTP(w, r) p.ServeHTTP(w, r)
...@@ -752,6 +763,8 @@ func TestUpstreamHeadersUpdate(t *testing.T) { ...@@ -752,6 +763,8 @@ func TestUpstreamHeadersUpdate(t *testing.T) {
"Remove-Me": nil, "Remove-Me": nil,
"Replace-Me": {replacer.Replace("{hostname}")}, "Replace-Me": {replacer.Replace("{hostname}")},
"Clear-Me": nil, "Clear-Me": nil,
"Regex-Me": {"I am now that"},
"Regexreplace-Me": {"The host is " + replacer.Replace("{hostname}")},
} { } {
if got := actualHeaders[headerKey]; !reflect.DeepEqual(got, expect) { if got := actualHeaders[headerKey]; !reflect.DeepEqual(got, expect) {
t.Errorf("Upstream request does not contain expected %v header: expect %v, but got %v", t.Errorf("Upstream request does not contain expected %v header: expect %v, but got %v",
...@@ -775,6 +788,8 @@ func TestDownstreamHeadersUpdate(t *testing.T) { ...@@ -775,6 +788,8 @@ func TestDownstreamHeadersUpdate(t *testing.T) {
w.Header().Add("Replace-Me", "Replace-Value") w.Header().Add("Replace-Me", "Replace-Value")
w.Header().Add("Content-Type", "text/html") w.Header().Add("Content-Type", "text/html")
w.Header().Add("Overwrite-Me", "Overwrite-Value") w.Header().Add("Overwrite-Me", "Overwrite-Value")
w.Header().Add("Regex-Me", "I was originally this")
w.Header().Add("Regexreplace-Me", "The host is bad")
w.Write([]byte("Hello, client")) w.Write([]byte("Hello, client"))
})) }))
defer backend.Close() defer backend.Close()
...@@ -786,6 +801,13 @@ func TestDownstreamHeadersUpdate(t *testing.T) { ...@@ -786,6 +801,13 @@ func TestDownstreamHeadersUpdate(t *testing.T) {
"-Remove-Me": {""}, "-Remove-Me": {""},
"Replace-Me": {"{hostname}"}, "Replace-Me": {"{hostname}"},
} }
regex1, _ := regexp.Compile("was originally")
regex2, _ := regexp.Compile("this")
regex3, _ := regexp.Compile("bad")
upstream.host.DownstreamHeaderReplacements = headerReplacements{
"Regex-Me": {headerReplacement{regex1, "am now"}, headerReplacement{regex2, "that"}},
"Regexreplace-Me": {headerReplacement{regex3, "{hostname}"}},
}
// set up proxy // set up proxy
p := &Proxy{ p := &Proxy{
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
...@@ -812,6 +834,8 @@ func TestDownstreamHeadersUpdate(t *testing.T) { ...@@ -812,6 +834,8 @@ func TestDownstreamHeadersUpdate(t *testing.T) {
"Replace-Me": {replacer.Replace("{hostname}")}, "Replace-Me": {replacer.Replace("{hostname}")},
"Content-Type": {"text/css"}, "Content-Type": {"text/css"},
"Overwrite-Me": {"Overwrite-Value"}, "Overwrite-Me": {"Overwrite-Value"},
"Regex-Me": {"I am now that"},
"Regexreplace-Me": {"The host is " + replacer.Replace("{hostname}")},
} { } {
if got := actualHeaders[headerKey]; !reflect.DeepEqual(got, expect) { if got := actualHeaders[headerKey]; !reflect.DeepEqual(got, expect) {
t.Errorf("Downstream response does not contain expected %s header: expect %v, but got %v", t.Errorf("Downstream response does not contain expected %s header: expect %v, but got %v",
......
...@@ -23,8 +23,10 @@ import ( ...@@ -23,8 +23,10 @@ import (
"io/ioutil" "io/ioutil"
"net" "net"
"net/http" "net/http"
"net/textproto"
"net/url" "net/url"
"path" "path"
"regexp"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
...@@ -71,12 +73,33 @@ type staticUpstream struct { ...@@ -71,12 +73,33 @@ type staticUpstream struct {
MaxFails int32 MaxFails int32
resolver srvResolver resolver srvResolver
CaCertPool *x509.CertPool CaCertPool *x509.CertPool
upstreamHeaderReplacements headerReplacements
downstreamHeaderReplacements headerReplacements
} }
type srvResolver interface { type srvResolver interface {
LookupSRV(context.Context, string, string, string) (string, []*net.SRV, error) LookupSRV(context.Context, string, string, string) (string, []*net.SRV, error)
} }
// headerReplacement stores a compiled regex matcher and a string replacer, for replacement rules
type headerReplacement struct {
regexp *regexp.Regexp
to string
}
// headerReplacements stores a mapping of canonical MIME header to headerReplacement
// Implements a subset of http.Header functions, to allow convenient addition and deletion of rules
type headerReplacements map[string][]headerReplacement
func (h headerReplacements) Add(key string, value headerReplacement) {
key = textproto.CanonicalMIMEHeaderKey(key)
h[key] = append(h[key], value)
}
func (h headerReplacements) Del(key string) {
delete(h, textproto.CanonicalMIMEHeaderKey(key))
}
// NewStaticUpstreams parses the configuration input and sets up // NewStaticUpstreams parses the configuration input and sets up
// static upstreams for the proxy middleware. The host string parameter, // static upstreams for the proxy middleware. The host string parameter,
// if not empty, is used for setting the upstream Host header for the // if not empty, is used for setting the upstream Host header for the
...@@ -98,6 +121,8 @@ func NewStaticUpstreams(c caddyfile.Dispenser, host string) ([]Upstream, error) ...@@ -98,6 +121,8 @@ func NewStaticUpstreams(c caddyfile.Dispenser, host string) ([]Upstream, error)
KeepAlive: http.DefaultMaxIdleConnsPerHost, KeepAlive: http.DefaultMaxIdleConnsPerHost,
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
resolver: net.DefaultResolver, resolver: net.DefaultResolver,
upstreamHeaderReplacements: make(headerReplacements),
downstreamHeaderReplacements: make(headerReplacements),
} }
if !c.Args(&upstream.from) { if !c.Args(&upstream.from) {
...@@ -223,6 +248,8 @@ func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) { ...@@ -223,6 +248,8 @@ func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) {
WithoutPathPrefix: u.WithoutPathPrefix, WithoutPathPrefix: u.WithoutPathPrefix,
MaxConns: u.MaxConns, MaxConns: u.MaxConns,
HealthCheckResult: atomic.Value{}, HealthCheckResult: atomic.Value{},
UpstreamHeaderReplacements: u.upstreamHeaderReplacements,
DownstreamHeaderReplacements: u.downstreamHeaderReplacements,
} }
baseURL, err := url.Parse(uh.Name) baseURL, err := url.Parse(uh.Name)
...@@ -302,6 +329,8 @@ func parseUpstream(u string) ([]string, error) { ...@@ -302,6 +329,8 @@ func parseUpstream(u string) ([]string, error) {
} }
func parseBlock(c *caddyfile.Dispenser, u *staticUpstream, hasSrv bool) error { func parseBlock(c *caddyfile.Dispenser, u *staticUpstream, hasSrv bool) error {
var isUpstream bool
switch c.Val() { switch c.Val() {
case "policy": case "policy":
if !c.NextArg() { if !c.NextArg() {
...@@ -431,23 +460,37 @@ func parseBlock(c *caddyfile.Dispenser, u *staticUpstream, hasSrv bool) error { ...@@ -431,23 +460,37 @@ func parseBlock(c *caddyfile.Dispenser, u *staticUpstream, hasSrv bool) error {
} }
u.HealthCheck.ContentString = c.Val() u.HealthCheck.ContentString = c.Val()
case "header_upstream": case "header_upstream":
var header, value string isUpstream = true
if !c.Args(&header, &value) { fallthrough
// When removing a header, the value can be optional. case "header_downstream":
if !strings.HasPrefix(header, "-") { var header, value, replaced string
if c.Args(&header, &value, &replaced) {
// Don't allow - or + in replacements
if strings.HasPrefix(header, "-") || strings.HasPrefix(header, "+") {
return c.ArgErr() return c.ArgErr()
} }
r, err := regexp.Compile(value)
if err != nil {
return err
} }
u.upstreamHeaders.Add(header, value) if isUpstream {
case "header_downstream": u.upstreamHeaderReplacements.Add(header, headerReplacement{r, replaced})
var header, value string } else {
if !c.Args(&header, &value) { u.downstreamHeaderReplacements.Add(header, headerReplacement{r, replaced})
}
} else {
if len(value) == 0 {
// When removing a header, the value can be optional. // When removing a header, the value can be optional.
if !strings.HasPrefix(header, "-") { if !strings.HasPrefix(header, "-") {
return c.ArgErr() return c.ArgErr()
} }
} }
if isUpstream {
u.upstreamHeaders.Add(header, value)
} else {
u.downstreamHeaders.Add(header, value) u.downstreamHeaders.Add(header, value)
}
}
case "transparent": case "transparent":
// Note: X-Forwarded-For header is always being appended for proxy connections // Note: X-Forwarded-For header is always being appended for proxy connections
// See implementation of createUpstreamRequest in proxy.go // See implementation of createUpstreamRequest in proxy.go
......
...@@ -386,6 +386,61 @@ func TestParseBlockTransparent(t *testing.T) { ...@@ -386,6 +386,61 @@ func TestParseBlockTransparent(t *testing.T) {
} }
} }
func TestParseBlockRegex(t *testing.T) {
// tests for regex replacement of headers
r, _ := http.NewRequest("GET", "/", nil)
tests := []struct {
config string
}{
// Test #1: transparent preset with replacement of Host
{"proxy / localhost:8080 {\n transparent \nheader_upstream Host (.*) NewHost \n}"},
// Test #2: transparent preset with replacement of another param
{"proxy / localhost:8080 {\n transparent \nheader_upstream X-Test Tester \nheader_upstream X-Test Test Host \n}"},
// Test #3: transparent preset with multiple params
{"proxy / localhost:8080 {\n transparent \nheader_upstream X-Test Tester \nheader_upstream X-Test Test Host \nheader_upstream X-Test er ing \n}"},
}
for i, test := range tests {
upstreams, err := NewStaticUpstreams(caddyfile.NewDispenser("Testfile", strings.NewReader(test.config)), "")
if err != nil {
t.Errorf("Expected no error. Got: %s", err.Error())
}
for _, upstream := range upstreams {
headers := upstream.Select(r).UpstreamHeaderReplacements
switch i {
case 0:
if host, ok := headers["Host"]; !ok || host[0].to != "NewHost" {
t.Errorf("Test %d: Incorrect Host replacement: %v", i+1, host[0])
}
case 1:
if v, ok := headers["X-Test"]; !ok {
t.Errorf("Test %d: Incorrect X-Test replacement", i+1)
} else {
if v[0].to != "Host" {
t.Errorf("Test %d: Incorrect X-Test replacement: %v", i+1, v[0])
}
}
case 2:
if v, ok := headers["X-Test"]; !ok {
t.Errorf("Test %d: Incorrect X-Test replacement", i+1)
} else {
if v[0].to != "Host" {
t.Errorf("Test %d: Incorrect X-Test replacement: %v", i+1, v[0])
}
if v[1].to != "ing" {
t.Errorf("Test %d: Incorrect X-Test replacement: %v", i+1, v[1])
}
}
default:
t.Error("Testing error")
}
}
}
}
func TestHealthSetUp(t *testing.T) { func TestHealthSetUp(t *testing.T) {
// tests for insecure skip verify // tests for insecure skip verify
tests := []struct { tests := []struct {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment