Commit f7003bee authored by Volodymyr Galkin's avatar Volodymyr Galkin

Add support for default (wildcard) error page

parent 68be4a91
...@@ -25,6 +25,7 @@ func init() { ...@@ -25,6 +25,7 @@ func init() {
// ErrorHandler handles HTTP errors (and errors from other middleware). // ErrorHandler handles HTTP errors (and errors from other middleware).
type ErrorHandler struct { type ErrorHandler struct {
Next httpserver.Handler Next httpserver.Handler
GenericErrorPage string // default error page filename
ErrorPages map[int]string // map of status code to filename ErrorPages map[int]string // map of status code to filename
LogFile string LogFile string
Log *log.Logger Log *log.Logger
...@@ -63,7 +64,7 @@ func (h ErrorHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, er ...@@ -63,7 +64,7 @@ func (h ErrorHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, er
// message is written instead, and the extra error is logged. // message is written instead, and the extra error is logged.
func (h ErrorHandler) errorPage(w http.ResponseWriter, r *http.Request, code int) { func (h ErrorHandler) errorPage(w http.ResponseWriter, r *http.Request, code int) {
// See if an error page for this status code was specified // See if an error page for this status code was specified
if pagePath, ok := h.ErrorPages[code]; ok { if pagePath, ok := h.findErrorPage(code); ok {
// Try to open it // Try to open it
errorPage, err := os.Open(pagePath) errorPage, err := os.Open(pagePath)
if err != nil { if err != nil {
...@@ -94,6 +95,18 @@ func (h ErrorHandler) errorPage(w http.ResponseWriter, r *http.Request, code int ...@@ -94,6 +95,18 @@ func (h ErrorHandler) errorPage(w http.ResponseWriter, r *http.Request, code int
httpserver.DefaultErrorFunc(w, r, code) httpserver.DefaultErrorFunc(w, r, code)
} }
func (h ErrorHandler) findErrorPage(code int) (string, bool) {
if pagePath, ok := h.ErrorPages[code]; ok {
return pagePath, true
}
if h.GenericErrorPage != "" {
return h.GenericErrorPage, true
}
return "", false
}
func (h ErrorHandler) recovery(w http.ResponseWriter, r *http.Request) { func (h ErrorHandler) recovery(w http.ResponseWriter, r *http.Request) {
rec := recover() rec := recover()
if rec == nil { if rec == nil {
......
...@@ -18,19 +18,13 @@ import ( ...@@ -18,19 +18,13 @@ import (
func TestErrors(t *testing.T) { func TestErrors(t *testing.T) {
// create a temporary page // create a temporary page
path := filepath.Join(os.TempDir(), "errors_test.html")
f, err := os.Create(path)
if err != nil {
t.Fatal(err)
}
defer os.Remove(path)
const content = "This is a error page" const content = "This is a error page"
_, err = f.WriteString(content)
path, err := createErrorPageFile("errors_test.html", content)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
f.Close() defer os.Remove(path)
buf := bytes.Buffer{} buf := bytes.Buffer{}
em := ErrorHandler{ em := ErrorHandler{
...@@ -157,6 +151,87 @@ func TestVisibleErrorWithPanic(t *testing.T) { ...@@ -157,6 +151,87 @@ func TestVisibleErrorWithPanic(t *testing.T) {
} }
} }
func TestGenericErrorPage(t *testing.T) {
// create temporary generic error page
const genericErrorContent = "This is a generic error page"
genericErrorPagePath, err := createErrorPageFile("generic_error_test.html", genericErrorContent)
if err != nil {
t.Fatal(err)
}
defer os.Remove(genericErrorPagePath)
// create temporary error page
const notFoundErrorContent = "This is a error page"
notFoundErrorPagePath, err := createErrorPageFile("not_found.html", notFoundErrorContent)
if err != nil {
t.Fatal(err)
}
defer os.Remove(notFoundErrorPagePath)
buf := bytes.Buffer{}
em := ErrorHandler{
GenericErrorPage: genericErrorPagePath,
ErrorPages: map[int]string{
http.StatusNotFound: notFoundErrorPagePath,
},
Log: log.New(&buf, "", 0),
}
tests := []struct {
next httpserver.Handler
expectedCode int
expectedBody string
expectedLog string
expectedErr error
}{
{
next: genErrorHandler(http.StatusNotFound, nil, ""),
expectedCode: 0,
expectedBody: notFoundErrorContent,
expectedLog: "",
expectedErr: nil,
},
{
next: genErrorHandler(http.StatusInternalServerError, nil, ""),
expectedCode: 0,
expectedBody: genericErrorContent,
expectedLog: "",
expectedErr: nil,
},
}
req, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatal(err)
}
for i, test := range tests {
em.Next = test.next
buf.Reset()
rec := httptest.NewRecorder()
code, err := em.ServeHTTP(rec, req)
if err != test.expectedErr {
t.Errorf("Test %d: Expected error %v, but got %v",
i, test.expectedErr, err)
}
if code != test.expectedCode {
t.Errorf("Test %d: Expected status code %d, but got %d",
i, test.expectedCode, code)
}
if body := rec.Body.String(); body != test.expectedBody {
t.Errorf("Test %d: Expected body %q, but got %q",
i, test.expectedBody, body)
}
if log := buf.String(); !strings.Contains(log, test.expectedLog) {
t.Errorf("Test %d: Expected log %q, but got %q",
i, test.expectedLog, log)
}
}
}
func genErrorHandler(status int, err error, body string) httpserver.Handler { func genErrorHandler(status int, err error, body string) httpserver.Handler {
return httpserver.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { return httpserver.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
if len(body) > 0 { if len(body) > 0 {
...@@ -166,3 +241,19 @@ func genErrorHandler(status int, err error, body string) httpserver.Handler { ...@@ -166,3 +241,19 @@ func genErrorHandler(status int, err error, body string) httpserver.Handler {
return status, err return status, err
}) })
} }
func createErrorPageFile(name string, content string) (string, error) {
errorPageFilePath := filepath.Join(os.TempDir(), name)
f, err := os.Create(errorPageFilePath)
if err != nil {
return "", err
}
_, err = f.WriteString(content)
if err != nil {
return "", err
}
f.Close()
return errorPageFilePath, nil
}
...@@ -122,13 +122,17 @@ func errorsParse(c *caddy.Controller) (*ErrorHandler, error) { ...@@ -122,13 +122,17 @@ func errorsParse(c *caddy.Controller) (*ErrorHandler, error) {
} }
f.Close() f.Close()
if what == "*" {
handler.GenericErrorPage = where
} else {
whatInt, err := strconv.Atoi(what) whatInt, err := strconv.Atoi(what)
if err != nil { if err != nil {
return hadBlock, c.Err("Expecting a numeric status code, got '" + what + "'") return hadBlock, c.Err("Expecting a numeric status code or '*', got '" + what + "'")
} }
handler.ErrorPages[whatInt] = where handler.ErrorPages[whatInt] = where
} }
} }
}
return hadBlock, nil return hadBlock, nil
} }
......
...@@ -103,6 +103,18 @@ func TestErrorsParse(t *testing.T) { ...@@ -103,6 +103,18 @@ func TestErrorsParse(t *testing.T) {
LocalTime: true, LocalTime: true,
}, },
}}, }},
{`errors { log errors.txt
* generic_error.html
404 404.html
503 503.html
}`, false, ErrorHandler{
LogFile: "errors.txt",
GenericErrorPage: "generic_error.html",
ErrorPages: map[int]string{
404: "404.html",
503: "503.html",
},
}},
} }
for i, test := range tests { for i, test := range tests {
actualErrorsRule, err := errorsParse(caddy.NewTestController("http", test.inputErrorsRules)) actualErrorsRule, err := errorsParse(caddy.NewTestController("http", test.inputErrorsRules))
...@@ -150,6 +162,10 @@ func TestErrorsParse(t *testing.T) { ...@@ -150,6 +162,10 @@ func TestErrorsParse(t *testing.T) {
i, test.expectedErrorHandler.LogRoller.LocalTime, actualErrorsRule.LogRoller.LocalTime) i, test.expectedErrorHandler.LogRoller.LocalTime, actualErrorsRule.LogRoller.LocalTime)
} }
} }
if actualErrorsRule.GenericErrorPage != test.expectedErrorHandler.GenericErrorPage {
t.Fatalf("Test %d expected GenericErrorPage to be %v, but got %v",
i, test.expectedErrorHandler.GenericErrorPage, actualErrorsRule.GenericErrorPage)
}
} }
} }
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment