Commit a8c9d478 authored by Matt Holt's avatar Matt Holt

Merge pull request #404 from abiosoft/master

gzip: support for min_length.
parents aba0ae35 b65ddbc7
...@@ -27,9 +27,13 @@ func gzipParse(c *Controller) ([]gzip.Config, error) { ...@@ -27,9 +27,13 @@ func gzipParse(c *Controller) ([]gzip.Config, error) {
for c.Next() { for c.Next() {
config := gzip.Config{} config := gzip.Config{}
// Request Filters
pathFilter := gzip.PathFilter{IgnoredPaths: make(gzip.Set)} pathFilter := gzip.PathFilter{IgnoredPaths: make(gzip.Set)}
extFilter := gzip.ExtFilter{Exts: make(gzip.Set)} extFilter := gzip.ExtFilter{Exts: make(gzip.Set)}
// Response Filters
lengthFilter := gzip.LengthFilter(0)
// No extra args expected // No extra args expected
if len(c.RemainingArgs()) > 0 { if len(c.RemainingArgs()) > 0 {
return configs, c.ArgErr() return configs, c.ArgErr()
...@@ -68,24 +72,42 @@ func gzipParse(c *Controller) ([]gzip.Config, error) { ...@@ -68,24 +72,42 @@ func gzipParse(c *Controller) ([]gzip.Config, error) {
} }
level, _ := strconv.Atoi(c.Val()) level, _ := strconv.Atoi(c.Val())
config.Level = level config.Level = level
case "min_length":
if !c.NextArg() {
return configs, c.ArgErr()
}
length, err := strconv.ParseInt(c.Val(), 10, 64)
if err != nil {
return configs, err
} else if length == 0 {
return configs, fmt.Errorf(`gzip: min_length must be greater than 0`)
}
lengthFilter = gzip.LengthFilter(length)
default: default:
return configs, c.ArgErr() return configs, c.ArgErr()
} }
} }
config.Filters = []gzip.Filter{} // Request Filters
config.RequestFilters = []gzip.RequestFilter{}
// If ignored paths are specified, put in front to filter with path first // If ignored paths are specified, put in front to filter with path first
if len(pathFilter.IgnoredPaths) > 0 { if len(pathFilter.IgnoredPaths) > 0 {
config.Filters = []gzip.Filter{pathFilter} config.RequestFilters = []gzip.RequestFilter{pathFilter}
} }
// Then, if extensions are specified, use those to filter. // Then, if extensions are specified, use those to filter.
// Otherwise, use default extensions filter. // Otherwise, use default extensions filter.
if len(extFilter.Exts) > 0 { if len(extFilter.Exts) > 0 {
config.Filters = append(config.Filters, extFilter) config.RequestFilters = append(config.RequestFilters, extFilter)
} else { } else {
config.Filters = append(config.Filters, gzip.DefaultExtFilter()) config.RequestFilters = append(config.RequestFilters, gzip.DefaultExtFilter())
}
// Response Filters
// If min_length is specified, use it.
if int64(lengthFilter) != 0 {
config.ResponseFilters = append(config.ResponseFilters, lengthFilter)
} }
configs = append(configs, config) configs = append(configs, config)
......
...@@ -73,6 +73,18 @@ func TestGzip(t *testing.T) { ...@@ -73,6 +73,18 @@ func TestGzip(t *testing.T) {
level 1 level 1
} }
`, false}, `, false},
{`gzip { not /file
ext *
level 1
min_length ab
}
`, true},
{`gzip { not /file
ext *
level 1
min_length 1000
}
`, false},
} }
for i, test := range tests { for i, test := range tests {
c := NewTestController(test.input) c := NewTestController(test.input)
......
...@@ -6,6 +6,7 @@ import ( ...@@ -6,6 +6,7 @@ import (
"compress/gzip" "compress/gzip"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net/http" "net/http"
"strings" "strings"
...@@ -23,7 +24,8 @@ type Gzip struct { ...@@ -23,7 +24,8 @@ type Gzip struct {
// Config holds the configuration for Gzip middleware // Config holds the configuration for Gzip middleware
type Config struct { type Config struct {
Filters []Filter // Filters to use RequestFilters []RequestFilter
ResponseFilters []ResponseFilter
Level int // Compression level Level int // Compression level
} }
...@@ -36,8 +38,8 @@ func (g Gzip) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { ...@@ -36,8 +38,8 @@ func (g Gzip) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
outer: outer:
for _, c := range g.Configs { for _, c := range g.Configs {
// Check filters to determine if gzipping is permitted for this request // Check request filters to determine if gzipping is permitted for this request
for _, filter := range c.Filters { for _, filter := range c.RequestFilters {
if !filter.ShouldCompress(r) { if !filter.ShouldCompress(r) {
continue outer continue outer
} }
...@@ -46,9 +48,10 @@ outer: ...@@ -46,9 +48,10 @@ outer:
// Delete this header so gzipping is not repeated later in the chain // Delete this header so gzipping is not repeated later in the chain
r.Header.Del("Accept-Encoding") r.Header.Del("Accept-Encoding")
w.Header().Set("Content-Encoding", "gzip") // gzipWriter modifies underlying writer at init,
w.Header().Set("Vary", "Accept-Encoding") // use a discard writer instead to leave ResponseWriter in
gzipWriter, err := newWriter(c, w) // original form.
gzipWriter, err := newWriter(c, ioutil.Discard)
if err != nil { if err != nil {
// should not happen // should not happen
return http.StatusInternalServerError, err return http.StatusInternalServerError, err
...@@ -56,8 +59,19 @@ outer: ...@@ -56,8 +59,19 @@ outer:
defer gzipWriter.Close() defer gzipWriter.Close()
gz := gzipResponseWriter{Writer: gzipWriter, ResponseWriter: w} gz := gzipResponseWriter{Writer: gzipWriter, ResponseWriter: w}
var rw http.ResponseWriter
// if no response filter is used
if len(c.ResponseFilters) == 0 {
// replace discard writer with ResponseWriter
gzipWriter.Reset(w)
rw = gz
} else {
// wrap gzip writer with ResponseFilterWriter
rw = NewResponseFilterWriter(c.ResponseFilters, gz)
}
// Any response in forward middleware will now be compressed // Any response in forward middleware will now be compressed
status, err := g.Next.ServeHTTP(gz, r) status, err := g.Next.ServeHTTP(rw, r)
// If there was an error that remained unhandled, we need // If there was an error that remained unhandled, we need
// to send something back before gzipWriter gets closed at // to send something back before gzipWriter gets closed at
...@@ -78,7 +92,7 @@ outer: ...@@ -78,7 +92,7 @@ outer:
// newWriter create a new Gzip Writer based on the compression level. // newWriter create a new Gzip Writer based on the compression level.
// If the level is valid (i.e. between 1 and 9), it uses the level. // If the level is valid (i.e. between 1 and 9), it uses the level.
// Otherwise, it uses default compression level. // Otherwise, it uses default compression level.
func newWriter(c Config, w http.ResponseWriter) (*gzip.Writer, error) { func newWriter(c Config, w io.Writer) (*gzip.Writer, error) {
if c.Level >= gzip.BestSpeed && c.Level <= gzip.BestCompression { if c.Level >= gzip.BestSpeed && c.Level <= gzip.BestCompression {
return gzip.NewWriterLevel(w, c.Level) return gzip.NewWriterLevel(w, c.Level)
} }
...@@ -98,6 +112,8 @@ type gzipResponseWriter struct { ...@@ -98,6 +112,8 @@ type gzipResponseWriter struct {
// be wrong because it doesn't know it's being gzipped. // be wrong because it doesn't know it's being gzipped.
func (w gzipResponseWriter) WriteHeader(code int) { func (w gzipResponseWriter) WriteHeader(code int) {
w.Header().Del("Content-Length") w.Header().Del("Content-Length")
w.Header().Set("Content-Encoding", "gzip")
w.Header().Set("Vary", "Accept-Encoding")
w.ResponseWriter.WriteHeader(code) w.ResponseWriter.WriteHeader(code)
} }
......
...@@ -21,7 +21,7 @@ func TestGzipHandler(t *testing.T) { ...@@ -21,7 +21,7 @@ func TestGzipHandler(t *testing.T) {
extFilter.Exts.Add(e) extFilter.Exts.Add(e)
} }
gz := Gzip{Configs: []Config{ gz := Gzip{Configs: []Config{
{Filters: []Filter{pathFilter, extFilter}}, {RequestFilters: []RequestFilter{pathFilter, extFilter}},
}} }}
w := httptest.NewRecorder() w := httptest.NewRecorder()
...@@ -80,6 +80,8 @@ func TestGzipHandler(t *testing.T) { ...@@ -80,6 +80,8 @@ func TestGzipHandler(t *testing.T) {
func nextFunc(shouldGzip bool) middleware.Handler { func nextFunc(shouldGzip bool) middleware.Handler {
return middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { return middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
w.WriteHeader(200)
w.Write([]byte("test"))
if shouldGzip { if shouldGzip {
if r.Header.Get("Accept-Encoding") != "" { if r.Header.Get("Accept-Encoding") != "" {
return 0, fmt.Errorf("Accept-Encoding header not expected") return 0, fmt.Errorf("Accept-Encoding header not expected")
......
...@@ -7,8 +7,8 @@ import ( ...@@ -7,8 +7,8 @@ import (
"github.com/mholt/caddy/middleware" "github.com/mholt/caddy/middleware"
) )
// Filter determines if a request should be gzipped. // RequestFilter determines if a request should be gzipped.
type Filter interface { type RequestFilter interface {
// ShouldCompress tells if gzip compression // ShouldCompress tells if gzip compression
// should be done on the request. // should be done on the request.
ShouldCompress(*http.Request) bool ShouldCompress(*http.Request) bool
...@@ -26,7 +26,7 @@ func DefaultExtFilter() ExtFilter { ...@@ -26,7 +26,7 @@ func DefaultExtFilter() ExtFilter {
return m return m
} }
// ExtFilter is Filter for file name extensions. // ExtFilter is RequestFilter for file name extensions.
type ExtFilter struct { type ExtFilter struct {
// Exts is the file name extensions to accept // Exts is the file name extensions to accept
Exts Set Exts Set
...@@ -43,7 +43,7 @@ func (e ExtFilter) ShouldCompress(r *http.Request) bool { ...@@ -43,7 +43,7 @@ func (e ExtFilter) ShouldCompress(r *http.Request) bool {
return e.Exts.Contains(ExtWildCard) || e.Exts.Contains(ext) return e.Exts.Contains(ExtWildCard) || e.Exts.Contains(ext)
} }
// PathFilter is Filter for request path. // PathFilter is RequestFilter for request path.
type PathFilter struct { type PathFilter struct {
// IgnoredPaths is the paths to ignore // IgnoredPaths is the paths to ignore
IgnoredPaths Set IgnoredPaths Set
......
...@@ -47,7 +47,7 @@ func TestSet(t *testing.T) { ...@@ -47,7 +47,7 @@ func TestSet(t *testing.T) {
} }
func TestExtFilter(t *testing.T) { func TestExtFilter(t *testing.T) {
var filter Filter = ExtFilter{make(Set)} var filter RequestFilter = ExtFilter{make(Set)}
for _, e := range []string{".txt", ".html", ".css", ".md"} { for _, e := range []string{".txt", ".html", ".css", ".md"} {
filter.(ExtFilter).Exts.Add(e) filter.(ExtFilter).Exts.Add(e)
} }
...@@ -86,7 +86,7 @@ func TestPathFilter(t *testing.T) { ...@@ -86,7 +86,7 @@ func TestPathFilter(t *testing.T) {
paths := []string{ paths := []string{
"/a", "/b", "/c", "/de", "/a", "/b", "/c", "/de",
} }
var filter Filter = PathFilter{make(Set)} var filter RequestFilter = PathFilter{make(Set)}
for _, p := range paths { for _, p := range paths {
filter.(PathFilter).IgnoredPaths.Add(p) filter.(PathFilter).IgnoredPaths.Add(p)
} }
......
package gzip
import (
"compress/gzip"
"net/http"
"strconv"
)
// ResponseFilter determines if the response should be gzipped.
type ResponseFilter interface {
ShouldCompress(http.ResponseWriter) bool
}
// LengthFilter is ResponseFilter for minimum content length.
type LengthFilter int64
// ShouldCompress returns if content length is greater than or
// equals to minimum length.
func (l LengthFilter) ShouldCompress(w http.ResponseWriter) bool {
contentLength := w.Header().Get("Content-Length")
length, err := strconv.ParseInt(contentLength, 10, 64)
if err != nil || length == 0 {
return false
}
return l != 0 && int64(l) <= length
}
// ResponseFilterWriter validates ResponseFilters. It writes
// gzip compressed data if ResponseFilters are satisfied or
// uncompressed data otherwise.
type ResponseFilterWriter struct {
filters []ResponseFilter
shouldCompress bool
gzipResponseWriter
}
// NewResponseFilterWriter creates and initializes a new ResponseFilterWriter.
func NewResponseFilterWriter(filters []ResponseFilter, gz gzipResponseWriter) *ResponseFilterWriter {
return &ResponseFilterWriter{filters: filters, gzipResponseWriter: gz}
}
// Write wraps underlying Write method and compresses if filters
// are satisfied.
func (r *ResponseFilterWriter) WriteHeader(code int) {
// Determine if compression should be used or not.
r.shouldCompress = true
for _, filter := range r.filters {
if !filter.ShouldCompress(r) {
r.shouldCompress = false
break
}
}
if r.shouldCompress {
// replace discard writer with ResponseWriter
if gzWriter, ok := r.gzipResponseWriter.Writer.(*gzip.Writer); ok {
gzWriter.Reset(r.ResponseWriter)
}
// use gzip WriteHeader to include and delete
// necessary headers
r.gzipResponseWriter.WriteHeader(code)
} else {
r.ResponseWriter.WriteHeader(code)
}
}
// Write wraps underlying Write method and compresses if filters
// are satisfied
func (r *ResponseFilterWriter) Write(b []byte) (int, error) {
if r.shouldCompress {
return r.gzipResponseWriter.Write(b)
}
return r.ResponseWriter.Write(b)
}
package gzip
import (
"compress/gzip"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/mholt/caddy/middleware"
)
func TestLengthFilter(t *testing.T) {
var filters []ResponseFilter = []ResponseFilter{
LengthFilter(100),
LengthFilter(1000),
LengthFilter(0),
}
var tests = []struct {
length int64
shouldCompress [3]bool
}{
{20, [3]bool{false, false, false}},
{50, [3]bool{false, false, false}},
{100, [3]bool{true, false, false}},
{500, [3]bool{true, false, false}},
{1000, [3]bool{true, true, false}},
{1500, [3]bool{true, true, false}},
}
for i, ts := range tests {
for j, filter := range filters {
r := httptest.NewRecorder()
r.Header().Set("Content-Length", fmt.Sprint(ts.length))
wWriter := NewResponseFilterWriter([]ResponseFilter{filter}, gzipResponseWriter{gzip.NewWriter(r), r})
if filter.ShouldCompress(wWriter) != ts.shouldCompress[j] {
t.Errorf("Test %v: Expected %v found %v", i, ts.shouldCompress[j], filter.ShouldCompress(r))
}
}
}
}
func TestResponseFilterWriter(t *testing.T) {
tests := []struct {
body string
shouldCompress bool
}{
{"Hello\t\t\t\n", false},
{"Hello the \t\t\t world is\n\n\n great", true},
{"Hello \t\t\nfrom gzip", true},
{"Hello gzip\n", false},
}
filters := []ResponseFilter{
LengthFilter(15),
}
server := Gzip{Configs: []Config{
{ResponseFilters: filters},
}}
for i, ts := range tests {
server.Next = middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
w.Header().Set("Content-Length", fmt.Sprint(len(ts.body)))
w.WriteHeader(200)
w.Write([]byte(ts.body))
return 200, nil
})
r := urlRequest("/")
r.Header.Set("Accept-Encoding", "gzip")
w := httptest.NewRecorder()
server.ServeHTTP(w, r)
resp := w.Body.String()
if !ts.shouldCompress {
if resp != ts.body {
t.Errorf("Test %v: No compression expected, found %v", i, resp)
}
} else {
if resp == ts.body {
t.Errorf("Test %v: Compression expected, found %v", i, resp)
}
}
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment