Commit 59bf71c2 authored by Angel Santiago's avatar Angel Santiago Committed by Matt Holt

proxy: Cleanly shutdown health checks on restart (#1524)

* Add a shutdown function and context to staticUpstream so that running goroutines can be cancelled. Add a GetShutdownFunc to Upstream interface to expose the shutdown function to the caddy Controller for performing it on restarts.

* Make fakeUpstream implement new Upstream methods.

Implement new Upstream method for fakeWSUpstream as well.

* Rename GetShutdownFunc to Stop(). Add a waitgroup to the staticUpstream for controlling individual object's goroutines. Add the Stop function to OnRestart and OnShutdown. Add tests for checking to see if healthchecks continue hitting a backend server after stop has been called.

* Go back to using a stop channel since the context adds no additional benefit.
Only register stop function for onShutdown since it's called as part of restart.

* Remove assignment to atomic value

* Incrementing WaitGroup outside of goroutine to avoid race condition. Loading atomic values in test.

* Linting: change counter to just use the default zero value instead of setting it

* Clarify Stop method comments, add comments to stop channel and waitgroup and remove out of date comment about handling stopping the proxy. Stop the ticker when the stop signal is sent
parent 464ade1d
...@@ -42,6 +42,9 @@ type Upstream interface { ...@@ -42,6 +42,9 @@ type Upstream interface {
// Gets the number of upstream hosts. // Gets the number of upstream hosts.
GetHostCount() int GetHostCount() int
// Stops the upstream from proxying requests to shutdown goroutines cleanly.
Stop() error
} }
// UpstreamHostDownFunc can be used to customize how Down behaves. // UpstreamHostDownFunc can be used to customize how Down behaves.
......
...@@ -1216,6 +1216,7 @@ func (u *fakeUpstream) AllowedPath(requestPath string) bool { return true } ...@@ -1216,6 +1216,7 @@ func (u *fakeUpstream) AllowedPath(requestPath string) bool { return true }
func (u *fakeUpstream) GetTryDuration() time.Duration { return 1 * time.Second } func (u *fakeUpstream) GetTryDuration() time.Duration { return 1 * time.Second }
func (u *fakeUpstream) GetTryInterval() time.Duration { return 250 * time.Millisecond } func (u *fakeUpstream) GetTryInterval() time.Duration { return 250 * time.Millisecond }
func (u *fakeUpstream) GetHostCount() int { return 1 } func (u *fakeUpstream) GetHostCount() int { return 1 }
func (u *fakeUpstream) Stop() error { return nil }
// newWebSocketTestProxy returns a test proxy that will // newWebSocketTestProxy returns a test proxy that will
// redirect to the specified backendAddr. The function // redirect to the specified backendAddr. The function
...@@ -1268,6 +1269,7 @@ func (u *fakeWsUpstream) AllowedPath(requestPath string) bool { return true } ...@@ -1268,6 +1269,7 @@ func (u *fakeWsUpstream) AllowedPath(requestPath string) bool { return true }
func (u *fakeWsUpstream) GetTryDuration() time.Duration { return 1 * time.Second } func (u *fakeWsUpstream) GetTryDuration() time.Duration { return 1 * time.Second }
func (u *fakeWsUpstream) GetTryInterval() time.Duration { return 250 * time.Millisecond } func (u *fakeWsUpstream) GetTryInterval() time.Duration { return 250 * time.Millisecond }
func (u *fakeWsUpstream) GetHostCount() int { return 1 } func (u *fakeWsUpstream) GetHostCount() int { return 1 }
func (u *fakeWsUpstream) Stop() error { return nil }
// recorderHijacker is a ResponseRecorder that can // recorderHijacker is a ResponseRecorder that can
// be hijacked. // be hijacked.
......
...@@ -21,5 +21,11 @@ func setup(c *caddy.Controller) error { ...@@ -21,5 +21,11 @@ func setup(c *caddy.Controller) error {
httpserver.GetConfig(c).AddMiddleware(func(next httpserver.Handler) httpserver.Handler { httpserver.GetConfig(c).AddMiddleware(func(next httpserver.Handler) httpserver.Handler {
return Proxy{Next: next, Upstreams: upstreams} return Proxy{Next: next, Upstreams: upstreams}
}) })
// Register shutdown handlers.
for _, upstream := range upstreams {
c.OnShutdown(upstream.Stop)
}
return nil return nil
} }
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
"path" "path"
"strconv" "strconv"
"strings" "strings"
"sync"
"sync/atomic" "sync/atomic"
"time" "time"
...@@ -24,6 +25,8 @@ type staticUpstream struct { ...@@ -24,6 +25,8 @@ type staticUpstream struct {
from string from string
upstreamHeaders http.Header upstreamHeaders http.Header
downstreamHeaders http.Header downstreamHeaders http.Header
stop chan struct{} // Signals running goroutines to stop.
wg sync.WaitGroup // Used to wait for running goroutines to stop.
Hosts HostPool Hosts HostPool
Policy Policy Policy Policy
KeepAlive int KeepAlive int
...@@ -48,8 +51,10 @@ type staticUpstream struct { ...@@ -48,8 +51,10 @@ type staticUpstream struct {
func NewStaticUpstreams(c caddyfile.Dispenser) ([]Upstream, error) { func NewStaticUpstreams(c caddyfile.Dispenser) ([]Upstream, error) {
var upstreams []Upstream var upstreams []Upstream
for c.Next() { for c.Next() {
upstream := &staticUpstream{ upstream := &staticUpstream{
from: "", from: "",
stop: make(chan struct{}),
upstreamHeaders: make(http.Header), upstreamHeaders: make(http.Header),
downstreamHeaders: make(http.Header), downstreamHeaders: make(http.Header),
Hosts: nil, Hosts: nil,
...@@ -108,7 +113,11 @@ func NewStaticUpstreams(c caddyfile.Dispenser) ([]Upstream, error) { ...@@ -108,7 +113,11 @@ func NewStaticUpstreams(c caddyfile.Dispenser) ([]Upstream, error) {
upstream.HealthCheck.Client = http.Client{ upstream.HealthCheck.Client = http.Client{
Timeout: upstream.HealthCheck.Timeout, Timeout: upstream.HealthCheck.Timeout,
} }
go upstream.HealthCheckWorker(nil) upstream.wg.Add(1)
go func() {
defer upstream.wg.Done()
upstream.HealthCheckWorker(upstream.stop)
}()
} }
upstreams = append(upstreams, upstream) upstreams = append(upstreams, upstream)
} }
...@@ -380,9 +389,8 @@ func (u *staticUpstream) HealthCheckWorker(stop chan struct{}) { ...@@ -380,9 +389,8 @@ func (u *staticUpstream) HealthCheckWorker(stop chan struct{}) {
case <-ticker.C: case <-ticker.C:
u.healthCheck() u.healthCheck()
case <-stop: case <-stop:
// TODO: the library should provide a stop channel and global ticker.Stop()
// waitgroup to allow goroutines started by plugins a chance return
// to clean themselves up.
} }
} }
} }
...@@ -434,6 +442,14 @@ func (u *staticUpstream) GetHostCount() int { ...@@ -434,6 +442,14 @@ func (u *staticUpstream) GetHostCount() int {
return len(u.Hosts) return len(u.Hosts)
} }
// Stop sends a signal to all goroutines started by this staticUpstream to exit
// and waits for them to finish before returning.
func (u *staticUpstream) Stop() error {
close(u.stop)
u.wg.Wait()
return nil
}
// RegisterPolicy adds a custom policy to the proxy. // RegisterPolicy adds a custom policy to the proxy.
func RegisterPolicy(name string, policy func() Policy) { func RegisterPolicy(name string, policy func() Policy) {
supportedPolicies[name] = policy supportedPolicies[name] = policy
......
package proxy package proxy
import ( import (
"fmt"
"net/http" "net/http"
"net/http/httptest"
"strings" "strings"
"sync/atomic"
"testing" "testing"
"time" "time"
...@@ -189,6 +192,75 @@ func TestParseBlockHealthCheck(t *testing.T) { ...@@ -189,6 +192,75 @@ func TestParseBlockHealthCheck(t *testing.T) {
} }
} }
func TestStop(t *testing.T) {
config := "proxy / %s {\n health_check /healthcheck \nhealth_check_interval %dms \n}"
tests := []struct {
name string
intervalInMilliseconds int
numHealthcheckIntervals int
}{
{
"No Healthchecks After Stop - 5ms, 1 intervals",
5,
1,
},
{
"No Healthchecks After Stop - 5ms, 2 intervals",
5,
2,
},
{
"No Healthchecks After Stop - 5ms, 3 intervals",
5,
3,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
// Set up proxy.
var counter int64
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.Body.Close()
atomic.AddInt64(&counter, 1)
}))
defer backend.Close()
upstreams, err := NewStaticUpstreams(caddyfile.NewDispenser("Testfile", strings.NewReader(fmt.Sprintf(config, backend.URL, test.intervalInMilliseconds))))
if err != nil {
t.Error("Expected no error. Got:", err.Error())
}
// Give some time for healthchecks to hit the server.
time.Sleep(time.Duration(test.intervalInMilliseconds*test.numHealthcheckIntervals) * time.Millisecond)
for _, upstream := range upstreams {
if err := upstream.Stop(); err != nil {
t.Error("Expected no error stopping upstream. Got: ", err.Error())
}
}
counterValueAfterShutdown := atomic.LoadInt64(&counter)
// Give some time to see if healthchecks are still hitting the server.
time.Sleep(time.Duration(test.intervalInMilliseconds*test.numHealthcheckIntervals) * time.Millisecond)
if counterValueAfterShutdown == 0 {
t.Error("Expected healthchecks to hit test server. Got no healthchecks.")
}
counterValueAfterWaiting := atomic.LoadInt64(&counter)
if counterValueAfterWaiting != counterValueAfterShutdown {
t.Errorf("Expected no more healthchecks after shutdown. Got: %d healthchecks after shutdown", counterValueAfterWaiting-counterValueAfterShutdown)
}
})
}
}
func TestParseBlock(t *testing.T) { func TestParseBlock(t *testing.T) {
r, _ := http.NewRequest("GET", "/", nil) r, _ := http.NewRequest("GET", "/", nil)
tests := []struct { tests := []struct {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment