Commit 3b33cf3d authored by Jacob Vosmaer's avatar Jacob Vosmaer Committed by Alessio Caiazza

Use testify/require in top level tests

parent 9a4a1d48
......@@ -8,6 +8,7 @@ import (
"testing"
"github.com/dgrijalva/jwt-go"
"github.com/stretchr/testify/require"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
......@@ -31,9 +32,7 @@ func runPreAuthorizeHandler(t *testing.T, ts *httptest.Server, suffix string, ur
// Create http request
httpRequest, err := http.NewRequest("GET", "/address", nil)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
parsedURL := helper.URLMustParse(ts.URL)
testhelper.ConfigureSecret()
a := api.NewAPI(parsedURL, "123", roundtripper.NewTestBackendRoundTripper(parsedURL))
......@@ -70,9 +69,8 @@ func TestPreAuthorizeJsonFailure(t *testing.T) {
func TestPreAuthorizeContentTypeFailure(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if _, err := w.Write([]byte(`{"hello":"world"}`)); err != nil {
t.Fatalf("write auth response: %v", err)
}
_, err := w.Write([]byte(`{"hello":"world"}`))
require.NoError(t, err, "write auth response")
}))
defer ts.Close()
......@@ -110,27 +108,16 @@ func TestPreAuthorizeJWT(t *testing.T) {
return secretBytes, nil
})
if err != nil {
t.Fatalf("decode token: %v", err)
}
require.NoError(t, err, "decode token")
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
t.Fatal("claims cast failed")
}
if !token.Valid {
t.Fatal("JWT token invalid")
}
if claims["iss"] != "gitlab-workhorse" {
t.Fatalf("execpted issuer gitlab-workhorse, got %q", claims["iss"])
}
require.True(t, ok, "claims cast")
require.True(t, token.Valid, "JWT token valid")
require.Equal(t, "gitlab-workhorse", claims["iss"], "JWT token issuer")
w.Header().Set("Content-Type", api.ResponseContentType)
if _, err := w.Write([]byte(`{"hello":"world"}`)); err != nil {
t.Fatalf("write auth response: %v", err)
}
_, err = w.Write([]byte(`{"hello":"world"}`))
require.NoError(t, err, "write auth response")
}))
defer ts.Close()
......
......@@ -2,9 +2,11 @@ package main
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestParseAuthBackend(t *testing.T) {
func TestParseAuthBackendFailure(t *testing.T) {
failures := []string{
"",
"ftp://localhost",
......@@ -12,11 +14,14 @@ func TestParseAuthBackend(t *testing.T) {
}
for _, example := range failures {
if _, err := parseAuthBackend(example); err == nil {
t.Errorf("error expected for %q", example)
}
t.Run(example, func(t *testing.T) {
_, err := parseAuthBackend(example)
require.Error(t, err)
})
}
}
func TestParseAuthBackend(t *testing.T) {
successes := []struct{ input, host, scheme string }{
{"http://localhost:8080", "localhost:8080", "http"},
{"localhost:3000", "localhost:3000", "http"},
......@@ -25,18 +30,12 @@ func TestParseAuthBackend(t *testing.T) {
}
for _, example := range successes {
result, err := parseAuthBackend(example.input)
if err != nil {
t.Errorf("parse %q: %v", example.input, err)
break
}
if result.Host != example.host {
t.Errorf("example %q: expected %q, got %q", example.input, example.host, result.Host)
}
t.Run(example.input, func(t *testing.T) {
result, err := parseAuthBackend(example.input)
require.NoError(t, err)
if result.Scheme != example.scheme {
t.Errorf("example %q: expected %q, got %q", example.input, example.scheme, result.Scheme)
}
require.Equal(t, example.host, result.Host, "host")
require.Equal(t, example.scheme, result.Scheme, "scheme")
})
}
}
......@@ -14,6 +14,7 @@ import (
"time"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/require"
"gitlab.com/gitlab-org/labkit/log"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
......@@ -45,9 +46,7 @@ func TestChannelHappyPath(t *testing.T) {
defer close()
client, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com")
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
server := (<-serverConns).conn
defer server.Close()
......@@ -55,14 +54,10 @@ func TestChannelHappyPath(t *testing.T) {
message := "test message"
// channel.k8s.io: server writes to channel 1, STDOUT
if err := say(server, "\x01"+message); err != nil {
t.Fatal(err)
}
require.NoError(t, say(server, "\x01"+message))
requireReadMessage(t, client, websocket.BinaryMessage, message)
if err := say(client, message); err != nil {
t.Fatal(err)
}
require.NoError(t, say(client, message))
// channel.k8s.io: client writes get put on channel 0, STDIN
requireReadMessage(t, server, websocket.BinaryMessage, "\x00"+message)
......@@ -78,14 +73,8 @@ func TestChannelBadTLS(t *testing.T) {
_, clientURL, close := wireupChannel(envTerminalPath, badCA, "channel.k8s.io")
defer close()
client, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com")
if err != websocket.ErrBadHandshake {
t.Fatalf("Expected connection to fail ErrBadHandshake, got: %v", err)
}
if err == nil {
log.Info("TLS negotiation should have failed!")
defer client.Close()
}
_, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com")
require.Equal(t, websocket.ErrBadHandshake, err, "unexpected error %v", err)
}
func TestChannelSessionTimeout(t *testing.T) {
......@@ -93,9 +82,7 @@ func TestChannelSessionTimeout(t *testing.T) {
defer close()
client, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com")
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
sc := <-serverConns
defer sc.conn.Close()
......@@ -103,9 +90,7 @@ func TestChannelSessionTimeout(t *testing.T) {
client.SetReadDeadline(time.Now().Add(time.Duration(2) * time.Second))
_, _, err = client.ReadMessage()
if !websocket.IsCloseError(err, websocket.CloseAbnormalClosure) {
t.Fatalf("Client connection was not closed, got %v", err)
}
require.True(t, websocket.IsCloseError(err, websocket.CloseAbnormalClosure), "Client connection was not closed, got %v", err)
}
func TestChannelProxyForwardsHeadersFromUpstream(t *testing.T) {
......@@ -115,16 +100,12 @@ func TestChannelProxyForwardsHeadersFromUpstream(t *testing.T) {
defer close()
client, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com")
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
defer client.Close()
sc := <-serverConns
defer sc.conn.Close()
if sc.req.Header.Get("Random-Header") != "Value" {
t.Fatal("Header specified by upstream not sent to remote")
}
require.Equal(t, "Value", sc.req.Header.Get("Random-Header"), "Header specified by upstream not sent to remote")
}
func TestChannelProxyForwardsXForwardedForFromClient(t *testing.T) {
......@@ -134,21 +115,16 @@ func TestChannelProxyForwardsXForwardedForFromClient(t *testing.T) {
hdr := make(http.Header)
hdr.Set("X-Forwarded-For", "127.0.0.2")
client, _, err := dialWebsocket(clientURL, hdr, "terminal.gitlab.com")
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
defer client.Close()
clientIP, _, err := net.SplitHostPort(client.LocalAddr().String())
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
sc := <-serverConns
defer sc.conn.Close()
if xff := sc.req.Header.Get("X-Forwarded-For"); xff != "127.0.0.2, "+clientIP {
t.Fatalf("X-Forwarded-For from client not sent to remote: %+v", xff)
}
require.Equal(t, "127.0.0.2, "+clientIP, sc.req.Header.Get("X-Forwarded-For"), "X-Forwarded-For from client not sent to remote")
}
func wireupChannel(channelPath string, modifier func(*api.Response), subprotocols ...string) (chan connWithReq, string, func()) {
......@@ -262,15 +238,8 @@ func say(conn *websocket.Conn, message string) error {
func requireReadMessage(t *testing.T, conn *websocket.Conn, expectedMessageType int, expectedData string) {
messageType, data, err := conn.ReadMessage()
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
if messageType != expectedMessageType {
t.Fatalf("Expected message, %d, got %d", expectedMessageType, messageType)
}
if string(data) != expectedData {
t.Fatalf("Message was mangled in transit. Expected %q, got %q", expectedData, string(data))
}
require.Equal(t, expectedMessageType, messageType, "message type")
require.Equal(t, expectedData, string(data), "message data")
}
......@@ -167,6 +167,11 @@ func TestGetInfoRefsProxiedToGitalyInterruptedStream(t *testing.T) {
close(done)
}()
waitDone(t, done)
}
func waitDone(t *testing.T, done chan struct{}) {
t.Helper()
select {
case <-done:
return
......@@ -252,12 +257,7 @@ func TestPostReceivePackProxiedToGitalyInterrupted(t *testing.T) {
close(done)
}()
select {
case <-done:
return
case <-time.After(10 * time.Second):
t.Fatal("time out waiting for gitaly handler to return")
}
waitDone(t, done)
}
// ReaderFunc is an adapter to turn a conforming function into an io.Reader.
......@@ -319,9 +319,7 @@ func TestPostUploadPackProxiedToGitalySuccessfully(t *testing.T) {
m.Lock()
requestFinished := requestReadFinished
m.Unlock()
if !requestFinished {
t.Fatalf("response written before request was fully read")
}
require.True(t, requestFinished, "response written before request was fully read")
body := string(testhelper.ReadAll(t, resp.Body))
bodySplit := strings.SplitN(body, "\000", 2)
......@@ -376,12 +374,7 @@ func TestPostUploadPackProxiedToGitalyInterrupted(t *testing.T) {
close(done)
}()
select {
case <-done:
return
case <-time.After(10 * time.Second):
t.Fatal("time out waiting for gitaly handler to return")
}
waitDone(t, done)
}
func TestGetDiffProxiedToGitalySuccessfully(t *testing.T) {
......@@ -447,12 +440,7 @@ func TestGetBlobProxiedToGitalyInterruptedStream(t *testing.T) {
close(done)
}()
select {
case <-done:
return
case <-time.After(10 * time.Second):
t.Fatal("time out waiting for gitaly handler to return")
}
waitDone(t, done)
}
func TestGetArchiveProxiedToGitalySuccessfully(t *testing.T) {
......@@ -521,12 +509,7 @@ func TestGetArchiveProxiedToGitalyInterruptedStream(t *testing.T) {
close(done)
}()
select {
case <-done:
return
case <-time.After(10 * time.Second):
t.Fatal("time out waiting for gitaly handler to return")
}
waitDone(t, done)
}
func TestGetDiffProxiedToGitalyInterruptedStream(t *testing.T) {
......@@ -553,12 +536,7 @@ func TestGetDiffProxiedToGitalyInterruptedStream(t *testing.T) {
close(done)
}()
select {
case <-done:
return
case <-time.After(10 * time.Second):
t.Fatal("time out waiting for gitaly handler to return")
}
waitDone(t, done)
}
func TestGetPatchProxiedToGitalyInterruptedStream(t *testing.T) {
......@@ -585,12 +563,7 @@ func TestGetPatchProxiedToGitalyInterruptedStream(t *testing.T) {
close(done)
}()
select {
case <-done:
return
case <-time.After(10 * time.Second):
t.Fatal("time out waiting for gitaly handler to return")
}
waitDone(t, done)
}
func TestGetSnapshotProxiedToGitalySuccessfully(t *testing.T) {
......@@ -634,12 +607,7 @@ func TestGetSnapshotProxiedToGitalyInterruptedStream(t *testing.T) {
close(done)
}()
select {
case <-done:
return
case <-time.After(10 * time.Second):
t.Fatal("time out waiting for gitaly handler to return")
}
waitDone(t, done)
}
func buildGetSnapshotParams(gitalyAddress string, repo *gitalypb.Repository) string {
......
......@@ -3,12 +3,11 @@ package main
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"regexp"
"strings"
"testing"
"time"
......@@ -33,27 +32,20 @@ func newProxy(url string, rt http.RoundTripper) *proxy.Proxy {
func TestProxyRequest(t *testing.T) {
ts := testhelper.TestServerWithHandler(regexp.MustCompile(`/url/path\z`), func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
t.Fatal("Expected POST request")
}
require.Equal(t, "POST", r.Method, "method")
require.Equal(t, "test", r.Header.Get("Custom-Header"), "custom header")
require.Equal(t, testVersion, r.Header.Get("Gitlab-Workhorse"), "version header")
if r.Header.Get("Custom-Header") != "test" {
t.Fatal("Missing custom header")
}
require.Regexp(
t,
regexp.MustCompile(`\A1`),
r.Header.Get("Gitlab-Workhorse-Proxy-Start"),
"expect Gitlab-Workhorse-Proxy-Start to start with 1",
)
if h := r.Header.Get("Gitlab-Workhorse"); h != testVersion {
t.Fatalf("Missing GitLab-Workhorse header: want %q, got %q", testVersion, h)
}
if h := r.Header.Get("Gitlab-Workhorse-Proxy-Start"); !strings.HasPrefix(h, "1") {
t.Fatalf("Expect Gitlab-Workhorse-Proxy-Start to start with 1, got %q", h)
}
var body bytes.Buffer
io.Copy(&body, r.Body)
if body.String() != "REQUEST" {
t.Fatal("Expected REQUEST in request body")
}
body, err := ioutil.ReadAll(r.Body)
require.NoError(t, err, "read body")
require.Equal(t, "REQUEST", string(body), "body contents")
w.Header().Set("Custom-Response-Header", "test")
w.WriteHeader(202)
......@@ -61,9 +53,7 @@ func TestProxyRequest(t *testing.T) {
})
httpRequest, err := http.NewRequest("POST", ts.URL+"/url/path", bytes.NewBufferString("REQUEST"))
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
httpRequest.Header.Set("Custom-Header", "test")
w := httptest.NewRecorder()
......@@ -71,16 +61,12 @@ func TestProxyRequest(t *testing.T) {
require.Equal(t, 202, w.Code)
testhelper.RequireResponseBody(t, w, "RESPONSE")
if w.Header().Get("Custom-Response-Header") != "test" {
t.Fatal("Expected custom response header")
}
require.Equal(t, "test", w.Header().Get("Custom-Response-Header"), "custom response header")
}
func TestProxyError(t *testing.T) {
httpRequest, err := http.NewRequest("POST", "/url/path", bytes.NewBufferString("REQUEST"))
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
httpRequest.Header.Set("Custom-Header", "test")
w := httptest.NewRecorder()
......@@ -95,9 +81,7 @@ func TestProxyReadTimeout(t *testing.T) {
})
httpRequest, err := http.NewRequest("POST", "http://localhost/url/path", nil)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
rt := badgateway.NewRoundTripper(false, &http.Transport{
Proxy: http.ProxyFromEnvironment,
......@@ -124,9 +108,7 @@ func TestProxyHandlerTimeout(t *testing.T) {
)
httpRequest, err := http.NewRequest("POST", "http://localhost/url/path", nil)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
w := httptest.NewRecorder()
newProxy(ts.URL, nil).ServeHTTP(w, httpRequest)
......
......@@ -73,23 +73,18 @@ func uploadTestServer(t *testing.T, extraTests func(r *http.Request)) *httptest.
expectSignedRequest(t, r)
w.Header().Set("Content-Type", api.ResponseContentType)
if _, err := fmt.Fprintf(w, `{"TempPath":"%s"}`, scratchDir); err != nil {
t.Fatal(err)
}
_, err := fmt.Fprintf(w, `{"TempPath":"%s"}`, scratchDir)
require.NoError(t, err)
return
}
err := r.ParseMultipartForm(100000)
if err != nil {
t.Fatal(err)
}
nValues := 10 // file name, path, remote_url, remote_id, size, md5, sha1, sha256, sha512, gitlab-workhorse-upload for just the upload (no metadata because we are not POSTing a valid zip file)
if len(r.MultipartForm.Value) != nValues {
t.Errorf("Expected to receive exactly %d values", nValues)
}
if len(r.MultipartForm.File) != 0 {
t.Error("Expected to not receive any files")
}
require.NoError(t, r.ParseMultipartForm(100000))
const nValues = 10 // file name, path, remote_url, remote_id, size, md5, sha1, sha256, sha512, gitlab-workhorse-upload for just the upload (no metadata because we are not POSTing a valid zip file)
require.Len(t, r.MultipartForm.Value, nValues)
require.Empty(t, r.MultipartForm.File, "multipart form files")
if extraTests != nil {
extraTests(r)
}
......@@ -202,43 +197,35 @@ func TestBlockingRewrittenFieldsHeader(t *testing.T) {
{"no multipart", "text/plain", nil, false},
}
if b, c, err := multipartBodyWithFile(); err == nil {
testCases[0].contentType = c
testCases[0].body = b
} else {
t.Fatal(err)
}
var err error
testCases[0].body, testCases[0].contentType, err = multipartBodyWithFile()
require.NoError(t, err)
for _, tc := range testCases {
ts := testhelper.TestServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, r *http.Request) {
h := upload.RewrittenFieldsHeader
if _, ok := r.Header[h]; ok != tc.present {
t.Errorf("Expectation of presence (%v) violated", tc.present)
}
if r.Header.Get(h) == canary {
t.Errorf("Found canary %q in header %q", canary, h)
key := upload.RewrittenFieldsHeader
if tc.present {
require.Contains(t, r.Header, key)
} else {
require.NotContains(t, r.Header, key)
}
require.NotEqual(t, canary, r.Header.Get(key), "Found canary %q in header %q", canary, key)
})
defer ts.Close()
ws := startWorkhorseServer(ts.URL)
defer ws.Close()
req, err := http.NewRequest("POST", ws.URL+"/something", tc.body)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
req.Header.Set("Content-Type", tc.contentType)
req.Header.Set(upload.RewrittenFieldsHeader, canary)
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Error(err)
}
require.NoError(t, err)
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Errorf("%s: expected HTTP 200, got %d", tc.desc, resp.StatusCode)
}
require.Equal(t, 200, resp.StatusCode, "status code")
}
}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment