Commit 98de336a authored by Tanmay Chaudhry's avatar Tanmay Chaudhry Committed by Matt Holt

proxy: Enabled configurable timeout (#2070)

* Enabled configurable Timeout for the proxy directive

* Added Test for reverse for proxy timeout

* Removed Duplication in proxy constructors

* Remove indirection from multiple constructors and refactor into one

* Fix inconsistent error message and refactor dialer initialization
parent 9fe2ef41
...@@ -58,6 +58,10 @@ type Upstream interface { ...@@ -58,6 +58,10 @@ type Upstream interface {
// Gets the number of upstream hosts. // Gets the number of upstream hosts.
GetHostCount() int GetHostCount() int
// Gets how long to wait before timing out
// the request
GetTimeout() time.Duration
// Stops the upstream from proxying requests to shutdown goroutines cleanly. // Stops the upstream from proxying requests to shutdown goroutines cleanly.
Stop() error Stop() error
} }
...@@ -187,7 +191,11 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { ...@@ -187,7 +191,11 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
if nameURL, err := url.Parse(host.Name); err == nil { if nameURL, err := url.Parse(host.Name); err == nil {
outreq.Host = nameURL.Host outreq.Host = nameURL.Host
if proxy == nil { if proxy == nil {
proxy = NewSingleHostReverseProxy(nameURL, host.WithoutPathPrefix, http.DefaultMaxIdleConnsPerHost) proxy = NewSingleHostReverseProxy(nameURL,
host.WithoutPathPrefix,
http.DefaultMaxIdleConnsPerHost,
upstream.GetTimeout(),
)
} }
// use upstream credentials by default // use upstream credentials by default
......
This diff is collapsed.
...@@ -94,6 +94,10 @@ type ReverseProxy struct { ...@@ -94,6 +94,10 @@ type ReverseProxy struct {
// If zero, no periodic flushing is done. // If zero, no periodic flushing is done.
FlushInterval time.Duration FlushInterval time.Duration
// dialer is used when values from the
// defaultDialer need to be overridden per Proxy
dialer *net.Dialer
srvResolver srvResolver srvResolver srvResolver
} }
...@@ -103,13 +107,13 @@ type ReverseProxy struct { ...@@ -103,13 +107,13 @@ type ReverseProxy struct {
// What we need is just the path, so if "unix:/var/run/www.socket" // What we need is just the path, so if "unix:/var/run/www.socket"
// was the proxy directive, the parsed hostName would be // was the proxy directive, the parsed hostName would be
// "unix:///var/run/www.socket", hence the ambiguous trimming. // "unix:///var/run/www.socket", hence the ambiguous trimming.
func socketDial(hostName string) func(network, addr string) (conn net.Conn, err error) { func socketDial(hostName string, timeout time.Duration) func(network, addr string) (conn net.Conn, err error) {
return func(network, addr string) (conn net.Conn, err error) { return func(network, addr string) (conn net.Conn, err error) {
return net.Dial("unix", hostName[len("unix://"):]) return net.DialTimeout("unix", hostName[len("unix://"):], timeout)
} }
} }
func (rp *ReverseProxy) srvDialerFunc(locator string) func(network, addr string) (conn net.Conn, err error) { func (rp *ReverseProxy) srvDialerFunc(locator string, timeout time.Duration) func(network, addr string) (conn net.Conn, err error) {
service := locator service := locator
if strings.HasPrefix(locator, "srv://") { if strings.HasPrefix(locator, "srv://") {
service = locator[6:] service = locator[6:]
...@@ -122,7 +126,7 @@ func (rp *ReverseProxy) srvDialerFunc(locator string) func(network, addr string) ...@@ -122,7 +126,7 @@ func (rp *ReverseProxy) srvDialerFunc(locator string) func(network, addr string)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return net.Dial("tcp", fmt.Sprintf("%s:%d", addrs[0].Target, addrs[0].Port)) return net.DialTimeout("tcp", fmt.Sprintf("%s:%d", addrs[0].Target, addrs[0].Port), timeout)
} }
} }
...@@ -144,7 +148,7 @@ func singleJoiningSlash(a, b string) string { ...@@ -144,7 +148,7 @@ func singleJoiningSlash(a, b string) string {
// the target request will be for /base/dir. // the target request will be for /base/dir.
// Without logic: target's path is "/", incoming is "/api/messages", // Without logic: target's path is "/", incoming is "/api/messages",
// without is "/api", then the target request will be for /messages. // without is "/api", then the target request will be for /messages.
func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *ReverseProxy { func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int, timeout time.Duration) *ReverseProxy {
targetQuery := target.RawQuery targetQuery := target.RawQuery
director := func(req *http.Request) { director := func(req *http.Request) {
if target.Scheme == "unix" { if target.Scheme == "unix" {
...@@ -226,15 +230,21 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) * ...@@ -226,15 +230,21 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
} }
} }
dialer := *defaultDialer
if timeout != defaultDialer.Timeout {
dialer.Timeout = timeout
}
rp := &ReverseProxy{ rp := &ReverseProxy{
Director: director, Director: director,
FlushInterval: 250 * time.Millisecond, // flushing good for streaming & server-sent events FlushInterval: 250 * time.Millisecond, // flushing good for streaming & server-sent events
srvResolver: net.DefaultResolver, srvResolver: net.DefaultResolver,
dialer: &dialer,
} }
if target.Scheme == "unix" { if target.Scheme == "unix" {
rp.Transport = &http.Transport{ rp.Transport = &http.Transport{
Dial: socketDial(target.String()), Dial: socketDial(target.String(), timeout),
} }
} else if target.Scheme == "quic" { } else if target.Scheme == "quic" {
rp.Transport = &h2quic.RoundTripper{ rp.Transport = &h2quic.RoundTripper{
...@@ -244,9 +254,9 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) * ...@@ -244,9 +254,9 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
}, },
} }
} else if keepalive != http.DefaultMaxIdleConnsPerHost || strings.HasPrefix(target.Scheme, "srv") { } else if keepalive != http.DefaultMaxIdleConnsPerHost || strings.HasPrefix(target.Scheme, "srv") {
dialFunc := defaultDialer.Dial dialFunc := rp.dialer.Dial
if strings.HasPrefix(target.Scheme, "srv") { if strings.HasPrefix(target.Scheme, "srv") {
dialFunc = rp.srvDialerFunc(target.String()) dialFunc = rp.srvDialerFunc(target.String(), timeout)
} }
transport := &http.Transport{ transport := &http.Transport{
...@@ -275,7 +285,7 @@ func (rp *ReverseProxy) UseInsecureTransport() { ...@@ -275,7 +285,7 @@ func (rp *ReverseProxy) UseInsecureTransport() {
if rp.Transport == nil { if rp.Transport == nil {
transport := &http.Transport{ transport := &http.Transport{
Proxy: http.ProxyFromEnvironment, Proxy: http.ProxyFromEnvironment,
Dial: defaultDialer.Dial, Dial: rp.dialer.Dial,
TLSHandshakeTimeout: defaultCryptoHandshakeTimeout, TLSHandshakeTimeout: defaultCryptoHandshakeTimeout,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
} }
...@@ -306,7 +316,9 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, ...@@ -306,7 +316,9 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request,
if requestIsWebsocket(outreq) { if requestIsWebsocket(outreq) {
transport = newConnHijackerTransport(transport) transport = newConnHijackerTransport(transport)
} else if transport == nil { } else if transport == nil {
transport = http.DefaultTransport transport = &http.Transport{
Dial: rp.dialer.Dial,
}
} }
rp.Director(outreq) rp.Director(outreq)
...@@ -361,7 +373,7 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, ...@@ -361,7 +373,7 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request,
} }
bufferPool.Put(hj.Replay) bufferPool.Put(hj.Replay)
} else { } else {
backendConn, err = net.Dial("tcp", outreq.URL.Host) backendConn, err = net.DialTimeout("tcp", outreq.URL.Host, rp.dialer.Timeout)
if err != nil { if err != nil {
return err return err
} }
......
...@@ -21,6 +21,7 @@ import ( ...@@ -21,6 +21,7 @@ import (
"net/url" "net/url"
"strconv" "strconv"
"testing" "testing"
"time"
) )
const ( const (
...@@ -66,7 +67,7 @@ func TestSingleSRVHostReverseProxy(t *testing.T) { ...@@ -66,7 +67,7 @@ func TestSingleSRVHostReverseProxy(t *testing.T) {
} }
port := uint16(pp) port := uint16(pp)
rp := NewSingleHostReverseProxy(target, "", http.DefaultMaxIdleConnsPerHost) rp := NewSingleHostReverseProxy(target, "", http.DefaultMaxIdleConnsPerHost, 30*time.Second)
rp.srvResolver = testResolver{ rp.srvResolver = testResolver{
result: []*net.SRV{ result: []*net.SRV{
{Target: upstream.Hostname(), Port: port, Priority: 1, Weight: 1}, {Target: upstream.Hostname(), Port: port, Priority: 1, Weight: 1},
......
...@@ -49,6 +49,7 @@ type staticUpstream struct { ...@@ -49,6 +49,7 @@ type staticUpstream struct {
Hosts HostPool Hosts HostPool
Policy Policy Policy Policy
KeepAlive int KeepAlive int
Timeout time.Duration
FailTimeout time.Duration FailTimeout time.Duration
TryDuration time.Duration TryDuration time.Duration
TryInterval time.Duration TryInterval time.Duration
...@@ -92,6 +93,7 @@ func NewStaticUpstreams(c caddyfile.Dispenser, host string) ([]Upstream, error) ...@@ -92,6 +93,7 @@ func NewStaticUpstreams(c caddyfile.Dispenser, host string) ([]Upstream, error)
TryInterval: 250 * time.Millisecond, TryInterval: 250 * time.Millisecond,
MaxConns: 0, MaxConns: 0,
KeepAlive: http.DefaultMaxIdleConnsPerHost, KeepAlive: http.DefaultMaxIdleConnsPerHost,
Timeout: 30 * time.Second,
resolver: net.DefaultResolver, resolver: net.DefaultResolver,
} }
...@@ -225,7 +227,7 @@ func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) { ...@@ -225,7 +227,7 @@ func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) {
return nil, err return nil, err
} }
uh.ReverseProxy = NewSingleHostReverseProxy(baseURL, uh.WithoutPathPrefix, u.KeepAlive) uh.ReverseProxy = NewSingleHostReverseProxy(baseURL, uh.WithoutPathPrefix, u.KeepAlive, u.Timeout)
if u.insecureSkipVerify { if u.insecureSkipVerify {
uh.ReverseProxy.UseInsecureTransport() uh.ReverseProxy.UseInsecureTransport()
} }
...@@ -464,6 +466,15 @@ func parseBlock(c *caddyfile.Dispenser, u *staticUpstream, hasSrv bool) error { ...@@ -464,6 +466,15 @@ func parseBlock(c *caddyfile.Dispenser, u *staticUpstream, hasSrv bool) error {
return c.ArgErr() return c.ArgErr()
} }
u.KeepAlive = n u.KeepAlive = n
case "timeout":
if !c.NextArg() {
return c.ArgErr()
}
dur, err := time.ParseDuration(c.Val())
if err != nil {
return c.Errf("unable to parse timeout duration '%s'", c.Val())
}
u.Timeout = dur
default: default:
return c.Errf("unknown property '%s'", c.Val()) return c.Errf("unknown property '%s'", c.Val())
} }
...@@ -619,6 +630,11 @@ func (u *staticUpstream) GetTryInterval() time.Duration { ...@@ -619,6 +630,11 @@ func (u *staticUpstream) GetTryInterval() time.Duration {
return u.TryInterval return u.TryInterval
} }
// GetTimeout returns u.Timeout.
func (u *staticUpstream) GetTimeout() time.Duration {
return u.Timeout
}
func (u *staticUpstream) GetHostCount() int { func (u *staticUpstream) GetHostCount() int {
return len(u.Hosts) return len(u.Hosts)
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment