Commit 5a6b7656 authored by ericdreeves's avatar ericdreeves Committed by Matt Holt

Add connect_timeout and read_timeout to fastcgi. (#1257)

parent 8acf0432
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"errors" "errors"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time"
) )
type dialer interface { type dialer interface {
...@@ -13,10 +14,12 @@ type dialer interface { ...@@ -13,10 +14,12 @@ type dialer interface {
// basicDialer is a basic dialer that wraps default fcgi functions. // basicDialer is a basic dialer that wraps default fcgi functions.
type basicDialer struct { type basicDialer struct {
network, address string network string
address string
timeout time.Duration
} }
func (b basicDialer) Dial() (Client, error) { return Dial(b.network, b.address) } func (b basicDialer) Dial() (Client, error) { return Dial(b.network, b.address, b.timeout) }
func (b basicDialer) Close(c Client) error { return c.Close() } func (b basicDialer) Close(c Client) error { return c.Close() }
// persistentDialer keeps a pool of fcgi connections. // persistentDialer keeps a pool of fcgi connections.
...@@ -25,6 +28,7 @@ type persistentDialer struct { ...@@ -25,6 +28,7 @@ type persistentDialer struct {
size int size int
network string network string
address string address string
timeout time.Duration
pool []Client pool []Client
sync.Mutex sync.Mutex
} }
...@@ -43,7 +47,7 @@ func (p *persistentDialer) Dial() (Client, error) { ...@@ -43,7 +47,7 @@ func (p *persistentDialer) Dial() (Client, error) {
p.Unlock() p.Unlock()
// no connection available, create new one // no connection available, create new one
return Dial(p.network, p.address) return Dial(p.network, p.address, p.timeout)
} }
func (p *persistentDialer) Close(client Client) error { func (p *persistentDialer) Close(client Client) error {
......
...@@ -12,6 +12,7 @@ import ( ...@@ -12,6 +12,7 @@ import (
"path/filepath" "path/filepath"
"strconv" "strconv"
"strings" "strings"
"time"
"github.com/mholt/caddy/caddyhttp/httpserver" "github.com/mholt/caddy/caddyhttp/httpserver"
) )
...@@ -81,6 +82,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) ...@@ -81,6 +82,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error)
if err != nil { if err != nil {
return http.StatusBadGateway, err return http.StatusBadGateway, err
} }
fcgiBackend.SetReadTimeout(rule.ReadTimeout)
var resp *http.Response var resp *http.Response
contentLength, _ := strconv.Atoi(r.Header.Get("Content-Length")) contentLength, _ := strconv.Atoi(r.Header.Get("Content-Length"))
...@@ -301,6 +303,9 @@ type Rule struct { ...@@ -301,6 +303,9 @@ type Rule struct {
// Ignored paths // Ignored paths
IgnoredSubPaths []string IgnoredSubPaths []string
// The duration used to set a deadline when reading from the FastCGI server.
ReadTimeout time.Duration
// FCGI dialer // FCGI dialer
dialer dialer dialer dialer
} }
......
...@@ -10,6 +10,7 @@ import ( ...@@ -10,6 +10,7 @@ import (
"strings" "strings"
"sync" "sync"
"testing" "testing"
"time"
) )
func TestServeHTTP(t *testing.T) { func TestServeHTTP(t *testing.T) {
...@@ -29,7 +30,13 @@ func TestServeHTTP(t *testing.T) { ...@@ -29,7 +30,13 @@ func TestServeHTTP(t *testing.T) {
network, address := parseAddress(listener.Addr().String()) network, address := parseAddress(listener.Addr().String())
handler := Handler{ handler := Handler{
Next: nil, Next: nil,
Rules: []Rule{{Path: "/", Address: listener.Addr().String(), dialer: basicDialer{network, address}}}, Rules: []Rule{
{
Path: "/",
Address: listener.Addr().String(),
dialer: basicDialer{network: network, address: address},
},
},
} }
r, err := http.NewRequest("GET", "/", nil) r, err := http.NewRequest("GET", "/", nil)
if err != nil { if err != nil {
...@@ -318,3 +325,39 @@ func TestBuildEnv(t *testing.T) { ...@@ -318,3 +325,39 @@ func TestBuildEnv(t *testing.T) {
} }
} }
func TestReadTimeout(t *testing.T) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Unable to create listener for test: %v", err)
}
defer listener.Close()
go fcgi.Serve(listener, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(time.Second * 1)
}))
network, address := parseAddress(listener.Addr().String())
handler := Handler{
Next: nil,
Rules: []Rule{
{
Path: "/",
Address: listener.Addr().String(),
dialer: basicDialer{network: network, address: address},
ReadTimeout: time.Millisecond * 100,
},
},
}
r, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatalf("Unable to create request: %v", err)
}
w := httptest.NewRecorder()
_, err = handler.ServeHTTP(w, r)
if err == nil {
t.Error("Expected i/o timeout error but had none")
} else if err, ok := err.(net.Error); !ok || !err.Timeout() {
t.Errorf("Expected i/o timeout error, got: '%s'", err.Error())
}
}
...@@ -15,6 +15,7 @@ import ( ...@@ -15,6 +15,7 @@ import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"mime/multipart" "mime/multipart"
...@@ -28,6 +29,7 @@ import ( ...@@ -28,6 +29,7 @@ import (
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"time"
) )
// FCGIListenSockFileno describes listen socket file number. // FCGIListenSockFileno describes listen socket file number.
...@@ -114,6 +116,8 @@ type Client interface { ...@@ -114,6 +116,8 @@ type Client interface {
Post(pairs map[string]string, method string, bodyType string, body io.Reader, contentLength int) (response *http.Response, err error) Post(pairs map[string]string, method string, bodyType string, body io.Reader, contentLength int) (response *http.Response, err error)
Close() error Close() error
StdErr() bytes.Buffer StdErr() bytes.Buffer
ReadTimeout() time.Duration
SetReadTimeout(time.Duration) error
} }
type header struct { type header struct {
...@@ -176,6 +180,7 @@ type FCGIClient struct { ...@@ -176,6 +180,7 @@ type FCGIClient struct {
stderr bytes.Buffer stderr bytes.Buffer
keepAlive bool keepAlive bool
reqID uint16 reqID uint16
readTimeout time.Duration
} }
// DialWithDialer connects to the fcgi responder at the specified network address, using custom net.Dialer. // DialWithDialer connects to the fcgi responder at the specified network address, using custom net.Dialer.
...@@ -198,8 +203,8 @@ func DialWithDialer(network, address string, dialer net.Dialer) (fcgi *FCGIClien ...@@ -198,8 +203,8 @@ func DialWithDialer(network, address string, dialer net.Dialer) (fcgi *FCGIClien
// Dial connects to the fcgi responder at the specified network address, using default net.Dialer. // Dial connects to the fcgi responder at the specified network address, using default net.Dialer.
// See func net.Dial for a description of the network and address parameters. // See func net.Dial for a description of the network and address parameters.
func Dial(network, address string) (fcgi *FCGIClient, err error) { func Dial(network string, address string, timeout time.Duration) (fcgi *FCGIClient, err error) {
return DialWithDialer(network, address, net.Dialer{}) return DialWithDialer(network, address, net.Dialer{Timeout: timeout})
} }
// Close closes fcgi connnection. // Close closes fcgi connnection.
...@@ -350,6 +355,15 @@ func (w *streamReader) Read(p []byte) (n int, err error) { ...@@ -350,6 +355,15 @@ func (w *streamReader) Read(p []byte) (n int, err error) {
for { for {
rec := &record{} rec := &record{}
var buf []byte var buf []byte
if readTimeout := w.c.ReadTimeout(); readTimeout > 0 {
conn, ok := w.c.rwc.(net.Conn)
if ok {
conn.SetReadDeadline(time.Now().Add(readTimeout))
} else {
err = fmt.Errorf("Could not set Client ReadTimeout")
return
}
}
buf, err = rec.read(w.c.rwc) buf, err = rec.read(w.c.rwc)
if err == errInvalidHeaderVersion { if err == errInvalidHeaderVersion {
continue continue
...@@ -559,6 +573,17 @@ func (c *FCGIClient) PostFile(p map[string]string, data url.Values, file map[str ...@@ -559,6 +573,17 @@ func (c *FCGIClient) PostFile(p map[string]string, data url.Values, file map[str
return c.Post(p, "POST", bodyType, buf, buf.Len()) return c.Post(p, "POST", bodyType, buf, buf.Len())
} }
// ReadTimeout returns the read timeout for future calls that read from the
// fcgi responder.
func (c *FCGIClient) ReadTimeout() time.Duration { return c.readTimeout }
// SetReadTimeout sets the read timeout for future calls that read from the
// fcgi responder. A zero value for t means no timeout will be set.
func (c *FCGIClient) SetReadTimeout(t time.Duration) error {
c.readTimeout = t
return nil
}
// Checks whether chunked is part of the encodings stack // Checks whether chunked is part of the encodings stack
func chunked(te []string) bool { return len(te) > 0 && te[0] == "chunked" } func chunked(te []string) bool { return len(te) > 0 && te[0] == "chunked" }
......
...@@ -103,7 +103,7 @@ func (s FastCGIServer) ServeHTTP(resp http.ResponseWriter, req *http.Request) { ...@@ -103,7 +103,7 @@ func (s FastCGIServer) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
} }
func sendFcgi(reqType int, fcgiParams map[string]string, data []byte, posts map[string]string, files map[string]string) (content []byte) { func sendFcgi(reqType int, fcgiParams map[string]string, data []byte, posts map[string]string, files map[string]string) (content []byte) {
fcgi, err := Dial("tcp", ipPort) fcgi, err := Dial("tcp", ipPort, 0)
if err != nil { if err != nil {
log.Println("err:", err) log.Println("err:", err)
return return
......
...@@ -6,6 +6,7 @@ import ( ...@@ -6,6 +6,7 @@ import (
"path/filepath" "path/filepath"
"strconv" "strconv"
"strings" "strings"
"time"
"github.com/mholt/caddy" "github.com/mholt/caddy"
"github.com/mholt/caddy/caddyhttp/httpserver" "github.com/mholt/caddy/caddyhttp/httpserver"
...@@ -69,6 +70,9 @@ func fastcgiParse(c *caddy.Controller) ([]Rule, error) { ...@@ -69,6 +70,9 @@ func fastcgiParse(c *caddy.Controller) ([]Rule, error) {
} }
} }
var err error
var pool int
var timeout time.Duration
var dialers []dialer var dialers []dialer
var poolSize = -1 var poolSize = -1
...@@ -116,7 +120,7 @@ func fastcgiParse(c *caddy.Controller) ([]Rule, error) { ...@@ -116,7 +120,7 @@ func fastcgiParse(c *caddy.Controller) ([]Rule, error) {
if !c.NextArg() { if !c.NextArg() {
return rules, c.ArgErr() return rules, c.ArgErr()
} }
pool, err := strconv.Atoi(c.Val()) pool, err = strconv.Atoi(c.Val())
if err != nil { if err != nil {
return rules, err return rules, err
} }
...@@ -125,15 +129,32 @@ func fastcgiParse(c *caddy.Controller) ([]Rule, error) { ...@@ -125,15 +129,32 @@ func fastcgiParse(c *caddy.Controller) ([]Rule, error) {
} else { } else {
return rules, c.Errf("positive integer expected, found %d", pool) return rules, c.Errf("positive integer expected, found %d", pool)
} }
case "connect_timeout":
if !c.NextArg() {
return rules, c.ArgErr()
}
timeout, err = time.ParseDuration(c.Val())
if err != nil {
return rules, err
}
case "read_timeout":
if !c.NextArg() {
return rules, c.ArgErr()
}
readTimeout, err := time.ParseDuration(c.Val())
if err != nil {
return rules, err
}
rule.ReadTimeout = readTimeout
} }
} }
for _, rawAddress := range upstreams { for _, rawAddress := range upstreams {
network, address := parseAddress(rawAddress) network, address := parseAddress(rawAddress)
if poolSize >= 0 { if poolSize >= 0 {
dialers = append(dialers, &persistentDialer{size: poolSize, network: network, address: address}) dialers = append(dialers, &persistentDialer{size: poolSize, network: network, address: address, timeout: timeout})
} else { } else {
dialers = append(dialers, basicDialer{network: network, address: address}) dialers = append(dialers, basicDialer{network: network, address: address, timeout: timeout})
} }
} }
......
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"testing" "testing"
"time"
"github.com/mholt/caddy" "github.com/mholt/caddy"
"github.com/mholt/caddy/caddyhttp/httpserver" "github.com/mholt/caddy/caddyhttp/httpserver"
...@@ -159,6 +160,29 @@ func TestFastcgiParse(t *testing.T) { ...@@ -159,6 +160,29 @@ func TestFastcgiParse(t *testing.T) {
dialer: &loadBalancingDialer{dialers: []dialer{basicDialer{network: network, address: address}}}, dialer: &loadBalancingDialer{dialers: []dialer{basicDialer{network: network, address: address}}},
IndexFiles: []string{}, IndexFiles: []string{},
}}}, }}},
{`fastcgi / ` + defaultAddress + ` {
connect_timeout 5s
}`,
false, []Rule{{
Path: "/",
Address: defaultAddress,
Ext: "",
SplitPath: "",
dialer: &loadBalancingDialer{dialers: []dialer{basicDialer{network: network, address: address, timeout: 5 * time.Second}}},
IndexFiles: []string{},
}}},
{`fastcgi / ` + defaultAddress + ` {
read_timeout 5s
}`,
false, []Rule{{
Path: "/",
Address: defaultAddress,
Ext: "",
SplitPath: "",
dialer: &loadBalancingDialer{dialers: []dialer{basicDialer{network: network, address: address}}},
IndexFiles: []string{},
ReadTimeout: 5 * time.Second,
}}},
{`fastcgi / { {`fastcgi / {
}`, }`,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment