Commit bd748961 authored by Jacob Vosmaer's avatar Jacob Vosmaer

Move Proxy into an internal package

parent 3dd96bd4
......@@ -5,6 +5,7 @@ In this file we handle 'git archive' downloads
package main
import (
"./internal/helper"
"fmt"
"io"
"io/ioutil"
......@@ -103,22 +104,22 @@ func handleGetArchive(w http.ResponseWriter, r *http.Request, a *apiResponse) {
setArchiveHeaders(w, format, archiveFilename)
w.WriteHeader(200) // Don't bother with HTTP 500 from this point on, just return
if _, err := io.Copy(w, archiveReader); err != nil {
logError(fmt.Errorf("handleGetArchive: read: %v", err))
helper.LogError(fmt.Errorf("handleGetArchive: read: %v", err))
return
}
if err := archiveCmd.Wait(); err != nil {
logError(fmt.Errorf("handleGetArchive: archiveCmd: %v", err))
helper.LogError(fmt.Errorf("handleGetArchive: archiveCmd: %v", err))
return
}
if compressCmd != nil {
if err := compressCmd.Wait(); err != nil {
logError(fmt.Errorf("handleGetArchive: compressCmd: %v", err))
helper.LogError(fmt.Errorf("handleGetArchive: compressCmd: %v", err))
return
}
}
if err := finalizeCachedArchive(tempFile, a.ArchivePath); err != nil {
logError(fmt.Errorf("handleGetArchive: finalize cached archive: %v", err))
helper.LogError(fmt.Errorf("handleGetArchive: finalize cached archive: %v", err))
return
}
}
......
package main
import (
"./internal/proxy"
"encoding/json"
"fmt"
"io"
......@@ -15,7 +16,7 @@ func (api *API) newUpstreamRequest(r *http.Request, body io.Reader, suffix strin
authReq := &http.Request{
Method: r.Method,
URL: &url,
Header: headerClone(r.Header),
Header: proxy.HeaderClone(r.Header),
}
if body != nil {
authReq.Body = ioutil.NopCloser(body)
......
......@@ -23,7 +23,7 @@ func runPreAuthorizeHandler(t *testing.T, suffix string, url *regexp.Regexp, api
if err != nil {
t.Fatal(err)
}
api := newUpstream(ts.URL, nil).API
api := newUpstream(ts.URL, "").API
response := httptest.NewRecorder()
api.preAuthorizeHandler(okHandler, suffix)(response, httpRequest)
......
......@@ -5,6 +5,7 @@ In this file we handle the Git 'smart HTTP' protocol
package main
import (
"./internal/helper"
"errors"
"fmt"
"io"
......@@ -69,19 +70,19 @@ func handleGetInfoRefs(w http.ResponseWriter, r *http.Request, a *apiResponse) {
w.Header().Add("Cache-Control", "no-cache")
w.WriteHeader(200) // Don't bother with HTTP 500 from this point on, just return
if err := pktLine(w, fmt.Sprintf("# service=%s\n", rpc)); err != nil {
logError(fmt.Errorf("handleGetInfoRefs: pktLine: %v", err))
helper.LogError(fmt.Errorf("handleGetInfoRefs: pktLine: %v", err))
return
}
if err := pktFlush(w); err != nil {
logError(fmt.Errorf("handleGetInfoRefs: pktFlush: %v", err))
helper.LogError(fmt.Errorf("handleGetInfoRefs: pktFlush: %v", err))
return
}
if _, err := io.Copy(w, stdout); err != nil {
logError(fmt.Errorf("handleGetInfoRefs: read from %v: %v", cmd.Args, err))
helper.LogError(fmt.Errorf("handleGetInfoRefs: read from %v: %v", cmd.Args, err))
return
}
if err := cmd.Wait(); err != nil {
logError(fmt.Errorf("handleGetInfoRefs: wait for %v: %v", cmd.Args, err))
helper.LogError(fmt.Errorf("handleGetInfoRefs: wait for %v: %v", cmd.Args, err))
return
}
}
......@@ -136,11 +137,11 @@ func handlePostRPC(w http.ResponseWriter, r *http.Request, a *apiResponse) {
// This io.Copy may take a long time, both for Git push and pull.
if _, err := io.Copy(w, stdout); err != nil {
logError(fmt.Errorf("handlePostRPC read from %v: %v", cmd.Args, err))
helper.LogError(fmt.Errorf("handlePostRPC read from %v: %v", cmd.Args, err))
return
}
if err := cmd.Wait(); err != nil {
logError(fmt.Errorf("handlePostRPC wait for %v: %v", cmd.Args, err))
helper.LogError(fmt.Errorf("handlePostRPC wait for %v: %v", cmd.Args, err))
return
}
}
......
......@@ -5,9 +5,8 @@ Miscellaneous helpers: logging, errors, subprocesses
package main
import (
"errors"
"./internal/helper"
"fmt"
"log"
"net/http"
"os"
"os/exec"
......@@ -17,11 +16,7 @@ import (
func fail500(w http.ResponseWriter, err error) {
http.Error(w, "Internal server error", 500)
logError(err)
}
func logError(err error) {
log.Printf("error: %v", err)
helper.LogError(err)
}
func httpError(w http.ResponseWriter, r *http.Request, error string, code int) {
......@@ -71,36 +66,6 @@ func setNoCacheHeaders(header http.Header) {
header.Set("Expires", "Fri, 01 Jan 1990 00:00:00 GMT")
}
func openFile(path string) (file *os.File, fi os.FileInfo, err error) {
file, err = os.Open(path)
if err != nil {
return
}
defer func() {
if err != nil {
file.Close()
}
}()
fi, err = file.Stat()
if err != nil {
return
}
// The os.Open can also open directories
if fi.IsDir() {
err = &os.PathError{
Op: "open",
Path: path,
Err: errors.New("path is directory"),
}
return
}
return
}
// Borrowed from: net/http/server.go
// Return the canonical path for p, eliminating . and .. elements.
func cleanURIPath(p string) string {
......
package helper
import (
"errors"
"log"
"os"
)
func LogError(err error) {
log.Printf("error: %v", err)
}
func OpenFile(path string) (file *os.File, fi os.FileInfo, err error) {
file, err = os.Open(path)
if err != nil {
return
}
defer func() {
if err != nil {
file.Close()
}
}()
fi, err = file.Stat()
if err != nil {
return
}
// The os.Open can also open directories
if fi.IsDir() {
err = &os.PathError{
Op: "open",
Path: path,
Err: errors.New("path is directory"),
}
return
}
return
}
package main
package proxy
import (
"../helper"
"bytes"
"fmt"
"io/ioutil"
"net/http"
"net/http/httputil"
"net/url"
)
type proxyRoundTripper struct {
type Proxy struct {
reverseProxy *httputil.ReverseProxy
version string
}
func NewProxy(url *url.URL, transport http.RoundTripper, version string) *Proxy {
// Modify a copy of url
proxyURL := *url
proxyURL.Path = ""
p := Proxy{reverseProxy: httputil.NewSingleHostReverseProxy(&proxyURL), version: version}
p.reverseProxy.Transport = transport
return &p
}
type RoundTripper struct {
transport http.RoundTripper
}
func (p *proxyRoundTripper) RoundTrip(r *http.Request) (res *http.Response, err error) {
res, err = p.transport.RoundTrip(r)
func NewRoundTripper(transport http.RoundTripper) *RoundTripper {
return &RoundTripper{transport: transport}
}
func (rt *RoundTripper) RoundTrip(r *http.Request) (res *http.Response, err error) {
res, err = rt.transport.RoundTrip(r)
// httputil.ReverseProxy translates all errors from this
// RoundTrip function into 500 errors. But the most likely error
......@@ -21,7 +42,7 @@ func (p *proxyRoundTripper) RoundTrip(r *http.Request) (res *http.Response, err
// instead of 500s we catch the RoundTrip error here and inject a
// 502 response.
if err != nil {
logError(fmt.Errorf("proxyRoundTripper: %s %q failed with: %q", r.Method, r.RequestURI, err))
helper.LogError(fmt.Errorf("proxyRoundTripper: %s %q failed with: %q", r.Method, r.RequestURI, err))
res = &http.Response{
StatusCode: http.StatusBadGateway,
......@@ -41,7 +62,7 @@ func (p *proxyRoundTripper) RoundTrip(r *http.Request) (res *http.Response, err
return
}
func headerClone(h http.Header) http.Header {
func HeaderClone(h http.Header) http.Header {
h2 := make(http.Header, len(h))
for k, vv := range h {
vv2 := make([]string, len(vv))
......@@ -54,12 +75,12 @@ func headerClone(h http.Header) http.Header {
func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Clone request
req := *r
req.Header = headerClone(r.Header)
req.Header = HeaderClone(r.Header)
// Set Workhorse version
req.Header.Set("Gitlab-Workhorse", Version)
req.Header.Set("Gitlab-Workhorse", p.version)
rw := newSendFileResponseWriter(w, &req)
defer rw.Flush()
p.ReverseProxy.ServeHTTP(&rw, &req)
p.reverseProxy.ServeHTTP(&rw, &req)
}
......@@ -4,9 +4,10 @@ via the X-Sendfile mechanism. All that is needed in the Rails code is the
'send_file' method.
*/
package main
package proxy
import (
"../helper"
"log"
"net/http"
)
......@@ -63,7 +64,7 @@ func (s *sendFileResponseWriter) WriteHeader(status int) {
// Serve the file
log.Printf("Send file %q for %s %q", file, s.req.Method, s.req.RequestURI)
content, fi, err := openFile(file)
content, fi, err := helper.OpenFile(file)
if err != nil {
http.NotFound(s.rw, s.req)
return
......
......@@ -153,23 +153,6 @@ func main() {
log.Fatal(err)
}
// Create Proxy Transport
authTransport := http.DefaultTransport
if *authSocket != "" {
dialer := &net.Dialer{
// The values below are taken from http.DefaultTransport
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}
authTransport = &http.Transport{
Dial: func(_, _ string) (net.Conn, error) {
return dialer.Dial("unix", *authSocket)
},
ResponseHeaderTimeout: *responseHeadersTimeout,
}
}
proxyTransport := &proxyRoundTripper{transport: authTransport}
// The profiler will only be activated by HTTP requests. HTTP
// requests can only reach the profiler if we start a listener. So by
// having no profiler HTTP listener by default, the profiler is
......@@ -180,7 +163,7 @@ func main() {
}()
}
upstream := newUpstream(*authBackend, proxyTransport)
upstream := newUpstream(*authBackend, *authSocket)
compileRoutes(upstream)
log.Fatal(http.Serve(listener, upstream))
}
......@@ -326,7 +326,7 @@ func testAuthServer(url *regexp.Regexp, code int, body interface{}) *httptest.Se
}
func startWorkhorseServer(authBackend string) *httptest.Server {
u := newUpstream(authBackend, nil)
u := newUpstream(authBackend, "")
compileRoutes(u)
return httptest.NewServer(u)
}
......
package main
import (
"./internal/proxy"
"bytes"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"net/url"
"regexp"
"testing"
"time"
......@@ -39,7 +41,7 @@ func TestProxyRequest(t *testing.T) {
}
httpRequest.Header.Set("Custom-Header", "test")
u := newUpstream(ts.URL, nil)
u := newUpstream(ts.URL, "")
w := httptest.NewRecorder()
u.Proxy.ServeHTTP(w, httpRequest)
assertResponseCode(t, w, 202)
......@@ -57,11 +59,7 @@ func TestProxyError(t *testing.T) {
}
httpRequest.Header.Set("Custom-Header", "test")
transport := proxyRoundTripper{
transport: http.DefaultTransport,
}
u := newUpstream("http://localhost:655575/", &transport)
u := newUpstream("http://localhost:655575/", "")
w := httptest.NewRecorder()
u.Proxy.ServeHTTP(w, httpRequest)
assertResponseCode(t, w, 502)
......@@ -78,8 +76,8 @@ func TestProxyReadTimeout(t *testing.T) {
t.Fatal(err)
}
transport := &proxyRoundTripper{
transport: &http.Transport{
transport := proxy.NewRoundTripper(
&http.Transport{
Proxy: http.ProxyFromEnvironment,
Dial: (&net.Dialer{
Timeout: 30 * time.Second,
......@@ -88,9 +86,14 @@ func TestProxyReadTimeout(t *testing.T) {
TLSHandshakeTimeout: 10 * time.Second,
ResponseHeaderTimeout: time.Millisecond,
},
}
)
u := newUpstream(ts.URL, transport)
u := newUpstream(ts.URL, "")
url, err := url.Parse(ts.URL)
if err != nil {
t.Fatal(err)
}
u.Proxy = proxy.NewProxy(url, transport, "123")
w := httptest.NewRecorder()
u.Proxy.ServeHTTP(w, httpRequest)
......@@ -110,11 +113,7 @@ func TestProxyHandlerTimeout(t *testing.T) {
t.Fatal(err)
}
transport := &proxyRoundTripper{
transport: http.DefaultTransport,
}
u := newUpstream(ts.URL, transport)
u := newUpstream(ts.URL, "")
w := httptest.NewRecorder()
u.Proxy.ServeHTTP(w, httpRequest)
......
package main
import (
"./internal/helper"
"log"
"net/http"
"os"
......@@ -36,7 +37,7 @@ func (u *upstream) handleServeFile(documentRoot *string, cache CacheMode, notFou
// Serve pre-gzipped assets
if acceptEncoding := r.Header.Get("Accept-Encoding"); strings.Contains(acceptEncoding, "gzip") {
content, fi, err = openFile(file + ".gz")
content, fi, err = helper.OpenFile(file + ".gz")
if err == nil {
w.Header().Set("Content-Encoding", "gzip")
}
......@@ -44,7 +45,7 @@ func (u *upstream) handleServeFile(documentRoot *string, cache CacheMode, notFou
// If not found, open the original file
if content == nil || err != nil {
content, fi, err = openFile(file)
content, fi, err = helper.OpenFile(file)
}
if err != nil {
if notFoundHandler != nil {
......
......@@ -11,7 +11,7 @@ import (
"testing"
)
var dummyUpstream = newUpstream("http://localhost", nil)
var dummyUpstream = newUpstream("http://localhost", "")
func TestServingNonExistingFile(t *testing.T) {
dir := "/path/to/non/existing/directory"
......
......@@ -51,7 +51,7 @@ func TestUploadHandlerForwardingRawData(t *testing.T) {
response := httptest.NewRecorder()
httpRequest.Header.Set(tempPathHeader, tempPath)
u := newUpstream(ts.URL, nil)
u := newUpstream(ts.URL, "")
handleFileUploads(u.Proxy).ServeHTTP(response, httpRequest)
assertResponseCode(t, response, 202)
......@@ -126,7 +126,7 @@ func TestUploadHandlerRewritingMultiPartData(t *testing.T) {
httpRequest.Header.Set("Content-Type", writer.FormDataContentType())
httpRequest.Header.Set(tempPathHeader, tempPath)
response := httptest.NewRecorder()
u := newUpstream(ts.URL, nil)
u := newUpstream(ts.URL, "")
handleFileUploads(u.Proxy).ServeHTTP(response, httpRequest)
assertResponseCode(t, response, 202)
......
......@@ -7,12 +7,14 @@ In this file we handle request routing and interaction with the authBackend.
package main
import (
"./internal/proxy"
"fmt"
"log"
"net"
"net/http"
"net/http/httputil"
"net/url"
"strings"
"time"
)
type serviceHandleFunc func(http.ResponseWriter, *http.Request, *apiResponse)
......@@ -24,15 +26,11 @@ type API struct {
type upstream struct {
API *API
Proxy *Proxy
Proxy *proxy.Proxy
authBackend string
relativeURLRoot string
}
type Proxy struct {
ReverseProxy *httputil.ReverseProxy
}
type apiResponse struct {
// GL_ID is an environment variable used by gitlab-shell hooks during 'git
// push' and 'git pull'
......@@ -61,16 +59,7 @@ type apiResponse struct {
TempPath string
}
func newProxy(url *url.URL, transport http.RoundTripper) *Proxy {
// Modify a copy of url
proxyURL := *url
proxyURL.Path = ""
proxy := Proxy{ReverseProxy: httputil.NewSingleHostReverseProxy(&proxyURL)}
proxy.ReverseProxy.Transport = transport
return &proxy
}
func newUpstream(authBackend string, authTransport http.RoundTripper) *upstream {
func newUpstream(authBackend string, authSocket string) *upstream {
parsedURL, err := url.Parse(authBackend)
if err != nil {
log.Fatalln(err)
......@@ -81,10 +70,27 @@ func newUpstream(authBackend string, authTransport http.RoundTripper) *upstream
relativeURLRoot += "/"
}
// Create Proxy Transport
authTransport := http.DefaultTransport
if authSocket != "" {
dialer := &net.Dialer{
// The values below are taken from http.DefaultTransport
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}
authTransport = &http.Transport{
Dial: func(_, _ string) (net.Conn, error) {
return dialer.Dial("unix", authSocket)
},
ResponseHeaderTimeout: *responseHeadersTimeout,
}
}
proxyTransport := proxy.NewRoundTripper(authTransport)
up := &upstream{
authBackend: authBackend,
API: &API{Client: &http.Client{Transport: authTransport}, URL: parsedURL},
Proxy: newProxy(parsedURL, authTransport),
API: &API{Client: &http.Client{Transport: proxyTransport}, URL: parsedURL},
Proxy: proxy.NewProxy(parsedURL, proxyTransport, Version),
relativeURLRoot: relativeURLRoot,
}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment