Commit 866610fa authored by Nick Thomas's avatar Nick Thomas

Merge branch 'proxy-actioncable-ws-route' into 'master'

Proxy ActionCable websocket connection

See merge request gitlab-org/gitlab-workhorse!454
parents 73d42cb5 99d40e17
......@@ -48,6 +48,10 @@ Options:
Authentication/authorization backend (default "http://localhost:8080")
-authSocket string
Optional: Unix domain socket to dial authBackend at
-cableBackend string
Optional: ActionCable backend (default authBackend)
-cableSocket string
Optional: Unix domain socket to dial cableBackend at (default authSocket)
-config string
TOML file to load config from
-developmentMode
......@@ -164,6 +168,8 @@ In table form:
|`http://localhost:3000`|`/path/to/socket`|`/path/to/socket`|`/`|
|`http://localhost:3000/gitlab`|`/path/to/socket`|`/path/to/socket`|`/gitlab`|
The same applies to `cableBackend` and `cableSocket`.
## Installation
To install gitlab-workhorse you need [Go 1.8 or
......
package main
import (
"net/http"
"net/http/httptest"
"net/url"
"regexp"
"testing"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/require"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/config"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper"
)
const cablePath = "/-/cable"
func TestSingleBackend(t *testing.T) {
cableServerConns, cableBackendServer := startCableServer()
defer cableBackendServer.Close()
config := newUpstreamWithCableConfig(cableBackendServer.URL, "")
workhorse := startWorkhorseServerWithConfig(config)
defer workhorse.Close()
cableURL := websocketURL(workhorse.URL, cablePath)
client, _, err := dialWebsocket(cableURL, nil)
require.NoError(t, err)
defer client.Close()
server := (<-cableServerConns).conn
defer server.Close()
require.NoError(t, say(client, "hello"))
assertReadMessage(t, server, websocket.TextMessage, "hello")
require.NoError(t, say(server, "world"))
assertReadMessage(t, client, websocket.TextMessage, "world")
}
func TestSeparateCableBackend(t *testing.T) {
authBackendServer := testhelper.TestServerWithHandler(regexp.MustCompile(`.`), http.HandlerFunc(http.NotFound))
defer authBackendServer.Close()
cableServerConns, cableBackendServer := startCableServer()
defer cableBackendServer.Close()
config := newUpstreamWithCableConfig(authBackendServer.URL, cableBackendServer.URL)
workhorse := startWorkhorseServerWithConfig(config)
defer workhorse.Close()
cableURL := websocketURL(workhorse.URL, cablePath)
client, _, err := dialWebsocket(cableURL, nil)
require.NoError(t, err)
defer client.Close()
server := (<-cableServerConns).conn
defer server.Close()
require.NoError(t, say(client, "hello"))
assertReadMessage(t, server, websocket.TextMessage, "hello")
require.NoError(t, say(server, "world"))
assertReadMessage(t, client, websocket.TextMessage, "world")
}
func startCableServer() (chan connWithReq, *httptest.Server) {
upgrader := &websocket.Upgrader{}
connCh := make(chan connWithReq, 1)
server := testhelper.TestServerWithHandler(regexp.MustCompile(cablePath), webSocketHandler(upgrader, connCh))
return connCh, server
}
func newUpstreamWithCableConfig(authBackend string, cableBackend string) *config.Config {
var cableBackendURL *url.URL
if cableBackend != "" {
cableBackendURL = helper.URLMustParse(cableBackend)
}
return &config.Config{
Version: "123",
DocumentRoot: testDocumentRoot,
Backend: helper.URLMustParse(authBackend),
CableBackend: cableBackendURL,
}
}
......@@ -171,7 +171,13 @@ func startWebsocketServer(subprotocols ...string) (chan connWithReq, *httptest.S
upgrader := &websocket.Upgrader{Subprotocols: subprotocols}
connCh := make(chan connWithReq, 1)
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
server := httptest.NewTLSServer(webSocketHandler(upgrader, connCh))
return connCh, server
}
func webSocketHandler(upgrader *websocket.Upgrader, connCh chan connWithReq) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
logEntry := log.WithFields(log.Fields{
"method": r.Method,
"url": r.URL,
......@@ -186,9 +192,7 @@ func startWebsocketServer(subprotocols ...string) (chan connWithReq, *httptest.S
}
connCh <- connWithReq{conn, r}
// The connection has been hijacked so it's OK to end here
}))
return connCh, server
})
}
func channelOkBody(remote *httptest.Server, header http.Header, subprotocols ...string) *api.Response {
......
......@@ -43,10 +43,12 @@ type RedisConfig struct {
type Config struct {
Redis *RedisConfig `toml:"redis"`
Backend *url.URL `toml:"-"`
CableBackend *url.URL `toml:"-"`
Version string `toml:"-"`
DocumentRoot string `toml:"-"`
DevelopmentMode bool `toml:"-"`
Socket string `toml:"-"`
CableSocket string `toml:"-"`
ProxyHeadersTimeout time.Duration `toml:"-"`
APILimit uint `toml:"-"`
APIQueueLimit uint `toml:"-"`
......
......@@ -162,6 +162,7 @@ func (u *upstream) configureRoutes() {
static := &staticpages.Static{DocumentRoot: u.DocumentRoot}
proxy := buildProxy(u.Backend, u.Version, u.RoundTripper)
cableProxy := proxypkg.NewProxy(u.CableBackend, u.Version, u.CableRoundTripper)
signingTripper := secret.NewRoundTripper(u.RoundTripper, u.Version)
signingProxy := buildProxy(u.Backend, u.Version, signingTripper)
......@@ -191,6 +192,9 @@ func (u *upstream) configureRoutes() {
route("POST", apiPattern+`v4/jobs/[0-9]+/artifacts\z`, contentEncodingHandler(artifacts.UploadArtifacts(api, signingProxy))),
route("POST", ciAPIPattern+`v1/builds/[0-9]+/artifacts\z`, contentEncodingHandler(artifacts.UploadArtifacts(api, signingProxy))),
// ActionCable websocket
wsRoute(`^/-/cable\z`, cableProxy),
// Terminal websocket
wsRoute(projectPattern+`-/environments/[0-9]+/terminal.ws\z`, channel.Handler(api)),
wsRoute(projectPattern+`-/jobs/[0-9]+/terminal.ws\z`, channel.Handler(api)),
......
......@@ -32,9 +32,10 @@ var (
type upstream struct {
config.Config
URLPrefix urlprefix.Prefix
Routes []routeEntry
RoundTripper http.RoundTripper
URLPrefix urlprefix.Prefix
Routes []routeEntry
RoundTripper http.RoundTripper
CableRoundTripper http.RoundTripper
}
func NewUpstream(cfg config.Config, accessLogger *logrus.Logger) http.Handler {
......@@ -44,7 +45,14 @@ func NewUpstream(cfg config.Config, accessLogger *logrus.Logger) http.Handler {
if up.Backend == nil {
up.Backend = DefaultBackend
}
if up.CableBackend == nil {
up.CableBackend = up.Backend
}
if up.CableSocket == "" {
up.CableSocket = up.Socket
}
up.RoundTripper = roundtripper.NewBackendRoundTripper(up.Backend, up.Socket, up.ProxyHeadersTimeout, cfg.DevelopmentMode)
up.CableRoundTripper = roundtripper.NewBackendRoundTripper(up.CableBackend, up.CableSocket, up.ProxyHeadersTimeout, cfg.DevelopmentMode)
up.configureURLPrefix()
up.configureRoutes()
......
......@@ -46,6 +46,8 @@ var listenNetwork = flag.String("listenNetwork", "tcp", "Listen 'network' (tcp,
var listenUmask = flag.Int("listenUmask", 0, "Umask for Unix socket")
var authBackend = flag.String("authBackend", upstream.DefaultBackend.String(), "Authentication/authorization backend")
var authSocket = flag.String("authSocket", "", "Optional: Unix domain socket to dial authBackend at")
var cableBackend = flag.String("cableBackend", upstream.DefaultBackend.String(), "ActionCable backend")
var cableSocket = flag.String("cableSocket", "", "Optional: Unix domain socket to dial cableBackend at")
var pprofListenAddr = flag.String("pprofListenAddr", "", "pprof listening address, e.g. 'localhost:6060'")
var documentRoot = flag.String("documentRoot", "public", "Path to static files content")
var proxyHeadersTimeout = flag.Duration("proxyHeadersTimeout", 5*time.Minute, "How long to wait for response headers when proxying the request")
......@@ -91,6 +93,11 @@ func main() {
log.WithError(err).Fatal("Invalid authBackend")
}
cableBackendURL, err := parseAuthBackend(*cableBackend)
if err != nil {
log.WithError(err).Fatal("Invalid cableBackend")
}
log.WithField("version", Version).WithField("build_time", BuildTime).Print("Starting")
// Good housekeeping for Unix sockets: unlink before binding
......@@ -137,7 +144,9 @@ func main() {
secret.SetPath(*secretPath)
cfg := config.Config{
Backend: backendURL,
CableBackend: cableBackendURL,
Socket: *authSocket,
CableSocket: *cableSocket,
Version: Version,
DocumentRoot: *documentRoot,
DevelopmentMode: *developmentMode,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment