Commit bee2c423 authored by Mateusz Nowotyński's avatar Mateusz Nowotyński

Add support for Gitaly feature flags

gitaly#1846

Gitaly does not store or remember feature flags. We must pass them in as
metadata on each request. This MR adds plumbing to pass Gitaly feature flags
supplied by gitlab-rails to the Gitaly server as gRPC request metadata.
Signed-off-by: default avatarMateusz Nowotyński <maxmati4@gmail.com>
parent 7d5229db
...@@ -18,6 +18,7 @@ func (c *Command) performGitalyCall(response *accessverifier.Response) error { ...@@ -18,6 +18,7 @@ func (c *Command) performGitalyCall(response *accessverifier.Response) error {
ServiceName: string(commandargs.ReceivePack), ServiceName: string(commandargs.ReceivePack),
Address: response.Gitaly.Address, Address: response.Gitaly.Address,
Token: response.Gitaly.Token, Token: response.Gitaly.Token,
Features: response.Gitaly.Features,
} }
request := &pb.SSHReceivePackRequest{ request := &pb.SSHReceivePackRequest{
......
...@@ -14,7 +14,7 @@ import ( ...@@ -14,7 +14,7 @@ import (
) )
func TestReceivePack(t *testing.T) { func TestReceivePack(t *testing.T) {
gitalyAddress, cleanup := testserver.StartGitalyServer(t) gitalyAddress, _, cleanup := testserver.StartGitalyServer(t)
defer cleanup() defer cleanup()
requests := requesthandlers.BuildAllowedWithGitalyHandlers(t, gitalyAddress) requests := requesthandlers.BuildAllowedWithGitalyHandlers(t, gitalyAddress)
......
...@@ -18,6 +18,7 @@ func (c *Command) performGitalyCall(response *accessverifier.Response) error { ...@@ -18,6 +18,7 @@ func (c *Command) performGitalyCall(response *accessverifier.Response) error {
ServiceName: string(commandargs.UploadArchive), ServiceName: string(commandargs.UploadArchive),
Address: response.Gitaly.Address, Address: response.Gitaly.Address,
Token: response.Gitaly.Token, Token: response.Gitaly.Token,
Features: response.Gitaly.Features,
} }
request := &pb.SSHUploadArchiveRequest{Repository: &response.Gitaly.Repo} request := &pb.SSHUploadArchiveRequest{Repository: &response.Gitaly.Repo}
......
...@@ -14,7 +14,7 @@ import ( ...@@ -14,7 +14,7 @@ import (
) )
func TestUploadPack(t *testing.T) { func TestUploadPack(t *testing.T) {
gitalyAddress, cleanup := testserver.StartGitalyServer(t) gitalyAddress, _, cleanup := testserver.StartGitalyServer(t)
defer cleanup() defer cleanup()
requests := requesthandlers.BuildAllowedWithGitalyHandlers(t, gitalyAddress) requests := requesthandlers.BuildAllowedWithGitalyHandlers(t, gitalyAddress)
......
...@@ -18,6 +18,7 @@ func (c *Command) performGitalyCall(response *accessverifier.Response) error { ...@@ -18,6 +18,7 @@ func (c *Command) performGitalyCall(response *accessverifier.Response) error {
ServiceName: string(commandargs.UploadPack), ServiceName: string(commandargs.UploadPack),
Address: response.Gitaly.Address, Address: response.Gitaly.Address,
Token: response.Gitaly.Token, Token: response.Gitaly.Token,
Features: response.Gitaly.Features,
} }
request := &pb.SSHUploadPackRequest{ request := &pb.SSHUploadPackRequest{
......
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs" "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs"
...@@ -14,7 +15,7 @@ import ( ...@@ -14,7 +15,7 @@ import (
) )
func TestUploadPack(t *testing.T) { func TestUploadPack(t *testing.T) {
gitalyAddress, cleanup := testserver.StartGitalyServer(t) gitalyAddress, testServer, cleanup := testserver.StartGitalyServer(t)
defer cleanup() defer cleanup()
requests := requesthandlers.BuildAllowedWithGitalyHandlers(t, gitalyAddress) requests := requesthandlers.BuildAllowedWithGitalyHandlers(t, gitalyAddress)
...@@ -37,4 +38,14 @@ func TestUploadPack(t *testing.T) { ...@@ -37,4 +38,14 @@ func TestUploadPack(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, "UploadPack: "+repo, output.String()) require.Equal(t, "UploadPack: "+repo, output.String())
for k, v := range map[string]string{
"gitaly-feature-cache_invalidator": "true",
"gitaly-feature-inforef_uploadpack_cache": "false",
} {
actual := testServer.ReceivedMD[k]
assert.Len(t, actual, 1)
assert.Equal(t, v, actual[0])
}
assert.Empty(t, testServer.ReceivedMD["some-other-ff"])
} }
...@@ -31,9 +31,10 @@ type Request struct { ...@@ -31,9 +31,10 @@ type Request struct {
} }
type Gitaly struct { type Gitaly struct {
Repo pb.Repository `json:"repository"` Repo pb.Repository `json:"repository"`
Address string `json:"address"` Address string `json:"address"`
Token string `json:"token"` Token string `json:"token"`
Features map[string]string `json:"features"`
} }
type CustomPayloadData struct { type CustomPayloadData struct {
......
...@@ -10,49 +10,56 @@ import ( ...@@ -10,49 +10,56 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/metadata"
pb "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb" pb "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb"
) )
type testGitalyServer struct{} type TestGitalyServer struct{ ReceivedMD metadata.MD }
func (s *testGitalyServer) SSHReceivePack(stream pb.SSHService_SSHReceivePackServer) error { func (s *TestGitalyServer) SSHReceivePack(stream pb.SSHService_SSHReceivePackServer) error {
req, err := stream.Recv() req, err := stream.Recv()
if err != nil { if err != nil {
return err return err
} }
s.ReceivedMD, _ = metadata.FromIncomingContext(stream.Context())
response := []byte("ReceivePack: " + req.GlId + " " + req.Repository.GlRepository) response := []byte("ReceivePack: " + req.GlId + " " + req.Repository.GlRepository)
stream.Send(&pb.SSHReceivePackResponse{Stdout: response}) stream.Send(&pb.SSHReceivePackResponse{Stdout: response})
return nil return nil
} }
func (s *testGitalyServer) SSHUploadPack(stream pb.SSHService_SSHUploadPackServer) error { func (s *TestGitalyServer) SSHUploadPack(stream pb.SSHService_SSHUploadPackServer) error {
req, err := stream.Recv() req, err := stream.Recv()
if err != nil { if err != nil {
return err return err
} }
s.ReceivedMD, _ = metadata.FromIncomingContext(stream.Context())
response := []byte("UploadPack: " + req.Repository.GlRepository) response := []byte("UploadPack: " + req.Repository.GlRepository)
stream.Send(&pb.SSHUploadPackResponse{Stdout: response}) stream.Send(&pb.SSHUploadPackResponse{Stdout: response})
return nil return nil
} }
func (s *testGitalyServer) SSHUploadArchive(stream pb.SSHService_SSHUploadArchiveServer) error { func (s *TestGitalyServer) SSHUploadArchive(stream pb.SSHService_SSHUploadArchiveServer) error {
req, err := stream.Recv() req, err := stream.Recv()
if err != nil { if err != nil {
return err return err
} }
s.ReceivedMD, _ = metadata.FromIncomingContext(stream.Context())
response := []byte("UploadArchive: " + req.Repository.GlRepository) response := []byte("UploadArchive: " + req.Repository.GlRepository)
stream.Send(&pb.SSHUploadArchiveResponse{Stdout: response}) stream.Send(&pb.SSHUploadArchiveResponse{Stdout: response})
return nil return nil
} }
func StartGitalyServer(t *testing.T) (string, func()) { func StartGitalyServer(t *testing.T) (string, *TestGitalyServer, func()) {
tempDir, _ := ioutil.TempDir("", "gitlab-shell-test-api") tempDir, _ := ioutil.TempDir("", "gitlab-shell-test-api")
gitalySocketPath := path.Join(tempDir, "gitaly.sock") gitalySocketPath := path.Join(tempDir, "gitaly.sock")
...@@ -64,7 +71,8 @@ func StartGitalyServer(t *testing.T) (string, func()) { ...@@ -64,7 +71,8 @@ func StartGitalyServer(t *testing.T) (string, func()) {
listener, err := net.Listen("unix", gitalySocketPath) listener, err := net.Listen("unix", gitalySocketPath)
require.NoError(t, err) require.NoError(t, err)
pb.RegisterSSHServiceServer(server, &testGitalyServer{}) testServer := TestGitalyServer{}
pb.RegisterSSHServiceServer(server, &testServer)
go server.Serve(listener) go server.Serve(listener)
...@@ -74,5 +82,5 @@ func StartGitalyServer(t *testing.T) (string, func()) { ...@@ -74,5 +82,5 @@ func StartGitalyServer(t *testing.T) (string, func()) {
os.RemoveAll(tempDir) os.RemoveAll(tempDir)
} }
return gitalySocketUrl, cleanup return gitalySocketUrl, &testServer, cleanup
} }
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"os" "os"
"strings"
"gitlab.com/gitlab-org/gitaly/auth" "gitlab.com/gitlab-org/gitaly/auth"
"gitlab.com/gitlab-org/gitaly/client" "gitlab.com/gitlab-org/gitaly/client"
...@@ -11,6 +12,7 @@ import ( ...@@ -11,6 +12,7 @@ import (
"gitlab.com/gitlab-org/gitlab-shell/go/internal/config" "gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
"gitlab.com/gitlab-org/labkit/tracing" "gitlab.com/gitlab-org/labkit/tracing"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/metadata"
) )
// GitalyHandlerFunc implementations are responsible for making // GitalyHandlerFunc implementations are responsible for making
...@@ -29,6 +31,7 @@ type GitalyCommand struct { ...@@ -29,6 +31,7 @@ type GitalyCommand struct {
ServiceName string ServiceName string
Address string Address string
Token string Token string
Features map[string]string
} }
// RunGitalyCommand provides a bootstrap for Gitaly commands executed // RunGitalyCommand provides a bootstrap for Gitaly commands executed
...@@ -48,6 +51,18 @@ func (gc *GitalyCommand) RunGitalyCommand(handler GitalyHandlerFunc) error { ...@@ -48,6 +51,18 @@ func (gc *GitalyCommand) RunGitalyCommand(handler GitalyHandlerFunc) error {
return err return err
} }
func withOutgoingMetadata(ctx context.Context, features map[string]string) context.Context {
md := metadata.New(nil)
for k, v := range features {
if !strings.HasPrefix(k, "gitaly-feature-") {
continue
}
md.Append(k, v)
}
return metadata.NewOutgoingContext(ctx, md)
}
func getConn(gc *GitalyCommand) (*GitalyConn, error) { func getConn(gc *GitalyCommand) (*GitalyConn, error) {
if gc.Address == "" { if gc.Address == "" {
return nil, fmt.Errorf("no gitaly_address given") return nil, fmt.Errorf("no gitaly_address given")
...@@ -80,6 +95,7 @@ func getConn(gc *GitalyCommand) (*GitalyConn, error) { ...@@ -80,6 +95,7 @@ func getConn(gc *GitalyCommand) (*GitalyConn, error) {
) )
ctx, finished := tracing.ExtractFromEnv(context.Background()) ctx, finished := tracing.ExtractFromEnv(context.Background())
ctx = withOutgoingMetadata(ctx, gc.Features)
conn, err := client.Dial(gc.Address, connOpts) conn, err := client.Dial(gc.Address, connOpts)
if err != nil { if err != nil {
......
...@@ -5,8 +5,10 @@ import ( ...@@ -5,8 +5,10 @@ import (
"errors" "errors"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/metadata"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/config" "gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
) )
...@@ -40,3 +42,45 @@ func TestMissingGitalyAddress(t *testing.T) { ...@@ -40,3 +42,45 @@ func TestMissingGitalyAddress(t *testing.T) {
err := cmd.RunGitalyCommand(makeHandler(t, nil)) err := cmd.RunGitalyCommand(makeHandler(t, nil))
require.EqualError(t, err, "no gitaly_address given") require.EqualError(t, err, "no gitaly_address given")
} }
func TestGetConnMetadata(t *testing.T) {
tests := []struct {
name string
gc *GitalyCommand
want map[string]string
}{
{
name: "gitaly_feature_flags",
gc: &GitalyCommand{
Config: &config.Config{},
Address: "tcp://localhost:9999",
Features: map[string]string{
"gitaly-feature-cache_invalidator": "true",
"other-ff": "true",
"gitaly-feature-inforef_uploadpack_cache": "false",
},
},
want: map[string]string{
"gitaly-feature-cache_invalidator": "true",
"gitaly-feature-inforef_uploadpack_cache": "false",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
conn, err := getConn(tt.gc)
require.NoError(t, err)
md, exists := metadata.FromOutgoingContext(conn.ctx)
require.True(t, exists)
require.Equal(t, len(tt.want), md.Len())
for k, v := range tt.want {
values := md.Get(k)
assert.Equal(t, 1, len(values))
assert.Equal(t, v, values[0])
}
})
}
}
...@@ -47,6 +47,11 @@ func BuildAllowedWithGitalyHandlers(t *testing.T, gitalyAddress string) []testse ...@@ -47,6 +47,11 @@ func BuildAllowedWithGitalyHandlers(t *testing.T, gitalyAddress string) []testse
}, },
"address": gitalyAddress, "address": gitalyAddress,
"token": "token", "token": "token",
"features": map[string]string{
"gitaly-feature-cache_invalidator": "true",
"gitaly-feature-inforef_uploadpack_cache": "false",
"some-other-ff": "true",
},
}, },
} }
require.NoError(t, json.NewEncoder(w).Encode(body)) require.NoError(t, json.NewEncoder(w).Encode(body))
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment