Commit 39545b9c authored by Kirill Smelkov's avatar Kirill Smelkov

go/neo/proto: Unexport Msg.NEOMsg{Encode,Decode}

And provide only single top-level entry-points to encode/decode
messages. As of now the entry points are just plain forwarding, but with
introducing of msgpack and encodings, they will take into account
through which encoding a message has to be encoded/decoded.
parent aa16f0f2
...@@ -1321,7 +1321,7 @@ func (c *Conn) err(op string, e error) error { ...@@ -1321,7 +1321,7 @@ func (c *Conn) err(op string, e error) error {
// pktEncode allocates pktBuf and encodes msg into it. // pktEncode allocates pktBuf and encodes msg into it.
func pktEncode(connId uint32, msg proto.Msg) *pktBuf { func pktEncode(connId uint32, msg proto.Msg) *pktBuf {
l := msg.NEOMsgEncodedLen() l := proto.MsgEncodedLen(msg)
buf := pktAlloc(proto.PktHeaderLen + l) buf := pktAlloc(proto.PktHeaderLen + l)
h := buf.Header() h := buf.Header()
...@@ -1329,7 +1329,7 @@ func pktEncode(connId uint32, msg proto.Msg) *pktBuf { ...@@ -1329,7 +1329,7 @@ func pktEncode(connId uint32, msg proto.Msg) *pktBuf {
h.MsgCode = packed.Hton16(proto.MsgCode(msg)) h.MsgCode = packed.Hton16(proto.MsgCode(msg))
h.MsgLen = packed.Hton32(uint32(l)) // XXX casting: think again h.MsgLen = packed.Hton32(uint32(l)) // XXX casting: think again
msg.NEOMsgEncode(buf.Payload()) proto.MsgEncode(msg, buf.Payload())
return buf return buf
} }
...@@ -1376,7 +1376,7 @@ func (c *Conn) Recv() (proto.Msg, error) { ...@@ -1376,7 +1376,7 @@ func (c *Conn) Recv() (proto.Msg, error) {
// msg := reflect.NewAt(msgType, bufAlloc(msgType.Size()) // msg := reflect.NewAt(msgType, bufAlloc(msgType.Size())
_, err = msg.NEOMsgDecode(payload) _, err = proto.MsgDecode(msg, payload)
if err != nil { if err != nil {
return nil, c.err("decode", err) // XXX "decode:" is already in ErrDecodeOverflow return nil, c.err("decode", err) // XXX "decode:" is already in ErrDecodeOverflow
} }
...@@ -1428,7 +1428,7 @@ func (c *Conn) Expect(msgv ...proto.Msg) (which int, err error) { ...@@ -1428,7 +1428,7 @@ func (c *Conn) Expect(msgv ...proto.Msg) (which int, err error) {
for i, msg := range msgv { for i, msg := range msgv {
if proto.MsgCode(msg) == msgCode { if proto.MsgCode(msg) == msgCode {
_, err := msg.NEOMsgDecode(payload) _, err := proto.MsgDecode(msg, payload)
if err != nil { if err != nil {
return -1, c.err("decode", err) return -1, c.err("decode", err)
} }
......
...@@ -157,8 +157,8 @@ func (t *T) xverifyPkt(pkt *pktBuf, connid uint32, msgcode uint16, payload []byt ...@@ -157,8 +157,8 @@ func (t *T) xverifyPkt(pkt *pktBuf, connid uint32, msgcode uint16, payload []byt
// Verify pktBuf to match expected message. // Verify pktBuf to match expected message.
func (t *T) xverifyPktMsg(pkt *pktBuf, connid uint32, msg proto.Msg) { func (t *T) xverifyPktMsg(pkt *pktBuf, connid uint32, msg proto.Msg) {
data := make([]byte, msg.NEOMsgEncodedLen()) data := make([]byte, proto.MsgEncodedLen(msg))
msg.NEOMsgEncode(data) proto.MsgEncode(msg, data)
t.xverifyPkt(pkt, connid, proto.MsgCode(msg), data) t.xverifyPkt(pkt, connid, proto.MsgCode(msg), data)
} }
......
...@@ -95,7 +95,7 @@ func (pkt *pktBuf) String() string { ...@@ -95,7 +95,7 @@ func (pkt *pktBuf) String() string {
// XXX dup wrt Conn.Recv // XXX dup wrt Conn.Recv
msg := reflect.New(msgType).Interface().(proto.Msg) msg := reflect.New(msgType).Interface().(proto.Msg)
n, err := msg.NEOMsgDecode(payload) n, err := proto.MsgDecode(msg, payload)
if err != nil { if err != nil {
s += fmt.Sprintf(" (%s) %v; [%d]: % x", msgType.Name(), err, len(payload), payload) s += fmt.Sprintf(" (%s) %v; [%d]: % x", msgType.Name(), err, len(payload), payload)
} else { } else {
......
...@@ -114,18 +114,36 @@ type Msg interface { ...@@ -114,18 +114,36 @@ type Msg interface {
// on the wire. // on the wire.
neoMsgCode() uint16 neoMsgCode() uint16
// NEOMsgEncodedLen returns how much space is needed to encode current message payload. // neoMsgEncodedLen returns how much space is needed to encode current message payload.
NEOMsgEncodedLen() int neoMsgEncodedLen() int
// NEOMsgEncode encodes current message state into buf. // neoMsgEncode encodes current message state into buf.
// //
// len(buf) must be >= neoMsgEncodedLen(). // len(buf) must be >= neoMsgEncodedLen().
NEOMsgEncode(buf []byte) neoMsgEncode(buf []byte)
// NEOMsgDecode decodes data into message in-place. // neoMsgDecode decodes data into message in-place.
NEOMsgDecode(data []byte) (nread int, err error) neoMsgDecode(data []byte) (nread int, err error)
} }
// MsgEncodedLen returns how much space is needed to encode msg payload.
func MsgEncodedLen(msg Msg) int {
return msg.neoMsgEncodedLen()
}
// MsgEncode encodes msg state into buf.
//
// len(buf) must be >= MsgEncodedLen(m).
func MsgEncode(msg Msg, buf []byte) {
msg.neoMsgEncode(buf)
}
// MsgDecode decodes data into msg in-place.
func MsgDecode(msg Msg, data []byte) (nread int, err error) {
return msg.neoMsgDecode(data)
}
// ErrDecodeOverflow is the error returned by neoMsgDecode when decoding hits buffer overflow // ErrDecodeOverflow is the error returned by neoMsgDecode when decoding hits buffer overflow
var ErrDecodeOverflow = errors.New("decode: buffer overflow") var ErrDecodeOverflow = errors.New("decode: buffer overflow")
......
...@@ -91,7 +91,7 @@ func testMsgMarshal(t *testing.T, msg Msg, encoded string) { ...@@ -91,7 +91,7 @@ func testMsgMarshal(t *testing.T, msg Msg, encoded string) {
// msg.encode() == expected // msg.encode() == expected
msgCode := msg.neoMsgCode() msgCode := msg.neoMsgCode()
n := msg.NEOMsgEncodedLen() n := MsgEncodedLen(msg)
msgType := MsgType(msgCode) msgType := MsgType(msgCode)
if msgType != typ { if msgType != typ {
t.Errorf("%v: msgCode = %v which corresponds to %v", typ, msgCode, msgType) t.Errorf("%v: msgCode = %v which corresponds to %v", typ, msgCode, msgType)
...@@ -101,7 +101,7 @@ func testMsgMarshal(t *testing.T, msg Msg, encoded string) { ...@@ -101,7 +101,7 @@ func testMsgMarshal(t *testing.T, msg Msg, encoded string) {
} }
buf := make([]byte, n) buf := make([]byte, n)
msg.NEOMsgEncode(buf) MsgEncode(msg, buf)
if string(buf) != encoded { if string(buf) != encoded {
t.Errorf("%v: encode result unexpected:", typ) t.Errorf("%v: encode result unexpected:", typ)
t.Errorf("\thave: %s", hexpkg.EncodeToString(buf)) t.Errorf("\thave: %s", hexpkg.EncodeToString(buf))
...@@ -131,13 +131,13 @@ func testMsgMarshal(t *testing.T, msg Msg, encoded string) { ...@@ -131,13 +131,13 @@ func testMsgMarshal(t *testing.T, msg Msg, encoded string) {
} }
}() }()
msg.NEOMsgEncode(buf[:l]) MsgEncode(msg, buf[:l])
}() }()
} }
// msg.decode() == expected // msg.decode() == expected
data := []byte(encoded + "noise") data := []byte(encoded + "noise")
n, err := msg2.NEOMsgDecode(data) n, err := MsgDecode(msg2, data)
if err != nil { if err != nil {
t.Errorf("%v: decode error %v", typ, err) t.Errorf("%v: decode error %v", typ, err)
} }
...@@ -151,7 +151,7 @@ func testMsgMarshal(t *testing.T, msg Msg, encoded string) { ...@@ -151,7 +151,7 @@ func testMsgMarshal(t *testing.T, msg Msg, encoded string) {
// decode must detect buffer overflow // decode must detect buffer overflow
for l := len(encoded) - 1; l >= 0; l-- { for l := len(encoded) - 1; l >= 0; l-- {
n, err = msg2.NEOMsgDecode(data[:l]) n, err = MsgDecode(msg2, data[:l])
if !(n == 0 && err == ErrDecodeOverflow) { if !(n == 0 && err == ErrDecodeOverflow) {
t.Errorf("%v: decode overflow not detected on [:%v]", typ, l) t.Errorf("%v: decode overflow not detected on [:%v]", typ, l)
} }
...@@ -290,11 +290,11 @@ func TestMsgMarshalAllOverflowLightly(t *testing.T) { ...@@ -290,11 +290,11 @@ func TestMsgMarshalAllOverflowLightly(t *testing.T) {
for _, typ := range msgTypeRegistry { for _, typ := range msgTypeRegistry {
// zero-value for a type // zero-value for a type
msg := reflect.New(typ).Interface().(Msg) msg := reflect.New(typ).Interface().(Msg)
l := msg.NEOMsgEncodedLen() l := MsgEncodedLen(msg)
zerol := make([]byte, l) zerol := make([]byte, l)
// decoding will turn nil slice & map into empty allocated ones. // decoding will turn nil slice & map into empty allocated ones.
// we need it so that reflect.DeepEqual works for msg encode/decode comparison // we need it so that reflect.DeepEqual works for msg encode/decode comparison
n, err := msg.NEOMsgDecode(zerol) n, err := MsgDecode(msg, zerol)
if !(n == l && err == nil) { if !(n == l && err == nil) {
t.Errorf("%v: zero-decode unexpected: %v, %v ; want %v, nil", typ, n, err, l) t.Errorf("%v: zero-decode unexpected: %v, %v ; want %v, nil", typ, n, err, l)
} }
...@@ -325,7 +325,7 @@ func TestMsgDecodeLenOverflow(t *testing.T) { ...@@ -325,7 +325,7 @@ func TestMsgDecodeLenOverflow(t *testing.T) {
} }
}() }()
n, err := tt.msg.NEOMsgDecode(data) n, err := MsgDecode(tt.msg, data)
if !(n == 0 && err == ErrDecodeOverflow) { if !(n == 0 && err == ErrDecodeOverflow) {
t.Errorf("%T: decode %x\nhave: %d, %v\nwant: %d, %v", tt.msg, data, t.Errorf("%T: decode %x\nhave: %d, %v\nwant: %d, %v", tt.msg, data,
n, err, 0, ErrDecodeOverflow) n, err, 0, ErrDecodeOverflow)
......
...@@ -26,9 +26,9 @@ This program generates marshalling code for message types defined in proto.go . ...@@ -26,9 +26,9 @@ This program generates marshalling code for message types defined in proto.go .
For every type 4 methods are generated in accordance with neo.Msg interface: For every type 4 methods are generated in accordance with neo.Msg interface:
neoMsgCode() uint16 neoMsgCode() uint16
NEOMsgEncodedLen() int neoMsgEncodedLen() int
NEOMsgEncode(buf []byte) neoMsgEncode(buf []byte)
NEOMsgDecode(data []byte) (nread int, err error) neoMsgDecode(data []byte) (nread int, err error)
List of message types is obtained via searching through proto.go AST - looking List of message types is obtained via searching through proto.go AST - looking
for appropriate struct declarations there. for appropriate struct declarations there.
...@@ -606,7 +606,7 @@ type sizer struct { ...@@ -606,7 +606,7 @@ type sizer struct {
// //
// when type is recursively walked, for every case code to update `data[n:]` is generated. // when type is recursively walked, for every case code to update `data[n:]` is generated.
// no overflow checks are generated as by neo.Msg interface provided data // no overflow checks are generated as by neo.Msg interface provided data
// buffer should have at least payloadLen length returned by NEOMsgEncodedLen() // buffer should have at least payloadLen length returned by neoMsgEncodedLen()
// (the size computed by sizer). // (the size computed by sizer).
// //
// the code emitted looks like: // the code emitted looks like:
...@@ -615,7 +615,7 @@ type sizer struct { ...@@ -615,7 +615,7 @@ type sizer struct {
// encode<typ2>(data[n2:], path2) // encode<typ2>(data[n2:], path2)
// ... // ...
// //
// TODO encode have to care in NEOMsgEncode to emit preamble such that bound // TODO encode have to care in neoMsgEncode to emit preamble such that bound
// checking is performed only once (currently compiler emits many of them) // checking is performed only once (currently compiler emits many of them)
type encoder struct { type encoder struct {
commonCodeGen commonCodeGen
...@@ -663,7 +663,7 @@ var _ CodeGenerator = (*decoder)(nil) ...@@ -663,7 +663,7 @@ var _ CodeGenerator = (*decoder)(nil)
func (s *sizer) generatedCode() string { func (s *sizer) generatedCode() string {
code := Buffer{} code := Buffer{}
// prologue // prologue
code.emit("func (%s *%s) NEOMsgEncodedLen() int {", s.recvName, s.typeName) code.emit("func (%s *%s) neoMsgEncodedLen() int {", s.recvName, s.typeName)
if s.varUsed["size"] { if s.varUsed["size"] {
code.emit("var %s int", s.var_("size")) code.emit("var %s int", s.var_("size"))
} }
...@@ -684,7 +684,7 @@ func (s *sizer) generatedCode() string { ...@@ -684,7 +684,7 @@ func (s *sizer) generatedCode() string {
func (e *encoder) generatedCode() string { func (e *encoder) generatedCode() string {
code := Buffer{} code := Buffer{}
// prologue // prologue
code.emit("func (%s *%s) NEOMsgEncode(data []byte) {", e.recvName, e.typeName) code.emit("func (%s *%s) neoMsgEncode(data []byte) {", e.recvName, e.typeName)
code.Write(e.buf.Bytes()) code.Write(e.buf.Bytes())
...@@ -796,7 +796,7 @@ func (d *decoder) generatedCode() string { ...@@ -796,7 +796,7 @@ func (d *decoder) generatedCode() string {
code := Buffer{} code := Buffer{}
// prologue // prologue
code.emit("func (%s *%s) NEOMsgDecode(data []byte) (int, error) {", d.recvName, d.typeName) code.emit("func (%s *%s) neoMsgDecode(data []byte) (int, error) {", d.recvName, d.typeName)
if d.varUsed["nread"] { if d.varUsed["nread"] {
code.emit("var %v uint64", d.var_("nread")) code.emit("var %v uint64", d.var_("nread"))
} }
......
This diff is collapsed.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment