Commit db06fc75 authored by Mitchell Hashimoto's avatar Mitchell Hashimoto

packer/rpc: implement Communicator

parent 72fcb566
...@@ -51,6 +51,13 @@ func (c *Client) Cache() packer.Cache { ...@@ -51,6 +51,13 @@ func (c *Client) Cache() packer.Cache {
} }
} }
func (c *Client) Communicator() packer.Communicator {
return &communicator{
client: c.client,
mux: c.mux,
}
}
func (c *Client) PostProcessor() packer.PostProcessor { func (c *Client) PostProcessor() packer.PostProcessor {
return &postProcessor{ return &postProcessor{
client: c.client, client: c.client,
......
...@@ -2,11 +2,9 @@ package rpc ...@@ -2,11 +2,9 @@ package rpc
import ( import (
"encoding/gob" "encoding/gob"
"errors"
"github.com/mitchellh/packer/packer" "github.com/mitchellh/packer/packer"
"io" "io"
"log" "log"
"net"
"net/rpc" "net/rpc"
) )
...@@ -14,12 +12,14 @@ import ( ...@@ -14,12 +12,14 @@ import (
// executed over an RPC connection. // executed over an RPC connection.
type communicator struct { type communicator struct {
client *rpc.Client client *rpc.Client
mux *MuxConn
} }
// CommunicatorServer wraps a packer.Communicator implementation and makes // CommunicatorServer wraps a packer.Communicator implementation and makes
// it exportable as part of a Golang RPC server. // it exportable as part of a Golang RPC server.
type CommunicatorServer struct { type CommunicatorServer struct {
c packer.Communicator c packer.Communicator
mux *MuxConn
} }
type CommandFinished struct { type CommandFinished struct {
...@@ -28,20 +28,20 @@ type CommandFinished struct { ...@@ -28,20 +28,20 @@ type CommandFinished struct {
type CommunicatorStartArgs struct { type CommunicatorStartArgs struct {
Command string Command string
StdinAddress string StdinStreamId uint32
StdoutAddress string StdoutStreamId uint32
StderrAddress string StderrStreamId uint32
ResponseAddress string ResponseStreamId uint32
} }
type CommunicatorDownloadArgs struct { type CommunicatorDownloadArgs struct {
Path string Path string
WriterAddress string WriterStreamId uint32
} }
type CommunicatorUploadArgs struct { type CommunicatorUploadArgs struct {
Path string Path string
ReaderAddress string ReaderStreamId uint32
} }
type CommunicatorUploadDirArgs struct { type CommunicatorUploadDirArgs struct {
...@@ -51,7 +51,7 @@ type CommunicatorUploadDirArgs struct { ...@@ -51,7 +51,7 @@ type CommunicatorUploadDirArgs struct {
} }
func Communicator(client *rpc.Client) *communicator { func Communicator(client *rpc.Client) *communicator {
return &communicator{client} return &communicator{client: client}
} }
func (c *communicator) Start(cmd *packer.RemoteCmd) (err error) { func (c *communicator) Start(cmd *packer.RemoteCmd) (err error) {
...@@ -59,41 +59,38 @@ func (c *communicator) Start(cmd *packer.RemoteCmd) (err error) { ...@@ -59,41 +59,38 @@ func (c *communicator) Start(cmd *packer.RemoteCmd) (err error) {
args.Command = cmd.Command args.Command = cmd.Command
if cmd.Stdin != nil { if cmd.Stdin != nil {
stdinL := netListenerInRange(portRangeMin, portRangeMax) args.StdinStreamId = c.mux.NextId()
args.StdinAddress = stdinL.Addr().String() go serveSingleCopy("stdin", c.mux, args.StdinStreamId, nil, cmd.Stdin)
go serveSingleCopy("stdin", stdinL, nil, cmd.Stdin)
} }
if cmd.Stdout != nil { if cmd.Stdout != nil {
stdoutL := netListenerInRange(portRangeMin, portRangeMax) args.StdoutStreamId = c.mux.NextId()
args.StdoutAddress = stdoutL.Addr().String() go serveSingleCopy("stdout", c.mux, args.StdoutStreamId, cmd.Stdout, nil)
go serveSingleCopy("stdout", stdoutL, cmd.Stdout, nil)
} }
if cmd.Stderr != nil { if cmd.Stderr != nil {
stderrL := netListenerInRange(portRangeMin, portRangeMax) args.StderrStreamId = c.mux.NextId()
args.StderrAddress = stderrL.Addr().String() go serveSingleCopy("stderr", c.mux, args.StderrStreamId, cmd.Stderr, nil)
go serveSingleCopy("stderr", stderrL, cmd.Stderr, nil)
} }
responseL := netListenerInRange(portRangeMin, portRangeMax) responseStreamId := c.mux.NextId()
args.ResponseAddress = responseL.Addr().String() args.ResponseStreamId = responseStreamId
go func() { go func() {
defer responseL.Close() conn, err := c.mux.Accept(responseStreamId)
conn, err := responseL.Accept()
if err != nil { if err != nil {
log.Printf("[ERR] Error accepting response stream %d: %s",
responseStreamId, err)
cmd.SetExited(123) cmd.SetExited(123)
return return
} }
defer conn.Close() defer conn.Close()
decoder := gob.NewDecoder(conn)
var finished CommandFinished var finished CommandFinished
decoder := gob.NewDecoder(conn)
if err := decoder.Decode(&finished); err != nil { if err := decoder.Decode(&finished); err != nil {
log.Printf("[ERR] Error decoding response stream %d: %s",
responseStreamId, err)
cmd.SetExited(123) cmd.SetExited(123)
return return
} }
...@@ -106,23 +103,13 @@ func (c *communicator) Start(cmd *packer.RemoteCmd) (err error) { ...@@ -106,23 +103,13 @@ func (c *communicator) Start(cmd *packer.RemoteCmd) (err error) {
} }
func (c *communicator) Upload(path string, r io.Reader) (err error) { func (c *communicator) Upload(path string, r io.Reader) (err error) {
// We need to create a server that can proxy the reader data
// over because we can't simply gob encode an io.Reader
readerL := netListenerInRange(portRangeMin, portRangeMax)
if readerL == nil {
err = errors.New("couldn't allocate listener for upload reader")
return
}
// Make sure at the end of this call, we close the listener
defer readerL.Close()
// Pipe the reader through to the connection // Pipe the reader through to the connection
go serveSingleCopy("uploadReader", readerL, nil, r) streamId := c.mux.NextId()
go serveSingleCopy("uploadReader", c.mux, streamId, nil, r)
args := CommunicatorUploadArgs{ args := CommunicatorUploadArgs{
path, Path: path,
readerL.Addr().String(), ReaderStreamId: streamId,
} }
err = c.client.Call("Communicator.Upload", &args, new(interface{})) err = c.client.Call("Communicator.Upload", &args, new(interface{}))
...@@ -146,23 +133,13 @@ func (c *communicator) UploadDir(dst string, src string, exclude []string) error ...@@ -146,23 +133,13 @@ func (c *communicator) UploadDir(dst string, src string, exclude []string) error
} }
func (c *communicator) Download(path string, w io.Writer) (err error) { func (c *communicator) Download(path string, w io.Writer) (err error) {
// We need to create a server that can proxy that data downloaded
// into the writer because we can't gob encode a writer directly.
writerL := netListenerInRange(portRangeMin, portRangeMax)
if writerL == nil {
err = errors.New("couldn't allocate listener for download writer")
return
}
// Make sure we close the listener once we're done because we'll be done
defer writerL.Close()
// Serve a single connection and a single copy // Serve a single connection and a single copy
go serveSingleCopy("downloadWriter", writerL, w, nil) streamId := c.mux.NextId()
go serveSingleCopy("downloadWriter", c.mux, streamId, w, nil)
args := CommunicatorDownloadArgs{ args := CommunicatorDownloadArgs{
path, Path: path,
writerL.Addr().String(), WriterStreamId: streamId,
} }
err = c.client.Call("Communicator.Download", &args, new(interface{})) err = c.client.Call("Communicator.Download", &args, new(interface{}))
...@@ -175,40 +152,40 @@ func (c *CommunicatorServer) Start(args *CommunicatorStartArgs, reply *interface ...@@ -175,40 +152,40 @@ func (c *CommunicatorServer) Start(args *CommunicatorStartArgs, reply *interface
var cmd packer.RemoteCmd var cmd packer.RemoteCmd
cmd.Command = args.Command cmd.Command = args.Command
toClose := make([]net.Conn, 0) toClose := make([]io.Closer, 0)
if args.StdinAddress != "" { if args.StdinStreamId > 0 {
stdinC, err := tcpDial(args.StdinAddress) conn, err := c.mux.Dial(args.StdinStreamId)
if err != nil { if err != nil {
return err return err
} }
toClose = append(toClose, stdinC) toClose = append(toClose, conn)
cmd.Stdin = stdinC cmd.Stdin = conn
} }
if args.StdoutAddress != "" { if args.StdoutStreamId > 0 {
stdoutC, err := tcpDial(args.StdoutAddress) conn, err := c.mux.Dial(args.StdoutStreamId)
if err != nil { if err != nil {
return err return err
} }
toClose = append(toClose, stdoutC) toClose = append(toClose, conn)
cmd.Stdout = stdoutC cmd.Stdout = conn
} }
if args.StderrAddress != "" { if args.StderrStreamId > 0 {
stderrC, err := tcpDial(args.StderrAddress) conn, err := c.mux.Dial(args.StderrStreamId)
if err != nil { if err != nil {
return err return err
} }
toClose = append(toClose, stderrC) toClose = append(toClose, conn)
cmd.Stderr = stderrC cmd.Stderr = conn
} }
// Connect to the response address so we can write our result to it // Connect to the response address so we can write our result to it
// when ready. // when ready.
responseC, err := tcpDial(args.ResponseAddress) responseC, err := c.mux.Dial(args.ResponseStreamId)
if err != nil { if err != nil {
return err return err
} }
...@@ -234,11 +211,10 @@ func (c *CommunicatorServer) Start(args *CommunicatorStartArgs, reply *interface ...@@ -234,11 +211,10 @@ func (c *CommunicatorServer) Start(args *CommunicatorStartArgs, reply *interface
} }
func (c *CommunicatorServer) Upload(args *CommunicatorUploadArgs, reply *interface{}) (err error) { func (c *CommunicatorServer) Upload(args *CommunicatorUploadArgs, reply *interface{}) (err error) {
readerC, err := tcpDial(args.ReaderAddress) readerC, err := c.mux.Dial(args.ReaderStreamId)
if err != nil { if err != nil {
return return
} }
defer readerC.Close() defer readerC.Close()
err = c.c.Upload(args.Path, readerC) err = c.c.Upload(args.Path, readerC)
...@@ -250,21 +226,18 @@ func (c *CommunicatorServer) UploadDir(args *CommunicatorUploadDirArgs, reply *e ...@@ -250,21 +226,18 @@ func (c *CommunicatorServer) UploadDir(args *CommunicatorUploadDirArgs, reply *e
} }
func (c *CommunicatorServer) Download(args *CommunicatorDownloadArgs, reply *interface{}) (err error) { func (c *CommunicatorServer) Download(args *CommunicatorDownloadArgs, reply *interface{}) (err error) {
writerC, err := tcpDial(args.WriterAddress) writerC, err := c.mux.Dial(args.WriterStreamId)
if err != nil { if err != nil {
return return
} }
defer writerC.Close() defer writerC.Close()
err = c.c.Download(args.Path, writerC) err = c.c.Download(args.Path, writerC)
return return
} }
func serveSingleCopy(name string, l net.Listener, dst io.Writer, src io.Reader) { func serveSingleCopy(name string, mux *MuxConn, id uint32, dst io.Writer, src io.Reader) {
defer l.Close() conn, err := mux.Accept(id)
conn, err := l.Accept()
if err != nil { if err != nil {
log.Printf("'%s' accept error: %s", name, err) log.Printf("'%s' accept error: %s", name, err)
return return
......
...@@ -4,7 +4,6 @@ import ( ...@@ -4,7 +4,6 @@ import (
"bufio" "bufio"
"github.com/mitchellh/packer/packer" "github.com/mitchellh/packer/packer"
"io" "io"
"net/rpc"
"reflect" "reflect"
"testing" "testing"
) )
...@@ -14,16 +13,11 @@ func TestCommunicatorRPC(t *testing.T) { ...@@ -14,16 +13,11 @@ func TestCommunicatorRPC(t *testing.T) {
c := new(packer.MockCommunicator) c := new(packer.MockCommunicator)
// Start the server // Start the server
server := rpc.NewServer() client, server := testClientServer(t)
RegisterCommunicator(server, c) defer client.Close()
address := serveSingleConn(server) defer server.Close()
server.RegisterCommunicator(c)
// Create the client over RPC and run some methods to verify it works remote := client.Communicator()
client, err := rpc.Dial("tcp", address)
if err != nil {
t.Fatalf("err: %s", err)
}
remote := Communicator(client)
// The remote command we'll use // The remote command we'll use
stdin_r, stdin_w := io.Pipe() stdin_r, stdin_w := io.Pipe()
...@@ -42,7 +36,7 @@ func TestCommunicatorRPC(t *testing.T) { ...@@ -42,7 +36,7 @@ func TestCommunicatorRPC(t *testing.T) {
c.StartExitStatus = 42 c.StartExitStatus = 42
// Test Start // Test Start
err = remote.Start(&cmd) err := remote.Start(&cmd)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
...@@ -74,7 +68,7 @@ func TestCommunicatorRPC(t *testing.T) { ...@@ -74,7 +68,7 @@ func TestCommunicatorRPC(t *testing.T) {
stdin_w.Close() stdin_w.Close()
cmd.Wait() cmd.Wait()
if c.StartStdin != "info\n" { if c.StartStdin != "info\n" {
t.Fatalf("bad data: %s", data) t.Fatalf("bad data: %s", c.StartStdin)
} }
// Test that we can get the exit status properly // Test that we can get the exit status properly
......
...@@ -266,7 +266,7 @@ func (m *MuxConn) loop() { ...@@ -266,7 +266,7 @@ func (m *MuxConn) loop() {
return return
} }
log.Printf("[DEBUG] Stream %d received packet %d", id, packetType) //log.Printf("[DEBUG] Stream %d received packet %d", id, packetType)
switch packetType { switch packetType {
case muxPacketAck: case muxPacketAck:
stream.mu.Lock() stream.mu.Lock()
......
...@@ -38,7 +38,7 @@ func RegisterCommand(s *rpc.Server, c packer.Command) { ...@@ -38,7 +38,7 @@ func RegisterCommand(s *rpc.Server, c packer.Command) {
// Registers the appropriate endpoint on an RPC server to serve a // Registers the appropriate endpoint on an RPC server to serve a
// Packer Communicator. // Packer Communicator.
func RegisterCommunicator(s *rpc.Server, c packer.Communicator) { func RegisterCommunicator(s *rpc.Server, c packer.Communicator) {
registerComponent(s, "Communicator", &CommunicatorServer{c}, false) registerComponent(s, "Communicator", &CommunicatorServer{c: c}, false)
} }
// Registers the appropriate endpoint on an RPC server to serve a // Registers the appropriate endpoint on an RPC server to serve a
......
...@@ -14,6 +14,7 @@ var endpointId uint64 ...@@ -14,6 +14,7 @@ var endpointId uint64
const ( const (
DefaultArtifactEndpoint string = "Artifact" DefaultArtifactEndpoint string = "Artifact"
DefaultCacheEndpoint = "Cache" DefaultCacheEndpoint = "Cache"
DefaultCommunicatorEndpoint = "Communicator"
DefaultPostProcessorEndpoint = "PostProcessor" DefaultPostProcessorEndpoint = "PostProcessor"
DefaultUiEndpoint = "Ui" DefaultUiEndpoint = "Ui"
) )
...@@ -55,6 +56,13 @@ func (s *Server) RegisterCache(c packer.Cache) { ...@@ -55,6 +56,13 @@ func (s *Server) RegisterCache(c packer.Cache) {
}) })
} }
func (s *Server) RegisterCommunicator(c packer.Communicator) {
s.server.RegisterName(DefaultCommunicatorEndpoint, &CommunicatorServer{
c: c,
mux: s.mux,
})
}
func (s *Server) RegisterPostProcessor(p packer.PostProcessor) { func (s *Server) RegisterPostProcessor(p packer.PostProcessor) {
s.server.RegisterName(DefaultPostProcessorEndpoint, &PostProcessorServer{ s.server.RegisterName(DefaultPostProcessorEndpoint, &PostProcessorServer{
mux: s.mux, mux: s.mux,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment