Commit ddd53598 authored by Nick Thomas's avatar Nick Thomas

Merge branch 'jv-config-struct' into 'master'

Simplify config handling in main()

See merge request gitlab-org/gitlab-workhorse!634
parents ecec37a2 8662516a
#!/bin/sh #!/bin/sh
set -eu
IMPORT_RESULT=$(goimports -e -local "gitlab.com/gitlab-org/gitlab-workhorse" -l "$@") IMPORT_RESULT=$(goimports -e -local "gitlab.com/gitlab-org/gitlab-workhorse" -l "$@")
if [ -n "${IMPORT_RESULT}" ]; then if [ -n "${IMPORT_RESULT}" ]; then
......
---
title: Simplify config handling in main()
merge_request: 634
author:
type: other
package main
import (
"flag"
"io"
"io/ioutil"
"net/url"
"os"
"testing"
"time"
"github.com/stretchr/testify/require"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/config"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/queueing"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/upstream"
)
func TestConfigFile(t *testing.T) {
f, err := ioutil.TempFile("", "workhorse-config-test")
require.NoError(t, err)
defer os.Remove(f.Name())
data := `
[redis]
password = "redis password"
[object_storage]
provider = "test provider"
[image_resizer]
max_scaler_procs = 123
`
_, err = io.WriteString(f, data)
require.NoError(t, err)
require.NoError(t, f.Close())
_, cfg, err := buildConfig("test", []string{"-config", f.Name()})
require.NoError(t, err, "build config")
// These are integration tests: we want to see that each section in the
// config file ends up in the config struct. We do not test all the
// fields in each section; that should happen in the tests of the
// internal/config package.
require.Equal(t, "redis password", cfg.Redis.Password)
require.Equal(t, "test provider", cfg.ObjectStorageCredentials.Provider)
require.Equal(t, uint32(123), cfg.ImageResizerConfig.MaxScalerProcs, "image resizer max_scaler_procs")
}
func TestConfigErrorHelp(t *testing.T) {
for _, f := range []string{"-h", "-help"} {
t.Run(f, func(t *testing.T) {
_, _, err := buildConfig("test", []string{f})
require.Equal(t, alreadyPrintedError{flag.ErrHelp}, err)
})
}
}
func TestConfigError(t *testing.T) {
for _, arg := range []string{"-foobar", "foobar"} {
t.Run(arg, func(t *testing.T) {
_, _, err := buildConfig("test", []string{arg})
require.Error(t, err)
require.IsType(t, alreadyPrintedError{}, err)
})
}
}
func TestConfigDefaults(t *testing.T) {
boot, cfg, err := buildConfig("test", nil)
require.NoError(t, err, "build config")
expectedBoot := &bootConfig{
secretPath: "./.gitlab_workhorse_secret",
listenAddr: "localhost:8181",
listenNetwork: "tcp",
logFormat: "text",
}
require.Equal(t, expectedBoot, boot)
expectedCfg := &config.Config{
Backend: upstream.DefaultBackend,
CableBackend: upstream.DefaultBackend,
Version: "(unknown version)",
DocumentRoot: "public",
ProxyHeadersTimeout: 5 * time.Minute,
APIQueueTimeout: queueing.DefaultTimeout,
APICILongPollingDuration: 50 * time.Nanosecond, // TODO this is meant to be 50*time.Second but it has been wrong for ages
ImageResizerConfig: config.DefaultImageResizerConfig,
}
require.Equal(t, expectedCfg, cfg)
}
func TestConfigFlagParsing(t *testing.T) {
backendURL, err := url.Parse("http://localhost:1234")
require.NoError(t, err)
cableURL, err := url.Parse("http://localhost:5678")
require.NoError(t, err)
args := []string{
"-version",
"-secretPath", "secret path",
"-listenAddr", "listen addr",
"-listenNetwork", "listen network",
"-listenUmask", "123",
"-pprofListenAddr", "pprof listen addr",
"-prometheusListenAddr", "prometheus listen addr",
"-logFile", "log file",
"-logFormat", "log format",
"-documentRoot", "document root",
"-developmentMode",
"-authBackend", backendURL.String(),
"-authSocket", "auth socket",
"-cableBackend", cableURL.String(),
"-cableSocket", "cable socket",
"-proxyHeadersTimeout", "10m",
"-apiLimit", "234",
"-apiQueueLimit", "345",
"-apiQueueDuration", "123s",
"-apiCiLongPollingDuration", "234s",
"-propagateCorrelationID",
}
boot, cfg, err := buildConfig("test", args)
require.NoError(t, err, "build config")
expectedBoot := &bootConfig{
secretPath: "secret path",
listenAddr: "listen addr",
listenNetwork: "listen network",
listenUmask: 123,
pprofListenAddr: "pprof listen addr",
prometheusListenAddr: "prometheus listen addr",
logFile: "log file",
logFormat: "log format",
printVersion: true,
}
require.Equal(t, expectedBoot, boot)
expectedCfg := &config.Config{
DocumentRoot: "document root",
DevelopmentMode: true,
Backend: backendURL,
Socket: "auth socket",
CableBackend: cableURL,
CableSocket: "cable socket",
Version: "(unknown version)",
ProxyHeadersTimeout: 10 * time.Minute,
APILimit: 234,
APIQueueLimit: 345,
APIQueueTimeout: 123 * time.Second,
APICILongPollingDuration: 234 * time.Second,
PropagateCorrelationID: true,
ImageResizerConfig: config.DefaultImageResizerConfig,
}
require.Equal(t, expectedCfg, cfg)
}
...@@ -115,11 +115,10 @@ var DefaultImageResizerConfig = &ImageResizerConfig{ ...@@ -115,11 +115,10 @@ var DefaultImageResizerConfig = &ImageResizerConfig{
MaxFilesize: DefaultImageResizerMaxFilesize, MaxFilesize: DefaultImageResizerMaxFilesize,
} }
// LoadConfig from a file func LoadConfig(data string) (*Config, error) {
func LoadConfig(filename string) (*Config, error) {
cfg := &Config{ImageResizerConfig: DefaultImageResizerConfig} cfg := &Config{ImageResizerConfig: DefaultImageResizerConfig}
if _, err := toml.DecodeFile(filename, cfg); err != nil { if _, err := toml.Decode(data, cfg); err != nil {
return nil, err return nil, err
} }
......
package config package config
import ( import (
"io/ioutil"
"os"
"testing" "testing"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
...@@ -20,13 +18,20 @@ azure_storage_access_key = "deadbeef" ...@@ -20,13 +18,20 @@ azure_storage_access_key = "deadbeef"
func TestLoadEmptyConfig(t *testing.T) { func TestLoadEmptyConfig(t *testing.T) {
config := `` config := ``
tmpFile, cfg := loadTempConfig(t, config) cfg, err := LoadConfig(config)
defer os.Remove(tmpFile.Name()) require.NoError(t, err)
require.Nil(t, cfg.ObjectStorageCredentials) expected := Config{
ImageResizerConfig: &ImageResizerConfig{
MaxScalerProcs: DefaultImageResizerMaxScalerProcs,
MaxFilesize: DefaultImageResizerMaxFilesize,
},
}
err := cfg.RegisterGoCloudURLOpeners() require.Equal(t, expected, *cfg)
require.NoError(t, err)
require.Nil(t, cfg.ObjectStorageCredentials)
require.NoError(t, cfg.RegisterGoCloudURLOpeners())
} }
func TestLoadObjectStorageConfig(t *testing.T) { func TestLoadObjectStorageConfig(t *testing.T) {
...@@ -39,8 +44,8 @@ aws_access_key_id = "minio" ...@@ -39,8 +44,8 @@ aws_access_key_id = "minio"
aws_secret_access_key = "gdk-minio" aws_secret_access_key = "gdk-minio"
` `
tmpFile, cfg := loadTempConfig(t, config) cfg, err := LoadConfig(config)
defer os.Remove(tmpFile.Name()) require.NoError(t, err)
require.NotNil(t, cfg.ObjectStorageCredentials, "Expected object storage credentials") require.NotNil(t, cfg.ObjectStorageCredentials, "Expected object storage credentials")
...@@ -56,8 +61,8 @@ aws_secret_access_key = "gdk-minio" ...@@ -56,8 +61,8 @@ aws_secret_access_key = "gdk-minio"
} }
func TestRegisterGoCloudURLOpeners(t *testing.T) { func TestRegisterGoCloudURLOpeners(t *testing.T) {
tmpFile, cfg := loadTempConfig(t, azureConfig) cfg, err := LoadConfig(azureConfig)
defer os.Remove(tmpFile.Name()) require.NoError(t, err)
require.NotNil(t, cfg.ObjectStorageCredentials, "Expected object storage credentials") require.NotNil(t, cfg.ObjectStorageCredentials, "Expected object storage credentials")
...@@ -72,30 +77,13 @@ func TestRegisterGoCloudURLOpeners(t *testing.T) { ...@@ -72,30 +77,13 @@ func TestRegisterGoCloudURLOpeners(t *testing.T) {
require.Equal(t, expected, *cfg.ObjectStorageCredentials) require.Equal(t, expected, *cfg.ObjectStorageCredentials)
require.Nil(t, cfg.ObjectStorageConfig.URLMux) require.Nil(t, cfg.ObjectStorageConfig.URLMux)
err := cfg.RegisterGoCloudURLOpeners() require.NoError(t, cfg.RegisterGoCloudURLOpeners())
require.NoError(t, err)
require.NotNil(t, cfg.ObjectStorageConfig.URLMux) require.NotNil(t, cfg.ObjectStorageConfig.URLMux)
require.True(t, cfg.ObjectStorageConfig.URLMux.ValidBucketScheme("azblob")) require.True(t, cfg.ObjectStorageConfig.URLMux.ValidBucketScheme("azblob"))
require.Equal(t, []string{"azblob"}, cfg.ObjectStorageConfig.URLMux.BucketSchemes()) require.Equal(t, []string{"azblob"}, cfg.ObjectStorageConfig.URLMux.BucketSchemes())
} }
func TestLoadDefaultConfig(t *testing.T) {
config := ``
tmpFile, cfg := loadTempConfig(t, config)
defer os.Remove(tmpFile.Name())
expected := Config{
ImageResizerConfig: &ImageResizerConfig{
MaxScalerProcs: DefaultImageResizerMaxScalerProcs,
MaxFilesize: DefaultImageResizerMaxFilesize,
},
}
require.Equal(t, expected, *cfg)
}
func TestLoadImageResizerConfig(t *testing.T) { func TestLoadImageResizerConfig(t *testing.T) {
config := ` config := `
[image_resizer] [image_resizer]
...@@ -103,8 +91,8 @@ max_scaler_procs = 200 ...@@ -103,8 +91,8 @@ max_scaler_procs = 200
max_filesize = 350000 max_filesize = 350000
` `
tmpFile, cfg := loadTempConfig(t, config) cfg, err := LoadConfig(config)
defer os.Remove(tmpFile.Name()) require.NoError(t, err)
require.NotNil(t, cfg.ImageResizerConfig, "Expected image resizer config") require.NotNil(t, cfg.ImageResizerConfig, "Expected image resizer config")
...@@ -115,16 +103,3 @@ max_filesize = 350000 ...@@ -115,16 +103,3 @@ max_filesize = 350000
require.Equal(t, expected, *cfg.ImageResizerConfig) require.Equal(t, expected, *cfg.ImageResizerConfig)
} }
func loadTempConfig(t *testing.T, config string) (f *os.File, cfg *Config) {
tmpFile, err := ioutil.TempFile(os.TempDir(), "test-")
require.NoError(t, err)
_, err = tmpFile.Write([]byte(config))
require.NoError(t, err)
cfg, err = LoadConfig(tmpFile.Name())
require.NoError(t, err)
return tmpFile, cfg
}
...@@ -3,7 +3,6 @@ package config ...@@ -3,7 +3,6 @@ package config
import ( import (
"context" "context"
"net/url" "net/url"
"os"
"testing" "testing"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
...@@ -11,13 +10,12 @@ import ( ...@@ -11,13 +10,12 @@ import (
) )
func TestURLOpeners(t *testing.T) { func TestURLOpeners(t *testing.T) {
tmpFile, cfg := loadTempConfig(t, azureConfig) cfg, err := LoadConfig(azureConfig)
defer os.Remove(tmpFile.Name()) require.NoError(t, err)
require.NotNil(t, cfg.ObjectStorageCredentials, "Expected object storage credentials") require.NotNil(t, cfg.ObjectStorageCredentials, "Expected object storage credentials")
err := cfg.RegisterGoCloudURLOpeners() require.NoError(t, cfg.RegisterGoCloudURLOpeners())
require.NoError(t, err)
require.NotNil(t, cfg.ObjectStorageConfig.URLMux) require.NotNil(t, cfg.ObjectStorageConfig.URLMux)
tests := []struct { tests := []struct {
......
...@@ -18,26 +18,20 @@ const ( ...@@ -18,26 +18,20 @@ const (
noneLogType = "none" noneLogType = "none"
) )
type logConfiguration struct { func startLogging(file string, format string) (io.Closer, error) {
logFile string
logFormat string
}
func startLogging(config logConfiguration) (io.Closer, error) {
// Golog always goes to stderr // Golog always goes to stderr
goLog.SetOutput(os.Stderr) goLog.SetOutput(os.Stderr)
logFile := config.logFile if file == "" {
if logFile == "" { file = "stderr"
logFile = "stderr"
} }
switch config.logFormat { switch format {
case noneLogType: case noneLogType:
return logkit.Initialize(logkit.WithWriter(ioutil.Discard)) return logkit.Initialize(logkit.WithWriter(ioutil.Discard))
case jsonLogFormat: case jsonLogFormat:
return logkit.Initialize( return logkit.Initialize(
logkit.WithOutputName(logFile), logkit.WithOutputName(file),
logkit.WithFormatter("json"), logkit.WithFormatter("json"),
) )
case textLogFormat: case textLogFormat:
...@@ -48,23 +42,22 @@ func startLogging(config logConfiguration) (io.Closer, error) { ...@@ -48,23 +42,22 @@ func startLogging(config logConfiguration) (io.Closer, error) {
) )
case structuredFormat: case structuredFormat:
return logkit.Initialize( return logkit.Initialize(
logkit.WithOutputName(logFile), logkit.WithOutputName(file),
logkit.WithFormatter("color"), logkit.WithFormatter("color"),
) )
} }
return nil, fmt.Errorf("unknown logFormat: %v", config.logFormat) return nil, fmt.Errorf("unknown logFormat: %v", format)
} }
// In text format, we use a separate logger for access logs // In text format, we use a separate logger for access logs
func getAccessLogger(config logConfiguration) (*log.Logger, io.Closer, error) { func getAccessLogger(file string, format string) (*log.Logger, io.Closer, error) {
if config.logFormat != "text" { if format != "text" {
return log.StandardLogger(), ioutil.NopCloser(nil), nil return log.StandardLogger(), ioutil.NopCloser(nil), nil
} }
logFile := config.logFile if file == "" {
if logFile == "" { file = "stderr"
logFile = "stderr"
} }
accessLogger := log.New() accessLogger := log.New()
...@@ -72,7 +65,7 @@ func getAccessLogger(config logConfiguration) (*log.Logger, io.Closer, error) { ...@@ -72,7 +65,7 @@ func getAccessLogger(config logConfiguration) (*log.Logger, io.Closer, error) {
closer, err := logkit.Initialize( closer, err := logkit.Initialize(
logkit.WithLogger(accessLogger), // Configure `accessLogger` logkit.WithLogger(accessLogger), // Configure `accessLogger`
logkit.WithFormatter("combined"), // Use the combined formatter logkit.WithFormatter("combined"), // Use the combined formatter
logkit.WithOutputName(logFile), logkit.WithOutputName(file),
) )
return accessLogger, closer, err return accessLogger, closer, err
......
...@@ -16,6 +16,7 @@ package main ...@@ -16,6 +16,7 @@ package main
import ( import (
"flag" "flag"
"fmt" "fmt"
"io/ioutil"
"net" "net"
"net/http" "net/http"
_ "net/http/pprof" _ "net/http/pprof"
...@@ -36,85 +37,151 @@ import ( ...@@ -36,85 +37,151 @@ import (
// Version is the current version of GitLab Workhorse // Version is the current version of GitLab Workhorse
var Version = "(unknown version)" // Set at build time in the Makefile var Version = "(unknown version)" // Set at build time in the Makefile
// BuildTime signifies the time the binary was build // BuildTime signifies the time the binary was build
var BuildTime = "19700101.000000" // Set at build time in the Makefile var BuildTime = "19700101.000000" // Set at build time in the Makefile
var printVersion = flag.Bool("version", false, "Print version and exit") type bootConfig struct {
var configFile = flag.String("config", "", "TOML file to load config from") secretPath string
var listenAddr = flag.String("listenAddr", "localhost:8181", "Listen address for HTTP server") listenAddr string
var listenNetwork = flag.String("listenNetwork", "tcp", "Listen 'network' (tcp, tcp4, tcp6, unix)") listenNetwork string
var listenUmask = flag.Int("listenUmask", 0, "Umask for Unix socket") listenUmask int
var authBackend = flag.String("authBackend", upstream.DefaultBackend.String(), "Authentication/authorization backend") pprofListenAddr string
var authSocket = flag.String("authSocket", "", "Optional: Unix domain socket to dial authBackend at") prometheusListenAddr string
var cableBackend = flag.String("cableBackend", upstream.DefaultBackend.String(), "ActionCable backend") logFile string
var cableSocket = flag.String("cableSocket", "", "Optional: Unix domain socket to dial cableBackend at") logFormat string
var pprofListenAddr = flag.String("pprofListenAddr", "", "pprof listening address, e.g. 'localhost:6060'") printVersion bool
var documentRoot = flag.String("documentRoot", "public", "Path to static files content")
var proxyHeadersTimeout = flag.Duration("proxyHeadersTimeout", 5*time.Minute, "How long to wait for response headers when proxying the request")
var developmentMode = flag.Bool("developmentMode", false, "Allow the assets to be served from Rails app")
var secretPath = flag.String("secretPath", "./.gitlab_workhorse_secret", "File with secret key to authenticate with authBackend")
var apiLimit = flag.Uint("apiLimit", 0, "Number of API requests allowed at single time")
var apiQueueLimit = flag.Uint("apiQueueLimit", 0, "Number of API requests allowed to be queued")
var apiQueueTimeout = flag.Duration("apiQueueDuration", queueing.DefaultTimeout, "Maximum queueing duration of requests")
var apiCiLongPollingDuration = flag.Duration("apiCiLongPollingDuration", 50, "Long polling duration for job requesting for runners (default 50s - enabled)")
var propagateCorrelationID = flag.Bool("propagateCorrelationID", false, "Reuse existing Correlation-ID from the incoming request header `X-Request-ID` if present")
var prometheusListenAddr = flag.String("prometheusListenAddr", "", "Prometheus listening address, e.g. 'localhost:9229'")
var logConfig = logConfiguration{}
func init() {
flag.StringVar(&logConfig.logFile, "logFile", "", "Log file location")
flag.StringVar(&logConfig.logFormat, "logFormat", "text", "Log format to use defaults to text (text, json, structured, none)")
} }
func main() { func main() {
flag.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage of %s:\n", os.Args[0]) boot, cfg, err := buildConfig(os.Args[0], os.Args[1:])
fmt.Fprintf(os.Stderr, "\n %s [OPTIONS]\n\nOptions:\n", os.Args[0]) if err == (alreadyPrintedError{flag.ErrHelp}) {
flag.PrintDefaults() os.Exit(0)
}
if err != nil {
if _, alreadyPrinted := err.(alreadyPrintedError); !alreadyPrinted {
fmt.Fprintln(os.Stderr, err)
}
os.Exit(2)
} }
flag.Parse()
if *printVersion { if boot.printVersion {
fmt.Printf("gitlab-workhorse %s-%s\n", Version, BuildTime) fmt.Printf("gitlab-workhorse %s-%s\n", Version, BuildTime)
os.Exit(0) os.Exit(0)
} }
log.WithError(run()).Fatal("shutting down") log.WithError(run(*boot, *cfg)).Fatal("shutting down")
} }
func run() error { type alreadyPrintedError struct{ error }
closer, err := startLogging(logConfig)
// buildConfig may print messages to os.Stderr if err != nil. If err is
// of type alreadyPrintedError it has already been printed.
func buildConfig(arg0 string, args []string) (*bootConfig, *config.Config, error) {
boot := &bootConfig{}
cfg := &config.Config{Version: Version}
fset := flag.NewFlagSet(arg0, flag.ContinueOnError)
fset.Usage = func() {
fmt.Fprintf(fset.Output(), "Usage of %s:\n", arg0)
fmt.Fprintf(fset.Output(), "\n %s [OPTIONS]\n\nOptions:\n", arg0)
fset.PrintDefaults()
}
configFile := fset.String("config", "", "TOML file to load config from")
fset.StringVar(&boot.secretPath, "secretPath", "./.gitlab_workhorse_secret", "File with secret key to authenticate with authBackend")
fset.StringVar(&boot.listenAddr, "listenAddr", "localhost:8181", "Listen address for HTTP server")
fset.StringVar(&boot.listenNetwork, "listenNetwork", "tcp", "Listen 'network' (tcp, tcp4, tcp6, unix)")
fset.IntVar(&boot.listenUmask, "listenUmask", 0, "Umask for Unix socket")
fset.StringVar(&boot.pprofListenAddr, "pprofListenAddr", "", "pprof listening address, e.g. 'localhost:6060'")
fset.StringVar(&boot.prometheusListenAddr, "prometheusListenAddr", "", "Prometheus listening address, e.g. 'localhost:9229'")
fset.StringVar(&boot.logFile, "logFile", "", "Log file location")
fset.StringVar(&boot.logFormat, "logFormat", "text", "Log format to use defaults to text (text, json, structured, none)")
fset.BoolVar(&boot.printVersion, "version", false, "Print version and exit")
// gitlab-rails backend
authBackend := fset.String("authBackend", upstream.DefaultBackend.String(), "Authentication/authorization backend")
fset.StringVar(&cfg.Socket, "authSocket", "", "Optional: Unix domain socket to dial authBackend at")
// actioncable backend
cableBackend := fset.String("cableBackend", upstream.DefaultBackend.String(), "ActionCable backend")
fset.StringVar(&cfg.CableSocket, "cableSocket", "", "Optional: Unix domain socket to dial cableBackend at")
fset.StringVar(&cfg.DocumentRoot, "documentRoot", "public", "Path to static files content")
fset.DurationVar(&cfg.ProxyHeadersTimeout, "proxyHeadersTimeout", 5*time.Minute, "How long to wait for response headers when proxying the request")
fset.BoolVar(&cfg.DevelopmentMode, "developmentMode", false, "Allow the assets to be served from Rails app")
fset.UintVar(&cfg.APILimit, "apiLimit", 0, "Number of API requests allowed at single time")
fset.UintVar(&cfg.APIQueueLimit, "apiQueueLimit", 0, "Number of API requests allowed to be queued")
fset.DurationVar(&cfg.APIQueueTimeout, "apiQueueDuration", queueing.DefaultTimeout, "Maximum queueing duration of requests")
fset.DurationVar(&cfg.APICILongPollingDuration, "apiCiLongPollingDuration", 50, "Long polling duration for job requesting for runners (default 50s - enabled)")
fset.BoolVar(&cfg.PropagateCorrelationID, "propagateCorrelationID", false, "Reuse existing Correlation-ID from the incoming request header `X-Request-ID` if present")
if err := fset.Parse(args); err != nil {
return nil, nil, alreadyPrintedError{err}
}
if fset.NArg() > 0 {
err := alreadyPrintedError{fmt.Errorf("unexpected arguments: %v", fset.Args())}
fmt.Fprintln(fset.Output(), err)
fset.Usage()
return nil, nil, err
}
var err error
cfg.Backend, err = parseAuthBackend(*authBackend)
if err != nil { if err != nil {
return err return nil, nil, fmt.Errorf("authBackend: %v", err)
} }
defer closer.Close()
tracing.Initialize(tracing.WithServiceName("gitlab-workhorse")) cfg.CableBackend, err = parseAuthBackend(*cableBackend)
if err != nil {
return nil, nil, fmt.Errorf("cableBackend: %v", err)
}
backendURL, err := parseAuthBackend(*authBackend) tomlData := ""
if *configFile != "" {
buf, err := ioutil.ReadFile(*configFile)
if err != nil { if err != nil {
return fmt.Errorf("authBackend: %v", err) return nil, nil, fmt.Errorf("configFile: %v", err)
} }
tomlData = string(buf)
}
cfgFromFile, err := config.LoadConfig(tomlData)
if err != nil {
return nil, nil, fmt.Errorf("configFile: %v", err)
}
cfg.Redis = cfgFromFile.Redis
cfg.ObjectStorageCredentials = cfgFromFile.ObjectStorageCredentials
cfg.ImageResizerConfig = cfgFromFile.ImageResizerConfig
return boot, cfg, nil
}
cableBackendURL, err := parseAuthBackend(*cableBackend) // run() lets us use normal Go error handling; there is no log.Fatal in run().
func run(boot bootConfig, cfg config.Config) error {
closer, err := startLogging(boot.logFile, boot.logFormat)
if err != nil { if err != nil {
return fmt.Errorf("cableBackend: %v", err) return err
} }
defer closer.Close()
tracing.Initialize(tracing.WithServiceName("gitlab-workhorse"))
log.WithField("version", Version).WithField("build_time", BuildTime).Print("Starting") log.WithField("version", Version).WithField("build_time", BuildTime).Print("Starting")
// Good housekeeping for Unix sockets: unlink before binding // Good housekeeping for Unix sockets: unlink before binding
if *listenNetwork == "unix" { if boot.listenNetwork == "unix" {
if err := os.Remove(*listenAddr); err != nil && !os.IsNotExist(err) { if err := os.Remove(boot.listenAddr); err != nil && !os.IsNotExist(err) {
return err return err
} }
} }
// Change the umask only around net.Listen() // Change the umask only around net.Listen()
oldUmask := syscall.Umask(*listenUmask) oldUmask := syscall.Umask(boot.listenUmask)
listener, err := net.Listen(*listenNetwork, *listenAddr) listener, err := net.Listen(boot.listenNetwork, boot.listenAddr)
syscall.Umask(oldUmask) syscall.Umask(oldUmask)
if err != nil { if err != nil {
return fmt.Errorf("main listener: %v", err) return fmt.Errorf("main listener: %v", err)
...@@ -126,8 +193,8 @@ func run() error { ...@@ -126,8 +193,8 @@ func run() error {
// requests can only reach the profiler if we start a listener. So by // requests can only reach the profiler if we start a listener. So by
// having no profiler HTTP listener by default, the profiler is // having no profiler HTTP listener by default, the profiler is
// effectively disabled by default. // effectively disabled by default.
if *pprofListenAddr != "" { if boot.pprofListenAddr != "" {
l, err := net.Listen("tcp", *pprofListenAddr) l, err := net.Listen("tcp", boot.pprofListenAddr)
if err != nil { if err != nil {
return fmt.Errorf("pprofListenAddr: %v", err) return fmt.Errorf("pprofListenAddr: %v", err)
} }
...@@ -137,8 +204,8 @@ func run() error { ...@@ -137,8 +204,8 @@ func run() error {
monitoringOpts := []monitoring.Option{monitoring.WithBuildInformation(Version, BuildTime)} monitoringOpts := []monitoring.Option{monitoring.WithBuildInformation(Version, BuildTime)}
if *prometheusListenAddr != "" { if boot.prometheusListenAddr != "" {
l, err := net.Listen("tcp", *prometheusListenAddr) l, err := net.Listen("tcp", boot.prometheusListenAddr)
if err != nil { if err != nil {
return fmt.Errorf("prometheusListenAddr: %v", err) return fmt.Errorf("prometheusListenAddr: %v", err)
} }
...@@ -152,46 +219,18 @@ func run() error { ...@@ -152,46 +219,18 @@ func run() error {
} }
}() }()
secret.SetPath(*secretPath) secret.SetPath(boot.secretPath)
cfg := config.Config{
Backend: backendURL,
CableBackend: cableBackendURL,
Socket: *authSocket,
CableSocket: *cableSocket,
Version: Version,
DocumentRoot: *documentRoot,
DevelopmentMode: *developmentMode,
ProxyHeadersTimeout: *proxyHeadersTimeout,
APILimit: *apiLimit,
APIQueueLimit: *apiQueueLimit,
APIQueueTimeout: *apiQueueTimeout,
APICILongPollingDuration: *apiCiLongPollingDuration,
PropagateCorrelationID: *propagateCorrelationID,
ImageResizerConfig: config.DefaultImageResizerConfig,
}
if *configFile != "" {
cfgFromFile, err := config.LoadConfig(*configFile)
if err != nil {
return fmt.Errorf("configFile: %v", err)
}
cfg.Redis = cfgFromFile.Redis
cfg.ObjectStorageCredentials = cfgFromFile.ObjectStorageCredentials
cfg.ImageResizerConfig = cfgFromFile.ImageResizerConfig
if cfg.Redis != nil { if cfg.Redis != nil {
redis.Configure(cfg.Redis, redis.DefaultDialFunc) redis.Configure(cfg.Redis, redis.DefaultDialFunc)
go redis.Process() go redis.Process()
} }
err = cfg.RegisterGoCloudURLOpeners() if err := cfg.RegisterGoCloudURLOpeners(); err != nil {
if err != nil {
return fmt.Errorf("register cloud credentials: %v", err) return fmt.Errorf("register cloud credentials: %v", err)
} }
}
accessLogger, accessCloser, err := getAccessLogger(logConfig) accessLogger, accessCloser, err := getAccessLogger(boot.logFile, boot.logFormat)
if err != nil { if err != nil {
return fmt.Errorf("configure access logger: %v", err) return fmt.Errorf("configure access logger: %v", err)
} }
......
...@@ -38,6 +38,9 @@ import ( ...@@ -38,6 +38,9 @@ import (
const scratchDir = "testdata/scratch" const scratchDir = "testdata/scratch"
const testRepoRoot = "testdata/data" const testRepoRoot = "testdata/data"
const testDocumentRoot = "testdata/public" const testDocumentRoot = "testdata/public"
var absDocumentRoot string
const testRepo = "group/test.git" const testRepo = "group/test.git"
const testProject = "group/test" const testProject = "group/test"
...@@ -183,7 +186,7 @@ func TestAllowedPublicUploadsFile(t *testing.T) { ...@@ -183,7 +186,7 @@ func TestAllowedPublicUploadsFile(t *testing.T) {
proxied := false proxied := false
ts := testhelper.TestServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, r *http.Request) { ts := testhelper.TestServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, r *http.Request) {
proxied = true proxied = true
w.Header().Add("X-Sendfile", *documentRoot+r.URL.Path) w.Header().Add("X-Sendfile", absDocumentRoot+r.URL.Path)
w.WriteHeader(200) w.WriteHeader(200)
}) })
defer ts.Close() defer ts.Close()
...@@ -577,11 +580,11 @@ func setupStaticFile(fpath, content string) error { ...@@ -577,11 +580,11 @@ func setupStaticFile(fpath, content string) error {
if err != nil { if err != nil {
return err return err
} }
*documentRoot = path.Join(cwd, testDocumentRoot) absDocumentRoot = path.Join(cwd, testDocumentRoot)
if err := os.MkdirAll(path.Join(*documentRoot, path.Dir(fpath)), 0755); err != nil { if err := os.MkdirAll(path.Join(absDocumentRoot, path.Dir(fpath)), 0755); err != nil {
return err return err
} }
staticFile := path.Join(*documentRoot, fpath) staticFile := path.Join(absDocumentRoot, fpath)
return ioutil.WriteFile(staticFile, []byte(content), 0666) return ioutil.WriteFile(staticFile, []byte(content), 0666)
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment