Commit c0d54d50 authored by Kirill Smelkov's avatar Kirill Smelkov

go/neo/proto: Introduce Encoding

Encoding specifies a way to encode/decode NEO messages and packets.
Current way of how messages were encoded is called to be 'N' encoding.

This patch:

- adds proto.Encoding type
- changes MsgEncode and MsgDecode to be methods of Encoding
- renames thigs that are specific to 'N' encoding to have 'N' suffix
- changes tests to run a testcase agains vector of provided encodings.
  That vector is currently only ['N'].
parent 39545b9c
...@@ -122,7 +122,8 @@ import ( ...@@ -122,7 +122,8 @@ import (
// //
// It is safe to use NodeLink from multiple goroutines simultaneously. // It is safe to use NodeLink from multiple goroutines simultaneously.
type NodeLink struct { type NodeLink struct {
peerLink net.Conn // raw conn to peer peerLink net.Conn // raw conn to peer
enc proto.Encoding // protocol encoding in use ('N')
connMu sync.Mutex connMu sync.Mutex
connTab map[uint32]*Conn // connId -> Conn associated with connId connTab map[uint32]*Conn // connId -> Conn associated with connId
...@@ -151,7 +152,7 @@ type NodeLink struct { ...@@ -151,7 +152,7 @@ type NodeLink struct {
axclosed atomic32 // whether CloseAccept was called axclosed atomic32 // whether CloseAccept was called
closed atomic32 // whether Close was called closed atomic32 // whether Close was called
rxbuf rbuf.RingBuf // buffer for reading from peerLink rxbufN rbuf.RingBuf // buffer for reading from peerLink (N encoding)
// scheduling optimization: whenever serveRecv sends to Conn.rxq // scheduling optimization: whenever serveRecv sends to Conn.rxq
// receiving side must ack here to receive G handoff. // receiving side must ack here to receive G handoff.
...@@ -246,6 +247,8 @@ const ( ...@@ -246,6 +247,8 @@ const (
// newNodeLink makes a new NodeLink from already established net.Conn . // newNodeLink makes a new NodeLink from already established net.Conn .
// //
// On the wire messages will be encoded according to enc.
//
// Role specifies how to treat our role on the link - either as client or // Role specifies how to treat our role on the link - either as client or
// server. The difference in between client and server roles is in: // server. The difference in between client and server roles is in:
// //
...@@ -258,7 +261,7 @@ const ( ...@@ -258,7 +261,7 @@ const (
// //
// Though it is possible to wrap just-established raw connection into NodeLink, // Though it is possible to wrap just-established raw connection into NodeLink,
// users should always use Handshake which performs protocol handshaking first. // users should always use Handshake which performs protocol handshaking first.
func newNodeLink(conn net.Conn, role _LinkRole) *NodeLink { func newNodeLink(conn net.Conn, enc proto.Encoding, role _LinkRole) *NodeLink {
var nextConnId uint32 var nextConnId uint32
switch role &^ linkFlagsMask { switch role &^ linkFlagsMask {
case _LinkServer: case _LinkServer:
...@@ -271,6 +274,7 @@ func newNodeLink(conn net.Conn, role _LinkRole) *NodeLink { ...@@ -271,6 +274,7 @@ func newNodeLink(conn net.Conn, role _LinkRole) *NodeLink {
nl := &NodeLink{ nl := &NodeLink{
peerLink: conn, peerLink: conn,
enc: enc,
connTab: map[uint32]*Conn{}, connTab: map[uint32]*Conn{},
nextConnId: nextConnId, nextConnId: nextConnId,
acceptq: make(chan *Conn), // XXX +buf ? acceptq: make(chan *Conn), // XXX +buf ?
...@@ -792,7 +796,7 @@ func (nl *NodeLink) serveRecv() { ...@@ -792,7 +796,7 @@ func (nl *NodeLink) serveRecv() {
// pkt.ConnId -> Conn // pkt.ConnId -> Conn
var connId uint32 var connId uint32
if err == nil { if err == nil {
connId, _, _, err = pktDecodeHead(pkt) connId, _, _, err = pktDecodeHead(nl.enc, pkt)
} }
// on IO error framing over peerLink becomes broken // on IO error framing over peerLink becomes broken
...@@ -1040,7 +1044,7 @@ func (c *Conn) sendPkt(pkt *pktBuf) error { ...@@ -1040,7 +1044,7 @@ func (c *Conn) sendPkt(pkt *pktBuf) error {
func (c *Conn) sendPkt2(pkt *pktBuf) error { func (c *Conn) sendPkt2(pkt *pktBuf) error {
// connId must be set to one associated with this connection // connId must be set to one associated with this connection
connID, _, _, err := pktDecodeHead(pkt) connID, _, _, err := pktDecodeHead(c.link.enc, pkt)
if err != nil { if err != nil {
panic(fmt.Sprintf("Conn.sendPkt: bad packet: %s", err)) panic(fmt.Sprintf("Conn.sendPkt: bad packet: %s", err))
} }
...@@ -1129,7 +1133,7 @@ func (nl *NodeLink) serveSend() { ...@@ -1129,7 +1133,7 @@ func (nl *NodeLink) serveSend() {
// sendPktDirect sends raw packet with appropriate connection ID directly via link. // sendPktDirect sends raw packet with appropriate connection ID directly via link.
func (c *Conn) sendPktDirect(pkt *pktBuf) error { func (c *Conn) sendPktDirect(pkt *pktBuf) error {
// connId must be set to one associated with this connection // connId must be set to one associated with this connection
connID, _, _, err := pktDecodeHead(pkt) connID, _, _, err := pktDecodeHead(c.link.enc, pkt)
if err != nil { if err != nil {
panic(fmt.Sprintf("Conn.sendPkt: bad packet: %s", err)) panic(fmt.Sprintf("Conn.sendPkt: bad packet: %s", err))
} }
...@@ -1166,7 +1170,7 @@ const dumpio = false ...@@ -1166,7 +1170,7 @@ const dumpio = false
func (nl *NodeLink) sendPkt(pkt *pktBuf) error { func (nl *NodeLink) sendPkt(pkt *pktBuf) error {
if dumpio { if dumpio {
// XXX -> log // XXX -> log
fmt.Printf("%v > %v: %v\n", nl.peerLink.LocalAddr(), nl.peerLink.RemoteAddr(), pkt) fmt.Printf("%s > %s: %s\n", nl.peerLink.LocalAddr(), nl.peerLink.RemoteAddr(), pktString(nl.enc, pkt))
//defer fmt.Printf("\t-> sendPkt err: %v\n", err) //defer fmt.Printf("\t-> sendPkt err: %v\n", err)
} }
...@@ -1183,8 +1187,29 @@ var ErrPktTooBig = errors.New("packet too big") ...@@ -1183,8 +1187,29 @@ var ErrPktTooBig = errors.New("packet too big")
// rx error, if any, is returned as is and is analyzed in serveRecv // rx error, if any, is returned as is and is analyzed in serveRecv
// //
// XXX dup in ZEO. // XXX dup in ZEO.
func (nl *NodeLink) recvPkt() (*pktBuf, error) { func (nl *NodeLink) recvPkt() (pkt *pktBuf, err error) {
// FIXME if rxbuf is non-empty - first look there for header and then if switch nl.enc {
case 'N': pkt, err = nl.recvPktN()
default: panic("bug")
}
if dumpio {
// XXX -> log
s := fmt.Sprintf("%s < %s: ", nl.peerLink.LocalAddr(), nl.peerLink.RemoteAddr())
if err != nil {
s += err.Error()
} else {
s += pktString(nl.enc, pkt)
}
fmt.Println(s)
}
return pkt, err
}
func (nl *NodeLink) recvPktN() (*pktBuf, error) {
// FIXME if rxbufN is non-empty - first look there for header and then if
// we know size -> allocate pkt with that size. // we know size -> allocate pkt with that size.
pkt := pktAlloc(4096) pkt := pktAlloc(4096)
// len=4K but cap can be more since pkt is from pool - use all space to buffer reads // len=4K but cap can be more since pkt is from pool - use all space to buffer reads
...@@ -1194,35 +1219,35 @@ func (nl *NodeLink) recvPkt() (*pktBuf, error) { ...@@ -1194,35 +1219,35 @@ func (nl *NodeLink) recvPkt() (*pktBuf, error) {
n := 0 // number of pkt bytes obtained so far n := 0 // number of pkt bytes obtained so far
// next packet could be already prefetched in part by previous read // next packet could be already prefetched in part by previous read
if nl.rxbuf.Len() > 0 { if nl.rxbufN.Len() > 0 {
δn, _ := nl.rxbuf.Read(data[:proto.PktHeaderLen]) δn, _ := nl.rxbufN.Read(data[:proto.PktHeaderLenN])
n += δn n += δn
} }
// first read to read pkt header and hopefully rest of packet in 1 syscall // first read to read pkt header and hopefully rest of packet in 1 syscall
if n < proto.PktHeaderLen { if n < proto.PktHeaderLenN {
δn, err := io.ReadAtLeast(nl.peerLink, data[n:], proto.PktHeaderLen - n) δn, err := io.ReadAtLeast(nl.peerLink, data[n:], proto.PktHeaderLenN - n)
if err != nil { if err != nil {
return nil, err return nil, err
} }
n += δn n += δn
} }
pkth := pkt.Header() pkth := pkt.HeaderN()
msgLen := packed.Ntoh32(pkth.MsgLen) msgLen := packed.Ntoh32(pkth.MsgLen)
if msgLen > proto.PktMaxSize - proto.PktHeaderLen { if msgLen > proto.PktMaxSize - proto.PktHeaderLenN {
return nil, ErrPktTooBig return nil, ErrPktTooBig
} }
pktLen := int(proto.PktHeaderLen + msgLen) // whole packet length pktLen := int(proto.PktHeaderLenN + msgLen) // whole packet length
// resize data if we don't have enough room in it // resize data if we don't have enough room in it
data = xbytes.Resize(data, pktLen) data = xbytes.Resize(data, pktLen)
data = data[:cap(data)] data = data[:cap(data)]
// we might have more data already prefetched in rxbuf // we might have more data already prefetched in rxbufN
if nl.rxbuf.Len() > 0 { if nl.rxbufN.Len() > 0 {
δn, _ := nl.rxbuf.Read(data[n:pktLen]) δn, _ := nl.rxbufN.Read(data[n:pktLen])
n += δn n += δn
} }
...@@ -1235,20 +1260,15 @@ func (nl *NodeLink) recvPkt() (*pktBuf, error) { ...@@ -1235,20 +1260,15 @@ func (nl *NodeLink) recvPkt() (*pktBuf, error) {
n += δn n += δn
} }
// put overread data into rxbuf for next reader // put overread data into rxbufN for next reader
if n > pktLen { if n > pktLen {
nl.rxbuf.Write(data[pktLen:n]) nl.rxbufN.Write(data[pktLen:n])
} }
// fixup data/pkt // fixup data/pkt
data = data[:pktLen] data = data[:pktLen]
pkt.data = data pkt.data = data
if dumpio {
// XXX -> log
fmt.Printf("%v < %v: %v\n", nl.peerLink.LocalAddr(), nl.peerLink.RemoteAddr(), pkt)
}
return pkt, nil return pkt, nil
} }
...@@ -1320,29 +1340,51 @@ func (c *Conn) err(op string, e error) error { ...@@ -1320,29 +1340,51 @@ func (c *Conn) err(op string, e error) error {
// ---- exchange of messages ---- // ---- exchange of messages ----
// pktEncode allocates pktBuf and encodes msg into it. // pktEncode allocates pktBuf and encodes msg into it.
func pktEncode(connId uint32, msg proto.Msg) *pktBuf { func pktEncode(e proto.Encoding, connId uint32, msg proto.Msg) *pktBuf {
l := proto.MsgEncodedLen(msg) switch e {
buf := pktAlloc(proto.PktHeaderLen + l) case 'N': return pktEncodeN(connId, msg)
default: panic("bug")
}
}
h := buf.Header() // pktDecodeHead decodes header of a packet.
h.ConnId = packed.Hton32(connId) func pktDecodeHead(e proto.Encoding, pkt *pktBuf) (connID uint32, msgCode uint16, payload []byte, err error) {
switch e {
case 'N': connID, msgCode, payload, err = pktDecodeHeadN(pkt)
default: panic("bug")
}
if err != nil {
err = fmt.Errorf("%c: decode header: %s", e, err)
}
return connID, msgCode, payload, err
}
func pktEncodeN(connId uint32, msg proto.Msg) *pktBuf {
const enc = proto.Encoding('N')
l := enc.MsgEncodedLen(msg)
buf := pktAlloc(proto.PktHeaderLenN + l)
h := buf.HeaderN()
h.ConnId = packed.Hton32(connId)
h.MsgCode = packed.Hton16(proto.MsgCode(msg)) h.MsgCode = packed.Hton16(proto.MsgCode(msg))
h.MsgLen = packed.Hton32(uint32(l)) // XXX casting: think again h.MsgLen = packed.Hton32(uint32(l)) // XXX casting: think again
proto.MsgEncode(msg, buf.Payload()) enc.MsgEncode(msg, buf.PayloadN())
return buf return buf
} }
// pktDecodeHead decodes header of a packet. func pktDecodeHeadN(pkt *pktBuf) (connID uint32, msgCode uint16, payload []byte, err error) {
func pktDecodeHead(pkt *pktBuf) (connID uint32, msgCode uint16, payload []byte, err error) { if len(pkt.data) < proto.PktHeaderLenN {
if len(pkt.data) < proto.PktHeaderLen {
return 0, 0, nil, fmt.Errorf("packet too short") return 0, 0, nil, fmt.Errorf("packet too short")
} }
pkth := pkt.Header() pkth := pkt.HeaderN()
connID = packed.Ntoh32(pkth.ConnId) connID = packed.Ntoh32(pkth.ConnId)
msgCode = packed.Ntoh16(pkth.MsgCode) msgCode = packed.Ntoh16(pkth.MsgCode)
msgLen := packed.Ntoh32(pkth.MsgLen) msgLen := packed.Ntoh32(pkth.MsgLen)
payload = pkt.Payload() payload = pkt.PayloadN()
if len(payload) != int(msgLen) { if len(payload) != int(msgLen) {
return 0, 0, nil, fmt.Errorf("len(payload) != msgLen") return 0, 0, nil, fmt.Errorf("len(payload) != msgLen")
} }
...@@ -1359,7 +1401,7 @@ func (c *Conn) Recv() (proto.Msg, error) { ...@@ -1359,7 +1401,7 @@ func (c *Conn) Recv() (proto.Msg, error) {
defer pkt.Free() defer pkt.Free()
// decode packet // decode packet
_, msgCode, payload, err := pktDecodeHead(pkt) _, msgCode, payload, err := pktDecodeHead(c.link.enc, pkt)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -1376,7 +1418,7 @@ func (c *Conn) Recv() (proto.Msg, error) { ...@@ -1376,7 +1418,7 @@ func (c *Conn) Recv() (proto.Msg, error) {
// msg := reflect.NewAt(msgType, bufAlloc(msgType.Size()) // msg := reflect.NewAt(msgType, bufAlloc(msgType.Size())
_, err = proto.MsgDecode(msg, payload) _, err = c.link.enc.MsgDecode(msg, payload)
if err != nil { if err != nil {
return nil, c.err("decode", err) // XXX "decode:" is already in ErrDecodeOverflow return nil, c.err("decode", err) // XXX "decode:" is already in ErrDecodeOverflow
} }
...@@ -1390,14 +1432,14 @@ func (c *Conn) Recv() (proto.Msg, error) { ...@@ -1390,14 +1432,14 @@ func (c *Conn) Recv() (proto.Msg, error) {
// //
// it is ok to call sendMsg in parallel with serveSend. XXX link to sendPktDirect for rationale? // it is ok to call sendMsg in parallel with serveSend. XXX link to sendPktDirect for rationale?
func (link *NodeLink) sendMsg(connId uint32, msg proto.Msg) error { func (link *NodeLink) sendMsg(connId uint32, msg proto.Msg) error {
buf := pktEncode(connId, msg) buf := pktEncode(link.enc, connId, msg)
return link.sendPkt(buf) // XXX more context in err? (msg type) return link.sendPkt(buf) // XXX more context in err? (msg type)
// FIXME ^^^ shutdown whole link on error // FIXME ^^^ shutdown whole link on error
} }
// Send sends message over the connection. // Send sends message over the connection.
func (c *Conn) Send(msg proto.Msg) error { func (c *Conn) Send(msg proto.Msg) error {
buf := pktEncode(c.connId, msg) buf := pktEncode(c.link.enc, c.connId, msg)
return c.sendPkt(buf) // XXX more context in err? (msg type) return c.sendPkt(buf) // XXX more context in err? (msg type)
} }
...@@ -1421,14 +1463,14 @@ func (c *Conn) Expect(msgv ...proto.Msg) (which int, err error) { ...@@ -1421,14 +1463,14 @@ func (c *Conn) Expect(msgv ...proto.Msg) (which int, err error) {
defer pkt.Free() defer pkt.Free()
// decode packet // decode packet
_, msgCode, payload, err := pktDecodeHead(pkt) _, msgCode, payload, err := pktDecodeHead(c.link.enc, pkt)
if err != nil { if err != nil {
return -1, err return -1, err
} }
for i, msg := range msgv { for i, msg := range msgv {
if proto.MsgCode(msg) == msgCode { if proto.MsgCode(msg) == msgCode {
_, err := proto.MsgDecode(msg, payload) _, err := c.link.enc.MsgDecode(msg, payload)
if err != nil { if err != nil {
return -1, c.err("decode", err) return -1, c.err("decode", err)
} }
......
...@@ -22,6 +22,7 @@ package neonet ...@@ -22,6 +22,7 @@ package neonet
import ( import (
"bytes" "bytes"
"context" "context"
"fmt"
"io" "io"
"net" "net"
"reflect" "reflect"
...@@ -45,16 +46,25 @@ import ( ...@@ -45,16 +46,25 @@ import (
// T is neonet testing environment. // T is neonet testing environment.
type T struct { type T struct {
*testing.T *testing.T
enc proto.Encoding // encoding to use for messages exchange
} }
// Verify tests f for all possible environments. // Verify tests f for all possible environments.
func Verify(t *testing.T, f func(*T)) { func Verify(t *testing.T, f func(*T)) {
f(&T{t}) // for each encoding
for _, enc := range []proto.Encoding{'N'} {
t.Run(fmt.Sprintf("enc=%c", enc), func(t *testing.T) {
f(&T{t, enc})
})
}
} }
// bin returns payload for raw binary data as it would-be encoded in t. // bin returns payload for raw binary data as it would-be encoded by t.enc .
func (t *T) bin(data string) []byte { func (t *T) bin(data string) []byte {
return []byte(data) switch t.enc {
case 'N': return []byte(data)
default: panic("bug")
}
} }
...@@ -118,26 +128,32 @@ func xconnError(err error) error { ...@@ -118,26 +128,32 @@ func xconnError(err error) error {
} }
// Prepare pktBuf with content. // Prepare pktBuf with content.
func _mkpkt(connid uint32, msgcode uint16, payload []byte) *pktBuf { func _mkpkt(enc proto.Encoding, connid uint32, msgcode uint16, payload []byte) *pktBuf {
pkt := &pktBuf{make([]byte, proto.PktHeaderLen+len(payload))} switch enc {
h := pkt.Header() case 'N':
h.ConnId = packed.Hton32(connid) pkt := &pktBuf{make([]byte, proto.PktHeaderLenN+len(payload))}
h.MsgCode = packed.Hton16(msgcode) h := pkt.HeaderN()
h.MsgLen = packed.Hton32(uint32(len(payload))) h.ConnId = packed.Hton32(connid)
copy(pkt.Payload(), payload) h.MsgCode = packed.Hton16(msgcode)
return pkt h.MsgLen = packed.Hton32(uint32(len(payload)))
copy(pkt.PayloadN(), payload)
return pkt
default:
panic("bug")
}
} }
func (c *Conn) mkpkt(msgcode uint16, payload []byte) *pktBuf { func (c *Conn) mkpkt(msgcode uint16, payload []byte) *pktBuf {
// in Conn exchange connid is automatically set by Conn.sendPkt // in Conn exchange connid is automatically set by Conn.sendPkt
return _mkpkt(c.connId, msgcode, payload) return _mkpkt(c.link.enc, c.connId, msgcode, payload)
} }
// Verify pktBuf is as expected. // Verify pktBuf is as expected.
func (t *T) xverifyPkt(pkt *pktBuf, connid uint32, msgcode uint16, payload []byte) { func (t *T) xverifyPkt(pkt *pktBuf, connid uint32, msgcode uint16, payload []byte) {
errv := xerr.Errorv{} errv := xerr.Errorv{}
pktConnID, pktMsgCode, pktPayload, err := pktDecodeHead(pkt) pktConnID, pktMsgCode, pktPayload, err := pktDecodeHead(t.enc, pkt)
exc.Raiseif(err) exc.Raiseif(err)
// TODO include caller location // TODO include caller location
...@@ -157,8 +173,8 @@ func (t *T) xverifyPkt(pkt *pktBuf, connid uint32, msgcode uint16, payload []byt ...@@ -157,8 +173,8 @@ func (t *T) xverifyPkt(pkt *pktBuf, connid uint32, msgcode uint16, payload []byt
// Verify pktBuf to match expected message. // Verify pktBuf to match expected message.
func (t *T) xverifyPktMsg(pkt *pktBuf, connid uint32, msg proto.Msg) { func (t *T) xverifyPktMsg(pkt *pktBuf, connid uint32, msg proto.Msg) {
data := make([]byte, proto.MsgEncodedLen(msg)) data := make([]byte, t.enc.MsgEncodedLen(msg))
proto.MsgEncode(msg, data) t.enc.MsgEncode(msg, data)
t.xverifyPkt(pkt, connid, proto.MsgCode(msg), data) t.xverifyPkt(pkt, connid, proto.MsgCode(msg), data)
} }
...@@ -176,11 +192,11 @@ func tdelay() { ...@@ -176,11 +192,11 @@ func tdelay() {
time.Sleep(1 * time.Millisecond) time.Sleep(1 * time.Millisecond)
} }
// create NodeLinks connected via net.Pipe // create NodeLinks connected via net.Pipe; messages are encoded via t.enc.
func (t *T) _nodeLinkPipe(flags1, flags2 _LinkRole) (nl1, nl2 *NodeLink) { func (t *T) _nodeLinkPipe(flags1, flags2 _LinkRole) (nl1, nl2 *NodeLink) {
node1, node2 := net.Pipe() node1, node2 := net.Pipe()
nl1 = newNodeLink(node1, _LinkClient|flags1) nl1 = newNodeLink(node1, t.enc, _LinkClient|flags1)
nl2 = newNodeLink(node2, _LinkServer|flags2) nl2 = newNodeLink(node2, t.enc, _LinkServer|flags2)
return nl1, nl2 return nl1, nl2
} }
...@@ -289,7 +305,7 @@ func _TestNodeLink(t *T) { ...@@ -289,7 +305,7 @@ func _TestNodeLink(t *T) {
okch := make(chan int, 2) okch := make(chan int, 2)
gox(wg, func(_ context.Context) { gox(wg, func(_ context.Context) {
// send ping; wait for pong // send ping; wait for pong
pkt := _mkpkt(1, 2, b("ping")) pkt := _mkpkt(t.enc, 1, 2, b("ping"))
xsendPkt(nl1, pkt) xsendPkt(nl1, pkt)
pkt = xrecvPkt(nl1) pkt = xrecvPkt(nl1)
t.xverifyPkt(pkt, 3, 4, b("pong")) t.xverifyPkt(pkt, 3, 4, b("pong"))
...@@ -299,7 +315,7 @@ func _TestNodeLink(t *T) { ...@@ -299,7 +315,7 @@ func _TestNodeLink(t *T) {
// wait for ping; send pong // wait for ping; send pong
pkt = xrecvPkt(nl2) pkt = xrecvPkt(nl2)
t.xverifyPkt(pkt, 1, 2, b("ping")) t.xverifyPkt(pkt, 1, 2, b("ping"))
pkt = _mkpkt(3, 4, b("pong")) pkt = _mkpkt(t.enc, 3, 4, b("pong"))
xsendPkt(nl2, pkt) xsendPkt(nl2, pkt)
okch <- 2 okch <- 2
}) })
...@@ -614,7 +630,7 @@ func _TestNodeLink(t *T) { ...@@ -614,7 +630,7 @@ func _TestNodeLink(t *T) {
gox(wg, func(_ context.Context) { gox(wg, func(_ context.Context) {
pkt := xrecvPkt(c) pkt := xrecvPkt(c)
_, msgCode, _, err := pktDecodeHead(pkt) _, msgCode, _, err := pktDecodeHead(t.enc, pkt)
exc.Raiseif(err) exc.Raiseif(err)
x := replyOrder[msgCode] x := replyOrder[msgCode]
......
...@@ -81,7 +81,8 @@ func handshakeClient(ctx context.Context, conn net.Conn, version uint32) (*NodeL ...@@ -81,7 +81,8 @@ func handshakeClient(ctx context.Context, conn net.Conn, version uint32) (*NodeL
if err != nil { if err != nil {
return nil, err return nil, err
} }
return newNodeLink(conn, _LinkClient), nil enc := proto.Encoding('N')
return newNodeLink(conn, enc, _LinkClient), nil
} }
// handshakeServer implements server-side NEO protocol handshake just after raw // handshakeServer implements server-side NEO protocol handshake just after raw
...@@ -96,7 +97,8 @@ func handshakeServer(ctx context.Context, conn net.Conn, version uint32) (*NodeL ...@@ -96,7 +97,8 @@ func handshakeServer(ctx context.Context, conn net.Conn, version uint32) (*NodeL
if err != nil { if err != nil {
return nil, err return nil, err
} }
return newNodeLink(conn, _LinkServer), nil enc := proto.Encoding('N')
return newNodeLink(conn, enc, _LinkServer), nil
} }
func _handshakeClient(ctx context.Context, conn net.Conn, version uint32) (err error) { func _handshakeClient(ctx context.Context, conn net.Conn, version uint32) (err error) {
......
...@@ -38,16 +38,16 @@ type pktBuf struct { ...@@ -38,16 +38,16 @@ type pktBuf struct {
data []byte // whole packet data including all headers data []byte // whole packet data including all headers
} }
// Header returns pointer to packet header. // HeaderN returns pointer to packet header in 'N'-encoding.
func (pkt *pktBuf) Header() *proto.PktHeader { func (pkt *pktBuf) HeaderN() *proto.PktHeaderN {
// NOTE no need to check len(.data) < PktHeader: // NOTE no need to check len(.data) < PktHeaderN:
// .data is always allocated with cap >= PktHeaderLen. // .data is always allocated with cap >= PktHeaderLenN.
return (*proto.PktHeader)(unsafe.Pointer(&pkt.data[0])) return (*proto.PktHeaderN)(unsafe.Pointer(&pkt.data[0]))
} }
// Payload returns []byte representing packet payload. // PayloadN returns []byte representing packet payload in 'N'-encoding.
func (pkt *pktBuf) Payload() []byte { func (pkt *pktBuf) PayloadN() []byte {
return pkt.data[proto.PktHeaderLen:] return pkt.data[proto.PktHeaderLenN:]
} }
// ---- pktBuf freelist ---- // ---- pktBuf freelist ----
...@@ -59,11 +59,11 @@ var pktBufPool = sync.Pool{New: func() interface{} { ...@@ -59,11 +59,11 @@ var pktBufPool = sync.Pool{New: func() interface{} {
// pktAlloc allocates pktBuf with len=n. // pktAlloc allocates pktBuf with len=n.
func pktAlloc(n int) *pktBuf { func pktAlloc(n int) *pktBuf {
// make sure cap >= PktHeaderLen. // make sure cap >= PktHeaderLenN.
// see Header for why // see HeaderN for why
l := n l := n
if l < proto.PktHeaderLen { if l < proto.PktHeaderLenN {
l = proto.PktHeaderLen l = proto.PktHeaderLenN
} }
pkt := pktBufPool.Get().(*pktBuf) pkt := pktBufPool.Get().(*pktBuf)
pkt.data = xbytes.Realloc(pkt.data, l)[:n] pkt.data = xbytes.Realloc(pkt.data, l)[:n]
...@@ -78,9 +78,9 @@ func (pkt *pktBuf) Free() { ...@@ -78,9 +78,9 @@ func (pkt *pktBuf) Free() {
// ---- pktBuf dump ---- // ---- pktBuf dump ----
// String dumps a packet in human-readable form. // pktString dumps a packet in human-readable form.
func (pkt *pktBuf) String() string { func pktString(e proto.Encoding, pkt *pktBuf) string {
connID, msgCode, payload, err := pktDecodeHead(pkt) connID, msgCode, payload, err := pktDecodeHead(e, pkt)
if err != nil { if err != nil {
return fmt.Sprintf("(%s) % x", err, pkt.data) return fmt.Sprintf("(%s) % x", err, pkt.data)
} }
...@@ -95,7 +95,7 @@ func (pkt *pktBuf) String() string { ...@@ -95,7 +95,7 @@ func (pkt *pktBuf) String() string {
// XXX dup wrt Conn.Recv // XXX dup wrt Conn.Recv
msg := reflect.New(msgType).Interface().(proto.Msg) msg := reflect.New(msgType).Interface().(proto.Msg)
n, err := proto.MsgDecode(msg, payload) n, err := e.MsgDecode(msg, payload)
if err != nil { if err != nil {
s += fmt.Sprintf(" (%s) %v; [%d]: % x", msgType.Name(), err, len(payload), payload) s += fmt.Sprintf(" (%s) %v; [%d]: % x", msgType.Name(), err, len(payload), payload)
} else { } else {
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
// ID of subconnection multiplexed on top of the underlying link, carried // ID of subconnection multiplexed on top of the underlying link, carried
// message code and message data. // message code and message data.
// //
// PktHeader describes packet header structure. // PktHeaderN describes packet header structure in 'N' encoding.
// //
// Messages are represented by corresponding types that all implement Msg interface. // Messages are represented by corresponding types that all implement Msg interface.
// //
...@@ -79,8 +79,8 @@ const ( ...@@ -79,8 +79,8 @@ const (
// the high order byte 0 is different from TLS Handshake (0x16). // the high order byte 0 is different from TLS Handshake (0x16).
Version = 6 Version = 6
// length of packet header // length of packet header in 'N'-encoding
PktHeaderLen = 10 // = unsafe.Sizeof(PktHeader{}), but latter gives typed constant (uintptr) PktHeaderLenN = 10 // = unsafe.Sizeof(PktHeaderN{}), but latter gives typed constant (uintptr)
// packets larger than PktMaxSize are not allowed. // packets larger than PktMaxSize are not allowed.
// this helps to avoid out-of-memory error on packets with corrupt message len. // this helps to avoid out-of-memory error on packets with corrupt message len.
...@@ -95,12 +95,12 @@ const ( ...@@ -95,12 +95,12 @@ const (
INVALID_OID zodb.Oid = 1<<64 - 1 INVALID_OID zodb.Oid = 1<<64 - 1
) )
// PktHeader represents header of a raw packet. // PktHeaderN represents header of a raw packet in 'N'-encoding.
// //
// A packet contains connection ID and message. // A packet contains connection ID and message.
// //
//neo:proto typeonly //neo:proto typeonly
type PktHeader struct { type PktHeaderN struct {
ConnId packed.BE32 // NOTE is .msgid in py ConnId packed.BE32 // NOTE is .msgid in py
MsgCode packed.BE16 // payload message code MsgCode packed.BE16 // payload message code
MsgLen packed.BE32 // payload message length (excluding packet header) MsgLen packed.BE32 // payload message length (excluding packet header)
...@@ -114,33 +114,50 @@ type Msg interface { ...@@ -114,33 +114,50 @@ type Msg interface {
// on the wire. // on the wire.
neoMsgCode() uint16 neoMsgCode() uint16
// neoMsgEncodedLen returns how much space is needed to encode current message payload. // for encoding E:
neoMsgEncodedLen() int //
// - neoMsgEncodedLen<E> returns how much space is needed to encode current message payload via E encoding.
// neoMsgEncode encodes current message state into buf. //
// - neoMsgEncode<E> encodes current message state into buf via E encoding.
// //
// len(buf) must be >= neoMsgEncodedLen(). // len(buf) must be >= neoMsgEncodedLen<E>().
neoMsgEncode(buf []byte) //
// - neoMsgDecode<E> decodes data via E encoding into message in-place.
// N encoding (original struct-based encoding)
neoMsgEncodedLenN() int
neoMsgEncodeN(buf []byte)
neoMsgDecodeN(data []byte) (nread int, err error)
// neoMsgDecode decodes data into message in-place.
neoMsgDecode(data []byte) (nread int, err error)
} }
// MsgEncodedLen returns how much space is needed to encode msg payload. // Encoding represents messages encoding.
func MsgEncodedLen(msg Msg) int { type Encoding byte
return msg.neoMsgEncodedLen()
// MsgEncodedLen returns how much space is needed to encode msg payload via encoding e.
func (e Encoding) MsgEncodedLen(msg Msg) int {
switch e {
default: panic("bug")
case 'N': return msg.neoMsgEncodedLenN()
}
} }
// MsgEncode encodes msg state into buf. // MsgEncode encodes msg state into buf via encoding e.
// //
// len(buf) must be >= MsgEncodedLen(m). // len(buf) must be >= e.MsgEncodedLen(m).
func MsgEncode(msg Msg, buf []byte) { func (e Encoding) MsgEncode(msg Msg, buf []byte) {
msg.neoMsgEncode(buf) switch e {
default: panic("bug")
case 'N': msg.neoMsgEncodeN(buf)
}
} }
// MsgDecode decodes data into msg in-place. // MsgDecode decodes data via encoding e into msg in-place.
func MsgDecode(msg Msg, data []byte) (nread int, err error) { func (e Encoding) MsgDecode(msg Msg, data []byte) (nread int, err error) {
return msg.neoMsgDecode(data) switch e {
default: panic("bug")
case 'N': return msg.neoMsgDecodeN(data)
}
} }
...@@ -266,7 +283,7 @@ type Address struct { ...@@ -266,7 +283,7 @@ type Address struct {
} }
// NOTE if Host == "" -> Port not added to wire (see py.PAddress): // NOTE if Host == "" -> Port not added to wire (see py.PAddress):
func (a *Address) neoEncodedLen() int { func (a *Address) neoEncodedLenN() int {
l := string_neoEncodedLen(a.Host) l := string_neoEncodedLen(a.Host)
if a.Host != "" { if a.Host != "" {
l += 2 l += 2
...@@ -274,7 +291,7 @@ func (a *Address) neoEncodedLen() int { ...@@ -274,7 +291,7 @@ func (a *Address) neoEncodedLen() int {
return l return l
} }
func (a *Address) neoEncode(b []byte) int { func (a *Address) neoEncodeN(b []byte) int {
n := string_neoEncode(a.Host, b[0:]) n := string_neoEncode(a.Host, b[0:])
if a.Host != "" { if a.Host != "" {
binary.BigEndian.PutUint16(b[n:], a.Port) binary.BigEndian.PutUint16(b[n:], a.Port)
...@@ -283,7 +300,7 @@ func (a *Address) neoEncode(b []byte) int { ...@@ -283,7 +300,7 @@ func (a *Address) neoEncode(b []byte) int {
return n return n
} }
func (a *Address) neoDecode(b []byte) (uint64, bool) { func (a *Address) neoDecodeN(b []byte) (uint64, bool) {
n, ok := string_neoDecode(&a.Host, b) n, ok := string_neoDecode(&a.Host, b)
if !ok { if !ok {
return 0, false return 0, false
...@@ -312,11 +329,11 @@ type PTid uint64 ...@@ -312,11 +329,11 @@ type PTid uint64
// IdTime represents time of identification. // IdTime represents time of identification.
type IdTime float64 type IdTime float64
func (t IdTime) neoEncodedLen() int { func (t IdTime) neoEncodedLenN() int {
return 8 return 8
} }
func (t IdTime) neoEncode(b []byte) int { func (t IdTime) neoEncodeN(b []byte) int {
// use -inf as value for no data (NaN != NaN -> hard to use NaN in tests) // use -inf as value for no data (NaN != NaN -> hard to use NaN in tests)
// NOTE neo/py uses None for "no data"; we use 0 for "no data" to avoid pointer // NOTE neo/py uses None for "no data"; we use 0 for "no data" to avoid pointer
tt := float64(t) tt := float64(t)
...@@ -327,7 +344,7 @@ func (t IdTime) neoEncode(b []byte) int { ...@@ -327,7 +344,7 @@ func (t IdTime) neoEncode(b []byte) int {
return 8 return 8
} }
func (t *IdTime) neoDecode(data []byte) (uint64, bool) { func (t *IdTime) neoDecodeN(data []byte) (uint64, bool) {
if len(data) < 8 { if len(data) < 8 {
return 0, false return 0, false
} }
...@@ -1210,13 +1227,13 @@ type FlushLog struct {} ...@@ -1210,13 +1227,13 @@ type FlushLog struct {}
// ---- runtime support for protogen and custom codecs ---- // ---- runtime support for protogen and custom codecs ----
// customCodec is the interface that is implemented by types with custom encodings. // customCodecN is the interface that is implemented by types with custom N encodings.
// //
// its semantic is very similar to Msg. // its semantic is very similar to Msg.
type customCodec interface { type customCodecN interface {
neoEncodedLen() int neoEncodedLenN() int
neoEncode(buf []byte) (nwrote int) neoEncodeN(buf []byte) (nwrote int)
neoDecode(data []byte) (nread uint64, ok bool) // XXX uint64 or int here? neoDecodeN(data []byte) (nread uint64, ok bool) // XXX uint64 or int here?
} }
func byte2bool(b byte) bool { func byte2bool(b byte) bool {
......
...@@ -68,42 +68,42 @@ func u64(v uint64) string { ...@@ -68,42 +68,42 @@ func u64(v uint64) string {
return string(b[:]) return string(b[:])
} }
func TestPktHeader(t *testing.T) { func TestPktHeaderN(t *testing.T) {
// make sure PktHeader is really packed and its size matches PktHeaderLen // make sure PktHeaderN is really packed and its size matches PktHeaderLenN
if unsafe.Sizeof(PktHeader{}) != 10 { if unsafe.Sizeof(PktHeaderN{}) != 10 {
t.Fatalf("sizeof(PktHeader) = %v ; want 10", unsafe.Sizeof(PktHeader{})) t.Fatalf("sizeof(PktHeaderN) = %v ; want 10", unsafe.Sizeof(PktHeaderN{}))
} }
if unsafe.Sizeof(PktHeader{}) != PktHeaderLen { if unsafe.Sizeof(PktHeaderN{}) != PktHeaderLenN {
t.Fatalf("sizeof(PktHeader) = %v ; want %v", unsafe.Sizeof(PktHeader{}), PktHeaderLen) t.Fatalf("sizeof(PktHeaderN) = %v ; want %v", unsafe.Sizeof(PktHeaderN{}), PktHeaderLenN)
} }
} }
// test marshalling for one message type // test marshalling for one message type
func testMsgMarshal(t *testing.T, msg Msg, encoded string) { func testMsgMarshal(t *testing.T, enc Encoding, msg Msg, encoded string) {
typ := reflect.TypeOf(msg).Elem() // type of *msg typ := reflect.TypeOf(msg).Elem() // type of *msg
msg2 := reflect.New(typ).Interface().(Msg) msg2 := reflect.New(typ).Interface().(Msg)
defer func() { defer func() {
if e := recover(); e != nil { if e := recover(); e != nil {
t.Errorf("%v: panic ↓↓↓:", typ) t.Errorf("%c/%v: panic ↓↓↓:", enc, typ)
panic(e) // to show traceback panic(e) // to show traceback
} }
}() }()
// msg.encode() == expected // msg.encode() == expected
msgCode := msg.neoMsgCode() msgCode := msg.neoMsgCode()
n := MsgEncodedLen(msg) n := enc.MsgEncodedLen(msg)
msgType := MsgType(msgCode) msgType := MsgType(msgCode)
if msgType != typ { if msgType != typ {
t.Errorf("%v: msgCode = %v which corresponds to %v", typ, msgCode, msgType) t.Errorf("%c/%v: msgCode = %v which corresponds to %v", enc, typ, msgCode, msgType)
} }
if n != len(encoded) { if n != len(encoded) {
t.Errorf("%v: encodedLen = %v ; want %v", typ, n, len(encoded)) t.Errorf("%c/%v: encodedLen = %v ; want %v", enc, typ, n, len(encoded))
} }
buf := make([]byte, n) buf := make([]byte, n)
MsgEncode(msg, buf) enc.MsgEncode(msg, buf)
if string(buf) != encoded { if string(buf) != encoded {
t.Errorf("%v: encode result unexpected:", typ) t.Errorf("%c/%v: encode result unexpected:", enc, typ)
t.Errorf("\thave: %s", hexpkg.EncodeToString(buf)) t.Errorf("\thave: %s", hexpkg.EncodeToString(buf))
t.Errorf("\twant: %s", hexpkg.EncodeToString([]byte(encoded))) t.Errorf("\twant: %s", hexpkg.EncodeToString([]byte(encoded)))
} }
...@@ -112,7 +112,7 @@ func testMsgMarshal(t *testing.T, msg Msg, encoded string) { ...@@ -112,7 +112,7 @@ func testMsgMarshal(t *testing.T, msg Msg, encoded string) {
for l := len(buf) - 1; l >= 0; l-- { for l := len(buf) - 1; l >= 0; l-- {
func() { func() {
defer func() { defer func() {
subj := fmt.Sprintf("%v: encode(buf[:encodedLen-%v])", typ, len(encoded)-l) subj := fmt.Sprintf("%c/%v: encode(buf[:encodedLen-%v])", enc, typ, len(encoded)-l)
e := recover() e := recover()
if e == nil { if e == nil {
t.Errorf("%s did not panic", subj) t.Errorf("%s did not panic", subj)
...@@ -131,29 +131,29 @@ func testMsgMarshal(t *testing.T, msg Msg, encoded string) { ...@@ -131,29 +131,29 @@ func testMsgMarshal(t *testing.T, msg Msg, encoded string) {
} }
}() }()
MsgEncode(msg, buf[:l]) enc.MsgEncode(msg, buf[:l])
}() }()
} }
// msg.decode() == expected // msg.decode() == expected
data := []byte(encoded + "noise") data := []byte(encoded + "noise")
n, err := MsgDecode(msg2, data) n, err := enc.MsgDecode(msg2, data)
if err != nil { if err != nil {
t.Errorf("%v: decode error %v", typ, err) t.Errorf("%c/%v: decode error %v", enc, typ, err)
} }
if n != len(encoded) { if n != len(encoded) {
t.Errorf("%v: nread = %v ; want %v", typ, n, len(encoded)) t.Errorf("%c/%v: nread = %v ; want %v", enc, typ, n, len(encoded))
} }
if !reflect.DeepEqual(msg2, msg) { if !reflect.DeepEqual(msg2, msg) {
t.Errorf("%v: decode result unexpected: %v ; want %v", typ, msg2, msg) t.Errorf("%c/%v: decode result unexpected: %v ; want %v", enc, typ, msg2, msg)
} }
// decode must detect buffer overflow // decode must detect buffer overflow
for l := len(encoded) - 1; l >= 0; l-- { for l := len(encoded) - 1; l >= 0; l-- {
n, err = MsgDecode(msg2, data[:l]) n, err = enc.MsgDecode(msg2, data[:l])
if !(n == 0 && err == ErrDecodeOverflow) { if !(n == 0 && err == ErrDecodeOverflow) {
t.Errorf("%v: decode overflow not detected on [:%v]", typ, l) t.Errorf("%c/%v: decode overflow not detected on [:%v]", enc, typ, l)
} }
} }
...@@ -162,8 +162,8 @@ func testMsgMarshal(t *testing.T, msg Msg, encoded string) { ...@@ -162,8 +162,8 @@ func testMsgMarshal(t *testing.T, msg Msg, encoded string) {
// test encoding/decoding of messages // test encoding/decoding of messages
func TestMsgMarshal(t *testing.T) { func TestMsgMarshal(t *testing.T) {
var testv = []struct { var testv = []struct {
msg Msg msg Msg
encoded string // []byte encodedN string // []byte
}{ }{
// empty // empty
{&Ping{}, ""}, {&Ping{}, ""},
...@@ -198,6 +198,7 @@ func TestMsgMarshal(t *testing.T) { ...@@ -198,6 +198,7 @@ func TestMsgMarshal(t *testing.T) {
}, },
}, },
// N
hex("0102030405060708") + hex("0102030405060708") +
hex("00000022") + hex("00000022") +
hex("00000003") + hex("00000003") +
...@@ -219,6 +220,7 @@ func TestMsgMarshal(t *testing.T) { ...@@ -219,6 +220,7 @@ func TestMsgMarshal(t *testing.T) {
5: {4, 3, true}, 5: {4, 3, true},
}}, }},
// N
u32(4) + u32(4) +
u64(1) + u64(1) + u64(0) + hex("00") + u64(1) + u64(1) + u64(0) + hex("00") +
u64(2) + u64(7) + u64(1) + hex("01") + u64(2) + u64(7) + u64(1) + hex("01") +
...@@ -238,6 +240,7 @@ func TestMsgMarshal(t *testing.T) { ...@@ -238,6 +240,7 @@ func TestMsgMarshal(t *testing.T) {
MaxTID: 128, MaxTID: 128,
}, },
// N
u32(4) + u32(4) +
u32(1) + u32(7) + u32(1) + u32(7) +
u32(2) + u32(9) + u32(2) + u32(9) +
...@@ -248,12 +251,13 @@ func TestMsgMarshal(t *testing.T) { ...@@ -248,12 +251,13 @@ func TestMsgMarshal(t *testing.T) {
// uint32, []uint32 // uint32, []uint32
{&PartitionCorrupted{7, []NodeUUID{1, 3, 9, 4}}, {&PartitionCorrupted{7, []NodeUUID{1, 3, 9, 4}},
// N
u32(7) + u32(4) + u32(1) + u32(3) + u32(9) + u32(4), u32(7) + u32(4) + u32(1) + u32(3) + u32(9) + u32(4),
}, },
// uint32, Address, string, IdTime // uint32, Address, string, IdTime
{&RequestIdentification{CLIENT, 17, Address{"localhost", 7777}, "myname", 0.12345678, []string{"room1", "rack234"}, []uint32{3,4,5} }, {&RequestIdentification{CLIENT, 17, Address{"localhost", 7777}, "myname", 0.12345678, []string{"room1", "rack234"}, []uint32{3,4,5} },
// N
u8(2) + u32(17) + u32(9) + u8(2) + u32(17) + u32(9) +
"localhost" + u16(7777) + "localhost" + u16(7777) +
u32(6) + "myname" + u32(6) + "myname" +
...@@ -265,14 +269,17 @@ func TestMsgMarshal(t *testing.T) { ...@@ -265,14 +269,17 @@ func TestMsgMarshal(t *testing.T) {
// IdTime, empty Address, int32 // IdTime, empty Address, int32
{&NotifyNodeInformation{1504466245.926185, []NodeInfo{ {&NotifyNodeInformation{1504466245.926185, []NodeInfo{
{CLIENT, Address{}, UUID(CLIENT, 1), RUNNING, 1504466245.925599}}}, {CLIENT, Address{}, UUID(CLIENT, 1), RUNNING, 1504466245.925599}}},
// N
hex("41d66b15517b469d") + u32(1) + hex("41d66b15517b469d") + u32(1) +
u8(2) + u32(0) /* <- ø Address */ + hex("e0000001") + u8(2) + u8(2) + u32(0) /* <- ø Address */ + hex("e0000001") + u8(2) +
hex("41d66b15517b3d04"), hex("41d66b15517b3d04"),
}, },
// empty IdTime // empty IdTime
{&NotifyNodeInformation{IdTimeNone, []NodeInfo{}}, hex("ffffffffffffffff") + hex("00000000")}, {&NotifyNodeInformation{IdTimeNone, []NodeInfo{}},
// N
hex("ffffffffffffffff") + hex("00000000"),
},
// TODO we need tests for: // TODO we need tests for:
// []varsize + trailing // []varsize + trailing
...@@ -280,7 +287,7 @@ func TestMsgMarshal(t *testing.T) { ...@@ -280,7 +287,7 @@ func TestMsgMarshal(t *testing.T) {
} }
for _, tt := range testv { for _, tt := range testv {
testMsgMarshal(t, tt.msg, tt.encoded) testMsgMarshal(t, 'N', tt.msg, tt.encodedN)
} }
} }
...@@ -288,23 +295,27 @@ func TestMsgMarshal(t *testing.T) { ...@@ -288,23 +295,27 @@ func TestMsgMarshal(t *testing.T) {
// this way we additionally lightly check encode / decode overflow behaviour for all types. // this way we additionally lightly check encode / decode overflow behaviour for all types.
func TestMsgMarshalAllOverflowLightly(t *testing.T) { func TestMsgMarshalAllOverflowLightly(t *testing.T) {
for _, typ := range msgTypeRegistry { for _, typ := range msgTypeRegistry {
// zero-value for a type for _, enc := range []Encoding{'N'} {
msg := reflect.New(typ).Interface().(Msg) // zero-value for a type
l := MsgEncodedLen(msg) msg := reflect.New(typ).Interface().(Msg)
zerol := make([]byte, l) l := enc.MsgEncodedLen(msg)
// decoding will turn nil slice & map into empty allocated ones. zerol := make([]byte, l)
// we need it so that reflect.DeepEqual works for msg encode/decode comparison // decoding will turn nil slice & map into empty allocated ones.
n, err := MsgDecode(msg, zerol) // we need it so that reflect.DeepEqual works for msg encode/decode comparison
if !(n == l && err == nil) { n, err := enc.MsgDecode(msg, zerol)
t.Errorf("%v: zero-decode unexpected: %v, %v ; want %v, nil", typ, n, err, l) if !(n == l && err == nil) {
} t.Errorf("%c/%v: zero-decode unexpected: %v, %v ; want %v, nil", enc, typ, n, err, l)
}
testMsgMarshal(t, msg, string(zerol)) testMsgMarshal(t, enc, msg, string(zerol))
}
} }
} }
// Verify overflow handling on decode len checks // Verify overflow handling on decodeN len checks
func TestMsgDecodeLenOverflow(t *testing.T) { func TestMsgDecodeLenOverflowN(t *testing.T) {
enc := Encoding('N')
var testv = []struct { var testv = []struct {
msg Msg // of type to decode into msg Msg // of type to decode into
data string // []byte - tricky data to exercise decoder u32 len checks overflow data string // []byte - tricky data to exercise decoder u32 len checks overflow
...@@ -325,7 +336,7 @@ func TestMsgDecodeLenOverflow(t *testing.T) { ...@@ -325,7 +336,7 @@ func TestMsgDecodeLenOverflow(t *testing.T) {
} }
}() }()
n, err := MsgDecode(tt.msg, data) n, err := enc.MsgDecode(tt.msg, data)
if !(n == 0 && err == ErrDecodeOverflow) { if !(n == 0 && err == ErrDecodeOverflow) {
t.Errorf("%T: decode %x\nhave: %d, %v\nwant: %d, %v", tt.msg, data, t.Errorf("%T: decode %x\nhave: %d, %v\nwant: %d, %v", tt.msg, data,
n, err, 0, ErrDecodeOverflow) n, err, 0, ErrDecodeOverflow)
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment