Commit 344cc6b4 authored by Nick Thomas's avatar Nick Thomas

Merge branch 'id-api-regular-http' into 'master'

Support calling internal api using HTTP

See merge request gitlab-org/gitlab-shell!295
parents 6e9b4dec 9d9e1617
...@@ -17,7 +17,6 @@ import ( ...@@ -17,7 +17,6 @@ import (
) )
var ( var (
testConfig = &config.Config{GitlabUrl: "http+unix://" + testserver.TestSocket}
requests = []testserver.TestRequestHandler{ requests = []testserver.TestRequestHandler{
{ {
Path: "/api/v4/internal/discover", Path: "/api/v4/internal/discover",
...@@ -46,7 +45,7 @@ var ( ...@@ -46,7 +45,7 @@ var (
) )
func TestExecute(t *testing.T) { func TestExecute(t *testing.T) {
cleanup, err := testserver.StartSocketHttpServer(requests) cleanup, url, err := testserver.StartSocketHttpServer(requests)
require.NoError(t, err) require.NoError(t, err)
defer cleanup() defer cleanup()
...@@ -79,7 +78,7 @@ func TestExecute(t *testing.T) { ...@@ -79,7 +78,7 @@ func TestExecute(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) { t.Run(tc.desc, func(t *testing.T) {
cmd := &Command{Config: testConfig, Args: tc.arguments} cmd := &Command{Config: &config.Config{GitlabUrl: url}, Args: tc.arguments}
buffer := &bytes.Buffer{} buffer := &bytes.Buffer{}
err := cmd.Execute(&readwriter.ReadWriter{Out: buffer}) err := cmd.Execute(&readwriter.ReadWriter{Out: buffer})
...@@ -91,7 +90,7 @@ func TestExecute(t *testing.T) { ...@@ -91,7 +90,7 @@ func TestExecute(t *testing.T) {
} }
func TestFailingExecute(t *testing.T) { func TestFailingExecute(t *testing.T) {
cleanup, err := testserver.StartSocketHttpServer(requests) cleanup, url, err := testserver.StartSocketHttpServer(requests)
require.NoError(t, err) require.NoError(t, err)
defer cleanup() defer cleanup()
...@@ -119,7 +118,7 @@ func TestFailingExecute(t *testing.T) { ...@@ -119,7 +118,7 @@ func TestFailingExecute(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) { t.Run(tc.desc, func(t *testing.T) {
cmd := &Command{Config: testConfig, Args: tc.arguments} cmd := &Command{Config: &config.Config{GitlabUrl: url}, Args: tc.arguments}
buffer := &bytes.Buffer{} buffer := &bytes.Buffer{}
err := cmd.Execute(&readwriter.ReadWriter{Out: buffer}) err := cmd.Execute(&readwriter.ReadWriter{Out: buffer})
......
...@@ -18,12 +18,10 @@ import ( ...@@ -18,12 +18,10 @@ import (
) )
var ( var (
testConfig *config.Config
requests []testserver.TestRequestHandler requests []testserver.TestRequestHandler
) )
func setup(t *testing.T) { func setup(t *testing.T) {
testConfig = &config.Config{GitlabUrl: "http+unix://" + testserver.TestSocket}
requests = []testserver.TestRequestHandler{ requests = []testserver.TestRequestHandler{
{ {
Path: "/api/v4/internal/two_factor_recovery_codes", Path: "/api/v4/internal/two_factor_recovery_codes",
...@@ -66,7 +64,7 @@ const ( ...@@ -66,7 +64,7 @@ const (
func TestExecute(t *testing.T) { func TestExecute(t *testing.T) {
setup(t) setup(t)
cleanup, err := testserver.StartSocketHttpServer(requests) cleanup, url, err := testserver.StartSocketHttpServer(requests)
require.NoError(t, err) require.NoError(t, err)
defer cleanup() defer cleanup()
...@@ -124,7 +122,7 @@ func TestExecute(t *testing.T) { ...@@ -124,7 +122,7 @@ func TestExecute(t *testing.T) {
output := &bytes.Buffer{} output := &bytes.Buffer{}
input := bytes.NewBufferString(tc.answer) input := bytes.NewBufferString(tc.answer)
cmd := &Command{Config: testConfig, Args: tc.arguments} cmd := &Command{Config: &config.Config{GitlabUrl: url}, Args: tc.arguments}
err := cmd.Execute(&readwriter.ReadWriter{Out: output, In: input}) err := cmd.Execute(&readwriter.ReadWriter{Out: output, In: input})
......
...@@ -22,6 +22,12 @@ type MigrationConfig struct { ...@@ -22,6 +22,12 @@ type MigrationConfig struct {
Features []string `yaml:"features"` Features []string `yaml:"features"`
} }
type HttpSettingsConfig struct {
User string `yaml:"user"`
Password string `yaml:"password"`
ReadTimeoutSeconds uint64 `yaml:"read_timeout"`
}
type Config struct { type Config struct {
RootDir string RootDir string
LogFile string `yaml:"log_file"` LogFile string `yaml:"log_file"`
...@@ -31,6 +37,8 @@ type Config struct { ...@@ -31,6 +37,8 @@ type Config struct {
GitlabTracing string `yaml:"gitlab_tracing"` GitlabTracing string `yaml:"gitlab_tracing"`
SecretFilePath string `yaml:"secret_file"` SecretFilePath string `yaml:"secret_file"`
Secret string `yaml:"secret"` Secret string `yaml:"secret"`
HttpSettings HttpSettingsConfig `yaml:"http_settings"`
HttpClient *HttpClient
} }
func New() (*Config, error) { func New() (*Config, error) {
...@@ -51,7 +59,7 @@ func (c *Config) FeatureEnabled(featureName string) bool { ...@@ -51,7 +59,7 @@ func (c *Config) FeatureEnabled(featureName string) bool {
return false return false
} }
if !strings.HasPrefix(c.GitlabUrl, "http+unix://") { if !strings.HasPrefix(c.GitlabUrl, "http+unix://") && !strings.HasPrefix(c.GitlabUrl, "http://") {
return false return false
} }
......
...@@ -30,6 +30,7 @@ func TestParseConfig(t *testing.T) { ...@@ -30,6 +30,7 @@ func TestParseConfig(t *testing.T) {
gitlabUrl string gitlabUrl string
migration MigrationConfig migration MigrationConfig
secret string secret string
httpSettings HttpSettingsConfig
}{ }{
{ {
path: path.Join(testRoot, "gitlab-shell.log"), path: path.Join(testRoot, "gitlab-shell.log"),
...@@ -86,6 +87,13 @@ func TestParseConfig(t *testing.T) { ...@@ -86,6 +87,13 @@ func TestParseConfig(t *testing.T) {
format: "text", format: "text",
secret: "an inline secret", secret: "an inline secret",
}, },
{
yaml: "http_settings:\n user: user_basic_auth\n password: password_basic_auth\n read_timeout: 500",
path: path.Join(testRoot, "gitlab-shell.log"),
format: "text",
secret: "default-secret-content",
httpSettings: HttpSettingsConfig{User: "user_basic_auth", Password: "password_basic_auth", ReadTimeoutSeconds: 500},
},
} }
for _, tc := range testCases { for _, tc := range testCases {
...@@ -101,6 +109,7 @@ func TestParseConfig(t *testing.T) { ...@@ -101,6 +109,7 @@ func TestParseConfig(t *testing.T) {
assert.Equal(t, tc.format, cfg.LogFormat) assert.Equal(t, tc.format, cfg.LogFormat)
assert.Equal(t, tc.gitlabUrl, cfg.GitlabUrl) assert.Equal(t, tc.gitlabUrl, cfg.GitlabUrl)
assert.Equal(t, tc.secret, cfg.Secret) assert.Equal(t, tc.secret, cfg.Secret)
assert.Equal(t, tc.httpSettings, cfg.HttpSettings)
}) })
} }
} }
...@@ -139,6 +148,15 @@ func TestFeatureEnabled(t *testing.T) { ...@@ -139,6 +148,15 @@ func TestFeatureEnabled(t *testing.T) {
feature: "discover", feature: "discover",
expectEnabled: false, expectEnabled: false,
}, },
{
desc: "When the protocol is http and the feature enabled",
config: &Config{
GitlabUrl: "http://localhost:3000",
Migration: MigrationConfig{Enabled: true, Features: []string{"discover"}},
},
feature: "discover",
expectEnabled: true,
},
{ {
desc: "When the protocol is not supported", desc: "When the protocol is not supported",
config: &Config{ config: &Config{
......
package config
import (
"context"
"net"
"net/http"
"strings"
"time"
)
const (
socketBaseUrl = "http://unix"
UnixSocketProtocol = "http+unix://"
HttpProtocol = "http://"
defaultReadTimeoutSeconds = 300
)
type HttpClient struct {
HttpClient *http.Client
Host string
}
func (c *Config) GetHttpClient() *HttpClient {
if c.HttpClient != nil {
return c.HttpClient
}
var transport *http.Transport
var host string
if strings.HasPrefix(c.GitlabUrl, UnixSocketProtocol) {
transport, host = c.buildSocketTransport()
} else if strings.HasPrefix(c.GitlabUrl, HttpProtocol) {
transport, host = c.buildHttpTransport()
} else {
return nil
}
httpClient := &http.Client{
Transport: transport,
Timeout: c.readTimeout(),
}
client := &HttpClient{HttpClient: httpClient, Host: host}
c.HttpClient = client
return client
}
func (c *Config) buildSocketTransport() (*http.Transport, string) {
socketPath := strings.TrimPrefix(c.GitlabUrl, UnixSocketProtocol)
transport := &http.Transport{
DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) {
dialer := net.Dialer{}
return dialer.DialContext(ctx, "unix", socketPath)
},
}
return transport, socketBaseUrl
}
func (c *Config) buildHttpTransport() (*http.Transport, string) {
return &http.Transport{}, c.GitlabUrl
}
func (c *Config) readTimeout() time.Duration {
timeoutSeconds := c.HttpSettings.ReadTimeoutSeconds
if timeoutSeconds == 0 {
timeoutSeconds = defaultReadTimeoutSeconds
}
return time.Duration(timeoutSeconds) * time.Second
}
package config
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestReadTimeout(t *testing.T) {
expectedSeconds := uint64(300)
config := &Config{
GitlabUrl: "http://localhost:3000",
HttpSettings: HttpSettingsConfig{ReadTimeoutSeconds: expectedSeconds},
}
client := config.GetHttpClient()
require.NotNil(t, client)
assert.Equal(t, time.Duration(expectedSeconds)*time.Second, client.HttpClient.Timeout)
}
package gitlabnet package gitlabnet
import ( import (
"bytes"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"net/http" "net/http"
"strings" "strings"
...@@ -15,22 +17,24 @@ const ( ...@@ -15,22 +17,24 @@ const (
secretHeaderName = "Gitlab-Shared-Secret" secretHeaderName = "Gitlab-Shared-Secret"
) )
type GitlabClient interface {
Get(path string) (*http.Response, error)
Post(path string, data interface{}) (*http.Response, error)
}
type ErrorResponse struct { type ErrorResponse struct {
Message string `json:"message"` Message string `json:"message"`
} }
func GetClient(config *config.Config) (GitlabClient, error) { type GitlabClient struct {
url := config.GitlabUrl httpClient *http.Client
if strings.HasPrefix(url, UnixSocketProtocol) { config *config.Config
return buildSocketClient(config), nil host string
} }
func GetClient(config *config.Config) (*GitlabClient, error) {
client := config.GetHttpClient()
if client == nil {
return nil, fmt.Errorf("Unsupported protocol") return nil, fmt.Errorf("Unsupported protocol")
}
return &GitlabClient{httpClient: client.HttpClient, config: config, host: client.Host}, nil
} }
func normalizePath(path string) string { func normalizePath(path string) string {
...@@ -44,6 +48,27 @@ func normalizePath(path string) string { ...@@ -44,6 +48,27 @@ func normalizePath(path string) string {
return path return path
} }
func newRequest(method, host, path string, data interface{}) (*http.Request, error) {
path = normalizePath(path)
var jsonReader io.Reader
if data != nil {
jsonData, err := json.Marshal(data)
if err != nil {
return nil, err
}
jsonReader = bytes.NewReader(jsonData)
}
request, err := http.NewRequest(method, host+path, jsonReader)
if err != nil {
return nil, err
}
return request, nil
}
func parseError(resp *http.Response) error { func parseError(resp *http.Response) error {
if resp.StatusCode >= 200 && resp.StatusCode <= 299 { if resp.StatusCode >= 200 && resp.StatusCode <= 299 {
return nil return nil
...@@ -59,11 +84,32 @@ func parseError(resp *http.Response) error { ...@@ -59,11 +84,32 @@ func parseError(resp *http.Response) error {
} }
func doRequest(client *http.Client, config *config.Config, request *http.Request) (*http.Response, error) { func (c *GitlabClient) Get(path string) (*http.Response, error) {
encodedSecret := base64.StdEncoding.EncodeToString([]byte(config.Secret)) return c.doRequest("GET", path, nil)
}
func (c *GitlabClient) Post(path string, data interface{}) (*http.Response, error) {
return c.doRequest("POST", path, data)
}
func (c *GitlabClient) doRequest(method, path string, data interface{}) (*http.Response, error) {
request, err := newRequest(method, c.host, path, data)
if err != nil {
return nil, err
}
user, password := c.config.HttpSettings.User, c.config.HttpSettings.Password
if user != "" && password != "" {
request.SetBasicAuth(user, password)
}
encodedSecret := base64.StdEncoding.EncodeToString([]byte(c.config.Secret))
request.Header.Set(secretHeaderName, encodedSecret) request.Header.Set(secretHeaderName, encodedSecret)
response, err := client.Do(request) request.Header.Add("Content-Type", "application/json")
request.Close = true
response, err := c.httpClient.Do(request)
if err != nil { if err != nil {
return nil, fmt.Errorf("Internal API unreachable") return nil, fmt.Errorf("Internal API unreachable")
} }
......
...@@ -61,37 +61,44 @@ func TestClients(t *testing.T) { ...@@ -61,37 +61,44 @@ func TestClients(t *testing.T) {
}, },
}, },
} }
testConfig := &config.Config{GitlabUrl: "http+unix://" + testserver.TestSocket, Secret: "sssh, it's a secret"}
testCases := []struct { testCases := []struct {
desc string desc string
client GitlabClient secret string
server func([]testserver.TestRequestHandler) (func(), error) server func([]testserver.TestRequestHandler) (func(), string, error)
}{ }{
{ {
desc: "Socket client", desc: "Socket client",
client: buildSocketClient(testConfig), secret: "sssh, it's a secret",
server: testserver.StartSocketHttpServer, server: testserver.StartSocketHttpServer,
}, },
{
desc: "Http client",
secret: "sssh, it's a secret",
server: testserver.StartHttpServer,
},
} }
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) { t.Run(tc.desc, func(t *testing.T) {
cleanup, err := tc.server(requests) cleanup, url, err := tc.server(requests)
defer cleanup() defer cleanup()
require.NoError(t, err) require.NoError(t, err)
testBrokenRequest(t, tc.client) client, err := GetClient(&config.Config{GitlabUrl: url, Secret: tc.secret})
testSuccessfulGet(t, tc.client) require.NoError(t, err)
testSuccessfulPost(t, tc.client)
testMissing(t, tc.client) testBrokenRequest(t, client)
testErrorMessage(t, tc.client) testSuccessfulGet(t, client)
testAuthenticationHeader(t, tc.client) testSuccessfulPost(t, client)
testMissing(t, client)
testErrorMessage(t, client)
testAuthenticationHeader(t, client)
}) })
} }
} }
func testSuccessfulGet(t *testing.T, client GitlabClient) { func testSuccessfulGet(t *testing.T, client *GitlabClient) {
t.Run("Successful get", func(t *testing.T) { t.Run("Successful get", func(t *testing.T) {
response, err := client.Get("/hello") response, err := client.Get("/hello")
defer response.Body.Close() defer response.Body.Close()
...@@ -105,7 +112,7 @@ func testSuccessfulGet(t *testing.T, client GitlabClient) { ...@@ -105,7 +112,7 @@ func testSuccessfulGet(t *testing.T, client GitlabClient) {
}) })
} }
func testSuccessfulPost(t *testing.T, client GitlabClient) { func testSuccessfulPost(t *testing.T, client *GitlabClient) {
t.Run("Successful Post", func(t *testing.T) { t.Run("Successful Post", func(t *testing.T) {
data := map[string]string{"key": "value"} data := map[string]string{"key": "value"}
...@@ -121,7 +128,7 @@ func testSuccessfulPost(t *testing.T, client GitlabClient) { ...@@ -121,7 +128,7 @@ func testSuccessfulPost(t *testing.T, client GitlabClient) {
}) })
} }
func testMissing(t *testing.T, client GitlabClient) { func testMissing(t *testing.T, client *GitlabClient) {
t.Run("Missing error for GET", func(t *testing.T) { t.Run("Missing error for GET", func(t *testing.T) {
response, err := client.Get("/missing") response, err := client.Get("/missing")
assert.EqualError(t, err, "Internal API error (404)") assert.EqualError(t, err, "Internal API error (404)")
...@@ -135,7 +142,7 @@ func testMissing(t *testing.T, client GitlabClient) { ...@@ -135,7 +142,7 @@ func testMissing(t *testing.T, client GitlabClient) {
}) })
} }
func testErrorMessage(t *testing.T, client GitlabClient) { func testErrorMessage(t *testing.T, client *GitlabClient) {
t.Run("Error with message for GET", func(t *testing.T) { t.Run("Error with message for GET", func(t *testing.T) {
response, err := client.Get("/error") response, err := client.Get("/error")
assert.EqualError(t, err, "Don't do that") assert.EqualError(t, err, "Don't do that")
...@@ -149,7 +156,7 @@ func testErrorMessage(t *testing.T, client GitlabClient) { ...@@ -149,7 +156,7 @@ func testErrorMessage(t *testing.T, client GitlabClient) {
}) })
} }
func testBrokenRequest(t *testing.T, client GitlabClient) { func testBrokenRequest(t *testing.T, client *GitlabClient) {
t.Run("Broken request for GET", func(t *testing.T) { t.Run("Broken request for GET", func(t *testing.T) {
response, err := client.Get("/broken") response, err := client.Get("/broken")
assert.EqualError(t, err, "Internal API unreachable") assert.EqualError(t, err, "Internal API unreachable")
...@@ -163,7 +170,7 @@ func testBrokenRequest(t *testing.T, client GitlabClient) { ...@@ -163,7 +170,7 @@ func testBrokenRequest(t *testing.T, client GitlabClient) {
}) })
} }
func testAuthenticationHeader(t *testing.T, client GitlabClient) { func testAuthenticationHeader(t *testing.T, client *GitlabClient) {
t.Run("Authentication headers for GET", func(t *testing.T) { t.Run("Authentication headers for GET", func(t *testing.T) {
response, err := client.Get("/auth") response, err := client.Get("/auth")
defer response.Body.Close() defer response.Body.Close()
......
...@@ -13,7 +13,7 @@ import ( ...@@ -13,7 +13,7 @@ import (
type Client struct { type Client struct {
config *config.Config config *config.Config
client gitlabnet.GitlabClient client *gitlabnet.GitlabClient
} }
type Response struct { type Response struct {
......
...@@ -15,12 +15,10 @@ import ( ...@@ -15,12 +15,10 @@ import (
) )
var ( var (
testConfig *config.Config
requests []testserver.TestRequestHandler requests []testserver.TestRequestHandler
) )
func init() { func init() {
testConfig = &config.Config{GitlabUrl: "http+unix://" + testserver.TestSocket}
requests = []testserver.TestRequestHandler{ requests = []testserver.TestRequestHandler{
{ {
Path: "/api/v4/internal/discover", Path: "/api/v4/internal/discover",
...@@ -121,10 +119,10 @@ func TestErrorResponses(t *testing.T) { ...@@ -121,10 +119,10 @@ func TestErrorResponses(t *testing.T) {
} }
func setup(t *testing.T) (*Client, func()) { func setup(t *testing.T) (*Client, func()) {
cleanup, err := testserver.StartSocketHttpServer(requests) cleanup, url, err := testserver.StartSocketHttpServer(requests)
require.NoError(t, err) require.NoError(t, err)
client, err := NewClient(testConfig) client, err := NewClient(&config.Config{GitlabUrl: url})
require.NoError(t, err) require.NoError(t, err)
return client, cleanup return client, cleanup
......
package gitlabnet
import (
"encoding/base64"
"fmt"
"io/ioutil"
"net/http"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/testserver"
)
const (
username = "basic_auth_user"
password = "basic_auth_password"
)
func TestBasicAuthSettings(t *testing.T) {
requests := []testserver.TestRequestHandler{
{
Path: "/api/v4/internal/get_endpoint",
Handler: func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method)
fmt.Fprint(w, r.Header.Get("Authorization"))
},
},
{
Path: "/api/v4/internal/post_endpoint",
Handler: func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodPost, r.Method)
fmt.Fprint(w, r.Header.Get("Authorization"))
},
},
}
config := &config.Config{HttpSettings: config.HttpSettingsConfig{User: username, Password: password}}
client, cleanup := setup(t, config, requests)
defer cleanup()
response, err := client.Get("/get_endpoint")
require.NoError(t, err)
testBasicAuthHeaders(t, response)
response, err = client.Post("/post_endpoint", nil)
require.NoError(t, err)
testBasicAuthHeaders(t, response)
}
func testBasicAuthHeaders(t *testing.T, response *http.Response) {
defer response.Body.Close()
require.NotNil(t, response)
responseBody, err := ioutil.ReadAll(response.Body)
assert.NoError(t, err)
headerParts := strings.Split(string(responseBody), " ")
assert.Equal(t, "Basic", headerParts[0])
credentials, err := base64.StdEncoding.DecodeString(headerParts[1])
require.NoError(t, err)
assert.Equal(t, username+":"+password, string(credentials))
}
func TestEmptyBasicAuthSettings(t *testing.T) {
requests := []testserver.TestRequestHandler{
{
Path: "/api/v4/internal/empty_basic_auth",
Handler: func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "", r.Header.Get("Authorization"))
},
},
}
client, cleanup := setup(t, &config.Config{}, requests)
defer cleanup()
_, err := client.Get("/empty_basic_auth")
require.NoError(t, err)
}
func setup(t *testing.T, config *config.Config, requests []testserver.TestRequestHandler) (*GitlabClient, func()) {
cleanup, url, err := testserver.StartHttpServer(requests)
require.NoError(t, err)
config.GitlabUrl = url
client, err := GetClient(config)
require.NoError(t, err)
return client, cleanup
}
package gitlabnet
import (
"bytes"
"context"
"encoding/json"
"net"
"net/http"
"strings"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
)
const (
// We need to set the base URL to something starting with HTTP, the host
// itself is ignored as we're talking over a socket.
socketBaseUrl = "http://unix"
UnixSocketProtocol = "http+unix://"
)
type GitlabSocketClient struct {
httpClient *http.Client
config *config.Config
}
func buildSocketClient(config *config.Config) *GitlabSocketClient {
path := strings.TrimPrefix(config.GitlabUrl, UnixSocketProtocol)
httpClient := &http.Client{
Transport: &http.Transport{
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
return net.Dial("unix", path)
},
},
}
return &GitlabSocketClient{httpClient: httpClient, config: config}
}
func (c *GitlabSocketClient) Get(path string) (*http.Response, error) {
path = normalizePath(path)
request, err := http.NewRequest("GET", socketBaseUrl+path, nil)
if err != nil {
return nil, err
}
return doRequest(c.httpClient, c.config, request)
}
func (c *GitlabSocketClient) Post(path string, data interface{}) (*http.Response, error) {
path = normalizePath(path)
jsonData, err := json.Marshal(data)
if err != nil {
return nil, err
}
request, err := http.NewRequest("POST", socketBaseUrl+path, bytes.NewReader(jsonData))
request.Header.Add("Content-Type", "application/json")
if err != nil {
return nil, err
}
return doRequest(c.httpClient, c.config, request)
}
...@@ -5,6 +5,7 @@ import ( ...@@ -5,6 +5,7 @@ import (
"log" "log"
"net" "net"
"net/http" "net/http"
"net/http/httptest"
"os" "os"
"path" "path"
"path/filepath" "path/filepath"
...@@ -12,7 +13,7 @@ import ( ...@@ -12,7 +13,7 @@ import (
var ( var (
tempDir, _ = ioutil.TempDir("", "gitlab-shell-test-api") tempDir, _ = ioutil.TempDir("", "gitlab-shell-test-api")
TestSocket = path.Join(tempDir, "internal.sock") testSocket = path.Join(tempDir, "internal.sock")
) )
type TestRequestHandler struct { type TestRequestHandler struct {
...@@ -20,14 +21,14 @@ type TestRequestHandler struct { ...@@ -20,14 +21,14 @@ type TestRequestHandler struct {
Handler func(w http.ResponseWriter, r *http.Request) Handler func(w http.ResponseWriter, r *http.Request)
} }
func StartSocketHttpServer(handlers []TestRequestHandler) (func(), error) { func StartSocketHttpServer(handlers []TestRequestHandler) (func(), string, error) {
if err := os.MkdirAll(filepath.Dir(TestSocket), 0700); err != nil { if err := os.MkdirAll(filepath.Dir(testSocket), 0700); err != nil {
return nil, err return nil, "", err
} }
socketListener, err := net.Listen("unix", TestSocket) socketListener, err := net.Listen("unix", testSocket)
if err != nil { if err != nil {
return nil, err return nil, "", err
} }
server := http.Server{ server := http.Server{
...@@ -38,7 +39,15 @@ func StartSocketHttpServer(handlers []TestRequestHandler) (func(), error) { ...@@ -38,7 +39,15 @@ func StartSocketHttpServer(handlers []TestRequestHandler) (func(), error) {
} }
go server.Serve(socketListener) go server.Serve(socketListener)
return cleanupSocket, nil url := "http+unix://" + testSocket
return cleanupSocket, url, nil
}
func StartHttpServer(handlers []TestRequestHandler) (func(), string, error) {
server := httptest.NewServer(buildHandler(handlers))
return server.Close, server.URL, nil
} }
func cleanupSocket() { func cleanupSocket() {
......
...@@ -15,7 +15,7 @@ import ( ...@@ -15,7 +15,7 @@ import (
type Client struct { type Client struct {
config *config.Config config *config.Config
client gitlabnet.GitlabClient client *gitlabnet.GitlabClient
} }
type Response struct { type Response struct {
......
...@@ -17,12 +17,10 @@ import ( ...@@ -17,12 +17,10 @@ import (
) )
var ( var (
testConfig *config.Config
requests []testserver.TestRequestHandler requests []testserver.TestRequestHandler
) )
func initialize(t *testing.T) { func initialize(t *testing.T) {
testConfig = &config.Config{GitlabUrl: "http+unix://" + testserver.TestSocket}
requests = []testserver.TestRequestHandler{ requests = []testserver.TestRequestHandler{
{ {
Path: "/api/v4/internal/two_factor_recovery_codes", Path: "/api/v4/internal/two_factor_recovery_codes",
...@@ -151,10 +149,10 @@ func TestErrorResponses(t *testing.T) { ...@@ -151,10 +149,10 @@ func TestErrorResponses(t *testing.T) {
func setup(t *testing.T) (*Client, func()) { func setup(t *testing.T) (*Client, func()) {
initialize(t) initialize(t)
cleanup, err := testserver.StartSocketHttpServer(requests) cleanup, url, err := testserver.StartSocketHttpServer(requests)
require.NoError(t, err) require.NoError(t, err)
client, err := NewClient(testConfig) client, err := NewClient(&config.Config{GitlabUrl: url})
require.NoError(t, err) require.NoError(t, err)
return client, cleanup return client, cleanup
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment