Commit 9ffa0b8e authored by Mitchell Hashimoto's avatar Mitchell Hashimoto

packer/rpc: no more muxconn

parent 062e86e2
...@@ -9,14 +9,14 @@ import ( ...@@ -9,14 +9,14 @@ import (
// over an RPC connection. // over an RPC connection.
type build struct { type build struct {
client *rpc.Client client *rpc.Client
mux *MuxConn mux *muxBroker
} }
// BuildServer wraps a packer.Build implementation and makes it exportable // BuildServer wraps a packer.Build implementation and makes it exportable
// as part of a Golang RPC server. // as part of a Golang RPC server.
type BuildServer struct { type BuildServer struct {
build packer.Build build packer.Build
mux *MuxConn mux *muxBroker
} }
type BuildPrepareResponse struct { type BuildPrepareResponse struct {
......
...@@ -10,14 +10,14 @@ import ( ...@@ -10,14 +10,14 @@ import (
// over an RPC connection. // over an RPC connection.
type builder struct { type builder struct {
client *rpc.Client client *rpc.Client
mux *MuxConn mux *muxBroker
} }
// BuilderServer wraps a packer.Builder implementation and makes it exportable // BuilderServer wraps a packer.Builder implementation and makes it exportable
// as part of a Golang RPC server. // as part of a Golang RPC server.
type BuilderServer struct { type BuilderServer struct {
builder packer.Builder builder packer.Builder
mux *MuxConn mux *muxBroker
} }
type BuilderPrepareArgs struct { type BuilderPrepareArgs struct {
......
...@@ -12,22 +12,29 @@ import ( ...@@ -12,22 +12,29 @@ import (
// Establishing a connection is up to the user, the Client can just // Establishing a connection is up to the user, the Client can just
// communicate over any ReadWriteCloser. // communicate over any ReadWriteCloser.
type Client struct { type Client struct {
mux *MuxConn mux *muxBroker
client *rpc.Client client *rpc.Client
closeMux bool closeMux bool
} }
func NewClient(rwc io.ReadWriteCloser) (*Client, error) { func NewClient(rwc io.ReadWriteCloser) (*Client, error) {
result, err := newClientWithMux(NewMuxConn(rwc), 0) mux, err := newMuxBrokerClient(rwc)
if err != nil { if err != nil {
return nil, err return nil, err
} }
go mux.Run()
result, err := newClientWithMux(mux, 0)
if err != nil {
mux.Close()
return nil, err
}
result.closeMux = true result.closeMux = true
return result, err return result, err
} }
func newClientWithMux(mux *MuxConn, streamId uint32) (*Client, error) { func newClientWithMux(mux *muxBroker, streamId uint32) (*Client, error) {
clientConn, err := mux.Dial(streamId) clientConn, err := mux.Dial(streamId)
if err != nil { if err != nil {
return nil, err return nil, err
......
...@@ -9,14 +9,14 @@ import ( ...@@ -9,14 +9,14 @@ import (
// command is actually executed over an RPC connection. // command is actually executed over an RPC connection.
type command struct { type command struct {
client *rpc.Client client *rpc.Client
mux *MuxConn mux *muxBroker
} }
// A CommandServer wraps a packer.Command and makes it exportable as part // A CommandServer wraps a packer.Command and makes it exportable as part
// of a Golang RPC server. // of a Golang RPC server.
type CommandServer struct { type CommandServer struct {
command packer.Command command packer.Command
mux *MuxConn mux *muxBroker
} }
type CommandRunArgs struct { type CommandRunArgs struct {
......
...@@ -12,14 +12,14 @@ import ( ...@@ -12,14 +12,14 @@ import (
// executed over an RPC connection. // executed over an RPC connection.
type communicator struct { type communicator struct {
client *rpc.Client client *rpc.Client
mux *MuxConn mux *muxBroker
} }
// CommunicatorServer wraps a packer.Communicator implementation and makes // CommunicatorServer wraps a packer.Communicator implementation and makes
// it exportable as part of a Golang RPC server. // it exportable as part of a Golang RPC server.
type CommunicatorServer struct { type CommunicatorServer struct {
c packer.Communicator c packer.Communicator
mux *MuxConn mux *muxBroker
} }
type CommandFinished struct { type CommandFinished struct {
...@@ -252,7 +252,7 @@ func (c *CommunicatorServer) Download(args *CommunicatorDownloadArgs, reply *int ...@@ -252,7 +252,7 @@ func (c *CommunicatorServer) Download(args *CommunicatorDownloadArgs, reply *int
return return
} }
func serveSingleCopy(name string, mux *MuxConn, id uint32, dst io.Writer, src io.Reader) { func serveSingleCopy(name string, mux *muxBroker, id uint32, dst io.Writer, src io.Reader) {
conn, err := mux.Accept(id) conn, err := mux.Accept(id)
if err != nil { if err != nil {
log.Printf("[ERR] '%s' accept error: %s", name, err) log.Printf("[ERR] '%s' accept error: %s", name, err)
......
...@@ -10,14 +10,14 @@ import ( ...@@ -10,14 +10,14 @@ import (
// where the actual environment is executed over an RPC connection. // where the actual environment is executed over an RPC connection.
type Environment struct { type Environment struct {
client *rpc.Client client *rpc.Client
mux *MuxConn mux *muxBroker
} }
// A EnvironmentServer wraps a packer.Environment and makes it exportable // A EnvironmentServer wraps a packer.Environment and makes it exportable
// as part of a Golang RPC server. // as part of a Golang RPC server.
type EnvironmentServer struct { type EnvironmentServer struct {
env packer.Environment env packer.Environment
mux *MuxConn mux *muxBroker
} }
type EnvironmentCliArgs struct { type EnvironmentCliArgs struct {
......
...@@ -10,14 +10,14 @@ import ( ...@@ -10,14 +10,14 @@ import (
// over an RPC connection. // over an RPC connection.
type hook struct { type hook struct {
client *rpc.Client client *rpc.Client
mux *MuxConn mux *muxBroker
} }
// HookServer wraps a packer.Hook implementation and makes it exportable // HookServer wraps a packer.Hook implementation and makes it exportable
// as part of a Golang RPC server. // as part of a Golang RPC server.
type HookServer struct { type HookServer struct {
hook packer.Hook hook packer.Hook
mux *MuxConn mux *muxBroker
} }
type HookRunArgs struct { type HookRunArgs struct {
......
...@@ -3,8 +3,10 @@ package rpc ...@@ -3,8 +3,10 @@ package rpc
import ( import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"io"
"net" "net"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/hashicorp/yamux" "github.com/hashicorp/yamux"
...@@ -16,6 +18,7 @@ import ( ...@@ -16,6 +18,7 @@ import (
// or accept a connection from, and the broker handles the details of // or accept a connection from, and the broker handles the details of
// holding these channels open while they're being negotiated. // holding these channels open while they're being negotiated.
type muxBroker struct { type muxBroker struct {
nextId uint32
session *yamux.Session session *yamux.Session
streams map[uint32]*muxBrokerPending streams map[uint32]*muxBrokerPending
...@@ -34,6 +37,24 @@ func newMuxBroker(s *yamux.Session) *muxBroker { ...@@ -34,6 +37,24 @@ func newMuxBroker(s *yamux.Session) *muxBroker {
} }
} }
func newMuxBrokerClient(rwc io.ReadWriteCloser) (*muxBroker, error) {
s, err := yamux.Client(rwc, nil)
if err != nil {
return nil, err
}
return newMuxBroker(s), nil
}
func newMuxBrokerServer(rwc io.ReadWriteCloser) (*muxBroker, error) {
s, err := yamux.Server(rwc, nil)
if err != nil {
return nil, err
}
return newMuxBroker(s), nil
}
// Accept accepts a connection by ID. // Accept accepts a connection by ID.
// //
// This should not be called multiple times with the same ID at one time. // This should not be called multiple times with the same ID at one time.
...@@ -60,6 +81,11 @@ func (m *muxBroker) Accept(id uint32) (net.Conn, error) { ...@@ -60,6 +81,11 @@ func (m *muxBroker) Accept(id uint32) (net.Conn, error) {
return c, nil return c, nil
} }
// Close closes the connection and all sub-connections.
func (m *muxBroker) Close() error {
return m.session.Close()
}
// Dial opens a connection by ID. // Dial opens a connection by ID.
func (m *muxBroker) Dial(id uint32) (net.Conn, error) { func (m *muxBroker) Dial(id uint32) (net.Conn, error) {
// Open the stream // Open the stream
...@@ -88,6 +114,11 @@ func (m *muxBroker) Dial(id uint32) (net.Conn, error) { ...@@ -88,6 +114,11 @@ func (m *muxBroker) Dial(id uint32) (net.Conn, error) {
return stream, nil return stream, nil
} }
// NextId returns a unique ID to use next.
func (m *muxBroker) NextId() uint32 {
return atomic.AddUint32(&m.nextId, 1)
}
// Run starts the brokering and should be executed in a goroutine, since it // Run starts the brokering and should be executed in a goroutine, since it
// blocks forever, or until the session closes. // blocks forever, or until the session closes.
func (m *muxBroker) Run() { func (m *muxBroker) Run() {
......
This diff is collapsed.
package rpc
import (
"io"
"net"
"sync"
"testing"
)
func readStream(t *testing.T, s io.Reader) string {
var data [1024]byte
n, err := s.Read(data[:])
if err != nil {
t.Fatalf("err: %s", err)
}
return string(data[0:n])
}
func testMux(t *testing.T) (client *MuxConn, server *MuxConn) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("err: %s", err)
}
// Server side
doneCh := make(chan struct{})
go func() {
defer close(doneCh)
conn, err := l.Accept()
l.Close()
if err != nil {
t.Fatalf("err: %s", err)
}
server = NewMuxConn(conn)
}()
// Client side
conn, err := net.Dial("tcp", l.Addr().String())
if err != nil {
t.Fatalf("err: %s", err)
}
client = NewMuxConn(conn)
// Wait for the server
<-doneCh
return
}
func TestMuxConn(t *testing.T) {
client, server := testMux(t)
defer client.Close()
defer server.Close()
// When the server is done
doneCh := make(chan struct{})
// The server side
go func() {
defer close(doneCh)
s0, err := server.Accept(0)
if err != nil {
t.Fatalf("err: %s", err)
}
s1, err := server.Dial(1)
if err != nil {
t.Fatalf("err: %s", err)
}
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
defer s1.Close()
data := readStream(t, s1)
if data != "another" {
t.Fatalf("bad: %#v", data)
}
}()
go func() {
defer wg.Done()
defer s0.Close()
data := readStream(t, s0)
if data != "hello" {
t.Fatalf("bad: %#v", data)
}
}()
wg.Wait()
}()
s0, err := client.Dial(0)
if err != nil {
t.Fatalf("err: %s", err)
}
s1, err := client.Accept(1)
if err != nil {
t.Fatalf("err: %s", err)
}
if _, err := s0.Write([]byte("hello")); err != nil {
t.Fatalf("err: %s", err)
}
if _, err := s1.Write([]byte("another")); err != nil {
t.Fatalf("err: %s", err)
}
s0.Close()
s1.Close()
// Wait for the server to be done
<-doneCh
}
func TestMuxConn_lotsOfData(t *testing.T) {
client, server := testMux(t)
defer client.Close()
defer server.Close()
// When the server is done
doneCh := make(chan struct{})
// The server side
go func() {
defer close(doneCh)
s0, err := server.Accept(0)
if err != nil {
t.Fatalf("err: %s", err)
}
var data [1024]byte
for {
n, err := s0.Read(data[:])
if err == io.EOF {
break
}
dataString := string(data[0:n])
if dataString != "hello" {
t.Fatalf("bad: %#v", dataString)
}
}
s0.Close()
}()
s0, err := client.Dial(0)
if err != nil {
t.Fatalf("err: %s", err)
}
for i := 0; i < 4096*4; i++ {
if _, err := s0.Write([]byte("hello")); err != nil {
t.Fatalf("err: %s", err)
}
}
if err := s0.Close(); err != nil {
t.Fatalf("err: %s", err)
}
// Wait for the server to be done
<-doneCh
}
// This tests that even when the client end is closed, data can be
// read from the server.
func TestMuxConn_clientCloseRead(t *testing.T) {
client, server := testMux(t)
defer client.Close()
defer server.Close()
// This channel will be closed when we close
waitCh := make(chan struct{})
go func() {
conn, err := server.Accept(0)
if err != nil {
t.Fatalf("err: %s", err)
}
<-waitCh
_, err = conn.Write([]byte("foo"))
if err != nil {
t.Fatalf("err: %s", err)
}
conn.Close()
}()
s0, err := client.Dial(0)
if err != nil {
t.Fatalf("err: %s", err)
}
if err := s0.Close(); err != nil {
t.Fatalf("bad: %s", err)
}
// Close this to continue on on the server-side
close(waitCh)
var data [1024]byte
n, err := s0.Read(data[:])
if string(data[:n]) != "foo" {
t.Fatalf("bad: %#v", string(data[:n]))
}
}
func TestMuxConn_socketClose(t *testing.T) {
client, server := testMux(t)
defer client.Close()
defer server.Close()
go func() {
_, err := server.Accept(0)
if err != nil {
t.Fatalf("err: %s", err)
}
server.rwc.Close()
}()
s0, err := client.Dial(0)
if err != nil {
t.Fatalf("err: %s", err)
}
var data [1024]byte
_, err = s0.Read(data[:])
if err != io.EOF {
t.Fatalf("err: %s", err)
}
}
func TestMuxConn_clientClosesStreams(t *testing.T) {
client, server := testMux(t)
defer client.Close()
defer server.Close()
go func() {
conn, err := server.Accept(0)
if err != nil {
t.Fatalf("err: %s", err)
}
conn.Close()
}()
s0, err := client.Dial(0)
if err != nil {
t.Fatalf("err: %s", err)
}
var data [1024]byte
_, err = s0.Read(data[:])
if err != io.EOF {
t.Fatalf("err: %s", err)
}
}
func TestMuxConn_serverClosesStreams(t *testing.T) {
client, server := testMux(t)
defer client.Close()
defer server.Close()
go server.Accept(0)
s0, err := client.Dial(0)
if err != nil {
t.Fatalf("err: %s", err)
}
if err := server.Close(); err != nil {
t.Fatalf("err: %s", err)
}
// This should block forever since we never write onto this stream.
var data [1024]byte
_, err = s0.Read(data[:])
if err != io.EOF {
t.Fatalf("err: %s", err)
}
}
func TestMuxConnNextId(t *testing.T) {
client, server := testMux(t)
defer client.Close()
defer server.Close()
a := client.NextId()
b := client.NextId()
if a != 1 || b != 2 {
t.Fatalf("IDs should increment")
}
a = server.NextId()
b = server.NextId()
if a != 1 || b != 2 {
t.Fatalf("IDs should increment: %d %d", a, b)
}
}
...@@ -9,14 +9,14 @@ import ( ...@@ -9,14 +9,14 @@ import (
// executed over an RPC connection. // executed over an RPC connection.
type postProcessor struct { type postProcessor struct {
client *rpc.Client client *rpc.Client
mux *MuxConn mux *muxBroker
} }
// PostProcessorServer wraps a packer.PostProcessor implementation and makes it // PostProcessorServer wraps a packer.PostProcessor implementation and makes it
// exportable as part of a Golang RPC server. // exportable as part of a Golang RPC server.
type PostProcessorServer struct { type PostProcessorServer struct {
client *rpc.Client client *rpc.Client
mux *MuxConn mux *muxBroker
p packer.PostProcessor p packer.PostProcessor
} }
......
...@@ -10,14 +10,14 @@ import ( ...@@ -10,14 +10,14 @@ import (
// executed over an RPC connection. // executed over an RPC connection.
type provisioner struct { type provisioner struct {
client *rpc.Client client *rpc.Client
mux *MuxConn mux *muxBroker
} }
// ProvisionerServer wraps a packer.Provisioner implementation and makes it // ProvisionerServer wraps a packer.Provisioner implementation and makes it
// exportable as part of a Golang RPC server. // exportable as part of a Golang RPC server.
type ProvisionerServer struct { type ProvisionerServer struct {
p packer.Provisioner p packer.Provisioner
mux *MuxConn mux *muxBroker
} }
type ProvisionerPrepareArgs struct { type ProvisionerPrepareArgs struct {
......
...@@ -29,7 +29,7 @@ const ( ...@@ -29,7 +29,7 @@ const (
// Server represents an RPC server for Packer. This must be paired on // Server represents an RPC server for Packer. This must be paired on
// the other side with a Client. // the other side with a Client.
type Server struct { type Server struct {
mux *MuxConn mux *muxBroker
streamId uint32 streamId uint32
server *rpc.Server server *rpc.Server
closeMux bool closeMux bool
...@@ -37,12 +37,14 @@ type Server struct { ...@@ -37,12 +37,14 @@ type Server struct {
// NewServer returns a new Packer RPC server. // NewServer returns a new Packer RPC server.
func NewServer(conn io.ReadWriteCloser) *Server { func NewServer(conn io.ReadWriteCloser) *Server {
result := newServerWithMux(NewMuxConn(conn), 0) mux, _ := newMuxBrokerServer(conn)
result := newServerWithMux(mux, 0)
result.closeMux = true result.closeMux = true
go mux.Run()
return result return result
} }
func newServerWithMux(mux *MuxConn, streamId uint32) *Server { func newServerWithMux(mux *muxBroker, streamId uint32) *Server {
return &Server{ return &Server{
mux: mux, mux: mux,
streamId: streamId, streamId: streamId,
...@@ -140,11 +142,11 @@ func (s *Server) Serve() { ...@@ -140,11 +142,11 @@ func (s *Server) Serve() {
// Accept a connection on stream ID 0, which is always used for // Accept a connection on stream ID 0, which is always used for
// normal client to server connections. // normal client to server connections.
stream, err := s.mux.Accept(s.streamId) stream, err := s.mux.Accept(s.streamId)
defer stream.Close()
if err != nil { if err != nil {
log.Printf("[ERR] Error retrieving stream for serving: %s", err) log.Printf("[ERR] Error retrieving stream for serving: %s", err)
return return
} }
defer stream.Close()
var h codec.MsgpackHandle var h codec.MsgpackHandle
rpcCodec := codec.GoRpc.ServerCodec(stream, &h) rpcCodec := codec.GoRpc.ServerCodec(stream, &h)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment