Commit 98de336a authored by Tanmay Chaudhry's avatar Tanmay Chaudhry Committed by Matt Holt

proxy: Enabled configurable timeout (#2070)

* Enabled configurable Timeout for the proxy directive

* Added Test for reverse for proxy timeout

* Removed Duplication in proxy constructors

* Remove indirection from multiple constructors and refactor into one

* Fix inconsistent error message and refactor dialer initialization
parent 9fe2ef41
......@@ -58,6 +58,10 @@ type Upstream interface {
// Gets the number of upstream hosts.
GetHostCount() int
// Gets how long to wait before timing out
// the request
GetTimeout() time.Duration
// Stops the upstream from proxying requests to shutdown goroutines cleanly.
Stop() error
}
......@@ -187,7 +191,11 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
if nameURL, err := url.Parse(host.Name); err == nil {
outreq.Host = nameURL.Host
if proxy == nil {
proxy = NewSingleHostReverseProxy(nameURL, host.WithoutPathPrefix, http.DefaultMaxIdleConnsPerHost)
proxy = NewSingleHostReverseProxy(nameURL,
host.WithoutPathPrefix,
http.DefaultMaxIdleConnsPerHost,
upstream.GetTimeout(),
)
}
// use upstream credentials by default
......
This diff is collapsed.
......@@ -94,6 +94,10 @@ type ReverseProxy struct {
// If zero, no periodic flushing is done.
FlushInterval time.Duration
// dialer is used when values from the
// defaultDialer need to be overridden per Proxy
dialer *net.Dialer
srvResolver srvResolver
}
......@@ -103,13 +107,13 @@ type ReverseProxy struct {
// What we need is just the path, so if "unix:/var/run/www.socket"
// was the proxy directive, the parsed hostName would be
// "unix:///var/run/www.socket", hence the ambiguous trimming.
func socketDial(hostName string) func(network, addr string) (conn net.Conn, err error) {
func socketDial(hostName string, timeout time.Duration) func(network, addr string) (conn net.Conn, err error) {
return func(network, addr string) (conn net.Conn, err error) {
return net.Dial("unix", hostName[len("unix://"):])
return net.DialTimeout("unix", hostName[len("unix://"):], timeout)
}
}
func (rp *ReverseProxy) srvDialerFunc(locator string) func(network, addr string) (conn net.Conn, err error) {
func (rp *ReverseProxy) srvDialerFunc(locator string, timeout time.Duration) func(network, addr string) (conn net.Conn, err error) {
service := locator
if strings.HasPrefix(locator, "srv://") {
service = locator[6:]
......@@ -122,7 +126,7 @@ func (rp *ReverseProxy) srvDialerFunc(locator string) func(network, addr string)
if err != nil {
return nil, err
}
return net.Dial("tcp", fmt.Sprintf("%s:%d", addrs[0].Target, addrs[0].Port))
return net.DialTimeout("tcp", fmt.Sprintf("%s:%d", addrs[0].Target, addrs[0].Port), timeout)
}
}
......@@ -144,7 +148,7 @@ func singleJoiningSlash(a, b string) string {
// the target request will be for /base/dir.
// Without logic: target's path is "/", incoming is "/api/messages",
// without is "/api", then the target request will be for /messages.
func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *ReverseProxy {
func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int, timeout time.Duration) *ReverseProxy {
targetQuery := target.RawQuery
director := func(req *http.Request) {
if target.Scheme == "unix" {
......@@ -226,15 +230,21 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
}
}
dialer := *defaultDialer
if timeout != defaultDialer.Timeout {
dialer.Timeout = timeout
}
rp := &ReverseProxy{
Director: director,
FlushInterval: 250 * time.Millisecond, // flushing good for streaming & server-sent events
srvResolver: net.DefaultResolver,
dialer: &dialer,
}
if target.Scheme == "unix" {
rp.Transport = &http.Transport{
Dial: socketDial(target.String()),
Dial: socketDial(target.String(), timeout),
}
} else if target.Scheme == "quic" {
rp.Transport = &h2quic.RoundTripper{
......@@ -244,9 +254,9 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
},
}
} else if keepalive != http.DefaultMaxIdleConnsPerHost || strings.HasPrefix(target.Scheme, "srv") {
dialFunc := defaultDialer.Dial
dialFunc := rp.dialer.Dial
if strings.HasPrefix(target.Scheme, "srv") {
dialFunc = rp.srvDialerFunc(target.String())
dialFunc = rp.srvDialerFunc(target.String(), timeout)
}
transport := &http.Transport{
......@@ -275,7 +285,7 @@ func (rp *ReverseProxy) UseInsecureTransport() {
if rp.Transport == nil {
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
Dial: defaultDialer.Dial,
Dial: rp.dialer.Dial,
TLSHandshakeTimeout: defaultCryptoHandshakeTimeout,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
......@@ -306,7 +316,9 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request,
if requestIsWebsocket(outreq) {
transport = newConnHijackerTransport(transport)
} else if transport == nil {
transport = http.DefaultTransport
transport = &http.Transport{
Dial: rp.dialer.Dial,
}
}
rp.Director(outreq)
......@@ -361,7 +373,7 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request,
}
bufferPool.Put(hj.Replay)
} else {
backendConn, err = net.Dial("tcp", outreq.URL.Host)
backendConn, err = net.DialTimeout("tcp", outreq.URL.Host, rp.dialer.Timeout)
if err != nil {
return err
}
......
......@@ -21,6 +21,7 @@ import (
"net/url"
"strconv"
"testing"
"time"
)
const (
......@@ -66,7 +67,7 @@ func TestSingleSRVHostReverseProxy(t *testing.T) {
}
port := uint16(pp)
rp := NewSingleHostReverseProxy(target, "", http.DefaultMaxIdleConnsPerHost)
rp := NewSingleHostReverseProxy(target, "", http.DefaultMaxIdleConnsPerHost, 30*time.Second)
rp.srvResolver = testResolver{
result: []*net.SRV{
{Target: upstream.Hostname(), Port: port, Priority: 1, Weight: 1},
......
......@@ -49,6 +49,7 @@ type staticUpstream struct {
Hosts HostPool
Policy Policy
KeepAlive int
Timeout time.Duration
FailTimeout time.Duration
TryDuration time.Duration
TryInterval time.Duration
......@@ -92,6 +93,7 @@ func NewStaticUpstreams(c caddyfile.Dispenser, host string) ([]Upstream, error)
TryInterval: 250 * time.Millisecond,
MaxConns: 0,
KeepAlive: http.DefaultMaxIdleConnsPerHost,
Timeout: 30 * time.Second,
resolver: net.DefaultResolver,
}
......@@ -225,7 +227,7 @@ func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) {
return nil, err
}
uh.ReverseProxy = NewSingleHostReverseProxy(baseURL, uh.WithoutPathPrefix, u.KeepAlive)
uh.ReverseProxy = NewSingleHostReverseProxy(baseURL, uh.WithoutPathPrefix, u.KeepAlive, u.Timeout)
if u.insecureSkipVerify {
uh.ReverseProxy.UseInsecureTransport()
}
......@@ -464,6 +466,15 @@ func parseBlock(c *caddyfile.Dispenser, u *staticUpstream, hasSrv bool) error {
return c.ArgErr()
}
u.KeepAlive = n
case "timeout":
if !c.NextArg() {
return c.ArgErr()
}
dur, err := time.ParseDuration(c.Val())
if err != nil {
return c.Errf("unable to parse timeout duration '%s'", c.Val())
}
u.Timeout = dur
default:
return c.Errf("unknown property '%s'", c.Val())
}
......@@ -619,6 +630,11 @@ func (u *staticUpstream) GetTryInterval() time.Duration {
return u.TryInterval
}
// GetTimeout returns u.Timeout.
func (u *staticUpstream) GetTimeout() time.Duration {
return u.Timeout
}
func (u *staticUpstream) GetHostCount() int {
return len(u.Hosts)
}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment