Commit e0f10c2b authored by Abiola Ibrahim's avatar Abiola Ibrahim

Gzip: Accept MIME types.

parent 01aca02e
...@@ -28,7 +28,8 @@ func gzipParse(c *Controller) ([]gzip.Config, error) { ...@@ -28,7 +28,8 @@ func gzipParse(c *Controller) ([]gzip.Config, error) {
config := gzip.Config{} config := gzip.Config{}
pathFilter := gzip.PathFilter{make(gzip.Set)} pathFilter := gzip.PathFilter{make(gzip.Set)}
extFilter := gzip.DefaultExtFilter() mimeFilter := gzip.MIMEFilter{make(gzip.Set)}
extFilter := gzip.ExtFilter{make(gzip.Set)}
// no extra args expected // no extra args expected
if len(c.RemainingArgs()) > 0 { if len(c.RemainingArgs()) > 0 {
...@@ -37,6 +38,17 @@ func gzipParse(c *Controller) ([]gzip.Config, error) { ...@@ -37,6 +38,17 @@ func gzipParse(c *Controller) ([]gzip.Config, error) {
for c.NextBlock() { for c.NextBlock() {
switch c.Val() { switch c.Val() {
case "mimes":
mimes := c.RemainingArgs()
if len(mimes) == 0 {
return configs, c.ArgErr()
}
for _, m := range mimes {
if !gzip.ValidMIME(m) {
return configs, fmt.Errorf("Invalid MIME %v.", m)
}
mimeFilter.Types.Add(m)
}
case "ext": case "ext":
exts := c.RemainingArgs() exts := c.RemainingArgs()
if len(exts) == 0 { if len(exts) == 0 {
...@@ -74,8 +86,25 @@ func gzipParse(c *Controller) ([]gzip.Config, error) { ...@@ -74,8 +86,25 @@ func gzipParse(c *Controller) ([]gzip.Config, error) {
} }
} }
// put pathFilter in front to filter with path first config.Filters = []gzip.Filter{}
config.Filters = []gzip.Filter{pathFilter, extFilter}
// if ignored paths are specified, put in front to filter with path first
if len(pathFilter.IgnoredPaths) > 0 {
config.Filters = []gzip.Filter{pathFilter}
}
// if mime types are specified, use it and ignore extensions
if len(mimeFilter.Types) > 0 {
config.Filters = append(config.Filters, mimeFilter)
// if extensions are specified, use it
} else if len(extFilter.Exts) > 0 {
config.Filters = append(config.Filters, extFilter)
// neither is specified, use default mime types
} else {
config.Filters = append(config.Filters, gzip.DefaultMIMEFilter())
}
configs = append(configs, config) configs = append(configs, config)
} }
......
...@@ -59,14 +59,35 @@ func TestGzip(t *testing.T) { ...@@ -59,14 +59,35 @@ func TestGzip(t *testing.T) {
level 3 level 3
} }
`, false}, `, false},
{`gzip { mimes text/html
}`, false},
{`gzip { mimes text/html application/json
}`, false},
{`gzip { mimes text/html application/
}`, true},
{`gzip { mimes text/html /json
}`, true},
{`gzip { mimes /json text/html
}`, true},
{`gzip { not /file
ext .html
level 1
mimes text/html text/plain
}
gzip { not /file1
ext .htm
level 3
mimes text/html text/css
}
`, false},
} }
for i, test := range tests { for i, test := range tests {
c := newTestController(test.input) c := newTestController(test.input)
_, err := gzipParse(c) _, err := gzipParse(c)
if test.shouldErr && err == nil { if test.shouldErr && err == nil {
t.Errorf("Text %v: Expected error but found nil", i) t.Errorf("Test %v: Expected error but found nil", i)
} else if !test.shouldErr && err != nil { } else if !test.shouldErr && err != nil {
t.Errorf("Text %v: Expected no error but found error: ", i, err) t.Errorf("Test %v: Expected no error but found error: %v", i, err)
} }
} }
} }
...@@ -3,13 +3,14 @@ package gzip ...@@ -3,13 +3,14 @@ package gzip
import ( import (
"net/http" "net/http"
"path" "path"
"strings"
"github.com/mholt/caddy/middleware" "github.com/mholt/caddy/middleware"
) )
// Filter determines if a request should be gzipped. // Filter determines if a request should be gzipped.
type Filter interface { type Filter interface {
// ShouldCompress tells if compression gzip compression // ShouldCompress tells if gzip compression
// should be done on the request. // should be done on the request.
ShouldCompress(*http.Request) bool ShouldCompress(*http.Request) bool
} }
...@@ -20,24 +21,12 @@ type ExtFilter struct { ...@@ -20,24 +21,12 @@ type ExtFilter struct {
Exts Set Exts Set
} }
// textExts is a list of extensions for text related files.
var textExts = []string{
".html", ".htm", ".css", ".json", ".php", ".js", ".txt", ".md", ".xml",
}
// extWildCard is the wildcard for extensions. // extWildCard is the wildcard for extensions.
const extWildCard = "*" const extWildCard = "*"
// DefaultExtFilter creates a default ExtFilter with // ShouldCompress checks if the request file extension matches any
// file extensions for text types. // of the registered extensions. It returns true if the extension is
func DefaultExtFilter() ExtFilter { // found and false otherwise.
e := ExtFilter{make(Set)}
for _, ext := range textExts {
e.Exts.Add(ext)
}
return e
}
func (e ExtFilter) ShouldCompress(r *http.Request) bool { func (e ExtFilter) ShouldCompress(r *http.Request) bool {
ext := path.Ext(r.URL.Path) ext := path.Ext(r.URL.Path)
return e.Exts.Contains(extWildCard) || e.Exts.Contains(ext) return e.Exts.Contains(extWildCard) || e.Exts.Contains(ext)
...@@ -50,7 +39,7 @@ type PathFilter struct { ...@@ -50,7 +39,7 @@ type PathFilter struct {
} }
// ShouldCompress checks if the request path matches any of the // ShouldCompress checks if the request path matches any of the
// registered paths to ignore. If returns false if an ignored path // registered paths to ignore. It returns false if an ignored path
// is found and true otherwise. // is found and true otherwise.
func (p PathFilter) ShouldCompress(r *http.Request) bool { func (p PathFilter) ShouldCompress(r *http.Request) bool {
return !p.IgnoredPaths.ContainsFunc(func(value string) bool { return !p.IgnoredPaths.ContainsFunc(func(value string) bool {
...@@ -58,6 +47,39 @@ func (p PathFilter) ShouldCompress(r *http.Request) bool { ...@@ -58,6 +47,39 @@ func (p PathFilter) ShouldCompress(r *http.Request) bool {
}) })
} }
// MIMEFilter is Filter for request content types.
type MIMEFilter struct {
// Types is the MIME types to accept.
Types Set
}
// defaultMIMETypes is the list of default MIME types to use.
var defaultMIMETypes = []string{
"text/plain", "text/html", "text/css", "application/json", "application/javascript",
"text/x-markdown", "text/xml", "application/xml",
}
// DefaultMIMEFilter creates a MIMEFilter with default types.
func DefaultMIMEFilter() MIMEFilter {
m := MIMEFilter{Types: make(Set)}
for _, mime := range defaultMIMETypes {
m.Types.Add(mime)
}
return m
}
// ShouldCompress checks if the content type of the request
// matches any of the registered ones. It returns true if
// found and false otherwise.
func (m MIMEFilter) ShouldCompress(r *http.Request) bool {
return m.Types.Contains(r.Header.Get("Content-Type"))
}
func ValidMIME(mime string) bool {
s := strings.Split(mime, "/")
return len(s) == 2 && strings.TrimSpace(s[0]) != "" && strings.TrimSpace(s[1]) != ""
}
// Set stores distinct strings. // Set stores distinct strings.
type Set map[string]struct{} type Set map[string]struct{}
......
...@@ -47,13 +47,13 @@ func TestSet(t *testing.T) { ...@@ -47,13 +47,13 @@ func TestSet(t *testing.T) {
} }
func TestExtFilter(t *testing.T) { func TestExtFilter(t *testing.T) {
var filter Filter = DefaultExtFilter() var filter Filter = ExtFilter{make(Set)}
_ = filter.(ExtFilter) for _, e := range []string{".txt", ".html", ".css", ".md"} {
for i, e := range textExts { filter.(ExtFilter).Exts.Add(e)
r := urlRequest("file" + e) }
if !filter.ShouldCompress(r) { r := urlRequest("file.txt")
t.Errorf("Test %v: Should be valid filter", i) if !filter.ShouldCompress(r) {
} t.Errorf("Should be valid filter")
} }
var exts = []string{ var exts = []string{
".html", ".css", ".md", ".html", ".css", ".md",
...@@ -100,6 +100,32 @@ func TestPathFilter(t *testing.T) { ...@@ -100,6 +100,32 @@ func TestPathFilter(t *testing.T) {
} }
} }
func TestMIMEFilter(t *testing.T) {
var filter Filter = DefaultMIMEFilter()
_ = filter.(MIMEFilter)
var mimes = []string{
"text/html", "text/css", "application/json",
}
for i, m := range mimes {
r := urlRequest("file" + m)
r.Header.Set("Content-Type", m)
if !filter.ShouldCompress(r) {
t.Errorf("Test %v: Should be valid filter", i)
}
}
mimes = []string{
"image/jpeg", "image/png",
}
filter = DefaultMIMEFilter()
for i, m := range mimes {
r := urlRequest("file" + m)
r.Header.Set("Content-Type", m)
if filter.ShouldCompress(r) {
t.Errorf("Test %v: Should not be valid filter", i)
}
}
}
func urlRequest(url string) *http.Request { func urlRequest(url string) *http.Request {
r, _ := http.NewRequest("GET", url, nil) r, _ := http.NewRequest("GET", url, nil)
return r return r
......
...@@ -16,13 +16,20 @@ func TestGzipHandler(t *testing.T) { ...@@ -16,13 +16,20 @@ func TestGzipHandler(t *testing.T) {
for _, p := range badPaths { for _, p := range badPaths {
pathFilter.IgnoredPaths.Add(p) pathFilter.IgnoredPaths.Add(p)
} }
extFilter := ExtFilter{make(Set)}
for _, e := range []string{".txt", ".html", ".css", ".md"} {
extFilter.Exts.Add(e)
}
gz := Gzip{Configs: []Config{ gz := Gzip{Configs: []Config{
Config{Filters: []Filter{DefaultExtFilter(), pathFilter}}, Config{Filters: []Filter{pathFilter, extFilter}},
}} }}
w := httptest.NewRecorder() w := httptest.NewRecorder()
gz.Next = nextFunc(true) gz.Next = nextFunc(true)
for _, e := range textExts { var exts = []string{
".html", ".css", ".md",
}
for _, e := range exts {
url := "/file" + e url := "/file" + e
r, err := http.NewRequest("GET", url, nil) r, err := http.NewRequest("GET", url, nil)
if err != nil { if err != nil {
...@@ -38,7 +45,7 @@ func TestGzipHandler(t *testing.T) { ...@@ -38,7 +45,7 @@ func TestGzipHandler(t *testing.T) {
w = httptest.NewRecorder() w = httptest.NewRecorder()
gz.Next = nextFunc(false) gz.Next = nextFunc(false)
for _, p := range badPaths { for _, p := range badPaths {
for _, e := range textExts { for _, e := range exts {
url := p + "/file" + e url := p + "/file" + e
r, err := http.NewRequest("GET", url, nil) r, err := http.NewRequest("GET", url, nil)
if err != nil { if err != nil {
...@@ -54,7 +61,7 @@ func TestGzipHandler(t *testing.T) { ...@@ -54,7 +61,7 @@ func TestGzipHandler(t *testing.T) {
w = httptest.NewRecorder() w = httptest.NewRecorder()
gz.Next = nextFunc(false) gz.Next = nextFunc(false)
exts := []string{ exts = []string{
".htm1", ".abc", ".mdx", ".htm1", ".abc", ".mdx",
} }
for _, e := range exts { for _, e := range exts {
...@@ -70,6 +77,45 @@ func TestGzipHandler(t *testing.T) { ...@@ -70,6 +77,45 @@ func TestGzipHandler(t *testing.T) {
} }
} }
gz.Configs[0].Filters[1] = DefaultMIMEFilter()
w = httptest.NewRecorder()
gz.Next = nextFunc(true)
var mimes = []string{
"text/html", "text/css", "application/json",
}
for _, m := range mimes {
url := "/file"
r, err := http.NewRequest("GET", url, nil)
if err != nil {
t.Error(err)
}
r.Header.Set("Content-Type", m)
r.Header.Set("Accept-Encoding", "gzip")
_, err = gz.ServeHTTP(w, r)
if err != nil {
t.Error(err)
}
}
w = httptest.NewRecorder()
gz.Next = nextFunc(false)
mimes = []string{
"image/jpeg", "image/png",
}
for _, m := range mimes {
url := "/file"
r, err := http.NewRequest("GET", url, nil)
if err != nil {
t.Error(err)
}
r.Header.Set("Content-Type", m)
r.Header.Set("Accept-Encoding", "gzip")
_, err = gz.ServeHTTP(w, r)
if err != nil {
t.Error(err)
}
}
} }
func nextFunc(shouldGzip bool) middleware.Handler { func nextFunc(shouldGzip bool) middleware.Handler {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment