Commit f3677174 authored by Jacob Vosmaer's avatar Jacob Vosmaer

Merge branch 'smh-fix-write-before-read' into 'master'

Delay PostUploadPack response until request is fully read

Closes #258

See merge request gitlab-org/gitlab-workhorse!494
parents d952a6c8 2fe99f66
......@@ -13,6 +13,7 @@ import (
"os/exec"
"path"
"strings"
"sync"
"testing"
"time"
......@@ -190,15 +191,17 @@ func TestPostReceivePackProxiedToGitalySuccessfully(t *testing.T) {
gitProtocol := "fake Git protocol"
resource := "/gitlab-org/gitlab-test.git/git-receive-pack"
resp, body := httpPost(
resp := httpPost(
t,
ws.URL+resource,
map[string]string{
"Content-Type": "application/x-git-receive-pack-request",
"Git-Protocol": gitProtocol,
},
testhelper.GitalyReceivePackResponseMock,
bytes.NewReader(testhelper.GitalyReceivePackResponseMock),
)
defer resp.Body.Close()
body := string(testhelper.ReadAll(t, resp.Body))
split := strings.SplitN(body, "\000", 2)
require.Len(t, split, 2)
......@@ -257,6 +260,11 @@ func TestPostReceivePackProxiedToGitalyInterrupted(t *testing.T) {
}
}
// ReaderFunc is an adapter to turn a conforming function into an io.Reader.
type ReaderFunc func(b []byte) (int, error)
func (r ReaderFunc) Read(b []byte) (int, error) { return r(b) }
func TestPostUploadPackProxiedToGitalySuccessfully(t *testing.T) {
for i, tc := range []struct {
showAllRefs bool
......@@ -283,19 +291,39 @@ func TestPostUploadPackProxiedToGitalySuccessfully(t *testing.T) {
gitProtocol := "fake git protocol"
resource := "/gitlab-org/gitlab-test.git/git-upload-pack"
resp, body := httpPost(
requestReader := bytes.NewReader(testhelper.GitalyUploadPackResponseMock)
var m sync.Mutex
requestReadFinished := false
resp := httpPost(
t,
ws.URL+resource,
map[string]string{
"Content-Type": "application/x-git-upload-pack-request",
"Git-Protocol": gitProtocol,
},
testhelper.GitalyUploadPackResponseMock,
ReaderFunc(func(b []byte) (int, error) {
n, err := requestReader.Read(b)
if err != nil {
m.Lock()
requestReadFinished = true
m.Unlock()
}
return n, err
}),
)
defer resp.Body.Close()
require.Equal(t, 200, resp.StatusCode, "POST %q", resource)
testhelper.AssertResponseHeader(t, resp, "Content-Type", "application/x-git-upload-pack-result")
m.Lock()
requestFinished := requestReadFinished
m.Unlock()
if !requestFinished {
t.Fatalf("response written before request was fully read")
}
body := string(testhelper.ReadAll(t, resp.Body))
bodySplit := strings.SplitN(body, "\000", 2)
require.Len(t, bodySplit, 2)
......
......@@ -21,13 +21,8 @@ var (
func handleUploadPack(w *HttpResponseWriter, r *http.Request, a *api.Response) error {
ctx := r.Context()
// The body will consist almost entirely of 'have XXX' and 'want XXX'
// lines; these are about 50 bytes long. With a size limit of 10MiB, the
// client can send over 200,000 have/want lines.
sizeLimited := io.LimitReader(r.Body, 10*1024*1024)
// Prevent the client from holding the connection open indefinitely. A
// transfer rate of 17KiB/sec is sufficient to fill the 10MiB buffer in
// transfer rate of 17KiB/sec is sufficient to send 10MiB of data in
// ten minutes, which seems adequate. Most requests will be much smaller.
// This mitigates a use-after-check issue.
//
......@@ -36,21 +31,16 @@ func handleUploadPack(w *HttpResponseWriter, r *http.Request, a *api.Response) e
readerCtx, cancel := context.WithTimeout(ctx, uploadPackTimeout)
defer cancel()
limited := helper.NewContextReader(readerCtx, sizeLimited)
buffer, err := helper.ReadAllTempfile(limited)
if err != nil {
return fmt.Errorf("ReadAllTempfile: %v", err)
}
defer buffer.Close()
r.Body.Close()
limited := helper.NewContextReader(readerCtx, r.Body)
cr, cw := helper.NewWriteAfterReader(limited, w)
defer cw.Flush()
action := getService(r)
writePostRPCHeader(w, action)
gitProtocol := r.Header.Get("Git-Protocol")
return handleUploadPackWithGitaly(ctx, a, buffer, w, gitProtocol)
return handleUploadPackWithGitaly(ctx, a, cr, cw, gitProtocol)
}
func handleUploadPackWithGitaly(ctx context.Context, a *api.Response, clientRequest io.Reader, clientResponse io.Writer, gitProtocol string) error {
......
package git
import (
"fmt"
"io/ioutil"
"net"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"gitlab.com/gitlab-org/gitaly/proto/go/gitalypb"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/gitaly"
)
var (
......@@ -23,17 +33,54 @@ func (f *fakeReader) Read(b []byte) (int, error) {
return f.n, f.err
}
type smartHTTPServiceServer struct {
gitalypb.UnimplementedSmartHTTPServiceServer
PostUploadPackFunc func(gitalypb.SmartHTTPService_PostUploadPackServer) error
}
func (srv *smartHTTPServiceServer) PostUploadPack(s gitalypb.SmartHTTPService_PostUploadPackServer) error {
return srv.PostUploadPackFunc(s)
}
func TestUploadPackTimesOut(t *testing.T) {
uploadPackTimeout = time.Millisecond
defer func() { uploadPackTimeout = originalUploadPackTimeout }()
addr, cleanUp := startSmartHTTPServer(t, &smartHTTPServiceServer{
PostUploadPackFunc: func(stream gitalypb.SmartHTTPService_PostUploadPackServer) error {
_, err := stream.Recv() // trigger a read on the client request body
require.NoError(t, err)
return nil
},
})
defer cleanUp()
body := &fakeReader{n: 0, err: nil}
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", body)
a := &api.Response{}
a := &api.Response{GitalyServer: gitaly.Server{Address: addr}}
err := handleUploadPack(NewHttpResponseWriter(w), r, a)
require.EqualError(t, err, "ReadAllTempfile: context deadline exceeded")
require.EqualError(t, err, "smarthttp.UploadPack: busyReader: context deadline exceeded")
}
func startSmartHTTPServer(t testing.TB, s gitalypb.SmartHTTPServiceServer) (string, func()) {
tmp, err := ioutil.TempDir("", "")
require.NoError(t, err)
socket := filepath.Join(tmp, "gitaly.sock")
ln, err := net.Listen("unix", socket)
require.NoError(t, err)
srv := grpc.NewServer()
gitalypb.RegisterSmartHTTPServiceServer(srv, s)
go func() {
require.NoError(t, srv.Serve(ln))
}()
return fmt.Sprintf("%s://%s", ln.Addr().Network(), ln.Addr().String()), func() {
srv.Stop()
assert.NoError(t, os.RemoveAll(tmp), "error removing temp dir %q", tmp)
}
}
......@@ -195,11 +195,16 @@ func (s *GitalyTestServer) PostUploadPack(stream gitalypb.SmartHTTPService_PostU
return err
}
data := []byte(strings.Join([]string{
jsonString,
}, "\000") + "\000")
if err := stream.Send(&gitalypb.PostUploadPackResponse{
Data: []byte(strings.Join([]string{jsonString}, "\000") + "\000"),
}); err != nil {
return err
}
// The body of the request starts in the second message
nSends := 0
// The body of the request starts in the second message. Gitaly streams PostUploadPack responses
// as soon as possible without reading the request completely first. We stream messages here
// directly back to the client to simulate the streaming of the actual implementation.
for {
req, err := stream.Recv()
if err != nil {
......@@ -209,12 +214,12 @@ func (s *GitalyTestServer) PostUploadPack(stream gitalypb.SmartHTTPService_PostU
break
}
data = append(data, req.GetData()...)
if err := stream.Send(&gitalypb.PostUploadPackResponse{Data: req.GetData()}); err != nil {
return err
}
nSends, _ := sendBytes(data, 100, func(p []byte) error {
return stream.Send(&gitalypb.PostUploadPackResponse{Data: p})
})
nSends++
}
if nSends <= 1 {
panic("should have sent more than one message")
......
......@@ -5,6 +5,7 @@ import (
"bytes"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
......@@ -16,6 +17,7 @@ import (
"testing"
"github.com/dgrijalva/jwt-go"
"github.com/stretchr/testify/require"
"gitlab.com/gitlab-org/labkit/log"
......@@ -177,6 +179,14 @@ func LoadFile(t *testing.T, filePath string) string {
return string(content)
}
func ReadAll(t *testing.T, r io.Reader) []byte {
t.Helper()
b, err := ioutil.ReadAll(r)
require.NoError(t, err)
return b
}
func ParseJWT(token *jwt.Token) (interface{}, error) {
// Don't forget to validate the alg is what you expect:
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
......
......@@ -6,6 +6,7 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
......@@ -593,8 +594,8 @@ func httpGet(t *testing.T, url string, headers map[string]string) (*http.Respons
return resp, string(b)
}
func httpPost(t *testing.T, url string, headers map[string]string, reqBody []byte) (*http.Response, string) {
req, err := http.NewRequest("POST", url, bytes.NewReader(reqBody))
func httpPost(t *testing.T, url string, headers map[string]string, reqBody io.Reader) *http.Response {
req, err := http.NewRequest("POST", url, reqBody)
require.NoError(t, err)
for k, v := range headers {
......@@ -603,12 +604,8 @@ func httpPost(t *testing.T, url string, headers map[string]string, reqBody []byt
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body)
require.NoError(t, err)
return resp, string(b)
return resp
}
func assertNginxResponseBuffering(t *testing.T, expected string, resp *http.Response, msgAndArgs ...interface{}) {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment