Commit 50cfb678 authored by Mitchell Hashimoto's avatar Mitchell Hashimoto

packer/rpc: muxconn is a lot more sane, acts like bsd socket

parent 36a47f5b
......@@ -2,7 +2,6 @@ package rpc
import (
"github.com/mitchellh/packer/packer"
"net/rpc"
"reflect"
"testing"
)
......@@ -31,19 +30,13 @@ func (testArtifact) Destroy() error {
func TestArtifactRPC(t *testing.T) {
// Create the interface to test
a := new(testArtifact)
a := new(packer.MockArtifact)
// Start the server
server := rpc.NewServer()
RegisterArtifact(server, a)
address := serveSingleConn(server)
// Create the client over RPC and run some methods to verify it works
client, err := rpc.Dial("tcp", address)
if err != nil {
t.Fatalf("err: %s", err)
}
aClient := Artifact(client)
server := NewServer()
server.RegisterArtifact(a)
client := testClient(t, server)
aClient := client.Artifact()
// Test
if aClient.BuilderId() != "bid" {
......
package rpc
import (
"github.com/mitchellh/packer/packer"
"io"
"net/rpc"
)
// Client is the client end that communicates with a Packer RPC server.
// Establishing a connection is up to the user, the Client can just
// communicate over any ReadWriteCloser.
type Client struct {
mux *MuxConn
client *rpc.Client
}
func NewClient(rwc io.ReadWriteCloser) (*Client, error) {
// Create the MuxConn around the RWC and get the client to server stream.
// This is the primary stream that we use to communicate with the
// remote RPC server. On the remote side Server.ServeConn also listens
// on this stream ID.
mux := NewMuxConn(rwc)
stream, err := mux.Dial(0)
if err != nil {
return nil, err
}
return &Client{
mux: mux,
client: rpc.NewClient(stream),
}, nil
}
func (c *Client) Artifact() packer.Artifact {
return &artifact{
client: c.client,
}
}
package rpc
import (
"testing"
)
func testClient(t *testing.T, server *Server) *Client {
return nil
}
......@@ -6,6 +6,7 @@ import (
"io"
"log"
"sync"
"time"
)
// MuxConn is a connection that can be used bi-directionally for RPC. Normally,
......@@ -20,15 +21,24 @@ import (
// we decided to cut a lot of corners and make this easily usable for Packer.
type MuxConn struct {
rwc io.ReadWriteCloser
streams map[byte]io.WriteCloser
streams map[byte]*Stream
mu sync.RWMutex
wlock sync.Mutex
}
type muxPacketType byte
const (
muxPacketSyn muxPacketType = iota
muxPacketAck
muxPacketFin
muxPacketData
)
func NewMuxConn(rwc io.ReadWriteCloser) *MuxConn {
m := &MuxConn{
rwc: rwc,
streams: make(map[byte]io.WriteCloser),
streams: make(map[byte]*Stream),
}
go m.loop()
......@@ -46,56 +56,140 @@ func (m *MuxConn) Close() error {
for _, w := range m.streams {
w.Close()
}
m.streams = make(map[byte]io.WriteCloser)
m.streams = make(map[byte]*Stream)
return m.rwc.Close()
}
// Stream returns a io.ReadWriteCloser that will only read/write to the
// given stream ID. No handshake is done so if the remote end does not
// have a stream open with the same ID, then the messages will simply
// be dropped.
//
// This is one of those cases where we cut corners. Since Packer only does
// local connections, we can assume that both ends are ready at a certain
// point. In a real muxer, we'd probably want a handshake here.
func (m *MuxConn) Stream(id byte) (io.ReadWriteCloser, error) {
// Accept accepts a multiplexed connection with the given ID. This
// will block until a request is made to connect.
func (m *MuxConn) Accept(id byte) (io.ReadWriteCloser, error) {
stream, err := m.openStream(id)
if err != nil {
return nil, err
}
// If the stream isn't closed, then it is already open somehow
stream.mu.Lock()
if stream.state != streamStateSynRecv && stream.state != streamStateClosed {
stream.mu.Unlock()
return nil, fmt.Errorf("Stream already open in bad state: %d", stream.state)
}
if stream.state == streamStateSynRecv {
// Fast track establishing since we already got the syn
stream.setState(streamStateEstablished)
stream.mu.Unlock()
}
if stream.state != streamStateEstablished {
// Go into the listening state
stream.setState(streamStateListen)
stream.mu.Unlock()
// Wait for the connection to establish
ACCEPT_ESTABLISH_LOOP:
for {
time.Sleep(50 * time.Millisecond)
stream.mu.Lock()
switch stream.state {
case streamStateListen:
stream.mu.Unlock()
case streamStateEstablished:
stream.mu.Unlock()
break ACCEPT_ESTABLISH_LOOP
default:
defer stream.mu.Unlock()
return nil, fmt.Errorf("Stream went to bad state: %d", stream.state)
}
}
}
// Send the ack down
if _, err := m.write(stream.id, muxPacketAck, nil); err != nil {
return nil, err
}
return stream, nil
}
// Dial opens a connection to the remote end using the given stream ID.
// An Accept on the remote end will only work with if the IDs match.
func (m *MuxConn) Dial(id byte) (io.ReadWriteCloser, error) {
stream, err := m.openStream(id)
if err != nil {
return nil, err
}
// If the stream isn't closed, then it is already open somehow
stream.mu.Lock()
if stream.state != streamStateClosed {
stream.mu.Unlock()
return nil, fmt.Errorf("Stream already open in bad state: %d", stream.state)
}
// Open a connection
if _, err := m.write(stream.id, muxPacketSyn, nil); err != nil {
return nil, err
}
stream.setState(streamStateSynSent)
stream.mu.Unlock()
for {
time.Sleep(50 * time.Millisecond)
stream.mu.Lock()
switch stream.state {
case streamStateSynSent:
stream.mu.Unlock()
case streamStateEstablished:
stream.mu.Unlock()
return stream, nil
default:
defer stream.mu.Unlock()
return nil, fmt.Errorf("Stream went to bad state: %d", stream.state)
}
}
}
func (m *MuxConn) openStream(id byte) (*Stream, error) {
m.mu.Lock()
defer m.mu.Unlock()
if _, ok := m.streams[id]; ok {
m.mu.Unlock()
return nil, fmt.Errorf("Stream %d already exists", id)
if stream, ok := m.streams[id]; ok {
return stream, nil
}
// Create the stream object and channel where data will be sent to
dataR, dataW := io.Pipe()
// Set the data channel so we can write to it.
m.streams[id] = dataW
// Unlock the lock so that the reader can access the stream writer.
m.mu.Unlock()
stream := &Stream{
id: id,
mux: m,
reader: dataR,
writer: dataW,
}
stream.setState(streamStateClosed)
return stream, nil
m.streams[id] = stream
return m.streams[id], nil
}
func (m *MuxConn) loop() {
defer m.Close()
var id byte
var packetType muxPacketType
var length int32
for {
var id byte
var length int32
if err := binary.Read(m.rwc, binary.BigEndian, &id); err != nil {
log.Printf("[ERR] Error reading stream ID: %s", err)
return
}
if err := binary.Read(m.rwc, binary.BigEndian, &packetType); err != nil {
log.Printf("[ERR] Error reading packet type: %s", err)
return
}
if err := binary.Read(m.rwc, binary.BigEndian, &length); err != nil {
log.Printf("[ERR] Error reading length: %s", err)
return
......@@ -103,44 +197,115 @@ func (m *MuxConn) loop() {
// TODO(mitchellh): probably would be better to re-use a buffer...
data := make([]byte, length)
if _, err := m.rwc.Read(data); err != nil {
log.Printf("[ERR] Error reading data: %s", err)
if length > 0 {
if _, err := m.rwc.Read(data); err != nil {
log.Printf("[ERR] Error reading data: %s", err)
return
}
}
stream, err := m.openStream(id)
if err != nil {
log.Printf("[ERR] Error opening stream %d: %s", id, err)
return
}
m.mu.RLock()
w, ok := m.streams[id]
if ok {
// Note that if this blocks, it'll block the whole read loop.
// Danger here... not sure how to handle it though.
w.Write(data)
switch packetType {
case muxPacketAck:
stream.mu.Lock()
if stream.state == streamStateSynSent {
stream.setState(streamStateEstablished)
} else {
log.Printf("[ERR] Ack received for stream in state: %d", stream.state)
}
stream.mu.Unlock()
case muxPacketSyn:
stream.mu.Lock()
switch stream.state {
case streamStateClosed:
stream.setState(streamStateSynRecv)
case streamStateListen:
stream.setState(streamStateEstablished)
default:
log.Printf("[ERR] Syn received for stream in state: %d", stream.state)
}
stream.mu.Unlock()
case muxPacketFin:
stream.mu.Lock()
stream.setState(streamStateClosed)
stream.writer.Close()
stream.mu.Unlock()
m.mu.Lock()
delete(m.streams, stream.id)
m.mu.Unlock()
case muxPacketData:
stream.mu.Lock()
if stream.state == streamStateEstablished {
stream.writer.Write(data)
} else {
log.Printf("[ERR] Data received for stream in state: %d", stream.state)
}
stream.mu.Unlock()
}
m.mu.RUnlock()
}
}
func (m *MuxConn) write(id byte, p []byte) (int, error) {
func (m *MuxConn) write(id byte, dataType muxPacketType, p []byte) (int, error) {
m.wlock.Lock()
defer m.wlock.Unlock()
if err := binary.Write(m.rwc, binary.BigEndian, id); err != nil {
return 0, err
}
if err := binary.Write(m.rwc, binary.BigEndian, byte(dataType)); err != nil {
return 0, err
}
if err := binary.Write(m.rwc, binary.BigEndian, int32(len(p))); err != nil {
return 0, err
}
if len(p) == 0 {
return 0, nil
}
return m.rwc.Write(p)
}
// Stream is a single stream of data and implements io.ReadWriteCloser
type Stream struct {
id byte
mux *MuxConn
reader io.Reader
id byte
mux *MuxConn
reader io.Reader
writer io.WriteCloser
state streamState
stateUpdated time.Time
mu sync.Mutex
}
type streamState byte
const (
streamStateClosed streamState = iota
streamStateListen
streamStateSynRecv
streamStateSynSent
streamStateEstablished
streamStateFinWait
)
func (s *Stream) Close() error {
// Not functional yet, does it ever have to be?
s.mu.Lock()
defer s.mu.Unlock()
if s.state != streamStateEstablished {
return fmt.Errorf("Stream in bad state: %d", s.state)
}
if _, err := s.mux.write(s.id, muxPacketFin, nil); err != nil {
return err
}
s.setState(streamStateClosed)
s.writer.Close()
return nil
}
......@@ -149,5 +314,10 @@ func (s *Stream) Read(p []byte) (int, error) {
}
func (s *Stream) Write(p []byte) (int, error) {
return s.mux.write(s.id, p)
return s.mux.write(s.id, muxPacketData, p)
}
func (s *Stream) setState(state streamState) {
s.state = state
s.stateUpdated = time.Now().UTC()
}
......@@ -56,24 +56,21 @@ func TestMuxConn(t *testing.T) {
// When the server is done
doneCh := make(chan struct{})
readyCh := make(chan struct{})
// The server side
go func() {
defer close(doneCh)
s0, err := server.Stream(0)
s0, err := server.Accept(0)
if err != nil {
t.Fatalf("err: %s", err)
}
s1, err := server.Stream(1)
s1, err := server.Dial(1)
if err != nil {
t.Fatalf("err: %s", err)
}
close(readyCh)
var wg sync.WaitGroup
wg.Add(2)
......@@ -96,19 +93,16 @@ func TestMuxConn(t *testing.T) {
wg.Wait()
}()
s0, err := client.Stream(0)
s0, err := client.Dial(0)
if err != nil {
t.Fatalf("err: %s", err)
}
s1, err := client.Stream(1)
s1, err := client.Accept(1)
if err != nil {
t.Fatalf("err: %s", err)
}
// Wait for the server to be ready
<-readyCh
if _, err := s0.Write([]byte("hello")); err != nil {
t.Fatalf("err: %s", err)
}
......@@ -124,8 +118,9 @@ func TestMuxConn_clientClosesStreams(t *testing.T) {
client, server := testMux(t)
defer client.Close()
defer server.Close()
go server.Accept(0)
s0, err := client.Stream(0)
s0, err := client.Dial(0)
if err != nil {
t.Fatalf("err: %s", err)
}
......@@ -146,8 +141,9 @@ func TestMuxConn_serverClosesStreams(t *testing.T) {
client, server := testMux(t)
defer client.Close()
defer server.Close()
go server.Accept(0)
s0, err := client.Stream(0)
s0, err := client.Dial(0)
if err != nil {
t.Fatalf("err: %s", err)
}
......
package rpc
import (
"fmt"
"github.com/mitchellh/packer/packer"
"net/rpc"
"sync/atomic"
)
// This keeps track of the endpoint ID to use when registering artifacts.
var endpointId uint64 = 0
// Registers the appropriate endpoint on an RPC server to serve an
// Artifact.
func RegisterArtifact(s *rpc.Server, a packer.Artifact) {
......@@ -82,10 +77,6 @@ func RegisterUi(s *rpc.Server, ui packer.Ui) {
// The endpoint name is returned.
func registerComponent(s *rpc.Server, name string, rcvr interface{}, id bool) string {
endpoint := name
if id {
fmt.Sprintf("%s.%d", endpoint, atomic.AddUint64(&endpointId, 1))
}
s.RegisterName(endpoint, rcvr)
return endpoint
}
......
package rpc
import (
"fmt"
"github.com/mitchellh/packer/packer"
"io"
"log"
"net/rpc"
"sync/atomic"
)
// Server represents an RPC server for Packer. This must be paired on
// the other side with a Client.
type Server struct {
endpointId uint64
rpcServer *rpc.Server
}
// NewServer returns a new Packer RPC server.
func NewServer() *Server {
return &Server{
endpointId: 0,
rpcServer: rpc.NewServer(),
}
}
func (s *Server) RegisterArtifact(a packer.Artifact) {
s.registerComponent("Artifact", &ArtifactServer{a}, false)
}
// ServeConn serves a single connection over the RPC server. It is up
// to the caller to obtain a proper io.ReadWriteCloser.
func (s *Server) ServeConn(conn io.ReadWriteCloser) {
mux := NewMuxConn(conn)
defer mux.Close()
// Get stream ID 0, which we always use as the stream for serving
// our RPC server on.
stream, err := mux.Accept(0)
if err != nil {
log.Printf("[ERR] Error retrieving stream for serving: %s", err)
return
}
s.rpcServer.ServeConn(stream)
}
// registerComponent registers a single Packer RPC component onto
// the RPC server. If id is true, then a unique ID number will be appended
// onto the end of the endpoint.
//
// The endpoint name is returned.
func (s *Server) registerComponent(name string, rcvr interface{}, id bool) string {
endpoint := name
if id {
fmt.Sprintf("%s.%d", endpoint, atomic.AddUint64(&s.endpointId, 1))
}
s.rpcServer.RegisterName(endpoint, rcvr)
return endpoint
}
......@@ -9,8 +9,7 @@ import (
// An implementation of packer.Ui where the Ui is actually executed
// over an RPC connection.
type Ui struct {
client *rpc.Client
endpoint string
client *rpc.Client
}
// UiServer wraps a packer.Ui implementation and makes it exportable
......@@ -26,12 +25,12 @@ type UiMachineArgs struct {
}
func (u *Ui) Ask(query string) (result string, err error) {
err = u.client.Call(u.endpoint+".Ask", query, &result)
err = u.client.Call("Ui.Ask", query, &result)
return
}
func (u *Ui) Error(message string) {
if err := u.client.Call(u.endpoint+".Error", message, new(interface{})); err != nil {
if err := u.client.Call("Ui.Error", message, new(interface{})); err != nil {
log.Printf("Error in Ui RPC call: %s", err)
}
}
......@@ -42,19 +41,19 @@ func (u *Ui) Machine(t string, args ...string) {
Args: args,
}
if err := u.client.Call(u.endpoint+".Machine", rpcArgs, new(interface{})); err != nil {
if err := u.client.Call("Ui.Machine", rpcArgs, new(interface{})); err != nil {
log.Printf("Error in Ui RPC call: %s", err)
}
}
func (u *Ui) Message(message string) {
if err := u.client.Call(u.endpoint+".Message", message, new(interface{})); err != nil {
if err := u.client.Call("Ui.Message", message, new(interface{})); err != nil {
log.Printf("Error in Ui RPC call: %s", err)
}
}
func (u *Ui) Say(message string) {
if err := u.client.Call(u.endpoint+".Say", message, new(interface{})); err != nil {
if err := u.client.Call("Ui.Say", message, new(interface{})); err != nil {
log.Printf("Error in Ui RPC call: %s", err)
}
}
......
......@@ -62,7 +62,7 @@ func TestUiRPC(t *testing.T) {
panic(err)
}
uiClient := &Ui{client: client, endpoint: "Ui"}
uiClient := &Ui{client: client}
// Basic error and say tests
result, err := uiClient.Ask("query")
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment