Commit b513c936 authored by Nick Thomas's avatar Nick Thomas

Merge branch 'sh-fix-performance-bar-ajax-requests' into 'master'

Remove duplicate X-Request-Id response header

Closes gitlab-ce#60111

See merge request gitlab-org/gitlab-workhorse!384
parents 8b043410 d44f2fb6
...@@ -14,6 +14,8 @@ import ( ...@@ -14,6 +14,8 @@ import (
"strings" "strings"
"syscall" "syscall"
"github.com/sebest/xff"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/log" "gitlab.com/gitlab-org/gitlab-workhorse/internal/log"
) )
...@@ -157,6 +159,15 @@ func AllowResponseBuffering(w http.ResponseWriter) { ...@@ -157,6 +159,15 @@ func AllowResponseBuffering(w http.ResponseWriter) {
w.Header().Del(NginxResponseBufferHeader) w.Header().Del(NginxResponseBufferHeader)
} }
func FixRemoteAddr(r *http.Request) {
// Unix domain sockets have a remote addr of @. This will make the
// xff package lookup the X-Forwarded-For address if available.
if r.RemoteAddr == "@" {
r.RemoteAddr = "127.0.0.1:0"
}
r.RemoteAddr = xff.GetRemoteAddr(r)
}
func SetForwardedFor(newHeaders *http.Header, originalRequest *http.Request) { func SetForwardedFor(newHeaders *http.Header, originalRequest *http.Request) {
if clientIP, _, err := net.SplitHostPort(originalRequest.RemoteAddr); err == nil { if clientIP, _, err := net.SplitHostPort(originalRequest.RemoteAddr); err == nil {
var header string var header string
......
...@@ -8,8 +8,39 @@ import ( ...@@ -8,8 +8,39 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestFixRemoteAddr(t *testing.T) {
testCases := []struct {
initial string
forwarded string
expected string
}{
{initial: "@", forwarded: "", expected: "127.0.0.1:0"},
{initial: "@", forwarded: "18.245.0.1", expected: "18.245.0.1:0"},
{initial: "@", forwarded: "127.0.0.1", expected: "127.0.0.1:0"},
{initial: "@", forwarded: "192.168.0.1", expected: "127.0.0.1:0"},
{initial: "192.168.1.1:0", forwarded: "", expected: "192.168.1.1:0"},
{initial: "192.168.1.1:0", forwarded: "18.245.0.1", expected: "18.245.0.1:0"},
}
for _, tc := range testCases {
req, err := http.NewRequest("POST", "unix:///tmp/test.socket/info/refs", nil)
require.NoError(t, err)
req.RemoteAddr = tc.initial
if tc.forwarded != "" {
req.Header.Add("X-Forwarded-For", tc.forwarded)
}
FixRemoteAddr(req)
assert.Equal(t, tc.expected, req.RemoteAddr)
}
}
func TestSetForwardedForGeneratesHeader(t *testing.T) { func TestSetForwardedForGeneratesHeader(t *testing.T) {
testCases := []struct { testCases := []struct {
remoteAddr string remoteAddr string
......
...@@ -11,7 +11,7 @@ import ( ...@@ -11,7 +11,7 @@ import (
"net/http" "net/http"
"strings" "strings"
"github.com/sebest/xff" "gitlab.com/gitlab-org/labkit/correlation"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/config" "gitlab.com/gitlab-org/gitlab-workhorse/internal/config"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
...@@ -44,7 +44,8 @@ func NewUpstream(cfg config.Config) http.Handler { ...@@ -44,7 +44,8 @@ func NewUpstream(cfg config.Config) http.Handler {
up.RoundTripper = roundtripper.NewBackendRoundTripper(up.Backend, up.Socket, up.ProxyHeadersTimeout, cfg.DevelopmentMode) up.RoundTripper = roundtripper.NewBackendRoundTripper(up.Backend, up.Socket, up.ProxyHeadersTimeout, cfg.DevelopmentMode)
up.configureURLPrefix() up.configureURLPrefix()
up.configureRoutes() up.configureRoutes()
return &up
return correlation.InjectCorrelationID(&up)
} }
func (u *upstream) configureURLPrefix() { func (u *upstream) configureURLPrefix() {
...@@ -56,12 +57,7 @@ func (u *upstream) configureURLPrefix() { ...@@ -56,12 +57,7 @@ func (u *upstream) configureURLPrefix() {
} }
func (u *upstream) ServeHTTP(ow http.ResponseWriter, r *http.Request) { func (u *upstream) ServeHTTP(ow http.ResponseWriter, r *http.Request) {
// Unix domain sockets have a remote addr of @. This will make the helper.FixRemoteAddr(r)
// xff package lookup the X-Forwarded-For address if available.
if r.RemoteAddr == "@" {
r.RemoteAddr = "127.0.0.1:0"
}
r.RemoteAddr = xff.GetRemoteAddr(r)
w := helper.NewStatsCollectingResponseWriter(ow) w := helper.NewStatsCollectingResponseWriter(ow)
defer w.RequestFinished(r) defer w.RequestFinished(r)
......
package upstream
import (
"net/http"
"net/http/httptest"
"testing"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// This test doesn't actually listen for a connection so there may be
// some error spew that can be ignored.
func TestXForwardedForHeaders(t *testing.T) {
testCases := []struct {
initial string
forwarded string
expected string
}{
{initial: "@", forwarded: "", expected: "127.0.0.1:0"},
{initial: "@", forwarded: "18.245.0.1", expected: "18.245.0.1:0"},
{initial: "@", forwarded: "127.0.0.1", expected: "127.0.0.1:0"},
{initial: "@", forwarded: "192.168.0.1", expected: "127.0.0.1:0"},
{initial: "192.168.1.1:0", forwarded: "", expected: "192.168.1.1:0"},
{initial: "192.168.1.1:0", forwarded: "18.245.0.1", expected: "18.245.0.1:0"},
}
for _, tc := range testCases {
req, err := http.NewRequest("POST", "unix:///tmp/test.socket/info/refs", nil)
require.NoError(t, err)
req.RemoteAddr = tc.initial
if tc.forwarded != "" {
req.Header.Add("X-Forwarded-For", tc.forwarded)
}
config := config.Config{}
u := NewUpstream(config)
resp := httptest.NewRecorder()
u.ServeHTTP(resp, req)
assert.Equal(t, tc.expected, req.RemoteAddr)
}
}
...@@ -25,7 +25,6 @@ import ( ...@@ -25,7 +25,6 @@ import (
"github.com/prometheus/client_golang/prometheus/promhttp" "github.com/prometheus/client_golang/prometheus/promhttp"
"gitlab.com/gitlab-org/labkit/correlation"
"gitlab.com/gitlab-org/labkit/tracing" "gitlab.com/gitlab-org/labkit/tracing"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/config" "gitlab.com/gitlab-org/gitlab-workhorse/internal/config"
...@@ -152,12 +151,7 @@ func main() { ...@@ -152,12 +151,7 @@ func main() {
} }
} }
up := wrapRaven( up := wrapRaven(upstream.NewUpstream(cfg))
correlation.InjectCorrelationID(
upstream.NewUpstream(cfg),
correlation.WithSetResponseHeader(),
),
)
logger.Fatal(http.Serve(listener, up)) logger.Fatal(http.Serve(listener, up))
} }
...@@ -449,6 +449,26 @@ func TestAPIFalsePositivesAreProxied(t *testing.T) { ...@@ -449,6 +449,26 @@ func TestAPIFalsePositivesAreProxied(t *testing.T) {
} }
} }
func TestCorrelationIdHeader(t *testing.T) {
ts := testhelper.TestServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("X-Request-Id", "12345678")
w.WriteHeader(200)
})
defer ts.Close()
ws := startWorkhorseServer(ts.URL)
defer ws.Close()
for _, resource := range []string{
"/api/v3/projects/123/repository/not/special",
} {
resp, _ := httpGet(t, ws.URL+resource, nil)
assert.Equal(t, 200, resp.StatusCode, "GET %q: status code", resource)
requestIds := resp.Header["X-Request-Id"]
assert.Equal(t, 1, len(requestIds), "GET %q: One X-Request-Id present", resource)
}
}
func setupStaticFile(fpath, content string) error { func setupStaticFile(fpath, content string) error {
cwd, err := os.Getwd() cwd, err := os.Getwd()
if err != nil { if err != nil {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment