Commit ec68a3fd authored by Mitchell Hashimoto's avatar Mitchell Hashimoto

packer/rpc: MuxConn can return next available stream ID

parent 171781c3
......@@ -18,8 +18,9 @@ import (
// are established using a subset of the TCP protocol. Only a subset is
// necessary since we assume ordering on the underlying RWC.
type MuxConn struct {
curId uint32
rwc io.ReadWriteCloser
streams map[byte]*Stream
streams map[uint32]*Stream
mu sync.RWMutex
wlock sync.Mutex
}
......@@ -36,7 +37,7 @@ const (
func NewMuxConn(rwc io.ReadWriteCloser) *MuxConn {
m := &MuxConn{
rwc: rwc,
streams: make(map[byte]*Stream),
streams: make(map[uint32]*Stream),
}
go m.loop()
......@@ -54,14 +55,14 @@ func (m *MuxConn) Close() error {
for _, w := range m.streams {
w.Close()
}
m.streams = make(map[byte]*Stream)
m.streams = make(map[uint32]*Stream)
return m.rwc.Close()
}
// Accept accepts a multiplexed connection with the given ID. This
// will block until a request is made to connect.
func (m *MuxConn) Accept(id byte) (io.ReadWriteCloser, error) {
func (m *MuxConn) Accept(id uint32) (io.ReadWriteCloser, error) {
stream, err := m.openStream(id)
if err != nil {
return nil, err
......@@ -113,7 +114,7 @@ func (m *MuxConn) Accept(id byte) (io.ReadWriteCloser, error) {
// Dial opens a connection to the remote end using the given stream ID.
// An Accept on the remote end will only work with if the IDs match.
func (m *MuxConn) Dial(id byte) (io.ReadWriteCloser, error) {
func (m *MuxConn) Dial(id uint32) (io.ReadWriteCloser, error) {
stream, err := m.openStream(id)
if err != nil {
return nil, err
......@@ -149,7 +150,22 @@ func (m *MuxConn) Dial(id byte) (io.ReadWriteCloser, error) {
}
}
func (m *MuxConn) openStream(id byte) (*Stream, error) {
// NextId returns the next available stream ID that isn't currently
// taken.
func (m *MuxConn) NextId() uint32 {
m.mu.Lock()
defer m.mu.Unlock()
for {
if _, ok := m.streams[m.curId]; !ok {
return m.curId
}
m.curId++
}
}
func (m *MuxConn) openStream(id uint32) (*Stream, error) {
m.mu.Lock()
defer m.mu.Unlock()
......@@ -176,7 +192,7 @@ func (m *MuxConn) openStream(id byte) (*Stream, error) {
func (m *MuxConn) loop() {
defer m.Close()
var id byte
var id uint32
var packetType muxPacketType
var length int32
for {
......@@ -249,7 +265,7 @@ func (m *MuxConn) loop() {
}
}
func (m *MuxConn) write(id byte, dataType muxPacketType, p []byte) (int, error) {
func (m *MuxConn) write(id uint32, dataType muxPacketType, p []byte) (int, error) {
m.wlock.Lock()
defer m.wlock.Unlock()
......@@ -270,7 +286,7 @@ func (m *MuxConn) write(id byte, dataType muxPacketType, p []byte) (int, error)
// Stream is a single stream of data and implements io.ReadWriteCloser
type Stream struct {
id byte
id uint32
mux *MuxConn
reader io.Reader
writer io.WriteCloser
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment