Commit b4567c63 authored by Mitchell Hashimoto's avatar Mitchell Hashimoto

packer/plugin: use new RPC API

parent ce2304c9
...@@ -10,7 +10,6 @@ import ( ...@@ -10,7 +10,6 @@ import (
"io/ioutil" "io/ioutil"
"log" "log"
"net" "net"
"net/rpc"
"os" "os"
"os/exec" "os/exec"
"strings" "strings"
...@@ -130,56 +129,56 @@ func (c *Client) Exited() bool { ...@@ -130,56 +129,56 @@ func (c *Client) Exited() bool {
// Returns a builder implementation that is communicating over this // Returns a builder implementation that is communicating over this
// client. If the client hasn't been started, this will start it. // client. If the client hasn't been started, this will start it.
func (c *Client) Builder() (packer.Builder, error) { func (c *Client) Builder() (packer.Builder, error) {
client, err := c.rpcClient() client, err := c.packrpcClient()
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &cmdBuilder{packrpc.Builder(client), c}, nil return &cmdBuilder{client.Builder(), c}, nil
} }
// Returns a command implementation that is communicating over this // Returns a command implementation that is communicating over this
// client. If the client hasn't been started, this will start it. // client. If the client hasn't been started, this will start it.
func (c *Client) Command() (packer.Command, error) { func (c *Client) Command() (packer.Command, error) {
client, err := c.rpcClient() client, err := c.packrpcClient()
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &cmdCommand{packrpc.Command(client), c}, nil return &cmdCommand{client.Command(), c}, nil
} }
// Returns a hook implementation that is communicating over this // Returns a hook implementation that is communicating over this
// client. If the client hasn't been started, this will start it. // client. If the client hasn't been started, this will start it.
func (c *Client) Hook() (packer.Hook, error) { func (c *Client) Hook() (packer.Hook, error) {
client, err := c.rpcClient() client, err := c.packrpcClient()
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &cmdHook{packrpc.Hook(client), c}, nil return &cmdHook{client.Hook(), c}, nil
} }
// Returns a post-processor implementation that is communicating over // Returns a post-processor implementation that is communicating over
// this client. If the client hasn't been started, this will start it. // this client. If the client hasn't been started, this will start it.
func (c *Client) PostProcessor() (packer.PostProcessor, error) { func (c *Client) PostProcessor() (packer.PostProcessor, error) {
client, err := c.rpcClient() client, err := c.packrpcClient()
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &cmdPostProcessor{packrpc.PostProcessor(client), c}, nil return &cmdPostProcessor{client.PostProcessor(), c}, nil
} }
// Returns a provisioner implementation that is communicating over this // Returns a provisioner implementation that is communicating over this
// client. If the client hasn't been started, this will start it. // client. If the client hasn't been started, this will start it.
func (c *Client) Provisioner() (packer.Provisioner, error) { func (c *Client) Provisioner() (packer.Provisioner, error) {
client, err := c.rpcClient() client, err := c.packrpcClient()
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &cmdProvisioner{packrpc.Provisioner(client), c}, nil return &cmdProvisioner{client.Provisioner(), c}, nil
} }
// End the executing subprocess (if it is running) and perform any cleanup // End the executing subprocess (if it is running) and perform any cleanup
...@@ -361,7 +360,7 @@ func (c *Client) logStderr(r io.Reader) { ...@@ -361,7 +360,7 @@ func (c *Client) logStderr(r io.Reader) {
close(c.doneLogging) close(c.doneLogging)
} }
func (c *Client) rpcClient() (*rpc.Client, error) { func (c *Client) packrpcClient() (*packrpc.Client, error) {
address, err := c.Start() address, err := c.Start()
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -376,5 +375,11 @@ func (c *Client) rpcClient() (*rpc.Client, error) { ...@@ -376,5 +375,11 @@ func (c *Client) rpcClient() (*rpc.Client, error) {
tcpConn := conn.(*net.TCPConn) tcpConn := conn.(*net.TCPConn)
tcpConn.SetKeepAlive(true) tcpConn.SetKeepAlive(true)
return rpc.NewClient(tcpConn), nil client, err := packrpc.NewClient(tcpConn)
if err != nil {
tcpConn.Close()
return nil, err
}
return client, nil
} }
...@@ -14,7 +14,6 @@ import ( ...@@ -14,7 +14,6 @@ import (
packrpc "github.com/mitchellh/packer/packer/rpc" packrpc "github.com/mitchellh/packer/packer/rpc"
"log" "log"
"net" "net"
"net/rpc"
"os" "os"
"os/signal" "os/signal"
"runtime" "runtime"
...@@ -35,13 +34,14 @@ const MagicCookieValue = "d602bf8f470bc67ca7faa0386276bbdd4330efaf76d1a219cb4d69 ...@@ -35,13 +34,14 @@ const MagicCookieValue = "d602bf8f470bc67ca7faa0386276bbdd4330efaf76d1a219cb4d69
// know how to speak it. // know how to speak it.
const APIVersion = "1" const APIVersion = "1"
// This serves a single RPC connection on the given RPC server on // Server waits for a connection to this plugin and returns a Packer
// a random port. // RPC server that you can use to register components and serve them.
func serve(server *rpc.Server) (err error) { func Server() (*packrpc.Server, error) {
log.Printf("Plugin build against Packer '%s'", packer.GitCommit) log.Printf("Plugin build against Packer '%s'", packer.GitCommit)
if os.Getenv(MagicCookieKey) != MagicCookieValue { if os.Getenv(MagicCookieKey) != MagicCookieValue {
return errors.New("Please do not execute plugins directly. Packer will execute these for you.") return nil, errors.New(
"Please do not execute plugins directly. Packer will execute these for you.")
} }
// If there is no explicit number of Go threads to use, then set it // If there is no explicit number of Go threads to use, then set it
...@@ -51,12 +51,12 @@ func serve(server *rpc.Server) (err error) { ...@@ -51,12 +51,12 @@ func serve(server *rpc.Server) (err error) {
minPort, err := strconv.ParseInt(os.Getenv("PACKER_PLUGIN_MIN_PORT"), 10, 32) minPort, err := strconv.ParseInt(os.Getenv("PACKER_PLUGIN_MIN_PORT"), 10, 32)
if err != nil { if err != nil {
return return nil, err
} }
maxPort, err := strconv.ParseInt(os.Getenv("PACKER_PLUGIN_MAX_PORT"), 10, 32) maxPort, err := strconv.ParseInt(os.Getenv("PACKER_PLUGIN_MAX_PORT"), 10, 32)
if err != nil { if err != nil {
return return nil, err
} }
log.Printf("Plugin minimum port: %d\n", minPort) log.Printf("Plugin minimum port: %d\n", minPort)
...@@ -77,7 +77,6 @@ func serve(server *rpc.Server) (err error) { ...@@ -77,7 +77,6 @@ func serve(server *rpc.Server) (err error) {
break break
} }
defer listener.Close() defer listener.Close()
// Output the address to stdout // Output the address to stdout
...@@ -90,13 +89,12 @@ func serve(server *rpc.Server) (err error) { ...@@ -90,13 +89,12 @@ func serve(server *rpc.Server) (err error) {
conn, err := listener.Accept() conn, err := listener.Accept()
if err != nil { if err != nil {
log.Printf("Error accepting connection: %s\n", err.Error()) log.Printf("Error accepting connection: %s\n", err.Error())
return return nil, err
} }
// Serve a single connection // Serve a single connection
log.Println("Serving a plugin connection...") log.Println("Serving a plugin connection...")
server.ServeConn(conn) return packrpc.NewServer(conn), nil
return
} }
// Registers a signal handler to swallow and count interrupts so that the // Registers a signal handler to swallow and count interrupts so that the
...@@ -115,76 +113,6 @@ func countInterrupts() { ...@@ -115,76 +113,6 @@ func countInterrupts() {
}() }()
} }
// Serves a builder from a plugin.
func ServeBuilder(builder packer.Builder) {
log.Println("Preparing to serve a builder plugin...")
server := rpc.NewServer()
packrpc.RegisterBuilder(server, builder)
countInterrupts()
if err := serve(server); err != nil {
log.Printf("ERROR: %s", err)
os.Exit(1)
}
}
// Serves a command from a plugin.
func ServeCommand(command packer.Command) {
log.Println("Preparing to serve a command plugin...")
server := rpc.NewServer()
packrpc.RegisterCommand(server, command)
countInterrupts()
if err := serve(server); err != nil {
log.Printf("ERROR: %s", err)
os.Exit(1)
}
}
// Serves a hook from a plugin.
func ServeHook(hook packer.Hook) {
log.Println("Preparing to serve a hook plugin...")
server := rpc.NewServer()
packrpc.RegisterHook(server, hook)
countInterrupts()
if err := serve(server); err != nil {
log.Printf("ERROR: %s", err)
os.Exit(1)
}
}
// Serves a post-processor from a plugin.
func ServePostProcessor(p packer.PostProcessor) {
log.Println("Preparing to serve a post-processor plugin...")
server := rpc.NewServer()
packrpc.RegisterPostProcessor(server, p)
countInterrupts()
if err := serve(server); err != nil {
log.Printf("ERROR: %s", err)
os.Exit(1)
}
}
// Serves a provisioner from a plugin.
func ServeProvisioner(p packer.Provisioner) {
log.Println("Preparing to serve a provisioner plugin...")
server := rpc.NewServer()
packrpc.RegisterProvisioner(server, p)
countInterrupts()
if err := serve(server); err != nil {
log.Printf("ERROR: %s", err)
os.Exit(1)
}
}
// Tests whether or not the plugin was interrupted or not. // Tests whether or not the plugin was interrupted or not.
func Interrupted() bool { func Interrupted() bool {
return atomic.LoadInt32(&Interrupts) > 0 return atomic.LoadInt32(&Interrupts) > 0
......
...@@ -54,20 +54,50 @@ func TestHelperProcess(*testing.T) { ...@@ -54,20 +54,50 @@ func TestHelperProcess(*testing.T) {
fmt.Printf("%s1|:1234\n", APIVersion) fmt.Printf("%s1|:1234\n", APIVersion)
<-make(chan int) <-make(chan int)
case "builder": case "builder":
ServeBuilder(new(packer.MockBuilder)) server, err := Server()
if err != nil {
log.Printf("[ERR] %s", err)
os.Exit(1)
}
server.RegisterBuilder(new(packer.MockBuilder))
server.Serve()
case "command": case "command":
ServeCommand(new(helperCommand)) server, err := Server()
if err != nil {
log.Printf("[ERR] %s", err)
os.Exit(1)
}
server.RegisterCommand(new(helperCommand))
server.Serve()
case "hook": case "hook":
ServeHook(new(packer.MockHook)) server, err := Server()
if err != nil {
log.Printf("[ERR] %s", err)
os.Exit(1)
}
server.RegisterHook(new(packer.MockHook))
server.Serve()
case "invalid-rpc-address": case "invalid-rpc-address":
fmt.Println("lolinvalid") fmt.Println("lolinvalid")
case "mock": case "mock":
fmt.Printf("%s|:1234\n", APIVersion) fmt.Printf("%s|:1234\n", APIVersion)
<-make(chan int) <-make(chan int)
case "post-processor": case "post-processor":
ServePostProcessor(new(helperPostProcessor)) server, err := Server()
if err != nil {
log.Printf("[ERR] %s", err)
os.Exit(1)
}
server.RegisterPostProcessor(new(helperPostProcessor))
server.Serve()
case "provisioner": case "provisioner":
ServeProvisioner(new(packer.MockProvisioner)) server, err := Server()
if err != nil {
log.Printf("[ERR] %s", err)
os.Exit(1)
}
server.RegisterProvisioner(new(packer.MockProvisioner))
server.Serve()
case "start-timeout": case "start-timeout":
time.Sleep(1 * time.Minute) time.Sleep(1 * time.Minute)
os.Exit(1) os.Exit(1)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment