Commit 88d3dcae authored by Kris Hamoud's avatar Kris Hamoud

added ip_hash load balancing

updated tests

fixed comment format

fixed formatting, minor logic fix

added newline to EOF

updated logic, fixed tests

added comment

updated formatting

updated test output

fixed typo
parent 72af3f82
package proxy package proxy
import ( import (
"hash/fnv"
"math" "math"
"math/rand" "math/rand"
"net"
"net/http"
"sync" "sync"
) )
...@@ -11,20 +14,21 @@ type HostPool []*UpstreamHost ...@@ -11,20 +14,21 @@ type HostPool []*UpstreamHost
// Policy decides how a host will be selected from a pool. // Policy decides how a host will be selected from a pool.
type Policy interface { type Policy interface {
Select(pool HostPool) *UpstreamHost Select(pool HostPool, r *http.Request) *UpstreamHost
} }
func init() { func init() {
RegisterPolicy("random", func() Policy { return &Random{} }) RegisterPolicy("random", func() Policy { return &Random{} })
RegisterPolicy("least_conn", func() Policy { return &LeastConn{} }) RegisterPolicy("least_conn", func() Policy { return &LeastConn{} })
RegisterPolicy("round_robin", func() Policy { return &RoundRobin{} }) RegisterPolicy("round_robin", func() Policy { return &RoundRobin{} })
RegisterPolicy("ip_hash", func() Policy { return &IPHash{} })
} }
// Random is a policy that selects up hosts from a pool at random. // Random is a policy that selects up hosts from a pool at random.
type Random struct{} type Random struct{}
// Select selects an up host at random from the specified pool. // Select selects an up host at random from the specified pool.
func (r *Random) Select(pool HostPool) *UpstreamHost { func (r *Random) Select(pool HostPool, request *http.Request) *UpstreamHost {
// Because the number of available hosts isn't known // Because the number of available hosts isn't known
// up front, the host is selected via reservoir sampling // up front, the host is selected via reservoir sampling
...@@ -53,7 +57,7 @@ type LeastConn struct{} ...@@ -53,7 +57,7 @@ type LeastConn struct{}
// Select selects the up host with the least number of connections in the // Select selects the up host with the least number of connections in the
// pool. If more than one host has the same least number of connections, // pool. If more than one host has the same least number of connections,
// one of the hosts is chosen at random. // one of the hosts is chosen at random.
func (r *LeastConn) Select(pool HostPool) *UpstreamHost { func (r *LeastConn) Select(pool HostPool, request *http.Request) *UpstreamHost {
var bestHost *UpstreamHost var bestHost *UpstreamHost
count := 0 count := 0
leastConn := int64(math.MaxInt64) leastConn := int64(math.MaxInt64)
...@@ -86,7 +90,7 @@ type RoundRobin struct { ...@@ -86,7 +90,7 @@ type RoundRobin struct {
} }
// Select selects an up host from the pool using a round robin ordering scheme. // Select selects an up host from the pool using a round robin ordering scheme.
func (r *RoundRobin) Select(pool HostPool) *UpstreamHost { func (r *RoundRobin) Select(pool HostPool, request *http.Request) *UpstreamHost {
poolLen := uint32(len(pool)) poolLen := uint32(len(pool))
r.mutex.Lock() r.mutex.Lock()
defer r.mutex.Unlock() defer r.mutex.Unlock()
...@@ -100,3 +104,35 @@ func (r *RoundRobin) Select(pool HostPool) *UpstreamHost { ...@@ -100,3 +104,35 @@ func (r *RoundRobin) Select(pool HostPool) *UpstreamHost {
} }
return nil return nil
} }
// IPHash is a policy that selects hosts based on hashing the request ip
type IPHash struct{}
func hash(s string) uint32 {
h := fnv.New32a()
h.Write([]byte(s))
return h.Sum32()
}
// Select selects an up host from the pool using a round robin ordering scheme.
func (r *IPHash) Select(pool HostPool, request *http.Request) *UpstreamHost {
poolLen := uint32(len(pool))
clientIP, _, err := net.SplitHostPort(request.RemoteAddr)
if err != nil {
clientIP = request.RemoteAddr
}
hash := hash(clientIP)
for {
if poolLen == 0 {
break
}
index := hash % poolLen
host := pool[index]
if host.Available() {
return host
}
pool = append(pool[:index], pool[index+1:]...)
poolLen--
}
return nil
}
...@@ -21,7 +21,7 @@ func TestMain(m *testing.M) { ...@@ -21,7 +21,7 @@ func TestMain(m *testing.M) {
type customPolicy struct{} type customPolicy struct{}
func (r *customPolicy) Select(pool HostPool) *UpstreamHost { func (r *customPolicy) Select(pool HostPool, request *http.Request) *UpstreamHost {
return pool[0] return pool[0]
} }
...@@ -43,37 +43,39 @@ func testPool() HostPool { ...@@ -43,37 +43,39 @@ func testPool() HostPool {
func TestRoundRobinPolicy(t *testing.T) { func TestRoundRobinPolicy(t *testing.T) {
pool := testPool() pool := testPool()
rrPolicy := &RoundRobin{} rrPolicy := &RoundRobin{}
h := rrPolicy.Select(pool) request, _ := http.NewRequest("GET", "/", nil)
h := rrPolicy.Select(pool, request)
// First selected host is 1, because counter starts at 0 // First selected host is 1, because counter starts at 0
// and increments before host is selected // and increments before host is selected
if h != pool[1] { if h != pool[1] {
t.Error("Expected first round robin host to be second host in the pool.") t.Error("Expected first round robin host to be second host in the pool.")
} }
h = rrPolicy.Select(pool) h = rrPolicy.Select(pool, request)
if h != pool[2] { if h != pool[2] {
t.Error("Expected second round robin host to be third host in the pool.") t.Error("Expected second round robin host to be third host in the pool.")
} }
h = rrPolicy.Select(pool) h = rrPolicy.Select(pool, request)
if h != pool[0] { if h != pool[0] {
t.Error("Expected third round robin host to be first host in the pool.") t.Error("Expected third round robin host to be first host in the pool.")
} }
// mark host as down // mark host as down
pool[1].Unhealthy = true pool[1].Unhealthy = true
h = rrPolicy.Select(pool) h = rrPolicy.Select(pool, request)
if h != pool[2] { if h != pool[2] {
t.Error("Expected to skip down host.") t.Error("Expected to skip down host.")
} }
// mark host as up // mark host as up
pool[1].Unhealthy = false pool[1].Unhealthy = false
h = rrPolicy.Select(pool) h = rrPolicy.Select(pool, request)
if h == pool[2] { if h == pool[2] {
t.Error("Expected to balance evenly among healthy hosts") t.Error("Expected to balance evenly among healthy hosts")
} }
// mark host as full // mark host as full
pool[1].Conns = 1 pool[1].Conns = 1
pool[1].MaxConns = 1 pool[1].MaxConns = 1
h = rrPolicy.Select(pool) h = rrPolicy.Select(pool, request)
if h != pool[2] { if h != pool[2] {
t.Error("Expected to skip full host.") t.Error("Expected to skip full host.")
} }
...@@ -82,14 +84,16 @@ func TestRoundRobinPolicy(t *testing.T) { ...@@ -82,14 +84,16 @@ func TestRoundRobinPolicy(t *testing.T) {
func TestLeastConnPolicy(t *testing.T) { func TestLeastConnPolicy(t *testing.T) {
pool := testPool() pool := testPool()
lcPolicy := &LeastConn{} lcPolicy := &LeastConn{}
request, _ := http.NewRequest("GET", "/", nil)
pool[0].Conns = 10 pool[0].Conns = 10
pool[1].Conns = 10 pool[1].Conns = 10
h := lcPolicy.Select(pool) h := lcPolicy.Select(pool, request)
if h != pool[2] { if h != pool[2] {
t.Error("Expected least connection host to be third host.") t.Error("Expected least connection host to be third host.")
} }
pool[2].Conns = 100 pool[2].Conns = 100
h = lcPolicy.Select(pool) h = lcPolicy.Select(pool, request)
if h != pool[0] && h != pool[1] { if h != pool[0] && h != pool[1] {
t.Error("Expected least connection host to be first or second host.") t.Error("Expected least connection host to be first or second host.")
} }
...@@ -98,8 +102,127 @@ func TestLeastConnPolicy(t *testing.T) { ...@@ -98,8 +102,127 @@ func TestLeastConnPolicy(t *testing.T) {
func TestCustomPolicy(t *testing.T) { func TestCustomPolicy(t *testing.T) {
pool := testPool() pool := testPool()
customPolicy := &customPolicy{} customPolicy := &customPolicy{}
h := customPolicy.Select(pool) request, _ := http.NewRequest("GET", "/", nil)
h := customPolicy.Select(pool, request)
if h != pool[0] { if h != pool[0] {
t.Error("Expected custom policy host to be the first host.") t.Error("Expected custom policy host to be the first host.")
} }
} }
func TestIPHashPolicy(t *testing.T) {
pool := testPool()
ipHash := &IPHash{}
request, _ := http.NewRequest("GET", "/", nil)
// We should be able to predict where every request is routed.
request.RemoteAddr = "172.0.0.1:80"
h := ipHash.Select(pool, request)
if h != pool[1] {
t.Error("Expected ip hash policy host to be the second host.")
}
request.RemoteAddr = "172.0.0.2:80"
h = ipHash.Select(pool, request)
if h != pool[1] {
t.Error("Expected ip hash policy host to be the second host.")
}
request.RemoteAddr = "172.0.0.3:80"
h = ipHash.Select(pool, request)
if h != pool[2] {
t.Error("Expected ip hash policy host to be the third host.")
}
request.RemoteAddr = "172.0.0.4:80"
h = ipHash.Select(pool, request)
if h != pool[1] {
t.Error("Expected ip hash policy host to be the second host.")
}
// we should get the same results without a port
request.RemoteAddr = "172.0.0.1"
h = ipHash.Select(pool, request)
if h != pool[1] {
t.Error("Expected ip hash policy host to be the second host.")
}
request.RemoteAddr = "172.0.0.2"
h = ipHash.Select(pool, request)
if h != pool[1] {
t.Error("Expected ip hash policy host to be the second host.")
}
request.RemoteAddr = "172.0.0.3"
h = ipHash.Select(pool, request)
if h != pool[2] {
t.Error("Expected ip hash policy host to be the third host.")
}
request.RemoteAddr = "172.0.0.4"
h = ipHash.Select(pool, request)
if h != pool[1] {
t.Error("Expected ip hash policy host to be the second host.")
}
// we should get a healthy host if the original host is unhealthy and a
// healthy host is available
request.RemoteAddr = "172.0.0.1"
pool[1].Unhealthy = true
h = ipHash.Select(pool, request)
if h != pool[0] {
t.Error("Expected ip hash policy host to be the first host.")
}
request.RemoteAddr = "172.0.0.2"
h = ipHash.Select(pool, request)
if h != pool[1] {
t.Error("Expected ip hash policy host to be the second host.")
}
pool[1].Unhealthy = false
request.RemoteAddr = "172.0.0.3"
pool[2].Unhealthy = true
h = ipHash.Select(pool, request)
if h != pool[0] {
t.Error("Expected ip hash policy host to be the first host.")
}
request.RemoteAddr = "172.0.0.4"
h = ipHash.Select(pool, request)
if h != pool[0] {
t.Error("Expected ip hash policy host to be the first host.")
}
// We should be able to resize the host pool and still be able to predict
// where a request will be routed with the same IP's used above
pool = []*UpstreamHost{
{
Name: workableServer.URL, // this should resolve (healthcheck test)
},
{
Name: "http://localhost:99998", // this shouldn't
},
}
pool = HostPool(pool)
request.RemoteAddr = "172.0.0.1:80"
h = ipHash.Select(pool, request)
if h != pool[0] {
t.Error("Expected ip hash policy host to be the first host.")
}
request.RemoteAddr = "172.0.0.2:80"
h = ipHash.Select(pool, request)
if h != pool[1] {
t.Error("Expected ip hash policy host to be the second host.")
}
request.RemoteAddr = "172.0.0.3:80"
h = ipHash.Select(pool, request)
if h != pool[0] {
t.Error("Expected ip hash policy host to be the first host.")
}
request.RemoteAddr = "172.0.0.4:80"
h = ipHash.Select(pool, request)
if h != pool[1] {
t.Error("Expected ip hash policy host to be the second host.")
}
// We should get nil when there are no healthy hosts
pool[0].Unhealthy = true
pool[1].Unhealthy = true
h = ipHash.Select(pool, request)
if h != nil {
t.Error("Expected ip hash policy host to be nil.")
}
}
...@@ -27,7 +27,7 @@ type Upstream interface { ...@@ -27,7 +27,7 @@ type Upstream interface {
// The path this upstream host should be routed on // The path this upstream host should be routed on
From() string From() string
// Selects an upstream host to be routed to. // Selects an upstream host to be routed to.
Select() *UpstreamHost Select(*http.Request) *UpstreamHost
// Checks if subpath is not an ignored path // Checks if subpath is not an ignored path
AllowedPath(string) bool AllowedPath(string) bool
} }
...@@ -93,7 +93,7 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { ...@@ -93,7 +93,7 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
// hosts until timeout (or until we get a nil host). // hosts until timeout (or until we get a nil host).
start := time.Now() start := time.Now()
for time.Now().Sub(start) < tryDuration { for time.Now().Sub(start) < tryDuration {
host := upstream.Select() host := upstream.Select(r)
if host == nil { if host == nil {
return http.StatusBadGateway, errUnreachable return http.StatusBadGateway, errUnreachable
} }
......
...@@ -736,7 +736,7 @@ func (u *fakeUpstream) From() string { ...@@ -736,7 +736,7 @@ func (u *fakeUpstream) From() string {
return u.from return u.from
} }
func (u *fakeUpstream) Select() *UpstreamHost { func (u *fakeUpstream) Select(r *http.Request) *UpstreamHost {
if u.host == nil { if u.host == nil {
uri, err := url.Parse(u.name) uri, err := url.Parse(u.name)
if err != nil { if err != nil {
...@@ -781,7 +781,7 @@ func (u *fakeWsUpstream) From() string { ...@@ -781,7 +781,7 @@ func (u *fakeWsUpstream) From() string {
return "/" return "/"
} }
func (u *fakeWsUpstream) Select() *UpstreamHost { func (u *fakeWsUpstream) Select(r *http.Request) *UpstreamHost {
uri, _ := url.Parse(u.name) uri, _ := url.Parse(u.name)
return &UpstreamHost{ return &UpstreamHost{
Name: u.name, Name: u.name,
......
...@@ -346,7 +346,7 @@ func (u *staticUpstream) HealthCheckWorker(stop chan struct{}) { ...@@ -346,7 +346,7 @@ func (u *staticUpstream) HealthCheckWorker(stop chan struct{}) {
} }
} }
func (u *staticUpstream) Select() *UpstreamHost { func (u *staticUpstream) Select(r *http.Request) *UpstreamHost {
pool := u.Hosts pool := u.Hosts
if len(pool) == 1 { if len(pool) == 1 {
if !pool[0].Available() { if !pool[0].Available() {
...@@ -364,11 +364,10 @@ func (u *staticUpstream) Select() *UpstreamHost { ...@@ -364,11 +364,10 @@ func (u *staticUpstream) Select() *UpstreamHost {
if allUnavailable { if allUnavailable {
return nil return nil
} }
if u.Policy == nil { if u.Policy == nil {
return (&Random{}).Select(pool) return (&Random{}).Select(pool, r)
} }
return u.Policy.Select(pool) return u.Policy.Select(pool, r)
} }
func (u *staticUpstream) AllowedPath(requestPath string) bool { func (u *staticUpstream) AllowedPath(requestPath string) bool {
......
package proxy package proxy
import ( import (
"github.com/mholt/caddy/caddyfile"
"net/http"
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/mholt/caddy/caddyfile"
) )
func TestNewHost(t *testing.T) { func TestNewHost(t *testing.T) {
...@@ -72,14 +72,15 @@ func TestSelect(t *testing.T) { ...@@ -72,14 +72,15 @@ func TestSelect(t *testing.T) {
FailTimeout: 10 * time.Second, FailTimeout: 10 * time.Second,
MaxFails: 1, MaxFails: 1,
} }
r, _ := http.NewRequest("GET", "/", nil)
upstream.Hosts[0].Unhealthy = true upstream.Hosts[0].Unhealthy = true
upstream.Hosts[1].Unhealthy = true upstream.Hosts[1].Unhealthy = true
upstream.Hosts[2].Unhealthy = true upstream.Hosts[2].Unhealthy = true
if h := upstream.Select(); h != nil { if h := upstream.Select(r); h != nil {
t.Error("Expected select to return nil as all host are down") t.Error("Expected select to return nil as all host are down")
} }
upstream.Hosts[2].Unhealthy = false upstream.Hosts[2].Unhealthy = false
if h := upstream.Select(); h == nil { if h := upstream.Select(r); h == nil {
t.Error("Expected select to not return nil") t.Error("Expected select to not return nil")
} }
upstream.Hosts[0].Conns = 1 upstream.Hosts[0].Conns = 1
...@@ -88,11 +89,11 @@ func TestSelect(t *testing.T) { ...@@ -88,11 +89,11 @@ func TestSelect(t *testing.T) {
upstream.Hosts[1].MaxConns = 1 upstream.Hosts[1].MaxConns = 1
upstream.Hosts[2].Conns = 1 upstream.Hosts[2].Conns = 1
upstream.Hosts[2].MaxConns = 1 upstream.Hosts[2].MaxConns = 1
if h := upstream.Select(); h != nil { if h := upstream.Select(r); h != nil {
t.Error("Expected select to return nil as all hosts are full") t.Error("Expected select to return nil as all hosts are full")
} }
upstream.Hosts[2].Conns = 0 upstream.Hosts[2].Conns = 0
if h := upstream.Select(); h == nil { if h := upstream.Select(r); h == nil {
t.Error("Expected select to not return nil") t.Error("Expected select to not return nil")
} }
} }
...@@ -188,6 +189,7 @@ func TestParseBlockHealthCheck(t *testing.T) { ...@@ -188,6 +189,7 @@ func TestParseBlockHealthCheck(t *testing.T) {
} }
func TestParseBlock(t *testing.T) { func TestParseBlock(t *testing.T) {
r, _ := http.NewRequest("GET", "/", nil)
tests := []struct { tests := []struct {
config string config string
}{ }{
...@@ -207,7 +209,7 @@ func TestParseBlock(t *testing.T) { ...@@ -207,7 +209,7 @@ func TestParseBlock(t *testing.T) {
t.Error("Expected no error. Got:", err.Error()) t.Error("Expected no error. Got:", err.Error())
} }
for _, upstream := range upstreams { for _, upstream := range upstreams {
headers := upstream.Select().UpstreamHeaders headers := upstream.Select(r).UpstreamHeaders
if _, ok := headers["Host"]; !ok { if _, ok := headers["Host"]; !ok {
t.Errorf("Test %d: Could not find the Host header", i+1) t.Errorf("Test %d: Could not find the Host header", i+1)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment