Commit 98de336a authored by Tanmay Chaudhry's avatar Tanmay Chaudhry Committed by Matt Holt

proxy: Enabled configurable timeout (#2070)

* Enabled configurable Timeout for the proxy directive

* Added Test for reverse for proxy timeout

* Removed Duplication in proxy constructors

* Remove indirection from multiple constructors and refactor into one

* Fix inconsistent error message and refactor dialer initialization
parent 9fe2ef41
...@@ -58,6 +58,10 @@ type Upstream interface { ...@@ -58,6 +58,10 @@ type Upstream interface {
// Gets the number of upstream hosts. // Gets the number of upstream hosts.
GetHostCount() int GetHostCount() int
// Gets how long to wait before timing out
// the request
GetTimeout() time.Duration
// Stops the upstream from proxying requests to shutdown goroutines cleanly. // Stops the upstream from proxying requests to shutdown goroutines cleanly.
Stop() error Stop() error
} }
...@@ -187,7 +191,11 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { ...@@ -187,7 +191,11 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
if nameURL, err := url.Parse(host.Name); err == nil { if nameURL, err := url.Parse(host.Name); err == nil {
outreq.Host = nameURL.Host outreq.Host = nameURL.Host
if proxy == nil { if proxy == nil {
proxy = NewSingleHostReverseProxy(nameURL, host.WithoutPathPrefix, http.DefaultMaxIdleConnsPerHost) proxy = NewSingleHostReverseProxy(nameURL,
host.WithoutPathPrefix,
http.DefaultMaxIdleConnsPerHost,
upstream.GetTimeout(),
)
} }
// use upstream credentials by default // use upstream credentials by default
......
...@@ -122,7 +122,7 @@ func TestReverseProxy(t *testing.T) { ...@@ -122,7 +122,7 @@ func TestReverseProxy(t *testing.T) {
// set up proxy // set up proxy
p := &Proxy{ p := &Proxy{
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
Upstreams: []Upstream{newFakeUpstream(backend.URL, false)}, Upstreams: []Upstream{newFakeUpstream(backend.URL, false, 30*time.Second)},
} }
// Create the fake request body. // Create the fake request body.
...@@ -202,7 +202,7 @@ func TestReverseProxyInsecureSkipVerify(t *testing.T) { ...@@ -202,7 +202,7 @@ func TestReverseProxyInsecureSkipVerify(t *testing.T) {
// set up proxy // set up proxy
p := &Proxy{ p := &Proxy{
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
Upstreams: []Upstream{newFakeUpstream(backend.URL, true)}, Upstreams: []Upstream{newFakeUpstream(backend.URL, true, 30*time.Second)},
} }
// create request and response recorder // create request and response recorder
...@@ -287,6 +287,31 @@ func TestReverseProxyMaxConnLimit(t *testing.T) { ...@@ -287,6 +287,31 @@ func TestReverseProxyMaxConnLimit(t *testing.T) {
jobs.Wait() jobs.Wait()
} }
func TestReverseProxyTimeout(t *testing.T) {
timeout := 2 * time.Second
errorMargin := 100 * time.Millisecond
log.SetOutput(ioutil.Discard)
defer log.SetOutput(os.Stderr)
// set up proxy
p := &Proxy{
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
Upstreams: []Upstream{newFakeUpstream("https://8.8.8.8", true, timeout)},
}
// create request and response recorder
r := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
start := time.Now()
p.ServeHTTP(w, r)
took := time.Since(start)
if took > timeout+errorMargin {
t.Errorf("Expected timeout ~ %v but got %v", timeout, took)
}
}
func TestWebSocketReverseProxyNonHijackerPanic(t *testing.T) { func TestWebSocketReverseProxyNonHijackerPanic(t *testing.T) {
// Capture the expected panic // Capture the expected panic
defer func() { defer func() {
...@@ -301,7 +326,7 @@ func TestWebSocketReverseProxyNonHijackerPanic(t *testing.T) { ...@@ -301,7 +326,7 @@ func TestWebSocketReverseProxyNonHijackerPanic(t *testing.T) {
defer wsNop.Close() defer wsNop.Close()
// Get proxy to use for the test // Get proxy to use for the test
p := newWebSocketTestProxy(wsNop.URL, false) p := newWebSocketTestProxy(wsNop.URL, false, 30*time.Second)
// Create client request // Create client request
r := httptest.NewRequest("GET", "/", nil) r := httptest.NewRequest("GET", "/", nil)
...@@ -331,7 +356,7 @@ func TestWebSocketReverseProxyBackendShutDown(t *testing.T) { ...@@ -331,7 +356,7 @@ func TestWebSocketReverseProxyBackendShutDown(t *testing.T) {
}() }()
// Get proxy to use for the test // Get proxy to use for the test
p := newWebSocketTestProxy(backend.URL, false) p := newWebSocketTestProxy(backend.URL, false, 30*time.Second)
backendProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { backendProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
p.ServeHTTP(w, r) p.ServeHTTP(w, r)
})) }))
...@@ -360,7 +385,7 @@ func TestWebSocketReverseProxyServeHTTPHandler(t *testing.T) { ...@@ -360,7 +385,7 @@ func TestWebSocketReverseProxyServeHTTPHandler(t *testing.T) {
defer wsNop.Close() defer wsNop.Close()
// Get proxy to use for the test // Get proxy to use for the test
p := newWebSocketTestProxy(wsNop.URL, false) p := newWebSocketTestProxy(wsNop.URL, false, 30*time.Second)
// Create client request // Create client request
r := httptest.NewRequest("GET", "/", nil) r := httptest.NewRequest("GET", "/", nil)
...@@ -407,7 +432,7 @@ func TestWebSocketReverseProxyFromWSClient(t *testing.T) { ...@@ -407,7 +432,7 @@ func TestWebSocketReverseProxyFromWSClient(t *testing.T) {
defer wsEcho.Close() defer wsEcho.Close()
// Get proxy to use for the test // Get proxy to use for the test
p := newWebSocketTestProxy(wsEcho.URL, false) p := newWebSocketTestProxy(wsEcho.URL, false, 30*time.Second)
// This is a full end-end test, so the proxy handler // This is a full end-end test, so the proxy handler
// has to be part of a server listening on a port. Our // has to be part of a server listening on a port. Our
...@@ -452,7 +477,7 @@ func TestWebSocketReverseProxyFromWSSClient(t *testing.T) { ...@@ -452,7 +477,7 @@ func TestWebSocketReverseProxyFromWSSClient(t *testing.T) {
})) }))
defer wsEcho.Close() defer wsEcho.Close()
p := newWebSocketTestProxy(wsEcho.URL, true) p := newWebSocketTestProxy(wsEcho.URL, true, 30*time.Second)
echoProxy := newTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { echoProxy := newTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
p.ServeHTTP(w, r) p.ServeHTTP(w, r)
...@@ -528,7 +553,7 @@ func TestUnixSocketProxy(t *testing.T) { ...@@ -528,7 +553,7 @@ func TestUnixSocketProxy(t *testing.T) {
defer ts.Close() defer ts.Close()
url := strings.Replace(ts.URL, "http://", "unix:", 1) url := strings.Replace(ts.URL, "http://", "unix:", 1)
p := newWebSocketTestProxy(url, false) p := newWebSocketTestProxy(url, false, 30*time.Second)
echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
p.ServeHTTP(w, r) p.ServeHTTP(w, r)
...@@ -686,7 +711,7 @@ func TestUpstreamHeadersUpdate(t *testing.T) { ...@@ -686,7 +711,7 @@ func TestUpstreamHeadersUpdate(t *testing.T) {
})) }))
defer backend.Close() defer backend.Close()
upstream := newFakeUpstream(backend.URL, false) upstream := newFakeUpstream(backend.URL, false, 30*time.Second)
upstream.host.UpstreamHeaders = http.Header{ upstream.host.UpstreamHeaders = http.Header{
"Connection": {"{>Connection}"}, "Connection": {"{>Connection}"},
"Upgrade": {"{>Upgrade}"}, "Upgrade": {"{>Upgrade}"},
...@@ -753,7 +778,7 @@ func TestDownstreamHeadersUpdate(t *testing.T) { ...@@ -753,7 +778,7 @@ func TestDownstreamHeadersUpdate(t *testing.T) {
})) }))
defer backend.Close() defer backend.Close()
upstream := newFakeUpstream(backend.URL, false) upstream := newFakeUpstream(backend.URL, false, 30*time.Second)
upstream.host.DownstreamHeaders = http.Header{ upstream.host.DownstreamHeaders = http.Header{
"+Merge-Me": {"Merge-Value"}, "+Merge-Me": {"Merge-Value"},
"+Add-Me": {"Add-Value"}, "+Add-Me": {"Add-Value"},
...@@ -893,7 +918,7 @@ func TestHostSimpleProxyNoHeaderForward(t *testing.T) { ...@@ -893,7 +918,7 @@ func TestHostSimpleProxyNoHeaderForward(t *testing.T) {
// set up proxy // set up proxy
p := &Proxy{ p := &Proxy{
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
Upstreams: []Upstream{newFakeUpstream(backend.URL, false)}, Upstreams: []Upstream{newFakeUpstream(backend.URL, false, 30*time.Second)},
} }
r := httptest.NewRequest("GET", "/", nil) r := httptest.NewRequest("GET", "/", nil)
...@@ -982,7 +1007,7 @@ func TestHostHeaderReplacedUsingForward(t *testing.T) { ...@@ -982,7 +1007,7 @@ func TestHostHeaderReplacedUsingForward(t *testing.T) {
})) }))
defer backend.Close() defer backend.Close()
upstream := newFakeUpstream(backend.URL, false) upstream := newFakeUpstream(backend.URL, false, 30*time.Second)
proxyHostHeader := "test2.com" proxyHostHeader := "test2.com"
upstream.host.UpstreamHeaders = http.Header{"Host": []string{proxyHostHeader}} upstream.host.UpstreamHeaders = http.Header{"Host": []string{proxyHostHeader}}
// set up proxy // set up proxy
...@@ -1044,7 +1069,7 @@ func basicAuthTestcase(t *testing.T, upstreamUser, clientUser *url.Userinfo) { ...@@ -1044,7 +1069,7 @@ func basicAuthTestcase(t *testing.T, upstreamUser, clientUser *url.Userinfo) {
p := &Proxy{ p := &Proxy{
Next: httpserver.EmptyNext, Next: httpserver.EmptyNext,
Upstreams: []Upstream{newFakeUpstream(backURL.String(), false)}, Upstreams: []Upstream{newFakeUpstream(backURL.String(), false, 30*time.Second)},
} }
r, err := http.NewRequest("GET", "/foo", nil) r, err := http.NewRequest("GET", "/foo", nil)
if err != nil { if err != nil {
...@@ -1179,7 +1204,7 @@ func TestProxyDirectorURL(t *testing.T) { ...@@ -1179,7 +1204,7 @@ func TestProxyDirectorURL(t *testing.T) {
continue continue
} }
NewSingleHostReverseProxy(targetURL, c.without, 0).Director(req) NewSingleHostReverseProxy(targetURL, c.without, 0, 30*time.Second).Director(req)
if expect, got := c.expectURL, req.URL.String(); expect != got { if expect, got := c.expectURL, req.URL.String(); expect != got {
t.Errorf("case %d url not equal: expect %q, but got %q", t.Errorf("case %d url not equal: expect %q, but got %q",
i, expect, got) i, expect, got)
...@@ -1326,7 +1351,7 @@ func TestCancelRequest(t *testing.T) { ...@@ -1326,7 +1351,7 @@ func TestCancelRequest(t *testing.T) {
// set up proxy // set up proxy
p := &Proxy{ p := &Proxy{
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
Upstreams: []Upstream{newFakeUpstream(backend.URL, false)}, Upstreams: []Upstream{newFakeUpstream(backend.URL, false, 30*time.Second)},
} }
// setup request with cancel ctx // setup request with cancel ctx
...@@ -1375,14 +1400,15 @@ func (r *noopReader) Read(b []byte) (int, error) { ...@@ -1375,14 +1400,15 @@ func (r *noopReader) Read(b []byte) (int, error) {
return n, nil return n, nil
} }
func newFakeUpstream(name string, insecure bool) *fakeUpstream { func newFakeUpstream(name string, insecure bool, timeout time.Duration) *fakeUpstream {
uri, _ := url.Parse(name) uri, _ := url.Parse(name)
u := &fakeUpstream{ u := &fakeUpstream{
name: name, name: name,
from: "/", from: "/",
timeout: timeout,
host: &UpstreamHost{ host: &UpstreamHost{
Name: name, Name: name,
ReverseProxy: NewSingleHostReverseProxy(uri, "", http.DefaultMaxIdleConnsPerHost), ReverseProxy: NewSingleHostReverseProxy(uri, "", http.DefaultMaxIdleConnsPerHost, timeout),
}, },
} }
if insecure { if insecure {
...@@ -1396,6 +1422,7 @@ type fakeUpstream struct { ...@@ -1396,6 +1422,7 @@ type fakeUpstream struct {
host *UpstreamHost host *UpstreamHost
from string from string
without string without string
timeout time.Duration
} }
func (u *fakeUpstream) From() string { func (u *fakeUpstream) From() string {
...@@ -1410,7 +1437,7 @@ func (u *fakeUpstream) Select(r *http.Request) *UpstreamHost { ...@@ -1410,7 +1437,7 @@ func (u *fakeUpstream) Select(r *http.Request) *UpstreamHost {
} }
u.host = &UpstreamHost{ u.host = &UpstreamHost{
Name: u.name, Name: u.name,
ReverseProxy: NewSingleHostReverseProxy(uri, u.without, http.DefaultMaxIdleConnsPerHost), ReverseProxy: NewSingleHostReverseProxy(uri, u.without, http.DefaultMaxIdleConnsPerHost, u.GetTimeout()),
} }
} }
return u.host return u.host
...@@ -1419,6 +1446,7 @@ func (u *fakeUpstream) Select(r *http.Request) *UpstreamHost { ...@@ -1419,6 +1446,7 @@ func (u *fakeUpstream) Select(r *http.Request) *UpstreamHost {
func (u *fakeUpstream) AllowedPath(requestPath string) bool { return true } func (u *fakeUpstream) AllowedPath(requestPath string) bool { return true }
func (u *fakeUpstream) GetTryDuration() time.Duration { return 1 * time.Second } func (u *fakeUpstream) GetTryDuration() time.Duration { return 1 * time.Second }
func (u *fakeUpstream) GetTryInterval() time.Duration { return 250 * time.Millisecond } func (u *fakeUpstream) GetTryInterval() time.Duration { return 250 * time.Millisecond }
func (u *fakeUpstream) GetTimeout() time.Duration { return u.timeout }
func (u *fakeUpstream) GetHostCount() int { return 1 } func (u *fakeUpstream) GetHostCount() int { return 1 }
func (u *fakeUpstream) Stop() error { return nil } func (u *fakeUpstream) Stop() error { return nil }
...@@ -1426,13 +1454,14 @@ func (u *fakeUpstream) Stop() error { return nil } ...@@ -1426,13 +1454,14 @@ func (u *fakeUpstream) Stop() error { return nil }
// redirect to the specified backendAddr. The function // redirect to the specified backendAddr. The function
// also sets up the rules/environment for testing WebSocket // also sets up the rules/environment for testing WebSocket
// proxy. // proxy.
func newWebSocketTestProxy(backendAddr string, insecure bool) *Proxy { func newWebSocketTestProxy(backendAddr string, insecure bool, timeout time.Duration) *Proxy {
return &Proxy{ return &Proxy{
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
Upstreams: []Upstream{&fakeWsUpstream{ Upstreams: []Upstream{&fakeWsUpstream{
name: backendAddr, name: backendAddr,
without: "", without: "",
insecure: insecure, insecure: insecure,
timeout: timeout,
}}, }},
} }
} }
...@@ -1440,7 +1469,7 @@ func newWebSocketTestProxy(backendAddr string, insecure bool) *Proxy { ...@@ -1440,7 +1469,7 @@ func newWebSocketTestProxy(backendAddr string, insecure bool) *Proxy {
func newPrefixedWebSocketTestProxy(backendAddr string, prefix string) *Proxy { func newPrefixedWebSocketTestProxy(backendAddr string, prefix string) *Proxy {
return &Proxy{ return &Proxy{
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
Upstreams: []Upstream{&fakeWsUpstream{name: backendAddr, without: prefix}}, Upstreams: []Upstream{&fakeWsUpstream{name: backendAddr, without: prefix, timeout: 30 * time.Second}},
} }
} }
...@@ -1448,6 +1477,7 @@ type fakeWsUpstream struct { ...@@ -1448,6 +1477,7 @@ type fakeWsUpstream struct {
name string name string
without string without string
insecure bool insecure bool
timeout time.Duration
} }
func (u *fakeWsUpstream) From() string { func (u *fakeWsUpstream) From() string {
...@@ -1458,7 +1488,7 @@ func (u *fakeWsUpstream) Select(r *http.Request) *UpstreamHost { ...@@ -1458,7 +1488,7 @@ func (u *fakeWsUpstream) Select(r *http.Request) *UpstreamHost {
uri, _ := url.Parse(u.name) uri, _ := url.Parse(u.name)
host := &UpstreamHost{ host := &UpstreamHost{
Name: u.name, Name: u.name,
ReverseProxy: NewSingleHostReverseProxy(uri, u.without, http.DefaultMaxIdleConnsPerHost), ReverseProxy: NewSingleHostReverseProxy(uri, u.without, http.DefaultMaxIdleConnsPerHost, u.GetTimeout()),
UpstreamHeaders: http.Header{ UpstreamHeaders: http.Header{
"Connection": {"{>Connection}"}, "Connection": {"{>Connection}"},
"Upgrade": {"{>Upgrade}"}}, "Upgrade": {"{>Upgrade}"}},
...@@ -1472,6 +1502,7 @@ func (u *fakeWsUpstream) Select(r *http.Request) *UpstreamHost { ...@@ -1472,6 +1502,7 @@ func (u *fakeWsUpstream) Select(r *http.Request) *UpstreamHost {
func (u *fakeWsUpstream) AllowedPath(requestPath string) bool { return true } func (u *fakeWsUpstream) AllowedPath(requestPath string) bool { return true }
func (u *fakeWsUpstream) GetTryDuration() time.Duration { return 1 * time.Second } func (u *fakeWsUpstream) GetTryDuration() time.Duration { return 1 * time.Second }
func (u *fakeWsUpstream) GetTryInterval() time.Duration { return 250 * time.Millisecond } func (u *fakeWsUpstream) GetTryInterval() time.Duration { return 250 * time.Millisecond }
func (u *fakeWsUpstream) GetTimeout() time.Duration { return u.timeout }
func (u *fakeWsUpstream) GetHostCount() int { return 1 } func (u *fakeWsUpstream) GetHostCount() int { return 1 }
func (u *fakeWsUpstream) Stop() error { return nil } func (u *fakeWsUpstream) Stop() error { return nil }
...@@ -1517,7 +1548,7 @@ func BenchmarkProxy(b *testing.B) { ...@@ -1517,7 +1548,7 @@ func BenchmarkProxy(b *testing.B) {
})) }))
defer backend.Close() defer backend.Close()
upstream := newFakeUpstream(backend.URL, false) upstream := newFakeUpstream(backend.URL, false, 30*time.Second)
upstream.host.UpstreamHeaders = http.Header{ upstream.host.UpstreamHeaders = http.Header{
"Hostname": {"{hostname}"}, "Hostname": {"{hostname}"},
"Host": {"{host}"}, "Host": {"{host}"},
...@@ -1560,7 +1591,7 @@ func TestChunkedWebSocketReverseProxy(t *testing.T) { ...@@ -1560,7 +1591,7 @@ func TestChunkedWebSocketReverseProxy(t *testing.T) {
defer wsNop.Close() defer wsNop.Close()
// Get proxy to use for the test // Get proxy to use for the test
p := newWebSocketTestProxy(wsNop.URL, false) p := newWebSocketTestProxy(wsNop.URL, false, 30*time.Second)
// Create client request // Create client request
r := httptest.NewRequest("GET", "/", nil) r := httptest.NewRequest("GET", "/", nil)
......
...@@ -94,6 +94,10 @@ type ReverseProxy struct { ...@@ -94,6 +94,10 @@ type ReverseProxy struct {
// If zero, no periodic flushing is done. // If zero, no periodic flushing is done.
FlushInterval time.Duration FlushInterval time.Duration
// dialer is used when values from the
// defaultDialer need to be overridden per Proxy
dialer *net.Dialer
srvResolver srvResolver srvResolver srvResolver
} }
...@@ -103,13 +107,13 @@ type ReverseProxy struct { ...@@ -103,13 +107,13 @@ type ReverseProxy struct {
// What we need is just the path, so if "unix:/var/run/www.socket" // What we need is just the path, so if "unix:/var/run/www.socket"
// was the proxy directive, the parsed hostName would be // was the proxy directive, the parsed hostName would be
// "unix:///var/run/www.socket", hence the ambiguous trimming. // "unix:///var/run/www.socket", hence the ambiguous trimming.
func socketDial(hostName string) func(network, addr string) (conn net.Conn, err error) { func socketDial(hostName string, timeout time.Duration) func(network, addr string) (conn net.Conn, err error) {
return func(network, addr string) (conn net.Conn, err error) { return func(network, addr string) (conn net.Conn, err error) {
return net.Dial("unix", hostName[len("unix://"):]) return net.DialTimeout("unix", hostName[len("unix://"):], timeout)
} }
} }
func (rp *ReverseProxy) srvDialerFunc(locator string) func(network, addr string) (conn net.Conn, err error) { func (rp *ReverseProxy) srvDialerFunc(locator string, timeout time.Duration) func(network, addr string) (conn net.Conn, err error) {
service := locator service := locator
if strings.HasPrefix(locator, "srv://") { if strings.HasPrefix(locator, "srv://") {
service = locator[6:] service = locator[6:]
...@@ -122,7 +126,7 @@ func (rp *ReverseProxy) srvDialerFunc(locator string) func(network, addr string) ...@@ -122,7 +126,7 @@ func (rp *ReverseProxy) srvDialerFunc(locator string) func(network, addr string)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return net.Dial("tcp", fmt.Sprintf("%s:%d", addrs[0].Target, addrs[0].Port)) return net.DialTimeout("tcp", fmt.Sprintf("%s:%d", addrs[0].Target, addrs[0].Port), timeout)
} }
} }
...@@ -144,7 +148,7 @@ func singleJoiningSlash(a, b string) string { ...@@ -144,7 +148,7 @@ func singleJoiningSlash(a, b string) string {
// the target request will be for /base/dir. // the target request will be for /base/dir.
// Without logic: target's path is "/", incoming is "/api/messages", // Without logic: target's path is "/", incoming is "/api/messages",
// without is "/api", then the target request will be for /messages. // without is "/api", then the target request will be for /messages.
func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *ReverseProxy { func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int, timeout time.Duration) *ReverseProxy {
targetQuery := target.RawQuery targetQuery := target.RawQuery
director := func(req *http.Request) { director := func(req *http.Request) {
if target.Scheme == "unix" { if target.Scheme == "unix" {
...@@ -226,15 +230,21 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) * ...@@ -226,15 +230,21 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
} }
} }
dialer := *defaultDialer
if timeout != defaultDialer.Timeout {
dialer.Timeout = timeout
}
rp := &ReverseProxy{ rp := &ReverseProxy{
Director: director, Director: director,
FlushInterval: 250 * time.Millisecond, // flushing good for streaming & server-sent events FlushInterval: 250 * time.Millisecond, // flushing good for streaming & server-sent events
srvResolver: net.DefaultResolver, srvResolver: net.DefaultResolver,
dialer: &dialer,
} }
if target.Scheme == "unix" { if target.Scheme == "unix" {
rp.Transport = &http.Transport{ rp.Transport = &http.Transport{
Dial: socketDial(target.String()), Dial: socketDial(target.String(), timeout),
} }
} else if target.Scheme == "quic" { } else if target.Scheme == "quic" {
rp.Transport = &h2quic.RoundTripper{ rp.Transport = &h2quic.RoundTripper{
...@@ -244,9 +254,9 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) * ...@@ -244,9 +254,9 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
}, },
} }
} else if keepalive != http.DefaultMaxIdleConnsPerHost || strings.HasPrefix(target.Scheme, "srv") { } else if keepalive != http.DefaultMaxIdleConnsPerHost || strings.HasPrefix(target.Scheme, "srv") {
dialFunc := defaultDialer.Dial dialFunc := rp.dialer.Dial
if strings.HasPrefix(target.Scheme, "srv") { if strings.HasPrefix(target.Scheme, "srv") {
dialFunc = rp.srvDialerFunc(target.String()) dialFunc = rp.srvDialerFunc(target.String(), timeout)
} }
transport := &http.Transport{ transport := &http.Transport{
...@@ -275,7 +285,7 @@ func (rp *ReverseProxy) UseInsecureTransport() { ...@@ -275,7 +285,7 @@ func (rp *ReverseProxy) UseInsecureTransport() {
if rp.Transport == nil { if rp.Transport == nil {
transport := &http.Transport{ transport := &http.Transport{
Proxy: http.ProxyFromEnvironment, Proxy: http.ProxyFromEnvironment,
Dial: defaultDialer.Dial, Dial: rp.dialer.Dial,
TLSHandshakeTimeout: defaultCryptoHandshakeTimeout, TLSHandshakeTimeout: defaultCryptoHandshakeTimeout,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
} }
...@@ -306,7 +316,9 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, ...@@ -306,7 +316,9 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request,
if requestIsWebsocket(outreq) { if requestIsWebsocket(outreq) {
transport = newConnHijackerTransport(transport) transport = newConnHijackerTransport(transport)
} else if transport == nil { } else if transport == nil {
transport = http.DefaultTransport transport = &http.Transport{
Dial: rp.dialer.Dial,
}
} }
rp.Director(outreq) rp.Director(outreq)
...@@ -361,7 +373,7 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, ...@@ -361,7 +373,7 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request,
} }
bufferPool.Put(hj.Replay) bufferPool.Put(hj.Replay)
} else { } else {
backendConn, err = net.Dial("tcp", outreq.URL.Host) backendConn, err = net.DialTimeout("tcp", outreq.URL.Host, rp.dialer.Timeout)
if err != nil { if err != nil {
return err return err
} }
......
...@@ -21,6 +21,7 @@ import ( ...@@ -21,6 +21,7 @@ import (
"net/url" "net/url"
"strconv" "strconv"
"testing" "testing"
"time"
) )
const ( const (
...@@ -66,7 +67,7 @@ func TestSingleSRVHostReverseProxy(t *testing.T) { ...@@ -66,7 +67,7 @@ func TestSingleSRVHostReverseProxy(t *testing.T) {
} }
port := uint16(pp) port := uint16(pp)
rp := NewSingleHostReverseProxy(target, "", http.DefaultMaxIdleConnsPerHost) rp := NewSingleHostReverseProxy(target, "", http.DefaultMaxIdleConnsPerHost, 30*time.Second)
rp.srvResolver = testResolver{ rp.srvResolver = testResolver{
result: []*net.SRV{ result: []*net.SRV{
{Target: upstream.Hostname(), Port: port, Priority: 1, Weight: 1}, {Target: upstream.Hostname(), Port: port, Priority: 1, Weight: 1},
......
...@@ -49,6 +49,7 @@ type staticUpstream struct { ...@@ -49,6 +49,7 @@ type staticUpstream struct {
Hosts HostPool Hosts HostPool
Policy Policy Policy Policy
KeepAlive int KeepAlive int
Timeout time.Duration
FailTimeout time.Duration FailTimeout time.Duration
TryDuration time.Duration TryDuration time.Duration
TryInterval time.Duration TryInterval time.Duration
...@@ -92,6 +93,7 @@ func NewStaticUpstreams(c caddyfile.Dispenser, host string) ([]Upstream, error) ...@@ -92,6 +93,7 @@ func NewStaticUpstreams(c caddyfile.Dispenser, host string) ([]Upstream, error)
TryInterval: 250 * time.Millisecond, TryInterval: 250 * time.Millisecond,
MaxConns: 0, MaxConns: 0,
KeepAlive: http.DefaultMaxIdleConnsPerHost, KeepAlive: http.DefaultMaxIdleConnsPerHost,
Timeout: 30 * time.Second,
resolver: net.DefaultResolver, resolver: net.DefaultResolver,
} }
...@@ -225,7 +227,7 @@ func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) { ...@@ -225,7 +227,7 @@ func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) {
return nil, err return nil, err
} }
uh.ReverseProxy = NewSingleHostReverseProxy(baseURL, uh.WithoutPathPrefix, u.KeepAlive) uh.ReverseProxy = NewSingleHostReverseProxy(baseURL, uh.WithoutPathPrefix, u.KeepAlive, u.Timeout)
if u.insecureSkipVerify { if u.insecureSkipVerify {
uh.ReverseProxy.UseInsecureTransport() uh.ReverseProxy.UseInsecureTransport()
} }
...@@ -464,6 +466,15 @@ func parseBlock(c *caddyfile.Dispenser, u *staticUpstream, hasSrv bool) error { ...@@ -464,6 +466,15 @@ func parseBlock(c *caddyfile.Dispenser, u *staticUpstream, hasSrv bool) error {
return c.ArgErr() return c.ArgErr()
} }
u.KeepAlive = n u.KeepAlive = n
case "timeout":
if !c.NextArg() {
return c.ArgErr()
}
dur, err := time.ParseDuration(c.Val())
if err != nil {
return c.Errf("unable to parse timeout duration '%s'", c.Val())
}
u.Timeout = dur
default: default:
return c.Errf("unknown property '%s'", c.Val()) return c.Errf("unknown property '%s'", c.Val())
} }
...@@ -619,6 +630,11 @@ func (u *staticUpstream) GetTryInterval() time.Duration { ...@@ -619,6 +630,11 @@ func (u *staticUpstream) GetTryInterval() time.Duration {
return u.TryInterval return u.TryInterval
} }
// GetTimeout returns u.Timeout.
func (u *staticUpstream) GetTimeout() time.Duration {
return u.Timeout
}
func (u *staticUpstream) GetHostCount() int { func (u *staticUpstream) GetHostCount() int {
return len(u.Hosts) return len(u.Hosts)
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment