Commit 48142b98 authored by Nick Thomas's avatar Nick Thomas

Merge branch 'id-go-refactorings' into 'master'

Refactor execution and parsing logic in Go's implementation

See merge request gitlab-org/gitlab-shell!302
parents 58d8c769 b77a3752
...@@ -29,7 +29,7 @@ func findRootDir() (string, error) { ...@@ -29,7 +29,7 @@ func findRootDir() (string, error) {
func execRuby(rootDir string, readWriter *readwriter.ReadWriter) { func execRuby(rootDir string, readWriter *readwriter.ReadWriter) {
cmd := &fallback.Command{RootDir: rootDir, Args: os.Args} cmd := &fallback.Command{RootDir: rootDir, Args: os.Args}
if err := cmd.Execute(readWriter); err != nil { if err := cmd.Execute(); err != nil {
fmt.Fprintf(readWriter.ErrOut, "Failed to exec: %v\n", err) fmt.Fprintf(readWriter.ErrOut, "Failed to exec: %v\n", err)
os.Exit(1) os.Exit(1)
} }
...@@ -56,7 +56,7 @@ func main() { ...@@ -56,7 +56,7 @@ func main() {
execRuby(rootDir, readWriter) execRuby(rootDir, readWriter)
} }
cmd, err := command.New(os.Args, config) cmd, err := command.New(os.Args, config, readWriter)
if err != nil { if err != nil {
// For now this could happen if `SSH_CONNECTION` is not set on // For now this could happen if `SSH_CONNECTION` is not set on
// the environment // the environment
...@@ -66,7 +66,7 @@ func main() { ...@@ -66,7 +66,7 @@ func main() {
// The command will write to STDOUT on execution or replace the current // The command will write to STDOUT on execution or replace the current
// process in case of the `fallback.Command` // process in case of the `fallback.Command`
if err = cmd.Execute(readWriter); err != nil { if err = cmd.Execute(); err != nil {
fmt.Fprintf(readWriter.ErrOut, "%v\n", err) fmt.Fprintf(readWriter.ErrOut, "%v\n", err)
os.Exit(1) os.Exit(1)
} }
......
...@@ -10,10 +10,10 @@ import ( ...@@ -10,10 +10,10 @@ import (
) )
type Command interface { type Command interface {
Execute(*readwriter.ReadWriter) error Execute() error
} }
func New(arguments []string, config *config.Config) (Command, error) { func New(arguments []string, config *config.Config, readWriter *readwriter.ReadWriter) (Command, error) {
args, err := commandargs.Parse(arguments) args, err := commandargs.Parse(arguments)
if err != nil { if err != nil {
...@@ -21,18 +21,18 @@ func New(arguments []string, config *config.Config) (Command, error) { ...@@ -21,18 +21,18 @@ func New(arguments []string, config *config.Config) (Command, error) {
} }
if config.FeatureEnabled(string(args.CommandType)) { if config.FeatureEnabled(string(args.CommandType)) {
return buildCommand(args, config), nil return buildCommand(args, config, readWriter), nil
} }
return &fallback.Command{RootDir: config.RootDir, Args: arguments}, nil return &fallback.Command{RootDir: config.RootDir, Args: arguments}, nil
} }
func buildCommand(args *commandargs.CommandArgs, config *config.Config) Command { func buildCommand(args *commandargs.CommandArgs, config *config.Config, readWriter *readwriter.ReadWriter) Command {
switch args.CommandType { switch args.CommandType {
case commandargs.Discover: case commandargs.Discover:
return &discover.Command{Config: config, Args: args} return &discover.Command{Config: config, Args: args, ReadWriter: readWriter}
case commandargs.TwoFactorRecover: case commandargs.TwoFactorRecover:
return &twofactorrecover.Command{Config: config, Args: args} return &twofactorrecover.Command{Config: config, Args: args, ReadWriter: readWriter}
} }
return nil return nil
......
...@@ -65,7 +65,7 @@ func TestNew(t *testing.T) { ...@@ -65,7 +65,7 @@ func TestNew(t *testing.T) {
restoreEnv := testhelper.TempEnv(tc.environment) restoreEnv := testhelper.TempEnv(tc.environment)
defer restoreEnv() defer restoreEnv()
command, err := New(tc.arguments, tc.config) command, err := New(tc.arguments, tc.config, nil)
assert.NoError(t, err) assert.NoError(t, err)
assert.IsType(t, tc.expectedType, command) assert.IsType(t, tc.expectedType, command)
...@@ -78,7 +78,7 @@ func TestFailingNew(t *testing.T) { ...@@ -78,7 +78,7 @@ func TestFailingNew(t *testing.T) {
restoreEnv := testhelper.TempEnv(map[string]string{}) restoreEnv := testhelper.TempEnv(map[string]string{})
defer restoreEnv() defer restoreEnv()
_, err := New([]string{}, &config.Config{}) _, err := New([]string{}, &config.Config{}, nil)
assert.Error(t, err, "Only ssh allowed") assert.Error(t, err, "Only ssh allowed")
}) })
......
...@@ -12,18 +12,19 @@ import ( ...@@ -12,18 +12,19 @@ import (
type Command struct { type Command struct {
Config *config.Config Config *config.Config
Args *commandargs.CommandArgs Args *commandargs.CommandArgs
ReadWriter *readwriter.ReadWriter
} }
func (c *Command) Execute(readWriter *readwriter.ReadWriter) error { func (c *Command) Execute() error {
response, err := c.getUserInfo() response, err := c.getUserInfo()
if err != nil { if err != nil {
return fmt.Errorf("Failed to get username: %v", err) return fmt.Errorf("Failed to get username: %v", err)
} }
if response.IsAnonymous() { if response.IsAnonymous() {
fmt.Fprintf(readWriter.Out, "Welcome to GitLab, Anonymous!\n") fmt.Fprintf(c.ReadWriter.Out, "Welcome to GitLab, Anonymous!\n")
} else { } else {
fmt.Fprintf(readWriter.Out, "Welcome to GitLab, @%s!\n", response.Username) fmt.Fprintf(c.ReadWriter.Out, "Welcome to GitLab, @%s!\n", response.Username)
} }
return nil return nil
......
...@@ -78,10 +78,14 @@ func TestExecute(t *testing.T) { ...@@ -78,10 +78,14 @@ func TestExecute(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) { t.Run(tc.desc, func(t *testing.T) {
cmd := &Command{Config: &config.Config{GitlabUrl: url}, Args: tc.arguments}
buffer := &bytes.Buffer{} buffer := &bytes.Buffer{}
cmd := &Command{
Config: &config.Config{GitlabUrl: url},
Args: tc.arguments,
ReadWriter: &readwriter.ReadWriter{Out: buffer},
}
err := cmd.Execute(&readwriter.ReadWriter{Out: buffer}) err := cmd.Execute()
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, tc.expectedOutput, buffer.String()) assert.Equal(t, tc.expectedOutput, buffer.String())
...@@ -118,10 +122,14 @@ func TestFailingExecute(t *testing.T) { ...@@ -118,10 +122,14 @@ func TestFailingExecute(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) { t.Run(tc.desc, func(t *testing.T) {
cmd := &Command{Config: &config.Config{GitlabUrl: url}, Args: tc.arguments}
buffer := &bytes.Buffer{} buffer := &bytes.Buffer{}
cmd := &Command{
Config: &config.Config{GitlabUrl: url},
Args: tc.arguments,
ReadWriter: &readwriter.ReadWriter{Out: buffer},
}
err := cmd.Execute(&readwriter.ReadWriter{Out: buffer}) err := cmd.Execute()
assert.Empty(t, buffer.String()) assert.Empty(t, buffer.String())
assert.EqualError(t, err, tc.expectedError) assert.EqualError(t, err, tc.expectedError)
......
...@@ -4,8 +4,6 @@ import ( ...@@ -4,8 +4,6 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"syscall" "syscall"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/readwriter"
) )
type Command struct { type Command struct {
...@@ -22,7 +20,7 @@ const ( ...@@ -22,7 +20,7 @@ const (
RubyProgram = "gitlab-shell-ruby" RubyProgram = "gitlab-shell-ruby"
) )
func (c *Command) Execute(*readwriter.ReadWriter) error { func (c *Command) Execute() error {
rubyCmd := filepath.Join(c.RootDir, "bin", RubyProgram) rubyCmd := filepath.Join(c.RootDir, "bin", RubyProgram)
// Ensure rubyArgs[0] is the full path to gitlab-shell-ruby // Ensure rubyArgs[0] is the full path to gitlab-shell-ruby
......
...@@ -49,7 +49,7 @@ func TestExecuteExecsCommandSuccesfully(t *testing.T) { ...@@ -49,7 +49,7 @@ func TestExecuteExecsCommandSuccesfully(t *testing.T) {
fake.Setup() fake.Setup()
defer fake.Cleanup() defer fake.Cleanup()
require.NoError(t, cmd.Execute(nil)) require.NoError(t, cmd.Execute())
require.True(t, fake.Called) require.True(t, fake.Called)
require.Equal(t, fake.Filename, "/tmp/bin/gitlab-shell-ruby") require.Equal(t, fake.Filename, "/tmp/bin/gitlab-shell-ruby")
require.Equal(t, fake.Args, []string{"/tmp/bin/gitlab-shell-ruby", "foo", "bar"}) require.Equal(t, fake.Args, []string{"/tmp/bin/gitlab-shell-ruby", "foo", "bar"})
...@@ -64,12 +64,12 @@ func TestExecuteExecsCommandOnError(t *testing.T) { ...@@ -64,12 +64,12 @@ func TestExecuteExecsCommandOnError(t *testing.T) {
fake.Setup() fake.Setup()
defer fake.Cleanup() defer fake.Cleanup()
require.Error(t, cmd.Execute(nil)) require.Error(t, cmd.Execute())
require.True(t, fake.Called) require.True(t, fake.Called)
} }
func TestExecuteGivenNonexistentCommand(t *testing.T) { func TestExecuteGivenNonexistentCommand(t *testing.T) {
cmd := &Command{RootDir: "/tmp/does/not/exist", Args: fakeArgs} cmd := &Command{RootDir: "/tmp/does/not/exist", Args: fakeArgs}
require.Error(t, cmd.Execute(nil)) require.Error(t, cmd.Execute())
} }
...@@ -13,31 +13,32 @@ import ( ...@@ -13,31 +13,32 @@ import (
type Command struct { type Command struct {
Config *config.Config Config *config.Config
Args *commandargs.CommandArgs Args *commandargs.CommandArgs
ReadWriter *readwriter.ReadWriter
} }
func (c *Command) Execute(readWriter *readwriter.ReadWriter) error { func (c *Command) Execute() error {
if c.canContinue(readWriter) { if c.canContinue() {
c.displayRecoveryCodes(readWriter) c.displayRecoveryCodes()
} else { } else {
fmt.Fprintln(readWriter.Out, "\nNew recovery codes have *not* been generated. Existing codes will remain valid.") fmt.Fprintln(c.ReadWriter.Out, "\nNew recovery codes have *not* been generated. Existing codes will remain valid.")
} }
return nil return nil
} }
func (c *Command) canContinue(readWriter *readwriter.ReadWriter) bool { func (c *Command) canContinue() bool {
question := question :=
"Are you sure you want to generate new two-factor recovery codes?\n" + "Are you sure you want to generate new two-factor recovery codes?\n" +
"Any existing recovery codes you saved will be invalidated. (yes/no)" "Any existing recovery codes you saved will be invalidated. (yes/no)"
fmt.Fprintln(readWriter.Out, question) fmt.Fprintln(c.ReadWriter.Out, question)
var answer string var answer string
fmt.Fscanln(readWriter.In, &answer) fmt.Fscanln(c.ReadWriter.In, &answer)
return answer == "yes" return answer == "yes"
} }
func (c *Command) displayRecoveryCodes(readWriter *readwriter.ReadWriter) { func (c *Command) displayRecoveryCodes() {
codes, err := c.getRecoveryCodes() codes, err := c.getRecoveryCodes()
if err == nil { if err == nil {
...@@ -47,9 +48,9 @@ func (c *Command) displayRecoveryCodes(readWriter *readwriter.ReadWriter) { ...@@ -47,9 +48,9 @@ func (c *Command) displayRecoveryCodes(readWriter *readwriter.ReadWriter) {
"\n\nDuring sign in, use one of the codes above when prompted for\n" + "\n\nDuring sign in, use one of the codes above when prompted for\n" +
"your two-factor code. Then, visit your Profile Settings and add\n" + "your two-factor code. Then, visit your Profile Settings and add\n" +
"a new device so you do not lose access to your account again.\n" "a new device so you do not lose access to your account again.\n"
fmt.Fprint(readWriter.Out, messageWithCodes) fmt.Fprint(c.ReadWriter.Out, messageWithCodes)
} else { } else {
fmt.Fprintf(readWriter.Out, "\nAn error occurred while trying to generate new recovery codes.\n%v\n", err) fmt.Fprintf(c.ReadWriter.Out, "\nAn error occurred while trying to generate new recovery codes.\n%v\n", err)
} }
} }
......
...@@ -122,9 +122,13 @@ func TestExecute(t *testing.T) { ...@@ -122,9 +122,13 @@ func TestExecute(t *testing.T) {
output := &bytes.Buffer{} output := &bytes.Buffer{}
input := bytes.NewBufferString(tc.answer) input := bytes.NewBufferString(tc.answer)
cmd := &Command{Config: &config.Config{GitlabUrl: url}, Args: tc.arguments} cmd := &Command{
Config: &config.Config{GitlabUrl: url},
Args: tc.arguments,
ReadWriter: &readwriter.ReadWriter{Out: output, In: input},
}
err := cmd.Execute(&readwriter.ReadWriter{Out: output, In: input}) err := cmd.Execute()
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, tc.expectedOutput, output.String()) assert.Equal(t, tc.expectedOutput, output.String())
......
...@@ -17,6 +17,10 @@ const ( ...@@ -17,6 +17,10 @@ const (
secretHeaderName = "Gitlab-Shared-Secret" secretHeaderName = "Gitlab-Shared-Secret"
) )
var (
ParsingError = fmt.Errorf("Parsing failed")
)
type ErrorResponse struct { type ErrorResponse struct {
Message string `json:"message"` Message string `json:"message"`
} }
...@@ -120,3 +124,11 @@ func (c *GitlabClient) doRequest(method, path string, data interface{}) (*http.R ...@@ -120,3 +124,11 @@ func (c *GitlabClient) doRequest(method, path string, data interface{}) (*http.R
return response, nil return response, nil
} }
func ParseJSON(hr *http.Response, response interface{}) error {
if err := json.NewDecoder(hr.Body).Decode(response); err != nil {
return ParsingError
}
return nil
}
package discover package discover
import ( import (
"encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"net/url" "net/url"
...@@ -32,56 +31,39 @@ func NewClient(config *config.Config) (*Client, error) { ...@@ -32,56 +31,39 @@ func NewClient(config *config.Config) (*Client, error) {
} }
func (c *Client) GetByCommandArgs(args *commandargs.CommandArgs) (*Response, error) { func (c *Client) GetByCommandArgs(args *commandargs.CommandArgs) (*Response, error) {
if args.GitlabKeyId != "" { params := url.Values{}
return c.GetByKeyId(args.GitlabKeyId) if args.GitlabUsername != "" {
} else if args.GitlabUsername != "" { params.Add("username", args.GitlabUsername)
return c.GetByUsername(args.GitlabUsername) } else if args.GitlabKeyId != "" {
params.Add("key_id", args.GitlabKeyId)
} else { } else {
// There was no 'who' information, this matches the ruby error // There was no 'who' information, this matches the ruby error
// message. // message.
return nil, fmt.Errorf("who='' is invalid") return nil, fmt.Errorf("who='' is invalid")
} }
}
func (c *Client) GetByKeyId(keyId string) (*Response, error) {
params := url.Values{}
params.Add("key_id", keyId)
return c.getResponse(params)
}
func (c *Client) GetByUsername(username string) (*Response, error) {
params := url.Values{}
params.Add("username", username)
return c.getResponse(params) return c.getResponse(params)
} }
func (c *Client) parseResponse(resp *http.Response) (*Response, error) {
parsedResponse := &Response{}
if err := json.NewDecoder(resp.Body).Decode(parsedResponse); err != nil {
return nil, err
} else {
return parsedResponse, nil
}
}
func (c *Client) getResponse(params url.Values) (*Response, error) { func (c *Client) getResponse(params url.Values) (*Response, error) {
path := "/discover?" + params.Encode() path := "/discover?" + params.Encode()
response, err := c.client.Get(path)
response, err := c.client.Get(path)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer response.Body.Close() defer response.Body.Close()
parsedResponse, err := c.parseResponse(response)
if err != nil { return parse(response)
return nil, fmt.Errorf("Parsing failed") }
func parse(hr *http.Response) (*Response, error) {
response := &Response{}
if err := gitlabnet.ParseJSON(hr, response); err != nil {
return nil, err
} }
return parsedResponse, nil return response, nil
} }
func (r *Response) IsAnonymous() bool { func (r *Response) IsAnonymous() bool {
......
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"net/url"
"testing" "testing"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/config" "gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
...@@ -59,7 +60,9 @@ func TestGetByKeyId(t *testing.T) { ...@@ -59,7 +60,9 @@ func TestGetByKeyId(t *testing.T) {
client, cleanup := setup(t) client, cleanup := setup(t)
defer cleanup() defer cleanup()
result, err := client.GetByKeyId("1") params := url.Values{}
params.Add("key_id", "1")
result, err := client.getResponse(params)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, &Response{UserId: 2, Username: "alex-doe", Name: "Alex Doe"}, result) assert.Equal(t, &Response{UserId: 2, Username: "alex-doe", Name: "Alex Doe"}, result)
} }
...@@ -68,7 +71,9 @@ func TestGetByUsername(t *testing.T) { ...@@ -68,7 +71,9 @@ func TestGetByUsername(t *testing.T) {
client, cleanup := setup(t) client, cleanup := setup(t)
defer cleanup() defer cleanup()
result, err := client.GetByUsername("jane-doe") params := url.Values{}
params.Add("username", "jane-doe")
result, err := client.getResponse(params)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, &Response{UserId: 1, Username: "jane-doe", Name: "Jane Doe"}, result) assert.Equal(t, &Response{UserId: 1, Username: "jane-doe", Name: "Jane Doe"}, result)
} }
...@@ -77,7 +82,9 @@ func TestMissingUser(t *testing.T) { ...@@ -77,7 +82,9 @@ func TestMissingUser(t *testing.T) {
client, cleanup := setup(t) client, cleanup := setup(t)
defer cleanup() defer cleanup()
result, err := client.GetByUsername("missing") params := url.Values{}
params.Add("username", "missing")
result, err := client.getResponse(params)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, result.IsAnonymous()) assert.True(t, result.IsAnonymous())
} }
...@@ -110,7 +117,9 @@ func TestErrorResponses(t *testing.T) { ...@@ -110,7 +117,9 @@ func TestErrorResponses(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) { t.Run(tc.desc, func(t *testing.T) {
resp, err := client.GetByUsername(tc.fakeUsername) params := url.Values{}
params.Add("username", tc.fakeUsername)
resp, err := client.getResponse(params)
assert.EqualError(t, err, tc.expectedError) assert.EqualError(t, err, tc.expectedError)
assert.Nil(t, resp) assert.Nil(t, resp)
......
package twofactorrecover package twofactorrecover
import ( import (
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"io/ioutil"
"net/http" "net/http"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs" "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs"
...@@ -46,38 +44,25 @@ func (c *Client) GetRecoveryCodes(args *commandargs.CommandArgs) ([]string, erro ...@@ -46,38 +44,25 @@ func (c *Client) GetRecoveryCodes(args *commandargs.CommandArgs) ([]string, erro
} }
response, err := c.client.Post("/two_factor_recovery_codes", requestBody) response, err := c.client.Post("/two_factor_recovery_codes", requestBody)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer response.Body.Close() defer response.Body.Close()
parsedResponse, err := c.parseResponse(response)
if err != nil {
return nil, fmt.Errorf("Parsing failed")
}
if parsedResponse.Success { return parse(response)
return parsedResponse.RecoveryCodes, nil
} else {
return nil, errors.New(parsedResponse.Message)
}
} }
func (c *Client) parseResponse(resp *http.Response) (*Response, error) { func parse(hr *http.Response) ([]string, error) {
parsedResponse := &Response{} response := &Response{}
body, err := ioutil.ReadAll(resp.Body) if err := gitlabnet.ParseJSON(hr, response); err != nil {
if err != nil {
return nil, err return nil, err
} }
if err := json.Unmarshal(body, parsedResponse); err != nil { if !response.Success {
return nil, err return nil, errors.New(response.Message)
} else {
return parsedResponse, nil
} }
return response.RecoveryCodes, nil
} }
func (c *Client) getRequestBody(args *commandargs.CommandArgs) (*RequestBody, error) { func (c *Client) getRequestBody(args *commandargs.CommandArgs) (*RequestBody, error) {
......
...@@ -114,7 +114,7 @@ describe 'bin/gitlab-shell 2fa_recovery_codes' do ...@@ -114,7 +114,7 @@ describe 'bin/gitlab-shell 2fa_recovery_codes' do
it_behaves_like 'dialog for regenerating recovery keys' it_behaves_like 'dialog for regenerating recovery keys'
end end
describe 'with go features' do describe 'with go features', :go do
before(:context) do before(:context) do
write_config( write_config(
"gitlab_url" => "http+unix://#{CGI.escape(tmp_socket_path)}", "gitlab_url" => "http+unix://#{CGI.escape(tmp_socket_path)}",
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment