Commit c5664cdf authored by Jacob Vosmaer's avatar Jacob Vosmaer

Get rid of upstream.New

parent ef04d680
...@@ -9,7 +9,6 @@ import ( ...@@ -9,7 +9,6 @@ import (
"net/http/httptest" "net/http/httptest"
"regexp" "regexp"
"testing" "testing"
"time"
) )
func okHandler(w http.ResponseWriter, _ *http.Request, _ *api.Response) { func okHandler(w http.ResponseWriter, _ *http.Request, _ *api.Response) {
...@@ -27,7 +26,7 @@ func runPreAuthorizeHandler(t *testing.T, suffix string, url *regexp.Regexp, api ...@@ -27,7 +26,7 @@ func runPreAuthorizeHandler(t *testing.T, suffix string, url *regexp.Regexp, api
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
api := upstream.New(helper.URLMustParse(ts.URL), "", "123", time.Second).API api := (&upstream.Upstream{Backend: helper.URLMustParse(ts.URL), Version: "123"}).API()
response := httptest.NewRecorder() response := httptest.NewRecorder()
api.PreAuthorizeHandler(okHandler, suffix)(response, httpRequest) api.PreAuthorizeHandler(okHandler, suffix)(response, httpRequest)
......
...@@ -32,15 +32,11 @@ func (p *Proxy) configureReverseProxy() { ...@@ -32,15 +32,11 @@ func (p *Proxy) configureReverseProxy() {
} }
type RoundTripper struct { type RoundTripper struct {
transport http.RoundTripper Transport http.RoundTripper
}
func NewRoundTripper(transport http.RoundTripper) *RoundTripper {
return &RoundTripper{transport: transport}
} }
func (rt *RoundTripper) RoundTrip(r *http.Request) (res *http.Response, err error) { func (rt *RoundTripper) RoundTrip(r *http.Request) (res *http.Response, err error) {
res, err = rt.transport.RoundTrip(r) res, err = rt.Transport.RoundTrip(r)
// httputil.ReverseProxy translates all errors from this // httputil.ReverseProxy translates all errors from this
// RoundTrip function into 500 errors. But the most likely error // RoundTrip function into 500 errors. But the most likely error
......
...@@ -27,42 +27,47 @@ const ciAPIPattern = `^/ci/api/` ...@@ -27,42 +27,47 @@ const ciAPIPattern = `^/ci/api/`
// We match against URI not containing the relativeUrlRoot: // We match against URI not containing the relativeUrlRoot:
// see upstream.ServeHTTP // see upstream.ServeHTTP
func (u *Upstream) compileRoutes() { func (u *Upstream) Routes() []route {
u.configureRoutesOnce.Do(u.configureRoutes)
return u.routes
}
func (u *Upstream) configureRoutes() {
u.routes = []route{ u.routes = []route{
// Git Clone // Git Clone
route{"GET", regexp.MustCompile(gitProjectPattern + `info/refs\z`), git.GetInfoRefs(u.API)}, route{"GET", regexp.MustCompile(gitProjectPattern + `info/refs\z`), git.GetInfoRefs(u.API())},
route{"POST", regexp.MustCompile(gitProjectPattern + `git-upload-pack\z`), contentEncodingHandler(git.PostRPC(u.API))}, route{"POST", regexp.MustCompile(gitProjectPattern + `git-upload-pack\z`), contentEncodingHandler(git.PostRPC(u.API()))},
route{"POST", regexp.MustCompile(gitProjectPattern + `git-receive-pack\z`), contentEncodingHandler(git.PostRPC(u.API))}, route{"POST", regexp.MustCompile(gitProjectPattern + `git-receive-pack\z`), contentEncodingHandler(git.PostRPC(u.API()))},
route{"PUT", regexp.MustCompile(gitProjectPattern + `gitlab-lfs/objects/([0-9a-f]{64})/([0-9]+)\z`), lfs.PutStore(u.API, u.Proxy)}, route{"PUT", regexp.MustCompile(gitProjectPattern + `gitlab-lfs/objects/([0-9a-f]{64})/([0-9]+)\z`), lfs.PutStore(u.API(), u.Proxy())},
// Repository Archive // Repository Archive
route{"GET", regexp.MustCompile(projectPattern + `repository/archive\z`), git.GetArchive(u.API)}, route{"GET", regexp.MustCompile(projectPattern + `repository/archive\z`), git.GetArchive(u.API())},
route{"GET", regexp.MustCompile(projectPattern + `repository/archive.zip\z`), git.GetArchive(u.API)}, route{"GET", regexp.MustCompile(projectPattern + `repository/archive.zip\z`), git.GetArchive(u.API())},
route{"GET", regexp.MustCompile(projectPattern + `repository/archive.tar\z`), git.GetArchive(u.API)}, route{"GET", regexp.MustCompile(projectPattern + `repository/archive.tar\z`), git.GetArchive(u.API())},
route{"GET", regexp.MustCompile(projectPattern + `repository/archive.tar.gz\z`), git.GetArchive(u.API)}, route{"GET", regexp.MustCompile(projectPattern + `repository/archive.tar.gz\z`), git.GetArchive(u.API())},
route{"GET", regexp.MustCompile(projectPattern + `repository/archive.tar.bz2\z`), git.GetArchive(u.API)}, route{"GET", regexp.MustCompile(projectPattern + `repository/archive.tar.bz2\z`), git.GetArchive(u.API())},
// Repository Archive API // Repository Archive API
route{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive\z`), git.GetArchive(u.API)}, route{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive\z`), git.GetArchive(u.API())},
route{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.zip\z`), git.GetArchive(u.API)}, route{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.zip\z`), git.GetArchive(u.API())},
route{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.tar\z`), git.GetArchive(u.API)}, route{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.tar\z`), git.GetArchive(u.API())},
route{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.tar.gz\z`), git.GetArchive(u.API)}, route{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.tar.gz\z`), git.GetArchive(u.API())},
route{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.tar.bz2\z`), git.GetArchive(u.API)}, route{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.tar.bz2\z`), git.GetArchive(u.API())},
// CI Artifacts API // CI Artifacts API
route{"POST", regexp.MustCompile(ciAPIPattern + `v1/builds/[0-9]+/artifacts\z`), contentEncodingHandler(upload.Artifacts(u.API, u.Proxy))}, route{"POST", regexp.MustCompile(ciAPIPattern + `v1/builds/[0-9]+/artifacts\z`), contentEncodingHandler(upload.Artifacts(u.API(), u.Proxy()))},
// Explicitly u.Proxy API requests // Explicitly proxy API requests
route{"", regexp.MustCompile(apiPattern), u.Proxy}, route{"", regexp.MustCompile(apiPattern), u.Proxy()},
route{"", regexp.MustCompile(ciAPIPattern), u.Proxy}, route{"", regexp.MustCompile(ciAPIPattern), u.Proxy()},
// Serve assets // Serve assets
route{"", regexp.MustCompile(`^/assets/`), route{"", regexp.MustCompile(`^/assets/`),
handleServeFile(u.DocumentRoot, u.urlPrefix, CacheExpireMax, handleServeFile(u.DocumentRoot, u.URLPrefix(), CacheExpireMax,
handleDevelopmentMode(u.DevelopmentMode, handleDevelopmentMode(u.DevelopmentMode,
handleDeployPage(u.DocumentRoot, handleDeployPage(u.DocumentRoot,
errorpage.Inject(u.DocumentRoot, errorpage.Inject(u.DocumentRoot,
u.Proxy, u.Proxy(),
), ),
), ),
), ),
...@@ -71,10 +76,10 @@ func (u *Upstream) compileRoutes() { ...@@ -71,10 +76,10 @@ func (u *Upstream) compileRoutes() {
// Serve static files or forward the requests // Serve static files or forward the requests
route{"", nil, route{"", nil,
handleServeFile(u.DocumentRoot, u.urlPrefix, CacheDisabled, handleServeFile(u.DocumentRoot, u.URLPrefix(), CacheDisabled,
handleDeployPage(u.DocumentRoot, handleDeployPage(u.DocumentRoot,
errorpage.Inject(u.DocumentRoot, errorpage.Inject(u.DocumentRoot,
u.Proxy, u.Proxy(),
), ),
), ),
), ),
......
package upstream
import (
"../proxy"
"net"
"net/http"
"time"
)
// Values from http.DefaultTransport
var DefaultDialer = &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}
var DefaultTransport = &http.Transport{
Proxy: http.ProxyFromEnvironment, // from http.DefaultTransport
Dial: DefaultDialer.Dial, // from http.DefaultTransport
ResponseHeaderTimeout: time.Minute, // custom
TLSHandshakeTimeout: 10 * time.Second, // from http.DefaultTransport
}
func (u *Upstream) Transport() http.RoundTripper {
u.configureTransportOnce.Do(u.configureTransport)
return u.transport
}
func (u *Upstream) configureTransport() {
t := *DefaultTransport
if u.ResponseHeaderTimeout != 0 {
t.ResponseHeaderTimeout = u.ResponseHeaderTimeout
}
if u.Socket != "" {
t.Dial = func(_, _ string) (net.Conn, error) {
return DefaultDialer.Dial("unix", u.Socket)
}
}
u.transport = &proxy.RoundTripper{&t}
}
...@@ -8,61 +8,62 @@ package upstream ...@@ -8,61 +8,62 @@ package upstream
import ( import (
"../api" "../api"
"../helper"
"../proxy" "../proxy"
"fmt" "fmt"
"net"
"net/http" "net/http"
"net/url" "net/url"
"path" "path"
"strings" "sync"
"time" "time"
) )
var DefaultBackend = helper.URLMustParse("http://localhost:8080")
type Upstream struct { type Upstream struct {
Backend *url.URL
Version string Version string
API *api.API Socket string
Proxy *proxy.Proxy
DocumentRoot string DocumentRoot string
DevelopmentMode bool DevelopmentMode bool
ResponseHeadersTimeout time.Duration ResponseHeaderTimeout time.Duration
_api *api.API
configureAPIOnce sync.Once
_proxy *proxy.Proxy
configureProxyOnce sync.Once
urlPrefix urlPrefix urlPrefix urlPrefix
configureURLPrefixOnce sync.Once
routes []route routes []route
configureRoutesOnce sync.Once
transport http.RoundTripper
configureTransportOnce sync.Once
} }
func New(authBackend *url.URL, authSocket string, version string, responseHeadersTimeout time.Duration) *Upstream { func (u *Upstream) Proxy() *proxy.Proxy {
relativeURLRoot := authBackend.Path u.configureProxyOnce.Do(u.configureProxy)
if !strings.HasSuffix(relativeURLRoot, "/") { return u._proxy
relativeURLRoot += "/" }
}
// Create Proxy Transport func (u *Upstream) configureProxy() {
authTransport := http.DefaultTransport u._proxy = &proxy.Proxy{URL: u.Backend, Transport: u.Transport(), Version: u.Version}
if authSocket != "" { }
dialer := &net.Dialer{
// The values below are taken from http.DefaultTransport func (u *Upstream) API() *api.API {
Timeout: 30 * time.Second, u.configureAPIOnce.Do(u.configureAPI)
KeepAlive: 30 * time.Second, return u._api
} }
authTransport = &http.Transport{
Dial: func(_, _ string) (net.Conn, error) { func (u *Upstream) configureAPI() {
return dialer.Dial("unix", authSocket) u._api = &api.API{
}, Client: &http.Client{Transport: u.Transport()},
ResponseHeaderTimeout: responseHeadersTimeout, URL: u.Backend,
} Version: u.Version,
}
proxyTransport := proxy.NewRoundTripper(authTransport)
up := &Upstream{
API: &api.API{
Client: &http.Client{Transport: proxyTransport},
URL: authBackend,
Version: version,
},
Proxy: &proxy.Proxy{URL: authBackend, Transport: proxyTransport, Version: version},
urlPrefix: urlPrefix(relativeURLRoot),
} }
up.compileRoutes()
return up
} }
func (u *Upstream) ServeHTTP(ow http.ResponseWriter, r *http.Request) { func (u *Upstream) ServeHTTP(ow http.ResponseWriter, r *http.Request) {
...@@ -83,7 +84,7 @@ func (u *Upstream) ServeHTTP(ow http.ResponseWriter, r *http.Request) { ...@@ -83,7 +84,7 @@ func (u *Upstream) ServeHTTP(ow http.ResponseWriter, r *http.Request) {
// Check URL Root // Check URL Root
URIPath := cleanURIPath(r.URL.Path) URIPath := cleanURIPath(r.URL.Path)
prefix := u.urlPrefix prefix := u.URLPrefix()
if !prefix.match(URIPath) { if !prefix.match(URIPath) {
httpError(&w, r, fmt.Sprintf("Not found %q", URIPath), http.StatusNotFound) httpError(&w, r, fmt.Sprintf("Not found %q", URIPath), http.StatusNotFound)
return return
...@@ -92,7 +93,7 @@ func (u *Upstream) ServeHTTP(ow http.ResponseWriter, r *http.Request) { ...@@ -92,7 +93,7 @@ func (u *Upstream) ServeHTTP(ow http.ResponseWriter, r *http.Request) {
// Look for a matching Git service // Look for a matching Git service
var ro route var ro route
foundService := false foundService := false
for _, ro = range u.routes { for _, ro = range u.Routes() {
if ro.method != "" && r.Method != ro.method { if ro.method != "" && r.Method != ro.method {
continue continue
} }
......
...@@ -14,3 +14,19 @@ func (p urlPrefix) match(path string) bool { ...@@ -14,3 +14,19 @@ func (p urlPrefix) match(path string) bool {
pre := string(p) pre := string(p)
return strings.HasPrefix(path, pre) || path+"/" == pre return strings.HasPrefix(path, pre) || path+"/" == pre
} }
func (u *Upstream) URLPrefix() urlPrefix {
u.configureURLPrefixOnce.Do(u.configureURLPrefix)
return u.urlPrefix
}
func (u *Upstream) configureURLPrefix() {
if u.Backend == nil {
u.Backend = DefaultBackend
}
relativeURLRoot := u.Backend.Path
if !strings.HasSuffix(relativeURLRoot, "/") {
relativeURLRoot += "/"
}
u.urlPrefix = urlPrefix(relativeURLRoot)
}
...@@ -23,7 +23,6 @@ import ( ...@@ -23,7 +23,6 @@ import (
_ "net/http/pprof" _ "net/http/pprof"
"os" "os"
"syscall" "syscall"
"time"
) )
// Current version of GitLab Workhorse // Current version of GitLab Workhorse
...@@ -33,11 +32,11 @@ var printVersion = flag.Bool("version", false, "Print version and exit") ...@@ -33,11 +32,11 @@ var printVersion = flag.Bool("version", false, "Print version and exit")
var listenAddr = flag.String("listenAddr", "localhost:8181", "Listen address for HTTP server") var listenAddr = flag.String("listenAddr", "localhost:8181", "Listen address for HTTP server")
var listenNetwork = flag.String("listenNetwork", "tcp", "Listen 'network' (tcp, tcp4, tcp6, unix)") var listenNetwork = flag.String("listenNetwork", "tcp", "Listen 'network' (tcp, tcp4, tcp6, unix)")
var listenUmask = flag.Int("listenUmask", 022, "Umask for Unix socket, default: 022") var listenUmask = flag.Int("listenUmask", 022, "Umask for Unix socket, default: 022")
var authBackend = URLFlag("authBackend", "http://localhost:8080", "Authentication/authorization backend") var authBackend = URLFlag("authBackend", upstream.DefaultBackend, "Authentication/authorization backend")
var authSocket = flag.String("authSocket", "", "Optional: Unix domain socket to dial authBackend at") var authSocket = flag.String("authSocket", "", "Optional: Unix domain socket to dial authBackend at")
var pprofListenAddr = flag.String("pprofListenAddr", "", "pprof listening address, e.g. 'localhost:6060'") var pprofListenAddr = flag.String("pprofListenAddr", "", "pprof listening address, e.g. 'localhost:6060'")
var documentRoot = flag.String("documentRoot", "public", "Path to static files content") var documentRoot = flag.String("documentRoot", "public", "Path to static files content")
var responseHeadersTimeout = flag.Duration("proxyHeadersTimeout", time.Minute, "How long to wait for response headers when proxying the request") var responseHeadersTimeout = flag.Duration("proxyHeadersTimeout", upstream.DefaultTransport.ResponseHeaderTimeout, "How long to wait for response headers when proxying the request")
var developmentMode = flag.Bool("developmentMode", false, "Allow to serve assets from Rails app") var developmentMode = flag.Bool("developmentMode", false, "Allow to serve assets from Rails app")
func main() { func main() {
...@@ -81,9 +80,14 @@ func main() { ...@@ -81,9 +80,14 @@ func main() {
}() }()
} }
up := upstream.New(authBackend, *authSocket, Version, *responseHeadersTimeout) up := &upstream.Upstream{
up.DocumentRoot = *documentRoot Backend: authBackend,
up.DevelopmentMode = *developmentMode Socket: *authSocket,
Version: Version,
ResponseHeaderTimeout: *responseHeadersTimeout,
DocumentRoot: *documentRoot,
DevelopmentMode: *developmentMode,
}
log.Fatal(http.Serve(listener, up)) log.Fatal(http.Serve(listener, up))
} }
...@@ -311,7 +311,7 @@ func testAuthServer(url *regexp.Regexp, code int, body interface{}) *httptest.Se ...@@ -311,7 +311,7 @@ func testAuthServer(url *regexp.Regexp, code int, body interface{}) *httptest.Se
} }
func startWorkhorseServer(authBackend string) *httptest.Server { func startWorkhorseServer(authBackend string) *httptest.Server {
u := upstream.New(helper.URLMustParse(authBackend), "", "123", time.Second) u := &upstream.Upstream{Backend: helper.URLMustParse(authBackend), Version: "123"}
return httptest.NewServer(u) return httptest.NewServer(u)
} }
......
...@@ -16,7 +16,7 @@ import ( ...@@ -16,7 +16,7 @@ import (
) )
func newUpstream(url string) *upstream.Upstream { func newUpstream(url string) *upstream.Upstream {
return upstream.New(helper.URLMustParse(url), "", "123", time.Second) return &upstream.Upstream{Backend: helper.URLMustParse(url), Version: "123"}
} }
func TestProxyRequest(t *testing.T) { func TestProxyRequest(t *testing.T) {
...@@ -48,7 +48,7 @@ func TestProxyRequest(t *testing.T) { ...@@ -48,7 +48,7 @@ func TestProxyRequest(t *testing.T) {
u := newUpstream(ts.URL) u := newUpstream(ts.URL)
w := httptest.NewRecorder() w := httptest.NewRecorder()
u.Proxy.ServeHTTP(w, httpRequest) u.Proxy().ServeHTTP(w, httpRequest)
helper.AssertResponseCode(t, w, 202) helper.AssertResponseCode(t, w, 202)
helper.AssertResponseBody(t, w, "RESPONSE") helper.AssertResponseBody(t, w, "RESPONSE")
...@@ -66,7 +66,7 @@ func TestProxyError(t *testing.T) { ...@@ -66,7 +66,7 @@ func TestProxyError(t *testing.T) {
u := newUpstream("http://localhost:655575/") u := newUpstream("http://localhost:655575/")
w := httptest.NewRecorder() w := httptest.NewRecorder()
u.Proxy.ServeHTTP(w, httpRequest) u.Proxy().ServeHTTP(w, httpRequest)
helper.AssertResponseCode(t, w, 502) helper.AssertResponseCode(t, w, 502)
helper.AssertResponseBody(t, w, "dial tcp: invalid port 655575") helper.AssertResponseBody(t, w, "dial tcp: invalid port 655575")
} }
...@@ -81,7 +81,7 @@ func TestProxyReadTimeout(t *testing.T) { ...@@ -81,7 +81,7 @@ func TestProxyReadTimeout(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
transport := proxy.NewRoundTripper( transport := &proxy.RoundTripper{
&http.Transport{ &http.Transport{
Proxy: http.ProxyFromEnvironment, Proxy: http.ProxyFromEnvironment,
Dial: (&net.Dialer{ Dial: (&net.Dialer{
...@@ -91,7 +91,7 @@ func TestProxyReadTimeout(t *testing.T) { ...@@ -91,7 +91,7 @@ func TestProxyReadTimeout(t *testing.T) {
TLSHandshakeTimeout: 10 * time.Second, TLSHandshakeTimeout: 10 * time.Second,
ResponseHeaderTimeout: time.Millisecond, ResponseHeaderTimeout: time.Millisecond,
}, },
) }
p := &proxy.Proxy{URL: helper.URLMustParse(ts.URL), Transport: transport, Version: "123"} p := &proxy.Proxy{URL: helper.URLMustParse(ts.URL), Transport: transport, Version: "123"}
...@@ -116,7 +116,7 @@ func TestProxyHandlerTimeout(t *testing.T) { ...@@ -116,7 +116,7 @@ func TestProxyHandlerTimeout(t *testing.T) {
u := newUpstream(ts.URL) u := newUpstream(ts.URL)
w := httptest.NewRecorder() w := httptest.NewRecorder()
u.Proxy.ServeHTTP(w, httpRequest) u.Proxy().ServeHTTP(w, httpRequest)
helper.AssertResponseCode(t, w, 503) helper.AssertResponseCode(t, w, 503)
helper.AssertResponseBody(t, w, "Request took too long") helper.AssertResponseBody(t, w, "Request took too long")
} }
...@@ -2,7 +2,6 @@ package main ...@@ -2,7 +2,6 @@ package main
import ( import (
"flag" "flag"
"log"
"net/url" "net/url"
) )
...@@ -17,12 +16,8 @@ func (u *urlFlag) Set(s string) error { ...@@ -17,12 +16,8 @@ func (u *urlFlag) Set(s string) error {
return nil return nil
} }
func URLFlag(name string, value string, usage string) *url.URL { func URLFlag(name string, value *url.URL, usage string) *url.URL {
u, err := url.Parse(value) f := urlFlag{value}
if err != nil {
log.Fatalf("URLFlag: invalid default: %q %v", value, err)
}
f := urlFlag{u}
flag.CommandLine.Var(&f, name, usage) flag.CommandLine.Var(&f, name, usage)
return f.URL return f.URL
} }
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment