Commit ae00414b authored by Mitchell Hashimoto's avatar Mitchell Hashimoto

packer/plugin: communicate over unix domain sockets if you can

parent 8d9a6eef
...@@ -34,7 +34,7 @@ type Client struct { ...@@ -34,7 +34,7 @@ type Client struct {
exited bool exited bool
doneLogging chan struct{} doneLogging chan struct{}
l sync.Mutex l sync.Mutex
address string address net.Addr
} }
// ClientConfig is the configuration used to initialize a new // ClientConfig is the configuration used to initialize a new
...@@ -206,11 +206,11 @@ func (c *Client) Kill() { ...@@ -206,11 +206,11 @@ func (c *Client) Kill() {
// This method is safe to call multiple times. Subsequent calls have no effect. // This method is safe to call multiple times. Subsequent calls have no effect.
// Once a client has been started once, it cannot be started again, even if // Once a client has been started once, it cannot be started again, even if
// it was killed. // it was killed.
func (c *Client) Start() (address string, err error) { func (c *Client) Start() (addr net.Addr, err error) {
c.l.Lock() c.l.Lock()
defer c.l.Unlock() defer c.l.Unlock()
if c.address != "" { if c.address != nil {
return c.address, nil return c.address, nil
} }
...@@ -320,8 +320,8 @@ func (c *Client) Start() (address string, err error) { ...@@ -320,8 +320,8 @@ func (c *Client) Start() (address string, err error) {
// Trim the line and split by "|" in order to get the parts of // Trim the line and split by "|" in order to get the parts of
// the output. // the output.
line := strings.TrimSpace(string(lineBytes)) line := strings.TrimSpace(string(lineBytes))
parts := strings.SplitN(line, "|", 2) parts := strings.SplitN(line, "|", 3)
if len(parts) < 2 { if len(parts) < 3 {
err = fmt.Errorf("Unrecognized remote plugin message: %s", line) err = fmt.Errorf("Unrecognized remote plugin message: %s", line)
return return
} }
...@@ -333,10 +333,17 @@ func (c *Client) Start() (address string, err error) { ...@@ -333,10 +333,17 @@ func (c *Client) Start() (address string, err error) {
return return
} }
c.address = parts[1] switch parts[1] {
address = c.address case "tcp":
addr, err = net.ResolveTCPAddr("tcp", parts[2])
case "unix":
addr, err = net.ResolveUnixAddr("unix", parts[2])
default:
err = fmt.Errorf("Unknown address type: %s", parts[1])
}
} }
c.address = addr
return return
} }
...@@ -361,23 +368,24 @@ func (c *Client) logStderr(r io.Reader) { ...@@ -361,23 +368,24 @@ func (c *Client) logStderr(r io.Reader) {
} }
func (c *Client) packrpcClient() (*packrpc.Client, error) { func (c *Client) packrpcClient() (*packrpc.Client, error) {
address, err := c.Start() addr, err := c.Start()
if err != nil { if err != nil {
return nil, err return nil, err
} }
conn, err := net.Dial("tcp", address) conn, err := net.Dial(addr.Network(), addr.String())
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Make sure to set keep alive so that the connection doesn't die if tcpConn, ok := conn.(*net.TCPConn); ok {
tcpConn := conn.(*net.TCPConn) // Make sure to set keep alive so that the connection doesn't die
tcpConn.SetKeepAlive(true) tcpConn.SetKeepAlive(true)
}
client, err := packrpc.NewClient(tcpConn) client, err := packrpc.NewClient(conn)
if err != nil { if err != nil {
tcpConn.Close() conn.Close()
return nil, err return nil, err
} }
......
...@@ -20,8 +20,12 @@ func TestClient(t *testing.T) { ...@@ -20,8 +20,12 @@ func TestClient(t *testing.T) {
t.Fatalf("err should be nil, got %s", err) t.Fatalf("err should be nil, got %s", err)
} }
if addr != ":1234" { if addr.Network() != "tcp" {
t.Fatalf("incorrect addr %s", addr) t.Fatalf("bad: %#v", addr)
}
if addr.String() != ":1234" {
t.Fatalf("bad: %#v", addr)
} }
// Test that it exits properly if killed // Test that it exits properly if killed
......
...@@ -51,7 +51,7 @@ func TestHelperProcess(*testing.T) { ...@@ -51,7 +51,7 @@ func TestHelperProcess(*testing.T) {
cmd, args := args[0], args[1:] cmd, args := args[0], args[1:]
switch cmd { switch cmd {
case "bad-version": case "bad-version":
fmt.Printf("%s1|:1234\n", APIVersion) fmt.Printf("%s1|tcp|:1234\n", APIVersion)
<-make(chan int) <-make(chan int)
case "builder": case "builder":
server, err := Server() server, err := Server()
...@@ -80,7 +80,7 @@ func TestHelperProcess(*testing.T) { ...@@ -80,7 +80,7 @@ func TestHelperProcess(*testing.T) {
case "invalid-rpc-address": case "invalid-rpc-address":
fmt.Println("lolinvalid") fmt.Println("lolinvalid")
case "mock": case "mock":
fmt.Printf("%s|:1234\n", APIVersion) fmt.Printf("%s|tcp|:1234\n", APIVersion)
<-make(chan int) <-make(chan int)
case "post-processor": case "post-processor":
server, err := Server() server, err := Server()
...@@ -102,11 +102,11 @@ func TestHelperProcess(*testing.T) { ...@@ -102,11 +102,11 @@ func TestHelperProcess(*testing.T) {
time.Sleep(1 * time.Minute) time.Sleep(1 * time.Minute)
os.Exit(1) os.Exit(1)
case "stderr": case "stderr":
fmt.Printf("%s|:1234\n", APIVersion) fmt.Printf("%s|tcp|:1234\n", APIVersion)
log.Println("HELLO") log.Println("HELLO")
log.Println("WORLD") log.Println("WORLD")
case "stdin": case "stdin":
fmt.Printf("%s|:1234\n", APIVersion) fmt.Printf("%s|tcp|:1234\n", APIVersion)
data := make([]byte, 5) data := make([]byte, 5)
if _, err := os.Stdin.Read(data); err != nil { if _, err := os.Stdin.Read(data); err != nil {
log.Printf("stdin read error: %s", err) log.Printf("stdin read error: %s", err)
......
...@@ -12,6 +12,7 @@ import ( ...@@ -12,6 +12,7 @@ import (
"fmt" "fmt"
"github.com/mitchellh/packer/packer" "github.com/mitchellh/packer/packer"
packrpc "github.com/mitchellh/packer/packer/rpc" packrpc "github.com/mitchellh/packer/packer/rpc"
"io/ioutil"
"log" "log"
"net" "net"
"os" "os"
...@@ -32,7 +33,7 @@ const MagicCookieValue = "d602bf8f470bc67ca7faa0386276bbdd4330efaf76d1a219cb4d69 ...@@ -32,7 +33,7 @@ const MagicCookieValue = "d602bf8f470bc67ca7faa0386276bbdd4330efaf76d1a219cb4d69
// The APIVersion is outputted along with the RPC address. The plugin // The APIVersion is outputted along with the RPC address. The plugin
// client validates this API version and will show an error if it doesn't // client validates this API version and will show an error if it doesn't
// know how to speak it. // know how to speak it.
const APIVersion = "1" const APIVersion = "2"
// Server waits for a connection to this plugin and returns a Packer // Server waits for a connection to this plugin and returns a Packer
// RPC server that you can use to register components and serve them. // RPC server that you can use to register components and serve them.
...@@ -62,23 +63,19 @@ func Server() (*packrpc.Server, error) { ...@@ -62,23 +63,19 @@ func Server() (*packrpc.Server, error) {
log.Printf("Plugin minimum port: %d\n", minPort) log.Printf("Plugin minimum port: %d\n", minPort)
log.Printf("Plugin maximum port: %d\n", maxPort) log.Printf("Plugin maximum port: %d\n", maxPort)
var address string listener, err := serverListener(minPort, maxPort)
var listener net.Listener if err != nil {
for port := minPort; port <= maxPort; port++ { return nil, err
address = fmt.Sprintf("127.0.0.1:%d", port)
listener, err = net.Listen("tcp", address)
if err != nil {
err = nil
continue
}
break
} }
defer listener.Close() defer listener.Close()
// Output the address to stdout // Output the address to stdout
log.Printf("Plugin address: %s\n", address) log.Printf("Plugin address: %s %s\n",
fmt.Printf("%s|%s\n", APIVersion, address) listener.Addr().Network(), listener.Addr().String())
fmt.Printf("%s|%s|%s\n",
APIVersion,
listener.Addr().Network(),
listener.Addr().String())
os.Stdout.Sync() os.Stdout.Sync()
// Accept a connection // Accept a connection
...@@ -105,3 +102,42 @@ func Server() (*packrpc.Server, error) { ...@@ -105,3 +102,42 @@ func Server() (*packrpc.Server, error) {
log.Println("Serving a plugin connection...") log.Println("Serving a plugin connection...")
return packrpc.NewServer(conn), nil return packrpc.NewServer(conn), nil
} }
func serverListener(minPort, maxPort int64) (net.Listener, error) {
if runtime.GOOS == "windows" {
return serverListener_tcp(minPort, maxPort)
}
return serverListener_unix()
}
func serverListener_tcp(minPort, maxPort int64) (net.Listener, error) {
for port := minPort; port <= maxPort; port++ {
address := fmt.Sprintf("127.0.0.1:%d", port)
listener, err := net.Listen("tcp", address)
if err == nil {
return listener, nil
}
}
return nil, errors.New("Couldn't bind plugin TCP listener")
}
func serverListener_unix() (net.Listener, error) {
tf, err := ioutil.TempFile("", "packer-plugin")
if err != nil {
return nil, err
}
path := tf.Name()
// Close the file and remove it because it has to not exist for
// the domain socket.
if err := tf.Close(); err != nil {
return nil, err
}
if err := os.Remove(path); err != nil {
return nil, err
}
return net.Listen("unix", path)
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment