Commit 285f47a7 authored by Jacob Vosmaer's avatar Jacob Vosmaer

Merge branch 'delayed-responsewriter' into 'master'

Delay HTTP headers for Git HTTP responses

Alternative to https://gitlab.com/gitlab-org/gitlab-workhorse/merge_requests/38
and https://gitlab.com/gitlab-org/gitlab-workhorse/merge_requests/40 .

See merge request !42
parents 8b9f31be 56d92dc1
// Package delay exports delay.ResponseWriter. This type implements
// http.ResponseWriter with the ability to delay setting the HTTP
// response code (with WriteHeader()) until writing the first bufferSize
// bytes. This makes it possible, up to a point, to 'change your mind'
// about the HTTP status code. The caller must call
// ResponseWriter.Flush() before returning from the handler (e.g. using
// 'defer').
package delay
import (
"bytes"
"io"
"net/http"
)
const bufferSize = 8192
type ResponseWriter struct {
writer http.ResponseWriter
status int
bufWritten int
cap int
flushed bool
buffer *bytes.Buffer
}
func NewResponseWriter(w http.ResponseWriter) *ResponseWriter {
return &ResponseWriter{
writer: w,
buffer: bytes.NewBuffer(make([]byte, 0, bufferSize)),
cap: bufferSize,
}
}
func (rw *ResponseWriter) Write(buf []byte) (int, error) {
if !rw.flushed && len(buf)+rw.bufWritten <= rw.cap {
n, err := rw.buffer.Write(buf)
rw.bufWritten += n
return n, err
}
if err := rw.Flush(); err != nil {
return 0, err
}
return rw.writer.Write(buf)
}
func (rw *ResponseWriter) Header() http.Header {
return rw.writer.Header()
}
func (rw *ResponseWriter) WriteHeader(code int) {
if rw.status != 0 {
return
}
rw.status = code
}
func (rw *ResponseWriter) Flush() error {
if rw.flushed {
return nil
}
rw.flushed = true
if rw.status == 0 {
rw.writer.WriteHeader(http.StatusOK)
} else {
rw.writer.WriteHeader(rw.status)
}
_, err := io.Copy(rw.writer, rw.buffer)
rw.buffer = nil // "Release" the buffer for GC
return err
}
package delay
import (
"fmt"
"net/http/httptest"
"strings"
"testing"
)
func TestSanity(t *testing.T) {
first, second := 200, 500
w := httptest.NewRecorder()
w.WriteHeader(first)
w.WriteHeader(second)
if code := w.Code; code != first {
t.Fatalf("Expected HTTP code %d, got %d", first, code)
}
}
func TestSmallResponse(t *testing.T) {
code := 500
body := "hello"
w := httptest.NewRecorder()
rw := NewResponseWriter(w)
fmt.Fprint(rw, body)
rw.WriteHeader(code)
rw.Flush()
if actualCode := w.Code; actualCode != code {
t.Fatalf("Expected code %d, got %d", code, actualCode)
}
if actualBody := w.Body.String(); actualBody != body {
t.Fatalf("Expected body %q, got %q", body, actualBody)
}
}
func TestLargeResponse(t *testing.T) {
code := 200
body := strings.Repeat("0123456789", bufferSize/5) // must exceed the buffer size
w := httptest.NewRecorder()
rw := NewResponseWriter(w)
fmt.Fprint(rw, body)
// Because the 'body' was too long this 500 should be ignored
rw.WriteHeader(500)
rw.Flush()
if actualCode := w.Code; actualCode != code {
t.Fatalf("Expected code %d, got %d", code, actualCode)
}
if actualBody := w.Body.String(); actualBody != body {
t.Fatalf("Expected body %q, got %q", body, actualBody)
}
}
...@@ -6,6 +6,7 @@ package git ...@@ -6,6 +6,7 @@ package git
import ( import (
"../api" "../api"
"../delay"
"../helper" "../helper"
"errors" "errors"
"fmt" "fmt"
...@@ -52,7 +53,10 @@ func repoPreAuthorizeHandler(myAPI *api.API, handleFunc api.HandleFunc) http.Han ...@@ -52,7 +53,10 @@ func repoPreAuthorizeHandler(myAPI *api.API, handleFunc api.HandleFunc) http.Han
}, "") }, "")
} }
func handleGetInfoRefs(w http.ResponseWriter, r *http.Request, a *api.Response) { func handleGetInfoRefs(_w http.ResponseWriter, r *http.Request, a *api.Response) {
w := delay.NewResponseWriter(_w)
defer w.Flush()
rpc := r.URL.Query().Get("service") rpc := r.URL.Query().Get("service")
if !(rpc == "git-upload-pack" || rpc == "git-receive-pack") { if !(rpc == "git-upload-pack" || rpc == "git-receive-pack") {
// The 'dumb' Git HTTP protocol is not supported // The 'dumb' Git HTTP protocol is not supported
...@@ -77,26 +81,28 @@ func handleGetInfoRefs(w http.ResponseWriter, r *http.Request, a *api.Response) ...@@ -77,26 +81,28 @@ func handleGetInfoRefs(w http.ResponseWriter, r *http.Request, a *api.Response)
// Start writing the response // Start writing the response
w.Header().Add("Content-Type", fmt.Sprintf("application/x-%s-advertisement", rpc)) w.Header().Add("Content-Type", fmt.Sprintf("application/x-%s-advertisement", rpc))
w.Header().Add("Cache-Control", "no-cache") w.Header().Add("Cache-Control", "no-cache")
w.WriteHeader(200) // Don't bother with HTTP 500 from this point on, just return
if err := pktLine(w, fmt.Sprintf("# service=%s\n", rpc)); err != nil { if err := pktLine(w, fmt.Sprintf("# service=%s\n", rpc)); err != nil {
helper.LogError(fmt.Errorf("handleGetInfoRefs: pktLine: %v", err)) helper.Fail500(w, fmt.Errorf("handleGetInfoRefs: pktLine: %v", err))
return return
} }
if err := pktFlush(w); err != nil { if err := pktFlush(w); err != nil {
helper.LogError(fmt.Errorf("handleGetInfoRefs: pktFlush: %v", err)) helper.Fail500(w, fmt.Errorf("handleGetInfoRefs: pktFlush: %v", err))
return return
} }
if _, err := io.Copy(w, stdout); err != nil { if _, err := io.Copy(w, stdout); err != nil {
helper.LogError(fmt.Errorf("handleGetInfoRefs: copy output of %v: %v", cmd.Args, err)) helper.Fail500(w, fmt.Errorf("handleGetInfoRefs: copy output of %v: %v", cmd.Args, err))
return return
} }
if err := cmd.Wait(); err != nil { if err := cmd.Wait(); err != nil {
helper.LogError(fmt.Errorf("handleGetInfoRefs: wait for %v: %v", cmd.Args, err)) helper.Fail500(w, fmt.Errorf("handleGetInfoRefs: wait for %v: %v", cmd.Args, err))
return return
} }
} }
func handlePostRPC(w http.ResponseWriter, r *http.Request, a *api.Response) { func handlePostRPC(_w http.ResponseWriter, r *http.Request, a *api.Response) {
w := delay.NewResponseWriter(_w)
defer w.Flush()
var err error var err error
// Get Git action from URL // Get Git action from URL
...@@ -142,15 +148,14 @@ func handlePostRPC(w http.ResponseWriter, r *http.Request, a *api.Response) { ...@@ -142,15 +148,14 @@ func handlePostRPC(w http.ResponseWriter, r *http.Request, a *api.Response) {
// Start writing the response // Start writing the response
w.Header().Add("Content-Type", fmt.Sprintf("application/x-%s-result", action)) w.Header().Add("Content-Type", fmt.Sprintf("application/x-%s-result", action))
w.Header().Add("Cache-Control", "no-cache") w.Header().Add("Cache-Control", "no-cache")
w.WriteHeader(200) // Don't bother with HTTP 500 from this point on, just return
// This io.Copy may take a long time, both for Git push and pull. // This io.Copy may take a long time, both for Git push and pull.
if _, err := io.Copy(w, stdout); err != nil { if _, err := io.Copy(w, stdout); err != nil {
helper.LogError(fmt.Errorf("handlePostRPC copy output of %v: %v", cmd.Args, err)) helper.Fail500(w, fmt.Errorf("handlePostRPC copy output of %v: %v", cmd.Args, err))
return return
} }
if err := cmd.Wait(); err != nil { if err := cmd.Wait(); err != nil {
helper.LogError(fmt.Errorf("handlePostRPC wait for %v: %v", cmd.Args, err)) helper.Fail500(w, fmt.Errorf("handlePostRPC wait for %v: %v", cmd.Args, err))
return return
} }
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment