Commit 995a2ea6 authored by Matt Holt's avatar Matt Holt

Merge pull request #112 from abiosoft/master

Gzip: Added compression level, extension and path filters.
parents 6080c4fa 13db60d3
package setup package setup
import ( import (
"fmt"
"strconv"
"strings"
"github.com/mholt/caddy/middleware" "github.com/mholt/caddy/middleware"
"github.com/mholt/caddy/middleware/gzip" "github.com/mholt/caddy/middleware/gzip"
) )
// Gzip configures a new gzip middleware instance. // Gzip configures a new gzip middleware instance.
func Gzip(c *Controller) (middleware.Middleware, error) { func Gzip(c *Controller) (middleware.Middleware, error) {
configs, err := gzipParse(c)
if err != nil {
return nil, err
}
return func(next middleware.Handler) middleware.Handler { return func(next middleware.Handler) middleware.Handler {
return gzip.Gzip{Next: next} return gzip.Gzip{Next: next, Configs: configs}
}, nil }, nil
} }
func gzipParse(c *Controller) ([]gzip.Config, error) {
var configs []gzip.Config
for c.Next() {
config := gzip.Config{}
pathFilter := gzip.PathFilter{make(gzip.Set)}
extFilter := gzip.DefaultExtFilter()
// no extra args expected
if len(c.RemainingArgs()) > 0 {
return configs, c.ArgErr()
}
for c.NextBlock() {
switch c.Val() {
case "ext":
exts := c.RemainingArgs()
if len(exts) == 0 {
return configs, c.ArgErr()
}
for _, e := range exts {
if !strings.HasPrefix(e, ".") {
return configs, fmt.Errorf(`Invalid extension %v. Should start with "."`, e)
}
extFilter.Exts.Add(e)
}
case "not":
paths := c.RemainingArgs()
if len(paths) == 0 {
return configs, c.ArgErr()
}
for _, p := range paths {
if !strings.HasPrefix(p, "/") {
return configs, fmt.Errorf(`Invalid path %v. Should start with "/"`, p)
}
pathFilter.IgnoredPaths.Add(p)
// Warn user if / is used
if p == "/" {
fmt.Println("Warning: Paths ignored by gzip includes wildcard(/). No request will be gzipped.\nRemoving gzip directive from Caddyfile is preferred if this is intended.")
}
}
case "level":
if !c.NextArg() {
return configs, c.ArgErr()
}
level, _ := strconv.Atoi(c.Val())
config.Level = level
default:
return configs, c.ArgErr()
}
}
// put pathFilter in front to filter with path first
config.Filters = []gzip.Filter{pathFilter, extFilter}
configs = append(configs, config)
}
return configs, nil
}
...@@ -26,4 +26,47 @@ func TestGzip(t *testing.T) { ...@@ -26,4 +26,47 @@ func TestGzip(t *testing.T) {
if !sameNext(myHandler.Next, emptyNext) { if !sameNext(myHandler.Next, emptyNext) {
t.Error("'Next' field of handler was not set properly") t.Error("'Next' field of handler was not set properly")
} }
tests := []struct {
input string
shouldErr bool
}{
{`gzip {`, true},
{`gzip {}`, true},
{`gzip a b`, true},
{`gzip a {`, true},
{`gzip { not f } `, true},
{`gzip { not } `, true},
{`gzip { not /file
ext .html
level 1
} `, false},
{`gzip { level 9 } `, false},
{`gzip { ext } `, true},
{`gzip { ext /f
} `, true},
{`gzip { not /file
ext .html
level 1
}
gzip`, false},
{`gzip { not /file
ext .html
level 1
}
gzip { not /file1
ext .htm
level 3
}
`, false},
}
for i, test := range tests {
c := newTestController(test.input)
_, err := gzipParse(c)
if test.shouldErr && err == nil {
t.Errorf("Text %v: Expected error but found nil", i)
} else if !test.shouldErr && err != nil {
t.Errorf("Text %v: Expected no error but found error: ", i, err)
}
}
} }
package gzip
import (
"net/http"
"path"
"github.com/mholt/caddy/middleware"
)
// Filter determines if a request should be gzipped.
type Filter interface {
// ShouldCompress tells if compression gzip compression
// should be done on the request.
ShouldCompress(*http.Request) bool
}
// ExtFilter is Filter for file name extensions.
type ExtFilter struct {
// Exts is the file name extensions to accept
Exts Set
}
// textExts is a list of extensions for text related files.
var textExts = []string{
".html", ".htm", ".css", ".json", ".php", ".js", ".txt", ".md", ".xml",
}
// extWildCard is the wildcard for extensions.
const extWildCard = "*"
// DefaultExtFilter creates a default ExtFilter with
// file extensions for text types.
func DefaultExtFilter() ExtFilter {
e := ExtFilter{make(Set)}
for _, ext := range textExts {
e.Exts.Add(ext)
}
return e
}
func (e ExtFilter) ShouldCompress(r *http.Request) bool {
ext := path.Ext(r.URL.Path)
return e.Exts.Contains(extWildCard) || e.Exts.Contains(ext)
}
// PathFilter is Filter for request path.
type PathFilter struct {
// IgnoredPaths is the paths to ignore
IgnoredPaths Set
}
// ShouldCompress checks if the request path matches any of the
// registered paths to ignore. If returns false if an ignored path
// is found and true otherwise.
func (p PathFilter) ShouldCompress(r *http.Request) bool {
return !p.IgnoredPaths.ContainsFunc(func(value string) bool {
return middleware.Path(r.URL.Path).Matches(value)
})
}
// Set stores distinct strings.
type Set map[string]struct{}
// Add adds an element to the set.
func (s Set) Add(value string) {
s[value] = struct{}{}
}
// Remove removes an element from the set.
func (s Set) Remove(value string) {
delete(s, value)
}
// Contains check if the set contains value.
func (s Set) Contains(value string) bool {
_, ok := s[value]
return ok
}
// ContainsFunc is similar to Contains. It iterates all the
// elements in the set and passes each to f. It returns true
// on the first call to f that returns true and false otherwise.
func (s Set) ContainsFunc(f func(string) bool) bool {
for k, _ := range s {
if f(k) {
return true
}
}
return false
}
package gzip
import (
"net/http"
"testing"
)
func TestSet(t *testing.T) {
set := make(Set)
set.Add("a")
if len(set) != 1 {
t.Errorf("Expected 1 found %v", len(set))
}
set.Add("a")
if len(set) != 1 {
t.Errorf("Expected 1 found %v", len(set))
}
set.Add("b")
if len(set) != 2 {
t.Errorf("Expected 2 found %v", len(set))
}
if !set.Contains("a") {
t.Errorf("Set should contain a")
}
if !set.Contains("b") {
t.Errorf("Set should contain a")
}
set.Add("c")
if len(set) != 3 {
t.Errorf("Expected 3 found %v", len(set))
}
if !set.Contains("c") {
t.Errorf("Set should contain c")
}
set.Remove("a")
if len(set) != 2 {
t.Errorf("Expected 2 found %v", len(set))
}
if set.Contains("a") {
t.Errorf("Set should not contain a")
}
if !set.ContainsFunc(func(v string) bool {
return v == "c"
}) {
t.Errorf("ContainsFunc should return true")
}
}
func TestExtFilter(t *testing.T) {
var filter Filter = DefaultExtFilter()
_ = filter.(ExtFilter)
for i, e := range textExts {
r := urlRequest("file" + e)
if !filter.ShouldCompress(r) {
t.Errorf("Test %v: Should be valid filter", i)
}
}
var exts = []string{
".html", ".css", ".md",
}
for i, e := range exts {
r := urlRequest("file" + e)
if !filter.ShouldCompress(r) {
t.Errorf("Test %v: Should be valid filter", i)
}
}
exts = []string{
".htm1", ".abc", ".mdx",
}
for i, e := range exts {
r := urlRequest("file" + e)
if filter.ShouldCompress(r) {
t.Errorf("Test %v: Should not be valid filter", i)
}
}
}
func TestPathFilter(t *testing.T) {
paths := []string{
"/a", "/b", "/c", "/de",
}
var filter Filter = PathFilter{make(Set)}
for _, p := range paths {
filter.(PathFilter).IgnoredPaths.Add(p)
}
for i, p := range paths {
r := urlRequest(p)
if filter.ShouldCompress(r) {
t.Errorf("Test %v: Should not be valid filter", i)
}
}
paths = []string{
"/f", "/g", "/h", "/ed",
}
for i, p := range paths {
r := urlRequest(p)
if !filter.ShouldCompress(r) {
t.Errorf("Test %v: Should be valid filter", i)
}
}
}
func urlRequest(url string) *http.Request {
r, _ := http.NewRequest("GET", url, nil)
return r
}
...@@ -17,7 +17,14 @@ import ( ...@@ -17,7 +17,14 @@ import (
// specifies the Content-Type, otherwise some clients will assume // specifies the Content-Type, otherwise some clients will assume
// application/x-gzip and try to download a file. // application/x-gzip and try to download a file.
type Gzip struct { type Gzip struct {
Next middleware.Handler Next middleware.Handler
Configs []Config
}
// Config holds the configuration for Gzip middleware
type Config struct {
Filters []Filter // Filters to use
Level int // Compression level
} }
// ServeHTTP serves a gzipped response if the client supports it. // ServeHTTP serves a gzipped response if the client supports it.
...@@ -26,27 +33,56 @@ func (g Gzip) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { ...@@ -26,27 +33,56 @@ func (g Gzip) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
return g.Next.ServeHTTP(w, r) return g.Next.ServeHTTP(w, r)
} }
// Delete this header so gzipping isn't repeated later in the chain outer:
r.Header.Del("Accept-Encoding") for _, c := range g.Configs {
w.Header().Set("Content-Encoding", "gzip") // Check filters to determine if gzipping is permitted for this
gzipWriter := gzip.NewWriter(w) // request
defer gzipWriter.Close() for _, filter := range c.Filters {
gz := gzipResponseWriter{Writer: gzipWriter, ResponseWriter: w} if !filter.ShouldCompress(r) {
continue outer
// Any response in forward middleware will now be compressed }
status, err := g.Next.ServeHTTP(gz, r) }
// If there was an error that remained unhandled, we need // Delete this header so gzipping is not repeated later in the chain
// to send something back before gzipWriter gets closed at r.Header.Del("Accept-Encoding")
// the return of this method!
if status >= 400 { w.Header().Set("Content-Encoding", "gzip")
gz.Header().Set("Content-Type", "text/plain") // very necessary gzipWriter, err := newWriter(c, w)
gz.WriteHeader(status) if err != nil {
fmt.Fprintf(gz, "%d %s", status, http.StatusText(status)) // should not happen
return 0, err return http.StatusInternalServerError, err
}
defer gzipWriter.Close()
gz := gzipResponseWriter{Writer: gzipWriter, ResponseWriter: w}
// Any response in forward middleware will now be compressed
status, err := g.Next.ServeHTTP(gz, r)
// If there was an error that remained unhandled, we need
// to send something back before gzipWriter gets closed at
// the return of this method!
if status >= 400 {
gz.Header().Set("Content-Type", "text/plain") // very necessary
gz.WriteHeader(status)
fmt.Fprintf(gz, "%d %s", status, http.StatusText(status))
return 0, err
}
return status, err
}
// no matching filter
return g.Next.ServeHTTP(w, r)
}
// newWriter create a new Gzip Writer based on the compression level.
// If the level is valid (i.e. between 1 and 9), it uses the level.
// Otherwise, it uses default compression level.
func newWriter(c Config, w http.ResponseWriter) (*gzip.Writer, error) {
if c.Level >= gzip.BestSpeed && c.Level <= gzip.BestCompression {
return gzip.NewWriterLevel(w, c.Level)
} }
return status, err return gzip.NewWriter(w), nil
} }
// gzipResponeWriter wraps the underlying Write method // gzipResponeWriter wraps the underlying Write method
......
package gzip
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/mholt/caddy/middleware"
)
func TestGzipHandler(t *testing.T) {
pathFilter := PathFilter{make(Set)}
badPaths := []string{"/bad", "/nogzip", "/nongzip"}
for _, p := range badPaths {
pathFilter.IgnoredPaths.Add(p)
}
gz := Gzip{Configs: []Config{
Config{Filters: []Filter{DefaultExtFilter(), pathFilter}},
}}
w := httptest.NewRecorder()
gz.Next = nextFunc(true)
for _, e := range textExts {
url := "/file" + e
r, err := http.NewRequest("GET", url, nil)
if err != nil {
t.Error(err)
}
r.Header.Set("Accept-Encoding", "gzip")
_, err = gz.ServeHTTP(w, r)
if err != nil {
t.Error(err)
}
}
w = httptest.NewRecorder()
gz.Next = nextFunc(false)
for _, p := range badPaths {
for _, e := range textExts {
url := p + "/file" + e
r, err := http.NewRequest("GET", url, nil)
if err != nil {
t.Error(err)
}
r.Header.Set("Accept-Encoding", "gzip")
_, err = gz.ServeHTTP(w, r)
if err != nil {
t.Error(err)
}
}
}
w = httptest.NewRecorder()
gz.Next = nextFunc(false)
exts := []string{
".htm1", ".abc", ".mdx",
}
for _, e := range exts {
url := "/file" + e
r, err := http.NewRequest("GET", url, nil)
if err != nil {
t.Error(err)
}
r.Header.Set("Accept-Encoding", "gzip")
_, err = gz.ServeHTTP(w, r)
if err != nil {
t.Error(err)
}
}
}
func nextFunc(shouldGzip bool) middleware.Handler {
return middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
if shouldGzip {
if r.Header.Get("Accept-Encoding") != "" {
return 0, fmt.Errorf("Accept-Encoding header not expected")
}
if w.Header().Get("Content-Encoding") != "gzip" {
return 0, fmt.Errorf("Content-Encoding must be gzip, found %v", r.Header.Get("Content-Encoding"))
}
if _, ok := w.(gzipResponseWriter); !ok {
return 0, fmt.Errorf("ResponseWriter should be gzipResponseWriter, found %T", w)
}
return 0, nil
}
if r.Header.Get("Accept-Encoding") == "" {
return 0, fmt.Errorf("Accept-Encoding header expected")
}
if w.Header().Get("Content-Encoding") == "gzip" {
return 0, fmt.Errorf("Content-Encoding must not be gzip, found gzip")
}
if _, ok := w.(gzipResponseWriter); ok {
return 0, fmt.Errorf("ResponseWriter should not be gzipResponseWriter")
}
return 0, nil
})
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment