Commit 2959b8ac authored by Jacob Vosmaer's avatar Jacob Vosmaer Committed by Nick Thomas

Use more require in internal/{upload,upstream}

parent e3246cbe
...@@ -4,7 +4,6 @@ import ( ...@@ -4,7 +4,6 @@ import (
"bytes" "bytes"
"context" "context"
"fmt" "fmt"
"io"
"io/ioutil" "io/ioutil"
"mime/multipart" "mime/multipart"
"net/http" "net/http"
...@@ -53,9 +52,7 @@ func (a *testFormProcessor) Name() string { ...@@ -53,9 +52,7 @@ func (a *testFormProcessor) Name() string {
func TestUploadTempPathRequirement(t *testing.T) { func TestUploadTempPathRequirement(t *testing.T) {
response := httptest.NewRecorder() response := httptest.NewRecorder()
request, err := http.NewRequest("", "", nil) request, err := http.NewRequest("", "", nil)
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
apiResponse := &api.Response{} apiResponse := &api.Response{}
preparer := &DefaultPreparer{} preparer := &DefaultPreparer{}
opts, _, err := preparer.Prepare(apiResponse) opts, _, err := preparer.Prepare(apiResponse)
...@@ -67,15 +64,11 @@ func TestUploadTempPathRequirement(t *testing.T) { ...@@ -67,15 +64,11 @@ func TestUploadTempPathRequirement(t *testing.T) {
func TestUploadHandlerForwardingRawData(t *testing.T) { func TestUploadHandlerForwardingRawData(t *testing.T) {
ts := testhelper.TestServerWithHandler(regexp.MustCompile(`/url/path\z`), func(w http.ResponseWriter, r *http.Request) { ts := testhelper.TestServerWithHandler(regexp.MustCompile(`/url/path\z`), func(w http.ResponseWriter, r *http.Request) {
if r.Method != "PATCH" { require.Equal(t, "PATCH", r.Method, "method")
t.Fatal("Expected PATCH request")
}
var body bytes.Buffer body, err := ioutil.ReadAll(r.Body)
io.Copy(&body, r.Body) require.NoError(t, err)
if body.String() != "REQUEST" { require.Equal(t, "REQUEST", string(body), "request body")
t.Fatal("Expected REQUEST in request body")
}
w.WriteHeader(202) w.WriteHeader(202)
fmt.Fprint(w, "RESPONSE") fmt.Fprint(w, "RESPONSE")
...@@ -83,14 +76,10 @@ func TestUploadHandlerForwardingRawData(t *testing.T) { ...@@ -83,14 +76,10 @@ func TestUploadHandlerForwardingRawData(t *testing.T) {
defer ts.Close() defer ts.Close()
httpRequest, err := http.NewRequest("PATCH", ts.URL+"/url/path", bytes.NewBufferString("REQUEST")) httpRequest, err := http.NewRequest("PATCH", ts.URL+"/url/path", bytes.NewBufferString("REQUEST"))
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
tempPath, err := ioutil.TempDir("", "uploads") tempPath, err := ioutil.TempDir("", "uploads")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
defer os.RemoveAll(tempPath) defer os.RemoveAll(tempPath)
response := httptest.NewRecorder() response := httptest.NewRecorder()
...@@ -104,58 +93,30 @@ func TestUploadHandlerForwardingRawData(t *testing.T) { ...@@ -104,58 +93,30 @@ func TestUploadHandlerForwardingRawData(t *testing.T) {
HandleFileUploads(response, httpRequest, handler, apiResponse, nil, opts) HandleFileUploads(response, httpRequest, handler, apiResponse, nil, opts)
testhelper.RequireResponseCode(t, response, 202) testhelper.RequireResponseCode(t, response, 202)
if response.Body.String() != "RESPONSE" { require.Equal(t, "RESPONSE", response.Body.String(), "response body")
t.Fatal("Expected RESPONSE in response body")
}
} }
func TestUploadHandlerRewritingMultiPartData(t *testing.T) { func TestUploadHandlerRewritingMultiPartData(t *testing.T) {
var filePath string var filePath string
tempPath, err := ioutil.TempDir("", "uploads") tempPath, err := ioutil.TempDir("", "uploads")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
defer os.RemoveAll(tempPath) defer os.RemoveAll(tempPath)
ts := testhelper.TestServerWithHandler(regexp.MustCompile(`/url/path\z`), func(w http.ResponseWriter, r *http.Request) { ts := testhelper.TestServerWithHandler(regexp.MustCompile(`/url/path\z`), func(w http.ResponseWriter, r *http.Request) {
if r.Method != "PUT" { require.Equal(t, "PUT", r.Method, "method")
t.Fatal("Expected PUT request") require.NoError(t, r.ParseMultipartForm(100000))
}
err := r.ParseMultipartForm(100000)
if err != nil {
t.Fatal(err)
}
if len(r.MultipartForm.File) != 0 {
t.Error("Expected to not receive any files")
}
if r.FormValue("token") != "test" {
t.Error("Expected to receive token")
}
if r.FormValue("file.name") != "my.file" { require.Empty(t, r.MultipartForm.File, "Expected to not receive any files")
t.Error("Expected to receive a filename") require.Equal(t, "test", r.FormValue("token"), "Expected to receive token")
} require.Equal(t, "my.file", r.FormValue("file.name"), "Expected to receive a filename")
filePath = r.FormValue("file.path") filePath = r.FormValue("file.path")
if !strings.HasPrefix(filePath, tempPath) { require.True(t, strings.HasPrefix(filePath, tempPath), "Expected to the file to be in tempPath")
t.Error("Expected to the file to be in tempPath")
}
if r.FormValue("file.remote_url") != "" { require.Empty(t, r.FormValue("file.remote_url"), "Expected to receive empty remote_url")
t.Error("Expected to receive empty remote_url") require.Empty(t, r.FormValue("file.remote_id"), "Expected to receive empty remote_id")
} require.Equal(t, "4", r.FormValue("file.size"), "Expected to receive the file size")
if r.FormValue("file.remote_id") != "" {
t.Error("Expected to receive empty remote_id")
}
if r.FormValue("file.size") != "4" {
t.Error("Expected to receive the file size")
}
hashes := map[string]string{ hashes := map[string]string{
"md5": "098f6bcd4621d373cade4e832627b4f6", "md5": "098f6bcd4621d373cade4e832627b4f6",
...@@ -165,14 +126,10 @@ func TestUploadHandlerRewritingMultiPartData(t *testing.T) { ...@@ -165,14 +126,10 @@ func TestUploadHandlerRewritingMultiPartData(t *testing.T) {
} }
for algo, hash := range hashes { for algo, hash := range hashes {
if r.FormValue("file."+algo) != hash { require.Equal(t, hash, r.FormValue("file."+algo), "file hash %s", algo)
t.Errorf("Expected to receive file %s hash", algo)
}
} }
if valueCnt := len(r.MultipartForm.Value); valueCnt != 11 { require.Len(t, r.MultipartForm.Value, 11, "multipart form values")
t.Fatal("Expected to receive exactly 11 values but got", valueCnt)
}
w.WriteHeader(202) w.WriteHeader(202)
fmt.Fprint(w, "RESPONSE") fmt.Fprint(w, "RESPONSE")
...@@ -183,16 +140,12 @@ func TestUploadHandlerRewritingMultiPartData(t *testing.T) { ...@@ -183,16 +140,12 @@ func TestUploadHandlerRewritingMultiPartData(t *testing.T) {
writer := multipart.NewWriter(&buffer) writer := multipart.NewWriter(&buffer)
writer.WriteField("token", "test") writer.WriteField("token", "test")
file, err := writer.CreateFormFile("file", "my.file") file, err := writer.CreateFormFile("file", "my.file")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
fmt.Fprint(file, "test") fmt.Fprint(file, "test")
writer.Close() writer.Close()
httpRequest, err := http.NewRequest("PUT", ts.URL+"/url/path", nil) httpRequest, err := http.NewRequest("PUT", ts.URL+"/url/path", nil)
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
httpRequest = httpRequest.WithContext(ctx) httpRequest = httpRequest.WithContext(ctx)
...@@ -219,9 +172,7 @@ func TestUploadHandlerDetectingInjectedMultiPartData(t *testing.T) { ...@@ -219,9 +172,7 @@ func TestUploadHandlerDetectingInjectedMultiPartData(t *testing.T) {
var filePath string var filePath string
tempPath, err := ioutil.TempDir("", "uploads") tempPath, err := ioutil.TempDir("", "uploads")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
defer os.RemoveAll(tempPath) defer os.RemoveAll(tempPath)
tests := []struct { tests := []struct {
...@@ -249,9 +200,7 @@ func TestUploadHandlerDetectingInjectedMultiPartData(t *testing.T) { ...@@ -249,9 +200,7 @@ func TestUploadHandlerDetectingInjectedMultiPartData(t *testing.T) {
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
ts := testhelper.TestServerWithHandler(regexp.MustCompile(`/url/path\z`), func(w http.ResponseWriter, r *http.Request) { ts := testhelper.TestServerWithHandler(regexp.MustCompile(`/url/path\z`), func(w http.ResponseWriter, r *http.Request) {
if r.Method != "PUT" { require.Equal(t, "PUT", r.Method, "method")
t.Fatal("Expected PUT request")
}
w.WriteHeader(202) w.WriteHeader(202)
fmt.Fprint(w, "RESPONSE") fmt.Fprint(w, "RESPONSE")
...@@ -261,18 +210,14 @@ func TestUploadHandlerDetectingInjectedMultiPartData(t *testing.T) { ...@@ -261,18 +210,14 @@ func TestUploadHandlerDetectingInjectedMultiPartData(t *testing.T) {
writer := multipart.NewWriter(&buffer) writer := multipart.NewWriter(&buffer)
file, err := writer.CreateFormFile("file", "my.file") file, err := writer.CreateFormFile("file", "my.file")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
fmt.Fprint(file, "test") fmt.Fprint(file, "test")
writer.WriteField(test.field, "value") writer.WriteField(test.field, "value")
writer.Close() writer.Close()
httpRequest, err := http.NewRequest("PUT", ts.URL+"/url/path", &buffer) httpRequest, err := http.NewRequest("PUT", ts.URL+"/url/path", &buffer)
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
httpRequest = httpRequest.WithContext(ctx) httpRequest = httpRequest.WithContext(ctx)
...@@ -296,9 +241,7 @@ func TestUploadHandlerDetectingInjectedMultiPartData(t *testing.T) { ...@@ -296,9 +241,7 @@ func TestUploadHandlerDetectingInjectedMultiPartData(t *testing.T) {
func TestUploadProcessingField(t *testing.T) { func TestUploadProcessingField(t *testing.T) {
tempPath, err := ioutil.TempDir("", "uploads") tempPath, err := ioutil.TempDir("", "uploads")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
defer os.RemoveAll(tempPath) defer os.RemoveAll(tempPath)
var buffer bytes.Buffer var buffer bytes.Buffer
...@@ -308,9 +251,7 @@ func TestUploadProcessingField(t *testing.T) { ...@@ -308,9 +251,7 @@ func TestUploadProcessingField(t *testing.T) {
writer.Close() writer.Close()
httpRequest, err := http.NewRequest("PUT", "/url/path", &buffer) httpRequest, err := http.NewRequest("PUT", "/url/path", &buffer)
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
httpRequest.Header.Set("Content-Type", writer.FormDataContentType()) httpRequest.Header.Set("Content-Type", writer.FormDataContentType())
response := httptest.NewRecorder() response := httptest.NewRecorder()
...@@ -326,9 +267,7 @@ func TestUploadProcessingField(t *testing.T) { ...@@ -326,9 +267,7 @@ func TestUploadProcessingField(t *testing.T) {
func TestUploadProcessingFile(t *testing.T) { func TestUploadProcessingFile(t *testing.T) {
tempPath, err := ioutil.TempDir("", "uploads") tempPath, err := ioutil.TempDir("", "uploads")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
defer os.RemoveAll(tempPath) defer os.RemoveAll(tempPath)
_, testServer := test.StartObjectStore() _, testServer := test.StartObjectStore()
...@@ -362,16 +301,12 @@ func TestUploadProcessingFile(t *testing.T) { ...@@ -362,16 +301,12 @@ func TestUploadProcessingFile(t *testing.T) {
var buffer bytes.Buffer var buffer bytes.Buffer
writer := multipart.NewWriter(&buffer) writer := multipart.NewWriter(&buffer)
file, err := writer.CreateFormFile("file", "my.file") file, err := writer.CreateFormFile("file", "my.file")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
fmt.Fprint(file, "test") fmt.Fprint(file, "test")
writer.Close() writer.Close()
httpRequest, err := http.NewRequest("PUT", "/url/path", &buffer) httpRequest, err := http.NewRequest("PUT", "/url/path", &buffer)
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
httpRequest.Header.Set("Content-Type", writer.FormDataContentType()) httpRequest.Header.Set("Content-Type", writer.FormDataContentType())
response := httptest.NewRecorder() response := httptest.NewRecorder()
...@@ -392,9 +327,7 @@ func TestInvalidFileNames(t *testing.T) { ...@@ -392,9 +327,7 @@ func TestInvalidFileNames(t *testing.T) {
testhelper.ConfigureSecret() testhelper.ConfigureSecret()
tempPath, err := ioutil.TempDir("", "uploads") tempPath, err := ioutil.TempDir("", "uploads")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
defer os.RemoveAll(tempPath) defer os.RemoveAll(tempPath)
for _, testCase := range []struct { for _, testCase := range []struct {
...@@ -411,16 +344,12 @@ func TestInvalidFileNames(t *testing.T) { ...@@ -411,16 +344,12 @@ func TestInvalidFileNames(t *testing.T) {
writer := multipart.NewWriter(buffer) writer := multipart.NewWriter(buffer)
file, err := writer.CreateFormFile("file", testCase.filename) file, err := writer.CreateFormFile("file", testCase.filename)
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
fmt.Fprint(file, "test") fmt.Fprint(file, "test")
writer.Close() writer.Close()
httpRequest, err := http.NewRequest("POST", "/example", buffer) httpRequest, err := http.NewRequest("POST", "/example", buffer)
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
httpRequest.Header.Set("Content-Type", writer.FormDataContentType()) httpRequest.Header.Set("Content-Type", writer.FormDataContentType())
response := httptest.NewRecorder() response := httptest.NewRecorder()
...@@ -548,7 +477,5 @@ func waitUntilDeleted(t *testing.T, path string) { ...@@ -548,7 +477,5 @@ func waitUntilDeleted(t *testing.T, path string) {
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
} }
if !os.IsNotExist(err) { require.True(t, os.IsNotExist(err), "expected the file to be deleted")
t.Fatal("expected the file to be deleted")
}
} }
...@@ -6,6 +6,8 @@ import ( ...@@ -6,6 +6,8 @@ import (
"testing" "testing"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper" "gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper"
"github.com/stretchr/testify/require"
) )
func TestDevelopmentModeEnabled(t *testing.T) { func TestDevelopmentModeEnabled(t *testing.T) {
...@@ -18,9 +20,8 @@ func TestDevelopmentModeEnabled(t *testing.T) { ...@@ -18,9 +20,8 @@ func TestDevelopmentModeEnabled(t *testing.T) {
NotFoundUnless(developmentMode, http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { NotFoundUnless(developmentMode, http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
executed = true executed = true
})).ServeHTTP(w, r) })).ServeHTTP(w, r)
if !executed {
t.Error("The handler should get executed") require.True(t, executed, "The handler should get executed")
}
} }
func TestDevelopmentModeDisabled(t *testing.T) { func TestDevelopmentModeDisabled(t *testing.T) {
...@@ -33,8 +34,8 @@ func TestDevelopmentModeDisabled(t *testing.T) { ...@@ -33,8 +34,8 @@ func TestDevelopmentModeDisabled(t *testing.T) {
NotFoundUnless(developmentMode, http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { NotFoundUnless(developmentMode, http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
executed = true executed = true
})).ServeHTTP(w, r) })).ServeHTTP(w, r)
if executed {
t.Error("The handler should not get executed") require.False(t, executed, "The handler should not get executed")
}
testhelper.RequireResponseCode(t, w, 404) testhelper.RequireResponseCode(t, w, 404)
} }
...@@ -7,10 +7,11 @@ import ( ...@@ -7,10 +7,11 @@ import (
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"reflect"
"testing" "testing"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper" "gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper"
"github.com/stretchr/testify/require"
) )
func TestGzipEncoding(t *testing.T) { func TestGzipEncoding(t *testing.T) {
...@@ -24,18 +25,12 @@ func TestGzipEncoding(t *testing.T) { ...@@ -24,18 +25,12 @@ func TestGzipEncoding(t *testing.T) {
body := ioutil.NopCloser(&b) body := ioutil.NopCloser(&b)
req, err := http.NewRequest("POST", "http://address/test", body) req, err := http.NewRequest("POST", "http://address/test", body)
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
req.Header.Set("Content-Encoding", "gzip") req.Header.Set("Content-Encoding", "gzip")
contentEncodingHandler(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { contentEncodingHandler(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
if _, ok := r.Body.(*gzip.Reader); !ok { require.IsType(t, &gzip.Reader{}, r.Body, "body type")
t.Fatal("Expected gzip reader for body, but it's:", reflect.TypeOf(r.Body)) require.Empty(t, r.Header.Get("Content-Encoding"), "Content-Encoding should be deleted")
}
if r.Header.Get("Content-Encoding") != "" {
t.Fatal("Content-Encoding should be deleted")
}
})).ServeHTTP(resp, req) })).ServeHTTP(resp, req)
testhelper.RequireResponseCode(t, resp, 200) testhelper.RequireResponseCode(t, resp, 200)
...@@ -48,18 +43,12 @@ func TestNoEncoding(t *testing.T) { ...@@ -48,18 +43,12 @@ func TestNoEncoding(t *testing.T) {
body := ioutil.NopCloser(&b) body := ioutil.NopCloser(&b)
req, err := http.NewRequest("POST", "http://address/test", body) req, err := http.NewRequest("POST", "http://address/test", body)
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
req.Header.Set("Content-Encoding", "") req.Header.Set("Content-Encoding", "")
contentEncodingHandler(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { contentEncodingHandler(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
if r.Body != body { require.Equal(t, body, r.Body, "Expected the same body")
t.Fatal("Expected the same body") require.Empty(t, r.Header.Get("Content-Encoding"), "Content-Encoding should be deleted")
}
if r.Header.Get("Content-Encoding") != "" {
t.Fatal("Content-Encoding should be deleted")
}
})).ServeHTTP(resp, req) })).ServeHTTP(resp, req)
testhelper.RequireResponseCode(t, resp, 200) testhelper.RequireResponseCode(t, resp, 200)
...@@ -69,9 +58,7 @@ func TestInvalidEncoding(t *testing.T) { ...@@ -69,9 +58,7 @@ func TestInvalidEncoding(t *testing.T) {
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
req, err := http.NewRequest("POST", "http://address/test", nil) req, err := http.NewRequest("POST", "http://address/test", nil)
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
req.Header.Set("Content-Encoding", "application/unknown") req.Header.Set("Content-Encoding", "application/unknown")
contentEncodingHandler(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { contentEncodingHandler(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
......
package roundtripper package roundtripper
import ( import (
"strconv"
"testing" "testing"
"github.com/stretchr/testify/require"
) )
func TestMustParseAddress(t *testing.T) { func TestMustParseAddress(t *testing.T) {
...@@ -10,26 +13,27 @@ func TestMustParseAddress(t *testing.T) { ...@@ -10,26 +13,27 @@ func TestMustParseAddress(t *testing.T) {
{"[::1]:23", "http", "::1:23"}, {"[::1]:23", "http", "::1:23"},
{"4.5.6.7", "http", "4.5.6.7:http"}, {"4.5.6.7", "http", "4.5.6.7:http"},
} }
for _, example := range successExamples { for i, example := range successExamples {
result := mustParseAddress(example.address, example.scheme) t.Run(strconv.Itoa(i), func(t *testing.T) {
if example.expected != result { require.Equal(t, example.expected, mustParseAddress(example.address, example.scheme))
t.Errorf("expected %q, got %q", example.expected, result) })
}
} }
}
func TestMustParseAddressPanic(t *testing.T) {
panicExamples := []struct{ address, scheme string }{ panicExamples := []struct{ address, scheme string }{
{"1.2.3.4", ""}, {"1.2.3.4", ""},
{"1.2.3.4", "https"}, {"1.2.3.4", "https"},
} }
for _, panicExample := range panicExamples { for i, panicExample := range panicExamples {
func() { t.Run(strconv.Itoa(i), func(t *testing.T) {
defer func() { defer func() {
if r := recover(); r == nil { if r := recover(); r == nil {
t.Errorf("expected panic for %v but none occurred", panicExample) t.Fatal("expected panic")
} }
}() }()
t.Log(mustParseAddress(panicExample.address, panicExample.scheme)) mustParseAddress(panicExample.address, panicExample.scheme)
}() })
} }
} }
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment