Commit 5d1752d1 authored by Jacob Vosmaer's avatar Jacob Vosmaer Committed by Nick Thomas

Add support for Gitaly feature flags

parent efd1d567
......@@ -52,11 +52,11 @@ func realGitalyOkBody(t *testing.T) *api.Response {
}
func ensureGitalyRepository(t *testing.T, apiResponse *api.Response) error {
namespace, err := gitaly.NewNamespaceClient(apiResponse.GitalyServer)
ctx, namespace, err := gitaly.NewNamespaceClient(context.Background(), apiResponse.GitalyServer)
if err != nil {
return err
}
repository, err := gitaly.NewRepositoryClient(apiResponse.GitalyServer)
ctx, repository, err := gitaly.NewRepositoryClient(ctx, apiResponse.GitalyServer)
if err != nil {
return err
}
......@@ -66,7 +66,7 @@ func ensureGitalyRepository(t *testing.T, apiResponse *api.Response) error {
StorageName: apiResponse.Repository.StorageName,
Name: apiResponse.Repository.RelativePath,
}
_, err = namespace.RemoveNamespace(context.Background(), rmNsReq)
_, err = namespace.RemoveNamespace(ctx, rmNsReq)
if err != nil {
return err
}
......@@ -76,7 +76,7 @@ func ensureGitalyRepository(t *testing.T, apiResponse *api.Response) error {
Url: "https://gitlab.com/gitlab-org/gitlab-test.git",
}
_, err = repository.CreateRepositoryFromURL(context.Background(), createReq)
_, err = repository.CreateRepositoryFromURL(ctx, createReq)
return err
}
......
......@@ -64,6 +64,23 @@ func TestGetInfoRefsProxiedToGitalySuccessfully(t *testing.T) {
apiResponse := gitOkBody(t)
apiResponse.GitalyServer.Address = gitalyAddress
goodMetadata := map[string]string{
"gitaly-feature-foobar": "true",
"gitaly-feature-bazqux": "false",
}
badMetadata := map[string]string{
"bad-metadata": "is blocked",
}
features := make(map[string]string)
for k, v := range goodMetadata {
features[k] = v
}
for k, v := range badMetadata {
features[k] = v
}
apiResponse.GitalyServer.Features = features
testCases := []struct {
showAllRefs bool
gitRpc string
......@@ -106,6 +123,18 @@ func TestGetInfoRefsProxiedToGitalySuccessfully(t *testing.T) {
require.Equal(t, tc.gitRpc, bodySplit[1])
require.Equal(t, string(testhelper.GitalyInfoRefsResponseMock), bodySplit[2], "GET %q: response body", resource)
md := gitalyServer.LastIncomingMetadata
for k, v := range goodMetadata {
actual := md[k]
require.Len(t, actual, 1, "number of metadata values for %v", k)
require.Equal(t, v, actual[0], "value for %v", k)
}
for k := range badMetadata {
actual := md[k]
require.Empty(t, actual, "metadata for bad key %v", k)
}
})
}
}
......
......@@ -134,7 +134,7 @@ func (a *archive) Inject(w http.ResponseWriter, r *http.Request, sendData string
func handleArchiveWithGitaly(r *http.Request, params archiveParams, format gitalypb.GetArchiveRequest_Format) (io.Reader, error) {
var request *gitalypb.GetArchiveRequest
c, err := gitaly.NewRepositoryClient(params.GitalyServer)
ctx, c, err := gitaly.NewRepositoryClient(r.Context(), params.GitalyServer)
if err != nil {
return nil, err
}
......@@ -154,7 +154,7 @@ func handleArchiveWithGitaly(r *http.Request, params archiveParams, format gital
}
}
return c.ArchiveReader(r.Context(), request)
return c.ArchiveReader(ctx, request)
}
func setArchiveHeaders(w http.ResponseWriter, format gitalypb.GetArchiveRequest_Format, archiveFilename string) {
......
......@@ -26,13 +26,13 @@ func (b *blob) Inject(w http.ResponseWriter, r *http.Request, sendData string) {
return
}
blobClient, err := gitaly.NewBlobClient(params.GitalyServer)
ctx, blobClient, err := gitaly.NewBlobClient(r.Context(), params.GitalyServer)
if err != nil {
helper.Fail500(w, r, fmt.Errorf("blob.GetBlob: %v", err))
return
}
if err := blobClient.SendBlob(r.Context(), w, &params.GetBlobRequest); err != nil {
if err := blobClient.SendBlob(ctx, w, &params.GetBlobRequest); err != nil {
helper.Fail500(w, r, fmt.Errorf("blob.GetBlob: %v", err))
return
}
......
......@@ -32,13 +32,13 @@ func (d *diff) Inject(w http.ResponseWriter, r *http.Request, sendData string) {
return
}
diffClient, err := gitaly.NewDiffClient(params.GitalyServer)
ctx, diffClient, err := gitaly.NewDiffClient(r.Context(), params.GitalyServer)
if err != nil {
helper.Fail500(w, r, fmt.Errorf("diff.RawDiff: %v", err))
return
}
if err := diffClient.SendRawDiff(r.Context(), w, request); err != nil {
if err := diffClient.SendRawDiff(ctx, w, request); err != nil {
helper.LogError(
r,
&copyError{fmt.Errorf("diff.RawDiff: request=%v, err=%v", request, err)},
......
......@@ -32,13 +32,13 @@ func (p *patch) Inject(w http.ResponseWriter, r *http.Request, sendData string)
return
}
diffClient, err := gitaly.NewDiffClient(params.GitalyServer)
ctx, diffClient, err := gitaly.NewDiffClient(r.Context(), params.GitalyServer)
if err != nil {
helper.Fail500(w, r, fmt.Errorf("diff.RawPatch: %v", err))
return
}
if err := diffClient.SendRawPatch(r.Context(), w, request); err != nil {
if err := diffClient.SendRawPatch(ctx, w, request); err != nil {
helper.LogError(
r,
&copyError{fmt.Errorf("diff.RawPatch: request=%v, err=%v", request, err)},
......
......@@ -46,7 +46,7 @@ func handleGetInfoRefs(rw http.ResponseWriter, r *http.Request, a *api.Response)
}
func handleGetInfoRefsWithGitaly(ctx context.Context, responseWriter *HttpResponseWriter, a *api.Response, rpc, gitProtocol, encoding string) error {
smarthttp, err := gitaly.NewSmartHTTPClient(a.GitalyServer)
ctx, smarthttp, err := gitaly.NewSmartHTTPClient(ctx, a.GitalyServer)
if err != nil {
return fmt.Errorf("GetInfoRefsHandler: %v", err)
}
......
......@@ -20,12 +20,12 @@ func handleReceivePack(w *HttpResponseWriter, r *http.Request, a *api.Response)
gitProtocol := r.Header.Get("Git-Protocol")
smarthttp, err := gitaly.NewSmartHTTPClient(a.GitalyServer)
ctx, smarthttp, err := gitaly.NewSmartHTTPClient(r.Context(), a.GitalyServer)
if err != nil {
return fmt.Errorf("smarthttp.ReceivePack: %v", err)
}
if err := smarthttp.ReceivePack(r.Context(), &a.Repository, a.GL_ID, a.GL_USERNAME, a.GL_REPOSITORY, a.GitConfigOptions, cr, cw, gitProtocol); err != nil {
if err := smarthttp.ReceivePack(ctx, &a.Repository, a.GL_ID, a.GL_USERNAME, a.GL_REPOSITORY, a.GitConfigOptions, cr, cw, gitProtocol); err != nil {
return fmt.Errorf("smarthttp.ReceivePack: %v", err)
}
......
......@@ -39,13 +39,13 @@ func (s *snapshot) Inject(w http.ResponseWriter, r *http.Request, sendData strin
return
}
c, err := gitaly.NewRepositoryClient(params.GitalyServer)
ctx, c, err := gitaly.NewRepositoryClient(r.Context(), params.GitalyServer)
if err != nil {
helper.Fail500(w, r, fmt.Errorf("SendSnapshot: gitaly.NewRepositoryClient: %v", err))
return
}
reader, err := c.SnapshotReader(r.Context(), request)
reader, err := c.SnapshotReader(ctx, request)
if err != nil {
helper.Fail500(w, r, fmt.Errorf("SendSnapshot: client.SnapshotReader: %v", err))
return
......
......@@ -33,7 +33,7 @@ func handleUploadPack(w *HttpResponseWriter, r *http.Request, a *api.Response) e
}
func handleUploadPackWithGitaly(ctx context.Context, a *api.Response, clientRequest io.Reader, clientResponse io.Writer, gitProtocol string) error {
smarthttp, err := gitaly.NewSmartHTTPClient(a.GitalyServer)
ctx, smarthttp, err := gitaly.NewSmartHTTPClient(ctx, a.GitalyServer)
if err != nil {
return fmt.Errorf("smarthttp.UploadPack: %v", err)
}
......
package gitaly
import (
"context"
"strings"
"sync"
......@@ -13,6 +14,7 @@ import (
gitalyclient "gitlab.com/gitlab-org/gitaly/client"
"gitlab.com/gitlab-org/gitaly/proto/go/gitalypb"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
grpccorrelation "gitlab.com/gitlab-org/labkit/correlation/grpc"
grpctracing "gitlab.com/gitlab-org/labkit/tracing/grpc"
......@@ -21,17 +23,24 @@ import (
type Server struct {
Address string `json:"address"`
Token string `json:"token"`
Features map[string]string `json:"features"`
}
type cacheKey struct{ address, token string }
func (server Server) cacheKey() cacheKey {
return cacheKey{address: server.Address, token: server.Token}
}
type connectionsCache struct {
sync.RWMutex
connections map[Server]*grpc.ClientConn
connections map[cacheKey]*grpc.ClientConn
}
var (
jsonUnMarshaler = jsonpb.Unmarshaler{AllowUnknownFields: true}
cache = connectionsCache{
connections: make(map[Server]*grpc.ClientConn),
connections: make(map[cacheKey]*grpc.ClientConn),
}
connectionsTotal = prometheus.NewCounterVec(
......@@ -47,55 +56,69 @@ func init() {
prometheus.MustRegister(connectionsTotal)
}
func NewSmartHTTPClient(server Server) (*SmartHTTPClient, error) {
func withOutgoingMetadata(ctx context.Context, features map[string]string) context.Context {
md := metadata.New(nil)
for k, v := range features {
if !strings.HasPrefix(k, "gitaly-feature-") {
continue
}
md.Append(k, v)
}
return metadata.NewOutgoingContext(ctx, md)
}
func NewSmartHTTPClient(ctx context.Context, server Server) (context.Context, *SmartHTTPClient, error) {
conn, err := getOrCreateConnection(server)
if err != nil {
return nil, err
return nil, nil, err
}
grpcClient := gitalypb.NewSmartHTTPServiceClient(conn)
return &SmartHTTPClient{grpcClient}, nil
return withOutgoingMetadata(ctx, server.Features), &SmartHTTPClient{grpcClient}, nil
}
func NewBlobClient(server Server) (*BlobClient, error) {
func NewBlobClient(ctx context.Context, server Server) (context.Context, *BlobClient, error) {
conn, err := getOrCreateConnection(server)
if err != nil {
return nil, err
return nil, nil, err
}
grpcClient := gitalypb.NewBlobServiceClient(conn)
return &BlobClient{grpcClient}, nil
return withOutgoingMetadata(ctx, server.Features), &BlobClient{grpcClient}, nil
}
func NewRepositoryClient(server Server) (*RepositoryClient, error) {
func NewRepositoryClient(ctx context.Context, server Server) (context.Context, *RepositoryClient, error) {
conn, err := getOrCreateConnection(server)
if err != nil {
return nil, err
return nil, nil, err
}
grpcClient := gitalypb.NewRepositoryServiceClient(conn)
return &RepositoryClient{grpcClient}, nil
return withOutgoingMetadata(ctx, server.Features), &RepositoryClient{grpcClient}, nil
}
// NewNamespaceClient is only used by the Gitaly integration tests at present
func NewNamespaceClient(server Server) (*NamespaceClient, error) {
func NewNamespaceClient(ctx context.Context, server Server) (context.Context, *NamespaceClient, error) {
conn, err := getOrCreateConnection(server)
if err != nil {
return nil, err
return nil, nil, err
}
grpcClient := gitalypb.NewNamespaceServiceClient(conn)
return &NamespaceClient{grpcClient}, nil
return withOutgoingMetadata(ctx, server.Features), &NamespaceClient{grpcClient}, nil
}
func NewDiffClient(server Server) (*DiffClient, error) {
func NewDiffClient(ctx context.Context, server Server) (context.Context, *DiffClient, error) {
conn, err := getOrCreateConnection(server)
if err != nil {
return nil, err
return nil, nil, err
}
grpcClient := gitalypb.NewDiffServiceClient(conn)
return &DiffClient{grpcClient}, nil
return withOutgoingMetadata(ctx, server.Features), &DiffClient{grpcClient}, nil
}
func getOrCreateConnection(server Server) (*grpc.ClientConn, error) {
key := server.cacheKey()
cache.RLock()
conn := cache.connections[server]
conn := cache.connections[key]
cache.RUnlock()
if conn != nil {
......@@ -105,7 +128,7 @@ func getOrCreateConnection(server Server) (*grpc.ClientConn, error) {
cache.Lock()
defer cache.Unlock()
if conn := cache.connections[server]; conn != nil {
if conn := cache.connections[key]; conn != nil {
return conn, nil
}
......@@ -114,7 +137,7 @@ func getOrCreateConnection(server Server) (*grpc.ClientConn, error) {
return nil, err
}
cache.connections[server] = conn
cache.connections[key] = conn
return conn, nil
}
......
package gitaly
import (
"context"
"testing"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/metadata"
)
func TestNewSmartHTTPClient(t *testing.T) {
ctx, _, err := NewSmartHTTPClient(context.Background(), serverFixture())
require.NoError(t, err)
testOutgoingMetadata(t, ctx)
}
func TestNewBlobClient(t *testing.T) {
ctx, _, err := NewBlobClient(context.Background(), serverFixture())
require.NoError(t, err)
testOutgoingMetadata(t, ctx)
}
func TestNewRepositoryClient(t *testing.T) {
ctx, _, err := NewRepositoryClient(context.Background(), serverFixture())
require.NoError(t, err)
testOutgoingMetadata(t, ctx)
}
func TestNewNamespaceClient(t *testing.T) {
ctx, _, err := NewNamespaceClient(context.Background(), serverFixture())
require.NoError(t, err)
testOutgoingMetadata(t, ctx)
}
func TestNewDiffClient(t *testing.T) {
ctx, _, err := NewDiffClient(context.Background(), serverFixture())
require.NoError(t, err)
testOutgoingMetadata(t, ctx)
}
func testOutgoingMetadata(t *testing.T, ctx context.Context) {
md, ok := metadata.FromOutgoingContext(ctx)
require.True(t, ok, "get metadata from context")
for k, v := range allowedFeatures() {
actual := md[k]
require.Len(t, actual, 1, "expect one value for %v", k)
require.Equal(t, v, actual[0], "value for %v", k)
}
for k := range badFeatureMetadata() {
require.Empty(t, md[k], "value for bad key %v", k)
}
}
func serverFixture() Server {
features := make(map[string]string)
for k, v := range allowedFeatures() {
features[k] = v
}
for k, v := range badFeatureMetadata() {
features[k] = v
}
return Server{Address: "tcp://localhost:123", Features: features}
}
func allowedFeatures() map[string]string {
return map[string]string{
"gitaly-feature-foo": "bar",
"gitaly-feature-qux": "baz",
}
}
func badFeatureMetadata() map[string]string {
return map[string]string{
"bad-metadata-1": "bad-value-1",
"bad-metadata-2": "bad-value-2",
}
}
......@@ -14,12 +14,14 @@ import (
"gitlab.com/gitlab-org/labkit/log"
"golang.org/x/net/context"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
type GitalyTestServer struct {
finalMessageCode codes.Code
sync.WaitGroup
LastIncomingMetadata metadata.MD
}
var (
......@@ -71,6 +73,11 @@ func (s *GitalyTestServer) InfoRefsUploadPack(in *gitalypb.InfoRefsRequest, stre
GitalyInfoRefsResponseMock,
}, "\000"))
s.LastIncomingMetadata = nil
if md, ok := metadata.FromIncomingContext(stream.Context()); ok {
s.LastIncomingMetadata = md
}
return s.sendInfoRefs(stream, data)
}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment