Commit ceb8984b authored by Kirill Smelkov's avatar Kirill Smelkov

go/neo/neonet: Factor-out decode of packet header into pktHeadDecode

Preparatory step before we introduce support for msgpack encoding.
parent dfec9278
......@@ -788,10 +788,17 @@ func (nl *NodeLink) serveRecv() {
// NOTE if nl.peerLink was just closed by tx->shutdown we'll get ErrNetClosing
pkt, err := nl.recvPkt()
//fmt.Printf("\n%p recvPkt -> %v, %v\n", nl, pkt, err)
if err != nil {
// pkt.ConnId -> Conn
var connId uint32
if err == nil {
connId, _, _, err = pktDecodeHead(pkt)
}
// on IO error framing over peerLink becomes broken
// so we shut down node link and all connections over it.
// Same if we cannot decode packet header.
if err != nil {
nl.errMu.Lock()
nl.errRecv = err
nl.errMu.Unlock()
......@@ -800,8 +807,7 @@ func (nl *NodeLink) serveRecv() {
return
}
// pkt.ConnId -> Conn
connId := packed.Ntoh32(pkt.Header().ConnId)
accept := false
nl.connMu.Lock()
......@@ -1034,12 +1040,14 @@ func (c *Conn) sendPkt(pkt *pktBuf) error {
func (c *Conn) sendPkt2(pkt *pktBuf) error {
// connId must be set to one associated with this connection
if pkt.Header().ConnId != packed.Hton32(c.connId) {
connID, _, _, err := pktDecodeHead(pkt)
if err != nil {
panic(fmt.Sprintf("Conn.sendPkt: bad packet: %s", err))
}
if connID != c.connId {
panic("Conn.sendPkt: connId wrong")
}
var err error
select {
case <-c.txdown:
return c.errSendShutdown()
......@@ -1120,11 +1128,17 @@ func (nl *NodeLink) serveSend() {
// sendPktDirect sends raw packet with appropriate connection ID directly via link.
func (c *Conn) sendPktDirect(pkt *pktBuf) error {
// set pkt connId associated with this connection
pkt.Header().ConnId = packed.Hton32(c.connId)
// connId must be set to one associated with this connection
connID, _, _, err := pktDecodeHead(pkt)
if err != nil {
panic(fmt.Sprintf("Conn.sendPkt: bad packet: %s", err))
}
if connID != c.connId {
panic("Conn.sendPkt: connId wrong")
}
// NOTE if n.peerLink was just closed by rx->shutdown we'll get ErrNetClosing
err := c.link.sendPkt(pkt)
err = c.link.sendPkt(pkt)
//fmt.Printf("sendPkt -> %v\n", err)
// on IO error framing over peerLink becomes broken
......@@ -1319,7 +1333,22 @@ func pktEncode(connId uint32, msg proto.Msg) *pktBuf {
return buf
}
// TODO msgUnpack
// pktDecodeHead decodes header of a packet.
func pktDecodeHead(pkt *pktBuf) (connID uint32, msgCode uint16, payload []byte, err error) {
if len(pkt.data) < proto.PktHeaderLen {
return 0, 0, nil, fmt.Errorf("packet too short")
}
pkth := pkt.Header()
connID = packed.Ntoh32(pkth.ConnId)
msgCode = packed.Ntoh16(pkth.MsgCode)
msgLen := packed.Ntoh32(pkth.MsgLen)
payload = pkt.Payload()
if len(payload) != int(msgLen) {
return 0, 0, nil, fmt.Errorf("len(payload) != msgLen")
}
return
}
// Recv receives message from the connection.
func (c *Conn) Recv() (proto.Msg, error) {
......@@ -1330,8 +1359,11 @@ func (c *Conn) Recv() (proto.Msg, error) {
defer pkt.Free()
// decode packet
pkth := pkt.Header()
msgCode := packed.Ntoh16(pkth.MsgCode)
_, msgCode, payload, err := pktDecodeHead(pkt)
if err != nil {
return nil, err
}
msgType := proto.MsgType(msgCode)
if msgType == nil {
err := fmt.Errorf("invalid msgCode (%d)", msgCode)
......@@ -1344,7 +1376,7 @@ func (c *Conn) Recv() (proto.Msg, error) {
// msg := reflect.NewAt(msgType, bufAlloc(msgType.Size())
_, err = msg.NEOMsgDecode(pkt.Payload())
_, err = msg.NEOMsgDecode(payload)
if err != nil {
return nil, c.err("decode", err) // XXX "decode:" is already in ErrDecodeOverflow
}
......@@ -1388,12 +1420,15 @@ func (c *Conn) Expect(msgv ...proto.Msg) (which int, err error) {
}
defer pkt.Free()
pkth := pkt.Header()
msgCode := packed.Ntoh16(pkth.MsgCode)
// decode packet
_, msgCode, payload, err := pktDecodeHead(pkt)
if err != nil {
return -1, err
}
for i, msg := range msgv {
if msg.NEOMsgCode() == msgCode {
_, err := msg.NEOMsgDecode(pkt.Payload())
_, err := msg.NEOMsgDecode(payload)
if err != nil {
return -1, c.err("decode", err)
}
......
......@@ -136,20 +136,20 @@ func (c *Conn) mkpkt(msgcode uint16, payload []byte) *pktBuf {
// Verify pktBuf is as expected.
func (t *T) xverifyPkt(pkt *pktBuf, connid uint32, msgcode uint16, payload []byte) {
errv := xerr.Errorv{}
h := pkt.Header()
pktConnID, pktMsgCode, pktPayload, err := pktDecodeHead(pkt)
exc.Raiseif(err)
// TODO include caller location
if packed.Ntoh32(h.ConnId) != connid {
errv.Appendf("header: unexpected connid %v (want %v)", packed.Ntoh32(h.ConnId), connid)
}
if packed.Ntoh16(h.MsgCode) != msgcode {
errv.Appendf("header: unexpected msgcode %v (want %v)", packed.Ntoh16(h.MsgCode), msgcode)
if pktConnID != connid {
errv.Appendf("header: unexpected connid %v (want %v)", pktConnID, connid)
}
if packed.Ntoh32(h.MsgLen) != uint32(len(payload)) {
errv.Appendf("header: unexpected msglen %v (want %v)", packed.Ntoh32(h.MsgLen), len(payload))
if pktMsgCode != msgcode {
errv.Appendf("header: unexpected msgcode %v (want %v)", pktMsgCode, msgcode)
}
if !bytes.Equal(pkt.Payload(), payload) {
if !bytes.Equal(pktPayload, payload) {
errv.Appendf("payload differ:\n%s",
pretty.Compare(string(payload), string(pkt.Payload())))
pretty.Compare(string(payload), string(pktPayload)))
}
exc.Raiseif(errv.Err())
......@@ -614,8 +614,9 @@ func _TestNodeLink(t *T) {
gox(wg, func(_ context.Context) {
pkt := xrecvPkt(c)
n := packed.Ntoh16(pkt.Header().MsgCode)
x := replyOrder[n]
_, msgCode, _, err := pktDecodeHead(pkt)
exc.Raiseif(err)
x := replyOrder[msgCode]
// wait before it is our turn & echo pkt back
<-x.start
......
// Copyright (C) 2016-2018 Nexedi SA and Contributors.
// Copyright (C) 2016-2021 Nexedi SA and Contributors.
// Kirill Smelkov <kirr@nexedi.com>
//
// This program is free software: you can Use, Study, Modify and Redistribute
......@@ -28,7 +28,6 @@ import (
"lab.nexedi.com/kirr/go123/xbytes"
"lab.nexedi.com/kirr/neo/go/internal/packed"
"lab.nexedi.com/kirr/neo/go/neo/proto"
)
......@@ -80,32 +79,29 @@ func (pkt *pktBuf) Free() {
// String dumps a packet in human-readable form.
func (pkt *pktBuf) String() string {
if len(pkt.data) < proto.PktHeaderLen {
return fmt.Sprintf("(! < PktHeaderLen) % x", pkt.data)
connID, msgCode, payload, err := pktDecodeHead(pkt)
if err != nil {
return fmt.Sprintf("(%s) % x", err, pkt.data)
}
h := pkt.Header()
s := fmt.Sprintf(".%d", packed.Ntoh32(h.ConnId))
s := fmt.Sprintf(".%d", connID)
msgCode := packed.Ntoh16(h.MsgCode)
msgLen := packed.Ntoh32(h.MsgLen)
data := pkt.Payload()
msgType := proto.MsgType(msgCode)
if msgType == nil {
s += fmt.Sprintf(" ? (%d) #%d [%d]: % x", msgCode, msgLen, len(data), data)
s += fmt.Sprintf(" ? (%d) [%d]: % x", msgCode, len(payload), payload)
return s
}
// XXX dup wrt Conn.Recv
msg := reflect.New(msgType).Interface().(proto.Msg)
n, err := msg.NEOMsgDecode(data)
n, err := msg.NEOMsgDecode(payload)
if err != nil {
s += fmt.Sprintf(" (%s) %v; #%d [%d]: % x", msgType.Name(), err, msgLen, len(data), data)
s += fmt.Sprintf(" (%s) %v; [%d]: % x", msgType.Name(), err, len(payload), payload)
} else {
s += fmt.Sprintf(" %s %v", msgType.Name(), msg) // XXX or %+v better?
if n < len(data) {
tail := data[n:]
if n < len(payload) {
tail := payload[n:]
s += fmt.Sprintf(" ; [%d]tail: % x", len(tail), tail)
}
}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment