Commit 23631cfa authored by Abiola Ibrahim's avatar Abiola Ibrahim

Fix deleted Content-Length header bug.

parent 8631f339
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
package gzip package gzip
import ( import (
"bytes"
"compress/gzip" "compress/gzip"
"fmt" "fmt"
"io" "io"
...@@ -47,9 +48,13 @@ outer: ...@@ -47,9 +48,13 @@ outer:
// Delete this header so gzipping is not repeated later in the chain // Delete this header so gzipping is not repeated later in the chain
r.Header.Del("Accept-Encoding") r.Header.Del("Accept-Encoding")
w.Header().Set("Content-Encoding", "gzip") // gzipWriter modifies underlying writer at init,
w.Header().Set("Vary", "Accept-Encoding") // use a buffer instead to leave ResponseWriter in
gzipWriter, err := newWriter(c, w) // original form.
var buf = &bytes.Buffer{}
defer buf.Reset()
gzipWriter, err := newWriter(c, buf)
if err != nil { if err != nil {
// should not happen // should not happen
return http.StatusInternalServerError, err return http.StatusInternalServerError, err
...@@ -60,6 +65,8 @@ outer: ...@@ -60,6 +65,8 @@ outer:
var rw http.ResponseWriter var rw http.ResponseWriter
// if no response filter is used // if no response filter is used
if len(c.ResponseFilters) == 0 { if len(c.ResponseFilters) == 0 {
// replace buffer with ResponseWriter
gzipWriter.Reset(w)
rw = gz rw = gz
} else { } else {
// wrap gzip writer with ResponseFilterWriter // wrap gzip writer with ResponseFilterWriter
...@@ -88,7 +95,7 @@ outer: ...@@ -88,7 +95,7 @@ outer:
// newWriter create a new Gzip Writer based on the compression level. // newWriter create a new Gzip Writer based on the compression level.
// If the level is valid (i.e. between 1 and 9), it uses the level. // If the level is valid (i.e. between 1 and 9), it uses the level.
// Otherwise, it uses default compression level. // Otherwise, it uses default compression level.
func newWriter(c Config, w http.ResponseWriter) (*gzip.Writer, error) { func newWriter(c Config, w io.Writer) (*gzip.Writer, error) {
if c.Level >= gzip.BestSpeed && c.Level <= gzip.BestCompression { if c.Level >= gzip.BestSpeed && c.Level <= gzip.BestCompression {
return gzip.NewWriterLevel(w, c.Level) return gzip.NewWriterLevel(w, c.Level)
} }
...@@ -108,6 +115,8 @@ type gzipResponseWriter struct { ...@@ -108,6 +115,8 @@ type gzipResponseWriter struct {
// be wrong because it doesn't know it's being gzipped. // be wrong because it doesn't know it's being gzipped.
func (w gzipResponseWriter) WriteHeader(code int) { func (w gzipResponseWriter) WriteHeader(code int) {
w.Header().Del("Content-Length") w.Header().Del("Content-Length")
w.Header().Set("Content-Encoding", "gzip")
w.Header().Set("Vary", "Accept-Encoding")
w.ResponseWriter.WriteHeader(code) w.ResponseWriter.WriteHeader(code)
} }
......
...@@ -80,6 +80,8 @@ func TestGzipHandler(t *testing.T) { ...@@ -80,6 +80,8 @@ func TestGzipHandler(t *testing.T) {
func nextFunc(shouldGzip bool) middleware.Handler { func nextFunc(shouldGzip bool) middleware.Handler {
return middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { return middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
w.WriteHeader(200)
w.Write([]byte("test"))
if shouldGzip { if shouldGzip {
if r.Header.Get("Accept-Encoding") != "" { if r.Header.Get("Accept-Encoding") != "" {
return 0, fmt.Errorf("Accept-Encoding header not expected") return 0, fmt.Errorf("Accept-Encoding header not expected")
......
package gzip package gzip
import ( import (
"compress/gzip"
"net/http" "net/http"
"strconv" "strconv"
) )
...@@ -29,7 +30,6 @@ func (l LengthFilter) ShouldCompress(w http.ResponseWriter) bool { ...@@ -29,7 +30,6 @@ func (l LengthFilter) ShouldCompress(w http.ResponseWriter) bool {
// uncompressed data otherwise. // uncompressed data otherwise.
type ResponseFilterWriter struct { type ResponseFilterWriter struct {
filters []ResponseFilter filters []ResponseFilter
validated bool
shouldCompress bool shouldCompress bool
gzipResponseWriter gzipResponseWriter
} }
...@@ -40,11 +40,9 @@ func NewResponseFilterWriter(filters []ResponseFilter, gz gzipResponseWriter) *R ...@@ -40,11 +40,9 @@ func NewResponseFilterWriter(filters []ResponseFilter, gz gzipResponseWriter) *R
} }
// Write wraps underlying Write method and compresses if filters // Write wraps underlying Write method and compresses if filters
// are satisfied // are satisfied.
func (r *ResponseFilterWriter) Write(b []byte) (int, error) { func (r *ResponseFilterWriter) WriteHeader(code int) {
// One time validation to determine if compression should // Determine if compression should be used or not.
// be used or not.
if !r.validated {
r.shouldCompress = true r.shouldCompress = true
for _, filter := range r.filters { for _, filter := range r.filters {
if !filter.ShouldCompress(r) { if !filter.ShouldCompress(r) {
...@@ -52,9 +50,23 @@ func (r *ResponseFilterWriter) Write(b []byte) (int, error) { ...@@ -52,9 +50,23 @@ func (r *ResponseFilterWriter) Write(b []byte) (int, error) {
break break
} }
} }
r.validated = true
if r.shouldCompress {
// replace buffer with ResponseWriter
if gzWriter, ok := r.gzipResponseWriter.Writer.(*gzip.Writer); ok {
gzWriter.Reset(r.ResponseWriter)
}
// use gzip WriteHeader to include and delete
// necessary headers
r.gzipResponseWriter.WriteHeader(code)
} else {
r.ResponseWriter.WriteHeader(code)
} }
}
// Write wraps underlying Write method and compresses if filters
// are satisfied
func (r *ResponseFilterWriter) Write(b []byte) (int, error) {
if r.shouldCompress { if r.shouldCompress {
return r.gzipResponseWriter.Write(b) return r.gzipResponseWriter.Write(b)
} }
......
...@@ -3,8 +3,11 @@ package gzip ...@@ -3,8 +3,11 @@ package gzip
import ( import (
"compress/gzip" "compress/gzip"
"fmt" "fmt"
"net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"github.com/mholt/caddy/middleware"
) )
func TestLengthFilter(t *testing.T) { func TestLengthFilter(t *testing.T) {
...@@ -30,7 +33,8 @@ func TestLengthFilter(t *testing.T) { ...@@ -30,7 +33,8 @@ func TestLengthFilter(t *testing.T) {
for j, filter := range filters { for j, filter := range filters {
r := httptest.NewRecorder() r := httptest.NewRecorder()
r.Header().Set("Content-Length", fmt.Sprint(ts.length)) r.Header().Set("Content-Length", fmt.Sprint(ts.length))
if filter.ShouldCompress(r) != ts.shouldCompress[j] { wWriter := NewResponseFilterWriter([]ResponseFilter{filter}, gzipResponseWriter{gzip.NewWriter(r), r})
if filter.ShouldCompress(wWriter) != ts.shouldCompress[j] {
t.Errorf("Test %v: Expected %v found %v", i, ts.shouldCompress[j], filter.ShouldCompress(r)) t.Errorf("Test %v: Expected %v found %v", i, ts.shouldCompress[j], filter.ShouldCompress(r))
} }
} }
...@@ -47,16 +51,32 @@ func TestResponseFilterWriter(t *testing.T) { ...@@ -47,16 +51,32 @@ func TestResponseFilterWriter(t *testing.T) {
{"Hello \t\t\nfrom gzip", true}, {"Hello \t\t\nfrom gzip", true},
{"Hello gzip\n", false}, {"Hello gzip\n", false},
} }
filters := []ResponseFilter{ filters := []ResponseFilter{
LengthFilter(15), LengthFilter(15),
} }
server := Gzip{Configs: []Config{
{ResponseFilters: filters},
}}
for i, ts := range tests { for i, ts := range tests {
w := httptest.NewRecorder() server.Next = middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
w.Header().Set("Content-Length", fmt.Sprint(len(ts.body))) w.Header().Set("Content-Length", fmt.Sprint(len(ts.body)))
gz := gzipResponseWriter{gzip.NewWriter(w), w} w.WriteHeader(200)
rw := NewResponseFilterWriter(filters, gz) w.Write([]byte(ts.body))
rw.Write([]byte(ts.body)) return 200, nil
})
r := urlRequest("/")
r.Header.Set("Accept-Encoding", "gzip")
w := httptest.NewRecorder()
server.ServeHTTP(w, r)
resp := w.Body.String() resp := w.Body.String()
if !ts.shouldCompress { if !ts.shouldCompress {
if resp != ts.body { if resp != ts.body {
t.Errorf("Test %v: No compression expected, found %v", i, resp) t.Errorf("Test %v: No compression expected, found %v", i, resp)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment