Commit b25baa62 authored by Mitchell Hashimoto's avatar Mitchell Hashimoto

packer/rpc: fix a blocking issue

parent 57bde34c
......@@ -400,8 +400,6 @@ func (m *MuxConn) loop() {
stream.mu.Unlock()
case muxPacketData:
unlocked := false
stream.mu.Lock()
switch stream.state {
case streamStateFinWait1:
......@@ -409,28 +407,17 @@ func (m *MuxConn) loop() {
case streamStateFinWait2:
fallthrough
case streamStateEstablished:
if len(data) > 0 {
// Get a reference to the write channel while we have
// the lock because otherwise the field might change.
// We unlock early here because the write might block
// for a long time.
writeCh := stream.writeCh
stream.mu.Unlock()
unlocked = true
// Blocked write, this provides some backpressure on
// the connection if there is a lot of data incoming.
writeCh <- data
if len(data) > 0 && stream.writeCh != nil {
//log.Printf("[TRACE] %p: Stream %d (%s) WRITE-START", m, id, from)
stream.writeCh <- data
//log.Printf("[TRACE] %p: Stream %d (%s) WRITE-END", m, id, from)
}
default:
log.Printf("[ERR] Data received for stream in state: %d", stream.state)
}
if !unlocked {
stream.mu.Unlock()
}
}
}
}
func (m *MuxConn) write(from muxPacketFrom, id uint32, dataType muxPacketType, p []byte) (int, error) {
......@@ -516,6 +503,7 @@ func newStream(from muxPacketFrom, id uint32, m *MuxConn) *Stream {
go func() {
defer dataW.Close()
drain := false
for {
data := <-writeCh
if data == nil {
......@@ -524,8 +512,14 @@ func newStream(from muxPacketFrom, id uint32, m *MuxConn) *Stream {
return
}
if drain {
// We're draining, meaning we're just waiting for the
// write channel to close.
continue
}
if _, err := dataW.Write(data); err != nil {
return
drain = true
}
}
}()
......@@ -568,7 +562,10 @@ func (s *Stream) Write(p []byte) (int, error) {
}
func (s *Stream) closeWriter() {
if s.writeCh != nil {
s.writeCh <- nil
s.writeCh = nil
}
}
func (s *Stream) setState(state streamState) {
......@@ -594,6 +591,7 @@ func (s *Stream) waitState(target streamState) error {
delete(s.stateChange, stateCh)
}()
//log.Printf("[TRACE] %p: Stream %d (%s) waiting for state: %d", s.mux, s.id, s.from, target)
state := <-stateCh
if state == target {
return nil
......
......@@ -76,6 +76,7 @@ func TestMuxConn(t *testing.T) {
go func() {
defer wg.Done()
defer s1.Close()
data := readStream(t, s1)
if data != "another" {
t.Fatalf("bad: %#v", data)
......@@ -84,6 +85,7 @@ func TestMuxConn(t *testing.T) {
go func() {
defer wg.Done()
defer s0.Close()
data := readStream(t, s0)
if data != "hello" {
t.Fatalf("bad: %#v", data)
......@@ -110,6 +112,9 @@ func TestMuxConn(t *testing.T) {
t.Fatalf("err: %s", err)
}
s0.Close()
s1.Close()
// Wait for the server to be done
<-doneCh
}
......@@ -131,18 +136,20 @@ func TestMuxConn_lotsOfData(t *testing.T) {
t.Fatalf("err: %s", err)
}
var wg sync.WaitGroup
wg.Add(1)
var data [1024]byte
for {
n, err := s0.Read(data[:])
if err == io.EOF {
break
}
go func() {
defer wg.Done()
data := readStream(t, s0)
if data != "hello" {
t.Fatalf("bad: %#v", data)
dataString := string(data[0:n])
if dataString != "hello" {
t.Fatalf("bad: %#v", dataString)
}
}
}()
wg.Wait()
s0.Close()
}()
s0, err := client.Dial(0)
......@@ -156,6 +163,10 @@ func TestMuxConn_lotsOfData(t *testing.T) {
}
}
if err := s0.Close(); err != nil {
t.Fatalf("err: %s", err)
}
// Wait for the server to be done
<-doneCh
}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment