Commit 866610fa authored by Nick Thomas's avatar Nick Thomas

Merge branch 'proxy-actioncable-ws-route' into 'master'

Proxy ActionCable websocket connection

See merge request gitlab-org/gitlab-workhorse!454
parents 73d42cb5 99d40e17
...@@ -48,6 +48,10 @@ Options: ...@@ -48,6 +48,10 @@ Options:
Authentication/authorization backend (default "http://localhost:8080") Authentication/authorization backend (default "http://localhost:8080")
-authSocket string -authSocket string
Optional: Unix domain socket to dial authBackend at Optional: Unix domain socket to dial authBackend at
-cableBackend string
Optional: ActionCable backend (default authBackend)
-cableSocket string
Optional: Unix domain socket to dial cableBackend at (default authSocket)
-config string -config string
TOML file to load config from TOML file to load config from
-developmentMode -developmentMode
...@@ -164,6 +168,8 @@ In table form: ...@@ -164,6 +168,8 @@ In table form:
|`http://localhost:3000`|`/path/to/socket`|`/path/to/socket`|`/`| |`http://localhost:3000`|`/path/to/socket`|`/path/to/socket`|`/`|
|`http://localhost:3000/gitlab`|`/path/to/socket`|`/path/to/socket`|`/gitlab`| |`http://localhost:3000/gitlab`|`/path/to/socket`|`/path/to/socket`|`/gitlab`|
The same applies to `cableBackend` and `cableSocket`.
## Installation ## Installation
To install gitlab-workhorse you need [Go 1.8 or To install gitlab-workhorse you need [Go 1.8 or
......
package main
import (
"net/http"
"net/http/httptest"
"net/url"
"regexp"
"testing"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/require"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/config"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper"
)
const cablePath = "/-/cable"
func TestSingleBackend(t *testing.T) {
cableServerConns, cableBackendServer := startCableServer()
defer cableBackendServer.Close()
config := newUpstreamWithCableConfig(cableBackendServer.URL, "")
workhorse := startWorkhorseServerWithConfig(config)
defer workhorse.Close()
cableURL := websocketURL(workhorse.URL, cablePath)
client, _, err := dialWebsocket(cableURL, nil)
require.NoError(t, err)
defer client.Close()
server := (<-cableServerConns).conn
defer server.Close()
require.NoError(t, say(client, "hello"))
assertReadMessage(t, server, websocket.TextMessage, "hello")
require.NoError(t, say(server, "world"))
assertReadMessage(t, client, websocket.TextMessage, "world")
}
func TestSeparateCableBackend(t *testing.T) {
authBackendServer := testhelper.TestServerWithHandler(regexp.MustCompile(`.`), http.HandlerFunc(http.NotFound))
defer authBackendServer.Close()
cableServerConns, cableBackendServer := startCableServer()
defer cableBackendServer.Close()
config := newUpstreamWithCableConfig(authBackendServer.URL, cableBackendServer.URL)
workhorse := startWorkhorseServerWithConfig(config)
defer workhorse.Close()
cableURL := websocketURL(workhorse.URL, cablePath)
client, _, err := dialWebsocket(cableURL, nil)
require.NoError(t, err)
defer client.Close()
server := (<-cableServerConns).conn
defer server.Close()
require.NoError(t, say(client, "hello"))
assertReadMessage(t, server, websocket.TextMessage, "hello")
require.NoError(t, say(server, "world"))
assertReadMessage(t, client, websocket.TextMessage, "world")
}
func startCableServer() (chan connWithReq, *httptest.Server) {
upgrader := &websocket.Upgrader{}
connCh := make(chan connWithReq, 1)
server := testhelper.TestServerWithHandler(regexp.MustCompile(cablePath), webSocketHandler(upgrader, connCh))
return connCh, server
}
func newUpstreamWithCableConfig(authBackend string, cableBackend string) *config.Config {
var cableBackendURL *url.URL
if cableBackend != "" {
cableBackendURL = helper.URLMustParse(cableBackend)
}
return &config.Config{
Version: "123",
DocumentRoot: testDocumentRoot,
Backend: helper.URLMustParse(authBackend),
CableBackend: cableBackendURL,
}
}
...@@ -171,7 +171,13 @@ func startWebsocketServer(subprotocols ...string) (chan connWithReq, *httptest.S ...@@ -171,7 +171,13 @@ func startWebsocketServer(subprotocols ...string) (chan connWithReq, *httptest.S
upgrader := &websocket.Upgrader{Subprotocols: subprotocols} upgrader := &websocket.Upgrader{Subprotocols: subprotocols}
connCh := make(chan connWithReq, 1) connCh := make(chan connWithReq, 1)
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewTLSServer(webSocketHandler(upgrader, connCh))
return connCh, server
}
func webSocketHandler(upgrader *websocket.Upgrader, connCh chan connWithReq) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
logEntry := log.WithFields(log.Fields{ logEntry := log.WithFields(log.Fields{
"method": r.Method, "method": r.Method,
"url": r.URL, "url": r.URL,
...@@ -186,9 +192,7 @@ func startWebsocketServer(subprotocols ...string) (chan connWithReq, *httptest.S ...@@ -186,9 +192,7 @@ func startWebsocketServer(subprotocols ...string) (chan connWithReq, *httptest.S
} }
connCh <- connWithReq{conn, r} connCh <- connWithReq{conn, r}
// The connection has been hijacked so it's OK to end here // The connection has been hijacked so it's OK to end here
})) })
return connCh, server
} }
func channelOkBody(remote *httptest.Server, header http.Header, subprotocols ...string) *api.Response { func channelOkBody(remote *httptest.Server, header http.Header, subprotocols ...string) *api.Response {
......
...@@ -43,10 +43,12 @@ type RedisConfig struct { ...@@ -43,10 +43,12 @@ type RedisConfig struct {
type Config struct { type Config struct {
Redis *RedisConfig `toml:"redis"` Redis *RedisConfig `toml:"redis"`
Backend *url.URL `toml:"-"` Backend *url.URL `toml:"-"`
CableBackend *url.URL `toml:"-"`
Version string `toml:"-"` Version string `toml:"-"`
DocumentRoot string `toml:"-"` DocumentRoot string `toml:"-"`
DevelopmentMode bool `toml:"-"` DevelopmentMode bool `toml:"-"`
Socket string `toml:"-"` Socket string `toml:"-"`
CableSocket string `toml:"-"`
ProxyHeadersTimeout time.Duration `toml:"-"` ProxyHeadersTimeout time.Duration `toml:"-"`
APILimit uint `toml:"-"` APILimit uint `toml:"-"`
APIQueueLimit uint `toml:"-"` APIQueueLimit uint `toml:"-"`
......
...@@ -162,6 +162,7 @@ func (u *upstream) configureRoutes() { ...@@ -162,6 +162,7 @@ func (u *upstream) configureRoutes() {
static := &staticpages.Static{DocumentRoot: u.DocumentRoot} static := &staticpages.Static{DocumentRoot: u.DocumentRoot}
proxy := buildProxy(u.Backend, u.Version, u.RoundTripper) proxy := buildProxy(u.Backend, u.Version, u.RoundTripper)
cableProxy := proxypkg.NewProxy(u.CableBackend, u.Version, u.CableRoundTripper)
signingTripper := secret.NewRoundTripper(u.RoundTripper, u.Version) signingTripper := secret.NewRoundTripper(u.RoundTripper, u.Version)
signingProxy := buildProxy(u.Backend, u.Version, signingTripper) signingProxy := buildProxy(u.Backend, u.Version, signingTripper)
...@@ -191,6 +192,9 @@ func (u *upstream) configureRoutes() { ...@@ -191,6 +192,9 @@ func (u *upstream) configureRoutes() {
route("POST", apiPattern+`v4/jobs/[0-9]+/artifacts\z`, contentEncodingHandler(artifacts.UploadArtifacts(api, signingProxy))), route("POST", apiPattern+`v4/jobs/[0-9]+/artifacts\z`, contentEncodingHandler(artifacts.UploadArtifacts(api, signingProxy))),
route("POST", ciAPIPattern+`v1/builds/[0-9]+/artifacts\z`, contentEncodingHandler(artifacts.UploadArtifacts(api, signingProxy))), route("POST", ciAPIPattern+`v1/builds/[0-9]+/artifacts\z`, contentEncodingHandler(artifacts.UploadArtifacts(api, signingProxy))),
// ActionCable websocket
wsRoute(`^/-/cable\z`, cableProxy),
// Terminal websocket // Terminal websocket
wsRoute(projectPattern+`-/environments/[0-9]+/terminal.ws\z`, channel.Handler(api)), wsRoute(projectPattern+`-/environments/[0-9]+/terminal.ws\z`, channel.Handler(api)),
wsRoute(projectPattern+`-/jobs/[0-9]+/terminal.ws\z`, channel.Handler(api)), wsRoute(projectPattern+`-/jobs/[0-9]+/terminal.ws\z`, channel.Handler(api)),
......
...@@ -32,9 +32,10 @@ var ( ...@@ -32,9 +32,10 @@ var (
type upstream struct { type upstream struct {
config.Config config.Config
URLPrefix urlprefix.Prefix URLPrefix urlprefix.Prefix
Routes []routeEntry Routes []routeEntry
RoundTripper http.RoundTripper RoundTripper http.RoundTripper
CableRoundTripper http.RoundTripper
} }
func NewUpstream(cfg config.Config, accessLogger *logrus.Logger) http.Handler { func NewUpstream(cfg config.Config, accessLogger *logrus.Logger) http.Handler {
...@@ -44,7 +45,14 @@ func NewUpstream(cfg config.Config, accessLogger *logrus.Logger) http.Handler { ...@@ -44,7 +45,14 @@ func NewUpstream(cfg config.Config, accessLogger *logrus.Logger) http.Handler {
if up.Backend == nil { if up.Backend == nil {
up.Backend = DefaultBackend up.Backend = DefaultBackend
} }
if up.CableBackend == nil {
up.CableBackend = up.Backend
}
if up.CableSocket == "" {
up.CableSocket = up.Socket
}
up.RoundTripper = roundtripper.NewBackendRoundTripper(up.Backend, up.Socket, up.ProxyHeadersTimeout, cfg.DevelopmentMode) up.RoundTripper = roundtripper.NewBackendRoundTripper(up.Backend, up.Socket, up.ProxyHeadersTimeout, cfg.DevelopmentMode)
up.CableRoundTripper = roundtripper.NewBackendRoundTripper(up.CableBackend, up.CableSocket, up.ProxyHeadersTimeout, cfg.DevelopmentMode)
up.configureURLPrefix() up.configureURLPrefix()
up.configureRoutes() up.configureRoutes()
......
...@@ -46,6 +46,8 @@ var listenNetwork = flag.String("listenNetwork", "tcp", "Listen 'network' (tcp, ...@@ -46,6 +46,8 @@ var listenNetwork = flag.String("listenNetwork", "tcp", "Listen 'network' (tcp,
var listenUmask = flag.Int("listenUmask", 0, "Umask for Unix socket") var listenUmask = flag.Int("listenUmask", 0, "Umask for Unix socket")
var authBackend = flag.String("authBackend", upstream.DefaultBackend.String(), "Authentication/authorization backend") var authBackend = flag.String("authBackend", upstream.DefaultBackend.String(), "Authentication/authorization backend")
var authSocket = flag.String("authSocket", "", "Optional: Unix domain socket to dial authBackend at") var authSocket = flag.String("authSocket", "", "Optional: Unix domain socket to dial authBackend at")
var cableBackend = flag.String("cableBackend", upstream.DefaultBackend.String(), "ActionCable backend")
var cableSocket = flag.String("cableSocket", "", "Optional: Unix domain socket to dial cableBackend at")
var pprofListenAddr = flag.String("pprofListenAddr", "", "pprof listening address, e.g. 'localhost:6060'") var pprofListenAddr = flag.String("pprofListenAddr", "", "pprof listening address, e.g. 'localhost:6060'")
var documentRoot = flag.String("documentRoot", "public", "Path to static files content") var documentRoot = flag.String("documentRoot", "public", "Path to static files content")
var proxyHeadersTimeout = flag.Duration("proxyHeadersTimeout", 5*time.Minute, "How long to wait for response headers when proxying the request") var proxyHeadersTimeout = flag.Duration("proxyHeadersTimeout", 5*time.Minute, "How long to wait for response headers when proxying the request")
...@@ -91,6 +93,11 @@ func main() { ...@@ -91,6 +93,11 @@ func main() {
log.WithError(err).Fatal("Invalid authBackend") log.WithError(err).Fatal("Invalid authBackend")
} }
cableBackendURL, err := parseAuthBackend(*cableBackend)
if err != nil {
log.WithError(err).Fatal("Invalid cableBackend")
}
log.WithField("version", Version).WithField("build_time", BuildTime).Print("Starting") log.WithField("version", Version).WithField("build_time", BuildTime).Print("Starting")
// Good housekeeping for Unix sockets: unlink before binding // Good housekeeping for Unix sockets: unlink before binding
...@@ -137,7 +144,9 @@ func main() { ...@@ -137,7 +144,9 @@ func main() {
secret.SetPath(*secretPath) secret.SetPath(*secretPath)
cfg := config.Config{ cfg := config.Config{
Backend: backendURL, Backend: backendURL,
CableBackend: cableBackendURL,
Socket: *authSocket, Socket: *authSocket,
CableSocket: *cableSocket,
Version: Version, Version: Version,
DocumentRoot: *documentRoot, DocumentRoot: *documentRoot,
DevelopmentMode: *developmentMode, DevelopmentMode: *developmentMode,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment