Commit 524dcee9 authored by Matt Holt's avatar Matt Holt Committed by GitHub

Merge pull request #1373 from mholt/go18shutdown

Replace our old faithful gracefulListener with Go 1.8's Shutdown()
parents 58b2edd2 0cc48e84
package httpserver
import (
"net"
"sync"
"syscall"
)
// TODO: Should this be a generic graceful listener available in its own package or something?
// Also, passing in a WaitGroup is a little awkward. Why can't this listener just keep
// the waitgroup internal to itself?
// newGracefulListener returns a gracefulListener that wraps l and
// uses wg (stored in the host server) to count connections.
func newGracefulListener(l net.Listener, wg *sync.WaitGroup) *gracefulListener {
gl := &gracefulListener{Listener: l, stop: make(chan error), connWg: wg}
go func() {
<-gl.stop
gl.Lock()
gl.stopped = true
gl.Unlock()
gl.stop <- gl.Listener.Close()
}()
return gl
}
// gracefuListener is a net.Listener which can
// count the number of connections on it. Its
// methods mainly wrap net.Listener to be graceful.
type gracefulListener struct {
net.Listener
stop chan error
stopped bool
sync.Mutex // protects the stopped flag
connWg *sync.WaitGroup // pointer to the host's wg used for counting connections
}
// Accept accepts a connection.
func (gl *gracefulListener) Accept() (c net.Conn, err error) {
c, err = gl.Listener.Accept()
if err != nil {
return
}
c = gracefulConn{Conn: c, connWg: gl.connWg}
gl.connWg.Add(1)
return
}
// Close immediately closes the listener.
func (gl *gracefulListener) Close() error {
gl.Lock()
if gl.stopped {
gl.Unlock()
return syscall.EINVAL
}
gl.Unlock()
gl.stop <- nil
return <-gl.stop
}
// gracefulConn represents a connection on a
// gracefulListener so that we can keep track
// of the number of connections, thus facilitating
// a graceful shutdown.
type gracefulConn struct {
net.Conn
connWg *sync.WaitGroup // pointer to the host server's connection waitgroup
}
// Close closes c's underlying connection while updating the wg count.
func (c gracefulConn) Close() error {
err := c.Conn.Close()
if err != nil {
return err
}
// close can fail on http2 connections (as of Oct. 2015, before http2 in std lib)
// so don't decrement count unless close succeeds
c.connWg.Done()
return nil
}
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
package httpserver package httpserver
import ( import (
"context"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"io" "io"
...@@ -27,9 +28,8 @@ type Server struct { ...@@ -27,9 +28,8 @@ type Server struct {
listener net.Listener listener net.Listener
listenerMu sync.Mutex listenerMu sync.Mutex
sites []*SiteConfig sites []*SiteConfig
connTimeout time.Duration // max time to wait for a connection before force stop connTimeout time.Duration // max time to wait for a connection before force stop
connWg sync.WaitGroup // one increment per connection tlsGovChan chan struct{} // close to stop the TLS maintenance goroutine
tlsGovChan chan struct{} // close to stop the TLS maintenance goroutine
vhosts *vhostTrie vhosts *vhostTrie
} }
...@@ -46,16 +46,6 @@ func NewServer(addr string, group []*SiteConfig) (*Server, error) { ...@@ -46,16 +46,6 @@ func NewServer(addr string, group []*SiteConfig) (*Server, error) {
connTimeout: GracefulTimeout, connTimeout: GracefulTimeout,
} }
s.Server.Handler = s // this is weird, but whatever s.Server.Handler = s // this is weird, but whatever
s.Server.ConnState = func(c net.Conn, cs http.ConnState) {
if cs == http.StateIdle {
s.listenerMu.Lock()
// server stopped, close idle connection
if s.listener == nil {
c.Close()
}
s.listenerMu.Unlock()
}
}
// Disable HTTP/2 if desired // Disable HTTP/2 if desired
if !HTTP2 { if !HTTP2 {
...@@ -68,14 +58,6 @@ func NewServer(addr string, group []*SiteConfig) (*Server, error) { ...@@ -68,14 +58,6 @@ func NewServer(addr string, group []*SiteConfig) (*Server, error) {
s.Server.Handler = s.wrapWithSvcHeaders(s.Server.Handler) s.Server.Handler = s.wrapWithSvcHeaders(s.Server.Handler)
} }
// We have to bound our wg with one increment
// to prevent a "race condition" that is hard-coded
// into sync.WaitGroup.Wait() - basically, an add
// with a positive delta must be guaranteed to
// occur before Wait() is called on the wg.
// In a way, this kind of acts as a safety barrier.
s.connWg.Add(1)
// Set up TLS configuration // Set up TLS configuration
var tlsConfigs []*caddytls.Config var tlsConfigs []*caddytls.Config
for _, site := range group { for _, site := range group {
...@@ -163,8 +145,6 @@ func (s *Server) Serve(ln net.Listener) error { ...@@ -163,8 +145,6 @@ func (s *Server) Serve(ln net.Listener) error {
ln = tcpKeepAliveListener{TCPListener: tcpLn} ln = tcpKeepAliveListener{TCPListener: tcpLn}
} }
ln = newGracefulListener(ln, &s.connWg)
s.listenerMu.Lock() s.listenerMu.Lock()
s.listener = ln s.listener = ln
s.listenerMu.Unlock() s.listenerMu.Unlock()
...@@ -306,40 +286,21 @@ func (s *Server) Address() string { ...@@ -306,40 +286,21 @@ func (s *Server) Address() string {
// Stop stops s gracefully (or forcefully after timeout) and // Stop stops s gracefully (or forcefully after timeout) and
// closes its listener. // closes its listener.
func (s *Server) Stop() (err error) { func (s *Server) Stop() error {
s.Server.SetKeepAlivesEnabled(false) ctx, cancel := context.WithTimeout(context.Background(), s.connTimeout)
defer cancel()
if runtime.GOOS != "windows" {
// force connections to close after timeout
done := make(chan struct{})
go func() {
s.connWg.Done() // decrement our initial increment used as a barrier
s.connWg.Wait()
close(done)
}()
// Wait for remaining connections to finish or
// force them all to close after timeout
select {
case <-time.After(s.connTimeout):
case <-done:
}
}
// Close the listener now; this stops the server without delay err := s.Server.Shutdown(ctx)
s.listenerMu.Lock() if err != nil {
if s.listener != nil { return err
err = s.listener.Close()
s.listener = nil
} }
s.listenerMu.Unlock()
// Closing this signals any TLS governor goroutines to exit // signal any TLS governor goroutines to exit
if s.tlsGovChan != nil { if s.tlsGovChan != nil {
close(s.tlsGovChan) close(s.tlsGovChan)
} }
return return nil
} }
// sanitizePath collapses any ./ ../ /// madness // sanitizePath collapses any ./ ../ /// madness
...@@ -439,11 +400,10 @@ func makeHTTPServer(addr string, group []*SiteConfig) *http.Server { ...@@ -439,11 +400,10 @@ func makeHTTPServer(addr string, group []*SiteConfig) *http.Server {
} }
// set the final values on the server // set the final values on the server
// TODO: ReadHeaderTimeout and IdleTimeout require Go 1.8
s.ReadTimeout = min.ReadTimeout s.ReadTimeout = min.ReadTimeout
// s.ReadHeaderTimeout = min.ReadHeaderTimeout s.ReadHeaderTimeout = min.ReadHeaderTimeout
s.WriteTimeout = min.WriteTimeout s.WriteTimeout = min.WriteTimeout
// s.IdleTimeout = min.IdleTimeout s.IdleTimeout = min.IdleTimeout
return s return s
} }
......
...@@ -100,15 +100,14 @@ func TestMakeHTTPServer(t *testing.T) { ...@@ -100,15 +100,14 @@ func TestMakeHTTPServer(t *testing.T) {
if got, want := actual.ReadTimeout, tc.expected.ReadTimeout; got != want { if got, want := actual.ReadTimeout, tc.expected.ReadTimeout; got != want {
t.Errorf("Test %d: Expected ReadTimeout=%v, but was %v", i, want, got) t.Errorf("Test %d: Expected ReadTimeout=%v, but was %v", i, want, got)
} }
// TODO: ReadHeaderTimeout and IdleTimeout require Go 1.8 if got, want := actual.ReadHeaderTimeout, tc.expected.ReadHeaderTimeout; got != want {
// if got, want := actual.ReadHeaderTimeout, tc.expected.ReadHeaderTimeout; got != want { t.Errorf("Test %d: Expected ReadHeaderTimeout=%v, but was %v", i, want, got)
// t.Errorf("Test %d: Expected ReadHeaderTimeout=%v, but was %v", i, want, got) }
// }
if got, want := actual.WriteTimeout, tc.expected.WriteTimeout; got != want { if got, want := actual.WriteTimeout, tc.expected.WriteTimeout; got != want {
t.Errorf("Test %d: Expected WriteTimeout=%v, but was %v", i, want, got) t.Errorf("Test %d: Expected WriteTimeout=%v, but was %v", i, want, got)
} }
// if got, want := actual.IdleTimeout, tc.expected.IdleTimeout; got != want { if got, want := actual.IdleTimeout, tc.expected.IdleTimeout; got != want {
// t.Errorf("Test %d: Expected IdleTimeout=%v, but was %v", i, want, got) t.Errorf("Test %d: Expected IdleTimeout=%v, but was %v", i, want, got)
// } }
} }
} }
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment