Commit d33256f1 authored by Matthew Holt's avatar Matthew Holt

Refactor: Middleware chain uses Handler instead of HandlerFunc

parent db2cd9e9
...@@ -20,7 +20,7 @@ import ( ...@@ -20,7 +20,7 @@ import (
// Browse is an http.Handler that can show a file listing when // Browse is an http.Handler that can show a file listing when
// directories in the given paths are specified. // directories in the given paths are specified.
type Browse struct { type Browse struct {
Next middleware.HandlerFunc Next middleware.Handler
Root string Root string
Configs []BrowseConfig Configs []BrowseConfig
} }
...@@ -83,9 +83,9 @@ func New(c middleware.Controller) (middleware.Middleware, error) { ...@@ -83,9 +83,9 @@ func New(c middleware.Controller) (middleware.Middleware, error) {
Configs: configs, Configs: configs,
} }
return func(next middleware.HandlerFunc) middleware.HandlerFunc { return func(next middleware.Handler) middleware.Handler {
browse.Next = next browse.Next = next
return browse.ServeHTTP return browse
}, nil }, nil
} }
...@@ -95,11 +95,11 @@ func (b Browse) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { ...@@ -95,11 +95,11 @@ func (b Browse) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
info, err := os.Stat(filename) info, err := os.Stat(filename)
if err != nil { if err != nil {
return b.Next(w, r) return b.Next.ServeHTTP(w, r)
} }
if !info.IsDir() { if !info.IsDir() {
return b.Next(w, r) return b.Next.ServeHTTP(w, r)
} }
// See if there's a browse configuration to match the path // See if there's a browse configuration to match the path
...@@ -192,7 +192,7 @@ func (b Browse) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { ...@@ -192,7 +192,7 @@ func (b Browse) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
} }
// Didn't qualify; pass-thru // Didn't qualify; pass-thru
return b.Next(w, r) return b.Next.ServeHTTP(w, r)
} }
// parse returns a list of browsing configurations // parse returns a list of browsing configurations
......
...@@ -39,15 +39,15 @@ func New(c middleware.Controller) (middleware.Middleware, error) { ...@@ -39,15 +39,15 @@ func New(c middleware.Controller) (middleware.Middleware, error) {
return nil return nil
}) })
return func(next middleware.HandlerFunc) middleware.HandlerFunc { return func(next middleware.Handler) middleware.Handler {
handler.Next = next handler.Next = next
return handler.ServeHTTP return handler
}, nil }, nil
} }
// ErrorHandler handles HTTP errors (or errors from other middleware). // ErrorHandler handles HTTP errors (or errors from other middleware).
type ErrorHandler struct { type ErrorHandler struct {
Next middleware.HandlerFunc Next middleware.Handler
ErrorPages map[int]string // map of status code to filename ErrorPages map[int]string // map of status code to filename
LogFile string LogFile string
Log *log.Logger Log *log.Logger
...@@ -61,7 +61,7 @@ func (h ErrorHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, er ...@@ -61,7 +61,7 @@ func (h ErrorHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, er
} }
}() }()
status, err := h.Next(w, r) status, err := h.Next.ServeHTTP(w, r)
if err != nil { if err != nil {
h.Log.Printf("[ERROR %d %s] %v", status, r.URL.Path, err) h.Log.Printf("[ERROR %d %s] %v", status, r.URL.Path, err)
......
...@@ -24,12 +24,12 @@ func New(c middleware.Controller) (middleware.Middleware, error) { ...@@ -24,12 +24,12 @@ func New(c middleware.Controller) (middleware.Middleware, error) {
return nil, err return nil, err
} }
return func(next middleware.HandlerFunc) middleware.HandlerFunc { return func(next middleware.Handler) middleware.Handler {
return Ext{ return Ext{
Next: next, Next: next,
Extensions: extensions, Extensions: extensions,
Root: root, Root: root,
}.ServeHTTP }
}, nil }, nil
} }
...@@ -37,7 +37,7 @@ func New(c middleware.Controller) (middleware.Middleware, error) { ...@@ -37,7 +37,7 @@ func New(c middleware.Controller) (middleware.Middleware, error) {
// It tries extensions in the order listed in Extensions. // It tries extensions in the order listed in Extensions.
type Ext struct { type Ext struct {
// Next handler in the chain // Next handler in the chain
Next middleware.HandlerFunc Next middleware.Handler
// Path to ther root of the site // Path to ther root of the site
Root string Root string
...@@ -57,7 +57,7 @@ func (e Ext) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { ...@@ -57,7 +57,7 @@ func (e Ext) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
} }
} }
} }
return e.Next(w, r) return e.Next.ServeHTTP(w, r)
} }
// parse sets up an instance of extension middleware // parse sets up an instance of extension middleware
......
...@@ -26,8 +26,8 @@ func New(c middleware.Controller) (middleware.Middleware, error) { ...@@ -26,8 +26,8 @@ func New(c middleware.Controller) (middleware.Middleware, error) {
rules = append(rules, rule) rules = append(rules, rule)
} }
return func(next middleware.HandlerFunc) middleware.HandlerFunc { return func(next middleware.Handler) middleware.Handler {
return func(w http.ResponseWriter, r *http.Request) (int, error) { return middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
servedFcgi := false servedFcgi := false
for _, rule := range rules { for _, rule := range rules {
...@@ -97,11 +97,11 @@ func New(c middleware.Controller) (middleware.Middleware, error) { ...@@ -97,11 +97,11 @@ func New(c middleware.Controller) (middleware.Middleware, error) {
} }
if !servedFcgi { if !servedFcgi {
return next(w, r) return next.ServeHTTP(w, r)
} }
return 0, nil return 0, nil
} })
}, nil }, nil
} }
......
...@@ -11,29 +11,28 @@ import ( ...@@ -11,29 +11,28 @@ import (
"github.com/mholt/caddy/middleware" "github.com/mholt/caddy/middleware"
) )
// Gzip is a http.Handler middleware type which gzips HTTP responses. // Gzip is a middleware type which gzips HTTP responses.
type Gzip struct { type Gzip struct {
Next middleware.HandlerFunc Next middleware.Handler
} }
// New creates a new gzip middleware instance. // New creates a new gzip middleware instance.
func New(c middleware.Controller) (middleware.Middleware, error) { func New(c middleware.Controller) (middleware.Middleware, error) {
return func(next middleware.HandlerFunc) middleware.HandlerFunc { return func(next middleware.Handler) middleware.Handler {
gz := Gzip{Next: next} return Gzip{Next: next}
return gz.ServeHTTP
}, nil }, nil
} }
// ServeHTTP serves a gzipped response if the client supports it. // ServeHTTP serves a gzipped response if the client supports it.
func (g Gzip) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { func (g Gzip) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") { if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
return g.Next(w, r) return g.Next.ServeHTTP(w, r)
} }
w.Header().Set("Content-Encoding", "gzip") w.Header().Set("Content-Encoding", "gzip")
gzipWriter := gzip.NewWriter(w) gzipWriter := gzip.NewWriter(w)
defer gzipWriter.Close() defer gzipWriter.Close()
gz := gzipResponseWriter{Writer: gzipWriter, ResponseWriter: w} gz := gzipResponseWriter{Writer: gzipWriter, ResponseWriter: w}
return g.Next(gz, r) return g.Next.ServeHTTP(gz, r)
} }
// gzipResponeWriter wraps the underlying Write method // gzipResponeWriter wraps the underlying Write method
......
...@@ -12,7 +12,7 @@ import ( ...@@ -12,7 +12,7 @@ import (
// Headers is middleware that adds headers to the responses // Headers is middleware that adds headers to the responses
// for requests matching a certain path. // for requests matching a certain path.
type Headers struct { type Headers struct {
Next middleware.HandlerFunc Next middleware.Handler
Rules []HeaderRule Rules []HeaderRule
} }
...@@ -26,7 +26,7 @@ func (h Headers) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) ...@@ -26,7 +26,7 @@ func (h Headers) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error)
} }
} }
} }
return h.Next(w, r) return h.Next.ServeHTTP(w, r)
} }
type ( type (
......
...@@ -10,7 +10,7 @@ func New(c middleware.Controller) (middleware.Middleware, error) { ...@@ -10,7 +10,7 @@ func New(c middleware.Controller) (middleware.Middleware, error) {
return nil, err return nil, err
} }
return func(next middleware.HandlerFunc) middleware.HandlerFunc { return func(next middleware.Handler) middleware.Handler {
return Headers{Next: next, Rules: rules}.ServeHTTP return Headers{Next: next, Rules: rules}
}, nil }, nil
} }
...@@ -39,8 +39,8 @@ func New(c middleware.Controller) (middleware.Middleware, error) { ...@@ -39,8 +39,8 @@ func New(c middleware.Controller) (middleware.Middleware, error) {
return nil return nil
}) })
return func(next middleware.HandlerFunc) middleware.HandlerFunc { return func(next middleware.Handler) middleware.Handler {
return Logger{Next: next, Rules: rules}.ServeHTTP return Logger{Next: next, Rules: rules}
}, nil }, nil
} }
...@@ -48,13 +48,13 @@ func (l Logger) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { ...@@ -48,13 +48,13 @@ func (l Logger) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
for _, rule := range l.Rules { for _, rule := range l.Rules {
if middleware.Path(r.URL.Path).Matches(rule.PathScope) { if middleware.Path(r.URL.Path).Matches(rule.PathScope) {
responseRecorder := middleware.NewResponseRecorder(w) responseRecorder := middleware.NewResponseRecorder(w)
status, err := l.Next(responseRecorder, r) status, err := l.Next.ServeHTTP(responseRecorder, r)
rep := middleware.NewReplacer(r, responseRecorder) rep := middleware.NewReplacer(r, responseRecorder)
rule.Log.Println(rep.Replace(rule.Format)) rule.Log.Println(rep.Replace(rule.Format))
return status, err return status, err
} }
} }
return l.Next(w, r) return l.Next.ServeHTTP(w, r)
} }
func parse(c middleware.Controller) ([]LogRule, error) { func parse(c middleware.Controller) ([]LogRule, error) {
...@@ -103,7 +103,7 @@ func parse(c middleware.Controller) ([]LogRule, error) { ...@@ -103,7 +103,7 @@ func parse(c middleware.Controller) ([]LogRule, error) {
} }
type Logger struct { type Logger struct {
Next middleware.HandlerFunc Next middleware.Handler
Rules []LogRule Rules []LogRule
} }
......
...@@ -21,7 +21,7 @@ type Markdown struct { ...@@ -21,7 +21,7 @@ type Markdown struct {
Root string Root string
// Next HTTP handler in the chain // Next HTTP handler in the chain
Next middleware.HandlerFunc Next middleware.Handler
// The list of markdown configurations // The list of markdown configurations
Configs []MarkdownConfig Configs []MarkdownConfig
...@@ -58,9 +58,9 @@ func New(c middleware.Controller) (middleware.Middleware, error) { ...@@ -58,9 +58,9 @@ func New(c middleware.Controller) (middleware.Middleware, error) {
Configs: mdconfigs, Configs: mdconfigs,
} }
return func(next middleware.HandlerFunc) middleware.HandlerFunc { return func(next middleware.Handler) middleware.Handler {
md.Next = next md.Next = next
return md.ServeHTTP return md
}, nil }, nil
} }
...@@ -122,7 +122,7 @@ func (md Markdown) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error ...@@ -122,7 +122,7 @@ func (md Markdown) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error
} }
// Didn't qualify to serve as markdown; pass-thru // Didn't qualify to serve as markdown; pass-thru
return md.Next(w, r) return md.Next.ServeHTTP(w, r)
} }
// parse creates new instances of Markdown middleware. // parse creates new instances of Markdown middleware.
......
...@@ -6,26 +6,29 @@ import "net/http" ...@@ -6,26 +6,29 @@ import "net/http"
type ( type (
// Generator represents the outer layer of a middleware that // Generator represents the outer layer of a middleware that
// parses tokens to configure the middleware instance. // parses tokens to configure the middleware instance.
//
// Note: This type will be moved into a different package in the future.
Generator func(Controller) (Middleware, error) Generator func(Controller) (Middleware, error)
// Middleware is the middle layer which represents the traditional // Middleware is the middle layer which represents the traditional
// idea of middleware: it is passed the next HandlerFunc in the chain // idea of middleware: it chains one Handler to the next by being
// and returns the inner layer, which is the actual Handler. // passed the next Handler in the chain.
Middleware func(HandlerFunc) HandlerFunc //
// Note: This type will be moved into a different package in the future.
Middleware func(Handler) Handler
// HandlerFunc is like http.HandlerFunc except it returns a status code // Handler is like http.Handler except ServeHTTP returns a status code
// and an error. It is the inner-most layer which serves individual // and an error. The status code is for the client's benefit; the error
// requests. The status code is for the client's benefit; the error
// value is for the server's benefit. The status code will be sent to // value is for the server's benefit. The status code will be sent to
// the client while the error value will be logged privately. Sometimes, // the client while the error value will be logged privately. Sometimes,
// an error status code (4xx or 5xx) may be returned with a nil error // an error status code (4xx or 5xx) may be returned with a nil error
// when there is no reason to log the error on the server. // when there is no reason to log the error on the server.
// //
// If a HandlerFunc returns an error (status >= 400), it should NOT // If a HandlerFunc returns an error (status >= 400), it should NOT
// write to the response. This philosophy makes middleware.HandlerFunc // write to the response. This philosophy makes middleware.Handler
// different from http.HandlerFunc: error handling should happen at // different from http.Handler: error handling should happen at the
// the application layer or in dedicated error-handling middleware // application layer or in dedicated error-handling middleware only
// only, rather than with an "every middleware for itself" paradigm. // rather than with an "every middleware for itself" paradigm.
// //
// The application or error-handling middleware should incorporate logic // The application or error-handling middleware should incorporate logic
// to ensure that the client always gets a proper response according to // to ensure that the client always gets a proper response according to
...@@ -38,10 +41,6 @@ type ( ...@@ -38,10 +41,6 @@ type (
// response for a status code >= 400. When ANY handler writes to the // response for a status code >= 400. When ANY handler writes to the
// response, it should return a status code < 400 to signal others to // response, it should return a status code < 400 to signal others to
// NOT write to the response again, which would be erroneous. // NOT write to the response again, which would be erroneous.
HandlerFunc func(http.ResponseWriter, *http.Request) (int, error)
// Handler is like http.Handler except ServeHTTP returns a status code
// and an error. See HandlerFunc documentation for more information.
Handler interface { Handler interface {
ServeHTTP(http.ResponseWriter, *http.Request) (int, error) ServeHTTP(http.ResponseWriter, *http.Request) (int, error)
} }
...@@ -127,3 +126,13 @@ type ( ...@@ -127,3 +126,13 @@ type (
Err(string) error Err(string) error
} }
) )
// HandlerFunc is a convenience type like http.HandlerFunc, except
// ServeHTTP returns a status code and an error. See Handler
// documentation for more information.
type HandlerFunc func(http.ResponseWriter, *http.Request) (int, error)
// ServeHTTP implements the Handler interface.
func (f HandlerFunc) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
return f(w, r)
}
...@@ -25,8 +25,8 @@ func New(c middleware.Controller) (middleware.Middleware, error) { ...@@ -25,8 +25,8 @@ func New(c middleware.Controller) (middleware.Middleware, error) {
rules = append(rules, rule) rules = append(rules, rule)
} }
return func(next middleware.HandlerFunc) middleware.HandlerFunc { return func(next middleware.Handler) middleware.Handler {
return func(w http.ResponseWriter, r *http.Request) (int, error) { return middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
for _, rule := range rules { for _, rule := range rules {
if middleware.Path(r.URL.Path).Matches(rule.from) { if middleware.Path(r.URL.Path).Matches(rule.from) {
...@@ -59,8 +59,8 @@ func New(c middleware.Controller) (middleware.Middleware, error) { ...@@ -59,8 +59,8 @@ func New(c middleware.Controller) (middleware.Middleware, error) {
} }
} }
return next(w, r) return next.ServeHTTP(w, r)
} })
}, nil }, nil
} }
......
...@@ -38,8 +38,8 @@ func New(c middleware.Controller) (middleware.Middleware, error) { ...@@ -38,8 +38,8 @@ func New(c middleware.Controller) (middleware.Middleware, error) {
} }
} }
return func(next middleware.HandlerFunc) middleware.HandlerFunc { return func(next middleware.Handler) middleware.Handler {
return func(w http.ResponseWriter, r *http.Request) (int, error) { return middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
for _, rule := range redirects { for _, rule := range redirects {
if middleware.Path(r.URL.Path).Matches(rule.From) { if middleware.Path(r.URL.Path).Matches(rule.From) {
if rule.From == "/" { if rule.From == "/" {
...@@ -51,8 +51,8 @@ func New(c middleware.Controller) (middleware.Middleware, error) { ...@@ -51,8 +51,8 @@ func New(c middleware.Controller) (middleware.Middleware, error) {
return 0, nil return 0, nil
} }
} }
return next(w, r) return next.ServeHTTP(w, r)
} })
}, nil }, nil
} }
......
...@@ -29,16 +29,16 @@ func New(c middleware.Controller) (middleware.Middleware, error) { ...@@ -29,16 +29,16 @@ func New(c middleware.Controller) (middleware.Middleware, error) {
} }
// TODO: Why can't we just return an http.Handler here instead? // TODO: Why can't we just return an http.Handler here instead?
return func(next middleware.HandlerFunc) middleware.HandlerFunc { return func(next middleware.Handler) middleware.Handler {
return func(w http.ResponseWriter, r *http.Request) (int, error) { return middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
for _, rule := range rewrites { for _, rule := range rewrites {
if r.URL.Path == rule.From { if r.URL.Path == rule.From {
r.URL.Path = rule.To r.URL.Path = rule.To
break break
} }
} }
return next(w, r) return next.ServeHTTP(w, r)
} })
}, nil }, nil
} }
......
...@@ -16,7 +16,7 @@ type ( ...@@ -16,7 +16,7 @@ type (
// websocket endpoints. // websocket endpoints.
WebSockets struct { WebSockets struct {
// Next is the next HTTP handler in the chain for when the path doesn't match // Next is the next HTTP handler in the chain for when the path doesn't match
Next middleware.HandlerFunc Next middleware.Handler
// Sockets holds all the web socket endpoint configurations // Sockets holds all the web socket endpoint configurations
Sockets []WSConfig Sockets []WSConfig
...@@ -46,7 +46,7 @@ func (ws WebSockets) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, err ...@@ -46,7 +46,7 @@ func (ws WebSockets) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, err
} }
// Didn't match a websocket path, so pass-thru // Didn't match a websocket path, so pass-thru
return ws.Next(w, r) return ws.Next.ServeHTTP(w, r)
} }
// New constructs and configures a new websockets middleware instance. // New constructs and configures a new websockets middleware instance.
...@@ -115,8 +115,8 @@ func New(c middleware.Controller) (middleware.Middleware, error) { ...@@ -115,8 +115,8 @@ func New(c middleware.Controller) (middleware.Middleware, error) {
GatewayInterface = envGatewayInterface GatewayInterface = envGatewayInterface
ServerSoftware = envServerSoftware ServerSoftware = envServerSoftware
return func(next middleware.HandlerFunc) middleware.HandlerFunc { return func(next middleware.Handler) middleware.Handler {
return WebSockets{Next: next, Sockets: websocks}.ServeHTTP return WebSockets{Next: next, Sockets: websocks}
}, nil }, nil
} }
......
...@@ -28,7 +28,7 @@ var servers = make(map[string]*Server) ...@@ -28,7 +28,7 @@ var servers = make(map[string]*Server)
type Server struct { type Server struct {
config config.Config config config.Config
fileServer middleware.Handler fileServer middleware.Handler
stack middleware.HandlerFunc stack middleware.Handler
} }
// New creates a new Server and registers it with the list // New creates a new Server and registers it with the list
...@@ -118,8 +118,9 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { ...@@ -118,8 +118,9 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
}() }()
status, _ := s.stack(w, r) status, _ := s.stack.ServeHTTP(w, r)
// Fallback error response in case error handling wasn't chained in
if status >= 400 { if status >= 400 {
w.WriteHeader(status) w.WriteHeader(status)
fmt.Fprintf(w, "%d %s", status, http.StatusText(status)) fmt.Fprintf(w, "%d %s", status, http.StatusText(status))
...@@ -144,7 +145,7 @@ func (s *Server) buildStack() error { ...@@ -144,7 +145,7 @@ func (s *Server) buildStack() error {
// compile is an elegant alternative to nesting middleware function // compile is an elegant alternative to nesting middleware function
// calls like handler1(handler2(handler3(finalHandler))). // calls like handler1(handler2(handler3(finalHandler))).
func (s *Server) compile(layers []middleware.Middleware) { func (s *Server) compile(layers []middleware.Middleware) {
s.stack = s.fileServer.ServeHTTP // core app layer s.stack = s.fileServer // core app layer
for i := len(layers) - 1; i >= 0; i-- { for i := len(layers) - 1; i >= 0; i-- {
s.stack = layers[i](s.stack) s.stack = layers[i](s.stack)
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment