Commit 50cfb678 authored by Mitchell Hashimoto's avatar Mitchell Hashimoto

packer/rpc: muxconn is a lot more sane, acts like bsd socket

parent 36a47f5b
...@@ -2,7 +2,6 @@ package rpc ...@@ -2,7 +2,6 @@ package rpc
import ( import (
"github.com/mitchellh/packer/packer" "github.com/mitchellh/packer/packer"
"net/rpc"
"reflect" "reflect"
"testing" "testing"
) )
...@@ -31,19 +30,13 @@ func (testArtifact) Destroy() error { ...@@ -31,19 +30,13 @@ func (testArtifact) Destroy() error {
func TestArtifactRPC(t *testing.T) { func TestArtifactRPC(t *testing.T) {
// Create the interface to test // Create the interface to test
a := new(testArtifact) a := new(packer.MockArtifact)
// Start the server // Start the server
server := rpc.NewServer() server := NewServer()
RegisterArtifact(server, a) server.RegisterArtifact(a)
address := serveSingleConn(server) client := testClient(t, server)
aClient := client.Artifact()
// Create the client over RPC and run some methods to verify it works
client, err := rpc.Dial("tcp", address)
if err != nil {
t.Fatalf("err: %s", err)
}
aClient := Artifact(client)
// Test // Test
if aClient.BuilderId() != "bid" { if aClient.BuilderId() != "bid" {
......
package rpc
import (
"github.com/mitchellh/packer/packer"
"io"
"net/rpc"
)
// Client is the client end that communicates with a Packer RPC server.
// Establishing a connection is up to the user, the Client can just
// communicate over any ReadWriteCloser.
type Client struct {
mux *MuxConn
client *rpc.Client
}
func NewClient(rwc io.ReadWriteCloser) (*Client, error) {
// Create the MuxConn around the RWC and get the client to server stream.
// This is the primary stream that we use to communicate with the
// remote RPC server. On the remote side Server.ServeConn also listens
// on this stream ID.
mux := NewMuxConn(rwc)
stream, err := mux.Dial(0)
if err != nil {
return nil, err
}
return &Client{
mux: mux,
client: rpc.NewClient(stream),
}, nil
}
func (c *Client) Artifact() packer.Artifact {
return &artifact{
client: c.client,
}
}
package rpc
import (
"testing"
)
func testClient(t *testing.T, server *Server) *Client {
return nil
}
...@@ -6,6 +6,7 @@ import ( ...@@ -6,6 +6,7 @@ import (
"io" "io"
"log" "log"
"sync" "sync"
"time"
) )
// MuxConn is a connection that can be used bi-directionally for RPC. Normally, // MuxConn is a connection that can be used bi-directionally for RPC. Normally,
...@@ -20,15 +21,24 @@ import ( ...@@ -20,15 +21,24 @@ import (
// we decided to cut a lot of corners and make this easily usable for Packer. // we decided to cut a lot of corners and make this easily usable for Packer.
type MuxConn struct { type MuxConn struct {
rwc io.ReadWriteCloser rwc io.ReadWriteCloser
streams map[byte]io.WriteCloser streams map[byte]*Stream
mu sync.RWMutex mu sync.RWMutex
wlock sync.Mutex wlock sync.Mutex
} }
type muxPacketType byte
const (
muxPacketSyn muxPacketType = iota
muxPacketAck
muxPacketFin
muxPacketData
)
func NewMuxConn(rwc io.ReadWriteCloser) *MuxConn { func NewMuxConn(rwc io.ReadWriteCloser) *MuxConn {
m := &MuxConn{ m := &MuxConn{
rwc: rwc, rwc: rwc,
streams: make(map[byte]io.WriteCloser), streams: make(map[byte]*Stream),
} }
go m.loop() go m.loop()
...@@ -46,56 +56,140 @@ func (m *MuxConn) Close() error { ...@@ -46,56 +56,140 @@ func (m *MuxConn) Close() error {
for _, w := range m.streams { for _, w := range m.streams {
w.Close() w.Close()
} }
m.streams = make(map[byte]io.WriteCloser) m.streams = make(map[byte]*Stream)
return m.rwc.Close() return m.rwc.Close()
} }
// Stream returns a io.ReadWriteCloser that will only read/write to the // Accept accepts a multiplexed connection with the given ID. This
// given stream ID. No handshake is done so if the remote end does not // will block until a request is made to connect.
// have a stream open with the same ID, then the messages will simply func (m *MuxConn) Accept(id byte) (io.ReadWriteCloser, error) {
// be dropped. stream, err := m.openStream(id)
// if err != nil {
// This is one of those cases where we cut corners. Since Packer only does return nil, err
// local connections, we can assume that both ends are ready at a certain }
// point. In a real muxer, we'd probably want a handshake here.
func (m *MuxConn) Stream(id byte) (io.ReadWriteCloser, error) { // If the stream isn't closed, then it is already open somehow
stream.mu.Lock()
if stream.state != streamStateSynRecv && stream.state != streamStateClosed {
stream.mu.Unlock()
return nil, fmt.Errorf("Stream already open in bad state: %d", stream.state)
}
if stream.state == streamStateSynRecv {
// Fast track establishing since we already got the syn
stream.setState(streamStateEstablished)
stream.mu.Unlock()
}
if stream.state != streamStateEstablished {
// Go into the listening state
stream.setState(streamStateListen)
stream.mu.Unlock()
// Wait for the connection to establish
ACCEPT_ESTABLISH_LOOP:
for {
time.Sleep(50 * time.Millisecond)
stream.mu.Lock()
switch stream.state {
case streamStateListen:
stream.mu.Unlock()
case streamStateEstablished:
stream.mu.Unlock()
break ACCEPT_ESTABLISH_LOOP
default:
defer stream.mu.Unlock()
return nil, fmt.Errorf("Stream went to bad state: %d", stream.state)
}
}
}
// Send the ack down
if _, err := m.write(stream.id, muxPacketAck, nil); err != nil {
return nil, err
}
return stream, nil
}
// Dial opens a connection to the remote end using the given stream ID.
// An Accept on the remote end will only work with if the IDs match.
func (m *MuxConn) Dial(id byte) (io.ReadWriteCloser, error) {
stream, err := m.openStream(id)
if err != nil {
return nil, err
}
// If the stream isn't closed, then it is already open somehow
stream.mu.Lock()
if stream.state != streamStateClosed {
stream.mu.Unlock()
return nil, fmt.Errorf("Stream already open in bad state: %d", stream.state)
}
// Open a connection
if _, err := m.write(stream.id, muxPacketSyn, nil); err != nil {
return nil, err
}
stream.setState(streamStateSynSent)
stream.mu.Unlock()
for {
time.Sleep(50 * time.Millisecond)
stream.mu.Lock()
switch stream.state {
case streamStateSynSent:
stream.mu.Unlock()
case streamStateEstablished:
stream.mu.Unlock()
return stream, nil
default:
defer stream.mu.Unlock()
return nil, fmt.Errorf("Stream went to bad state: %d", stream.state)
}
}
}
func (m *MuxConn) openStream(id byte) (*Stream, error) {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock()
if _, ok := m.streams[id]; ok { if stream, ok := m.streams[id]; ok {
m.mu.Unlock() return stream, nil
return nil, fmt.Errorf("Stream %d already exists", id)
} }
// Create the stream object and channel where data will be sent to // Create the stream object and channel where data will be sent to
dataR, dataW := io.Pipe() dataR, dataW := io.Pipe()
// Set the data channel so we can write to it. // Set the data channel so we can write to it.
m.streams[id] = dataW
// Unlock the lock so that the reader can access the stream writer.
m.mu.Unlock()
stream := &Stream{ stream := &Stream{
id: id, id: id,
mux: m, mux: m,
reader: dataR, reader: dataR,
writer: dataW,
} }
stream.setState(streamStateClosed)
return stream, nil m.streams[id] = stream
return m.streams[id], nil
} }
func (m *MuxConn) loop() { func (m *MuxConn) loop() {
defer m.Close() defer m.Close()
for {
var id byte var id byte
var packetType muxPacketType
var length int32 var length int32
for {
if err := binary.Read(m.rwc, binary.BigEndian, &id); err != nil { if err := binary.Read(m.rwc, binary.BigEndian, &id); err != nil {
log.Printf("[ERR] Error reading stream ID: %s", err) log.Printf("[ERR] Error reading stream ID: %s", err)
return return
} }
if err := binary.Read(m.rwc, binary.BigEndian, &packetType); err != nil {
log.Printf("[ERR] Error reading packet type: %s", err)
return
}
if err := binary.Read(m.rwc, binary.BigEndian, &length); err != nil { if err := binary.Read(m.rwc, binary.BigEndian, &length); err != nil {
log.Printf("[ERR] Error reading length: %s", err) log.Printf("[ERR] Error reading length: %s", err)
return return
...@@ -103,32 +197,76 @@ func (m *MuxConn) loop() { ...@@ -103,32 +197,76 @@ func (m *MuxConn) loop() {
// TODO(mitchellh): probably would be better to re-use a buffer... // TODO(mitchellh): probably would be better to re-use a buffer...
data := make([]byte, length) data := make([]byte, length)
if length > 0 {
if _, err := m.rwc.Read(data); err != nil { if _, err := m.rwc.Read(data); err != nil {
log.Printf("[ERR] Error reading data: %s", err) log.Printf("[ERR] Error reading data: %s", err)
return return
} }
}
stream, err := m.openStream(id)
if err != nil {
log.Printf("[ERR] Error opening stream %d: %s", id, err)
return
}
switch packetType {
case muxPacketAck:
stream.mu.Lock()
if stream.state == streamStateSynSent {
stream.setState(streamStateEstablished)
} else {
log.Printf("[ERR] Ack received for stream in state: %d", stream.state)
}
stream.mu.Unlock()
case muxPacketSyn:
stream.mu.Lock()
switch stream.state {
case streamStateClosed:
stream.setState(streamStateSynRecv)
case streamStateListen:
stream.setState(streamStateEstablished)
default:
log.Printf("[ERR] Syn received for stream in state: %d", stream.state)
}
stream.mu.Unlock()
case muxPacketFin:
stream.mu.Lock()
stream.setState(streamStateClosed)
stream.writer.Close()
stream.mu.Unlock()
m.mu.RLock() m.mu.Lock()
w, ok := m.streams[id] delete(m.streams, stream.id)
if ok { m.mu.Unlock()
// Note that if this blocks, it'll block the whole read loop. case muxPacketData:
// Danger here... not sure how to handle it though. stream.mu.Lock()
w.Write(data) if stream.state == streamStateEstablished {
stream.writer.Write(data)
} else {
log.Printf("[ERR] Data received for stream in state: %d", stream.state)
}
stream.mu.Unlock()
} }
m.mu.RUnlock()
} }
} }
func (m *MuxConn) write(id byte, p []byte) (int, error) { func (m *MuxConn) write(id byte, dataType muxPacketType, p []byte) (int, error) {
m.wlock.Lock() m.wlock.Lock()
defer m.wlock.Unlock() defer m.wlock.Unlock()
if err := binary.Write(m.rwc, binary.BigEndian, id); err != nil { if err := binary.Write(m.rwc, binary.BigEndian, id); err != nil {
return 0, err return 0, err
} }
if err := binary.Write(m.rwc, binary.BigEndian, byte(dataType)); err != nil {
return 0, err
}
if err := binary.Write(m.rwc, binary.BigEndian, int32(len(p))); err != nil { if err := binary.Write(m.rwc, binary.BigEndian, int32(len(p))); err != nil {
return 0, err return 0, err
} }
if len(p) == 0 {
return 0, nil
}
return m.rwc.Write(p) return m.rwc.Write(p)
} }
...@@ -137,10 +275,37 @@ type Stream struct { ...@@ -137,10 +275,37 @@ type Stream struct {
id byte id byte
mux *MuxConn mux *MuxConn
reader io.Reader reader io.Reader
writer io.WriteCloser
state streamState
stateUpdated time.Time
mu sync.Mutex
} }
type streamState byte
const (
streamStateClosed streamState = iota
streamStateListen
streamStateSynRecv
streamStateSynSent
streamStateEstablished
streamStateFinWait
)
func (s *Stream) Close() error { func (s *Stream) Close() error {
// Not functional yet, does it ever have to be? s.mu.Lock()
defer s.mu.Unlock()
if s.state != streamStateEstablished {
return fmt.Errorf("Stream in bad state: %d", s.state)
}
if _, err := s.mux.write(s.id, muxPacketFin, nil); err != nil {
return err
}
s.setState(streamStateClosed)
s.writer.Close()
return nil return nil
} }
...@@ -149,5 +314,10 @@ func (s *Stream) Read(p []byte) (int, error) { ...@@ -149,5 +314,10 @@ func (s *Stream) Read(p []byte) (int, error) {
} }
func (s *Stream) Write(p []byte) (int, error) { func (s *Stream) Write(p []byte) (int, error) {
return s.mux.write(s.id, p) return s.mux.write(s.id, muxPacketData, p)
}
func (s *Stream) setState(state streamState) {
s.state = state
s.stateUpdated = time.Now().UTC()
} }
...@@ -56,24 +56,21 @@ func TestMuxConn(t *testing.T) { ...@@ -56,24 +56,21 @@ func TestMuxConn(t *testing.T) {
// When the server is done // When the server is done
doneCh := make(chan struct{}) doneCh := make(chan struct{})
readyCh := make(chan struct{})
// The server side // The server side
go func() { go func() {
defer close(doneCh) defer close(doneCh)
s0, err := server.Stream(0) s0, err := server.Accept(0)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
s1, err := server.Stream(1) s1, err := server.Dial(1)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
close(readyCh)
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(2) wg.Add(2)
...@@ -96,19 +93,16 @@ func TestMuxConn(t *testing.T) { ...@@ -96,19 +93,16 @@ func TestMuxConn(t *testing.T) {
wg.Wait() wg.Wait()
}() }()
s0, err := client.Stream(0) s0, err := client.Dial(0)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
s1, err := client.Stream(1) s1, err := client.Accept(1)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
// Wait for the server to be ready
<-readyCh
if _, err := s0.Write([]byte("hello")); err != nil { if _, err := s0.Write([]byte("hello")); err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
...@@ -124,8 +118,9 @@ func TestMuxConn_clientClosesStreams(t *testing.T) { ...@@ -124,8 +118,9 @@ func TestMuxConn_clientClosesStreams(t *testing.T) {
client, server := testMux(t) client, server := testMux(t)
defer client.Close() defer client.Close()
defer server.Close() defer server.Close()
go server.Accept(0)
s0, err := client.Stream(0) s0, err := client.Dial(0)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
...@@ -146,8 +141,9 @@ func TestMuxConn_serverClosesStreams(t *testing.T) { ...@@ -146,8 +141,9 @@ func TestMuxConn_serverClosesStreams(t *testing.T) {
client, server := testMux(t) client, server := testMux(t)
defer client.Close() defer client.Close()
defer server.Close() defer server.Close()
go server.Accept(0)
s0, err := client.Stream(0) s0, err := client.Dial(0)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
......
package rpc package rpc
import ( import (
"fmt"
"github.com/mitchellh/packer/packer" "github.com/mitchellh/packer/packer"
"net/rpc" "net/rpc"
"sync/atomic"
) )
// This keeps track of the endpoint ID to use when registering artifacts.
var endpointId uint64 = 0
// Registers the appropriate endpoint on an RPC server to serve an // Registers the appropriate endpoint on an RPC server to serve an
// Artifact. // Artifact.
func RegisterArtifact(s *rpc.Server, a packer.Artifact) { func RegisterArtifact(s *rpc.Server, a packer.Artifact) {
...@@ -82,10 +77,6 @@ func RegisterUi(s *rpc.Server, ui packer.Ui) { ...@@ -82,10 +77,6 @@ func RegisterUi(s *rpc.Server, ui packer.Ui) {
// The endpoint name is returned. // The endpoint name is returned.
func registerComponent(s *rpc.Server, name string, rcvr interface{}, id bool) string { func registerComponent(s *rpc.Server, name string, rcvr interface{}, id bool) string {
endpoint := name endpoint := name
if id {
fmt.Sprintf("%s.%d", endpoint, atomic.AddUint64(&endpointId, 1))
}
s.RegisterName(endpoint, rcvr) s.RegisterName(endpoint, rcvr)
return endpoint return endpoint
} }
......
package rpc
import (
"fmt"
"github.com/mitchellh/packer/packer"
"io"
"log"
"net/rpc"
"sync/atomic"
)
// Server represents an RPC server for Packer. This must be paired on
// the other side with a Client.
type Server struct {
endpointId uint64
rpcServer *rpc.Server
}
// NewServer returns a new Packer RPC server.
func NewServer() *Server {
return &Server{
endpointId: 0,
rpcServer: rpc.NewServer(),
}
}
func (s *Server) RegisterArtifact(a packer.Artifact) {
s.registerComponent("Artifact", &ArtifactServer{a}, false)
}
// ServeConn serves a single connection over the RPC server. It is up
// to the caller to obtain a proper io.ReadWriteCloser.
func (s *Server) ServeConn(conn io.ReadWriteCloser) {
mux := NewMuxConn(conn)
defer mux.Close()
// Get stream ID 0, which we always use as the stream for serving
// our RPC server on.
stream, err := mux.Accept(0)
if err != nil {
log.Printf("[ERR] Error retrieving stream for serving: %s", err)
return
}
s.rpcServer.ServeConn(stream)
}
// registerComponent registers a single Packer RPC component onto
// the RPC server. If id is true, then a unique ID number will be appended
// onto the end of the endpoint.
//
// The endpoint name is returned.
func (s *Server) registerComponent(name string, rcvr interface{}, id bool) string {
endpoint := name
if id {
fmt.Sprintf("%s.%d", endpoint, atomic.AddUint64(&s.endpointId, 1))
}
s.rpcServer.RegisterName(endpoint, rcvr)
return endpoint
}
...@@ -10,7 +10,6 @@ import ( ...@@ -10,7 +10,6 @@ import (
// over an RPC connection. // over an RPC connection.
type Ui struct { type Ui struct {
client *rpc.Client client *rpc.Client
endpoint string
} }
// UiServer wraps a packer.Ui implementation and makes it exportable // UiServer wraps a packer.Ui implementation and makes it exportable
...@@ -26,12 +25,12 @@ type UiMachineArgs struct { ...@@ -26,12 +25,12 @@ type UiMachineArgs struct {
} }
func (u *Ui) Ask(query string) (result string, err error) { func (u *Ui) Ask(query string) (result string, err error) {
err = u.client.Call(u.endpoint+".Ask", query, &result) err = u.client.Call("Ui.Ask", query, &result)
return return
} }
func (u *Ui) Error(message string) { func (u *Ui) Error(message string) {
if err := u.client.Call(u.endpoint+".Error", message, new(interface{})); err != nil { if err := u.client.Call("Ui.Error", message, new(interface{})); err != nil {
log.Printf("Error in Ui RPC call: %s", err) log.Printf("Error in Ui RPC call: %s", err)
} }
} }
...@@ -42,19 +41,19 @@ func (u *Ui) Machine(t string, args ...string) { ...@@ -42,19 +41,19 @@ func (u *Ui) Machine(t string, args ...string) {
Args: args, Args: args,
} }
if err := u.client.Call(u.endpoint+".Machine", rpcArgs, new(interface{})); err != nil { if err := u.client.Call("Ui.Machine", rpcArgs, new(interface{})); err != nil {
log.Printf("Error in Ui RPC call: %s", err) log.Printf("Error in Ui RPC call: %s", err)
} }
} }
func (u *Ui) Message(message string) { func (u *Ui) Message(message string) {
if err := u.client.Call(u.endpoint+".Message", message, new(interface{})); err != nil { if err := u.client.Call("Ui.Message", message, new(interface{})); err != nil {
log.Printf("Error in Ui RPC call: %s", err) log.Printf("Error in Ui RPC call: %s", err)
} }
} }
func (u *Ui) Say(message string) { func (u *Ui) Say(message string) {
if err := u.client.Call(u.endpoint+".Say", message, new(interface{})); err != nil { if err := u.client.Call("Ui.Say", message, new(interface{})); err != nil {
log.Printf("Error in Ui RPC call: %s", err) log.Printf("Error in Ui RPC call: %s", err)
} }
} }
......
...@@ -62,7 +62,7 @@ func TestUiRPC(t *testing.T) { ...@@ -62,7 +62,7 @@ func TestUiRPC(t *testing.T) {
panic(err) panic(err)
} }
uiClient := &Ui{client: client, endpoint: "Ui"} uiClient := &Ui{client: client}
// Basic error and say tests // Basic error and say tests
result, err := uiClient.Ask("query") result, err := uiClient.Ask("query")
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment