Commit c1082be1 authored by Patrick Bajao's avatar Patrick Bajao

Merge branch 'jv-router-escape-path-8-64' into '8-64-stable'

Use URL.EscapePath() in upstream router (8-64-stable)

See merge request gitlab-org/security/gitlab-workhorse!29
parents d2abf271 e2b0d449
---
title: Use URL.EscapePath() in upstream router
merge_request:
author:
type: security
...@@ -23,7 +23,7 @@ func TestIfNoDeployPageExist(t *testing.T) { ...@@ -23,7 +23,7 @@ func TestIfNoDeployPageExist(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
executed := false executed := false
st := &Static{dir} st := &Static{DocumentRoot: dir}
st.DeployPage(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { st.DeployPage(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
executed = true executed = true
})).ServeHTTP(w, nil) })).ServeHTTP(w, nil)
...@@ -45,7 +45,7 @@ func TestIfDeployPageExist(t *testing.T) { ...@@ -45,7 +45,7 @@ func TestIfDeployPageExist(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
executed := false executed := false
st := &Static{dir} st := &Static{DocumentRoot: dir}
st.DeployPage(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { st.DeployPage(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
executed = true executed = true
})).ServeHTTP(w, nil) })).ServeHTTP(w, nil)
......
...@@ -32,7 +32,7 @@ func TestIfErrorPageIsPresented(t *testing.T) { ...@@ -32,7 +32,7 @@ func TestIfErrorPageIsPresented(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, len(upstreamBody), n, "bytes written") require.Equal(t, len(upstreamBody), n, "bytes written")
}) })
st := &Static{dir} st := &Static{DocumentRoot: dir}
st.ErrorPagesUnless(false, ErrorFormatHTML, h).ServeHTTP(w, nil) st.ErrorPagesUnless(false, ErrorFormatHTML, h).ServeHTTP(w, nil)
w.Flush() w.Flush()
...@@ -54,7 +54,7 @@ func TestIfErrorPassedIfNoErrorPageIsFound(t *testing.T) { ...@@ -54,7 +54,7 @@ func TestIfErrorPassedIfNoErrorPageIsFound(t *testing.T) {
w.WriteHeader(404) w.WriteHeader(404)
fmt.Fprint(w, errorResponse) fmt.Fprint(w, errorResponse)
}) })
st := &Static{dir} st := &Static{DocumentRoot: dir}
st.ErrorPagesUnless(false, ErrorFormatHTML, h).ServeHTTP(w, nil) st.ErrorPagesUnless(false, ErrorFormatHTML, h).ServeHTTP(w, nil)
w.Flush() w.Flush()
...@@ -78,7 +78,7 @@ func TestIfErrorPageIsIgnoredInDevelopment(t *testing.T) { ...@@ -78,7 +78,7 @@ func TestIfErrorPageIsIgnoredInDevelopment(t *testing.T) {
w.WriteHeader(500) w.WriteHeader(500)
fmt.Fprint(w, serverError) fmt.Fprint(w, serverError)
}) })
st := &Static{dir} st := &Static{DocumentRoot: dir}
st.ErrorPagesUnless(true, ErrorFormatHTML, h).ServeHTTP(w, nil) st.ErrorPagesUnless(true, ErrorFormatHTML, h).ServeHTTP(w, nil)
w.Flush() w.Flush()
require.Equal(t, 500, w.Code) require.Equal(t, 500, w.Code)
...@@ -102,7 +102,7 @@ func TestIfErrorPageIsIgnoredIfCustomError(t *testing.T) { ...@@ -102,7 +102,7 @@ func TestIfErrorPageIsIgnoredIfCustomError(t *testing.T) {
w.WriteHeader(500) w.WriteHeader(500)
fmt.Fprint(w, serverError) fmt.Fprint(w, serverError)
}) })
st := &Static{dir} st := &Static{DocumentRoot: dir}
st.ErrorPagesUnless(false, ErrorFormatHTML, h).ServeHTTP(w, nil) st.ErrorPagesUnless(false, ErrorFormatHTML, h).ServeHTTP(w, nil)
w.Flush() w.Flush()
require.Equal(t, 500, w.Code) require.Equal(t, 500, w.Code)
...@@ -137,7 +137,7 @@ func TestErrorPageInterceptedByContentType(t *testing.T) { ...@@ -137,7 +137,7 @@ func TestErrorPageInterceptedByContentType(t *testing.T) {
w.WriteHeader(500) w.WriteHeader(500)
fmt.Fprint(w, serverError) fmt.Fprint(w, serverError)
}) })
st := &Static{dir} st := &Static{DocumentRoot: dir}
st.ErrorPagesUnless(false, ErrorFormatHTML, h).ServeHTTP(w, nil) st.ErrorPagesUnless(false, ErrorFormatHTML, h).ServeHTTP(w, nil)
w.Flush() w.Flush()
require.Equal(t, 500, w.Code) require.Equal(t, 500, w.Code)
...@@ -161,7 +161,7 @@ func TestIfErrorPageIsPresentedJSON(t *testing.T) { ...@@ -161,7 +161,7 @@ func TestIfErrorPageIsPresentedJSON(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, len(upstreamBody), n, "bytes written") require.Equal(t, len(upstreamBody), n, "bytes written")
}) })
st := &Static{""} st := &Static{}
st.ErrorPagesUnless(false, ErrorFormatJSON, h).ServeHTTP(w, nil) st.ErrorPagesUnless(false, ErrorFormatJSON, h).ServeHTTP(w, nil)
w.Flush() w.Flush()
...@@ -181,7 +181,7 @@ func TestIfErrorPageIsPresentedText(t *testing.T) { ...@@ -181,7 +181,7 @@ func TestIfErrorPageIsPresentedText(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, len(upstreamBody), n, "bytes written") require.Equal(t, len(upstreamBody), n, "bytes written")
}) })
st := &Static{""} st := &Static{}
st.ErrorPagesUnless(false, ErrorFormatText, h).ServeHTTP(w, nil) st.ErrorPagesUnless(false, ErrorFormatText, h).ServeHTTP(w, nil)
w.Flush() w.Flush()
......
package staticpages package staticpages
import ( import (
"errors"
"fmt"
"net/http" "net/http"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
"time" "time"
"gitlab.com/gitlab-org/labkit/log"
"gitlab.com/gitlab-org/labkit/mask" "gitlab.com/gitlab-org/labkit/mask"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/log"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/urlprefix" "gitlab.com/gitlab-org/gitlab-workhorse/internal/urlprefix"
) )
...@@ -26,21 +28,29 @@ const ( ...@@ -26,21 +28,29 @@ const (
// upstream. // upstream.
func (s *Static) ServeExisting(prefix urlprefix.Prefix, cache CacheMode, notFoundHandler http.Handler) http.Handler { func (s *Static) ServeExisting(prefix urlprefix.Prefix, cache CacheMode, notFoundHandler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
file := filepath.Join(s.DocumentRoot, prefix.Strip(r.URL.Path)) if notFoundHandler == nil {
notFoundHandler = http.HandlerFunc(http.NotFound)
}
// We intentionally use r.URL.Path instead of r.URL.EscaptedPath() below.
// This is to make it possible to serve static files with e.g. a space
// %20 in their name.
relativePath, err := s.validatePath(prefix.Strip(r.URL.Path))
if err != nil {
log.WithRequest(r).WithError(err).Error()
notFoundHandler.ServeHTTP(w, r)
return
}
// The filepath.Join does Clean traversing directories up file := filepath.Join(s.DocumentRoot, relativePath)
if !strings.HasPrefix(file, s.DocumentRoot) { if !strings.HasPrefix(file, s.DocumentRoot) {
helper.Fail500(w, r, &os.PathError{ log.WithRequest(r).WithError(errPathTraversal).Error()
Op: "open", notFoundHandler.ServeHTTP(w, r)
Path: file,
Err: os.ErrInvalid,
})
return return
} }
var content *os.File var content *os.File
var fi os.FileInfo var fi os.FileInfo
var err error
// Serve pre-gzipped assets // Serve pre-gzipped assets
if acceptEncoding := r.Header.Get("Accept-Encoding"); strings.Contains(acceptEncoding, "gzip") { if acceptEncoding := r.Header.Get("Accept-Encoding"); strings.Contains(acceptEncoding, "gzip") {
...@@ -55,11 +65,7 @@ func (s *Static) ServeExisting(prefix urlprefix.Prefix, cache CacheMode, notFoun ...@@ -55,11 +65,7 @@ func (s *Static) ServeExisting(prefix urlprefix.Prefix, cache CacheMode, notFoun
content, fi, err = helper.OpenFile(file) content, fi, err = helper.OpenFile(file)
} }
if err != nil { if err != nil {
if notFoundHandler != nil { notFoundHandler.ServeHTTP(w, r)
notFoundHandler.ServeHTTP(w, r)
} else {
http.NotFound(w, r)
}
return return
} }
defer content.Close() defer content.Close()
...@@ -82,3 +88,17 @@ func (s *Static) ServeExisting(prefix urlprefix.Prefix, cache CacheMode, notFoun ...@@ -82,3 +88,17 @@ func (s *Static) ServeExisting(prefix urlprefix.Prefix, cache CacheMode, notFoun
http.ServeContent(w, r, filepath.Base(file), fi.ModTime(), content) http.ServeContent(w, r, filepath.Base(file), fi.ModTime(), content)
}) })
} }
var errPathTraversal = errors.New("path traversal")
func (s *Static) validatePath(filename string) (string, error) {
filename = filepath.Clean(filename)
for _, exc := range s.Exclude {
if strings.HasPrefix(filename, exc) {
return "", fmt.Errorf("file is excluded: %s", exc)
}
}
return filename, nil
}
...@@ -20,7 +20,7 @@ func TestServingNonExistingFile(t *testing.T) { ...@@ -20,7 +20,7 @@ func TestServingNonExistingFile(t *testing.T) {
httpRequest, _ := http.NewRequest("GET", "/file", nil) httpRequest, _ := http.NewRequest("GET", "/file", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
st := &Static{dir} st := &Static{DocumentRoot: dir}
st.ServeExisting("/", CacheDisabled, nil).ServeHTTP(w, httpRequest) st.ServeExisting("/", CacheDisabled, nil).ServeHTTP(w, httpRequest)
require.Equal(t, 404, w.Code) require.Equal(t, 404, w.Code)
} }
...@@ -34,7 +34,7 @@ func TestServingDirectory(t *testing.T) { ...@@ -34,7 +34,7 @@ func TestServingDirectory(t *testing.T) {
httpRequest, _ := http.NewRequest("GET", "/file", nil) httpRequest, _ := http.NewRequest("GET", "/file", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
st := &Static{dir} st := &Static{DocumentRoot: dir}
st.ServeExisting("/", CacheDisabled, nil).ServeHTTP(w, httpRequest) st.ServeExisting("/", CacheDisabled, nil).ServeHTTP(w, httpRequest)
require.Equal(t, 404, w.Code) require.Equal(t, 404, w.Code)
} }
...@@ -44,7 +44,7 @@ func TestServingMalformedUri(t *testing.T) { ...@@ -44,7 +44,7 @@ func TestServingMalformedUri(t *testing.T) {
httpRequest, _ := http.NewRequest("GET", "/../../../static/file", nil) httpRequest, _ := http.NewRequest("GET", "/../../../static/file", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
st := &Static{dir} st := &Static{DocumentRoot: dir}
st.ServeExisting("/", CacheDisabled, nil).ServeHTTP(w, httpRequest) st.ServeExisting("/", CacheDisabled, nil).ServeHTTP(w, httpRequest)
require.Equal(t, 404, w.Code) require.Equal(t, 404, w.Code)
} }
...@@ -54,7 +54,7 @@ func TestExecutingHandlerWhenNoFileFound(t *testing.T) { ...@@ -54,7 +54,7 @@ func TestExecutingHandlerWhenNoFileFound(t *testing.T) {
httpRequest, _ := http.NewRequest("GET", "/file", nil) httpRequest, _ := http.NewRequest("GET", "/file", nil)
executed := false executed := false
st := &Static{dir} st := &Static{DocumentRoot: dir}
st.ServeExisting("/", CacheDisabled, http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { st.ServeExisting("/", CacheDisabled, http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
executed = (r == httpRequest) executed = (r == httpRequest)
})).ServeHTTP(nil, httpRequest) })).ServeHTTP(nil, httpRequest)
...@@ -76,7 +76,7 @@ func TestServingTheActualFile(t *testing.T) { ...@@ -76,7 +76,7 @@ func TestServingTheActualFile(t *testing.T) {
ioutil.WriteFile(filepath.Join(dir, "file"), []byte(fileContent), 0600) ioutil.WriteFile(filepath.Join(dir, "file"), []byte(fileContent), 0600)
w := httptest.NewRecorder() w := httptest.NewRecorder()
st := &Static{dir} st := &Static{DocumentRoot: dir}
st.ServeExisting("/", CacheDisabled, nil).ServeHTTP(w, httpRequest) st.ServeExisting("/", CacheDisabled, nil).ServeHTTP(w, httpRequest)
require.Equal(t, 200, w.Code) require.Equal(t, 200, w.Code)
if w.Body.String() != fileContent { if w.Body.String() != fileContent {
...@@ -84,6 +84,40 @@ func TestServingTheActualFile(t *testing.T) { ...@@ -84,6 +84,40 @@ func TestServingTheActualFile(t *testing.T) {
} }
} }
func TestExcludedPaths(t *testing.T) {
testCases := []struct {
desc string
path string
found bool
contents string
}{
{"allowed file", "/file1", true, "contents1"},
{"path traversal is allowed", "/uploads/../file1", true, "contents1"},
{"files in /uploads/ are invisible", "/uploads/file2", false, ""},
{"cannot use path traversal to get to /uploads/", "/foobar/../uploads/file2", false, ""},
{"cannot use escaped path traversal to get to /uploads/", "/foobar%2f%2e%2e%2fuploads/file2", false, ""},
{"cannot use double escaped path traversal to get to /uploads/", "/foobar%252f%252e%252e%252fuploads/file2", false, ""},
}
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
httpRequest, err := http.NewRequest("GET", tc.path, nil)
require.NoError(t, err)
w := httptest.NewRecorder()
st := &Static{DocumentRoot: "testdata", Exclude: []string{"/uploads/"}}
st.ServeExisting("/", CacheDisabled, nil).ServeHTTP(w, httpRequest)
if tc.found {
require.Equal(t, 200, w.Code)
require.Equal(t, tc.contents, w.Body.String())
} else {
require.Equal(t, 404, w.Code)
}
})
}
}
func testServingThePregzippedFile(t *testing.T, enableGzip bool) { func testServingThePregzippedFile(t *testing.T, enableGzip bool) {
dir, err := ioutil.TempDir("", "deploy") dir, err := ioutil.TempDir("", "deploy")
if err != nil { if err != nil {
...@@ -108,7 +142,7 @@ func testServingThePregzippedFile(t *testing.T, enableGzip bool) { ...@@ -108,7 +142,7 @@ func testServingThePregzippedFile(t *testing.T, enableGzip bool) {
ioutil.WriteFile(filepath.Join(dir, "file"), []byte(fileContent), 0600) ioutil.WriteFile(filepath.Join(dir, "file"), []byte(fileContent), 0600)
w := httptest.NewRecorder() w := httptest.NewRecorder()
st := &Static{dir} st := &Static{DocumentRoot: dir}
st.ServeExisting("/", CacheDisabled, nil).ServeHTTP(w, httpRequest) st.ServeExisting("/", CacheDisabled, nil).ServeHTTP(w, httpRequest)
require.Equal(t, 200, w.Code) require.Equal(t, 200, w.Code)
if enableGzip { if enableGzip {
......
...@@ -2,4 +2,5 @@ package staticpages ...@@ -2,4 +2,5 @@ package staticpages
type Static struct { type Static struct {
DocumentRoot string DocumentRoot string
Exclude []string
} }
contents1
\ No newline at end of file
contents2
\ No newline at end of file
...@@ -62,6 +62,14 @@ const ( ...@@ -62,6 +62,14 @@ const (
importPattern = `^/import/` importPattern = `^/import/`
) )
var (
// For legacy reasons, user uploads are stored in public/uploads. To
// prevent anybody who knows/guesses the URL of a user-uploaded file
// from downloading it we configure static.ServeExisting to treat files
// under public/uploads/ as if they do not exist.
staticExclude = []string{"/uploads/"}
)
func compileRegexp(regexpStr string) *regexp.Regexp { func compileRegexp(regexpStr string) *regexp.Regexp {
if len(regexpStr) == 0 { if len(regexpStr) == 0 {
return nil return nil
...@@ -181,20 +189,20 @@ func buildProxy(backend *url.URL, version string, rt http.RoundTripper, cfg conf ...@@ -181,20 +189,20 @@ func buildProxy(backend *url.URL, version string, rt http.RoundTripper, cfg conf
// We match against URI not containing the relativeUrlRoot: // We match against URI not containing the relativeUrlRoot:
// see upstream.ServeHTTP // see upstream.ServeHTTP
func (u *upstream) configureRoutes() { func configureRoutes(u *upstream) {
api := apipkg.NewAPI( api := apipkg.NewAPI(
u.Backend, u.Backend,
u.Version, u.Version,
u.RoundTripper, u.RoundTripper,
) )
static := &staticpages.Static{DocumentRoot: u.DocumentRoot} static := &staticpages.Static{DocumentRoot: u.DocumentRoot, Exclude: staticExclude}
proxy := buildProxy(u.Backend, u.Version, u.RoundTripper, u.Config) proxy := buildProxy(u.Backend, u.Version, u.RoundTripper, u.Config)
cableProxy := proxypkg.NewProxy(u.CableBackend, u.Version, u.CableRoundTripper) cableProxy := proxypkg.NewProxy(u.CableBackend, u.Version, u.CableRoundTripper)
assetsNotFoundHandler := NotFoundUnless(u.DevelopmentMode, proxy) assetsNotFoundHandler := NotFoundUnless(u.DevelopmentMode, proxy)
if u.AltDocumentRoot != "" { if u.AltDocumentRoot != "" {
altStatic := &staticpages.Static{DocumentRoot: u.AltDocumentRoot} altStatic := &staticpages.Static{DocumentRoot: u.AltDocumentRoot, Exclude: staticExclude}
assetsNotFoundHandler = altStatic.ServeExisting( assetsNotFoundHandler = altStatic.ServeExisting(
u.URLPrefix, u.URLPrefix,
staticpages.CacheExpireMax, staticpages.CacheExpireMax,
...@@ -306,12 +314,6 @@ func (u *upstream) configureRoutes() { ...@@ -306,12 +314,6 @@ func (u *upstream) configureRoutes() {
u.route("POST", snippetUploadPattern, upload.Accelerate(api, signingProxy, preparers.uploads)), u.route("POST", snippetUploadPattern, upload.Accelerate(api, signingProxy, preparers.uploads)),
u.route("POST", userUploadPattern, upload.Accelerate(api, signingProxy, preparers.uploads)), u.route("POST", userUploadPattern, upload.Accelerate(api, signingProxy, preparers.uploads)),
// For legacy reasons, user uploads are stored under the document root.
// To prevent anybody who knows/guesses the URL of a user-uploaded file
// from downloading it we make sure requests to /uploads/ do _not_ pass
// through static.ServeExisting.
u.route("", `^/uploads/`, static.ErrorPagesUnless(u.DevelopmentMode, staticpages.ErrorFormatHTML, proxy)),
// health checks don't intercept errors and go straight to rails // health checks don't intercept errors and go straight to rails
// TODO: We should probably not return a HTML deploy page? // TODO: We should probably not return a HTML deploy page?
// https://gitlab.com/gitlab-org/gitlab-workhorse/issues/230 // https://gitlab.com/gitlab-org/gitlab-workhorse/issues/230
......
...@@ -40,6 +40,10 @@ type upstream struct { ...@@ -40,6 +40,10 @@ type upstream struct {
} }
func NewUpstream(cfg config.Config, accessLogger *logrus.Logger) http.Handler { func NewUpstream(cfg config.Config, accessLogger *logrus.Logger) http.Handler {
return newUpstream(cfg, accessLogger, configureRoutes)
}
func newUpstream(cfg config.Config, accessLogger *logrus.Logger, routesCallback func(*upstream)) http.Handler {
up := upstream{ up := upstream{
Config: cfg, Config: cfg,
accessLogger: accessLogger, accessLogger: accessLogger,
...@@ -56,7 +60,7 @@ func NewUpstream(cfg config.Config, accessLogger *logrus.Logger) http.Handler { ...@@ -56,7 +60,7 @@ func NewUpstream(cfg config.Config, accessLogger *logrus.Logger) http.Handler {
up.RoundTripper = roundtripper.NewBackendRoundTripper(up.Backend, up.Socket, up.ProxyHeadersTimeout, cfg.DevelopmentMode) up.RoundTripper = roundtripper.NewBackendRoundTripper(up.Backend, up.Socket, up.ProxyHeadersTimeout, cfg.DevelopmentMode)
up.CableRoundTripper = roundtripper.NewBackendRoundTripper(up.CableBackend, up.CableSocket, up.ProxyHeadersTimeout, cfg.DevelopmentMode) up.CableRoundTripper = roundtripper.NewBackendRoundTripper(up.CableBackend, up.CableSocket, up.ProxyHeadersTimeout, cfg.DevelopmentMode)
up.configureURLPrefix() up.configureURLPrefix()
up.configureRoutes() routesCallback(&up)
var correlationOpts []correlation.InboundHandlerOption var correlationOpts []correlation.InboundHandlerOption
if cfg.PropagateCorrelationID { if cfg.PropagateCorrelationID {
...@@ -95,7 +99,7 @@ func (u *upstream) ServeHTTP(w http.ResponseWriter, r *http.Request) { ...@@ -95,7 +99,7 @@ func (u *upstream) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
// Check URL Root // Check URL Root
URIPath := urlprefix.CleanURIPath(r.URL.Path) URIPath := urlprefix.CleanURIPath(r.URL.EscapedPath())
prefix := u.URLPrefix prefix := u.URLPrefix
if !prefix.Match(URIPath) { if !prefix.Match(URIPath) {
helper.HTTPError(w, r, fmt.Sprintf("Not found %q", URIPath), http.StatusNotFound) helper.HTTPError(w, r, fmt.Sprintf("Not found %q", URIPath), http.StatusNotFound)
......
package upstream
import (
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/config"
)
func TestRouting(t *testing.T) {
handle := func(u *upstream, regex string) routeEntry {
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
io.WriteString(w, regex)
})
return u.route("", regex, handler)
}
const (
foobar = `\A/foobar\z`
quxbaz = `\A/quxbaz\z`
main = ""
)
u := newUpstream(config.Config{}, logrus.StandardLogger(), func(u *upstream) {
u.Routes = []routeEntry{
handle(u, foobar),
handle(u, quxbaz),
handle(u, main),
}
})
ts := httptest.NewServer(u)
defer ts.Close()
testCases := []struct {
desc string
path string
route string
}{
{"main route works", "/", main},
{"foobar route works", "/foobar", foobar},
{"quxbaz route works", "/quxbaz", quxbaz},
{"path traversal works, ends up in quxbaz", "/foobar/../quxbaz", quxbaz},
{"escaped path traversal does not match any route", "/foobar%2f%2e%2e%2fquxbaz", main},
{"double escaped path traversal does not match any route", "/foobar%252f%252e%252e%252fquxbaz", main},
}
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
resp, err := http.Get(ts.URL + tc.path)
require.NoError(t, err)
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode, "response code")
require.Equal(t, tc.route, string(body))
})
}
}
...@@ -222,12 +222,15 @@ func TestDeniedPublicUploadsFile(t *testing.T) { ...@@ -222,12 +222,15 @@ func TestDeniedPublicUploadsFile(t *testing.T) {
for _, resource := range []string{ for _, resource := range []string{
"/uploads/static.txt", "/uploads/static.txt",
"/uploads%2Fstatic.txt", "/uploads%2Fstatic.txt",
"/foobar%2F%2E%2E%2Fuploads/static.txt",
} { } {
resp, body := httpGet(t, ws.URL+resource, nil) t.Run(resource, func(t *testing.T) {
resp, body := httpGet(t, ws.URL+resource, nil)
require.Equal(t, 404, resp.StatusCode, "GET %q: status code", resource) require.Equal(t, 404, resp.StatusCode, "GET %q: status code", resource)
require.Equal(t, "", body, "GET %q: response body", resource) require.Equal(t, "", body, "GET %q: response body", resource)
require.True(t, proxied, "GET %q: never made it to backend", resource) require.True(t, proxied, "GET %q: never made it to backend", resource)
})
} }
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment