Commit 34a9138d authored by Jacob Vosmaer's avatar Jacob Vosmaer

Merge branch 'master' into destdir

parents 4bfdcba1 4464eb25
test/data
test/scratch
gitlab-workhorse gitlab-workhorse
test/public testdata/data
testdata/scratch
testdata/public
gitlab-zip-cat
gitlab-zip-metadata
...@@ -2,6 +2,21 @@ ...@@ -2,6 +2,21 @@
Formerly known as 'gitlab-git-http-server'. Formerly known as 'gitlab-git-http-server'.
0.6.1
Add support for generating zip artifacts metadata and serving single
files from zip archives.
Gitlab-workhorse now consists of multiple executables. We also fixed a
routing bug introduced by the 0.6.0 refactor that broke relative URL
support.
0.6.0
Overhauled the source code organization; no user-facing changes
(intended). The application code is now split into Go 'packages'
(modules). As of 0.6.0 gitlab-workhorse requires Go 1.5 or newer.
0.5.4 0.5.4
Fix /api/v3/projects routing bug introduced in 0.5.2-0.5.3. Fix /api/v3/projects routing bug introduced in 0.5.2-0.5.3.
......
PREFIX=/usr/local PREFIX=/usr/local
VERSION=$(shell git describe)-$(shell date -u +%Y%m%d.%H%M%S) VERSION=$(shell git describe)-$(shell date -u +%Y%m%d.%H%M%S)
GOBUILD=go build -ldflags "-X main.Version=${VERSION}"
gitlab-workhorse: $(wildcard *.go) all: gitlab-zip-cat gitlab-zip-metadata gitlab-workhorse
go build -ldflags "-X main.Version=${VERSION}" -o gitlab-workhorse
install: gitlab-workhorse gitlab-zip-cat: $(shell find cmd/gitlab-zip-cat/ -name '*.go')
${GOBUILD} -o $@ ./cmd/$@
gitlab-zip-metadata: $(shell find cmd/gitlab-zip-metadata/ -name '*.go')
${GOBUILD} -o $@ ./cmd/$@
gitlab-workhorse: $(shell find . -name '*.go')
${GOBUILD} -o $@
install: gitlab-workhorse gitlab-zip-cat gitlab-zip-metadata
mkdir -p $(DESTDIR)${PREFIX}/bin/ mkdir -p $(DESTDIR)${PREFIX}/bin/
install gitlab-workhorse ${DESTDIR}${PREFIX}/bin/ install gitlab-workhorse gitlab-zip-cat gitlab-zip-metadata ${DESTDIR}${PREFIX}/bin/
.PHONY: test .PHONY: test
test: test/data/group/test.git clean-workhorse gitlab-workhorse test: testdata/data/group/test.git clean-workhorse all
go fmt | awk '{ print "Please run go fmt"; exit 1 }' go fmt ./... | awk '{ print } END { if (NR > 0) { print "Please run go fmt"; exit 1 } }'
go test support/path go test ./...
@echo SUCCESS
coverage: test/data/group/test.git coverage: testdata/data/group/test.git
go test -cover -coverprofile=test.coverage go test -cover -coverprofile=test.coverage
go tool cover -html=test.coverage -o coverage.html go tool cover -html=test.coverage -o coverage.html
rm -f test.coverage rm -f test.coverage
test/data/group/test.git: test/data testdata/data/group/test.git: testdata/data
git clone --bare https://gitlab.com/gitlab-org/gitlab-test.git test/data/group/test.git git clone --bare https://gitlab.com/gitlab-org/gitlab-test.git $@
test/data: testdata/data:
mkdir -p test/data mkdir -p $@
.PHONY: clean .PHONY: clean
clean: clean-workhorse clean: clean-workhorse
rm -rf test/data test/scratch rm -rf testdata/data testdata/scratch
.PHONY: clean-workhorse .PHONY: clean-workhorse
clean-workhorse: clean-workhorse:
rm -f gitlab-workhorse rm -f gitlab-workhorse gitlab-zip-cat gitlab-zip-metadata
package main
func artifactsAuthorizeHandler(handleFunc serviceHandleFunc) serviceHandleFunc {
return preAuthorizeHandler(handleFunc, "/authorize")
}
package main package main
import ( import (
"./internal/api"
"./internal/helper"
"./internal/testhelper"
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
...@@ -8,14 +11,14 @@ import ( ...@@ -8,14 +11,14 @@ import (
"testing" "testing"
) )
func okHandler(w http.ResponseWriter, r *gitRequest) { func okHandler(w http.ResponseWriter, _ *http.Request, _ *api.Response) {
w.WriteHeader(201) w.WriteHeader(201)
fmt.Fprint(w, "{\"status\":\"ok\"}") fmt.Fprint(w, "{\"status\":\"ok\"}")
} }
func runPreAuthorizeHandler(t *testing.T, suffix string, url *regexp.Regexp, authorizationResponse interface{}, returnCode, expectedCode int) *httptest.ResponseRecorder { func runPreAuthorizeHandler(t *testing.T, suffix string, url *regexp.Regexp, apiResponse interface{}, returnCode, expectedCode int) *httptest.ResponseRecorder {
// Prepare test server and backend // Prepare test server and backend
ts := testAuthServer(url, returnCode, authorizationResponse) ts := testAuthServer(url, returnCode, apiResponse)
defer ts.Close() defer ts.Close()
// Create http request // Create http request
...@@ -23,15 +26,11 @@ func runPreAuthorizeHandler(t *testing.T, suffix string, url *regexp.Regexp, aut ...@@ -23,15 +26,11 @@ func runPreAuthorizeHandler(t *testing.T, suffix string, url *regexp.Regexp, aut
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
a := api.NewAPI(helper.URLMustParse(ts.URL), "123", nil)
request := gitRequest{
Request: httpRequest,
u: newUpstream(ts.URL, nil),
}
response := httptest.NewRecorder() response := httptest.NewRecorder()
preAuthorizeHandler(okHandler, suffix)(response, &request) a.PreAuthorizeHandler(okHandler, suffix).ServeHTTP(response, httpRequest)
assertResponseCode(t, response, expectedCode) testhelper.AssertResponseCode(t, response, expectedCode)
return response return response
} }
...@@ -39,7 +38,7 @@ func TestPreAuthorizeHappyPath(t *testing.T) { ...@@ -39,7 +38,7 @@ func TestPreAuthorizeHappyPath(t *testing.T) {
runPreAuthorizeHandler( runPreAuthorizeHandler(
t, "/authorize", t, "/authorize",
regexp.MustCompile(`/authorize\z`), regexp.MustCompile(`/authorize\z`),
&authorizationResponse{}, &api.Response{},
200, 201) 200, 201)
} }
...@@ -47,7 +46,7 @@ func TestPreAuthorizeSuffix(t *testing.T) { ...@@ -47,7 +46,7 @@ func TestPreAuthorizeSuffix(t *testing.T) {
runPreAuthorizeHandler( runPreAuthorizeHandler(
t, "/different-authorize", t, "/different-authorize",
regexp.MustCompile(`/authorize\z`), regexp.MustCompile(`/authorize\z`),
&authorizationResponse{}, &api.Response{},
200, 404) 200, 404)
} }
......
package main
import (
"../../internal/zipartifacts"
"archive/zip"
"flag"
"fmt"
"io"
"os"
)
const progName = "gitlab-zip-cat"
var Version = "unknown"
var printVersion = flag.Bool("version", false, "Print version and exit")
func main() {
flag.Parse()
version := fmt.Sprintf("%s %s", progName, Version)
if *printVersion {
fmt.Println(version)
os.Exit(0)
}
if len(os.Args) != 3 {
fmt.Fprintf(os.Stderr, "Usage: %s FILE.ZIP ENTRY", progName)
os.Exit(1)
}
archiveFileName := os.Args[1]
fileName, err := zipartifacts.DecodeFileEntry(os.Args[2])
if err != nil {
fatalError(fmt.Errorf("decode entry %q: %v", os.Args[2], err))
}
archive, err := zip.OpenReader(archiveFileName)
if err != nil {
notFoundError(fmt.Errorf("open %q: %v", archiveFileName, err))
}
defer archive.Close()
file := findFileInZip(fileName, &archive.Reader)
if file == nil {
notFoundError(fmt.Errorf("find %q in %q: not found", fileName, archiveFileName))
}
// Start decompressing the file
reader, err := file.Open()
if err != nil {
fatalError(fmt.Errorf("open %q in %q: %v", fileName, archiveFileName, err))
}
defer reader.Close()
if _, err := fmt.Printf("%d\n", file.UncompressedSize64); err != nil {
fatalError(fmt.Errorf("write file size: %v", err))
}
if _, err := io.Copy(os.Stdout, reader); err != nil {
fatalError(fmt.Errorf("write %q from %q to stdout: %v", fileName, archiveFileName, err))
}
}
func findFileInZip(fileName string, archive *zip.Reader) *zip.File {
for _, file := range archive.File {
if file.Name == fileName {
return file
}
}
return nil
}
func printError(err error) {
fmt.Fprintf(os.Stderr, "%s: %v", progName, err)
}
func fatalError(err error) {
printError(err)
os.Exit(1)
}
func notFoundError(err error) {
printError(err)
os.Exit(zipartifacts.StatusEntryNotFound)
}
package main
import (
"../../internal/zipartifacts"
"flag"
"fmt"
"os"
)
const progName = "gitlab-zip-metadata"
var Version = "unknown"
var printVersion = flag.Bool("version", false, "Print version and exit")
func main() {
flag.Parse()
version := fmt.Sprintf("%s %s", progName, Version)
if *printVersion {
fmt.Println(version)
os.Exit(0)
}
if len(os.Args) != 2 {
fmt.Fprintf(os.Stderr, "Usage: %s FILE.ZIP", progName)
os.Exit(1)
}
if err := zipartifacts.GenerateZipMetadataFromFile(os.Args[1], os.Stdout); err != nil {
fmt.Fprintf(os.Stderr, "%s: %v\n", progName, err)
if err == os.ErrInvalid {
os.Exit(zipartifacts.StatusNotZip)
}
os.Exit(1)
}
}
package main
import "net/http"
func handleDevelopmentMode(developmentMode *bool, handler serviceHandleFunc) serviceHandleFunc {
return func(w http.ResponseWriter, r *gitRequest) {
if !*developmentMode {
http.NotFound(w, r.Request)
return
}
handler(w, r)
}
}
package main package api
import ( import (
"../badgateway"
"../helper"
"../proxy"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net/http" "net/http"
"net/url"
"strings" "strings"
) )
func (u *upstream) newUpstreamRequest(r *http.Request, body io.Reader, suffix string) (*http.Request, error) { type API struct {
url := u.authBackend + "/" + strings.TrimPrefix(r.URL.RequestURI(), u.relativeURLRoot) + suffix Client *http.Client
authReq, err := http.NewRequest(r.Method, url, body) URL *url.URL
if err != nil { Version string
return nil, err }
func NewAPI(myURL *url.URL, version string, roundTripper *badgateway.RoundTripper) *API {
if roundTripper == nil {
roundTripper = badgateway.NewRoundTripper("", 0)
} }
// Forward all headers from our client to the auth backend. This includes return &API{
// HTTP Basic authentication credentials (the 'Authorization' header). Client: &http.Client{Transport: roundTripper},
for k, v := range r.Header { URL: myURL,
authReq.Header[k] = v Version: version,
}
}
type HandleFunc func(http.ResponseWriter, *http.Request, *Response)
type Response struct {
// GL_ID is an environment variable used by gitlab-shell hooks during 'git
// push' and 'git pull'
GL_ID string
// RepoPath is the full path on disk to the Git repository the request is
// about
RepoPath string
// ArchivePath is the full path where we should find/create a cached copy
// of a requested archive
ArchivePath string
// ArchivePrefix is used to put extracted archive contents in a
// subdirectory
ArchivePrefix string
// CommitId is used do prevent race conditions between the 'time of check'
// in the GitLab Rails app and the 'time of use' in gitlab-workhorse.
CommitId string
// StoreLFSPath is provided by the GitLab Rails application
// to mark where the tmp file should be placed
StoreLFSPath string
// LFS object id
LfsOid string
// LFS object size
LfsSize int64
// TmpPath is the path where we should store temporary files
// This is set by authorization middleware
TempPath string
// Archive is the path where the artifacts archive is stored
Archive string `json:"archive"`
// Entry is a filename inside the archive point to file that needs to be extracted
Entry string `json:"entry"`
}
// singleJoiningSlash is taken from reverseproxy.go:NewSingleHostReverseProxy
func singleJoiningSlash(a, b string) string {
aslash := strings.HasSuffix(a, "/")
bslash := strings.HasPrefix(b, "/")
switch {
case aslash && bslash:
return a + b[1:]
case !aslash && !bslash:
return a + "/" + b
}
return a + b
}
// rebaseUrl is taken from reverseproxy.go:NewSingleHostReverseProxy
func rebaseUrl(url *url.URL, onto *url.URL, suffix string) *url.URL {
newUrl := *url
newUrl.Scheme = onto.Scheme
newUrl.Host = onto.Host
if suffix != "" {
newUrl.Path = singleJoiningSlash(url.Path, suffix)
}
if onto.RawQuery == "" || newUrl.RawQuery == "" {
newUrl.RawQuery = onto.RawQuery + newUrl.RawQuery
} else {
newUrl.RawQuery = onto.RawQuery + "&" + newUrl.RawQuery
}
return &newUrl
}
func (api *API) newRequest(r *http.Request, body io.Reader, suffix string) (*http.Request, error) {
authReq := &http.Request{
Method: r.Method,
URL: rebaseUrl(r.URL, api.URL, suffix),
Header: proxy.HeaderClone(r.Header),
}
if body != nil {
authReq.Body = ioutil.NopCloser(body)
} }
// Clean some headers when issuing a new request without body // Clean some headers when issuing a new request without body
...@@ -46,22 +129,22 @@ func (u *upstream) newUpstreamRequest(r *http.Request, body io.Reader, suffix st ...@@ -46,22 +129,22 @@ func (u *upstream) newUpstreamRequest(r *http.Request, body io.Reader, suffix st
authReq.Host = r.Host authReq.Host = r.Host
// Set a custom header for the request. This can be used in some // Set a custom header for the request. This can be used in some
// configurations (Passenger) to solve auth request routing problems. // configurations (Passenger) to solve auth request routing problems.
authReq.Header.Set("Gitlab-Workhorse", Version) authReq.Header.Set("Gitlab-Workhorse", api.Version)
return authReq, nil return authReq, nil
} }
func preAuthorizeHandler(handleFunc serviceHandleFunc, suffix string) serviceHandleFunc { func (api *API) PreAuthorizeHandler(h HandleFunc, suffix string) http.Handler {
return func(w http.ResponseWriter, r *gitRequest) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authReq, err := r.u.newUpstreamRequest(r.Request, nil, suffix) authReq, err := api.newRequest(r, nil, suffix)
if err != nil { if err != nil {
fail500(w, fmt.Errorf("preAuthorizeHandler: newUpstreamRequest: %v", err)) helper.Fail500(w, fmt.Errorf("preAuthorizeHandler: newUpstreamRequest: %v", err))
return return
} }
authResponse, err := r.u.httpClient.Do(authReq) authResponse, err := api.Client.Do(authReq)
if err != nil { if err != nil {
fail500(w, fmt.Errorf("preAuthorizeHandler: do %v: %v", authReq.URL.Path, err)) helper.Fail500(w, fmt.Errorf("preAuthorizeHandler: do %v: %v", authReq.URL.Path, err))
return return
} }
defer authResponse.Body.Close() defer authResponse.Body.Close()
...@@ -85,11 +168,12 @@ func preAuthorizeHandler(handleFunc serviceHandleFunc, suffix string) serviceHan ...@@ -85,11 +168,12 @@ func preAuthorizeHandler(handleFunc serviceHandleFunc, suffix string) serviceHan
return return
} }
a := &Response{}
// The auth backend validated the client request and told us additional // The auth backend validated the client request and told us additional
// request metadata. We must extract this information from the auth // request metadata. We must extract this information from the auth
// response body. // response body.
if err := json.NewDecoder(authResponse.Body).Decode(&r.authorizationResponse); err != nil { if err := json.NewDecoder(authResponse.Body).Decode(a); err != nil {
fail500(w, fmt.Errorf("preAuthorizeHandler: decode authorization response: %v", err)) helper.Fail500(w, fmt.Errorf("preAuthorizeHandler: decode authorization response: %v", err))
return return
} }
// Don't hog a TCP connection in CLOSE_WAIT, we can already close it now // Don't hog a TCP connection in CLOSE_WAIT, we can already close it now
...@@ -104,6 +188,6 @@ func preAuthorizeHandler(handleFunc serviceHandleFunc, suffix string) serviceHan ...@@ -104,6 +188,6 @@ func preAuthorizeHandler(handleFunc serviceHandleFunc, suffix string) serviceHan
} }
} }
handleFunc(w, r) h(w, r, a)
} })
} }
package artifacts
import (
"../api"
"../helper"
"../zipartifacts"
"bufio"
"errors"
"fmt"
"io"
"mime"
"net/http"
"os"
"os/exec"
"path/filepath"
"strings"
"syscall"
)
func detectFileContentType(fileName string) string {
contentType := mime.TypeByExtension(filepath.Ext(fileName))
if contentType == "" {
contentType = "application/octet-stream"
}
return contentType
}
func unpackFileFromZip(archiveFileName, encodedFilename string, headers http.Header, output io.Writer) error {
fileName, err := zipartifacts.DecodeFileEntry(encodedFilename)
if err != nil {
return err
}
catFile := exec.Command("gitlab-zip-cat", archiveFileName, encodedFilename)
catFile.Stderr = os.Stderr
catFile.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
stdout, err := catFile.StdoutPipe()
if err != nil {
return fmt.Errorf("create gitlab-zip-cat stdout pipe: %v", err)
}
if err := catFile.Start(); err != nil {
return fmt.Errorf("start %v: %v", catFile.Args, err)
}
defer helper.CleanUpProcessGroup(catFile)
basename := filepath.Base(fileName)
reader := bufio.NewReader(stdout)
contentLength, err := reader.ReadString('\n')
if err != nil {
if catFileErr := waitCatFile(catFile); catFileErr != nil {
return catFileErr
}
return fmt.Errorf("read content-length: %v", err)
}
contentLength = strings.TrimSuffix(contentLength, "\n")
// Write http headers about the file
headers.Set("Content-Length", contentLength)
headers.Set("Content-Type", detectFileContentType(fileName))
headers.Set("Content-Disposition", "attachment; filename=\""+escapeQuotes(basename)+"\"")
// Copy file body to client
if _, err := io.Copy(output, reader); err != nil {
return fmt.Errorf("copy %v stdout: %v", catFile.Args, err)
}
return waitCatFile(catFile)
}
func waitCatFile(cmd *exec.Cmd) error {
err := cmd.Wait()
if err == nil {
return nil
}
if st, ok := helper.ExitStatus(err); ok && st == zipartifacts.StatusEntryNotFound {
return os.ErrNotExist
}
return fmt.Errorf("wait for %v to finish: %v", cmd.Args, err)
}
// Artifacts downloader doesn't support ranges when downloading a single file
func DownloadArtifact(myAPI *api.API) http.Handler {
return myAPI.PreAuthorizeHandler(func(w http.ResponseWriter, r *http.Request, a *api.Response) {
if a.Archive == "" || a.Entry == "" {
helper.Fail500(w, errors.New("DownloadArtifact: Archive or Path is empty"))
return
}
err := unpackFileFromZip(a.Archive, a.Entry, w.Header(), w)
if os.IsNotExist(err) {
http.NotFound(w, r)
return
} else if err != nil {
helper.Fail500(w, fmt.Errorf("DownloadArtifact: %v", err))
}
}, "")
}
package artifacts
import (
"../api"
"../helper"
"../testhelper"
"archive/zip"
"encoding/base64"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"testing"
)
func testArtifactDownloadServer(t *testing.T, archive string, entry string) *httptest.Server {
mux := http.NewServeMux()
mux.HandleFunc("/url/path", func(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" {
t.Fatal("Expected GET request")
}
w.Header().Set("Content-Type", "application/json")
data, err := json.Marshal(&api.Response{
Archive: archive,
Entry: base64.StdEncoding.EncodeToString([]byte(entry)),
})
if err != nil {
t.Fatal(err)
}
w.Write(data)
})
return testhelper.TestServerWithHandler(nil, mux.ServeHTTP)
}
func testDownloadArtifact(t *testing.T, ts *httptest.Server) *httptest.ResponseRecorder {
httpRequest, err := http.NewRequest("GET", ts.URL+"/url/path", nil)
if err != nil {
t.Fatal(err)
}
response := httptest.NewRecorder()
apiClient := api.NewAPI(helper.URLMustParse(ts.URL), "123", nil)
DownloadArtifact(apiClient).ServeHTTP(response, httpRequest)
return response
}
func TestDownloadingFromValidArchive(t *testing.T) {
tempFile, err := ioutil.TempFile("", "uploads")
if err != nil {
t.Fatal(err)
}
defer tempFile.Close()
defer os.Remove(tempFile.Name())
archive := zip.NewWriter(tempFile)
defer archive.Close()
fileInArchive, err := archive.Create("test.txt")
if err != nil {
t.Fatal(err)
}
fmt.Fprint(fileInArchive, "testtest")
archive.Close()
ts := testArtifactDownloadServer(t, tempFile.Name(), "test.txt")
defer ts.Close()
response := testDownloadArtifact(t, ts)
testhelper.AssertResponseCode(t, response, 200)
testhelper.AssertResponseHeader(t, response,
"Content-Type",
"text/plain; charset=utf-8")
testhelper.AssertResponseHeader(t, response,
"Content-Disposition",
"attachment; filename=\"test.txt\"")
testhelper.AssertResponseBody(t, response, "testtest")
}
func TestDownloadingNonExistingFile(t *testing.T) {
tempFile, err := ioutil.TempFile("", "uploads")
if err != nil {
t.Fatal(err)
}
defer tempFile.Close()
defer os.Remove(tempFile.Name())
archive := zip.NewWriter(tempFile)
defer archive.Close()
archive.Close()
ts := testArtifactDownloadServer(t, tempFile.Name(), "test")
defer ts.Close()
response := testDownloadArtifact(t, ts)
testhelper.AssertResponseCode(t, response, 404)
}
func TestDownloadingFromInvalidArchive(t *testing.T) {
ts := testArtifactDownloadServer(t, "path/to/non/existing/file", "test")
defer ts.Close()
response := testDownloadArtifact(t, ts)
testhelper.AssertResponseCode(t, response, 404)
}
func TestIncompleteApiResponse(t *testing.T) {
ts := testArtifactDownloadServer(t, "", "")
defer ts.Close()
response := testDownloadArtifact(t, ts)
testhelper.AssertResponseCode(t, response, 500)
}
package artifacts
import (
"../api"
"../helper"
"../upload"
"../zipartifacts"
"errors"
"fmt"
"io/ioutil"
"mime/multipart"
"net/http"
"os"
"os/exec"
"syscall"
)
type artifactsUploadProcessor struct {
TempPath string
metadataFile string
}
func (a *artifactsUploadProcessor) ProcessFile(formName, fileName string, writer *multipart.Writer) error {
// ProcessFile for artifacts requires file form-data field name to eq `file`
if formName != "file" {
return fmt.Errorf("Invalid form field: %q", formName)
}
if a.metadataFile != "" {
return fmt.Errorf("Artifacts request contains more than one file!")
}
// Create temporary file for metadata and store it's path
tempFile, err := ioutil.TempFile(a.TempPath, "metadata_")
if err != nil {
return err
}
defer tempFile.Close()
a.metadataFile = tempFile.Name()
// Generate metadata and save to file
zipMd := exec.Command("gitlab-zip-metadata", fileName)
zipMd.Stderr = os.Stderr
zipMd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
zipMd.Stdout = tempFile
if err := zipMd.Start(); err != nil {
return err
}
defer helper.CleanUpProcessGroup(zipMd)
if err := zipMd.Wait(); err != nil {
if st, ok := helper.ExitStatus(err); ok && st == zipartifacts.StatusNotZip {
return nil
}
return err
}
// Pass metadata file path to Rails
writer.WriteField("metadata.path", a.metadataFile)
writer.WriteField("metadata.name", "metadata.gz")
return nil
}
func (a *artifactsUploadProcessor) ProcessField(formName string, writer *multipart.Writer) error {
return nil
}
func (a *artifactsUploadProcessor) Cleanup() {
if a.metadataFile != "" {
os.Remove(a.metadataFile)
}
}
func UploadArtifacts(myAPI *api.API, h http.Handler) http.Handler {
return myAPI.PreAuthorizeHandler(func(w http.ResponseWriter, r *http.Request, a *api.Response) {
if a.TempPath == "" {
helper.Fail500(w, errors.New("UploadArtifacts: TempPath is empty"))
return
}
mg := &artifactsUploadProcessor{TempPath: a.TempPath}
defer mg.Cleanup()
upload.HandleFileUploads(w, r, h, a.TempPath, mg)
}, "/authorize")
}
package artifacts
import (
"../api"
"../helper"
"../proxy"
"../testhelper"
"../zipartifacts"
"archive/zip"
"bytes"
"compress/gzip"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"mime/multipart"
"net/http"
"net/http/httptest"
"os"
"testing"
)
func testArtifactsUploadServer(t *testing.T, tempPath string) *httptest.Server {
mux := http.NewServeMux()
mux.HandleFunc("/url/path/authorize", func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
t.Fatal("Expected POST request")
}
w.Header().Set("Content-Type", "application/json")
data, err := json.Marshal(&api.Response{
TempPath: tempPath,
})
if err != nil {
t.Fatal("Expected to marshal")
}
w.Write(data)
})
mux.HandleFunc("/url/path", func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
t.Fatal("Expected POST request")
}
if r.FormValue("file.path") == "" {
w.WriteHeader(501)
return
}
if r.FormValue("metadata.path") == "" {
w.WriteHeader(502)
return
}
_, err := ioutil.ReadFile(r.FormValue("file.path"))
if err != nil {
w.WriteHeader(404)
return
}
metadata, err := ioutil.ReadFile(r.FormValue("metadata.path"))
if err != nil {
w.WriteHeader(404)
return
}
gz, err := gzip.NewReader(bytes.NewReader(metadata))
if err != nil {
w.WriteHeader(405)
return
}
defer gz.Close()
metadata, err = ioutil.ReadAll(gz)
if err != nil {
w.WriteHeader(404)
return
}
if !bytes.HasPrefix(metadata, []byte(zipartifacts.MetadataHeaderPrefix+zipartifacts.MetadataHeader)) {
w.WriteHeader(400)
return
}
w.WriteHeader(200)
})
return testhelper.TestServerWithHandler(nil, mux.ServeHTTP)
}
func testUploadArtifacts(contentType string, body io.Reader, t *testing.T, ts *httptest.Server) *httptest.ResponseRecorder {
httpRequest, err := http.NewRequest("POST", ts.URL+"/url/path", body)
if err != nil {
t.Fatal(err)
}
httpRequest.Header.Set("Content-Type", contentType)
response := httptest.NewRecorder()
apiClient := api.NewAPI(helper.URLMustParse(ts.URL), "123", nil)
proxyClient := proxy.NewProxy(helper.URLMustParse(ts.URL), "123", nil)
UploadArtifacts(apiClient, proxyClient).ServeHTTP(response, httpRequest)
return response
}
func TestUploadHandlerAddingMetadata(t *testing.T) {
tempPath, err := ioutil.TempDir("", "uploads")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tempPath)
ts := testArtifactsUploadServer(t, tempPath)
defer ts.Close()
var buffer bytes.Buffer
writer := multipart.NewWriter(&buffer)
file, err := writer.CreateFormFile("file", "my.file")
if err != nil {
t.Fatal(err)
}
archive := zip.NewWriter(file)
defer archive.Close()
fileInArchive, err := archive.Create("test.file")
if err != nil {
t.Fatal(err)
}
fmt.Fprint(fileInArchive, "test")
archive.Close()
writer.Close()
response := testUploadArtifacts(writer.FormDataContentType(), &buffer, t, ts)
testhelper.AssertResponseCode(t, response, 200)
}
func TestUploadHandlerForUnsupportedArchive(t *testing.T) {
tempPath, err := ioutil.TempDir("", "uploads")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tempPath)
ts := testArtifactsUploadServer(t, tempPath)
defer ts.Close()
var buffer bytes.Buffer
writer := multipart.NewWriter(&buffer)
file, err := writer.CreateFormFile("file", "my.file")
if err != nil {
t.Fatal(err)
}
fmt.Fprint(file, "test")
writer.Close()
response := testUploadArtifacts(writer.FormDataContentType(), &buffer, t, ts)
// 502 is a custom response code from the mock server in testUploadArtifacts
testhelper.AssertResponseCode(t, response, 502)
}
func TestUploadFormProcessing(t *testing.T) {
tempPath, err := ioutil.TempDir("", "uploads")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tempPath)
ts := testArtifactsUploadServer(t, tempPath)
defer ts.Close()
var buffer bytes.Buffer
writer := multipart.NewWriter(&buffer)
file, err := writer.CreateFormFile("metadata", "my.file")
if err != nil {
t.Fatal(err)
}
fmt.Fprint(file, "test")
writer.Close()
response := testUploadArtifacts(writer.FormDataContentType(), &buffer, t, ts)
testhelper.AssertResponseCode(t, response, 500)
}
package artifacts
import "strings"
// taken from mime/multipart/writer.go
var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"")
func escapeQuotes(s string) string {
return quoteEscaper.Replace(s)
}
package main package badgateway
import ( import (
"../helper"
"bytes" "bytes"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net"
"net/http" "net/http"
"time"
) )
type proxyRoundTripper struct { // Values from http.DefaultTransport
transport http.RoundTripper var DefaultDialer = &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
} }
func (p *proxyRoundTripper) RoundTrip(r *http.Request) (res *http.Response, err error) { var DefaultTransport = &http.Transport{
res, err = p.transport.RoundTrip(r) Proxy: http.ProxyFromEnvironment, // from http.DefaultTransport
Dial: DefaultDialer.Dial, // from http.DefaultTransport
TLSHandshakeTimeout: 10 * time.Second, // from http.DefaultTransport
}
type RoundTripper struct {
Transport *http.Transport
}
func NewRoundTripper(socket string, proxyHeadersTimeout time.Duration) *RoundTripper {
tr := *DefaultTransport
tr.ResponseHeaderTimeout = proxyHeadersTimeout
if socket != "" {
tr.Dial = func(_, _ string) (net.Conn, error) {
return DefaultDialer.Dial("unix", socket)
}
}
return &RoundTripper{Transport: &tr}
}
func (t *RoundTripper) RoundTrip(r *http.Request) (res *http.Response, err error) {
res, err = t.Transport.RoundTrip(r)
// httputil.ReverseProxy translates all errors from this // httputil.ReverseProxy translates all errors from this
// RoundTrip function into 500 errors. But the most likely error // RoundTrip function into 500 errors. But the most likely error
...@@ -21,7 +48,7 @@ func (p *proxyRoundTripper) RoundTrip(r *http.Request) (res *http.Response, err ...@@ -21,7 +48,7 @@ func (p *proxyRoundTripper) RoundTrip(r *http.Request) (res *http.Response, err
// instead of 500s we catch the RoundTrip error here and inject a // instead of 500s we catch the RoundTrip error here and inject a
// 502 response. // 502 response.
if err != nil { if err != nil {
logError(fmt.Errorf("proxyRoundTripper: %s %q failed with: %q", r.Method, r.RequestURI, err)) helper.LogError(fmt.Errorf("proxyRoundTripper: %s %q failed with: %q", r.Method, r.RequestURI, err))
res = &http.Response{ res = &http.Response{
StatusCode: http.StatusBadGateway, StatusCode: http.StatusBadGateway,
...@@ -40,26 +67,3 @@ func (p *proxyRoundTripper) RoundTrip(r *http.Request) (res *http.Response, err ...@@ -40,26 +67,3 @@ func (p *proxyRoundTripper) RoundTrip(r *http.Request) (res *http.Response, err
} }
return return
} }
func headerClone(h http.Header) http.Header {
h2 := make(http.Header, len(h))
for k, vv := range h {
vv2 := make([]string, len(vv))
copy(vv2, vv)
h2[k] = vv2
}
return h2
}
func proxyRequest(w http.ResponseWriter, r *gitRequest) {
// Clone request
req := *r.Request
req.Header = headerClone(r.Header)
// Set Workhorse version
req.Header.Set("Gitlab-Workhorse", Version)
rw := newSendFileResponseWriter(w, &req)
defer rw.Flush()
r.u.httpProxy.ServeHTTP(&rw, &req)
}
...@@ -2,9 +2,11 @@ ...@@ -2,9 +2,11 @@
In this file we handle 'git archive' downloads In this file we handle 'git archive' downloads
*/ */
package main package git
import ( import (
"../api"
"../helper"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
...@@ -18,7 +20,11 @@ import ( ...@@ -18,7 +20,11 @@ import (
"time" "time"
) )
func handleGetArchive(w http.ResponseWriter, r *gitRequest) { func GetArchive(a *api.API) http.Handler {
return repoPreAuthorizeHandler(a, handleGetArchive)
}
func handleGetArchive(w http.ResponseWriter, r *http.Request, a *api.Response) {
var format string var format string
urlPath := r.URL.Path urlPath := r.URL.Path
switch filepath.Base(urlPath) { switch filepath.Base(urlPath) {
...@@ -31,20 +37,20 @@ func handleGetArchive(w http.ResponseWriter, r *gitRequest) { ...@@ -31,20 +37,20 @@ func handleGetArchive(w http.ResponseWriter, r *gitRequest) {
case "archive.tar.bz2": case "archive.tar.bz2":
format = "tar.bz2" format = "tar.bz2"
default: default:
fail500(w, fmt.Errorf("handleGetArchive: invalid format: %s", urlPath)) helper.Fail500(w, fmt.Errorf("handleGetArchive: invalid format: %s", urlPath))
return return
} }
archiveFilename := path.Base(r.ArchivePath) archiveFilename := path.Base(a.ArchivePath)
if cachedArchive, err := os.Open(r.ArchivePath); err == nil { if cachedArchive, err := os.Open(a.ArchivePath); err == nil {
defer cachedArchive.Close() defer cachedArchive.Close()
log.Printf("Serving cached file %q", r.ArchivePath) log.Printf("Serving cached file %q", a.ArchivePath)
setArchiveHeaders(w, format, archiveFilename) setArchiveHeaders(w, format, archiveFilename)
// Even if somebody deleted the cachedArchive from disk since we opened // Even if somebody deleted the cachedArchive from disk since we opened
// the file, Unix file semantics guarantee we can still read from the // the file, Unix file semantics guarantee we can still read from the
// open file in this process. // open file in this process.
http.ServeContent(w, r.Request, "", time.Unix(0, 0), cachedArchive) http.ServeContent(w, r, "", time.Unix(0, 0), cachedArchive)
return return
} }
...@@ -52,9 +58,9 @@ func handleGetArchive(w http.ResponseWriter, r *gitRequest) { ...@@ -52,9 +58,9 @@ func handleGetArchive(w http.ResponseWriter, r *gitRequest) {
// safe. We create the tempfile in the same directory as the final cached // safe. We create the tempfile in the same directory as the final cached
// archive we want to create so that we can use an atomic link(2) operation // archive we want to create so that we can use an atomic link(2) operation
// to finalize the cached archive. // to finalize the cached archive.
tempFile, err := prepareArchiveTempfile(path.Dir(r.ArchivePath), archiveFilename) tempFile, err := prepareArchiveTempfile(path.Dir(a.ArchivePath), archiveFilename)
if err != nil { if err != nil {
fail500(w, fmt.Errorf("handleGetArchive: create tempfile: %v", err)) helper.Fail500(w, fmt.Errorf("handleGetArchive: create tempfile: %v", err))
return return
} }
defer tempFile.Close() defer tempFile.Close()
...@@ -62,18 +68,18 @@ func handleGetArchive(w http.ResponseWriter, r *gitRequest) { ...@@ -62,18 +68,18 @@ func handleGetArchive(w http.ResponseWriter, r *gitRequest) {
compressCmd, archiveFormat := parseArchiveFormat(format) compressCmd, archiveFormat := parseArchiveFormat(format)
archiveCmd := gitCommand("", "git", "--git-dir="+r.RepoPath, "archive", "--format="+archiveFormat, "--prefix="+r.ArchivePrefix+"/", r.CommitId) archiveCmd := gitCommand("", "git", "--git-dir="+a.RepoPath, "archive", "--format="+archiveFormat, "--prefix="+a.ArchivePrefix+"/", a.CommitId)
archiveStdout, err := archiveCmd.StdoutPipe() archiveStdout, err := archiveCmd.StdoutPipe()
if err != nil { if err != nil {
fail500(w, fmt.Errorf("handleGetArchive: archive stdout: %v", err)) helper.Fail500(w, fmt.Errorf("handleGetArchive: archive stdout: %v", err))
return return
} }
defer archiveStdout.Close() defer archiveStdout.Close()
if err := archiveCmd.Start(); err != nil { if err := archiveCmd.Start(); err != nil {
fail500(w, fmt.Errorf("handleGetArchive: start %v: %v", archiveCmd.Args, err)) helper.Fail500(w, fmt.Errorf("handleGetArchive: start %v: %v", archiveCmd.Args, err))
return return
} }
defer cleanUpProcessGroup(archiveCmd) // Ensure brute force subprocess clean-up defer helper.CleanUpProcessGroup(archiveCmd) // Ensure brute force subprocess clean-up
var stdout io.ReadCloser var stdout io.ReadCloser
if compressCmd == nil { if compressCmd == nil {
...@@ -84,16 +90,16 @@ func handleGetArchive(w http.ResponseWriter, r *gitRequest) { ...@@ -84,16 +90,16 @@ func handleGetArchive(w http.ResponseWriter, r *gitRequest) {
stdout, err = compressCmd.StdoutPipe() stdout, err = compressCmd.StdoutPipe()
if err != nil { if err != nil {
fail500(w, fmt.Errorf("handleGetArchive: compress stdout: %v", err)) helper.Fail500(w, fmt.Errorf("handleGetArchive: compress stdout: %v", err))
return return
} }
defer stdout.Close() defer stdout.Close()
if err := compressCmd.Start(); err != nil { if err := compressCmd.Start(); err != nil {
fail500(w, fmt.Errorf("handleGetArchive: start %v: %v", compressCmd.Args, err)) helper.Fail500(w, fmt.Errorf("handleGetArchive: start %v: %v", compressCmd.Args, err))
return return
} }
defer cleanUpProcessGroup(compressCmd) defer helper.CleanUpProcessGroup(compressCmd)
archiveStdout.Close() archiveStdout.Close()
} }
...@@ -105,22 +111,22 @@ func handleGetArchive(w http.ResponseWriter, r *gitRequest) { ...@@ -105,22 +111,22 @@ func handleGetArchive(w http.ResponseWriter, r *gitRequest) {
setArchiveHeaders(w, format, archiveFilename) setArchiveHeaders(w, format, archiveFilename)
w.WriteHeader(200) // Don't bother with HTTP 500 from this point on, just return w.WriteHeader(200) // Don't bother with HTTP 500 from this point on, just return
if _, err := io.Copy(w, archiveReader); err != nil { if _, err := io.Copy(w, archiveReader); err != nil {
logError(fmt.Errorf("handleGetArchive: read: %v", err)) helper.LogError(fmt.Errorf("handleGetArchive: read: %v", err))
return return
} }
if err := archiveCmd.Wait(); err != nil { if err := archiveCmd.Wait(); err != nil {
logError(fmt.Errorf("handleGetArchive: archiveCmd: %v", err)) helper.LogError(fmt.Errorf("handleGetArchive: archiveCmd: %v", err))
return return
} }
if compressCmd != nil { if compressCmd != nil {
if err := compressCmd.Wait(); err != nil { if err := compressCmd.Wait(); err != nil {
logError(fmt.Errorf("handleGetArchive: compressCmd: %v", err)) helper.LogError(fmt.Errorf("handleGetArchive: compressCmd: %v", err))
return return
} }
} }
if err := finalizeCachedArchive(tempFile, r.ArchivePath); err != nil { if err := finalizeCachedArchive(tempFile, a.ArchivePath); err != nil {
logError(fmt.Errorf("handleGetArchive: finalize cached archive: %v", err)) helper.LogError(fmt.Errorf("handleGetArchive: finalize cached archive: %v", err))
return return
} }
} }
......
package git
import (
"fmt"
"os"
"os/exec"
"syscall"
)
// Git subprocess helpers
func gitCommand(gl_id string, name string, args ...string) *exec.Cmd {
cmd := exec.Command(name, args...)
// Start the command in its own process group (nice for signalling)
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
// Explicitly set the environment for the Git command
cmd.Env = []string{
fmt.Sprintf("HOME=%s", os.Getenv("HOME")),
fmt.Sprintf("PATH=%s", os.Getenv("PATH")),
fmt.Sprintf("LD_LIBRARY_PATH=%s", os.Getenv("LD_LIBRARY_PATH")),
fmt.Sprintf("GL_ID=%s", gl_id),
}
// If we don't do something with cmd.Stderr, Git errors will be lost
cmd.Stderr = os.Stderr
return cmd
}
...@@ -2,9 +2,11 @@ ...@@ -2,9 +2,11 @@
In this file we handle the Git 'smart HTTP' protocol In this file we handle the Git 'smart HTTP' protocol
*/ */
package main package git
import ( import (
"../api"
"../helper"
"errors" "errors"
"fmt" "fmt"
"io" "io"
...@@ -16,6 +18,14 @@ import ( ...@@ -16,6 +18,14 @@ import (
"strings" "strings"
) )
func GetInfoRefs(a *api.API) http.Handler {
return repoPreAuthorizeHandler(a, handleGetInfoRefs)
}
func PostRPC(a *api.API) http.Handler {
return repoPreAuthorizeHandler(a, handlePostRPC)
}
func looksLikeRepo(p string) bool { func looksLikeRepo(p string) bool {
// If /path/to/foo.git/objects exists then let's assume it is a valid Git // If /path/to/foo.git/objects exists then let's assume it is a valid Git
// repository. // repository.
...@@ -26,23 +36,23 @@ func looksLikeRepo(p string) bool { ...@@ -26,23 +36,23 @@ func looksLikeRepo(p string) bool {
return true return true
} }
func repoPreAuthorizeHandler(handleFunc serviceHandleFunc) serviceHandleFunc { func repoPreAuthorizeHandler(myAPI *api.API, handleFunc api.HandleFunc) http.Handler {
return preAuthorizeHandler(func(w http.ResponseWriter, r *gitRequest) { return myAPI.PreAuthorizeHandler(func(w http.ResponseWriter, r *http.Request, a *api.Response) {
if r.RepoPath == "" { if a.RepoPath == "" {
fail500(w, errors.New("repoPreAuthorizeHandler: RepoPath empty")) helper.Fail500(w, errors.New("repoPreAuthorizeHandler: RepoPath empty"))
return return
} }
if !looksLikeRepo(r.RepoPath) { if !looksLikeRepo(a.RepoPath) {
http.Error(w, "Not Found", 404) http.Error(w, "Not Found", 404)
return return
} }
handleFunc(w, r) handleFunc(w, r, a)
}, "") }, "")
} }
func handleGetInfoRefs(w http.ResponseWriter, r *gitRequest) { func handleGetInfoRefs(w http.ResponseWriter, r *http.Request, a *api.Response) {
rpc := r.URL.Query().Get("service") rpc := r.URL.Query().Get("service")
if !(rpc == "git-upload-pack" || rpc == "git-receive-pack") { if !(rpc == "git-upload-pack" || rpc == "git-receive-pack") {
// The 'dumb' Git HTTP protocol is not supported // The 'dumb' Git HTTP protocol is not supported
...@@ -51,75 +61,75 @@ func handleGetInfoRefs(w http.ResponseWriter, r *gitRequest) { ...@@ -51,75 +61,75 @@ func handleGetInfoRefs(w http.ResponseWriter, r *gitRequest) {
} }
// Prepare our Git subprocess // Prepare our Git subprocess
cmd := gitCommand(r.GL_ID, "git", subCommand(rpc), "--stateless-rpc", "--advertise-refs", r.RepoPath) cmd := gitCommand(a.GL_ID, "git", subCommand(rpc), "--stateless-rpc", "--advertise-refs", a.RepoPath)
stdout, err := cmd.StdoutPipe() stdout, err := cmd.StdoutPipe()
if err != nil { if err != nil {
fail500(w, fmt.Errorf("handleGetInfoRefs: stdout: %v", err)) helper.Fail500(w, fmt.Errorf("handleGetInfoRefs: stdout: %v", err))
return return
} }
defer stdout.Close() defer stdout.Close()
if err := cmd.Start(); err != nil { if err := cmd.Start(); err != nil {
fail500(w, fmt.Errorf("handleGetInfoRefs: start %v: %v", cmd.Args, err)) helper.Fail500(w, fmt.Errorf("handleGetInfoRefs: start %v: %v", cmd.Args, err))
return return
} }
defer cleanUpProcessGroup(cmd) // Ensure brute force subprocess clean-up defer helper.CleanUpProcessGroup(cmd) // Ensure brute force subprocess clean-up
// Start writing the response // Start writing the response
w.Header().Add("Content-Type", fmt.Sprintf("application/x-%s-advertisement", rpc)) w.Header().Add("Content-Type", fmt.Sprintf("application/x-%s-advertisement", rpc))
w.Header().Add("Cache-Control", "no-cache") w.Header().Add("Cache-Control", "no-cache")
w.WriteHeader(200) // Don't bother with HTTP 500 from this point on, just return w.WriteHeader(200) // Don't bother with HTTP 500 from this point on, just return
if err := pktLine(w, fmt.Sprintf("# service=%s\n", rpc)); err != nil { if err := pktLine(w, fmt.Sprintf("# service=%s\n", rpc)); err != nil {
logError(fmt.Errorf("handleGetInfoRefs: pktLine: %v", err)) helper.LogError(fmt.Errorf("handleGetInfoRefs: pktLine: %v", err))
return return
} }
if err := pktFlush(w); err != nil { if err := pktFlush(w); err != nil {
logError(fmt.Errorf("handleGetInfoRefs: pktFlush: %v", err)) helper.LogError(fmt.Errorf("handleGetInfoRefs: pktFlush: %v", err))
return return
} }
if _, err := io.Copy(w, stdout); err != nil { if _, err := io.Copy(w, stdout); err != nil {
logError(fmt.Errorf("handleGetInfoRefs: read from %v: %v", cmd.Args, err)) helper.LogError(fmt.Errorf("handleGetInfoRefs: read from %v: %v", cmd.Args, err))
return return
} }
if err := cmd.Wait(); err != nil { if err := cmd.Wait(); err != nil {
logError(fmt.Errorf("handleGetInfoRefs: wait for %v: %v", cmd.Args, err)) helper.LogError(fmt.Errorf("handleGetInfoRefs: wait for %v: %v", cmd.Args, err))
return return
} }
} }
func handlePostRPC(w http.ResponseWriter, r *gitRequest) { func handlePostRPC(w http.ResponseWriter, r *http.Request, a *api.Response) {
var err error var err error
// Get Git action from URL // Get Git action from URL
action := filepath.Base(r.URL.Path) action := filepath.Base(r.URL.Path)
if !(action == "git-upload-pack" || action == "git-receive-pack") { if !(action == "git-upload-pack" || action == "git-receive-pack") {
// The 'dumb' Git HTTP protocol is not supported // The 'dumb' Git HTTP protocol is not supported
fail500(w, fmt.Errorf("handlePostRPC: unsupported action: %s", r.URL.Path)) helper.Fail500(w, fmt.Errorf("handlePostRPC: unsupported action: %s", r.URL.Path))
return return
} }
// Prepare our Git subprocess // Prepare our Git subprocess
cmd := gitCommand(r.GL_ID, "git", subCommand(action), "--stateless-rpc", r.RepoPath) cmd := gitCommand(a.GL_ID, "git", subCommand(action), "--stateless-rpc", a.RepoPath)
stdout, err := cmd.StdoutPipe() stdout, err := cmd.StdoutPipe()
if err != nil { if err != nil {
fail500(w, fmt.Errorf("handlePostRPC: stdout: %v", err)) helper.Fail500(w, fmt.Errorf("handlePostRPC: stdout: %v", err))
return return
} }
defer stdout.Close() defer stdout.Close()
stdin, err := cmd.StdinPipe() stdin, err := cmd.StdinPipe()
if err != nil { if err != nil {
fail500(w, fmt.Errorf("handlePostRPC: stdin: %v", err)) helper.Fail500(w, fmt.Errorf("handlePostRPC: stdin: %v", err))
return return
} }
defer stdin.Close() defer stdin.Close()
if err := cmd.Start(); err != nil { if err := cmd.Start(); err != nil {
fail500(w, fmt.Errorf("handlePostRPC: start %v: %v", cmd.Args, err)) helper.Fail500(w, fmt.Errorf("handlePostRPC: start %v: %v", cmd.Args, err))
return return
} }
defer cleanUpProcessGroup(cmd) // Ensure brute force subprocess clean-up defer helper.CleanUpProcessGroup(cmd) // Ensure brute force subprocess clean-up
// Write the client request body to Git's standard input // Write the client request body to Git's standard input
if _, err := io.Copy(stdin, r.Body); err != nil { if _, err := io.Copy(stdin, r.Body); err != nil {
fail500(w, fmt.Errorf("handlePostRPC write to %v: %v", cmd.Args, err)) helper.Fail500(w, fmt.Errorf("handlePostRPC write to %v: %v", cmd.Args, err))
return return
} }
// Signal to the Git subprocess that no more data is coming // Signal to the Git subprocess that no more data is coming
...@@ -136,11 +146,11 @@ func handlePostRPC(w http.ResponseWriter, r *gitRequest) { ...@@ -136,11 +146,11 @@ func handlePostRPC(w http.ResponseWriter, r *gitRequest) {
// This io.Copy may take a long time, both for Git push and pull. // This io.Copy may take a long time, both for Git push and pull.
if _, err := io.Copy(w, stdout); err != nil { if _, err := io.Copy(w, stdout); err != nil {
logError(fmt.Errorf("handlePostRPC read from %v: %v", cmd.Args, err)) helper.LogError(fmt.Errorf("handlePostRPC read from %v: %v", cmd.Args, err))
return return
} }
if err := cmd.Wait(); err != nil { if err := cmd.Wait(); err != nil {
logError(fmt.Errorf("handlePostRPC wait for %v: %v", cmd.Args, err)) helper.LogError(fmt.Errorf("handlePostRPC wait for %v: %v", cmd.Args, err))
return return
} }
} }
......
/* package helper
Miscellaneous helpers: logging, errors, subprocesses
*/
package main
import ( import (
"errors" "errors"
"fmt"
"log" "log"
"net/http" "net/http"
"net/url"
"os" "os"
"os/exec" "os/exec"
"path"
"syscall" "syscall"
) )
func fail500(w http.ResponseWriter, err error) { func Fail500(w http.ResponseWriter, err error) {
http.Error(w, "Internal server error", 500) http.Error(w, "Internal server error", 500)
logError(err) LogError(err)
} }
func logError(err error) { func LogError(err error) {
log.Printf("error: %v", err) log.Printf("error: %v", err)
} }
func httpError(w http.ResponseWriter, r *http.Request, error string, code int) { func SetNoCacheHeaders(header http.Header) {
if r.ProtoAtLeast(1, 1) {
// Force client to disconnect if we render request error
w.Header().Set("Connection", "close")
}
http.Error(w, error, code)
}
// Git subprocess helpers
func gitCommand(gl_id string, name string, args ...string) *exec.Cmd {
cmd := exec.Command(name, args...)
// Start the command in its own process group (nice for signalling)
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
// Explicitly set the environment for the Git command
cmd.Env = []string{
fmt.Sprintf("HOME=%s", os.Getenv("HOME")),
fmt.Sprintf("PATH=%s", os.Getenv("PATH")),
fmt.Sprintf("LD_LIBRARY_PATH=%s", os.Getenv("LD_LIBRARY_PATH")),
fmt.Sprintf("GL_ID=%s", gl_id),
}
// If we don't do something with cmd.Stderr, Git errors will be lost
cmd.Stderr = os.Stderr
return cmd
}
func cleanUpProcessGroup(cmd *exec.Cmd) {
if cmd == nil {
return
}
process := cmd.Process
if process != nil && process.Pid > 0 {
// Send SIGTERM to the process group of cmd
syscall.Kill(-process.Pid, syscall.SIGTERM)
}
// reap our child process
cmd.Wait()
}
func setNoCacheHeaders(header http.Header) {
header.Set("Cache-Control", "no-cache, no-store, max-age=0, must-revalidate") header.Set("Cache-Control", "no-cache, no-store, max-age=0, must-revalidate")
header.Set("Pragma", "no-cache") header.Set("Pragma", "no-cache")
header.Set("Expires", "Fri, 01 Jan 1990 00:00:00 GMT") header.Set("Expires", "Fri, 01 Jan 1990 00:00:00 GMT")
} }
func openFile(path string) (file *os.File, fi os.FileInfo, err error) { func OpenFile(path string) (file *os.File, fi os.FileInfo, err error) {
file, err = os.Open(path) file, err = os.Open(path)
if err != nil { if err != nil {
return return
...@@ -101,20 +55,48 @@ func openFile(path string) (file *os.File, fi os.FileInfo, err error) { ...@@ -101,20 +55,48 @@ func openFile(path string) (file *os.File, fi os.FileInfo, err error) {
return return
} }
// Borrowed from: net/http/server.go func URLMustParse(s string) *url.URL {
// Return the canonical path for p, eliminating . and .. elements. u, err := url.Parse(s)
func cleanURIPath(p string) string { if err != nil {
if p == "" { log.Fatalf("urlMustParse: %q %v", s, err)
return "/"
} }
if p[0] != '/' { return u
p = "/" + p }
func HTTPError(w http.ResponseWriter, r *http.Request, error string, code int) {
if r.ProtoAtLeast(1, 1) {
// Force client to disconnect if we render request error
w.Header().Set("Connection", "close")
}
http.Error(w, error, code)
}
func CleanUpProcessGroup(cmd *exec.Cmd) {
if cmd == nil {
return
}
process := cmd.Process
if process != nil && process.Pid > 0 {
// Send SIGTERM to the process group of cmd
syscall.Kill(-process.Pid, syscall.SIGTERM)
} }
np := path.Clean(p)
// path.Clean removes trailing slash except for root; // reap our child process
// put the trailing slash back if necessary. cmd.Wait()
if p[len(p)-1] == '/' && np != "/" { }
np += "/"
func ExitStatus(err error) (int, bool) {
exitError, ok := err.(*exec.ExitError)
if !ok {
return 0, false
} }
return np
waitStatus, ok := exitError.Sys().(syscall.WaitStatus)
if !ok {
return 0, false
}
return waitStatus.ExitStatus(), true
} }
package main package helper
import ( import (
"fmt" "fmt"
...@@ -6,25 +6,25 @@ import ( ...@@ -6,25 +6,25 @@ import (
"time" "time"
) )
type loggingResponseWriter struct { type LoggingResponseWriter struct {
rw http.ResponseWriter rw http.ResponseWriter
status int status int
written int64 written int64
started time.Time started time.Time
} }
func newLoggingResponseWriter(rw http.ResponseWriter) loggingResponseWriter { func NewLoggingResponseWriter(rw http.ResponseWriter) LoggingResponseWriter {
return loggingResponseWriter{ return LoggingResponseWriter{
rw: rw, rw: rw,
started: time.Now(), started: time.Now(),
} }
} }
func (l *loggingResponseWriter) Header() http.Header { func (l *LoggingResponseWriter) Header() http.Header {
return l.rw.Header() return l.rw.Header()
} }
func (l *loggingResponseWriter) Write(data []byte) (n int, err error) { func (l *LoggingResponseWriter) Write(data []byte) (n int, err error) {
if l.status == 0 { if l.status == 0 {
l.WriteHeader(http.StatusOK) l.WriteHeader(http.StatusOK)
} }
...@@ -33,7 +33,7 @@ func (l *loggingResponseWriter) Write(data []byte) (n int, err error) { ...@@ -33,7 +33,7 @@ func (l *loggingResponseWriter) Write(data []byte) (n int, err error) {
return return
} }
func (l *loggingResponseWriter) WriteHeader(status int) { func (l *LoggingResponseWriter) WriteHeader(status int) {
if l.status != 0 { if l.status != 0 {
return return
} }
...@@ -42,7 +42,7 @@ func (l *loggingResponseWriter) WriteHeader(status int) { ...@@ -42,7 +42,7 @@ func (l *loggingResponseWriter) WriteHeader(status int) {
l.rw.WriteHeader(status) l.rw.WriteHeader(status)
} }
func (l *loggingResponseWriter) Log(r *http.Request) { func (l *LoggingResponseWriter) Log(r *http.Request) {
duration := time.Since(l.started) duration := time.Since(l.started)
fmt.Printf("%s %s - - [%s] %q %d %d %q %q %f\n", fmt.Printf("%s %s - - [%s] %q %d %d %q %q %f\n",
r.Host, r.RemoteAddr, l.started, r.Host, r.RemoteAddr, l.started,
......
/*
In this file we handle git lfs objects downloads and uploads
*/
package lfs
import (
"../api"
"../helper"
"../proxy"
"bytes"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"os"
"path/filepath"
)
func PutStore(a *api.API, p *proxy.Proxy) http.Handler {
return lfsAuthorizeHandler(a, handleStoreLfsObject(p))
}
func lfsAuthorizeHandler(myAPI *api.API, handleFunc api.HandleFunc) http.Handler {
return myAPI.PreAuthorizeHandler(func(w http.ResponseWriter, r *http.Request, a *api.Response) {
if a.StoreLFSPath == "" {
helper.Fail500(w, errors.New("lfsAuthorizeHandler: StoreLFSPath empty"))
return
}
if a.LfsOid == "" {
helper.Fail500(w, errors.New("lfsAuthorizeHandler: LfsOid empty"))
return
}
if err := os.MkdirAll(a.StoreLFSPath, 0700); err != nil {
helper.Fail500(w, fmt.Errorf("lfsAuthorizeHandler: mkdir StoreLFSPath: %v", err))
return
}
handleFunc(w, r, a)
}, "/authorize")
}
func handleStoreLfsObject(h http.Handler) api.HandleFunc {
return func(w http.ResponseWriter, r *http.Request, a *api.Response) {
file, err := ioutil.TempFile(a.StoreLFSPath, a.LfsOid)
if err != nil {
helper.Fail500(w, fmt.Errorf("handleStoreLfsObject: create tempfile: %v", err))
return
}
defer os.Remove(file.Name())
defer file.Close()
hash := sha256.New()
hw := io.MultiWriter(hash, file)
written, err := io.Copy(hw, r.Body)
if err != nil {
helper.Fail500(w, fmt.Errorf("handleStoreLfsObject: write tempfile: %v", err))
return
}
file.Close()
if written != a.LfsSize {
helper.Fail500(w, fmt.Errorf("handleStoreLfsObject: expected size %d, wrote %d", a.LfsSize, written))
return
}
shaStr := hex.EncodeToString(hash.Sum(nil))
if shaStr != a.LfsOid {
helper.Fail500(w, fmt.Errorf("handleStoreLfsObject: expected sha256 %s, got %s", a.LfsOid, shaStr))
return
}
// Inject header and body
r.Header.Set("X-GitLab-Lfs-Tmp", filepath.Base(file.Name()))
r.Body = ioutil.NopCloser(&bytes.Buffer{})
r.ContentLength = 0
// And proxy the request
h.ServeHTTP(w, r)
}
}
package proxy
import (
"../badgateway"
"net/http"
"net/http/httputil"
"net/url"
)
type Proxy struct {
Version string
reverseProxy *httputil.ReverseProxy
}
func NewProxy(myURL *url.URL, version string, roundTripper *badgateway.RoundTripper) *Proxy {
p := Proxy{Version: version}
u := *myURL // Make a copy of p.URL
u.Path = ""
p.reverseProxy = httputil.NewSingleHostReverseProxy(&u)
if roundTripper != nil {
p.reverseProxy.Transport = roundTripper
} else {
p.reverseProxy.Transport = badgateway.NewRoundTripper("", 0)
}
return &p
}
func HeaderClone(h http.Header) http.Header {
h2 := make(http.Header, len(h))
for k, vv := range h {
vv2 := make([]string, len(vv))
copy(vv2, vv)
h2[k] = vv2
}
return h2
}
func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Clone request
req := *r
req.Header = HeaderClone(r.Header)
// Set Workhorse version
req.Header.Set("Gitlab-Workhorse", p.Version)
rw := newSendFileResponseWriter(w, &req)
defer rw.Flush()
p.reverseProxy.ServeHTTP(&rw, &req)
}
...@@ -4,9 +4,10 @@ via the X-Sendfile mechanism. All that is needed in the Rails code is the ...@@ -4,9 +4,10 @@ via the X-Sendfile mechanism. All that is needed in the Rails code is the
'send_file' method. 'send_file' method.
*/ */
package main package proxy
import ( import (
"../helper"
"log" "log"
"net/http" "net/http"
) )
...@@ -63,7 +64,7 @@ func (s *sendFileResponseWriter) WriteHeader(status int) { ...@@ -63,7 +64,7 @@ func (s *sendFileResponseWriter) WriteHeader(status int) {
// Serve the file // Serve the file
log.Printf("Send file %q for %s %q", file, s.req.Method, s.req.RequestURI) log.Printf("Send file %q for %s %q", file, s.req.Method, s.req.RequestURI)
content, fi, err := openFile(file) content, fi, err := helper.OpenFile(file)
if err != nil { if err != nil {
http.NotFound(s.rw, s.req) http.NotFound(s.rw, s.req)
return return
......
package main package staticpages
import ( import (
"../helper"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"path/filepath" "path/filepath"
) )
func handleDeployPage(documentRoot *string, handler serviceHandleFunc) serviceHandleFunc { func (s *Static) DeployPage(handler http.Handler) http.Handler {
return func(w http.ResponseWriter, r *gitRequest) { deployPage := filepath.Join(s.DocumentRoot, "index.html")
deployPage := filepath.Join(*documentRoot, "index.html")
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
data, err := ioutil.ReadFile(deployPage) data, err := ioutil.ReadFile(deployPage)
if err != nil { if err != nil {
handler(w, r) handler.ServeHTTP(w, r)
return return
} }
setNoCacheHeaders(w.Header()) helper.SetNoCacheHeaders(w.Header())
w.Header().Set("Content-Type", "text/html; charset=utf-8") w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write(data) w.Write(data)
} })
} }
package main package staticpages
import ( import (
"../testhelper"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
...@@ -19,9 +20,10 @@ func TestIfNoDeployPageExist(t *testing.T) { ...@@ -19,9 +20,10 @@ func TestIfNoDeployPageExist(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
executed := false executed := false
handleDeployPage(&dir, func(w http.ResponseWriter, r *gitRequest) { st := &Static{dir}
st.DeployPage(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
executed = true executed = true
})(w, nil) })).ServeHTTP(w, nil)
if !executed { if !executed {
t.Error("The handler should get executed") t.Error("The handler should get executed")
} }
...@@ -40,14 +42,15 @@ func TestIfDeployPageExist(t *testing.T) { ...@@ -40,14 +42,15 @@ func TestIfDeployPageExist(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
executed := false executed := false
handleDeployPage(&dir, func(w http.ResponseWriter, r *gitRequest) { st := &Static{dir}
st.DeployPage(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
executed = true executed = true
})(w, nil) })).ServeHTTP(w, nil)
if executed { if executed {
t.Error("The handler should not get executed") t.Error("The handler should not get executed")
} }
w.Flush() w.Flush()
assertResponseCode(t, w, 200) testhelper.AssertResponseCode(t, w, 200)
assertResponseBody(t, w, deployPage) testhelper.AssertResponseBody(t, w, deployPage)
} }
package main package staticpages
import ( import (
"../helper"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"log" "log"
...@@ -12,7 +13,7 @@ type errorPageResponseWriter struct { ...@@ -12,7 +13,7 @@ type errorPageResponseWriter struct {
rw http.ResponseWriter rw http.ResponseWriter
status int status int
hijacked bool hijacked bool
path *string path string
} }
func (s *errorPageResponseWriter) Header() http.Header { func (s *errorPageResponseWriter) Header() http.Header {
...@@ -37,14 +38,14 @@ func (s *errorPageResponseWriter) WriteHeader(status int) { ...@@ -37,14 +38,14 @@ func (s *errorPageResponseWriter) WriteHeader(status int) {
s.status = status s.status = status
if 400 <= s.status && s.status <= 599 { if 400 <= s.status && s.status <= 599 {
errorPageFile := filepath.Join(*s.path, fmt.Sprintf("%d.html", s.status)) errorPageFile := filepath.Join(s.path, fmt.Sprintf("%d.html", s.status))
// check if custom error page exists, serve this page instead // check if custom error page exists, serve this page instead
if data, err := ioutil.ReadFile(errorPageFile); err == nil { if data, err := ioutil.ReadFile(errorPageFile); err == nil {
s.hijacked = true s.hijacked = true
log.Printf("ErrorPage: serving predefined error page: %d", s.status) log.Printf("ErrorPage: serving predefined error page: %d", s.status)
setNoCacheHeaders(s.rw.Header()) helper.SetNoCacheHeaders(s.rw.Header())
s.rw.Header().Set("Content-Type", "text/html; charset=utf-8") s.rw.Header().Set("Content-Type", "text/html; charset=utf-8")
s.rw.WriteHeader(s.status) s.rw.WriteHeader(s.status)
s.rw.Write(data) s.rw.Write(data)
...@@ -59,16 +60,16 @@ func (s *errorPageResponseWriter) Flush() { ...@@ -59,16 +60,16 @@ func (s *errorPageResponseWriter) Flush() {
s.WriteHeader(http.StatusOK) s.WriteHeader(http.StatusOK)
} }
func handleRailsError(documentRoot *string, enabled *bool, handler serviceHandleFunc) serviceHandleFunc { func (st *Static) ErrorPagesUnless(disabled bool, handler http.Handler) http.Handler {
if !*enabled { if disabled {
return handler return handler
} }
return func(w http.ResponseWriter, r *gitRequest) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rw := errorPageResponseWriter{ rw := errorPageResponseWriter{
rw: w, rw: w,
path: documentRoot, path: st.DocumentRoot,
} }
defer rw.Flush() defer rw.Flush()
handler(&rw, r) handler.ServeHTTP(&rw, r)
} })
} }
package main package staticpages
import ( import (
"../testhelper"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
...@@ -21,16 +22,16 @@ func TestIfErrorPageIsPresented(t *testing.T) { ...@@ -21,16 +22,16 @@ func TestIfErrorPageIsPresented(t *testing.T) {
ioutil.WriteFile(filepath.Join(dir, "404.html"), []byte(errorPage), 0600) ioutil.WriteFile(filepath.Join(dir, "404.html"), []byte(errorPage), 0600)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
enabled := true
handleRailsError(&dir, &enabled, func(w http.ResponseWriter, r *gitRequest) {
w.WriteHeader(404) w.WriteHeader(404)
fmt.Fprint(w, "Not Found") fmt.Fprint(w, "Not Found")
})(w, nil) })
st := &Static{dir}
st.ErrorPagesUnless(false, h).ServeHTTP(w, nil)
w.Flush() w.Flush()
assertResponseCode(t, w, 404) testhelper.AssertResponseCode(t, w, 404)
assertResponseBody(t, w, errorPage) testhelper.AssertResponseBody(t, w, errorPage)
} }
func TestIfErrorPassedIfNoErrorPageIsFound(t *testing.T) { func TestIfErrorPassedIfNoErrorPageIsFound(t *testing.T) {
...@@ -42,16 +43,16 @@ func TestIfErrorPassedIfNoErrorPageIsFound(t *testing.T) { ...@@ -42,16 +43,16 @@ func TestIfErrorPassedIfNoErrorPageIsFound(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
errorResponse := "ERROR" errorResponse := "ERROR"
h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
enabled := true
handleRailsError(&dir, &enabled, func(w http.ResponseWriter, r *gitRequest) {
w.WriteHeader(404) w.WriteHeader(404)
fmt.Fprint(w, errorResponse) fmt.Fprint(w, errorResponse)
})(w, nil) })
st := &Static{dir}
st.ErrorPagesUnless(false, h).ServeHTTP(w, nil)
w.Flush() w.Flush()
assertResponseCode(t, w, 404) testhelper.AssertResponseCode(t, w, 404)
assertResponseBody(t, w, errorResponse) testhelper.AssertResponseBody(t, w, errorResponse)
} }
func TestIfErrorPageIsIgnoredInDevelopment(t *testing.T) { func TestIfErrorPageIsIgnoredInDevelopment(t *testing.T) {
...@@ -65,15 +66,14 @@ func TestIfErrorPageIsIgnoredInDevelopment(t *testing.T) { ...@@ -65,15 +66,14 @@ func TestIfErrorPageIsIgnoredInDevelopment(t *testing.T) {
ioutil.WriteFile(filepath.Join(dir, "500.html"), []byte(errorPage), 0600) ioutil.WriteFile(filepath.Join(dir, "500.html"), []byte(errorPage), 0600)
w := httptest.NewRecorder() w := httptest.NewRecorder()
enabled := false
serverError := "Interesting Server Error" serverError := "Interesting Server Error"
handleRailsError(&dir, &enabled, func(w http.ResponseWriter, r *gitRequest) { h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(500) w.WriteHeader(500)
fmt.Fprint(w, serverError) fmt.Fprint(w, serverError)
})(w, nil) })
st := &Static{dir}
st.ErrorPagesUnless(true, h).ServeHTTP(w, nil)
w.Flush() w.Flush()
testhelper.AssertResponseCode(t, w, 500)
assertResponseCode(t, w, 500) testhelper.AssertResponseBody(t, w, serverError)
assertResponseBody(t, w, serverError)
} }
package main package staticpages
import ( import (
"../helper"
"../urlprefix"
"log" "log"
"net/http" "net/http"
"os" "os"
...@@ -19,13 +21,13 @@ const ( ...@@ -19,13 +21,13 @@ const (
// BUG/QUIRK: If a client requests 'foo%2Fbar' and 'foo/bar' exists, // BUG/QUIRK: If a client requests 'foo%2Fbar' and 'foo/bar' exists,
// handleServeFile will serve foo/bar instead of passing the request // handleServeFile will serve foo/bar instead of passing the request
// upstream. // upstream.
func handleServeFile(documentRoot *string, cache CacheMode, notFoundHandler serviceHandleFunc) serviceHandleFunc { func (s *Static) ServeExisting(prefix urlprefix.Prefix, cache CacheMode, notFoundHandler http.Handler) http.Handler {
return func(w http.ResponseWriter, r *gitRequest) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
file := filepath.Join(*documentRoot, r.relativeURIPath) file := filepath.Join(s.DocumentRoot, prefix.Strip(r.URL.Path))
// The filepath.Join does Clean traversing directories up // The filepath.Join does Clean traversing directories up
if !strings.HasPrefix(file, *documentRoot) { if !strings.HasPrefix(file, s.DocumentRoot) {
fail500(w, &os.PathError{ helper.Fail500(w, &os.PathError{
Op: "open", Op: "open",
Path: file, Path: file,
Err: os.ErrInvalid, Err: os.ErrInvalid,
...@@ -39,7 +41,7 @@ func handleServeFile(documentRoot *string, cache CacheMode, notFoundHandler serv ...@@ -39,7 +41,7 @@ func handleServeFile(documentRoot *string, cache CacheMode, notFoundHandler serv
// Serve pre-gzipped assets // Serve pre-gzipped assets
if acceptEncoding := r.Header.Get("Accept-Encoding"); strings.Contains(acceptEncoding, "gzip") { if acceptEncoding := r.Header.Get("Accept-Encoding"); strings.Contains(acceptEncoding, "gzip") {
content, fi, err = openFile(file + ".gz") content, fi, err = helper.OpenFile(file + ".gz")
if err == nil { if err == nil {
w.Header().Set("Content-Encoding", "gzip") w.Header().Set("Content-Encoding", "gzip")
} }
...@@ -47,13 +49,13 @@ func handleServeFile(documentRoot *string, cache CacheMode, notFoundHandler serv ...@@ -47,13 +49,13 @@ func handleServeFile(documentRoot *string, cache CacheMode, notFoundHandler serv
// If not found, open the original file // If not found, open the original file
if content == nil || err != nil { if content == nil || err != nil {
content, fi, err = openFile(file) content, fi, err = helper.OpenFile(file)
} }
if err != nil { if err != nil {
if notFoundHandler != nil { if notFoundHandler != nil {
notFoundHandler(w, r) notFoundHandler.ServeHTTP(w, r)
} else { } else {
http.NotFound(w, r.Request) http.NotFound(w, r)
} }
return return
} }
...@@ -68,6 +70,6 @@ func handleServeFile(documentRoot *string, cache CacheMode, notFoundHandler serv ...@@ -68,6 +70,6 @@ func handleServeFile(documentRoot *string, cache CacheMode, notFoundHandler serv
} }
log.Printf("Send static file %q (%q) for %s %q", file, w.Header().Get("Content-Encoding"), r.Method, r.RequestURI) log.Printf("Send static file %q (%q) for %s %q", file, w.Header().Get("Content-Encoding"), r.Method, r.RequestURI)
http.ServeContent(w, r.Request, filepath.Base(file), fi.ModTime(), content) http.ServeContent(w, r, filepath.Base(file), fi.ModTime(), content)
} })
} }
package main package staticpages
import ( import (
"../testhelper"
"bytes" "bytes"
"compress/gzip" "compress/gzip"
"io/ioutil" "io/ioutil"
...@@ -14,14 +15,11 @@ import ( ...@@ -14,14 +15,11 @@ import (
func TestServingNonExistingFile(t *testing.T) { func TestServingNonExistingFile(t *testing.T) {
dir := "/path/to/non/existing/directory" dir := "/path/to/non/existing/directory"
httpRequest, _ := http.NewRequest("GET", "/file", nil) httpRequest, _ := http.NewRequest("GET", "/file", nil)
request := &gitRequest{
Request: httpRequest,
relativeURIPath: "/static/file",
}
w := httptest.NewRecorder() w := httptest.NewRecorder()
handleServeFile(&dir, CacheDisabled, nil)(w, request) st := &Static{dir}
assertResponseCode(t, w, 404) st.ServeExisting("/", CacheDisabled, nil).ServeHTTP(w, httpRequest)
testhelper.AssertResponseCode(t, w, 404)
} }
func TestServingDirectory(t *testing.T) { func TestServingDirectory(t *testing.T) {
...@@ -32,41 +30,31 @@ func TestServingDirectory(t *testing.T) { ...@@ -32,41 +30,31 @@ func TestServingDirectory(t *testing.T) {
defer os.RemoveAll(dir) defer os.RemoveAll(dir)
httpRequest, _ := http.NewRequest("GET", "/file", nil) httpRequest, _ := http.NewRequest("GET", "/file", nil)
request := &gitRequest{
Request: httpRequest,
relativeURIPath: "/",
}
w := httptest.NewRecorder() w := httptest.NewRecorder()
handleServeFile(&dir, CacheDisabled, nil)(w, request) st := &Static{dir}
assertResponseCode(t, w, 404) st.ServeExisting("/", CacheDisabled, nil).ServeHTTP(w, httpRequest)
testhelper.AssertResponseCode(t, w, 404)
} }
func TestServingMalformedUri(t *testing.T) { func TestServingMalformedUri(t *testing.T) {
dir := "/path/to/non/existing/directory" dir := "/path/to/non/existing/directory"
httpRequest, _ := http.NewRequest("GET", "/file", nil) httpRequest, _ := http.NewRequest("GET", "/../../../static/file", nil)
request := &gitRequest{
Request: httpRequest,
relativeURIPath: "/../../../static/file",
}
w := httptest.NewRecorder() w := httptest.NewRecorder()
handleServeFile(&dir, CacheDisabled, nil)(w, request) st := &Static{dir}
assertResponseCode(t, w, 500) st.ServeExisting("/", CacheDisabled, nil).ServeHTTP(w, httpRequest)
testhelper.AssertResponseCode(t, w, 404)
} }
func TestExecutingHandlerWhenNoFileFound(t *testing.T) { func TestExecutingHandlerWhenNoFileFound(t *testing.T) {
dir := "/path/to/non/existing/directory" dir := "/path/to/non/existing/directory"
httpRequest, _ := http.NewRequest("GET", "/file", nil) httpRequest, _ := http.NewRequest("GET", "/file", nil)
request := &gitRequest{
Request: httpRequest,
relativeURIPath: "/static/file",
}
executed := false executed := false
handleServeFile(&dir, CacheDisabled, func(w http.ResponseWriter, r *gitRequest) { st := &Static{dir}
executed = (r == request) st.ServeExisting("/", CacheDisabled, http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
})(nil, request) executed = (r == httpRequest)
})).ServeHTTP(nil, httpRequest)
if !executed { if !executed {
t.Error("The handler should get executed") t.Error("The handler should get executed")
} }
...@@ -80,17 +68,14 @@ func TestServingTheActualFile(t *testing.T) { ...@@ -80,17 +68,14 @@ func TestServingTheActualFile(t *testing.T) {
defer os.RemoveAll(dir) defer os.RemoveAll(dir)
httpRequest, _ := http.NewRequest("GET", "/file", nil) httpRequest, _ := http.NewRequest("GET", "/file", nil)
request := &gitRequest{
Request: httpRequest,
relativeURIPath: "/file",
}
fileContent := "STATIC" fileContent := "STATIC"
ioutil.WriteFile(filepath.Join(dir, "file"), []byte(fileContent), 0600) ioutil.WriteFile(filepath.Join(dir, "file"), []byte(fileContent), 0600)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handleServeFile(&dir, CacheDisabled, nil)(w, request) st := &Static{dir}
assertResponseCode(t, w, 200) st.ServeExisting("/", CacheDisabled, nil).ServeHTTP(w, httpRequest)
testhelper.AssertResponseCode(t, w, 200)
if w.Body.String() != fileContent { if w.Body.String() != fileContent {
t.Error("We should serve the file: ", w.Body.String()) t.Error("We should serve the file: ", w.Body.String())
} }
...@@ -104,10 +89,6 @@ func testServingThePregzippedFile(t *testing.T, enableGzip bool) { ...@@ -104,10 +89,6 @@ func testServingThePregzippedFile(t *testing.T, enableGzip bool) {
defer os.RemoveAll(dir) defer os.RemoveAll(dir)
httpRequest, _ := http.NewRequest("GET", "/file", nil) httpRequest, _ := http.NewRequest("GET", "/file", nil)
request := &gitRequest{
Request: httpRequest,
relativeURIPath: "/file",
}
if enableGzip { if enableGzip {
httpRequest.Header.Set("Accept-Encoding", "gzip, deflate") httpRequest.Header.Set("Accept-Encoding", "gzip, deflate")
...@@ -124,16 +105,17 @@ func testServingThePregzippedFile(t *testing.T, enableGzip bool) { ...@@ -124,16 +105,17 @@ func testServingThePregzippedFile(t *testing.T, enableGzip bool) {
ioutil.WriteFile(filepath.Join(dir, "file"), []byte(fileContent), 0600) ioutil.WriteFile(filepath.Join(dir, "file"), []byte(fileContent), 0600)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handleServeFile(&dir, CacheDisabled, nil)(w, request) st := &Static{dir}
assertResponseCode(t, w, 200) st.ServeExisting("/", CacheDisabled, nil).ServeHTTP(w, httpRequest)
testhelper.AssertResponseCode(t, w, 200)
if enableGzip { if enableGzip {
assertResponseHeader(t, w, "Content-Encoding", "gzip") testhelper.AssertResponseHeader(t, w, "Content-Encoding", "gzip")
if bytes.Compare(w.Body.Bytes(), fileGzipContent.Bytes()) != 0 { if bytes.Compare(w.Body.Bytes(), fileGzipContent.Bytes()) != 0 {
t.Error("We should serve the pregzipped file") t.Error("We should serve the pregzipped file")
} }
} else { } else {
assertResponseCode(t, w, 200) testhelper.AssertResponseCode(t, w, 200)
assertResponseHeader(t, w, "Content-Encoding", "") testhelper.AssertResponseHeader(t, w, "Content-Encoding", "")
if w.Body.String() != fileContent { if w.Body.String() != fileContent {
t.Error("We should serve the file: ", w.Body.String()) t.Error("We should serve the file: ", w.Body.String())
} }
......
package staticpages
type Static struct {
DocumentRoot string
}
package main package testhelper
import ( import (
"log"
"net/http"
"net/http/httptest" "net/http/httptest"
"regexp"
"testing" "testing"
) )
func assertResponseCode(t *testing.T, response *httptest.ResponseRecorder, expectedCode int) { func AssertResponseCode(t *testing.T, response *httptest.ResponseRecorder, expectedCode int) {
if response.Code != expectedCode { if response.Code != expectedCode {
t.Fatalf("for HTTP request expected to get %d, got %d instead", expectedCode, response.Code) t.Fatalf("for HTTP request expected to get %d, got %d instead", expectedCode, response.Code)
} }
} }
func assertResponseBody(t *testing.T, response *httptest.ResponseRecorder, expectedBody string) { func AssertResponseBody(t *testing.T, response *httptest.ResponseRecorder, expectedBody string) {
if response.Body.String() != expectedBody { if response.Body.String() != expectedBody {
t.Fatalf("for HTTP request expected to receive %q, got %q instead as body", expectedBody, response.Body.String()) t.Fatalf("for HTTP request expected to receive %q, got %q instead as body", expectedBody, response.Body.String())
} }
} }
func assertResponseHeader(t *testing.T, response *httptest.ResponseRecorder, header string, expectedValue string) { func AssertResponseHeader(t *testing.T, response *httptest.ResponseRecorder, header string, expectedValue string) {
if response.Header().Get(header) != expectedValue { if response.Header().Get(header) != expectedValue {
t.Fatalf("for HTTP request expected to receive the header %q with %q, got %q", header, expectedValue, response.Header().Get(header)) t.Fatalf("for HTTP request expected to receive the header %q with %q, got %q", header, expectedValue, response.Header().Get(header))
} }
} }
func TestServerWithHandler(url *regexp.Regexp, handler http.HandlerFunc) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if url != nil && !url.MatchString(r.URL.Path) {
log.Println("UPSTREAM", r.Method, r.URL, "DENY")
w.WriteHeader(404)
return
}
if version := r.Header.Get("Gitlab-Workhorse"); version == "" {
log.Println("UPSTREAM", r.Method, r.URL, "DENY")
w.WriteHeader(403)
return
}
handler(w, r)
}))
}
package main package upload
import ( import (
"../helper"
"bytes" "bytes"
"errors"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
...@@ -11,7 +11,12 @@ import ( ...@@ -11,7 +11,12 @@ import (
"os" "os"
) )
func rewriteFormFilesFromMultipart(r *gitRequest, writer *multipart.Writer) (cleanup func(), err error) { type MultipartFormProcessor interface {
ProcessFile(formName, fileName string, writer *multipart.Writer) error
ProcessField(formName string, writer *multipart.Writer) error
}
func rewriteFormFilesFromMultipart(r *http.Request, writer *multipart.Writer, tempPath string, filter MultipartFormProcessor) (cleanup func(), err error) {
// Create multipart reader // Create multipart reader
reader, err := r.MultipartReader() reader, err := r.MultipartReader()
if err != nil { if err != nil {
...@@ -47,12 +52,12 @@ func rewriteFormFilesFromMultipart(r *gitRequest, writer *multipart.Writer) (cle ...@@ -47,12 +52,12 @@ func rewriteFormFilesFromMultipart(r *gitRequest, writer *multipart.Writer) (cle
// Copy form field // Copy form field
if filename := p.FileName(); filename != "" { if filename := p.FileName(); filename != "" {
// Create temporary directory where the uploaded file will be stored // Create temporary directory where the uploaded file will be stored
if err := os.MkdirAll(r.TempPath, 0700); err != nil { if err := os.MkdirAll(tempPath, 0700); err != nil {
return cleanup, err return cleanup, err
} }
// Create temporary file in path returned by Authorization filter // Create temporary file in path returned by Authorization filter
file, err := ioutil.TempFile(r.TempPath, "upload_") file, err := ioutil.TempFile(tempPath, "upload_")
if err != nil { if err != nil {
return cleanup, err return cleanup, err
} }
...@@ -64,10 +69,15 @@ func rewriteFormFilesFromMultipart(r *gitRequest, writer *multipart.Writer) (cle ...@@ -64,10 +69,15 @@ func rewriteFormFilesFromMultipart(r *gitRequest, writer *multipart.Writer) (cle
files = append(files, file.Name()) files = append(files, file.Name())
_, err = io.Copy(file, p) _, err = io.Copy(file, p)
file.Close()
if err != nil { if err != nil {
return cleanup, err return cleanup, err
} }
file.Close()
if err := filter.ProcessFile(name, file.Name(), writer); err != nil {
return cleanup, err
}
} else { } else {
np, err := writer.CreatePart(p.Header) np, err := writer.CreatePart(p.Header)
if err != nil { if err != nil {
...@@ -78,14 +88,18 @@ func rewriteFormFilesFromMultipart(r *gitRequest, writer *multipart.Writer) (cle ...@@ -78,14 +88,18 @@ func rewriteFormFilesFromMultipart(r *gitRequest, writer *multipart.Writer) (cle
if err != nil { if err != nil {
return cleanup, err return cleanup, err
} }
if err := filter.ProcessField(name, writer); err != nil {
return cleanup, err
}
} }
} }
return cleanup, nil return cleanup, nil
} }
func handleFileUploads(w http.ResponseWriter, r *gitRequest) { func HandleFileUploads(w http.ResponseWriter, r *http.Request, h http.Handler, tempPath string, filter MultipartFormProcessor) {
if r.TempPath == "" { if tempPath == "" {
fail500(w, errors.New("handleFileUploads: TempPath empty")) helper.Fail500(w, fmt.Errorf("handleFileUploads: temporary path not defined"))
return return
} }
...@@ -94,12 +108,12 @@ func handleFileUploads(w http.ResponseWriter, r *gitRequest) { ...@@ -94,12 +108,12 @@ func handleFileUploads(w http.ResponseWriter, r *gitRequest) {
defer writer.Close() defer writer.Close()
// Rewrite multipart form data // Rewrite multipart form data
cleanup, err := rewriteFormFilesFromMultipart(r, writer) cleanup, err := rewriteFormFilesFromMultipart(r, writer, tempPath, filter)
if err != nil { if err != nil {
if err == http.ErrNotMultipart { if err == http.ErrNotMultipart {
proxyRequest(w, r) h.ServeHTTP(w, r)
} else { } else {
fail500(w, fmt.Errorf("handleFileUploads: extract files from multipart: %v", err)) helper.Fail500(w, fmt.Errorf("handleFileUploads: extract files from multipart: %v", err))
} }
return return
} }
...@@ -117,5 +131,5 @@ func handleFileUploads(w http.ResponseWriter, r *gitRequest) { ...@@ -117,5 +131,5 @@ func handleFileUploads(w http.ResponseWriter, r *gitRequest) {
r.Header.Set("Content-Type", writer.FormDataContentType()) r.Header.Set("Content-Type", writer.FormDataContentType())
// Proxy the request // Proxy the request
proxyRequest(w, r) h.ServeHTTP(w, r)
} }
package main package upload
import ( import (
"../helper"
"../proxy"
"../testhelper"
"bytes" "bytes"
"errors"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
...@@ -14,19 +18,34 @@ import ( ...@@ -14,19 +18,34 @@ import (
"testing" "testing"
) )
var nilHandler = http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})
type testFormProcessor struct {
}
func (a *testFormProcessor) ProcessFile(formName, fileName string, writer *multipart.Writer) error {
if formName != "file" && fileName != "my.file" {
return errors.New("illegal file")
}
return nil
}
func (a *testFormProcessor) ProcessField(formName string, writer *multipart.Writer) error {
if formName != "token" {
return errors.New("illegal field")
}
return nil
}
func TestUploadTempPathRequirement(t *testing.T) { func TestUploadTempPathRequirement(t *testing.T) {
response := httptest.NewRecorder() response := httptest.NewRecorder()
request := gitRequest{ request := &http.Request{}
authorizationResponse: authorizationResponse{ HandleFileUploads(response, request, nilHandler, "", nil)
TempPath: "", testhelper.AssertResponseCode(t, response, 500)
},
}
handleFileUploads(response, &request)
assertResponseCode(t, response, 500)
} }
func TestUploadHandlerForwardingRawData(t *testing.T) { func TestUploadHandlerForwardingRawData(t *testing.T) {
ts := testServerWithHandler(regexp.MustCompile(`/url/path\z`), func(w http.ResponseWriter, r *http.Request) { ts := testhelper.TestServerWithHandler(regexp.MustCompile(`/url/path\z`), func(w http.ResponseWriter, r *http.Request) {
if r.Method != "PATCH" { if r.Method != "PATCH" {
t.Fatal("Expected PATCH request") t.Fatal("Expected PATCH request")
} }
...@@ -40,6 +59,7 @@ func TestUploadHandlerForwardingRawData(t *testing.T) { ...@@ -40,6 +59,7 @@ func TestUploadHandlerForwardingRawData(t *testing.T) {
w.WriteHeader(202) w.WriteHeader(202)
fmt.Fprint(w, "RESPONSE") fmt.Fprint(w, "RESPONSE")
}) })
defer ts.Close()
httpRequest, err := http.NewRequest("PATCH", ts.URL+"/url/path", bytes.NewBufferString("REQUEST")) httpRequest, err := http.NewRequest("PATCH", ts.URL+"/url/path", bytes.NewBufferString("REQUEST"))
if err != nil { if err != nil {
...@@ -53,15 +73,10 @@ func TestUploadHandlerForwardingRawData(t *testing.T) { ...@@ -53,15 +73,10 @@ func TestUploadHandlerForwardingRawData(t *testing.T) {
defer os.RemoveAll(tempPath) defer os.RemoveAll(tempPath)
response := httptest.NewRecorder() response := httptest.NewRecorder()
request := gitRequest{
Request: httpRequest, handler := proxy.NewProxy(helper.URLMustParse(ts.URL), "123", nil)
u: newUpstream(ts.URL, nil), HandleFileUploads(response, httpRequest, handler, tempPath, nil)
authorizationResponse: authorizationResponse{ testhelper.AssertResponseCode(t, response, 202)
TempPath: tempPath,
},
}
handleFileUploads(response, &request)
assertResponseCode(t, response, 202)
if response.Body.String() != "RESPONSE" { if response.Body.String() != "RESPONSE" {
t.Fatal("Expected RESPONSE in response body") t.Fatal("Expected RESPONSE in response body")
} }
...@@ -76,7 +91,7 @@ func TestUploadHandlerRewritingMultiPartData(t *testing.T) { ...@@ -76,7 +91,7 @@ func TestUploadHandlerRewritingMultiPartData(t *testing.T) {
} }
defer os.RemoveAll(tempPath) defer os.RemoveAll(tempPath)
ts := testServerWithHandler(regexp.MustCompile(`/url/path\z`), func(w http.ResponseWriter, r *http.Request) { ts := testhelper.TestServerWithHandler(regexp.MustCompile(`/url/path\z`), func(w http.ResponseWriter, r *http.Request) {
if r.Method != "PUT" { if r.Method != "PUT" {
t.Fatal("Expected PUT request") t.Fatal("Expected PUT request")
} }
...@@ -131,19 +146,65 @@ func TestUploadHandlerRewritingMultiPartData(t *testing.T) { ...@@ -131,19 +146,65 @@ func TestUploadHandlerRewritingMultiPartData(t *testing.T) {
httpRequest.Body = ioutil.NopCloser(&buffer) httpRequest.Body = ioutil.NopCloser(&buffer)
httpRequest.ContentLength = int64(buffer.Len()) httpRequest.ContentLength = int64(buffer.Len())
httpRequest.Header.Set("Content-Type", writer.FormDataContentType()) httpRequest.Header.Set("Content-Type", writer.FormDataContentType())
response := httptest.NewRecorder() response := httptest.NewRecorder()
request := gitRequest{
Request: httpRequest, handler := proxy.NewProxy(helper.URLMustParse(ts.URL), "123", nil)
u: newUpstream(ts.URL, nil), HandleFileUploads(response, httpRequest, handler, tempPath, &testFormProcessor{})
authorizationResponse: authorizationResponse{ testhelper.AssertResponseCode(t, response, 202)
TempPath: tempPath,
},
}
handleFileUploads(response, &request)
assertResponseCode(t, response, 202)
if _, err := os.Stat(filePath); !os.IsNotExist(err) { if _, err := os.Stat(filePath); !os.IsNotExist(err) {
t.Fatal("expected the file to be deleted") t.Fatal("expected the file to be deleted")
} }
} }
func TestUploadProcessingField(t *testing.T) {
tempPath, err := ioutil.TempDir("", "uploads")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tempPath)
var buffer bytes.Buffer
writer := multipart.NewWriter(&buffer)
writer.WriteField("token2", "test")
writer.Close()
httpRequest, err := http.NewRequest("PUT", "/url/path", &buffer)
if err != nil {
t.Fatal(err)
}
httpRequest.Header.Set("Content-Type", writer.FormDataContentType())
response := httptest.NewRecorder()
HandleFileUploads(response, httpRequest, nilHandler, tempPath, &testFormProcessor{})
testhelper.AssertResponseCode(t, response, 500)
}
func TestUploadProcessingFile(t *testing.T) {
tempPath, err := ioutil.TempDir("", "uploads")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tempPath)
var buffer bytes.Buffer
writer := multipart.NewWriter(&buffer)
file, err := writer.CreateFormFile("file2", "my.file")
if err != nil {
t.Fatal(err)
}
fmt.Fprint(file, "test")
writer.Close()
httpRequest, err := http.NewRequest("PUT", "/url/path", &buffer)
if err != nil {
t.Fatal(err)
}
httpRequest.Header.Set("Content-Type", writer.FormDataContentType())
response := httptest.NewRecorder()
HandleFileUploads(response, httpRequest, nilHandler, tempPath, &testFormProcessor{})
testhelper.AssertResponseCode(t, response, 500)
}
package main package upstream
import ( import (
"../testhelper"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
...@@ -13,9 +14,9 @@ func TestDevelopmentModeEnabled(t *testing.T) { ...@@ -13,9 +14,9 @@ func TestDevelopmentModeEnabled(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
executed := false executed := false
handleDevelopmentMode(&developmentMode, func(w http.ResponseWriter, r *gitRequest) { NotFoundUnless(developmentMode, http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
executed = true executed = true
})(w, &gitRequest{Request: r}) })).ServeHTTP(w, r)
if !executed { if !executed {
t.Error("The handler should get executed") t.Error("The handler should get executed")
} }
...@@ -28,11 +29,11 @@ func TestDevelopmentModeDisabled(t *testing.T) { ...@@ -28,11 +29,11 @@ func TestDevelopmentModeDisabled(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
executed := false executed := false
handleDevelopmentMode(&developmentMode, func(w http.ResponseWriter, r *gitRequest) { NotFoundUnless(developmentMode, http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
executed = true executed = true
})(w, &gitRequest{Request: r}) })).ServeHTTP(w, r)
if executed { if executed {
t.Error("The handler should not get executed") t.Error("The handler should not get executed")
} }
assertResponseCode(t, w, 404) testhelper.AssertResponseCode(t, w, 404)
} }
package main package upstream
import ( import (
"../helper"
"compress/gzip" "compress/gzip"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
) )
func contentEncodingHandler(handleFunc serviceHandleFunc) serviceHandleFunc { func contentEncodingHandler(h http.Handler) http.Handler {
return func(w http.ResponseWriter, r *gitRequest) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var body io.ReadCloser var body io.ReadCloser
var err error var err error
...@@ -24,7 +25,7 @@ func contentEncodingHandler(handleFunc serviceHandleFunc) serviceHandleFunc { ...@@ -24,7 +25,7 @@ func contentEncodingHandler(handleFunc serviceHandleFunc) serviceHandleFunc {
} }
if err != nil { if err != nil {
fail500(w, fmt.Errorf("contentEncodingHandler: %v", err)) helper.Fail500(w, fmt.Errorf("contentEncodingHandler: %v", err))
return return
} }
defer body.Close() defer body.Close()
...@@ -32,6 +33,6 @@ func contentEncodingHandler(handleFunc serviceHandleFunc) serviceHandleFunc { ...@@ -32,6 +33,6 @@ func contentEncodingHandler(handleFunc serviceHandleFunc) serviceHandleFunc {
r.Body = body r.Body = body
r.Header.Del("Content-Encoding") r.Header.Del("Content-Encoding")
handleFunc(w, r) h.ServeHTTP(w, r)
} })
} }
package main package upstream
import ( import (
"../testhelper"
"bytes" "bytes"
"compress/gzip" "compress/gzip"
"fmt" "fmt"
...@@ -27,17 +28,16 @@ func TestGzipEncoding(t *testing.T) { ...@@ -27,17 +28,16 @@ func TestGzipEncoding(t *testing.T) {
} }
req.Header.Set("Content-Encoding", "gzip") req.Header.Set("Content-Encoding", "gzip")
request := gitRequest{Request: req} contentEncodingHandler(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
contentEncodingHandler(func(w http.ResponseWriter, r *gitRequest) {
if _, ok := r.Body.(*gzip.Reader); !ok { if _, ok := r.Body.(*gzip.Reader); !ok {
t.Fatal("Expected gzip reader for body, but it's:", reflect.TypeOf(r.Body)) t.Fatal("Expected gzip reader for body, but it's:", reflect.TypeOf(r.Body))
} }
if r.Header.Get("Content-Encoding") != "" { if r.Header.Get("Content-Encoding") != "" {
t.Fatal("Content-Encoding should be deleted") t.Fatal("Content-Encoding should be deleted")
} }
})(resp, &request) })).ServeHTTP(resp, req)
assertResponseCode(t, resp, 200) testhelper.AssertResponseCode(t, resp, 200)
} }
func TestNoEncoding(t *testing.T) { func TestNoEncoding(t *testing.T) {
...@@ -52,17 +52,16 @@ func TestNoEncoding(t *testing.T) { ...@@ -52,17 +52,16 @@ func TestNoEncoding(t *testing.T) {
} }
req.Header.Set("Content-Encoding", "") req.Header.Set("Content-Encoding", "")
request := gitRequest{Request: req} contentEncodingHandler(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
contentEncodingHandler(func(w http.ResponseWriter, r *gitRequest) {
if r.Body != body { if r.Body != body {
t.Fatal("Expected the same body") t.Fatal("Expected the same body")
} }
if r.Header.Get("Content-Encoding") != "" { if r.Header.Get("Content-Encoding") != "" {
t.Fatal("Content-Encoding should be deleted") t.Fatal("Content-Encoding should be deleted")
} }
})(resp, &request) })).ServeHTTP(resp, req)
assertResponseCode(t, resp, 200) testhelper.AssertResponseCode(t, resp, 200)
} }
func TestInvalidEncoding(t *testing.T) { func TestInvalidEncoding(t *testing.T) {
...@@ -74,10 +73,9 @@ func TestInvalidEncoding(t *testing.T) { ...@@ -74,10 +73,9 @@ func TestInvalidEncoding(t *testing.T) {
} }
req.Header.Set("Content-Encoding", "application/unknown") req.Header.Set("Content-Encoding", "application/unknown")
request := gitRequest{Request: req} contentEncodingHandler(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
contentEncodingHandler(func(w http.ResponseWriter, r *gitRequest) {
t.Fatal("it shouldn't be executed") t.Fatal("it shouldn't be executed")
})(resp, &request) })).ServeHTTP(resp, req)
assertResponseCode(t, resp, 500) testhelper.AssertResponseCode(t, resp, 500)
} }
package upstream
import "net/http"
func NotFoundUnless(pass bool, handler http.Handler) http.Handler {
if pass {
return handler
} else {
return http.HandlerFunc(http.NotFound)
}
}
package upstream
import (
apipkg "../api"
"../artifacts"
"../git"
"../lfs"
proxypkg "../proxy"
"../staticpages"
"net/http"
"regexp"
)
type route struct {
method string
regex *regexp.Regexp
handler http.Handler
}
const projectPattern = `^/[^/]+/[^/]+/`
const gitProjectPattern = `^/[^/]+/[^/]+\.git/`
const apiPattern = `^/api/`
// A project ID in an API request is either a number or two strings 'namespace/project'
const projectsAPIPattern = `^/api/v3/projects/((\d+)|([^/]+/[^/]+))/`
const ciAPIPattern = `^/ci/api/`
// Routing table
// We match against URI not containing the relativeUrlRoot:
// see upstream.ServeHTTP
func (u *Upstream) configureRoutes() {
api := apipkg.NewAPI(
u.Backend,
u.Version,
u.RoundTripper,
)
static := &staticpages.Static{u.DocumentRoot}
proxy := proxypkg.NewProxy(
u.Backend,
u.Version,
u.RoundTripper,
)
u.Routes = []route{
// Git Clone
route{"GET", regexp.MustCompile(gitProjectPattern + `info/refs\z`), git.GetInfoRefs(api)},
route{"POST", regexp.MustCompile(gitProjectPattern + `git-upload-pack\z`), contentEncodingHandler(git.PostRPC(api))},
route{"POST", regexp.MustCompile(gitProjectPattern + `git-receive-pack\z`), contentEncodingHandler(git.PostRPC(api))},
route{"PUT", regexp.MustCompile(gitProjectPattern + `gitlab-lfs/objects/([0-9a-f]{64})/([0-9]+)\z`), lfs.PutStore(api, proxy)},
// Repository Archive
route{"GET", regexp.MustCompile(projectPattern + `repository/archive\z`), git.GetArchive(api)},
route{"GET", regexp.MustCompile(projectPattern + `repository/archive.zip\z`), git.GetArchive(api)},
route{"GET", regexp.MustCompile(projectPattern + `repository/archive.tar\z`), git.GetArchive(api)},
route{"GET", regexp.MustCompile(projectPattern + `repository/archive.tar.gz\z`), git.GetArchive(api)},
route{"GET", regexp.MustCompile(projectPattern + `repository/archive.tar.bz2\z`), git.GetArchive(api)},
// Repository Archive API
route{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive\z`), git.GetArchive(api)},
route{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.zip\z`), git.GetArchive(api)},
route{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.tar\z`), git.GetArchive(api)},
route{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.tar.gz\z`), git.GetArchive(api)},
route{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.tar.bz2\z`), git.GetArchive(api)},
// CI Artifacts
route{"GET", regexp.MustCompile(projectPattern + `builds/[0-9]+/artifacts/file/`), contentEncodingHandler(artifacts.DownloadArtifact(api))},
route{"POST", regexp.MustCompile(ciAPIPattern + `v1/builds/[0-9]+/artifacts\z`), contentEncodingHandler(artifacts.UploadArtifacts(api, proxy))},
// Explicitly proxy API requests
route{"", regexp.MustCompile(apiPattern), proxy},
route{"", regexp.MustCompile(ciAPIPattern), proxy},
// Serve assets
route{"", regexp.MustCompile(`^/assets/`),
static.ServeExisting(u.URLPrefix, staticpages.CacheExpireMax,
NotFoundUnless(u.DevelopmentMode,
proxy,
),
),
},
// For legacy reasons, user uploads are stored under the document root.
// To prevent anybody who knows/guesses the URL of a user-uploaded file
// from downloading it we make sure requests to /uploads/ do _not_ pass
// through static.ServeExisting.
route{"", regexp.MustCompile(`^/uploads/`), static.ErrorPagesUnless(u.DevelopmentMode, proxy)},
// Serve static files or forward the requests
route{"", nil,
static.ServeExisting(u.URLPrefix, staticpages.CacheDisabled,
static.DeployPage(
static.ErrorPagesUnless(u.DevelopmentMode,
proxy,
),
),
),
},
}
}
/*
The upstream type implements http.Handler.
In this file we handle request routing and interaction with the authBackend.
*/
package upstream
import (
"../badgateway"
"../helper"
"../urlprefix"
"fmt"
"net/http"
"net/url"
"strings"
"time"
)
var DefaultBackend = helper.URLMustParse("http://localhost:8080")
type Upstream struct {
Backend *url.URL
Version string
DocumentRoot string
DevelopmentMode bool
URLPrefix urlprefix.Prefix
Routes []route
RoundTripper *badgateway.RoundTripper
}
func NewUpstream(backend *url.URL, socket string, version string, documentRoot string, developmentMode bool, proxyHeadersTimeout time.Duration) *Upstream {
up := Upstream{
Backend: backend,
Version: version,
DocumentRoot: documentRoot,
DevelopmentMode: developmentMode,
RoundTripper: badgateway.NewRoundTripper(socket, proxyHeadersTimeout),
}
if backend == nil {
up.Backend = DefaultBackend
}
up.configureURLPrefix()
up.configureRoutes()
return &up
}
func (u *Upstream) configureURLPrefix() {
relativeURLRoot := u.Backend.Path
if !strings.HasSuffix(relativeURLRoot, "/") {
relativeURLRoot += "/"
}
u.URLPrefix = urlprefix.Prefix(relativeURLRoot)
}
func (u *Upstream) ServeHTTP(ow http.ResponseWriter, r *http.Request) {
w := helper.NewLoggingResponseWriter(ow)
defer w.Log(r)
// Drop WebSocket connection and CONNECT method
if r.RequestURI == "*" {
helper.HTTPError(&w, r, "Connection upgrade not allowed", http.StatusBadRequest)
return
}
// Disallow connect
if r.Method == "CONNECT" {
helper.HTTPError(&w, r, "CONNECT not allowed", http.StatusBadRequest)
return
}
// Check URL Root
URIPath := urlprefix.CleanURIPath(r.URL.Path)
prefix := u.URLPrefix
if !prefix.Match(URIPath) {
helper.HTTPError(&w, r, fmt.Sprintf("Not found %q", URIPath), http.StatusNotFound)
return
}
// Look for a matching Git service
var ro route
foundService := false
for _, ro = range u.Routes {
if ro.method != "" && r.Method != ro.method {
continue
}
if ro.regex == nil || ro.regex.MatchString(prefix.Strip(URIPath)) {
foundService = true
break
}
}
if !foundService {
// The protocol spec in git/Documentation/technical/http-protocol.txt
// says we must return 403 if no matching service is found.
helper.HTTPError(&w, r, "Forbidden", http.StatusForbidden)
return
}
ro.handler.ServeHTTP(&w, r)
}
package urlprefix
import (
"path"
"strings"
)
type Prefix string
func (p Prefix) Strip(path string) string {
return CleanURIPath(strings.TrimPrefix(path, string(p)))
}
func (p Prefix) Match(path string) bool {
pre := string(p)
return strings.HasPrefix(path, pre) || path+"/" == pre
}
// Borrowed from: net/http/server.go
// Return the canonical path for p, eliminating . and .. elements.
func CleanURIPath(p string) string {
if p == "" {
return "/"
}
if p[0] != '/' {
p = "/" + p
}
np := path.Clean(p)
// path.Clean removes trailing slash except for root;
// put the trailing slash back if necessary.
if p[len(p)-1] == '/' && np != "/" {
np += "/"
}
return np
}
package zipartifacts
// These are exit codes used by subprocesses in cmd/gitlab-zip-xxx
const (
StatusNotZip = 10 + iota
StatusEntryNotFound
)
package zipartifacts
import (
"encoding/base64"
)
func DecodeFileEntry(entry string) (string, error) {
decoded, err := base64.StdEncoding.DecodeString(entry)
if err != nil {
return "", err
}
return string(decoded), nil
}
package zipartifacts
import (
"archive/zip"
"compress/gzip"
"encoding/binary"
"encoding/json"
"io"
"os"
"strconv"
)
type metadata struct {
Modified int64 `json:"modified"`
Mode string `json:"mode"`
CRC uint32 `json:"crc,omitempty"`
Size uint64 `json:"size,omitempty"`
Zipped uint64 `json:"zipped,omitempty"`
Comment string `json:"comment,omitempty"`
}
const MetadataHeaderPrefix = "\x00\x00\x00&" // length of string below, encoded properly
const MetadataHeader = "GitLab Build Artifacts Metadata 0.0.2\n"
func newMetadata(file *zip.File) metadata {
return metadata{
Modified: file.ModTime().Unix(),
Mode: strconv.FormatUint(uint64(file.Mode().Perm()), 8),
CRC: file.CRC32,
Size: file.UncompressedSize64,
Zipped: file.CompressedSize64,
Comment: file.Comment,
}
}
func (m metadata) writeEncoded(output io.Writer) error {
j, err := json.Marshal(m)
if err != nil {
return err
}
j = append(j, byte('\n'))
return writeBytes(output, j)
}
func writeZipEntryMetadata(output io.Writer, entry *zip.File) error {
err := writeString(output, entry.Name)
if err != nil {
return err
}
err = newMetadata(entry).writeEncoded(output)
if err != nil {
return err
}
return nil
}
func generateZipMetadata(output io.Writer, archive *zip.Reader) error {
err := writeString(output, MetadataHeader)
if err != nil {
return err
}
// Write empty error string
err = writeString(output, "{}")
if err != nil {
return err
}
// Write all files
for _, entry := range archive.File {
err = writeZipEntryMetadata(output, entry)
if err != nil {
return err
}
}
return nil
}
func GenerateZipMetadataFromFile(fileName string, w io.Writer) error {
archive, err := zip.OpenReader(fileName)
if err != nil {
// Ignore non-zip archives
return os.ErrInvalid
}
defer archive.Close()
gz := gzip.NewWriter(w)
defer gz.Close()
return generateZipMetadata(gz, &archive.Reader)
}
func writeBytes(output io.Writer, data []byte) error {
err := binary.Write(output, binary.BigEndian, uint32(len(data)))
if err == nil {
_, err = output.Write(data)
}
return err
}
func writeString(output io.Writer, str string) error {
return writeBytes(output, []byte(str))
}
/*
In this file we handle git lfs objects downloads and uploads
*/
package main
import (
"bytes"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"os"
"path/filepath"
)
func lfsAuthorizeHandler(handleFunc serviceHandleFunc) serviceHandleFunc {
return preAuthorizeHandler(func(w http.ResponseWriter, r *gitRequest) {
if r.StoreLFSPath == "" {
fail500(w, errors.New("lfsAuthorizeHandler: StoreLFSPath empty"))
return
}
if r.LfsOid == "" {
fail500(w, errors.New("lfsAuthorizeHandler: LfsOid empty"))
return
}
if err := os.MkdirAll(r.StoreLFSPath, 0700); err != nil {
fail500(w, fmt.Errorf("lfsAuthorizeHandler: mkdir StoreLFSPath: %v", err))
return
}
handleFunc(w, r)
}, "/authorize")
}
func handleStoreLfsObject(w http.ResponseWriter, r *gitRequest) {
file, err := ioutil.TempFile(r.StoreLFSPath, r.LfsOid)
if err != nil {
fail500(w, fmt.Errorf("handleStoreLfsObject: create tempfile: %v", err))
return
}
defer os.Remove(file.Name())
defer file.Close()
hash := sha256.New()
hw := io.MultiWriter(hash, file)
written, err := io.Copy(hw, r.Body)
if err != nil {
fail500(w, fmt.Errorf("handleStoreLfsObject: write tempfile: %v", err))
return
}
file.Close()
if written != r.LfsSize {
fail500(w, fmt.Errorf("handleStoreLfsObject: expected size %d, wrote %d", r.LfsSize, written))
return
}
shaStr := hex.EncodeToString(hash.Sum(nil))
if shaStr != r.LfsOid {
fail500(w, fmt.Errorf("handleStoreLfsObject: expected sha256 %s, got %s", r.LfsOid, shaStr))
return
}
// Inject header and body
r.Header.Set("X-GitLab-Lfs-Tmp", filepath.Base(file.Name()))
r.Body = ioutil.NopCloser(&bytes.Buffer{})
r.ContentLength = 0
// And proxy the request
proxyRequest(w, r)
}
...@@ -14,6 +14,7 @@ In this file we start the web server and hand off to the upstream type. ...@@ -14,6 +14,7 @@ In this file we start the web server and hand off to the upstream type.
package main package main
import ( import (
"./internal/upstream"
"flag" "flag"
"fmt" "fmt"
"log" "log"
...@@ -21,7 +22,6 @@ import ( ...@@ -21,7 +22,6 @@ import (
"net/http" "net/http"
_ "net/http/pprof" _ "net/http/pprof"
"os" "os"
"regexp"
"syscall" "syscall"
"time" "time"
) )
...@@ -33,95 +33,13 @@ var printVersion = flag.Bool("version", false, "Print version and exit") ...@@ -33,95 +33,13 @@ var printVersion = flag.Bool("version", false, "Print version and exit")
var listenAddr = flag.String("listenAddr", "localhost:8181", "Listen address for HTTP server") var listenAddr = flag.String("listenAddr", "localhost:8181", "Listen address for HTTP server")
var listenNetwork = flag.String("listenNetwork", "tcp", "Listen 'network' (tcp, tcp4, tcp6, unix)") var listenNetwork = flag.String("listenNetwork", "tcp", "Listen 'network' (tcp, tcp4, tcp6, unix)")
var listenUmask = flag.Int("listenUmask", 022, "Umask for Unix socket, default: 022") var listenUmask = flag.Int("listenUmask", 022, "Umask for Unix socket, default: 022")
var authBackend = flag.String("authBackend", "http://localhost:8080", "Authentication/authorization backend") var authBackend = URLFlag("authBackend", upstream.DefaultBackend, "Authentication/authorization backend")
var authSocket = flag.String("authSocket", "", "Optional: Unix domain socket to dial authBackend at") var authSocket = flag.String("authSocket", "", "Optional: Unix domain socket to dial authBackend at")
var pprofListenAddr = flag.String("pprofListenAddr", "", "pprof listening address, e.g. 'localhost:6060'") var pprofListenAddr = flag.String("pprofListenAddr", "", "pprof listening address, e.g. 'localhost:6060'")
var documentRoot = flag.String("documentRoot", "public", "Path to static files content") var documentRoot = flag.String("documentRoot", "public", "Path to static files content")
var responseHeadersTimeout = flag.Duration("proxyHeadersTimeout", time.Minute, "How long to wait for response headers when proxying the request") var proxyHeadersTimeout = flag.Duration("proxyHeadersTimeout", time.Minute, "How long to wait for response headers when proxying the request")
var developmentMode = flag.Bool("developmentMode", false, "Allow to serve assets from Rails app") var developmentMode = flag.Bool("developmentMode", false, "Allow to serve assets from Rails app")
type httpRoute struct {
method string
regex *regexp.Regexp
handleFunc serviceHandleFunc
}
const projectPattern = `^/[^/]+/[^/]+/`
const gitProjectPattern = `^/[^/]+/[^/]+\.git/`
const apiPattern = `^/api/`
// A project ID in an API request is either a number or two strings 'namespace/project'
const projectsAPIPattern = `^/api/v3/projects/((\d+)|([^/]+/[^/]+))/`
const ciAPIPattern = `^/ci/api/`
// Routing table
// We match against URI not containing the relativeUrlRoot:
// see upstream.ServeHTTP
var httpRoutes = [...]httpRoute{
// Git Clone
httpRoute{"GET", regexp.MustCompile(gitProjectPattern + `info/refs\z`), repoPreAuthorizeHandler(handleGetInfoRefs)},
httpRoute{"POST", regexp.MustCompile(gitProjectPattern + `git-upload-pack\z`), repoPreAuthorizeHandler(contentEncodingHandler(handlePostRPC))},
httpRoute{"POST", regexp.MustCompile(gitProjectPattern + `git-receive-pack\z`), repoPreAuthorizeHandler(contentEncodingHandler(handlePostRPC))},
httpRoute{"PUT", regexp.MustCompile(gitProjectPattern + `gitlab-lfs/objects/([0-9a-f]{64})/([0-9]+)\z`), lfsAuthorizeHandler(handleStoreLfsObject)},
// Repository Archive
httpRoute{"GET", regexp.MustCompile(projectPattern + `repository/archive\z`), repoPreAuthorizeHandler(handleGetArchive)},
httpRoute{"GET", regexp.MustCompile(projectPattern + `repository/archive.zip\z`), repoPreAuthorizeHandler(handleGetArchive)},
httpRoute{"GET", regexp.MustCompile(projectPattern + `repository/archive.tar\z`), repoPreAuthorizeHandler(handleGetArchive)},
httpRoute{"GET", regexp.MustCompile(projectPattern + `repository/archive.tar.gz\z`), repoPreAuthorizeHandler(handleGetArchive)},
httpRoute{"GET", regexp.MustCompile(projectPattern + `repository/archive.tar.bz2\z`), repoPreAuthorizeHandler(handleGetArchive)},
// Repository Archive API
httpRoute{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive\z`), repoPreAuthorizeHandler(handleGetArchive)},
httpRoute{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.zip\z`), repoPreAuthorizeHandler(handleGetArchive)},
httpRoute{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.tar\z`), repoPreAuthorizeHandler(handleGetArchive)},
httpRoute{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.tar.gz\z`), repoPreAuthorizeHandler(handleGetArchive)},
httpRoute{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.tar.bz2\z`), repoPreAuthorizeHandler(handleGetArchive)},
// CI Artifacts API
httpRoute{"POST", regexp.MustCompile(ciAPIPattern + `v1/builds/[0-9]+/artifacts\z`), artifactsAuthorizeHandler(contentEncodingHandler(handleFileUploads))},
// Explicitly proxy API requests
httpRoute{"", regexp.MustCompile(apiPattern), proxyRequest},
httpRoute{"", regexp.MustCompile(ciAPIPattern), proxyRequest},
// Serve assets
httpRoute{"", regexp.MustCompile(`^/assets/`),
handleServeFile(documentRoot, CacheExpireMax,
handleDevelopmentMode(developmentMode,
handleDeployPage(documentRoot,
handleRailsError(documentRoot, developmentMode,
proxyRequest,
),
),
),
),
},
// For legacy reasons, user uploads are stored under the document root.
// To prevent anybody who knows/guesses the URL of a user-uploaded file
// from downloading it we make sure requests to /uploads/ do _not_ pass
// through handleServeFile.
httpRoute{"", regexp.MustCompile(`^/uploads/`),
handleRailsError(documentRoot, developmentMode,
proxyRequest,
),
},
// Serve static files or forward the requests
httpRoute{"", nil,
handleServeFile(documentRoot, CacheDisabled,
handleDeployPage(documentRoot,
handleRailsError(documentRoot, developmentMode,
proxyRequest,
),
),
),
},
}
func main() { func main() {
flag.Usage = func() { flag.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage of %s:\n", os.Args[0]) fmt.Fprintf(os.Stderr, "Usage of %s:\n", os.Args[0])
...@@ -153,23 +71,6 @@ func main() { ...@@ -153,23 +71,6 @@ func main() {
log.Fatal(err) log.Fatal(err)
} }
// Create Proxy Transport
authTransport := http.DefaultTransport
if *authSocket != "" {
dialer := &net.Dialer{
// The values below are taken from http.DefaultTransport
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}
authTransport = &http.Transport{
Dial: func(_, _ string) (net.Conn, error) {
return dialer.Dial("unix", *authSocket)
},
ResponseHeaderTimeout: *responseHeadersTimeout,
}
}
proxyTransport := &proxyRoundTripper{transport: authTransport}
// The profiler will only be activated by HTTP requests. HTTP // The profiler will only be activated by HTTP requests. HTTP
// requests can only reach the profiler if we start a listener. So by // requests can only reach the profiler if we start a listener. So by
// having no profiler HTTP listener by default, the profiler is // having no profiler HTTP listener by default, the profiler is
...@@ -180,6 +81,14 @@ func main() { ...@@ -180,6 +81,14 @@ func main() {
}() }()
} }
upstream := newUpstream(*authBackend, proxyTransport) up := upstream.NewUpstream(
log.Fatal(http.Serve(listener, upstream)) *authBackend,
*authSocket,
Version,
*documentRoot,
*developmentMode,
*proxyHeadersTimeout,
)
log.Fatal(http.Serve(listener, up))
} }
package main package main
import ( import (
"./internal/api"
"./internal/helper"
"./internal/testhelper"
"./internal/upstream"
"bytes" "bytes"
"encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"log" "log"
"mime/multipart"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os" "os"
...@@ -18,9 +24,9 @@ import ( ...@@ -18,9 +24,9 @@ import (
"time" "time"
) )
const scratchDir = "test/scratch" const scratchDir = "testdata/scratch"
const testRepoRoot = "test/data" const testRepoRoot = "testdata/data"
const testDocumentRoot = "test/public" const testDocumentRoot = "testdata/public"
const testRepo = "group/test.git" const testRepo = "group/test.git"
const testProject = "group/test" const testProject = "group/test"
...@@ -325,7 +331,7 @@ func TestAllowedStaticFile(t *testing.T) { ...@@ -325,7 +331,7 @@ func TestAllowedStaticFile(t *testing.T) {
} }
proxied := false proxied := false
ts := testServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, r *http.Request) { ts := testhelper.TestServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, r *http.Request) {
proxied = true proxied = true
w.WriteHeader(404) w.WriteHeader(404)
}) })
...@@ -339,25 +345,56 @@ func TestAllowedStaticFile(t *testing.T) { ...@@ -339,25 +345,56 @@ func TestAllowedStaticFile(t *testing.T) {
} { } {
resp, err := http.Get(ws.URL + resource) resp, err := http.Get(ws.URL + resource)
if err != nil { if err != nil {
t.Fatal(err) t.Error(err)
} }
defer resp.Body.Close() defer resp.Body.Close()
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
if _, err := io.Copy(buf, resp.Body); err != nil { if _, err := io.Copy(buf, resp.Body); err != nil {
t.Fatal(err) t.Error(err)
} }
if buf.String() != content { if buf.String() != content {
t.Fatalf("GET %q: Expected %q, got %q", resource, content, buf.String()) t.Errorf("GET %q: Expected %q, got %q", resource, content, buf.String())
} }
if resp.StatusCode != 200 { if resp.StatusCode != 200 {
t.Fatalf("GET %q: expected 200, got %d", resource, resp.StatusCode) t.Errorf("GET %q: expected 200, got %d", resource, resp.StatusCode)
} }
if proxied { if proxied {
t.Fatalf("GET %q: should not have made it to backend", resource) t.Errorf("GET %q: should not have made it to backend", resource)
} }
} }
} }
func TestStaticFileRelativeURL(t *testing.T) {
content := "PUBLIC"
if err := setupStaticFile("static.txt", content); err != nil {
t.Fatalf("create public/static.txt: %v", err)
}
ts := testhelper.TestServerWithHandler(regexp.MustCompile(`.`), http.HandlerFunc(http.NotFound))
defer ts.Close()
backendURLString := ts.URL + "/my-relative-url"
log.Print(backendURLString)
ws := startWorkhorseServer(backendURLString)
defer ws.Close()
resource := "/my-relative-url/static.txt"
resp, err := http.Get(ws.URL + resource)
if err != nil {
t.Error(err)
}
defer resp.Body.Close()
buf := &bytes.Buffer{}
if _, err := io.Copy(buf, resp.Body); err != nil {
t.Error(err)
}
if buf.String() != content {
t.Errorf("GET %q: Expected %q, got %q", resource, content, buf.String())
}
if resp.StatusCode != 200 {
t.Errorf("GET %q: expected 200, got %d", resource, resp.StatusCode)
}
}
func TestAllowedPublicUploadsFile(t *testing.T) { func TestAllowedPublicUploadsFile(t *testing.T) {
content := "PRIVATE but allowed" content := "PRIVATE but allowed"
if err := setupStaticFile("uploads/static file.txt", content); err != nil { if err := setupStaticFile("uploads/static file.txt", content); err != nil {
...@@ -365,7 +402,7 @@ func TestAllowedPublicUploadsFile(t *testing.T) { ...@@ -365,7 +402,7 @@ func TestAllowedPublicUploadsFile(t *testing.T) {
} }
proxied := false proxied := false
ts := testServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, r *http.Request) { ts := testhelper.TestServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, r *http.Request) {
proxied = true proxied = true
w.Header().Add("X-Sendfile", *documentRoot+r.URL.Path) w.Header().Add("X-Sendfile", *documentRoot+r.URL.Path)
w.WriteHeader(200) w.WriteHeader(200)
...@@ -406,7 +443,7 @@ func TestDeniedPublicUploadsFile(t *testing.T) { ...@@ -406,7 +443,7 @@ func TestDeniedPublicUploadsFile(t *testing.T) {
} }
proxied := false proxied := false
ts := testServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, _ *http.Request) { ts := testhelper.TestServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, _ *http.Request) {
proxied = true proxied = true
w.WriteHeader(404) w.WriteHeader(404)
}) })
...@@ -439,6 +476,85 @@ func TestDeniedPublicUploadsFile(t *testing.T) { ...@@ -439,6 +476,85 @@ func TestDeniedPublicUploadsFile(t *testing.T) {
} }
} }
func TestArtifactsUpload(t *testing.T) {
reqBody := &bytes.Buffer{}
writer := multipart.NewWriter(reqBody)
file, err := writer.CreateFormFile("file", "my.file")
if err != nil {
t.Fatal(err)
}
fmt.Fprint(file, "SHOULD BE ON DISK, NOT IN MULTIPART")
writer.Close()
ts := testhelper.TestServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.URL.Path, "/authorize") {
if _, err := fmt.Fprintf(w, `{"TempPath":"%s"}`, scratchDir); err != nil {
t.Fatal(err)
}
return
}
err := r.ParseMultipartForm(100000)
if err != nil {
t.Fatal(err)
}
nValues := 2 // filename + path for just the upload (no metadata because we are not POSTing a valid zip file)
if len(r.MultipartForm.Value) != nValues {
t.Errorf("Expected to receive exactly %d values", nValues)
}
if len(r.MultipartForm.File) != 0 {
t.Error("Expected to not receive any files")
}
w.WriteHeader(200)
})
defer ts.Close()
ws := startWorkhorseServer(ts.URL)
defer ws.Close()
resource := `/ci/api/v1/builds/123/artifacts`
resp, err := http.Post(ws.URL+resource, writer.FormDataContentType(), reqBody)
if err != nil {
t.Error(err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Errorf("GET %q: expected 200, got %d", resource, resp.StatusCode)
}
}
func TestArtifactsGetSingleFile(t *testing.T) {
// We manually created this zip file in the gitlab-workhorse Git repository
archivePath := `testdata/artifacts-archive.zip`
fileName := "myfile"
fileContents := "MY FILE"
resourcePath := `/namespace/project/builds/123/artifacts/file/` + fileName
ts := testhelper.TestServerWithHandler(regexp.MustCompile(`\A`+resourcePath+`\z`), func(w http.ResponseWriter, r *http.Request) {
encodedFilename := base64.StdEncoding.EncodeToString([]byte(fileName))
if _, err := fmt.Fprintf(w, `{"Archive":"%s","Entry":"%s"}`, archivePath, encodedFilename); err != nil {
t.Fatal(err)
}
return
})
defer ts.Close()
ws := startWorkhorseServer(ts.URL)
defer ws.Close()
resp, err := http.Get(ws.URL + resourcePath)
if err != nil {
t.Error(err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Errorf("GET %q: expected 200, got %d", resourcePath, resp.StatusCode)
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
if string(body) != fileContents {
t.Fatalf("Expected file contents %q, got %q", fileContents, body)
}
}
func setupStaticFile(fpath, content string) error { func setupStaticFile(fpath, content string) error {
cwd, err := os.Getwd() cwd, err := os.Getwd()
if err != nil { if err != nil {
...@@ -476,26 +592,8 @@ func newBranch() string { ...@@ -476,26 +592,8 @@ func newBranch() string {
return fmt.Sprintf("branch-%d", time.Now().UnixNano()) return fmt.Sprintf("branch-%d", time.Now().UnixNano())
} }
func testServerWithHandler(url *regexp.Regexp, handler http.HandlerFunc) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if url != nil && !url.MatchString(r.URL.Path) {
log.Println("UPSTREAM", r.Method, r.URL, "DENY")
w.WriteHeader(404)
return
}
if version := r.Header.Get("Gitlab-Workhorse"); version == "" {
log.Println("UPSTREAM", r.Method, r.URL, "DENY")
w.WriteHeader(403)
return
}
handler(w, r)
}))
}
func testAuthServer(url *regexp.Regexp, code int, body interface{}) *httptest.Server { func testAuthServer(url *regexp.Regexp, code int, body interface{}) *httptest.Server {
return testServerWithHandler(url, func(w http.ResponseWriter, r *http.Request) { return testhelper.TestServerWithHandler(url, func(w http.ResponseWriter, r *http.Request) {
// Write pure string // Write pure string
if data, ok := body.(string); ok { if data, ok := body.(string); ok {
log.Println("UPSTREAM", r.Method, r.URL, code) log.Println("UPSTREAM", r.Method, r.URL, code)
...@@ -520,7 +618,15 @@ func testAuthServer(url *regexp.Regexp, code int, body interface{}) *httptest.Se ...@@ -520,7 +618,15 @@ func testAuthServer(url *regexp.Regexp, code int, body interface{}) *httptest.Se
} }
func startWorkhorseServer(authBackend string) *httptest.Server { func startWorkhorseServer(authBackend string) *httptest.Server {
return httptest.NewServer(newUpstream(authBackend, nil)) u := upstream.NewUpstream(
helper.URLMustParse(authBackend),
"",
"123",
testDocumentRoot,
false,
0,
)
return httptest.NewServer(u)
} }
func runOrFail(t *testing.T, cmd *exec.Cmd) { func runOrFail(t *testing.T, cmd *exec.Cmd) {
...@@ -532,7 +638,7 @@ func runOrFail(t *testing.T, cmd *exec.Cmd) { ...@@ -532,7 +638,7 @@ func runOrFail(t *testing.T, cmd *exec.Cmd) {
} }
func gitOkBody(t *testing.T) interface{} { func gitOkBody(t *testing.T) interface{} {
return &authorizationResponse{ return &api.Response{
GL_ID: "user-123", GL_ID: "user-123",
RepoPath: repoPath(t), RepoPath: repoPath(t),
} }
...@@ -545,7 +651,7 @@ func archiveOkBody(t *testing.T, archiveName string) interface{} { ...@@ -545,7 +651,7 @@ func archiveOkBody(t *testing.T, archiveName string) interface{} {
} }
archivePath := path.Join(cwd, cacheDir, archiveName) archivePath := path.Join(cwd, cacheDir, archiveName)
return &authorizationResponse{ return &api.Response{
RepoPath: repoPath(t), RepoPath: repoPath(t),
ArchivePath: archivePath, ArchivePath: archivePath,
CommitId: "c7fbe50c7c7419d9701eebe64b1fdacc3df5b9dd", CommitId: "c7fbe50c7c7419d9701eebe64b1fdacc3df5b9dd",
......
package main package main
import ( import (
"./internal/badgateway"
"./internal/helper"
"./internal/proxy"
"./internal/testhelper"
"bytes" "bytes"
"fmt" "fmt"
"io" "io"
...@@ -12,8 +16,12 @@ import ( ...@@ -12,8 +16,12 @@ import (
"time" "time"
) )
func newProxy(url string, rt *badgateway.RoundTripper) *proxy.Proxy {
return proxy.NewProxy(helper.URLMustParse(url), "123", rt)
}
func TestProxyRequest(t *testing.T) { func TestProxyRequest(t *testing.T) {
ts := testServerWithHandler(regexp.MustCompile(`/url/path\z`), func(w http.ResponseWriter, r *http.Request) { ts := testhelper.TestServerWithHandler(regexp.MustCompile(`/url/path\z`), func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" { if r.Method != "POST" {
t.Fatal("Expected POST request") t.Fatal("Expected POST request")
} }
...@@ -39,15 +47,10 @@ func TestProxyRequest(t *testing.T) { ...@@ -39,15 +47,10 @@ func TestProxyRequest(t *testing.T) {
} }
httpRequest.Header.Set("Custom-Header", "test") httpRequest.Header.Set("Custom-Header", "test")
request := gitRequest{
Request: httpRequest,
u: newUpstream(ts.URL, nil),
}
w := httptest.NewRecorder() w := httptest.NewRecorder()
proxyRequest(w, &request) newProxy(ts.URL, nil).ServeHTTP(w, httpRequest)
assertResponseCode(t, w, 202) testhelper.AssertResponseCode(t, w, 202)
assertResponseBody(t, w, "RESPONSE") testhelper.AssertResponseBody(t, w, "RESPONSE")
if w.Header().Get("Custom-Response-Header") != "test" { if w.Header().Get("Custom-Response-Header") != "test" {
t.Fatal("Expected custom response header") t.Fatal("Expected custom response header")
...@@ -61,23 +64,14 @@ func TestProxyError(t *testing.T) { ...@@ -61,23 +64,14 @@ func TestProxyError(t *testing.T) {
} }
httpRequest.Header.Set("Custom-Header", "test") httpRequest.Header.Set("Custom-Header", "test")
transport := proxyRoundTripper{
transport: http.DefaultTransport,
}
request := gitRequest{
Request: httpRequest,
u: newUpstream("http://localhost:655575/", &transport),
}
w := httptest.NewRecorder() w := httptest.NewRecorder()
proxyRequest(w, &request) newProxy("http://localhost:655575/", nil).ServeHTTP(w, httpRequest)
assertResponseCode(t, w, 502) testhelper.AssertResponseCode(t, w, 502)
assertResponseBody(t, w, "dial tcp: invalid port 655575") testhelper.AssertResponseBody(t, w, "dial tcp: invalid port 655575")
} }
func TestProxyReadTimeout(t *testing.T) { func TestProxyReadTimeout(t *testing.T) {
ts := testServerWithHandler(nil, func(w http.ResponseWriter, r *http.Request) { ts := testhelper.TestServerWithHandler(nil, func(w http.ResponseWriter, r *http.Request) {
time.Sleep(time.Minute) time.Sleep(time.Minute)
}) })
...@@ -86,8 +80,8 @@ func TestProxyReadTimeout(t *testing.T) { ...@@ -86,8 +80,8 @@ func TestProxyReadTimeout(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
transport := &proxyRoundTripper{ rt := &badgateway.RoundTripper{
transport: &http.Transport{ Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment, Proxy: http.ProxyFromEnvironment,
Dial: (&net.Dialer{ Dial: (&net.Dialer{
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
...@@ -98,19 +92,15 @@ func TestProxyReadTimeout(t *testing.T) { ...@@ -98,19 +92,15 @@ func TestProxyReadTimeout(t *testing.T) {
}, },
} }
request := gitRequest{ p := newProxy(ts.URL, rt)
Request: httpRequest,
u: newUpstream(ts.URL, transport),
}
w := httptest.NewRecorder() w := httptest.NewRecorder()
proxyRequest(w, &request) p.ServeHTTP(w, httpRequest)
assertResponseCode(t, w, 502) testhelper.AssertResponseCode(t, w, 502)
assertResponseBody(t, w, "net/http: timeout awaiting response headers") testhelper.AssertResponseBody(t, w, "net/http: timeout awaiting response headers")
} }
func TestProxyHandlerTimeout(t *testing.T) { func TestProxyHandlerTimeout(t *testing.T) {
ts := testServerWithHandler(nil, ts := testhelper.TestServerWithHandler(nil,
http.TimeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.TimeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(time.Second) time.Sleep(time.Second)
}), time.Millisecond, "Request took too long").ServeHTTP, }), time.Millisecond, "Request took too long").ServeHTTP,
...@@ -121,17 +111,8 @@ func TestProxyHandlerTimeout(t *testing.T) { ...@@ -121,17 +111,8 @@ func TestProxyHandlerTimeout(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
transport := &proxyRoundTripper{
transport: http.DefaultTransport,
}
request := gitRequest{
Request: httpRequest,
u: newUpstream(ts.URL, transport),
}
w := httptest.NewRecorder() w := httptest.NewRecorder()
proxyRequest(w, &request) newProxy(ts.URL, nil).ServeHTTP(w, httpRequest)
assertResponseCode(t, w, 503) testhelper.AssertResponseCode(t, w, 503)
assertResponseBody(t, w, "Request took too long") testhelper.AssertResponseBody(t, w, "Request took too long")
} }
#!/bin/sh
exec env PATH=$(pwd):${PATH} "$@"
/*
The upstream type implements http.Handler.
In this file we handle request routing and interaction with the authBackend.
*/
package main
import (
"fmt"
"log"
"net/http"
"net/http/httputil"
"net/url"
"strings"
)
type serviceHandleFunc func(w http.ResponseWriter, r *gitRequest)
type upstream struct {
httpClient *http.Client
httpProxy *httputil.ReverseProxy
authBackend string
relativeURLRoot string
}
type authorizationResponse struct {
// GL_ID is an environment variable used by gitlab-shell hooks during 'git
// push' and 'git pull'
GL_ID string
// RepoPath is the full path on disk to the Git repository the request is
// about
RepoPath string
// ArchivePath is the full path where we should find/create a cached copy
// of a requested archive
ArchivePath string
// ArchivePrefix is used to put extracted archive contents in a
// subdirectory
ArchivePrefix string
// CommitId is used do prevent race conditions between the 'time of check'
// in the GitLab Rails app and the 'time of use' in gitlab-workhorse.
CommitId string
// StoreLFSPath is provided by the GitLab Rails application
// to mark where the tmp file should be placed
StoreLFSPath string
// LFS object id
LfsOid string
// LFS object size
LfsSize int64
// TmpPath is the path where we should store temporary files
// This is set by authorization middleware
TempPath string
}
// A gitRequest is an *http.Request decorated with attributes returned by the
// GitLab Rails application.
type gitRequest struct {
*http.Request
authorizationResponse
u *upstream
// This field contains the URL.Path stripped from RelativeUrlRoot
relativeURIPath string
}
func newUpstream(authBackend string, authTransport http.RoundTripper) *upstream {
gitlabURL, err := url.Parse(authBackend)
if err != nil {
log.Fatalln(err)
}
relativeURLRoot := gitlabURL.Path
if !strings.HasSuffix(relativeURLRoot, "/") {
relativeURLRoot += "/"
}
// If the relative URL is '/foobar' and we tell httputil.ReverseProxy to proxy
// to 'http://example.com/foobar' then we get a redirect loop, so we clear the
// Path field here.
gitlabURL.Path = ""
up := &upstream{
authBackend: authBackend,
httpClient: &http.Client{Transport: authTransport},
httpProxy: httputil.NewSingleHostReverseProxy(gitlabURL),
relativeURLRoot: relativeURLRoot,
}
up.httpProxy.Transport = authTransport
return up
}
func (u *upstream) ServeHTTP(ow http.ResponseWriter, r *http.Request) {
var g httpRoute
w := newLoggingResponseWriter(ow)
defer w.Log(r)
// Drop WebSocket connection and CONNECT method
if r.RequestURI == "*" {
httpError(&w, r, "Connection upgrade not allowed", http.StatusBadRequest)
return
}
// Disallow connect
if r.Method == "CONNECT" {
httpError(&w, r, "CONNECT not allowed", http.StatusBadRequest)
return
}
// Check URL Root
URIPath := cleanURIPath(r.URL.Path)
if !strings.HasPrefix(URIPath, u.relativeURLRoot) && URIPath+"/" != u.relativeURLRoot {
httpError(&w, r, fmt.Sprintf("Not found %q", URIPath), http.StatusNotFound)
return
}
// Strip prefix and add "/"
// To match against non-relative URL
// Making it simpler for our matcher
relativeURIPath := cleanURIPath(strings.TrimPrefix(URIPath, u.relativeURLRoot))
// Look for a matching Git service
foundService := false
for _, g = range httpRoutes {
if g.method != "" && r.Method != g.method {
continue
}
if g.regex == nil || g.regex.MatchString(relativeURIPath) {
foundService = true
break
}
}
if !foundService {
// The protocol spec in git/Documentation/technical/http-protocol.txt
// says we must return 403 if no matching service is found.
httpError(&w, r, "Forbidden", http.StatusForbidden)
return
}
request := gitRequest{
Request: r,
relativeURIPath: relativeURIPath,
u: u,
}
g.handleFunc(&w, &request)
}
package main
import (
"flag"
"net/url"
)
type urlFlag struct {
*url.URL
}
func (u *urlFlag) Set(s string) error {
myURL, err := url.Parse(s)
if err != nil {
return err
}
u.URL = myURL
return nil
}
func URLFlag(name string, value *url.URL, usage string) **url.URL {
f := &urlFlag{value}
flag.CommandLine.Var(f, name, usage)
return &f.URL
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment