Commit c972ea39 authored by Mateusz Gajewski's avatar Mateusz Gajewski Committed by Matt Holt

Fastcgi upstreams (#1264)

* Make fastcgi load balanceable too

* Address one more corner case - invalid configuration fastcgi /

* After review fixes

* Simplify conditions

* Error message

* New fastcgi syntax

* golint will be happy

* Change syntax
parent 12fd3499
package fastcgi package fastcgi
import "sync" import (
"errors"
"sync"
"sync/atomic"
)
type dialer interface { type dialer interface {
Dial() (*FCGIClient, error) Dial() (Client, error)
Close(*FCGIClient) error Close(Client) error
} }
// basicDialer is a basic dialer that wraps default fcgi functions. // basicDialer is a basic dialer that wraps default fcgi functions.
...@@ -12,8 +16,8 @@ type basicDialer struct { ...@@ -12,8 +16,8 @@ type basicDialer struct {
network, address string network, address string
} }
func (b basicDialer) Dial() (*FCGIClient, error) { return Dial(b.network, b.address) } func (b basicDialer) Dial() (Client, error) { return Dial(b.network, b.address) }
func (b basicDialer) Close(c *FCGIClient) error { return c.Close() } func (b basicDialer) Close(c Client) error { return c.Close() }
// persistentDialer keeps a pool of fcgi connections. // persistentDialer keeps a pool of fcgi connections.
// connections are not closed after use, rather added back to the pool for reuse. // connections are not closed after use, rather added back to the pool for reuse.
...@@ -21,11 +25,11 @@ type persistentDialer struct { ...@@ -21,11 +25,11 @@ type persistentDialer struct {
size int size int
network string network string
address string address string
pool []*FCGIClient pool []Client
sync.Mutex sync.Mutex
} }
func (p *persistentDialer) Dial() (*FCGIClient, error) { func (p *persistentDialer) Dial() (Client, error) {
p.Lock() p.Lock()
// connection is available, return first one. // connection is available, return first one.
if len(p.pool) > 0 { if len(p.pool) > 0 {
...@@ -42,7 +46,7 @@ func (p *persistentDialer) Dial() (*FCGIClient, error) { ...@@ -42,7 +46,7 @@ func (p *persistentDialer) Dial() (*FCGIClient, error) {
return Dial(p.network, p.address) return Dial(p.network, p.address)
} }
func (p *persistentDialer) Close(client *FCGIClient) error { func (p *persistentDialer) Close(client Client) error {
p.Lock() p.Lock()
if len(p.pool) < p.size { if len(p.pool) < p.size {
// pool is not full yet, add connection for reuse // pool is not full yet, add connection for reuse
...@@ -57,3 +61,35 @@ func (p *persistentDialer) Close(client *FCGIClient) error { ...@@ -57,3 +61,35 @@ func (p *persistentDialer) Close(client *FCGIClient) error {
// otherwise, close the connection. // otherwise, close the connection.
return client.Close() return client.Close()
} }
type loadBalancingDialer struct {
dialers []dialer
current int64
}
func (m *loadBalancingDialer) Dial() (Client, error) {
nextDialerIndex := atomic.AddInt64(&m.current, 1) % int64(len(m.dialers))
currentDialer := m.dialers[nextDialerIndex]
client, err := currentDialer.Dial()
if err != nil {
return nil, err
}
return &dialerAwareClient{Client: client, dialer: currentDialer}, nil
}
func (m *loadBalancingDialer) Close(c Client) error {
// Close the client according to dialer behaviour
if da, ok := c.(*dialerAwareClient); ok {
return da.dialer.Close(c)
}
return errors.New("Cannot close client")
}
type dialerAwareClient struct {
Client
dialer dialer
}
package fastcgi
import (
"errors"
"testing"
)
func TestLoadbalancingDialer(t *testing.T) {
// given
runs := 100
mockDialer1 := new(mockDialer)
mockDialer2 := new(mockDialer)
dialer := &loadBalancingDialer{dialers: []dialer{mockDialer1, mockDialer2}}
// when
for i := 0; i < runs; i++ {
client, err := dialer.Dial()
dialer.Close(client)
if err != nil {
t.Errorf("Expected error to be nil")
}
}
// then
if mockDialer1.dialCalled != mockDialer2.dialCalled && mockDialer1.dialCalled != 50 {
t.Errorf("Expected dialer to call Dial() on multiple backend dialers %d times [actual: %d, %d]", 50, mockDialer1.dialCalled, mockDialer2.dialCalled)
}
if mockDialer1.closeCalled != mockDialer2.closeCalled && mockDialer1.closeCalled != 50 {
t.Errorf("Expected dialer to call Close() on multiple backend dialers %d times [actual: %d, %d]", 50, mockDialer1.closeCalled, mockDialer2.closeCalled)
}
}
func TestLoadBalancingDialerShouldReturnDialerAwareClient(t *testing.T) {
// given
mockDialer1 := new(mockDialer)
dialer := &loadBalancingDialer{dialers: []dialer{mockDialer1}}
// when
client, err := dialer.Dial()
// then
if err != nil {
t.Errorf("Expected error to be nil")
}
if awareClient, ok := client.(*dialerAwareClient); !ok {
t.Error("Expected dialer to wrap client")
} else {
if awareClient.dialer != mockDialer1 {
t.Error("Expected wrapped client to have reference to dialer")
}
}
}
func TestLoadBalancingDialerShouldUnderlyingReturnDialerError(t *testing.T) {
// given
mockDialer1 := new(errorReturningDialer)
dialer := &loadBalancingDialer{dialers: []dialer{mockDialer1}}
// when
_, err := dialer.Dial()
// then
if err.Error() != "Error during dial" {
t.Errorf("Expected 'Error during dial', got: '%s'", err.Error())
}
}
func TestLoadBalancingDialerShouldCloseClient(t *testing.T) {
// given
mockDialer1 := new(mockDialer)
mockDialer2 := new(mockDialer)
dialer := &loadBalancingDialer{dialers: []dialer{mockDialer1, mockDialer2}}
client, _ := dialer.Dial()
// when
err := dialer.Close(client)
// then
if err != nil {
t.Error("Expected error not to occur")
}
// load balancing starts from index 1
if mockDialer2.client != client {
t.Errorf("Expected Close() to be called on referenced dialer")
}
}
type mockDialer struct {
dialCalled int
closeCalled int
client Client
}
type mockClient struct {
Client
}
func (m *mockDialer) Dial() (Client, error) {
m.dialCalled++
return mockClient{Client: &FCGIClient{}}, nil
}
func (m *mockDialer) Close(c Client) error {
m.client = c
m.closeCalled++
return nil
}
type errorReturningDialer struct {
client Client
}
func (m *errorReturningDialer) Dial() (Client, error) {
return mockClient{Client: &FCGIClient{}}, errors.New("Error during dial")
}
func (m *errorReturningDialer) Close(c Client) error {
m.client = c
return errors.New("Error during close")
}
...@@ -111,9 +111,9 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) ...@@ -111,9 +111,9 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error)
defer rule.dialer.Close(fcgiBackend) defer rule.dialer.Close(fcgiBackend)
// Log any stderr output from upstream // Log any stderr output from upstream
if fcgiBackend.stderr.Len() != 0 { if stderr := fcgiBackend.StdErr(); stderr.Len() != 0 {
// Remove trailing newline, error logger already does this. // Remove trailing newline, error logger already does this.
err = LogError(strings.TrimSuffix(fcgiBackend.stderr.String(), "\n")) err = LogError(strings.TrimSuffix(stderr.String(), "\n"))
} }
// Normally we would return the status code if it is an error status (>= 400), // Normally we would return the status code if it is an error status (>= 400),
......
...@@ -106,6 +106,16 @@ const ( ...@@ -106,6 +106,16 @@ const (
maxPad = 255 maxPad = 255
) )
// Client interface
type Client interface {
Get(pair map[string]string) (response *http.Response, err error)
Head(pair map[string]string) (response *http.Response, err error)
Options(pairs map[string]string) (response *http.Response, err error)
Post(pairs map[string]string, method string, bodyType string, body io.Reader, contentLength int) (response *http.Response, err error)
Close() error
StdErr() bytes.Buffer
}
type header struct { type header struct {
Version uint8 Version uint8
Type uint8 Type uint8
...@@ -197,22 +207,29 @@ func (c *FCGIClient) Close() error { ...@@ -197,22 +207,29 @@ func (c *FCGIClient) Close() error {
return c.rwc.Close() return c.rwc.Close()
} }
func (c *FCGIClient) writeRecord(recType uint8, content []byte) (err error) { func (c *FCGIClient) writeRecord(recType uint8, content []byte) error {
c.mutex.Lock() c.mutex.Lock()
defer c.mutex.Unlock() defer c.mutex.Unlock()
c.buf.Reset() c.buf.Reset()
c.h.init(recType, c.reqID, len(content)) c.h.init(recType, c.reqID, len(content))
if err := binary.Write(&c.buf, binary.BigEndian, c.h); err != nil { if err := binary.Write(&c.buf, binary.BigEndian, c.h); err != nil {
return err return err
} }
if _, err := c.buf.Write(content); err != nil { if _, err := c.buf.Write(content); err != nil {
return err return err
} }
if _, err := c.buf.Write(pad[:c.h.PaddingLength]); err != nil { if _, err := c.buf.Write(pad[:c.h.PaddingLength]); err != nil {
return err return err
} }
_, err = c.rwc.Write(c.buf.Bytes())
if _, err := c.rwc.Write(c.buf.Bytes()); err != nil {
return err return err
}
return nil
} }
func (c *FCGIClient) writeBeginRequest(role uint16, flags uint8) error { func (c *FCGIClient) writeBeginRequest(role uint16, flags uint8) error {
...@@ -360,6 +377,11 @@ func (w *streamReader) Read(p []byte) (n int, err error) { ...@@ -360,6 +377,11 @@ func (w *streamReader) Read(p []byte) (n int, err error) {
return return
} }
// StdErr returns stderr stream
func (c *FCGIClient) StdErr() bytes.Buffer {
return c.stderr
}
// Do made the request and returns a io.Reader that translates the data read // Do made the request and returns a io.Reader that translates the data read
// from fcgi responder out of fcgi packet before returning it. // from fcgi responder out of fcgi packet before returning it.
func (c *FCGIClient) Do(p map[string]string, req io.Reader) (r io.Reader, err error) { func (c *FCGIClient) Do(p map[string]string, req io.Reader) (r io.Reader, err error) {
......
...@@ -5,6 +5,7 @@ import ( ...@@ -5,6 +5,7 @@ import (
"net/http" "net/http"
"path/filepath" "path/filepath"
"strconv" "strconv"
"strings"
"github.com/mholt/caddy" "github.com/mholt/caddy"
"github.com/mholt/caddy/caddyhttp/httpserver" "github.com/mholt/caddy/caddyhttp/httpserver"
...@@ -55,26 +56,21 @@ func fastcgiParse(c *caddy.Controller) ([]Rule, error) { ...@@ -55,26 +56,21 @@ func fastcgiParse(c *caddy.Controller) ([]Rule, error) {
args := c.RemainingArgs() args := c.RemainingArgs()
switch len(args) { if len(args) < 2 || len(args) > 3 {
case 0:
return rules, c.ArgErr() return rules, c.ArgErr()
case 1: }
rule.Path = "/"
rule.Address = args[0]
case 2:
rule.Path = args[0]
rule.Address = args[1]
case 3:
rule.Path = args[0] rule.Path = args[0]
rule.Address = args[1] upstreams := []string{args[1]}
err := fastcgiPreset(args[2], &rule)
if err != nil { if len(args) == 3 {
return rules, c.Err("Invalid fastcgi rule preset '" + args[2] + "'") if err := fastcgiPreset(args[2], &rule); err != nil {
return rules, err
} }
} }
network, address := parseAddress(rule.Address) var dialers []dialer
rule.dialer = basicDialer{network: network, address: address} var poolSize = -1
for c.NextBlock() { for c.NextBlock() {
switch c.Val() { switch c.Val() {
...@@ -94,6 +90,15 @@ func fastcgiParse(c *caddy.Controller) ([]Rule, error) { ...@@ -94,6 +90,15 @@ func fastcgiParse(c *caddy.Controller) ([]Rule, error) {
return rules, c.ArgErr() return rules, c.ArgErr()
} }
rule.IndexFiles = args rule.IndexFiles = args
case "upstream":
args := c.RemainingArgs()
if len(args) != 1 {
return rules, c.ArgErr()
}
upstreams = append(upstreams, args[0])
case "env": case "env":
envArgs := c.RemainingArgs() envArgs := c.RemainingArgs()
if len(envArgs) < 2 { if len(envArgs) < 2 {
...@@ -106,6 +111,7 @@ func fastcgiParse(c *caddy.Controller) ([]Rule, error) { ...@@ -106,6 +111,7 @@ func fastcgiParse(c *caddy.Controller) ([]Rule, error) {
return rules, c.ArgErr() return rules, c.ArgErr()
} }
rule.IgnoredSubPaths = ignoredPaths rule.IgnoredSubPaths = ignoredPaths
case "pool": case "pool":
if !c.NextArg() { if !c.NextArg() {
return rules, c.ArgErr() return rules, c.ArgErr()
...@@ -115,13 +121,24 @@ func fastcgiParse(c *caddy.Controller) ([]Rule, error) { ...@@ -115,13 +121,24 @@ func fastcgiParse(c *caddy.Controller) ([]Rule, error) {
return rules, err return rules, err
} }
if pool >= 0 { if pool >= 0 {
rule.dialer = &persistentDialer{size: pool, network: network, address: address} poolSize = pool
} else { } else {
return rules, c.Errf("positive integer expected, found %d", pool) return rules, c.Errf("positive integer expected, found %d", pool)
} }
} }
} }
for _, rawAddress := range upstreams {
network, address := parseAddress(rawAddress)
if poolSize >= 0 {
dialers = append(dialers, &persistentDialer{size: poolSize, network: network, address: address})
} else {
dialers = append(dialers, basicDialer{network: network, address: address})
}
}
rule.dialer = &loadBalancingDialer{dialers: dialers}
rule.Address = strings.Join(upstreams, ",")
rules = append(rules, rule) rules = append(rules, rule)
} }
......
...@@ -76,9 +76,31 @@ func TestFastcgiParse(t *testing.T) { ...@@ -76,9 +76,31 @@ func TestFastcgiParse(t *testing.T) {
Address: "127.0.0.1:9000", Address: "127.0.0.1:9000",
Ext: ".php", Ext: ".php",
SplitPath: ".php", SplitPath: ".php",
dialer: basicDialer{network: "tcp", address: "127.0.0.1:9000"}, dialer: &loadBalancingDialer{dialers: []dialer{basicDialer{network: "tcp", address: "127.0.0.1:9000"}}},
IndexFiles: []string{"index.php"}, IndexFiles: []string{"index.php"},
}}}, }}},
{`fastcgi /blog 127.0.0.1:9000 php {
upstream 127.0.0.1:9001
}`,
false, []Rule{{
Path: "/blog",
Address: "127.0.0.1:9000,127.0.0.1:9001",
Ext: ".php",
SplitPath: ".php",
dialer: &loadBalancingDialer{dialers: []dialer{basicDialer{network: "tcp", address: "127.0.0.1:9000"}, basicDialer{network: "tcp", address: "127.0.0.1:9001"}}},
IndexFiles: []string{"index.php"},
}}},
{`fastcgi /blog 127.0.0.1:9000 {
upstream 127.0.0.1:9001
}`,
false, []Rule{{
Path: "/blog",
Address: "127.0.0.1:9000,127.0.0.1:9001",
Ext: "",
SplitPath: "",
dialer: &loadBalancingDialer{dialers: []dialer{basicDialer{network: "tcp", address: "127.0.0.1:9000"}, basicDialer{network: "tcp", address: "127.0.0.1:9001"}}},
IndexFiles: []string{},
}}},
{`fastcgi / ` + defaultAddress + ` { {`fastcgi / ` + defaultAddress + ` {
split .html split .html
}`, }`,
...@@ -87,7 +109,7 @@ func TestFastcgiParse(t *testing.T) { ...@@ -87,7 +109,7 @@ func TestFastcgiParse(t *testing.T) {
Address: defaultAddress, Address: defaultAddress,
Ext: "", Ext: "",
SplitPath: ".html", SplitPath: ".html",
dialer: basicDialer{network: network, address: address}, dialer: &loadBalancingDialer{dialers: []dialer{basicDialer{network: network, address: address}}},
IndexFiles: []string{}, IndexFiles: []string{},
}}}, }}},
{`fastcgi / ` + defaultAddress + ` { {`fastcgi / ` + defaultAddress + ` {
...@@ -99,7 +121,7 @@ func TestFastcgiParse(t *testing.T) { ...@@ -99,7 +121,7 @@ func TestFastcgiParse(t *testing.T) {
Address: "127.0.0.1:9001", Address: "127.0.0.1:9001",
Ext: "", Ext: "",
SplitPath: ".html", SplitPath: ".html",
dialer: basicDialer{network: network, address: address}, dialer: &loadBalancingDialer{dialers: []dialer{basicDialer{network: network, address: address}}},
IndexFiles: []string{}, IndexFiles: []string{},
IgnoredSubPaths: []string{"/admin", "/user"}, IgnoredSubPaths: []string{"/admin", "/user"},
}}}, }}},
...@@ -111,18 +133,19 @@ func TestFastcgiParse(t *testing.T) { ...@@ -111,18 +133,19 @@ func TestFastcgiParse(t *testing.T) {
Address: defaultAddress, Address: defaultAddress,
Ext: "", Ext: "",
SplitPath: "", SplitPath: "",
dialer: &persistentDialer{size: 0, network: network, address: address}, dialer: &loadBalancingDialer{dialers: []dialer{&persistentDialer{size: 0, network: network, address: address}}},
IndexFiles: []string{}, IndexFiles: []string{},
}}}, }}},
{`fastcgi / ` + defaultAddress + ` { {`fastcgi / 127.0.0.1:8080 {
upstream 127.0.0.1:9000
pool 5 pool 5
}`, }`,
false, []Rule{{ false, []Rule{{
Path: "/", Path: "/",
Address: defaultAddress, Address: "127.0.0.1:8080,127.0.0.1:9000",
Ext: "", Ext: "",
SplitPath: "", SplitPath: "",
dialer: &persistentDialer{size: 5, network: network, address: address}, dialer: &loadBalancingDialer{dialers: []dialer{&persistentDialer{size: 5, network: "tcp", address: "127.0.0.1:8080"}, &persistentDialer{size: 5, network: "tcp", address: "127.0.0.1:9000"}}},
IndexFiles: []string{}, IndexFiles: []string{},
}}}, }}},
{`fastcgi / ` + defaultAddress + ` { {`fastcgi / ` + defaultAddress + ` {
...@@ -133,9 +156,14 @@ func TestFastcgiParse(t *testing.T) { ...@@ -133,9 +156,14 @@ func TestFastcgiParse(t *testing.T) {
Address: defaultAddress, Address: defaultAddress,
Ext: "", Ext: "",
SplitPath: ".php", SplitPath: ".php",
dialer: basicDialer{network: network, address: address}, dialer: &loadBalancingDialer{dialers: []dialer{basicDialer{network: network, address: address}}},
IndexFiles: []string{}, IndexFiles: []string{},
}}}, }}},
{`fastcgi / {
}`,
true, []Rule{},
},
} }
for i, test := range tests { for i, test := range tests {
actualFastcgiConfigs, err := fastcgiParse(caddy.NewTestController("http", test.inputFastcgiConfig)) actualFastcgiConfigs, err := fastcgiParse(caddy.NewTestController("http", test.inputFastcgiConfig))
...@@ -175,20 +203,7 @@ func TestFastcgiParse(t *testing.T) { ...@@ -175,20 +203,7 @@ func TestFastcgiParse(t *testing.T) {
t.Errorf("Test %d expected %dth FastCGI dialer to be of type %T, but got %T", t.Errorf("Test %d expected %dth FastCGI dialer to be of type %T, but got %T",
i, j, test.expectedFastcgiConfig[j].dialer, actualFastcgiConfig.dialer) i, j, test.expectedFastcgiConfig[j].dialer, actualFastcgiConfig.dialer)
} else { } else {
equal := true if !areDialersEqual(actualFastcgiConfig.dialer, test.expectedFastcgiConfig[j].dialer, t) {
switch actual := actualFastcgiConfig.dialer.(type) {
case basicDialer:
equal = actualFastcgiConfig.dialer == test.expectedFastcgiConfig[j].dialer
case *persistentDialer:
if expected, ok := test.expectedFastcgiConfig[j].dialer.(*persistentDialer); ok {
equal = actual.Equals(expected)
} else {
equal = false
}
default:
t.Errorf("Unkonw dialer type %T", actualFastcgiConfig.dialer)
}
if !equal {
t.Errorf("Test %d expected %dth FastCGI dialer to be %v, but got %v", t.Errorf("Test %d expected %dth FastCGI dialer to be %v, but got %v",
i, j, test.expectedFastcgiConfig[j].dialer, actualFastcgiConfig.dialer) i, j, test.expectedFastcgiConfig[j].dialer, actualFastcgiConfig.dialer)
} }
...@@ -205,5 +220,31 @@ func TestFastcgiParse(t *testing.T) { ...@@ -205,5 +220,31 @@ func TestFastcgiParse(t *testing.T) {
} }
} }
} }
}
func areDialersEqual(current, expected dialer, t *testing.T) bool {
switch actual := current.(type) {
case *loadBalancingDialer:
if expected, ok := expected.(*loadBalancingDialer); ok {
for i := 0; i < len(actual.dialers); i++ {
if !areDialersEqual(actual.dialers[i], expected.dialers[i], t) {
return false
}
}
return true
}
case basicDialer:
return current == expected
case *persistentDialer:
if expected, ok := expected.(*persistentDialer); ok {
return actual.Equals(expected)
}
default:
t.Errorf("Unknown dialer type %T", current)
}
return false
} }
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment