package main import ( "bytes" "fmt" "io" "mime/multipart" "net/http" "net/http/httptest" "regexp" "strings" "testing" "github.com/dgrijalva/jwt-go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gitlab.com/gitlab-org/gitlab-workhorse/internal/api" "gitlab.com/gitlab-org/gitlab-workhorse/internal/secret" "gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper" "gitlab.com/gitlab-org/gitlab-workhorse/internal/upload" ) type uploadArtifactsFunction func(url, contentType string, body io.Reader) (*http.Response, string, error) func uploadArtifactsV1(url, contentType string, body io.Reader) (*http.Response, string, error) { resource := `/ci/api/v1/builds/123/artifacts` resp, err := http.Post(url+resource, contentType, body) return resp, resource, err } func uploadArtifactsV4(url, contentType string, body io.Reader) (*http.Response, string, error) { resource := `/api/v4/jobs/123/artifacts` resp, err := http.Post(url+resource, contentType, body) return resp, resource, err } func testArtifactsUpload(t *testing.T, uploadArtifacts uploadArtifactsFunction) { reqBody, contentType, err := multipartBodyWithFile() require.NoError(t, err) ts := uploadTestServer(t, nil) defer ts.Close() ws := startWorkhorseServer(ts.URL) defer ws.Close() resp, resource, err := uploadArtifacts(ws.URL, contentType, reqBody) assert.NoError(t, err) defer resp.Body.Close() assert.Equal(t, 200, resp.StatusCode, "GET %q: expected 200, got %d", resource, resp.StatusCode) } func TestArtifactsUpload(t *testing.T) { testArtifactsUpload(t, uploadArtifactsV1) testArtifactsUpload(t, uploadArtifactsV4) } func uploadTestServer(t *testing.T, extraTests func(r *http.Request)) *httptest.Server { return testhelper.TestServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, r *http.Request) { if strings.HasSuffix(r.URL.Path, "/authorize") { w.Header().Set("Content-Type", api.ResponseContentType) if _, err := fmt.Fprintf(w, `{"TempPath":"%s"}`, scratchDir); err != nil { t.Fatal(err) } return } err := r.ParseMultipartForm(100000) if err != nil { t.Fatal(err) } nValues := 7 // file name, path, size, md5, sha1, sha256, sha512 for just the upload (no metadata because we are not POSTing a valid zip file) if len(r.MultipartForm.Value) != nValues { t.Errorf("Expected to receive exactly %d values", nValues) } if len(r.MultipartForm.File) != 0 { t.Error("Expected to not receive any files") } if extraTests != nil { extraTests(r) } w.WriteHeader(200) }) } func TestAcceleratedUpload(t *testing.T) { reqBody, contentType, err := multipartBodyWithFile() if err != nil { t.Fatal(err) } ts := uploadTestServer(t, func(r *http.Request) { jwtToken, err := jwt.Parse(r.Header.Get(upload.RewrittenFieldsHeader), func(token *jwt.Token) (interface{}, error) { // Don't forget to validate the alg is what you expect: if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"]) } testhelper.ConfigureSecret() secretBytes, err := secret.Bytes() if err != nil { return nil, fmt.Errorf("read secret from file: %v", err) } return secretBytes, nil }) if err != nil { t.Fatal(err) } rewrittenFields := jwtToken.Claims.(jwt.MapClaims)["rewritten_fields"].(map[string]interface{}) if len(rewrittenFields) != 1 || len(rewrittenFields["file"].(string)) == 0 { t.Fatalf("Unexpected rewritten_fields value: %v", rewrittenFields) } }) defer ts.Close() ws := startWorkhorseServer(ts.URL) defer ws.Close() resource := `/example` resp, err := http.Post(ws.URL+resource, contentType, reqBody) if err != nil { t.Error(err) } defer resp.Body.Close() if resp.StatusCode != 200 { t.Errorf("GET %q: expected 200, got %d", resource, resp.StatusCode) } } func multipartBodyWithFile() (io.Reader, string, error) { result := &bytes.Buffer{} writer := multipart.NewWriter(result) file, err := writer.CreateFormFile("file", "my.file") if err != nil { return nil, "", err } fmt.Fprint(file, "SHOULD BE ON DISK, NOT IN MULTIPART") return result, writer.FormDataContentType(), writer.Close() } func TestBlockingRewrittenFieldsHeader(t *testing.T) { canary := "untrusted header passed by user" testCases := []struct { desc string contentType string body io.Reader present bool }{ {"multipart with file", "", nil, true}, // placeholder {"no multipart", "text/plain", nil, false}, } if b, c, err := multipartBodyWithFile(); err == nil { testCases[0].contentType = c testCases[0].body = b } else { t.Fatal(err) } for _, tc := range testCases { ts := testhelper.TestServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, r *http.Request) { h := upload.RewrittenFieldsHeader if _, ok := r.Header[h]; ok != tc.present { t.Errorf("Expectation of presence (%v) violated", tc.present) } if r.Header.Get(h) == canary { t.Errorf("Found canary %q in header %q", canary, h) } }) defer ts.Close() ws := startWorkhorseServer(ts.URL) defer ws.Close() req, err := http.NewRequest("POST", ws.URL+"/something", tc.body) if err != nil { t.Fatal(err) } req.Header.Set("Content-Type", tc.contentType) req.Header.Set(upload.RewrittenFieldsHeader, canary) resp, err := http.DefaultClient.Do(req) if err != nil { t.Error(err) } defer resp.Body.Close() if resp.StatusCode != 200 { t.Errorf("%s: expected HTTP 200, got %d", tc.desc, resp.StatusCode) } } }