Commit 51c7f313 authored by Mahmoud Rahbar Azad's avatar Mahmoud Rahbar Azad Committed by Jacob Vosmaer

added configuration option PropagateCorrelationID and corresponding flag...

added configuration option PropagateCorrelationID and corresponding flag argument propagateCorrelationID to enable correlation.InboundHandlerOption: WithPropagation() for correlation.InjectCorrelationID
parent cb7f2d58
......@@ -74,6 +74,7 @@ type Config struct {
APIQueueTimeout time.Duration `toml:"-"`
APICILongPollingDuration time.Duration `toml:"-"`
ObjectStorageCredentials *ObjectStorageCredentials `toml:"object_storage"`
PropagateCorrelationID bool `toml:"-"`
}
// LoadConfig from a file
......
......@@ -56,8 +56,13 @@ func NewUpstream(cfg config.Config, accessLogger *logrus.Logger) http.Handler {
up.configureURLPrefix()
up.configureRoutes()
var correlationOpts []correlation.InboundHandlerOption
if cfg.PropagateCorrelationID {
correlationOpts = append(correlationOpts, correlation.WithPropagation())
}
handler := log.AccessLogger(&up, log.WithAccessLogger(accessLogger))
handler = correlation.InjectCorrelationID(handler)
handler = correlation.InjectCorrelationID(handler, correlationOpts...)
return handler
}
......
......@@ -57,6 +57,7 @@ var apiLimit = flag.Uint("apiLimit", 0, "Number of API requests allowed at singl
var apiQueueLimit = flag.Uint("apiQueueLimit", 0, "Number of API requests allowed to be queued")
var apiQueueTimeout = flag.Duration("apiQueueDuration", queueing.DefaultTimeout, "Maximum queueing duration of requests")
var apiCiLongPollingDuration = flag.Duration("apiCiLongPollingDuration", 50, "Long polling duration for job requesting for runners (default 50s - enabled)")
var propagateCorrelationID = flag.Bool("propagateCorrelationID", false, "Reuse existing Correlation-ID from the incoming request header `X-Request-ID` if present")
var prometheusListenAddr = flag.String("prometheusListenAddr", "", "Prometheus listening address, e.g. 'localhost:9229'")
......@@ -155,6 +156,7 @@ func main() {
APIQueueLimit: *apiQueueLimit,
APIQueueTimeout: *apiQueueTimeout,
APICILongPollingDuration: *apiCiLongPollingDuration,
PropagateCorrelationID: *propagateCorrelationID,
}
if *configFile != "" {
......
......@@ -510,6 +510,52 @@ func TestCorrelationIdHeader(t *testing.T) {
}
}
func TestPropagateCorrelationIdHeader(t *testing.T) {
ts := testhelper.TestServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("X-Request-Id", r.Header.Get("X-Request-Id"))
w.WriteHeader(200)
})
defer ts.Close()
testCases := []struct {
desc string
propagateCorrelationID bool
}{
{
desc: "propagateCorrelatedId is true",
propagateCorrelationID: true,
},
{
desc: "propagateCorrelatedId is false",
propagateCorrelationID: false,
},
}
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
upstreamConfig := newUpstreamConfig(ts.URL)
upstreamConfig.PropagateCorrelationID = tc.propagateCorrelationID
ws := startWorkhorseServerWithConfig(upstreamConfig)
defer ws.Close()
resource := "/api/v3/projects/123/repository/not/special"
propagatedRequestId := "Propagated-RequestId-12345678"
resp, _ := httpGet(t, ws.URL+resource, map[string]string{"X-Request-Id": propagatedRequestId})
requestIds := resp.Header["X-Request-Id"]
assert.Equal(t, 200, resp.StatusCode, "GET %q: status code", resource)
assert.Equal(t, 1, len(requestIds), "GET %q: One X-Request-Id present", resource)
if tc.propagateCorrelationID {
assert.Contains(t, requestIds, propagatedRequestId, "GET %q: Has X-Request-Id %s present", resource, propagatedRequestId)
} else {
assert.NotContains(t, requestIds, propagatedRequestId, "GET %q: X-Request-Id not propagated")
}
})
}
}
func setupStaticFile(fpath, content string) error {
cwd, err := os.Getwd()
if err != nil {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment