Commit 9718a465 authored by Mitchell Hashimoto's avatar Mitchell Hashimoto

communicator/ssh: have a Connection func so we can re-establish

[GH-152]
parent db644c91
...@@ -9,13 +9,12 @@ import ( ...@@ -9,13 +9,12 @@ import (
"github.com/mitchellh/packer/communicator/ssh" "github.com/mitchellh/packer/communicator/ssh"
"github.com/mitchellh/packer/packer" "github.com/mitchellh/packer/packer"
"log" "log"
"net"
"time" "time"
) )
type stepConnectSSH struct { type stepConnectSSH struct {
cancel bool cancel bool
conn net.Conn comm packer.Communicator
} }
func (s *stepConnectSSH) Run(state map[string]interface{}) multistep.StepAction { func (s *stepConnectSSH) Run(state map[string]interface{}) multistep.StepAction {
...@@ -45,6 +44,7 @@ WaitLoop: ...@@ -45,6 +44,7 @@ WaitLoop:
return multistep.ActionHalt return multistep.ActionHalt
} }
s.comm = comm
state["communicator"] = comm state["communicator"] = comm
break WaitLoop break WaitLoop
case <-timeout: case <-timeout:
...@@ -63,9 +63,9 @@ WaitLoop: ...@@ -63,9 +63,9 @@ WaitLoop:
} }
func (s *stepConnectSSH) Cleanup(map[string]interface{}) { func (s *stepConnectSSH) Cleanup(map[string]interface{}) {
if s.conn != nil { if s.comm != nil {
s.conn.Close() // Close it TODO
s.conn = nil s.comm = nil
} }
} }
...@@ -85,14 +85,13 @@ func (s *stepConnectSSH) waitForSSH(state map[string]interface{}) (packer.Commun ...@@ -85,14 +85,13 @@ func (s *stepConnectSSH) waitForSSH(state map[string]interface{}) (packer.Commun
return nil, fmt.Errorf("Error setting up SSH config: %s", err) return nil, fmt.Errorf("Error setting up SSH config: %s", err)
} }
// Create the function that will be used to create the connection
connFunc := ssh.ConnectFunc(
"tcp", fmt.Sprintf("%s:%d", instance.DNSName, config.SSHPort))
ui.Say("Waiting for SSH to become available...") ui.Say("Waiting for SSH to become available...")
var comm packer.Communicator var comm packer.Communicator
var nc net.Conn
for { for {
if nc != nil {
nc.Close()
}
time.Sleep(5 * time.Second) time.Sleep(5 * time.Second)
if s.cancel { if s.cancel {
...@@ -100,28 +99,29 @@ func (s *stepConnectSSH) waitForSSH(state map[string]interface{}) (packer.Commun ...@@ -100,28 +99,29 @@ func (s *stepConnectSSH) waitForSSH(state map[string]interface{}) (packer.Commun
return nil, errors.New("SSH wait cancelled") return nil, errors.New("SSH wait cancelled")
} }
// Attempt to connect to SSH port // First just attempt a normal TCP connection that we close right
log.Printf( // away. We just test this in order to wait for the TCP port to be ready.
"Opening TCP conn for SSH to %s:%d", nc, err := connFunc()
instance.DNSName, config.SSHPort)
nc, err := net.Dial("tcp",
fmt.Sprintf("%s:%d", instance.DNSName, config.SSHPort))
if err != nil { if err != nil {
log.Printf("TCP connection to SSH ip/port failed: %s", err) log.Printf("TCP connection to SSH ip/port failed: %s", err)
continue continue
} }
nc.Close()
// Build the actual SSH client configuration // Build the configuration to connect to SSH
sshConfig := &gossh.ClientConfig{ config := &ssh.Config{
Connection: connFunc,
SSHConfig: &gossh.ClientConfig{
User: config.SSHUsername, User: config.SSHUsername,
Auth: []gossh.ClientAuth{ Auth: []gossh.ClientAuth{
gossh.ClientAuthKeyring(keyring), gossh.ClientAuthKeyring(keyring),
}, },
},
} }
sshConnectSuccess := make(chan bool, 1) sshConnectSuccess := make(chan bool, 1)
go func() { go func() {
comm, err = ssh.New(nc, sshConfig) comm, err = ssh.New(config)
if err != nil { if err != nil {
log.Printf("SSH connection fail: %s", err) log.Printf("SSH connection fail: %s", err)
sshConnectSuccess <- false sshConnectSuccess <- false
...@@ -145,7 +145,5 @@ func (s *stepConnectSSH) waitForSSH(state map[string]interface{}) (packer.Commun ...@@ -145,7 +145,5 @@ func (s *stepConnectSSH) waitForSSH(state map[string]interface{}) (packer.Commun
break break
} }
// Store the connection so we can close it later
s.conn = nc
return comm, nil return comm, nil
} }
...@@ -8,12 +8,11 @@ import ( ...@@ -8,12 +8,11 @@ import (
"github.com/mitchellh/packer/communicator/ssh" "github.com/mitchellh/packer/communicator/ssh"
"github.com/mitchellh/packer/packer" "github.com/mitchellh/packer/packer"
"log" "log"
"net"
"time" "time"
) )
type stepConnectSSH struct { type stepConnectSSH struct {
conn net.Conn comm packer.Communicator
} }
func (s *stepConnectSSH) Run(state map[string]interface{}) multistep.StepAction { func (s *stepConnectSSH) Run(state map[string]interface{}) multistep.StepAction {
...@@ -33,12 +32,17 @@ func (s *stepConnectSSH) Run(state map[string]interface{}) multistep.StepAction ...@@ -33,12 +32,17 @@ func (s *stepConnectSSH) Run(state map[string]interface{}) multistep.StepAction
return multistep.ActionHalt return multistep.ActionHalt
} }
connFunc := ssh.ConnectFunc("tcp", fmt.Sprintf("%s:%d", ipAddress, config.SSHPort))
// Build the actual SSH client configuration // Build the actual SSH client configuration
sshConfig := &gossh.ClientConfig{ sshConfig := &ssh.Config{
Connection: connFunc,
SSHConfig: &gossh.ClientConfig{
User: config.SSHUsername, User: config.SSHUsername,
Auth: []gossh.ClientAuth{ Auth: []gossh.ClientAuth{
gossh.ClientAuthKeyring(keyring), gossh.ClientAuthKeyring(keyring),
}, },
},
} }
// Start trying to connect to SSH // Start trying to connect to SSH
...@@ -50,8 +54,6 @@ func (s *stepConnectSSH) Run(state map[string]interface{}) multistep.StepAction ...@@ -50,8 +54,6 @@ func (s *stepConnectSSH) Run(state map[string]interface{}) multistep.StepAction
var comm packer.Communicator var comm packer.Communicator
go func() { go func() {
var err error
ui.Say("Connecting to the droplet via SSH...") ui.Say("Connecting to the droplet via SSH...")
attempts := 0 attempts := 0
handshakeAttempts := 0 handshakeAttempts := 0
...@@ -62,17 +64,19 @@ func (s *stepConnectSSH) Run(state map[string]interface{}) multistep.StepAction ...@@ -62,17 +64,19 @@ func (s *stepConnectSSH) Run(state map[string]interface{}) multistep.StepAction
default: default:
} }
// A brief sleep so we're not being overly zealous attempting
// to connect to the instance.
time.Sleep(500 * time.Millisecond)
attempts += 1 attempts += 1
log.Printf( nc, err := connFunc()
"Opening TCP conn for SSH to %s:%d (attempt %d)", if err != nil {
ipAddress, config.SSHPort, attempts) continue
s.conn, err = net.DialTimeout( }
"tcp", nc.Close()
fmt.Sprintf("%s:%d", ipAddress, config.SSHPort),
10*time.Second)
if err == nil {
log.Println("TCP connection made. Attempting SSH handshake.") log.Println("TCP connection made. Attempting SSH handshake.")
comm, err = ssh.New(s.conn, sshConfig) comm, err = ssh.New(sshConfig)
if err == nil { if err == nil {
log.Println("Connected to SSH!") log.Println("Connected to SSH!")
break break
...@@ -87,11 +91,6 @@ func (s *stepConnectSSH) Run(state map[string]interface{}) multistep.StepAction ...@@ -87,11 +91,6 @@ func (s *stepConnectSSH) Run(state map[string]interface{}) multistep.StepAction
} }
} }
// A brief sleep so we're not being overly zealous attempting
// to connect to the instance.
time.Sleep(500 * time.Millisecond)
}
connected <- nil connected <- nil
}() }()
...@@ -125,13 +124,15 @@ ConnectWaitLoop: ...@@ -125,13 +124,15 @@ ConnectWaitLoop:
} }
// Set the communicator on the state bag so it can be used later // Set the communicator on the state bag so it can be used later
s.comm = comm
state["communicator"] = comm state["communicator"] = comm
return multistep.ActionContinue return multistep.ActionContinue
} }
func (s *stepConnectSSH) Cleanup(map[string]interface{}) { func (s *stepConnectSSH) Cleanup(map[string]interface{}) {
if s.conn != nil { if s.comm != nil {
s.conn.Close() // TODO: close
s.comm = nil
} }
} }
...@@ -8,7 +8,6 @@ import ( ...@@ -8,7 +8,6 @@ import (
"github.com/mitchellh/packer/communicator/ssh" "github.com/mitchellh/packer/communicator/ssh"
"github.com/mitchellh/packer/packer" "github.com/mitchellh/packer/packer"
"log" "log"
"net"
"time" "time"
) )
...@@ -24,7 +23,7 @@ import ( ...@@ -24,7 +23,7 @@ import (
// communicator packer.Communicator // communicator packer.Communicator
type stepWaitForSSH struct { type stepWaitForSSH struct {
cancel bool cancel bool
conn net.Conn comm packer.Communicator
} }
func (s *stepWaitForSSH) Run(state map[string]interface{}) multistep.StepAction { func (s *stepWaitForSSH) Run(state map[string]interface{}) multistep.StepAction {
...@@ -54,6 +53,7 @@ WaitLoop: ...@@ -54,6 +53,7 @@ WaitLoop:
return multistep.ActionHalt return multistep.ActionHalt
} }
s.comm = comm
state["communicator"] = comm state["communicator"] = comm
break WaitLoop break WaitLoop
case <-timeout: case <-timeout:
...@@ -72,9 +72,9 @@ WaitLoop: ...@@ -72,9 +72,9 @@ WaitLoop:
} }
func (s *stepWaitForSSH) Cleanup(map[string]interface{}) { func (s *stepWaitForSSH) Cleanup(map[string]interface{}) {
if s.conn != nil { if s.comm != nil {
s.conn.Close() // TODO: close
s.conn = nil s.comm = nil
} }
} }
...@@ -85,14 +85,11 @@ func (s *stepWaitForSSH) waitForSSH(state map[string]interface{}) (packer.Commun ...@@ -85,14 +85,11 @@ func (s *stepWaitForSSH) waitForSSH(state map[string]interface{}) (packer.Commun
ui := state["ui"].(packer.Ui) ui := state["ui"].(packer.Ui)
sshHostPort := state["sshHostPort"].(uint) sshHostPort := state["sshHostPort"].(uint)
connFunc := ssh.ConnectFunc("tcp", fmt.Sprintf("127.0.0.1:%d", sshHostPort))
ui.Say("Waiting for SSH to become available...") ui.Say("Waiting for SSH to become available...")
var comm packer.Communicator var comm packer.Communicator
var nc net.Conn
for { for {
if nc != nil {
nc.Close()
}
time.Sleep(5 * time.Second) time.Sleep(5 * time.Second)
if s.cancel { if s.cancel {
...@@ -101,25 +98,29 @@ func (s *stepWaitForSSH) waitForSSH(state map[string]interface{}) (packer.Commun ...@@ -101,25 +98,29 @@ func (s *stepWaitForSSH) waitForSSH(state map[string]interface{}) (packer.Commun
} }
// Attempt to connect to SSH port // Attempt to connect to SSH port
nc, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", sshHostPort)) nc, err := connFunc()
if err != nil { if err != nil {
log.Printf("TCP connection to SSH ip/port failed: %s", err) log.Printf("TCP connection to SSH ip/port failed: %s", err)
continue continue
} }
nc.Close()
// Then we attempt to connect via SSH // Then we attempt to connect via SSH
sshConfig := &gossh.ClientConfig{ config := &ssh.Config{
Connection: connFunc,
SSHConfig: &gossh.ClientConfig{
User: config.SSHUser, User: config.SSHUser,
Auth: []gossh.ClientAuth{ Auth: []gossh.ClientAuth{
gossh.ClientAuthPassword(ssh.Password(config.SSHPassword)), gossh.ClientAuthPassword(ssh.Password(config.SSHPassword)),
gossh.ClientAuthKeyboardInteractive( gossh.ClientAuthKeyboardInteractive(
ssh.PasswordKeyboardInteractive(config.SSHPassword)), ssh.PasswordKeyboardInteractive(config.SSHPassword)),
}, },
},
} }
sshConnectSuccess := make(chan bool, 1) sshConnectSuccess := make(chan bool, 1)
go func() { go func() {
comm, err = ssh.New(nc, sshConfig) comm, err = ssh.New(config)
if err != nil { if err != nil {
log.Printf("SSH connection fail: %s", err) log.Printf("SSH connection fail: %s", err)
sshConnectSuccess <- false sshConnectSuccess <- false
...@@ -143,7 +144,5 @@ func (s *stepWaitForSSH) waitForSSH(state map[string]interface{}) (packer.Commun ...@@ -143,7 +144,5 @@ func (s *stepWaitForSSH) waitForSSH(state map[string]interface{}) (packer.Commun
break break
} }
// Store the connection so we can close it later
s.conn = nc
return comm, nil return comm, nil
} }
...@@ -9,7 +9,6 @@ import ( ...@@ -9,7 +9,6 @@ import (
"github.com/mitchellh/packer/packer" "github.com/mitchellh/packer/packer"
"io/ioutil" "io/ioutil"
"log" "log"
"net"
"os" "os"
"time" "time"
) )
...@@ -26,7 +25,7 @@ import ( ...@@ -26,7 +25,7 @@ import (
// communicator packer.Communicator // communicator packer.Communicator
type stepWaitForSSH struct { type stepWaitForSSH struct {
cancel bool cancel bool
conn net.Conn comm packer.Communicator
} }
func (s *stepWaitForSSH) Run(state map[string]interface{}) multistep.StepAction { func (s *stepWaitForSSH) Run(state map[string]interface{}) multistep.StepAction {
...@@ -56,6 +55,7 @@ WaitLoop: ...@@ -56,6 +55,7 @@ WaitLoop:
return multistep.ActionHalt return multistep.ActionHalt
} }
s.comm = comm
state["communicator"] = comm state["communicator"] = comm
break WaitLoop break WaitLoop
case <-timeout: case <-timeout:
...@@ -74,9 +74,9 @@ WaitLoop: ...@@ -74,9 +74,9 @@ WaitLoop:
} }
func (s *stepWaitForSSH) Cleanup(map[string]interface{}) { func (s *stepWaitForSSH) Cleanup(map[string]interface{}) {
if s.conn != nil { if s.comm != nil {
s.conn.Close() // TODO: close
s.conn = nil s.comm = nil
} }
} }
...@@ -117,12 +117,7 @@ func (s *stepWaitForSSH) waitForSSH(state map[string]interface{}) (packer.Commun ...@@ -117,12 +117,7 @@ func (s *stepWaitForSSH) waitForSSH(state map[string]interface{}) (packer.Commun
ui.Say("Waiting for SSH to become available...") ui.Say("Waiting for SSH to become available...")
var comm packer.Communicator var comm packer.Communicator
var nc net.Conn
for { for {
if nc != nil {
nc.Close()
}
time.Sleep(5 * time.Second) time.Sleep(5 * time.Second)
if s.cancel { if s.cancel {
...@@ -146,23 +141,28 @@ func (s *stepWaitForSSH) waitForSSH(state map[string]interface{}) (packer.Commun ...@@ -146,23 +141,28 @@ func (s *stepWaitForSSH) waitForSSH(state map[string]interface{}) (packer.Commun
log.Printf("Detected IP: %s", ip) log.Printf("Detected IP: %s", ip)
// Attempt to connect to SSH port // Attempt to connect to SSH port
nc, err = net.Dial("tcp", fmt.Sprintf("%s:%d", ip, config.SSHPort)) connFunc := ssh.ConnectFunc("tcp", fmt.Sprintf("%s:%d", ip, config.SSHPort))
nc, err := connFunc()
if err != nil { if err != nil {
log.Printf("TCP connection to SSH ip/port failed: %s", err) log.Printf("TCP connection to SSH ip/port failed: %s", err)
continue continue
} }
nc.Close()
// Then we attempt to connect via SSH // Then we attempt to connect via SSH
sshConfig := &gossh.ClientConfig{ config := &ssh.Config{
Connection: connFunc,
SSHConfig: &gossh.ClientConfig{
User: config.SSHUser, User: config.SSHUser,
Auth: []gossh.ClientAuth{ Auth: []gossh.ClientAuth{
gossh.ClientAuthPassword(ssh.Password(config.SSHPassword)), gossh.ClientAuthPassword(ssh.Password(config.SSHPassword)),
gossh.ClientAuthKeyboardInteractive( gossh.ClientAuthKeyboardInteractive(
ssh.PasswordKeyboardInteractive(config.SSHPassword)), ssh.PasswordKeyboardInteractive(config.SSHPassword)),
}, },
},
} }
comm, err = ssh.New(nc, sshConfig) comm, err = ssh.New(config)
if err != nil { if err != nil {
log.Printf("SSH handshake err: %s", err) log.Printf("SSH handshake err: %s", err)
...@@ -179,7 +179,5 @@ func (s *stepWaitForSSH) waitForSSH(state map[string]interface{}) (packer.Commun ...@@ -179,7 +179,5 @@ func (s *stepWaitForSSH) waitForSSH(state map[string]interface{}) (packer.Commun
break break
} }
// Store the connection so we can close it later
s.conn = nc
return comm, nil return comm, nil
} }
...@@ -14,13 +14,34 @@ import ( ...@@ -14,13 +14,34 @@ import (
type comm struct { type comm struct {
client *ssh.ClientConn client *ssh.ClientConn
config *Config
conn net.Conn
}
// Config is the structure used to configure the SSH communicator.
type Config struct {
// The configuration of the Go SSH connection
SSHConfig *ssh.ClientConfig
// Connection returns a new connection. The current connection
// in use will be closed as part of the Close method, or in the
// case an error occurs.
Connection func() (net.Conn, error)
} }
// Creates a new packer.Communicator implementation over SSH. This takes // Creates a new packer.Communicator implementation over SSH. This takes
// an already existing TCP connection and SSH configuration. // an already existing TCP connection and SSH configuration.
func New(c net.Conn, config *ssh.ClientConfig) (result *comm, err error) { func New(config *Config) (result *comm, err error) {
client, err := ssh.Client(c, config) // Establish an initial connection and connect
result = &comm{client} result = &comm{
config: config,
}
if err = result.reconnect(); err != nil {
result = nil
return
}
return return
} }
...@@ -168,3 +189,17 @@ func (c *comm) Upload(path string, input io.Reader) error { ...@@ -168,3 +189,17 @@ func (c *comm) Upload(path string, input io.Reader) error {
func (c *comm) Download(string, io.Writer) error { func (c *comm) Download(string, io.Writer) error {
panic("not implemented yet") panic("not implemented yet")
} }
func (c *comm) reconnect() (err error) {
if c.conn != nil {
c.conn.Close()
}
c.conn, err = c.config.Connection()
if err != nil {
return
}
c.client, err = ssh.Client(c.conn, c.config.SSHConfig)
return
}
...@@ -115,12 +115,20 @@ func TestNew_Invalid(t *testing.T) { ...@@ -115,12 +115,20 @@ func TestNew_Invalid(t *testing.T) {
}, },
} }
conn := func() (net.Conn, error) {
conn, err := net.Dial("tcp", newMockLineServer(t)) conn, err := net.Dial("tcp", newMockLineServer(t))
if err != nil { if err != nil {
t.Fatalf("unable to dial to remote side: %s", err) t.Fatalf("unable to dial to remote side: %s", err)
} }
return conn, err
}
config := &Config{
Connection: conn,
SSHConfig: clientConfig,
}
_, err = New(conn, clientConfig) _, err := New(config)
if err == nil { if err == nil {
t.Fatal("should have had an error connecting") t.Fatal("should have had an error connecting")
} }
...@@ -134,12 +142,20 @@ func TestStart(t *testing.T) { ...@@ -134,12 +142,20 @@ func TestStart(t *testing.T) {
}, },
} }
conn := func() (net.Conn, error) {
conn, err := net.Dial("tcp", newMockLineServer(t)) conn, err := net.Dial("tcp", newMockLineServer(t))
if err != nil { if err != nil {
t.Fatalf("unable to dial to remote side: %s", err) t.Fatalf("unable to dial to remote side: %s", err)
} }
return conn, err
}
config := &Config{
Connection: conn,
SSHConfig: clientConfig,
}
client, err := New(conn, clientConfig) client, err := New(config)
if err != nil { if err != nil {
t.Fatalf("error connecting to SSH: %s", err) t.Fatalf("error connecting to SSH: %s", err)
} }
......
package ssh
import (
"log"
"net"
)
// ConnectFunc is a convenience method for returning a function
// that just uses net.Dial to communicate with the remote end that
// is suitable for use with the SSH communicator configuration.
func ConnectFunc(network, addr string) func() (net.Conn, error) {
return func() (net.Conn, error) {
log.Printf("Opening conn for SSH to %s %s", network, addr)
return net.Dial(network, addr)
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment