Commit 3a415008 authored by Mitchell Hashimoto's avatar Mitchell Hashimoto

packer/rpc: more robust communicator connection cleanup

parent 4c5d6170
...@@ -95,6 +95,7 @@ func (c *communicator) Start(cmd *packer.RemoteCmd) (err error) { ...@@ -95,6 +95,7 @@ func (c *communicator) Start(cmd *packer.RemoteCmd) (err error) {
return return
} }
log.Printf("[INFO] RPC client: Communicator ended with: %d", finished.ExitStatus)
cmd.SetExited(finished.ExitStatus) cmd.SetExited(finished.ExitStatus)
}() }()
...@@ -146,17 +147,28 @@ func (c *communicator) Download(path string, w io.Writer) (err error) { ...@@ -146,17 +147,28 @@ func (c *communicator) Download(path string, w io.Writer) (err error) {
return return
} }
func (c *CommunicatorServer) Start(args *CommunicatorStartArgs, reply *interface{}) (err error) { func (c *CommunicatorServer) Start(args *CommunicatorStartArgs, reply *interface{}) (error) {
// Build the RemoteCmd on this side so that it all pipes over // Build the RemoteCmd on this side so that it all pipes over
// to the remote side. // to the remote side.
var cmd packer.RemoteCmd var cmd packer.RemoteCmd
cmd.Command = args.Command cmd.Command = args.Command
// Create a channel to signal we're done so that we can close
// our stdin/stdout/stderr streams
toClose := make([]io.Closer, 0) toClose := make([]io.Closer, 0)
doneCh := make(chan struct{})
go func() {
<-doneCh
for _, conn := range toClose {
defer conn.Close()
}
}()
if args.StdinStreamId > 0 { if args.StdinStreamId > 0 {
conn, err := c.mux.Dial(args.StdinStreamId) conn, err := c.mux.Dial(args.StdinStreamId)
if err != nil { if err != nil {
return err close(doneCh)
return NewBasicError(err)
} }
toClose = append(toClose, conn) toClose = append(toClose, conn)
...@@ -166,7 +178,8 @@ func (c *CommunicatorServer) Start(args *CommunicatorStartArgs, reply *interface ...@@ -166,7 +178,8 @@ func (c *CommunicatorServer) Start(args *CommunicatorStartArgs, reply *interface
if args.StdoutStreamId > 0 { if args.StdoutStreamId > 0 {
conn, err := c.mux.Dial(args.StdoutStreamId) conn, err := c.mux.Dial(args.StdoutStreamId)
if err != nil { if err != nil {
return err close(doneCh)
return NewBasicError(err)
} }
toClose = append(toClose, conn) toClose = append(toClose, conn)
...@@ -176,38 +189,42 @@ func (c *CommunicatorServer) Start(args *CommunicatorStartArgs, reply *interface ...@@ -176,38 +189,42 @@ func (c *CommunicatorServer) Start(args *CommunicatorStartArgs, reply *interface
if args.StderrStreamId > 0 { if args.StderrStreamId > 0 {
conn, err := c.mux.Dial(args.StderrStreamId) conn, err := c.mux.Dial(args.StderrStreamId)
if err != nil { if err != nil {
return err close(doneCh)
return NewBasicError(err)
} }
toClose = append(toClose, conn) toClose = append(toClose, conn)
cmd.Stderr = conn cmd.Stderr = conn
} }
// Connect to the response address so we can write our result to it // Connect to the response address so we can write our result to it
// when ready. // when ready.
responseC, err := c.mux.Dial(args.ResponseStreamId) responseC, err := c.mux.Dial(args.ResponseStreamId)
if err != nil { if err != nil {
return err close(doneCh)
return NewBasicError(err)
} }
responseWriter := gob.NewEncoder(responseC) responseWriter := gob.NewEncoder(responseC)
// Start the actual command // Start the actual command
err = c.c.Start(&cmd) err = c.c.Start(&cmd)
if err != nil {
close(doneCh)
return NewBasicError(err)
}
// Start a goroutine to spin and wait for the process to actual // Start a goroutine to spin and wait for the process to actual
// exit. When it does, report it back to caller... // exit. When it does, report it back to caller...
go func() { go func() {
defer close(doneCh)
defer responseC.Close() defer responseC.Close()
for _, conn := range toClose {
defer conn.Close()
}
cmd.Wait() cmd.Wait()
log.Printf("[INFO] RPC endpoint: Communicator ended with: %d", cmd.ExitStatus)
responseWriter.Encode(&CommandFinished{cmd.ExitStatus}) responseWriter.Encode(&CommandFinished{cmd.ExitStatus})
}() }()
return return nil
} }
func (c *CommunicatorServer) Upload(args *CommunicatorUploadArgs, reply *interface{}) (err error) { func (c *CommunicatorServer) Upload(args *CommunicatorUploadArgs, reply *interface{}) (err error) {
......
...@@ -197,7 +197,7 @@ func (m *MuxConn) openStream(id uint32) (*Stream, error) { ...@@ -197,7 +197,7 @@ func (m *MuxConn) openStream(id uint32) (*Stream, error) {
// Create the stream object and channel where data will be sent to // Create the stream object and channel where data will be sent to
dataR, dataW := io.Pipe() dataR, dataW := io.Pipe()
writeCh := make(chan []byte, 10) writeCh := make(chan []byte, 256)
// Set the data channel so we can write to it. // Set the data channel so we can write to it.
stream := &Stream{ stream := &Stream{
...@@ -315,7 +315,7 @@ func (m *MuxConn) loop() { ...@@ -315,7 +315,7 @@ func (m *MuxConn) loop() {
select { select {
case stream.writeCh <- data: case stream.writeCh <- data:
default: default:
log.Printf("[ERR] Failed to write data, buffer full: %d", id) panic(fmt.Sprintf("Failed to write data, buffer full for stream %d", id))
} }
} else { } else {
log.Printf("[ERR] Data received for stream in state: %d", stream.state) log.Printf("[ERR] Data received for stream in state: %d", stream.state)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment