Commit 08ec1011 authored by Jacob Vosmaer's avatar Jacob Vosmaer

Fix backend URL parsing

parent cdcabf45
package main
import (
"fmt"
"net/url"
)
func parseAuthBackend(authBackend string) (*url.URL, error) {
backendURL, err := url.Parse(authBackend)
if err != nil {
return nil, err
}
if backendURL.Host == "" {
backendURL, err = url.Parse("http://" + authBackend)
if err != nil {
return nil, err
}
}
if backendURL.Scheme != "http" {
return nil, fmt.Errorf("invalid scheme, only 'http' is allowed: %q", authBackend)
}
if backendURL.Host == "" {
return nil, fmt.Errorf("missing host in %q", authBackend)
}
return backendURL, nil
}
...@@ -64,7 +64,7 @@ func mustParseAddress(address, scheme string) string { ...@@ -64,7 +64,7 @@ func mustParseAddress(address, scheme string) string {
} }
} }
panic("could not parse host:port from address and scheme") panic(fmt.Errorf("could not parse host:port from address %q and scheme %q", address, scheme))
} }
func (t *RoundTripper) RoundTrip(r *http.Request) (res *http.Response, err error) { func (t *RoundTripper) RoundTrip(r *http.Request) (res *http.Response, err error) {
......
...@@ -34,7 +34,7 @@ var printVersion = flag.Bool("version", false, "Print version and exit") ...@@ -34,7 +34,7 @@ var printVersion = flag.Bool("version", false, "Print version and exit")
var listenAddr = flag.String("listenAddr", "localhost:8181", "Listen address for HTTP server") var listenAddr = flag.String("listenAddr", "localhost:8181", "Listen address for HTTP server")
var listenNetwork = flag.String("listenNetwork", "tcp", "Listen 'network' (tcp, tcp4, tcp6, unix)") var listenNetwork = flag.String("listenNetwork", "tcp", "Listen 'network' (tcp, tcp4, tcp6, unix)")
var listenUmask = flag.Int("listenUmask", 0, "Umask for Unix socket") var listenUmask = flag.Int("listenUmask", 0, "Umask for Unix socket")
var authBackend = URLFlag("authBackend", upstream.DefaultBackend, "Authentication/authorization backend") var authBackend = flag.String("authBackend", upstream.DefaultBackend.String(), "Authentication/authorization backend")
var authSocket = flag.String("authSocket", "", "Optional: Unix domain socket to dial authBackend at") var authSocket = flag.String("authSocket", "", "Optional: Unix domain socket to dial authBackend at")
var pprofListenAddr = flag.String("pprofListenAddr", "", "pprof listening address, e.g. 'localhost:6060'") var pprofListenAddr = flag.String("pprofListenAddr", "", "pprof listening address, e.g. 'localhost:6060'")
var documentRoot = flag.String("documentRoot", "public", "Path to static files content") var documentRoot = flag.String("documentRoot", "public", "Path to static files content")
...@@ -55,6 +55,12 @@ func main() { ...@@ -55,6 +55,12 @@ func main() {
os.Exit(0) os.Exit(0)
} }
backendURL, err := parseAuthBackend(*authBackend)
if err != nil {
fmt.Fprintf(os.Stderr, "invalid authBackend: %v\n", err)
os.Exit(1)
}
log.Printf("Starting %s", version) log.Printf("Starting %s", version)
// Good housekeeping for Unix sockets: unlink before binding // Good housekeeping for Unix sockets: unlink before binding
...@@ -83,7 +89,7 @@ func main() { ...@@ -83,7 +89,7 @@ func main() {
} }
up := upstream.NewUpstream( up := upstream.NewUpstream(
*authBackend, backendURL,
*authSocket, *authSocket,
Version, Version,
*documentRoot, *documentRoot,
......
...@@ -736,6 +736,43 @@ func TestGetGitPatch(t *testing.T) { ...@@ -736,6 +736,43 @@ func TestGetGitPatch(t *testing.T) {
} }
} }
func TestParseAuthBackend(t *testing.T) {
failures := []string{
"",
"ftp://localhost",
"https://example.com",
}
for _, example := range failures {
if _, err := parseAuthBackend(example); err == nil {
t.Errorf("error expected for %q", example)
}
}
successes := []struct{ input, host, scheme string }{
{"http://localhost:8080", "localhost:8080", "http"},
{"localhost:3000", "localhost:3000", "http"},
{"http://localhost", "localhost", "http"},
{"localhost", "localhost", "http"},
}
for _, example := range successes {
result, err := parseAuthBackend(example.input)
if err != nil {
t.Errorf("parse %q: %v", example.input, err)
break
}
if result.Host != example.host {
t.Errorf("expected %q, got %q", example.host, result.Host)
}
if result.Scheme != example.scheme {
t.Errorf("expected %q, got %q", example.scheme, result.Scheme)
}
}
}
func setupStaticFile(fpath, content string) error { func setupStaticFile(fpath, content string) error {
cwd, err := os.Getwd() cwd, err := os.Getwd()
if err != nil { if err != nil {
......
package main
import (
"flag"
"net/url"
)
type urlFlag struct {
*url.URL
}
func (u *urlFlag) Set(s string) error {
myURL, err := url.Parse(s)
if err != nil {
return err
}
u.URL = myURL
return nil
}
func URLFlag(name string, value *url.URL, usage string) **url.URL {
f := &urlFlag{value}
flag.CommandLine.Var(f, name, usage)
return &f.URL
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment