Commit 5f32f9b1 authored by Matt Holt's avatar Matt Holt

Merge pull request #40 from ChannelMeter/proxy-middleware

Proxy Middleware: Add support for multiple backends, load balancing & healthchecks
parents 290cf829 264e5b79
package proxy
import (
"math/rand"
"sync/atomic"
)
type HostPool []*UpstreamHost
// Policy decides how a host will be selected from a pool.
type Policy interface {
Select(pool HostPool) *UpstreamHost
}
// The random policy randomly selected an up host from the pool.
type Random struct{}
func (r *Random) Select(pool HostPool) *UpstreamHost {
// instead of just generating a random index
// this is done to prevent selecting a down host
var randHost *UpstreamHost
count := 0
for _, host := range pool {
if host.Down() {
continue
}
count++
if count == 1 {
randHost = host
} else {
r := rand.Int() % count
if r == (count - 1) {
randHost = host
}
}
}
return randHost
}
// The least_conn policy selects a host with the least connections.
// If multiple hosts have the least amount of connections, one is randomly
// chosen.
type LeastConn struct{}
func (r *LeastConn) Select(pool HostPool) *UpstreamHost {
var bestHost *UpstreamHost
count := 0
leastConn := int64(1<<63 - 1)
for _, host := range pool {
if host.Down() {
continue
}
hostConns := host.Conns
if hostConns < leastConn {
bestHost = host
leastConn = hostConns
count = 1
} else if hostConns == leastConn {
// randomly select host among hosts with least connections
count++
if count == 1 {
bestHost = host
} else {
r := rand.Int() % count
if r == (count - 1) {
bestHost = host
}
}
}
}
return bestHost
}
// The round_robin policy selects a host based on round robin ordering.
type RoundRobin struct {
Robin uint32
}
func (r *RoundRobin) Select(pool HostPool) *UpstreamHost {
poolLen := uint32(len(pool))
selection := atomic.AddUint32(&r.Robin, 1) % poolLen
host := pool[selection]
// if the currently selected host is down, just ffwd to up host
for i := uint32(1); host.Down() && i < poolLen; i++ {
host = pool[(selection+i)%poolLen]
}
if host.Down() {
return nil
}
return host
}
package proxy
import (
"testing"
)
func testPool() HostPool {
pool := []*UpstreamHost{
&UpstreamHost{
Name: "http://google.com", // this should resolve (healthcheck test)
},
&UpstreamHost{
Name: "http://shouldnot.resolve", // this shouldn't
},
&UpstreamHost{
Name: "http://C",
},
}
return HostPool(pool)
}
func TestRoundRobinPolicy(t *testing.T) {
pool := testPool()
rrPolicy := &RoundRobin{}
h := rrPolicy.Select(pool)
// First selected host is 1, because counter starts at 0
// and increments before host is selected
if h != pool[1] {
t.Error("Expected first round robin host to be second host in the pool.")
}
h = rrPolicy.Select(pool)
if h != pool[2] {
t.Error("Expected second round robin host to be third host in the pool.")
}
// mark host as down
pool[0].Unhealthy = true
h = rrPolicy.Select(pool)
if h != pool[1] {
t.Error("Expected third round robin host to be first host in the pool.")
}
}
func TestLeastConnPolicy(t *testing.T) {
pool := testPool()
lcPolicy := &LeastConn{}
pool[0].Conns = 10
pool[1].Conns = 10
h := lcPolicy.Select(pool)
if h != pool[2] {
t.Error("Expected least connection host to be third host.")
}
pool[2].Conns = 100
h = lcPolicy.Select(pool)
if h != pool[0] && h != pool[1] {
t.Error("Expected least connection host to be first or second host.")
}
}
...@@ -2,52 +2,119 @@ ...@@ -2,52 +2,119 @@
package proxy package proxy
import ( import (
"errors"
"github.com/mholt/caddy/middleware"
"net/http" "net/http"
"net/http/httputil"
"net/url" "net/url"
"strings" "sync/atomic"
"time"
"github.com/mholt/caddy/middleware"
) )
var errUnreachable = errors.New("Unreachable backend")
// Proxy represents a middleware instance that can proxy requests. // Proxy represents a middleware instance that can proxy requests.
type Proxy struct { type Proxy struct {
Next middleware.Handler Next middleware.Handler
Rules []Rule Upstreams []Upstream
}
// An upstream manages a pool of proxy upstream hosts. Select should return a
// suitable upstream host, or nil if no such hosts are available.
type Upstream interface {
// The path this upstream host should be routed on
From() string
// Selects an upstream host to be routed to.
Select() *UpstreamHost
}
type UpstreamHostDownFunc func(*UpstreamHost) bool
// An UpstreamHost represents a single proxy upstream
type UpstreamHost struct {
// The hostname of this upstream host
Name string
ReverseProxy *ReverseProxy
Conns int64
Fails int32
FailTimeout time.Duration
Unhealthy bool
ExtraHeaders http.Header
CheckDown UpstreamHostDownFunc
}
func (uh *UpstreamHost) Down() bool {
if uh.CheckDown == nil {
// Default settings
return uh.Unhealthy || uh.Fails > 0
}
return uh.CheckDown(uh)
} }
// ServeHTTP satisfies the middleware.Handler interface. // ServeHTTP satisfies the middleware.Handler interface.
func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
for _, rule := range p.Rules { for _, upstream := range p.Upstreams {
if middleware.Path(r.URL.Path).Matches(rule.From) { if middleware.Path(r.URL.Path).Matches(upstream.From()) {
var base string var replacer middleware.Replacer
start := time.Now()
requestHost := r.Host
if strings.HasPrefix(rule.To, "http") { // includes https // Since Select() should give us "up" hosts, keep retrying
// destination includes a scheme! no need to guess // hosts until timeout (or until we get a nil host).
base = rule.To for time.Now().Sub(start) < (60 * time.Second) {
} else { host := upstream.Select()
// no scheme specified; assume same as request if host == nil {
var scheme string return http.StatusBadGateway, errUnreachable
if r.TLS == nil {
scheme = "http"
} else {
scheme = "https"
}
base = scheme + "://" + rule.To
} }
proxy := host.ReverseProxy
r.Host = host.Name
baseUrl, err := url.Parse(base) if baseUrl, err := url.Parse(host.Name); err == nil {
if err != nil { r.Host = baseUrl.Host
if proxy == nil {
proxy = NewSingleHostReverseProxy(baseUrl)
}
} else if proxy == nil {
return http.StatusInternalServerError, err return http.StatusInternalServerError, err
} }
r.Host = baseUrl.Host var extraHeaders http.Header
if host.ExtraHeaders != nil {
extraHeaders = make(http.Header)
if replacer == nil {
rHost := r.Host
r.Host = requestHost
replacer = middleware.NewReplacer(r, nil)
r.Host = rHost
}
for header, values := range host.ExtraHeaders {
for _, value := range values {
extraHeaders.Add(header,
replacer.Replace(value))
if header == "Host" {
r.Host = replacer.Replace(value)
}
}
}
}
// TODO: Construct this before; not during every request, if possible atomic.AddInt64(&host.Conns, 1)
proxy := httputil.NewSingleHostReverseProxy(baseUrl) backendErr := proxy.ServeHTTP(w, r, extraHeaders)
proxy.ServeHTTP(w, r) atomic.AddInt64(&host.Conns, -1)
if backendErr == nil {
return 0, nil return 0, nil
} }
timeout := host.FailTimeout
if timeout == 0 {
timeout = 10 * time.Second
}
atomic.AddInt32(&host.Fails, 1)
go func(host *UpstreamHost, timeout time.Duration) {
time.Sleep(timeout)
atomic.AddInt32(&host.Fails, -1)
}(host, timeout)
}
return http.StatusBadGateway, errUnreachable
}
} }
return p.Next.ServeHTTP(w, r) return p.Next.ServeHTTP(w, r)
...@@ -55,30 +122,11 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { ...@@ -55,30 +122,11 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
// New creates a new instance of proxy middleware. // New creates a new instance of proxy middleware.
func New(c middleware.Controller) (middleware.Middleware, error) { func New(c middleware.Controller) (middleware.Middleware, error) {
rules, err := parse(c) if upstreams, err := newStaticUpstreams(c); err == nil {
if err != nil {
return nil, err
}
return func(next middleware.Handler) middleware.Handler { return func(next middleware.Handler) middleware.Handler {
return Proxy{Next: next, Rules: rules} return Proxy{Next: next, Upstreams: upstreams}
}, nil }, nil
} } else {
return nil, err
func parse(c middleware.Controller) ([]Rule, error) {
var rules []Rule
for c.Next() {
var rule Rule
if !c.Args(&rule.From, &rule.To) {
return rules, c.ArgErr()
}
rules = append(rules, rule)
} }
return rules, nil
}
type Rule struct {
From, To string
} }
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// HTTP reverse proxy handler
package proxy
import (
"io"
"net"
"net/http"
"net/url"
"strings"
"sync"
"time"
)
// onExitFlushLoop is a callback set by tests to detect the state of the
// flushLoop() goroutine.
var onExitFlushLoop func()
// ReverseProxy is an HTTP Handler that takes an incoming request and
// sends it to another server, proxying the response back to the
// client.
type ReverseProxy struct {
// Director must be a function which modifies
// the request into a new request to be sent
// using Transport. Its response is then copied
// back to the original client unmodified.
Director func(*http.Request)
// The transport used to perform proxy requests.
// If nil, http.DefaultTransport is used.
Transport http.RoundTripper
// FlushInterval specifies the flush interval
// to flush to the client while copying the
// response body.
// If zero, no periodic flushing is done.
FlushInterval time.Duration
}
func singleJoiningSlash(a, b string) string {
aslash := strings.HasSuffix(a, "/")
bslash := strings.HasPrefix(b, "/")
switch {
case aslash && bslash:
return a + b[1:]
case !aslash && !bslash:
return a + "/" + b
}
return a + b
}
// NewSingleHostReverseProxy returns a new ReverseProxy that rewrites
// URLs to the scheme, host, and base path provided in target. If the
// target's path is "/base" and the incoming request was for "/dir",
// the target request will be for /base/dir.
func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy {
targetQuery := target.RawQuery
director := func(req *http.Request) {
req.URL.Scheme = target.Scheme
req.URL.Host = target.Host
req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
if targetQuery == "" || req.URL.RawQuery == "" {
req.URL.RawQuery = targetQuery + req.URL.RawQuery
} else {
req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
}
}
return &ReverseProxy{Director: director}
}
func copyHeader(dst, src http.Header) {
for k, vv := range src {
for _, v := range vv {
dst.Add(k, v)
}
}
}
// Hop-by-hop headers. These are removed when sent to the backend.
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
var hopHeaders = []string{
"Connection",
"Keep-Alive",
"Proxy-Authenticate",
"Proxy-Authorization",
"Te", // canonicalized version of "TE"
"Trailers",
"Transfer-Encoding",
"Upgrade",
}
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request, extraHeaders http.Header) error {
transport := p.Transport
if transport == nil {
transport = http.DefaultTransport
}
outreq := new(http.Request)
*outreq = *req // includes shallow copies of maps, but okay
p.Director(outreq)
outreq.Proto = "HTTP/1.1"
outreq.ProtoMajor = 1
outreq.ProtoMinor = 1
outreq.Close = false
// Remove hop-by-hop headers to the backend. Especially
// important is "Connection" because we want a persistent
// connection, regardless of what the client sent to us. This
// is modifying the same underlying map from req (shallow
// copied above) so we only copy it if necessary.
copiedHeaders := false
for _, h := range hopHeaders {
if outreq.Header.Get(h) != "" {
if !copiedHeaders {
outreq.Header = make(http.Header)
copyHeader(outreq.Header, req.Header)
copiedHeaders = true
}
outreq.Header.Del(h)
}
}
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
// If we aren't the first proxy retain prior
// X-Forwarded-For information as a comma+space
// separated list and fold multiple headers into one.
if prior, ok := outreq.Header["X-Forwarded-For"]; ok {
clientIP = strings.Join(prior, ", ") + ", " + clientIP
}
outreq.Header.Set("X-Forwarded-For", clientIP)
}
if extraHeaders != nil {
for k, v := range extraHeaders {
outreq.Header[k] = v
}
}
res, err := transport.RoundTrip(outreq)
if err != nil {
return err
}
defer res.Body.Close()
for _, h := range hopHeaders {
res.Header.Del(h)
}
copyHeader(rw.Header(), res.Header)
rw.WriteHeader(res.StatusCode)
p.copyResponse(rw, res.Body)
return nil
}
func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
if p.FlushInterval != 0 {
if wf, ok := dst.(writeFlusher); ok {
mlw := &maxLatencyWriter{
dst: wf,
latency: p.FlushInterval,
done: make(chan bool),
}
go mlw.flushLoop()
defer mlw.stop()
dst = mlw
}
}
io.Copy(dst, src)
}
type writeFlusher interface {
io.Writer
http.Flusher
}
type maxLatencyWriter struct {
dst writeFlusher
latency time.Duration
lk sync.Mutex // protects Write + Flush
done chan bool
}
func (m *maxLatencyWriter) Write(p []byte) (int, error) {
m.lk.Lock()
defer m.lk.Unlock()
return m.dst.Write(p)
}
func (m *maxLatencyWriter) flushLoop() {
t := time.NewTicker(m.latency)
defer t.Stop()
for {
select {
case <-m.done:
if onExitFlushLoop != nil {
onExitFlushLoop()
}
return
case <-t.C:
m.lk.Lock()
m.dst.Flush()
m.lk.Unlock()
}
}
}
func (m *maxLatencyWriter) stop() { m.done <- true }
package proxy
import (
"github.com/mholt/caddy/middleware"
"io"
"io/ioutil"
"net/http"
"net/url"
"strconv"
"strings"
"time"
)
type staticUpstream struct {
from string
Hosts HostPool
Policy Policy
FailTimeout time.Duration
MaxFails int32
HealthCheck struct {
Path string
Interval time.Duration
}
}
func newStaticUpstreams(c middleware.Controller) ([]Upstream, error) {
var upstreams []Upstream
for c.Next() {
upstream := &staticUpstream{
from: "",
Hosts: nil,
Policy: &Random{},
FailTimeout: 10 * time.Second,
MaxFails: 1,
}
var proxyHeaders http.Header
if !c.Args(&upstream.from) {
return upstreams, c.ArgErr()
}
to := c.RemainingArgs()
if len(to) == 0 {
return upstreams, c.ArgErr()
}
for c.NextBlock() {
switch c.Val() {
case "policy":
if !c.NextArg() {
return upstreams, c.ArgErr()
}
switch c.Val() {
case "random":
upstream.Policy = &Random{}
case "round_robin":
upstream.Policy = &RoundRobin{}
case "least_conn":
upstream.Policy = &LeastConn{}
default:
return upstreams, c.ArgErr()
}
case "fail_timeout":
if !c.NextArg() {
return upstreams, c.ArgErr()
}
if dur, err := time.ParseDuration(c.Val()); err == nil {
upstream.FailTimeout = dur
} else {
return upstreams, err
}
case "max_fails":
if !c.NextArg() {
return upstreams, c.ArgErr()
}
if n, err := strconv.Atoi(c.Val()); err == nil {
upstream.MaxFails = int32(n)
} else {
return upstreams, err
}
case "health_check":
if !c.NextArg() {
return upstreams, c.ArgErr()
}
upstream.HealthCheck.Path = c.Val()
upstream.HealthCheck.Interval = 30 * time.Second
if c.NextArg() {
if dur, err := time.ParseDuration(c.Val()); err == nil {
upstream.HealthCheck.Interval = dur
} else {
return upstreams, err
}
}
case "proxy_header":
var header, value string
if !c.Args(&header, &value) {
return upstreams, c.ArgErr()
}
if proxyHeaders == nil {
proxyHeaders = make(map[string][]string)
}
proxyHeaders.Add(header, value)
}
}
upstream.Hosts = make([]*UpstreamHost, len(to))
for i, host := range to {
if !strings.HasPrefix(host, "http") {
host = "http://" + host
}
uh := &UpstreamHost{
Name: host,
Conns: 0,
Fails: 0,
FailTimeout: upstream.FailTimeout,
Unhealthy: false,
ExtraHeaders: proxyHeaders,
CheckDown: func(upstream *staticUpstream) UpstreamHostDownFunc {
return func(uh *UpstreamHost) bool {
if uh.Unhealthy {
return true
}
if uh.Fails >= upstream.MaxFails &&
upstream.MaxFails != 0 {
return true
}
return false
}
}(upstream),
}
if baseUrl, err := url.Parse(uh.Name); err == nil {
uh.ReverseProxy = NewSingleHostReverseProxy(baseUrl)
} else {
return upstreams, err
}
upstream.Hosts[i] = uh
}
if upstream.HealthCheck.Path != "" {
go upstream.healthCheckWorker(nil)
}
upstreams = append(upstreams, upstream)
}
return upstreams, nil
}
func (u *staticUpstream) healthCheck() {
for _, host := range u.Hosts {
hostUrl := host.Name + u.HealthCheck.Path
if r, err := http.Get(hostUrl); err == nil {
io.Copy(ioutil.Discard, r.Body)
r.Body.Close()
host.Unhealthy = r.StatusCode < 200 || r.StatusCode >= 400
} else {
host.Unhealthy = true
}
}
}
func (u *staticUpstream) healthCheckWorker(stop chan struct{}) {
ticker := time.NewTicker(u.HealthCheck.Interval)
u.healthCheck()
for {
select {
case <-ticker.C:
u.healthCheck()
case <-stop:
// TODO: the library should provide a stop channel and global
// waitgroup to allow goroutines started by plugins a chance
// to clean themselves up.
}
}
}
func (u *staticUpstream) From() string {
return u.from
}
func (u *staticUpstream) Select() *UpstreamHost {
pool := u.Hosts
if len(pool) == 1 {
if pool[0].Down() {
return nil
}
return pool[0]
}
allDown := true
for _, host := range pool {
if !host.Down() {
allDown = false
break
}
}
if allDown {
return nil
}
if u.Policy == nil {
return (&Random{}).Select(pool)
} else {
return u.Policy.Select(pool)
}
}
package proxy
import (
"testing"
"time"
)
func TestHealthCheck(t *testing.T) {
upstream := &staticUpstream{
from: "",
Hosts: testPool(),
Policy: &Random{},
FailTimeout: 10 * time.Second,
MaxFails: 1,
}
upstream.healthCheck()
if upstream.Hosts[0].Down() {
t.Error("Expected first host in testpool to not fail healthcheck.")
}
if !upstream.Hosts[1].Down() {
t.Error("Expected second host in testpool to fail healthcheck.")
}
}
func TestSelect(t *testing.T) {
upstream := &staticUpstream{
from: "",
Hosts: testPool()[:3],
Policy: &Random{},
FailTimeout: 10 * time.Second,
MaxFails: 1,
}
upstream.Hosts[0].Unhealthy = true
upstream.Hosts[1].Unhealthy = true
upstream.Hosts[2].Unhealthy = true
if h := upstream.Select(); h != nil {
t.Error("Expected select to return nil as all host are down")
}
upstream.Hosts[2].Unhealthy = false
if h := upstream.Select(); h == nil {
t.Error("Expected select to not return nil")
}
}
...@@ -8,17 +8,21 @@ import ( ...@@ -8,17 +8,21 @@ import (
"time" "time"
) )
// replacer is a type which can replace placeholder // Replacer is a type which can replace placeholder
// substrings in a string with actual values from a // substrings in a string with actual values from a
// http.Request and responseRecorder. Always use // http.Request and responseRecorder. Always use
// NewReplacer to get one of these. // NewReplacer to get one of these.
type Replacer interface {
Replace(string) string
}
type replacer map[string]string type replacer map[string]string
// NewReplacer makes a new replacer based on r and rr. // NewReplacer makes a new replacer based on r and rr.
// Do not create a new replacer until r and rr have all // Do not create a new replacer until r and rr have all
// the needed values, because this function copies those // the needed values, because this function copies those
// values into the replacer. // values into the replacer.
func NewReplacer(r *http.Request, rr *responseRecorder) replacer { func NewReplacer(r *http.Request, rr *responseRecorder) Replacer {
rep := replacer{ rep := replacer{
"{method}": r.Method, "{method}": r.Method,
"{scheme}": func() string { "{scheme}": func() string {
...@@ -33,6 +37,9 @@ func NewReplacer(r *http.Request, rr *responseRecorder) replacer { ...@@ -33,6 +37,9 @@ func NewReplacer(r *http.Request, rr *responseRecorder) replacer {
"{fragment}": r.URL.Fragment, "{fragment}": r.URL.Fragment,
"{proto}": r.Proto, "{proto}": r.Proto,
"{remote}": func() string { "{remote}": func() string {
if fwdFor := r.Header.Get("X-Forwarded-For"); fwdFor != "" {
return fwdFor
}
host, _, err := net.SplitHostPort(r.RemoteAddr) host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil { if err != nil {
return r.RemoteAddr return r.RemoteAddr
...@@ -50,9 +57,11 @@ func NewReplacer(r *http.Request, rr *responseRecorder) replacer { ...@@ -50,9 +57,11 @@ func NewReplacer(r *http.Request, rr *responseRecorder) replacer {
"{when}": func() string { "{when}": func() string {
return time.Now().Format(timeFormat) return time.Now().Format(timeFormat)
}(), }(),
"{status}": strconv.Itoa(rr.status), }
"{size}": strconv.Itoa(rr.size), if rr != nil {
"{latency}": time.Since(rr.start).String(), rep["{status}"] = strconv.Itoa(rr.status)
rep["{size}"] = strconv.Itoa(rr.size)
rep["{latency}"] = time.Since(rr.start).String()
} }
// Header placeholders // Header placeholders
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment