Commit 98de336a authored by Tanmay Chaudhry's avatar Tanmay Chaudhry Committed by Matt Holt

proxy: Enabled configurable timeout (#2070)

* Enabled configurable Timeout for the proxy directive

* Added Test for reverse for proxy timeout

* Removed Duplication in proxy constructors

* Remove indirection from multiple constructors and refactor into one

* Fix inconsistent error message and refactor dialer initialization
parent 9fe2ef41
......@@ -58,6 +58,10 @@ type Upstream interface {
// Gets the number of upstream hosts.
GetHostCount() int
// Gets how long to wait before timing out
// the request
GetTimeout() time.Duration
// Stops the upstream from proxying requests to shutdown goroutines cleanly.
Stop() error
}
......@@ -187,7 +191,11 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
if nameURL, err := url.Parse(host.Name); err == nil {
outreq.Host = nameURL.Host
if proxy == nil {
proxy = NewSingleHostReverseProxy(nameURL, host.WithoutPathPrefix, http.DefaultMaxIdleConnsPerHost)
proxy = NewSingleHostReverseProxy(nameURL,
host.WithoutPathPrefix,
http.DefaultMaxIdleConnsPerHost,
upstream.GetTimeout(),
)
}
// use upstream credentials by default
......
......@@ -122,7 +122,7 @@ func TestReverseProxy(t *testing.T) {
// set up proxy
p := &Proxy{
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
Upstreams: []Upstream{newFakeUpstream(backend.URL, false)},
Upstreams: []Upstream{newFakeUpstream(backend.URL, false, 30*time.Second)},
}
// Create the fake request body.
......@@ -202,7 +202,7 @@ func TestReverseProxyInsecureSkipVerify(t *testing.T) {
// set up proxy
p := &Proxy{
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
Upstreams: []Upstream{newFakeUpstream(backend.URL, true)},
Upstreams: []Upstream{newFakeUpstream(backend.URL, true, 30*time.Second)},
}
// create request and response recorder
......@@ -287,6 +287,31 @@ func TestReverseProxyMaxConnLimit(t *testing.T) {
jobs.Wait()
}
func TestReverseProxyTimeout(t *testing.T) {
timeout := 2 * time.Second
errorMargin := 100 * time.Millisecond
log.SetOutput(ioutil.Discard)
defer log.SetOutput(os.Stderr)
// set up proxy
p := &Proxy{
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
Upstreams: []Upstream{newFakeUpstream("https://8.8.8.8", true, timeout)},
}
// create request and response recorder
r := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
start := time.Now()
p.ServeHTTP(w, r)
took := time.Since(start)
if took > timeout+errorMargin {
t.Errorf("Expected timeout ~ %v but got %v", timeout, took)
}
}
func TestWebSocketReverseProxyNonHijackerPanic(t *testing.T) {
// Capture the expected panic
defer func() {
......@@ -301,7 +326,7 @@ func TestWebSocketReverseProxyNonHijackerPanic(t *testing.T) {
defer wsNop.Close()
// Get proxy to use for the test
p := newWebSocketTestProxy(wsNop.URL, false)
p := newWebSocketTestProxy(wsNop.URL, false, 30*time.Second)
// Create client request
r := httptest.NewRequest("GET", "/", nil)
......@@ -331,7 +356,7 @@ func TestWebSocketReverseProxyBackendShutDown(t *testing.T) {
}()
// Get proxy to use for the test
p := newWebSocketTestProxy(backend.URL, false)
p := newWebSocketTestProxy(backend.URL, false, 30*time.Second)
backendProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
p.ServeHTTP(w, r)
}))
......@@ -360,7 +385,7 @@ func TestWebSocketReverseProxyServeHTTPHandler(t *testing.T) {
defer wsNop.Close()
// Get proxy to use for the test
p := newWebSocketTestProxy(wsNop.URL, false)
p := newWebSocketTestProxy(wsNop.URL, false, 30*time.Second)
// Create client request
r := httptest.NewRequest("GET", "/", nil)
......@@ -407,7 +432,7 @@ func TestWebSocketReverseProxyFromWSClient(t *testing.T) {
defer wsEcho.Close()
// Get proxy to use for the test
p := newWebSocketTestProxy(wsEcho.URL, false)
p := newWebSocketTestProxy(wsEcho.URL, false, 30*time.Second)
// This is a full end-end test, so the proxy handler
// has to be part of a server listening on a port. Our
......@@ -452,7 +477,7 @@ func TestWebSocketReverseProxyFromWSSClient(t *testing.T) {
}))
defer wsEcho.Close()
p := newWebSocketTestProxy(wsEcho.URL, true)
p := newWebSocketTestProxy(wsEcho.URL, true, 30*time.Second)
echoProxy := newTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
p.ServeHTTP(w, r)
......@@ -528,7 +553,7 @@ func TestUnixSocketProxy(t *testing.T) {
defer ts.Close()
url := strings.Replace(ts.URL, "http://", "unix:", 1)
p := newWebSocketTestProxy(url, false)
p := newWebSocketTestProxy(url, false, 30*time.Second)
echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
p.ServeHTTP(w, r)
......@@ -686,7 +711,7 @@ func TestUpstreamHeadersUpdate(t *testing.T) {
}))
defer backend.Close()
upstream := newFakeUpstream(backend.URL, false)
upstream := newFakeUpstream(backend.URL, false, 30*time.Second)
upstream.host.UpstreamHeaders = http.Header{
"Connection": {"{>Connection}"},
"Upgrade": {"{>Upgrade}"},
......@@ -753,7 +778,7 @@ func TestDownstreamHeadersUpdate(t *testing.T) {
}))
defer backend.Close()
upstream := newFakeUpstream(backend.URL, false)
upstream := newFakeUpstream(backend.URL, false, 30*time.Second)
upstream.host.DownstreamHeaders = http.Header{
"+Merge-Me": {"Merge-Value"},
"+Add-Me": {"Add-Value"},
......@@ -893,7 +918,7 @@ func TestHostSimpleProxyNoHeaderForward(t *testing.T) {
// set up proxy
p := &Proxy{
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
Upstreams: []Upstream{newFakeUpstream(backend.URL, false)},
Upstreams: []Upstream{newFakeUpstream(backend.URL, false, 30*time.Second)},
}
r := httptest.NewRequest("GET", "/", nil)
......@@ -982,7 +1007,7 @@ func TestHostHeaderReplacedUsingForward(t *testing.T) {
}))
defer backend.Close()
upstream := newFakeUpstream(backend.URL, false)
upstream := newFakeUpstream(backend.URL, false, 30*time.Second)
proxyHostHeader := "test2.com"
upstream.host.UpstreamHeaders = http.Header{"Host": []string{proxyHostHeader}}
// set up proxy
......@@ -1044,7 +1069,7 @@ func basicAuthTestcase(t *testing.T, upstreamUser, clientUser *url.Userinfo) {
p := &Proxy{
Next: httpserver.EmptyNext,
Upstreams: []Upstream{newFakeUpstream(backURL.String(), false)},
Upstreams: []Upstream{newFakeUpstream(backURL.String(), false, 30*time.Second)},
}
r, err := http.NewRequest("GET", "/foo", nil)
if err != nil {
......@@ -1179,7 +1204,7 @@ func TestProxyDirectorURL(t *testing.T) {
continue
}
NewSingleHostReverseProxy(targetURL, c.without, 0).Director(req)
NewSingleHostReverseProxy(targetURL, c.without, 0, 30*time.Second).Director(req)
if expect, got := c.expectURL, req.URL.String(); expect != got {
t.Errorf("case %d url not equal: expect %q, but got %q",
i, expect, got)
......@@ -1326,7 +1351,7 @@ func TestCancelRequest(t *testing.T) {
// set up proxy
p := &Proxy{
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
Upstreams: []Upstream{newFakeUpstream(backend.URL, false)},
Upstreams: []Upstream{newFakeUpstream(backend.URL, false, 30*time.Second)},
}
// setup request with cancel ctx
......@@ -1375,14 +1400,15 @@ func (r *noopReader) Read(b []byte) (int, error) {
return n, nil
}
func newFakeUpstream(name string, insecure bool) *fakeUpstream {
func newFakeUpstream(name string, insecure bool, timeout time.Duration) *fakeUpstream {
uri, _ := url.Parse(name)
u := &fakeUpstream{
name: name,
from: "/",
name: name,
from: "/",
timeout: timeout,
host: &UpstreamHost{
Name: name,
ReverseProxy: NewSingleHostReverseProxy(uri, "", http.DefaultMaxIdleConnsPerHost),
ReverseProxy: NewSingleHostReverseProxy(uri, "", http.DefaultMaxIdleConnsPerHost, timeout),
},
}
if insecure {
......@@ -1396,6 +1422,7 @@ type fakeUpstream struct {
host *UpstreamHost
from string
without string
timeout time.Duration
}
func (u *fakeUpstream) From() string {
......@@ -1410,7 +1437,7 @@ func (u *fakeUpstream) Select(r *http.Request) *UpstreamHost {
}
u.host = &UpstreamHost{
Name: u.name,
ReverseProxy: NewSingleHostReverseProxy(uri, u.without, http.DefaultMaxIdleConnsPerHost),
ReverseProxy: NewSingleHostReverseProxy(uri, u.without, http.DefaultMaxIdleConnsPerHost, u.GetTimeout()),
}
}
return u.host
......@@ -1419,6 +1446,7 @@ func (u *fakeUpstream) Select(r *http.Request) *UpstreamHost {
func (u *fakeUpstream) AllowedPath(requestPath string) bool { return true }
func (u *fakeUpstream) GetTryDuration() time.Duration { return 1 * time.Second }
func (u *fakeUpstream) GetTryInterval() time.Duration { return 250 * time.Millisecond }
func (u *fakeUpstream) GetTimeout() time.Duration { return u.timeout }
func (u *fakeUpstream) GetHostCount() int { return 1 }
func (u *fakeUpstream) Stop() error { return nil }
......@@ -1426,13 +1454,14 @@ func (u *fakeUpstream) Stop() error { return nil }
// redirect to the specified backendAddr. The function
// also sets up the rules/environment for testing WebSocket
// proxy.
func newWebSocketTestProxy(backendAddr string, insecure bool) *Proxy {
func newWebSocketTestProxy(backendAddr string, insecure bool, timeout time.Duration) *Proxy {
return &Proxy{
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
Upstreams: []Upstream{&fakeWsUpstream{
name: backendAddr,
without: "",
insecure: insecure,
timeout: timeout,
}},
}
}
......@@ -1440,7 +1469,7 @@ func newWebSocketTestProxy(backendAddr string, insecure bool) *Proxy {
func newPrefixedWebSocketTestProxy(backendAddr string, prefix string) *Proxy {
return &Proxy{
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
Upstreams: []Upstream{&fakeWsUpstream{name: backendAddr, without: prefix}},
Upstreams: []Upstream{&fakeWsUpstream{name: backendAddr, without: prefix, timeout: 30 * time.Second}},
}
}
......@@ -1448,6 +1477,7 @@ type fakeWsUpstream struct {
name string
without string
insecure bool
timeout time.Duration
}
func (u *fakeWsUpstream) From() string {
......@@ -1458,7 +1488,7 @@ func (u *fakeWsUpstream) Select(r *http.Request) *UpstreamHost {
uri, _ := url.Parse(u.name)
host := &UpstreamHost{
Name: u.name,
ReverseProxy: NewSingleHostReverseProxy(uri, u.without, http.DefaultMaxIdleConnsPerHost),
ReverseProxy: NewSingleHostReverseProxy(uri, u.without, http.DefaultMaxIdleConnsPerHost, u.GetTimeout()),
UpstreamHeaders: http.Header{
"Connection": {"{>Connection}"},
"Upgrade": {"{>Upgrade}"}},
......@@ -1472,6 +1502,7 @@ func (u *fakeWsUpstream) Select(r *http.Request) *UpstreamHost {
func (u *fakeWsUpstream) AllowedPath(requestPath string) bool { return true }
func (u *fakeWsUpstream) GetTryDuration() time.Duration { return 1 * time.Second }
func (u *fakeWsUpstream) GetTryInterval() time.Duration { return 250 * time.Millisecond }
func (u *fakeWsUpstream) GetTimeout() time.Duration { return u.timeout }
func (u *fakeWsUpstream) GetHostCount() int { return 1 }
func (u *fakeWsUpstream) Stop() error { return nil }
......@@ -1517,7 +1548,7 @@ func BenchmarkProxy(b *testing.B) {
}))
defer backend.Close()
upstream := newFakeUpstream(backend.URL, false)
upstream := newFakeUpstream(backend.URL, false, 30*time.Second)
upstream.host.UpstreamHeaders = http.Header{
"Hostname": {"{hostname}"},
"Host": {"{host}"},
......@@ -1560,7 +1591,7 @@ func TestChunkedWebSocketReverseProxy(t *testing.T) {
defer wsNop.Close()
// Get proxy to use for the test
p := newWebSocketTestProxy(wsNop.URL, false)
p := newWebSocketTestProxy(wsNop.URL, false, 30*time.Second)
// Create client request
r := httptest.NewRequest("GET", "/", nil)
......
......@@ -94,6 +94,10 @@ type ReverseProxy struct {
// If zero, no periodic flushing is done.
FlushInterval time.Duration
// dialer is used when values from the
// defaultDialer need to be overridden per Proxy
dialer *net.Dialer
srvResolver srvResolver
}
......@@ -103,13 +107,13 @@ type ReverseProxy struct {
// What we need is just the path, so if "unix:/var/run/www.socket"
// was the proxy directive, the parsed hostName would be
// "unix:///var/run/www.socket", hence the ambiguous trimming.
func socketDial(hostName string) func(network, addr string) (conn net.Conn, err error) {
func socketDial(hostName string, timeout time.Duration) func(network, addr string) (conn net.Conn, err error) {
return func(network, addr string) (conn net.Conn, err error) {
return net.Dial("unix", hostName[len("unix://"):])
return net.DialTimeout("unix", hostName[len("unix://"):], timeout)
}
}
func (rp *ReverseProxy) srvDialerFunc(locator string) func(network, addr string) (conn net.Conn, err error) {
func (rp *ReverseProxy) srvDialerFunc(locator string, timeout time.Duration) func(network, addr string) (conn net.Conn, err error) {
service := locator
if strings.HasPrefix(locator, "srv://") {
service = locator[6:]
......@@ -122,7 +126,7 @@ func (rp *ReverseProxy) srvDialerFunc(locator string) func(network, addr string)
if err != nil {
return nil, err
}
return net.Dial("tcp", fmt.Sprintf("%s:%d", addrs[0].Target, addrs[0].Port))
return net.DialTimeout("tcp", fmt.Sprintf("%s:%d", addrs[0].Target, addrs[0].Port), timeout)
}
}
......@@ -144,7 +148,7 @@ func singleJoiningSlash(a, b string) string {
// the target request will be for /base/dir.
// Without logic: target's path is "/", incoming is "/api/messages",
// without is "/api", then the target request will be for /messages.
func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *ReverseProxy {
func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int, timeout time.Duration) *ReverseProxy {
targetQuery := target.RawQuery
director := func(req *http.Request) {
if target.Scheme == "unix" {
......@@ -226,15 +230,21 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
}
}
dialer := *defaultDialer
if timeout != defaultDialer.Timeout {
dialer.Timeout = timeout
}
rp := &ReverseProxy{
Director: director,
FlushInterval: 250 * time.Millisecond, // flushing good for streaming & server-sent events
srvResolver: net.DefaultResolver,
dialer: &dialer,
}
if target.Scheme == "unix" {
rp.Transport = &http.Transport{
Dial: socketDial(target.String()),
Dial: socketDial(target.String(), timeout),
}
} else if target.Scheme == "quic" {
rp.Transport = &h2quic.RoundTripper{
......@@ -244,9 +254,9 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
},
}
} else if keepalive != http.DefaultMaxIdleConnsPerHost || strings.HasPrefix(target.Scheme, "srv") {
dialFunc := defaultDialer.Dial
dialFunc := rp.dialer.Dial
if strings.HasPrefix(target.Scheme, "srv") {
dialFunc = rp.srvDialerFunc(target.String())
dialFunc = rp.srvDialerFunc(target.String(), timeout)
}
transport := &http.Transport{
......@@ -275,7 +285,7 @@ func (rp *ReverseProxy) UseInsecureTransport() {
if rp.Transport == nil {
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
Dial: defaultDialer.Dial,
Dial: rp.dialer.Dial,
TLSHandshakeTimeout: defaultCryptoHandshakeTimeout,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
......@@ -306,7 +316,9 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request,
if requestIsWebsocket(outreq) {
transport = newConnHijackerTransport(transport)
} else if transport == nil {
transport = http.DefaultTransport
transport = &http.Transport{
Dial: rp.dialer.Dial,
}
}
rp.Director(outreq)
......@@ -361,7 +373,7 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request,
}
bufferPool.Put(hj.Replay)
} else {
backendConn, err = net.Dial("tcp", outreq.URL.Host)
backendConn, err = net.DialTimeout("tcp", outreq.URL.Host, rp.dialer.Timeout)
if err != nil {
return err
}
......
......@@ -21,6 +21,7 @@ import (
"net/url"
"strconv"
"testing"
"time"
)
const (
......@@ -66,7 +67,7 @@ func TestSingleSRVHostReverseProxy(t *testing.T) {
}
port := uint16(pp)
rp := NewSingleHostReverseProxy(target, "", http.DefaultMaxIdleConnsPerHost)
rp := NewSingleHostReverseProxy(target, "", http.DefaultMaxIdleConnsPerHost, 30*time.Second)
rp.srvResolver = testResolver{
result: []*net.SRV{
{Target: upstream.Hostname(), Port: port, Priority: 1, Weight: 1},
......
......@@ -49,6 +49,7 @@ type staticUpstream struct {
Hosts HostPool
Policy Policy
KeepAlive int
Timeout time.Duration
FailTimeout time.Duration
TryDuration time.Duration
TryInterval time.Duration
......@@ -92,6 +93,7 @@ func NewStaticUpstreams(c caddyfile.Dispenser, host string) ([]Upstream, error)
TryInterval: 250 * time.Millisecond,
MaxConns: 0,
KeepAlive: http.DefaultMaxIdleConnsPerHost,
Timeout: 30 * time.Second,
resolver: net.DefaultResolver,
}
......@@ -225,7 +227,7 @@ func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) {
return nil, err
}
uh.ReverseProxy = NewSingleHostReverseProxy(baseURL, uh.WithoutPathPrefix, u.KeepAlive)
uh.ReverseProxy = NewSingleHostReverseProxy(baseURL, uh.WithoutPathPrefix, u.KeepAlive, u.Timeout)
if u.insecureSkipVerify {
uh.ReverseProxy.UseInsecureTransport()
}
......@@ -464,6 +466,15 @@ func parseBlock(c *caddyfile.Dispenser, u *staticUpstream, hasSrv bool) error {
return c.ArgErr()
}
u.KeepAlive = n
case "timeout":
if !c.NextArg() {
return c.ArgErr()
}
dur, err := time.ParseDuration(c.Val())
if err != nil {
return c.Errf("unable to parse timeout duration '%s'", c.Val())
}
u.Timeout = dur
default:
return c.Errf("unknown property '%s'", c.Val())
}
......@@ -619,6 +630,11 @@ func (u *staticUpstream) GetTryInterval() time.Duration {
return u.TryInterval
}
// GetTimeout returns u.Timeout.
func (u *staticUpstream) GetTimeout() time.Duration {
return u.Timeout
}
func (u *staticUpstream) GetHostCount() int {
return len(u.Hosts)
}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment