// Copyright (C) 2016-2021 Nexedi SA and Contributors. // Kirill Smelkov <kirr@nexedi.com> // // This program is free software: you can Use, Study, Modify and Redistribute // it under the terms of the GNU General Public License version 3, or (at your // option) any later version, as published by the Free Software Foundation. // // You can also Link and Combine this program with other software covered by // the terms of any of the Free Software licenses or any of the Open Source // Initiative approved licenses and Convey the resulting work. Corresponding // source of such a combination shall include the source code for all other // software used. // // This program is distributed WITHOUT ANY WARRANTY; without even the implied // warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. // // See COPYING file for full licensing terms. // See https://www.nexedi.com/licensing for rationale and options. // +build ignore /* NEO. Protocol module. Code generator This program generates marshalling code for message types defined in proto.go . For every type the following methods are generated in accordance with neo.Msg interface: neoMsgCode() uint16 ; E stands for 'N' and 'M' encodings neoMsgEncodedLen<E>() int neoMsgEncode<E>(buf []byte) neoMsgDecode<E>(data []byte) (nread int, err error) List of message types is obtained via searching through proto.go AST - looking for appropriate struct declarations there. Code generation for a type is organized via recursively walking through type's (sub-)elements and generating specialized code on leaf items (intX, slices, maps, ...). Top-level generation driver is in generateCodecCode(). It accepts type specification and something that performs actual leaf-nodes code generation (CodeGenerator interface). For each encoding there are 3 particular codegenerators implemented - sizer<E>, encoder<E> & decoder<E> - to generate each of the needed method functions. The structure of whole process is very similar to what would be happening at runtime if marshalling was reflect based, but statically with go/types we don't spend runtime time on decisions and thus generated marshallers are faster. For encoding format compatibility with Python NEO (neo/lib/protocol.py) is preserved in order for two implementations to be able to communicate to each other. NOTE we do no try to emit very clever code - for cases where compiler can do a good job the work is delegated to it. -------- Also along the way types registry table is generated for msgCode -> message type lookup needed in packet receive codepath. */ package main import ( "bytes" "fmt" "go/ast" "go/format" "go/importer" "go/parser" "go/token" "go/types" "log" "os" "sort" "strings" "lab.nexedi.com/kirr/neo/go/neo/internal/msgpack" ) // parsed & typechecked input var fset = token.NewFileSet() var fileMap = map[string]*ast.File{} // fileName -> AST var pkgMap = map[string]*types.Package{} // pkgPath -> Package var typeInfo = &types.Info{ Types: make(map[ast.Expr]types.TypeAndValue), Defs: make(map[*ast.Ident]types.Object), } // complete position of something with .Pos() func pos(x interface{ Pos() token.Pos }) token.Position { return fset.Position(x.Pos()) } // get type name in context of neo package var zodbPkg *types.Package var protoPkg *types.Package func typeName(typ types.Type) string { qf := func(pkg *types.Package) string { switch pkg { case protoPkg: // same package - unqualified return "" case zodbPkg: // zodb is imported - only name return pkg.Name() default: // fully qualified otherwise return pkg.Path() } } return types.TypeString(typ, qf) } // zodb.Tid and zodb.Oid types var zodbTid types.Type var zodbOid types.Type var neo_customCodecN *types.Interface // type of neo.customCodecN var memBuf types.Type // type of mem.Buf // registry of enums var enumRegistry = map[types.Type]int{} // type -> enum type serial // Define INVALID_ID because encoding behaves different depending // on if we have an INVALID_{TID,OID} or not. // // NOTE This assumes that INVALID_OID == INVALID_TID // // XXX Duplication wrt proto.go const INVALID_ID uint64 = 1<<64 - 1 // bytes.Buffer + bell & whistles type Buffer struct { bytes.Buffer } func (b *Buffer) emit(format string, a ...interface{}) { fmt.Fprintf(b, format+"\n", a...) } // importer that takes into account our already-loaded packages type localImporter struct { types.Importer } func (li *localImporter) Import(path string) (*types.Package, error) { pkg := pkgMap[path] if pkg != nil { return pkg, nil } return li.Importer.Import(path) } // importer instance - only 1 so that for 2 top-level packages same dependent // packages are not reimported several times. // // don't use importer.Default - this importer uses only binaries for installed // packages, which might a) get stale wrt sources, or b) be completely missing. // https://github.com/golang/go/issues/11415 var localImporterObj = &localImporter{importer.For("source", nil)} func loadPkg(pkgPath string, sources ...string) *types.Package { var filev []*ast.File // parse for _, src := range sources { f, err := parser.ParseFile(fset, src, nil, parser.ParseComments) if err != nil { log.Fatalf("parse: %v", err) } fileMap[src] = f filev = append(filev, f) } //ast.Print(fset, fv[0]) //return // typecheck conf := types.Config{Importer: localImporterObj} pkg, err := conf.Check(pkgPath, fset, filev, typeInfo) if err != nil { log.Fatalf("typecheck: %v", err) } pkgMap[pkgPath] = pkg return pkg } // `//neo:proto ...` annotations type Annotation struct { typeonly bool answer bool enum bool } // parse checks doc for specific comment annotations and, if present, loads them. func (a *Annotation) parse(doc *ast.CommentGroup) { if doc == nil { return // for many types .Doc = nil if there is no comments } for _, comment := range doc.List { cpos := pos(comment) if !(cpos.Column == 1 && strings.HasPrefix(comment.Text, "//neo:proto ")) { continue } __ := strings.SplitN(comment.Text, " ", 2) arg := __[1] switch arg { case "typeonly": if a.typeonly { log.Fatalf("%v: duplicate `typeonly`", cpos) } a.typeonly = true case "answer": if a.answer { log.Fatalf("%v: duplicate `answer`", cpos) } a.answer = true case "enum": if a.enum { log.Fatalf("%v: duplicate `enum`", cpos) } a.enum = true default: log.Fatalf("%v: unknown neo:proto directive %q", cpos, arg) } } } // MsgCode represents message code in symbolic form: `serial (| answerBit)?` type MsgCode struct { msgSerial int answer bool } func (c MsgCode) String() string { s := fmt.Sprintf("%d", c.msgSerial) if c.answer { s += " | answerBit" } return s } // sort MsgCode by (serial, answer) type BySerial []MsgCode func (v BySerial) Less(i, j int) bool { return (v[i].msgSerial < v[j].msgSerial) || (v[i].msgSerial == v[j].msgSerial && !v[i].answer && v[j].answer) } func (v BySerial) Swap(i, j int) { v[i], v[j] = v[j], v[i] } func (v BySerial) Len() int { return len(v) } // ---------------------------------------- func main() { var err error log.SetFlags(0) // go through proto.go and AST'ify & typecheck it zodbPkg = loadPkg("lab.nexedi.com/kirr/neo/go/zodb", "../../zodb/zodb.go") protoPkg = loadPkg("lab.nexedi.com/kirr/neo/go/neo/proto", "proto.go") // extract neo.customCodecN cc := xlookup(protoPkg, "customCodecN") var ok bool neo_customCodecN, ok = cc.Type().Underlying().(*types.Interface) if !ok { log.Fatal("customCodecN is not interface (got %v)", cc.Type()) } // extract mem.Buf memPath := "lab.nexedi.com/kirr/go123/mem" var memPkg *types.Package for _, pkg := range zodbPkg.Imports() { if pkg.Path() == memPath { memPkg = pkg break } } if memPkg == nil { log.Fatalf("cannot find `%s` in zodb imports", memPath) } __ := memPkg.Scope().Lookup("Buf") if __ == nil { log.Fatal("cannot find `mem.Buf`") } memBuf = __.Type() // extract zodb.Tid and zodb.Oid zodbTid = xlookup(zodbPkg, "Tid").Type() zodbOid = xlookup(zodbPkg, "Oid").Type() // prologue f := fileMap["proto.go"] buf := Buffer{} buf.emit(`// Code generated by protogen.go; DO NOT EDIT. package proto // NEO. protocol messages to/from wire marshalling. import ( "encoding/binary" "math" "reflect" "sort" "github.com/tinylib/msgp/msgp" "lab.nexedi.com/kirr/neo/go/neo/internal/msgpack" "lab.nexedi.com/kirr/go123/mem" "lab.nexedi.com/kirr/neo/go/zodb" )`) msgTypeRegistry := map[MsgCode]string{} // msgCode -> typename // go over message types declaration and generate marshal code for them buf.emit("// messages marshalling\n") msgSerial := 0 enumSerial := 0 for _, decl := range f.Decls { // we look for types (which can be only under GenDecl) gendecl, ok := decl.(*ast.GenDecl) if !ok || gendecl.Tok != token.TYPE { continue } //fmt.Println(gendecl) //ast.Print(fset, gendecl) //continue // `//neo:proto ...` annotations for whole decl // (e.g. <here> type ( t1 struct{...}; t2 struct{...} ) declAnnotation := Annotation{} declAnnotation.parse(gendecl.Doc) for _, spec := range gendecl.Specs { typespec := spec.(*ast.TypeSpec) // must be because tok = TYPE typename := typespec.Name.Name // `//neo:proto ...` annotation for this particular type specAnnotation := declAnnotation // inheriting from decl specAnnotation.parse(typespec.Doc) // remember enum types if specAnnotation.enum { typ := typeInfo.Defs[typespec.Name].Type() enumRegistry[typ]= enumSerial enumSerial++ } // messages are only struct types without typeonly annotation if _, ok := typespec.Type.(*ast.StructType); !ok { continue } if specAnnotation.typeonly { continue } // generate code for this type to implement neo.Msg var msgCode MsgCode msgCode.answer = specAnnotation.answer || strings.HasPrefix(typename, "Answer") msgCode.msgSerial = msgSerial // increment msgSerial only by +1 when going from // request1->request2 in `Request1 Answer1 Request2`. // Unlike as it was in pre-msgpack protocol, the global // increment is still +1, only for the answer packet // itself it's the same as the request packet. This means // in the post-msgpack protocol there are gaps. if msgCode.answer && typename != "Error" { msgCode.msgSerial = msgSerial - 1 } fmt.Fprintf(&buf, "// %s. %s\n\n", msgCode, typename) buf.emit("func (*%s) neoMsgCode() uint16 {", typename) buf.emit("return %s", msgCode) buf.emit("}\n") buf.WriteString(generateCodecCode(typespec, &sizerN{})) buf.WriteString(generateCodecCode(typespec, &encoderN{})) buf.WriteString(generateCodecCode(typespec, &decoderN{})) // TODO keep all M routines separate from N for code locality ? buf.WriteString(generateCodecCode(typespec, &sizerM{})) buf.WriteString(generateCodecCode(typespec, &encoderM{})) buf.WriteString(generateCodecCode(typespec, &decoderM{})) msgTypeRegistry[msgCode] = typename msgSerial++ } } // now generate message types registry buf.emit("\n// registry of message types") buf.emit("var msgTypeRegistry = map[uint16]reflect.Type {") // ordered by msgCode msgCodeV := []MsgCode{} for msgCode := range msgTypeRegistry { msgCodeV = append(msgCodeV, msgCode) } sort.Sort(BySerial(msgCodeV)) for _, msgCode := range msgCodeV { buf.emit("%v: reflect.TypeOf(%v{}),", msgCode, msgTypeRegistry[msgCode]) } buf.emit("}") // format & output generated code code, err := format.Source(buf.Bytes()) if err != nil { panic(err) // should not happen } _, err = os.Stdout.Write(code) if err != nil { log.Fatal(err) } } // info about encode/decode of a basic fixed-size type type basicCodecN struct { wireSize int encode string decode string } var basicTypesN = map[types.BasicKind]basicCodecN{ // encode: %v %v will be `data[n:]`, value // decode: %v will be `data[n:]` (and already made sure data has more enough bytes to read) types.Bool: {1, "(%v)[0] = bool2byte(%v)", "byte2bool((%v)[0])"}, types.Int8: {1, "(%v)[0] = uint8(%v)", "int8((%v)[0])"}, types.Int16: {2, "binary.BigEndian.PutUint16(%v, uint16(%v))", "int16(binary.BigEndian.Uint16(%v))"}, types.Int32: {4, "binary.BigEndian.PutUint32(%v, uint32(%v))", "int32(binary.BigEndian.Uint32(%v))"}, types.Int64: {8, "binary.BigEndian.PutUint64(%v, uint64(%v))", "int64(binary.BigEndian.Uint64(%v))"}, types.Uint8: {1, "(%v)[0] = %v", "(%v)[0]"}, types.Uint16: {2, "binary.BigEndian.PutUint16(%v, %v)", "binary.BigEndian.Uint16(%v)"}, types.Uint32: {4, "binary.BigEndian.PutUint32(%v, %v)", "binary.BigEndian.Uint32(%v)"}, types.Uint64: {8, "binary.BigEndian.PutUint64(%v, %v)", "binary.BigEndian.Uint64(%v)"}, types.Float64: {8, "float64_neoEncode(%v, %v)", "float64_neoDecode(%v)"}, } // does a type have fixed wire size and, if yes, what it is? func typeSizeFixed(encoding byte, typ types.Type) (wireSize int, ok bool) { // pass typ through sizer and see if encoded size is fixed or not var size SymSize switch encoding { case 'M': s := &sizerM{} codegenType("x", typ, nil, s) size = s.size case 'N': s := &sizerN{} codegenType("x", typ, nil, s) size = s.size default: panic("bad encoding") } if !size.IsNumeric() { // no symbolic part return 0, false } return size.num, true } // interface of a codegenerator (for sizer/encoder/decoder) type CodeGenerator interface { // codegenerator generates code for this encoding encoding() byte // tell codegen it should generate code for which type & receiver name setFunc(recvName, typeName string, typ types.Type, encoding byte) // generate code to process a basic fixed type (not string) // userType is type actually used in source (for which typ is underlying), or nil // path is associated data member - e.g. p.Address.Port (to read from or write to) genBasic(path string, typ *types.Basic, userType types.Type) // generate code to process slice or map // (see genBasic for argument details) genSlice(path string, typ *types.Slice, obj types.Object) genMap(path string, typ *types.Map, obj types.Object) // particular case of array or slice with 1-byte elem // // NOTE this particular case is kept separate because for 1-byte // elements there are no byteordering issues so data can be directly // either accessed or copied. genArray1(path string, typ *types.Array) genSlice1(path string, typ types.Type) // generate code to process header of struct genStructHead(path string, typ *types.Struct, userType types.Type) // mem.Buf genBuf(path string) // get generated code. generatedCode() string } // interface for codegenerators to inject themselves into {sizer/encoder/decoder}Common. type CodeGenCustomize interface { CodeGenerator // generate code to process slice or map header genSliceHead(path string, typ *types.Slice, obj types.Object) genMapHead(path string, typ *types.Map, obj types.Object) } // X reports encoding=X type N struct{}; func (_ *N) encoding() byte { return 'N' } type M struct{}; func (_ *M) encoding() byte { return 'M' } // common part of codegenerators type commonCodeGen struct { buf Buffer // code is emitted here recvName string // receiver/type for top-level func typeName string // or empty typ types.Type enc byte // encoding variant varUsed map[string]bool // whether a variable was used } func (c *commonCodeGen) emit(format string, a ...interface{}) { c.buf.emit(format, a...) } func (c *commonCodeGen) setFunc(recvName, typeName string, typ types.Type, encoding byte) { c.recvName = recvName c.typeName = typeName c.typ = typ c.enc = encoding } // get variable for varname (and automatically mark this var as used) func (c *commonCodeGen) var_(varname string) string { if c.varUsed == nil { c.varUsed = make(map[string]bool) } c.varUsed[varname] = true return varname } // pathName returns name representing path or assignto. func (c *commonCodeGen) pathName(path string) string { // Type, p.f1.f2 -> Type.f1.f2 return strings.Join(append([]string{c.typeName}, strings.Split(path, ".")[1:]...), ".") } // symbolic size // consists of numeric & symbolic expression parts // size is num + expr1 + expr2 + ... type SymSize struct { num int // numeric part of size exprv []string // symbolic part of size } func (s *SymSize) Add(n int) { s.num += n } func (s *SymSize) AddExpr(format string, a ...interface{}) { expr := fmt.Sprintf(format, a...) s.exprv = append(s.exprv, expr) } func (s *SymSize) String() string { // num + expr1 + expr2 + ... (omitting what is possible) sizev := []string{} if s.num != 0 { sizev = append(sizev, fmt.Sprintf("%v", s.num)) } exprStr := s.ExprString() if exprStr != "" { sizev = append(sizev, exprStr) } sizeStr := strings.Join(sizev, " + ") if sizeStr == "" { sizeStr = "0" } return sizeStr } // expression part of size (without numeric part) func (s *SymSize) ExprString() string { return strings.Join(s.exprv, " + ") } func (s *SymSize) IsZero() bool { return s.num == 0 && len(s.exprv) == 0 } // is it numeric only? func (s *SymSize) IsNumeric() bool { return len(s.exprv) == 0 } func (s *SymSize) Reset() { *s = SymSize{} } // decoder overflow check state type OverflowCheck struct { // accumulator for size to check at overflow check point checkSize SymSize // whether overflow was already checked for current decodings // (if yes, checkSize updates will be ignored) checked bool // stack operated by {Push,Pop}Checked checkedStk []bool // whether any 'goto overflow' has been emitted for // current decoder gotoEmitted bool } // push/pop checked state func (o *OverflowCheck) PushChecked(checked bool) { o.checkedStk = append(o.checkedStk, o.checked) o.checked = checked } func (o *OverflowCheck) PopChecked() bool { popret := o.checked l := len(o.checkedStk) o.checked = o.checkedStk[l-1] o.checkedStk = o.checkedStk[:l-1] return popret } // Add and AddExpr update .checkSize accordingly, but only if overflow was not // already marked as checked func (o *OverflowCheck) Add(n int) { if !o.checked { o.checkSize.Add(n) } } func (o *OverflowCheck) AddExpr(format string, a ...interface{}) { if !o.checked { o.checkSize.AddExpr(format, a...) } } // sizerX generates code to compute X-encoded size of a message. // // when type is recursively walked, for every case symbolic size is added appropriately. // in case when it was needed to generate loops, runtime accumulator variable is additionally used. // result is: symbolic size + (optionally) runtime accumulator. type sizerCommon struct { commonCodeGen size SymSize // currently accumulated size } type sizerN struct { sizerCommon; N } type sizerM struct { sizerCommon; M } // encoderX generates code to X-encode a message. // // when type is recursively walked, for every case code to update `data[n:]` is generated. // no overflow checks are generated as by neo.Msg interface provided data // buffer should have at least payloadLen length returned by neoMsgEncodedLenX() // (the size computed by sizerX). // // the code emitted looks like: // // encode<typ1>(data[n1:], path1) // encode<typ2>(data[n2:], path2) // ... // // TODO encode have to care in neoMsgEncodeX to emit preamble such that bound // checking is performed only once (currently compiler emits many of them) type encoderCommon struct { commonCodeGen n int // current write position in data } type encoderN struct { encoderCommon; N } type encoderM struct { encoderCommon; M } // decoderX generates code to X-decode a message. // // when type is recursively walked, for every case code to decode next item from // `data[n:]` is generated. // // overflow checks and nread updates are grouped and emitted so that they are // performed in the beginning of greedy fixed-wire-size blocks - checking / // updating as much as possible to do ahead in one go. // // the code emitted looks like: // // if len(data) < wireSize(typ1) + wireSize(typ2) + ... { // goto overflow // } // <assignto1> = decode<typ1>(data[n1:]) // <assignto2> = decode<typ2>(data[n2:]) // ... type decoderCommon struct { commonCodeGen // done buffer for generated code // current delayed overflow check will be inserted in between bufDone & buf bufDone Buffer n int // current read position in data. nread int // numeric part of total nread return nreadStk []int // stack to push/pop nread on loops // overflow check state and size that will be checked for overflow at // current overflow check point overflow OverflowCheck } type decoderN struct { decoderCommon; N } type decoderM struct { decoderCommon; M } var _ CodeGenerator = (*sizerN)(nil) var _ CodeGenerator = (*encoderN)(nil) var _ CodeGenerator = (*decoderN)(nil) var _ CodeGenerator = (*sizerM)(nil) var _ CodeGenerator = (*encoderM)(nil) var _ CodeGenerator = (*decoderM)(nil) func (s *sizerCommon) generatedCode() string { code := Buffer{} // prologue code.emit("func (%s *%s) neoMsgEncodedLen%c() int {", s.recvName, s.typeName, s.enc) if s.varUsed["size"] { code.emit("var %s int", s.var_("size")) } code.Write(s.buf.Bytes()) // epilogue size := s.size.String() if s.varUsed["size"] { size += " + " + s.var_("size") } code.emit("return %v", size) code.emit("}\n") return code.String() } func (e *encoderCommon) generatedCode() string { code := Buffer{} // prologue code.emit("func (%s *%s) neoMsgEncode%c(data []byte) {", e.recvName, e.typeName, e.enc) code.Write(e.buf.Bytes()) // epilogue code.emit("}\n") return code.String() } // data = data[n:] // n = 0 func (d *decoderCommon) resetPos() { if d.n != 0 { d.emit("data = data[%v:]", d.n) d.n = 0 } } // mark current place for insertion of overflow check code // // The check will be actually inserted later. // // later: because first we go forward in decode path scanning ahead as far as // we can - until first seeing variable-size encoded something, and then - // knowing fixed size would be to read - insert checking condition for // accumulated size to here-marked overflow checkpoint. // // so overflowCheck does: // 1. emit overflow checking code for previous overflow checkpoint // 2. mark current place as next overflow checkpoint to eventually emit // // it has to be inserted // - before reading a variable sized item // - in the beginning of a loop inside (via overflowCheckLoopEntry) // - right after loop exit (via overflowCheckLoopExit) func (d *decoderCommon) overflowCheck() { // nop if we know overflow was already checked if d.overflow.checked { return } //d.bufDone.emit("// overflow check point") if !d.overflow.checkSize.IsZero() { lendata := "len(data)" if !d.overflow.checkSize.IsNumeric() { // symbolic checksize has uint64 type lendata = "uint64(" + lendata + ")" } d.bufDone.emit("if %s < %v { goto overflow }", lendata, &d.overflow.checkSize) // Flag that we already committed a 'goto overflow' statement in current decoder d.overflow.gotoEmitted = true // if size for overflow check was only numeric - just // accumulate it at generation time // // otherwise accumulate into var(nread) at runtime. // we do not break runtime accumulation into numeric & symbolic // parts, because just above whole expression num + symbolic // was given to compiler as a whole so compiler should have it // just computed fully. // XXX recheck ^^^ is actually good with the compiler if d.overflow.checkSize.IsNumeric() { d.nread += d.overflow.checkSize.num } else { d.bufDone.emit("%v += %v", d.var_("nread"), &d.overflow.checkSize) } } d.overflow.checkSize.Reset() d.bufDone.Write(d.buf.Bytes()) d.buf.Reset() } // overflowCheck variant that should be inserted at the beginning of a loop inside func (d *decoderCommon) overflowCheckLoopEntry() { if d.overflow.checked { return } d.overflowCheck() // upon entering a loop organize new nread, because what will be statically // read inside loop should be multiplied by loop len in parent context. d.nreadStk = append(d.nreadStk, d.nread) d.nread = 0 } // overflowCheck variant that should be inserted right after loop exit func (d *decoderCommon) overflowCheckLoopExit(loopLenExpr string) { if d.overflow.checked { return } d.overflowCheck() // merge-in numeric nread updates from loop if d.nread != 0 { d.emit("%v += %v * %v", d.var_("nread"), loopLenExpr, d.nread) } l := len(d.nreadStk) d.nread = d.nreadStk[l-1] d.nreadStk = d.nreadStk[:l-1] } func (d *decoderCommon) generatedCode() string { // flush for last overflow check point d.overflowCheck() code := Buffer{} // prologue code.emit("func (%s *%s) neoMsgDecode%c(data []byte) (int, error) {", d.recvName, d.typeName, d.enc) if d.varUsed["nread"] { code.emit("var %v uint64", d.var_("nread")) } code.Write(d.bufDone.Bytes()) // epilogue retexpr := fmt.Sprintf("%v", d.nread) if d.varUsed["nread"] { // casting nread to int is ok even on 32 bit arches: // if nread would overflow 32 bits it would be caught earlier, // because on 32 bit arch len(data) is also 32 bit and in generated // code len(data) is checked first to be less than encoded message. retexpr += fmt.Sprintf(" + int(%v)", d.var_("nread")) } code.emit("return %v, nil", retexpr) // `goto overflow` is not used only for empty structs // NOTE for >0 check actual X in StdSizes{X} does not particularly matter if ((&types.StdSizes{8, 8}).Sizeof(d.typ) > 0 || d.enc != 'N') && d.overflow.gotoEmitted { code.emit("\noverflow:") code.emit("return 0, ErrDecodeOverflow") } code.emit("}\n") return code.String() } // ---- basic types ---- // N: emit code to size/encode/decode basic fixed type func (s *sizerN) genBasic(path string, typ *types.Basic, userType types.Type) { basic := basicTypesN[typ.Kind()] s.size.Add(basic.wireSize) } func (e *encoderN) genBasic(path string, typ *types.Basic, userType types.Type) { basic := basicTypesN[typ.Kind()] dataptr := fmt.Sprintf("data[%v:]", e.n) if userType != typ && userType != nil { // userType is a named type over some basic, like // type ClusterState int32 // -> need to cast path = fmt.Sprintf("%v(%v)", typeName(typ), path) } e.emit(basic.encode, dataptr, path) e.n += basic.wireSize } func (d *decoderN) genBasic(assignto string, typ *types.Basic, userType types.Type) { basic := basicTypesN[typ.Kind()] // XXX specifying :hi is not needed - it is only a workaround to help BCE. // see https://github.com/golang/go/issues/19126#issuecomment-358743715 dataptr := fmt.Sprintf("data[%v:%v+%d]", d.n, d.n, basic.wireSize) decoded := fmt.Sprintf(basic.decode, dataptr) if userType != typ && userType != nil { // need to cast (like in encode case) decoded = fmt.Sprintf("%v(%v)", typeName(userType), decoded) } // NOTE no space before "=" - to be able to merge with ":" // prefix and become defining assignment d.emit("%s= %s", assignto, decoded) d.n += basic.wireSize d.overflow.Add(basic.wireSize) } // M: emit code to size/encode/decode basic fixed type func (s *sizerM) genBasic(path string, typ *types.Basic, userType types.Type) { // upath casts path into basic type if needed // e.g. p.x -> int32(p.x) if p.x is custom type with underlying int32 upath := path if userType.Underlying() != userType { upath = fmt.Sprintf("%s(%s)", typ.Name(), upath) } // zodb.Tid and zodb.Oid are encoded as [8]bin or nil for INVALID_{TID_OID} if userType == zodbTid || userType == zodbOid { // INVALID_{TID,OID} must be NIL on the wire s.emit("if uint64(%s) == %v {", path, INVALID_ID) s.emit("%v += 1 // mnil", s.var_("size")) s.emit("} else {") s.emit("%v += 1+1+8 // mbin8 + 8 + [8]data", s.var_("size")) s.emit("}") return } // enums are encoded as extensions if _, isEnum := enumRegistry[userType]; isEnum { s.size.Add(1+1+1) // fixext1 enumType value return } switch typ.Kind() { case types.Bool: s.size.Add(1) // mfalse|mtrue case types.Int8: s.size.AddExpr("msgpack.Int8Size(%s)", upath) case types.Int16: s.size.AddExpr("msgpack.Int16Size(%s)", upath) case types.Int32: s.size.AddExpr("msgpack.Int32Size(%s)", upath) case types.Int64: s.size.AddExpr("msgpack.Int64Size(%s)", upath) case types.Uint8: s.size.AddExpr("msgpack.Uint8Size(%s)", upath) case types.Uint16: s.size.AddExpr("msgpack.Uint16Size(%s)", upath) case types.Uint32: s.size.AddExpr("msgpack.Uint32Size(%s)", upath) case types.Uint64: s.size.AddExpr("msgpack.Uint64Size(%s)", upath) case types.Float64: s.size.Add(1+8) // mfloat64 + <value64> } } func (e *encoderM) genBasic(path string, typ *types.Basic, userType types.Type) { // upath casts path into basic type if needed // e.g. p.x -> int32(p.x) if p.x is custom type with underlying int32 upath := path if userType.Underlying() != userType { upath = fmt.Sprintf("%s(%s)", typ.Name(), upath) } // zodb.Tid and zodb.Oid are encoded as [8]bin or nil if userType == zodbTid || userType == zodbOid { e.emit("if %s == %v {", path, INVALID_ID) // INVALID_{TID,OID} => e.emit("data[%v] = byte(msgpack.Nil)", e.n) // mnil e.emit("data = data[%v:]", e.n + 1) e.emit("} else {") e.emit("data[%v] = byte(msgpack.Bin8)", e.n); e.n++ e.emit("data[%v] = 8", e.n); e.n++ e.emit("binary.BigEndian.PutUint64(data[%v:], uint64(%s))", e.n, path) e.n += 8 e.resetPos() e.emit("}") return } // enums are encoded as `fixext1 enumType fixint<value>` if enum, ok := enumRegistry[userType]; ok { e.emit("data[%v] = byte(msgpack.FixExt1)", e.n); e.n++ e.emit("data[%v] = %d", e.n, enum); e.n++ e.emit("if !(0 <= %s && %s <= 0x7f) {", path, path) // mposfixint e.emit(` panic("%s: invalid %s enum value)")`, path, typeName(userType)) e.emit("}") e.emit("data[%v] = byte(%s)", e.n, path); e.n++ return } // mputint emits mput<kind>int<size>(path) mputint := func(kind string, size int) { KI := "I" // I or <Kind>i if kind != "" { KI = strings.ToUpper(kind) + "i" } e.emit("{") e.emit("n := msgpack.Put%snt%d(data[%v:], %s)", KI, size, e.n, upath) e.emit("data = data[%v+n:]", e.n) e.emit("}") e.n = 0 } switch typ.Kind() { case types.Bool: e.emit("data[%v] = byte(msgpack.Bool(%s))", e.n, path) e.n += 1 case types.Int8: mputint("", 8) case types.Int16: mputint("", 16) case types.Int32: mputint("", 32) case types.Int64: mputint("", 64) case types.Uint8: mputint("u", 8) case types.Uint16: mputint("u", 16) case types.Uint32: mputint("u", 32) case types.Uint64: mputint("u", 64) case types.Float64: // mfloat64 f64 e.emit("data[%v] = byte(msgpack.Float64)", e.n); e.n++ e.emit("float64_neoEncode(data[%v:], %s)", e.n, upath); e.n += 8 } } // decoder expects <op> func (d *decoderM) expectOp(assignto string, op string, addOverflow bool) { d.emit("if op := msgpack.Op(data[%v]); op != %s {", d.n, op); d.n++ d.emit(" return 0, mdecodeOpErr(%q, op, %s)", d.pathName(assignto), op) d.emit("}") if addOverflow { d.overflow.Add(1) } } // decoder expects mbin8 l func (d *decoderM) expectBin8Fix(assignto string, l int, addOverflow bool) { d.expectOp(assignto, "msgpack.Bin8", addOverflow) d.emit("if l := data[%v]; l != %d {", d.n, l); d.n++ d.emit(" return 0, mdecodeLen8Err(%q, l, %d)", d.pathName(assignto), l) d.emit("}") if addOverflow { d.overflow.Add(1) } } // decoder expects mfixext1 <enumType> func (d *decoderM) expectEnum(assignto string, enumType int) { d.expectOp(assignto, "msgpack.FixExt1", true) d.emit("if enumType := data[%v]; enumType != %d {", d.n, enumType); d.n++ d.emit(" return 0, mdecodeEnumTypeErr(%q, enumType, %d)", d.pathName(assignto), enumType) d.emit("}") d.overflow.Add(1) } func (d *decoderM) genBasic(assignto string, typ *types.Basic, userType types.Type) { // zodb.Tid and zodb.Oid are encoded as [8]bin or Nil if userType == zodbTid || userType == zodbOid { d.resetPos() d.overflowCheck() defer d.overflowCheck() // Size depends on if we have an INVALID_{TID,OID} // which is encoded as NIL or if we have a valid // {TID,OID} which is encdoed as Bin8Fix. d.overflow.AddExpr("msgpack.TidOrOidSize(data)") d.emit("if data[%v] == byte(msgpack.Nil) {", d.n) d.emit("%s = %s(%v)", assignto, typeName(userType), INVALID_ID) d.n += 1 d.resetPos() d.emit("} else {") d.expectBin8Fix(assignto, 8, false) d.emit("%s= %s(binary.BigEndian.Uint64(data[%v:]))", assignto, typeName(userType), d.n) d.n += 8 d.resetPos() d.emit("}") return } // enums are encoded as `fixext1 enumType fixint<value>` if enum, ok := enumRegistry[userType]; ok { d.expectEnum(assignto, enum) d.emit("{") d.emit("v := data[%v]", d.n); d.n++ d.emit("if !(0 <= v && v <= 0x7f) {") // mposfixint d.emit(" return 0, mdecodeEnumValueErr(%q, v)", d.pathName(assignto)) d.emit("}") d.emit("%s= %s(v)", assignto, typeName(userType)) d.emit("}") d.overflow.Add(1) return } // v represents basic decoded value casted to user type if needed v := "v" if userType.Underlying() != userType { v = fmt.Sprintf("%s(v)", typeName(userType)) } // mgetint emits assignto = mget<kind>int<size>() mgetint := func(kind string, size int) { // we are going to go into msgp - flush previously queued // overflow checks; put place for next overflow check after // msgp is done. d.overflowCheck() d.resetPos() defer d.overflowCheck() KI := "I" // I or <Kind>i if kind != "" { KI = strings.ToUpper(kind) + "i" } d.emit("{") d.emit("v, tail, err := msgp.Read%snt%dBytes(data)", KI, size) d.emit("if err != nil {") d.emit(" return 0, mdecodeErr(%q, err)", d.pathName(assignto)) d.emit("}") d.emit("%s= %s", assignto, v) d.emit("%v += uint64(len(data) - len(tail))", d.var_("nread")) d.emit("data = tail") d.emit("}") } // mgetfloat emits mgetfloat<size> mgetfloat := func(size int, optionalValue string) { // delving into msgp - flush/prepare next site for overflow check d.overflowCheck() d.resetPos() defer d.overflowCheck() d.emit("{") d.emit("v, tail, err := msgp.ReadFloat%dBytes(data)", size) d.emit("if err != nil {") if optionalValue != "" { // ReadFloat%dBytes returns 'ErrShortBytes' in case prefix is // correct float, but data is too short - catch this to return // 'ErrDecodeOverflow' instead of type error. d.emit(" err = mdecodeErr(%q, err)", d.pathName(assignto)) d.emit(" if err == ErrDecodeOverflow {") d.emit(" return 0, err") d.emit(" }") d.emit(" tail, err = msgp.ReadNilBytes(data)") d.emit(" if err != nil {") d.emit(" return 0, mdecodeErr(%q, err)", d.pathName(assignto)) d.emit(" }") d.emit(" v = %v", optionalValue) } else { d.emit(" return 0, mdecodeErr(%q, err)", d.pathName(assignto)) } d.emit("}") d.emit("%s= %s", assignto, v) d.emit("%v += uint64(len(data) - len(tail))", d.var_("nread")) d.emit("data = tail") d.emit("}") } // IdTime can be nil ('None' in py), in this case we use // infinite -1, see // https://lab.nexedi.com/kirr/neo/-/blob/1ad088c8/go/neo/proto/proto.go#L352-357 if typeName(userType) == "IdTime" { mgetfloat(64, "math.Inf(-1)") return } switch typ.Kind() { case types.Bool: d.emit("switch op := msgpack.Op(data[%v]); op {", d.n) d.emit("default: return 0, mdecodeOpErr(%q, op, msgpack.True, msgpack.False)", d.pathName(assignto)) d.emit("case msgpack.True: %s = true", assignto) d.emit("case msgpack.False: %s = false", assignto) d.emit("}") d.n++ d.overflow.Add(1) case types.Int8: mgetint("", 8) case types.Int16: mgetint("", 16) case types.Int32: mgetint("", 32) case types.Int64: mgetint("", 64) case types.Uint8: mgetint("u", 8) case types.Uint16: mgetint("u", 16) case types.Uint32: mgetint("u", 32) case types.Uint64: mgetint("u", 64) case types.Float64: mgetfloat(64, "") } } // emit code to size/encode/decode array with sizeof(elem)==1 // [len(A)]byte func (s *sizerN) genArray1(path string, typ *types.Array) { s.size.Add(int(typ.Len())) } func (e *encoderN) genArray1(path string, typ *types.Array) { e.emit("copy(data[%v:], %v[:])", e.n, path) e.n += int(typ.Len()) } func (d *decoderN) genArray1(assignto string, typ *types.Array) { typLen := int(typ.Len()) d.emit("copy(%v[:], data[%v:%v])", assignto, d.n, d.n+typLen) d.n += typLen d.overflow.Add(typLen) } // binX+lenX // [len(A)]byte func (s *sizerM) genArray1(path string, typ *types.Array) { l := int(typ.Len()) s.size.Add(msgpack.BinHeadSize(l)) s.size.Add(l) } func (e *encoderM) genArray1(path string, typ *types.Array) { l := int(typ.Len()) if l > 0xff { panic("TODO: array1 with > 255 elements") } e.emit("data[%v] = byte(msgpack.Bin8)", e.n); e.n++ e.emit("data[%v] = %d", e.n, l); e.n++ e.emit("copy(data[%v:], %v[:])", e.n, path) e.n += l } func (d *decoderM) genArray1(assignto string, typ *types.Array) { l := int(typ.Len()) if l > 0xff { panic("TODO: array1 with > 255 elements") } d.expectBin8Fix(assignto, l, true) d.emit("copy(%v[:], data[%v:%v])", assignto, d.n, d.n+l) d.n += l d.overflow.Add(l) } // emit code to size/encode/decode string or []byte // len u32 // [len]byte func (s *sizerN) genSlice1(path string, typ types.Type) { s.size.Add(4) s.size.AddExpr("len(%s)", path) } func (e *encoderN) genSlice1(path string, typ types.Type) { e.emit("{") e.emit("l := uint32(len(%s))", path) e.genBasic("l", types.Typ[types.Uint32], nil) e.emit("data = data[%v:]", e.n) e.emit("copy(data, %v)", path) e.emit("data = data[l:]") e.emit("}") e.n = 0 } func (d *decoderN) genSlice1(assignto string, typ types.Type) { d.emit("{") d.genBasic("l:", types.Typ[types.Uint32], nil) d.resetPos() d.overflowCheck() d.overflow.AddExpr("uint64(l)") switch t := typ.(type) { case *types.Basic: if t.Kind() != types.String { log.Panicf("bad basic type in slice1: %v", t) } d.emit("%v= string(data[:l])", assignto) case *types.Slice: // TODO eventually do not copy, but reference data from original d.emit("%v= make(%v, l)", assignto, typeName(typ)) d.emit("copy(%v, data[:l])", assignto) default: log.Panicf("bad type in slice1: %v", typ) } d.emit("data = data[l:]") d.emit("}") } // bin8+len8|bin16+len16|bin32+len32 // [len]byte func (s *sizerM) genSlice1(path string, typ types.Type) { s.size.AddExpr("msgpack.BinHeadSize(len(%s))", path) s.size.AddExpr("len(%s)", path) } func (e *encoderM) genSlice1(path string, typ types.Type) { e.emit("{") e.emit("l := len(%s)", path) e.emit("n := msgpack.PutBinHead(data[%v:], l)", e.n) e.emit("data = data[%v+n:]", e.n) e.emit("copy(data, %v)", path) e.emit("data = data[l:]") e.emit("}") e.n = 0 } func (d *decoderM) genSlice1(assignto string, typ types.Type) { // -> msgp: flush queued overflow checks; put place for next overflow // checks after msgp is done. d.overflowCheck() d.resetPos() defer d.overflowCheck() d.emit("{") d.emit("b, tail, err := msgp.ReadBytesZC(data)") d.emit("if err != nil {") d.emit(" return 0, mdecodeErr(%q, err)", d.pathName(assignto)) d.emit("}") switch t := typ.(type) { case *types.Basic: if t.Kind() != types.String { log.Panicf("bad basic type in slice1: %v", t) } d.emit("%v= string(b)", assignto) case *types.Slice: // TODO eventually do not copy, but reference data from original d.emit("%v= make(%v, len(b))", assignto, typeName(typ)) d.emit("copy(%v, b)", assignto) default: log.Panicf("bad type in slice1: %v", typ) } d.emit("%v += uint64(len(data) - len(tail))", d.var_("nread")) d.emit("data = tail") d.emit("}") } // emit code to size/encode/decode mem.Buf // same as slice1 but buffer is allocated via mem.BufAlloc func (s *sizerN) genBuf(path string) { s.genSlice1(path+".XData()", nil /* typ unused */) } func (s *sizerM) genBuf(path string) { s.genSlice1(path+".XData()", nil /* typ unused */) } func (e *encoderN) genBuf(path string) { e.genSlice1(path+".XData()", nil /* typ unused */) } func (e *encoderM) genBuf(path string) { e.genSlice1(path+".XData()", nil /* typ unused */) } func (d *decoderN) genBuf(assignto string) { d.emit("{") d.genBasic("l:", types.Typ[types.Uint32], nil) d.resetPos() d.overflowCheck() d.overflow.AddExpr("uint64(l)") // TODO eventually do not copy but reference original d.emit("%v= mem.BufAlloc(int(l))", assignto) d.emit("copy(%v.Data, data[:l])", assignto) d.emit("data = data[l:]") d.emit("}") } func (d *decoderM) genBuf(assignto string) { // -> msgp: flush queued overflow checks; put place for next overflow // checks after msgp is done. d.overflowCheck() d.resetPos() defer d.overflowCheck() d.emit("{") d.emit("b, tail, err := msgp.ReadBytesZC(data)") d.emit("if err != nil {") d.emit(" return 0, mdecodeErr(%q, err)", d.pathName(assignto)) d.emit("}") // TODO eventually do not copy but reference original d.emit("%v= mem.BufAlloc(len(b))", assignto) d.emit("copy(%v.Data, b)", assignto) d.emit("%v += uint64(len(data) - len(tail))", d.var_("nread")) d.emit("data = tail") d.emit("}") } // emit code to size/encode/decode slice // len u32 // [len]item func (s *sizerN) genSliceHead(path string, typ *types.Slice, obj types.Object) { s.size.Add(4) } func (s *sizerN) genSlice(path string, typ *types.Slice, obj types.Object) { s.genSliceCommon(s, path, typ, obj) } func (s *sizerCommon) genSliceCommon(xs CodeGenCustomize, path string, typ *types.Slice, obj types.Object) { xs.genSliceHead(path, typ, obj) // if size(item)==const - size update in one go elemSize, ok := typeSizeFixed(xs.encoding(), typ.Elem()) if ok { s.size.AddExpr("len(%v) * %v", path, elemSize) return } curSize := s.size s.size.Reset() s.emit("for i := 0; i < len(%v); i++ {", path) s.emit("a := &%s[i]", path) codegenType("(*a)", typ.Elem(), obj, xs) // merge-in size updates s.emit("%v += %v", s.var_("size"), s.size.ExprString()) s.emit("}") if s.size.num != 0 { curSize.AddExpr("len(%v) * %v", path, s.size.num) } s.size = curSize } func (e *encoderN) genSliceHead(path string, typ *types.Slice, obj types.Object) { e.emit("l := len(%s)", path) e.genBasic("l", types.Typ[types.Uint32], types.Typ[types.Int]) e.emit("data = data[%v:]", e.n) e.n = 0 } func (e *encoderN) genSlice(path string, typ *types.Slice, obj types.Object) { e.genSliceCommon(e, path, typ, obj) } func (e *encoderCommon) genSliceCommon(xe CodeGenCustomize, path string, typ *types.Slice, obj types.Object) { e.emit("{") xe.genSliceHead(path, typ, obj) e.emit("for i := 0; i <l; i++ {") e.emit("a := &%s[i]", path) codegenType("(*a)", typ.Elem(), obj, xe) if e.n != 0 { e.emit("data = data[%v:]", e.n) e.n = 0 } e.emit("}") e.emit("}") } // data = data[n:] // n = 0 // // XXX duplication wrt decoderCommon.resetPost func (e *encoderCommon) resetPos() { if e.n != 0 { e.emit("data = data[%v:]", e.n) e.n = 0 } } func (d *decoderN) genSliceHead(assignto string, typ *types.Slice, obj types.Object) { d.genBasic("l:", types.Typ[types.Uint32], nil) } func (d *decoderN) genSlice(assignto string, typ *types.Slice, obj types.Object) { d.genSliceCommon(d, assignto, typ, obj) } func (d *decoderCommon) genSliceCommon(xd CodeGenCustomize, assignto string, typ *types.Slice, obj types.Object) { d.emit("{") xd.genSliceHead(assignto, typ, obj) d.resetPos() // if size(item)==const - check overflow in one go elemSize, elemFixed := typeSizeFixed(xd.encoding(), typ.Elem()) if elemFixed { d.overflowCheck() d.overflow.AddExpr("uint64(l) * %v", elemSize) d.overflow.PushChecked(true) defer d.overflow.PopChecked() } d.emit("%v= make(%v, l)", assignto, typeName(typ)) d.emit("for i := 0; uint32(i) < l; i++ {") d.emit("a := &%s[i]", assignto) d.overflowCheckLoopEntry() codegenType("(*a)", typ.Elem(), obj, xd) d.resetPos() d.emit("}") d.overflowCheckLoopExit("uint64(l)") d.emit("}") } // fixarray|array16+YYYY|array32+ZZZZ // [len]item func (s *sizerM) genSliceHead(path string, typ *types.Slice, obj types.Object) { s.size.AddExpr("msgpack.ArrayHeadSize(len(%s))", path) } func (s *sizerM) genSlice(path string, typ *types.Slice, obj types.Object) { s.genSliceCommon(s, path, typ, obj) } func (e *encoderM) genSliceHead(path string, typ *types.Slice, obj types.Object) { e.emit("l := len(%s)", path) e.emit("n := msgpack.PutArrayHead(data[%v:], l)", e.n) e.emit("data = data[%v+n:]", e.n) e.n = 0 } func (e *encoderM) genSlice(path string, typ *types.Slice, obj types.Object) { e.genSliceCommon(e, path, typ, obj) } func (d *decoderM) genSliceHead(assignto string, typ *types.Slice, obj types.Object) { // -> msgp: flush queued overflow checks; put place for next overflow // checks after msgp is done. d.overflowCheck() d.resetPos() defer d.overflowCheck() d.emit("l, tail, err := msgp.ReadArrayHeaderBytes(data)") d.emit("if err != nil {") d.emit(" return 0, mdecodeErr(%q, err)", d.pathName(assignto)) d.emit("}") d.emit("%v += uint64(len(data) - len(tail))", d.var_("nread")) d.emit("data = tail") } func (d *decoderM) genSlice(assignto string, typ *types.Slice, obj types.Object) { d.genSliceCommon(d, assignto, typ, obj) } // generate code to encode/decode map // len u32 // [len](key, value) func (s *sizerN) genMapHead(path string, typ *types.Map, obj types.Object) { s.size.Add(4) } func (s *sizerN) genMap(path string, typ *types.Map, obj types.Object) { s.genMapCommon(s, path, typ, obj) } func (s *sizerCommon) genMapCommon(xs CodeGenCustomize, path string, typ *types.Map, obj types.Object) { xs.genMapHead(path, typ, obj) keySize, keyFixed := typeSizeFixed(xs.encoding(), typ.Key()) elemSize, elemFixed := typeSizeFixed(xs.encoding(), typ.Elem()) if keyFixed && elemFixed { s.size.AddExpr("len(%v) * %v", path, keySize+elemSize) return } curSize := s.size s.size.Reset() // FIXME for map of map gives ...[key][key] => key -> different variables s.emit("for key := range %s {", path) codegenType("key", typ.Key(), obj, xs) codegenType(fmt.Sprintf("%s[key]", path), typ.Elem(), obj, xs) // merge-in size updates s.emit("%v += %v", s.var_("size"), s.size.ExprString()) s.emit("}") if s.size.num != 0 { curSize.AddExpr("len(%v) * %v", path, s.size.num) } s.size = curSize } func (e *encoderN) genMapHead(path string, typ *types.Map, obj types.Object) { e.emit("l := len(%s)", path) e.genBasic("l", types.Typ[types.Uint32], types.Typ[types.Int]) e.emit("data = data[%v:]", e.n) e.n = 0 } func (e *encoderN) genMap(path string, typ *types.Map, obj types.Object) { e.genMapCommon(e, path, typ, obj) } func (e *encoderCommon) genMapCommon(xe CodeGenCustomize, path string, typ *types.Map, obj types.Object) { e.emit("{") xe.genMapHead(path, typ, obj) // output keys in sorted order on the wire // (easier for debugging & deterministic for testing) e.emit("keyv := make([]%s, 0, l)", typeName(typ.Key())) // FIXME do not throw old slice away -> do xslice.Realloc() e.emit("for key := range %s {", path) e.emit(" keyv = append(keyv, key)") e.emit("}") e.emit("sort.Slice(keyv, func (i, j int) bool { return keyv[i] < keyv[j] })") e.emit("for _, key := range keyv {") codegenType("key", typ.Key(), obj, xe) codegenType(fmt.Sprintf("%s[key]", path), typ.Elem(), obj, xe) if e.n != 0 { e.emit("data = data[%v:]", e.n) // XXX wrt map of map? e.n = 0 } e.emit("}") e.emit("}") } func (d *decoderN) genMapHead(assignto string, typ *types.Map, obj types.Object) { d.genBasic("l:", types.Typ[types.Uint32], nil) } func (d *decoderN) genMap(assignto string, typ *types.Map, obj types.Object) { d.genMapCommon(d, assignto, typ, obj) } func (d *decoderCommon) genMapCommon(xd CodeGenCustomize, assignto string, typ *types.Map, obj types.Object) { d.emit("{") xd.genMapHead(assignto, typ, obj) d.resetPos() // if size(key,item)==const - check overflow in one go keySize, keyFixed := typeSizeFixed(xd.encoding(), typ.Key()) elemSize, elemFixed := typeSizeFixed(xd.encoding(), typ.Elem()) if keyFixed && elemFixed { d.overflowCheck() d.overflow.AddExpr("uint64(l) * %v", keySize+elemSize) d.overflow.PushChecked(true) defer d.overflow.PopChecked() } d.emit("%v= make(%v, l)", assignto, typeName(typ)) d.emit("m := %v", assignto) d.emit("for i := 0; uint32(i) < l; i++ {") d.overflowCheckLoopEntry() d.emit("var key %s", typeName(typ.Key())) codegenType("key", typ.Key(), obj, xd) switch typ.Elem().Underlying().(type) { // basic types can be directly assigned to map entry case *types.Basic: codegenType("m[key]", typ.Elem(), obj, xd) // otherwise assign via temporary default: d.emit("var mv %v", typeName(typ.Elem())) codegenType("mv", typ.Elem(), obj, xd) d.emit("m[key] = mv") } d.resetPos() d.emit("}") d.overflowCheckLoopExit("uint64(l)") d.emit("}") } // fixmap|map16+YYYY|map32+ZZZZ // [len]key/value func (s *sizerM) genMapHead(path string, typ *types.Map, obj types.Object) { s.size.AddExpr("msgpack.MapHeadSize(len(%s))", path) } func (s *sizerM) genMap(path string, typ *types.Map, obj types.Object) { s.genMapCommon(s, path, typ, obj) } func (e *encoderM) genMapHead(path string, typ *types.Map, obj types.Object) { e.emit("l := len(%s)", path) e.emit("n := msgpack.PutMapHead(data[%v:], l)", e.n) e.emit("data = data[%v+n:]", e.n) e.n = 0 } func (e *encoderM) genMap(path string, typ *types.Map, obj types.Object) { e.genMapCommon(e, path, typ, obj) } func (d *decoderM) genMapHead(assignto string, typ *types.Map, obj types.Object) { // -> msgp: flush queued overflow checks; put place for next overflow // checks after msgp is done. d.overflowCheck() d.resetPos() defer d.overflowCheck() d.emit("l, tail, err := msgp.ReadMapHeaderBytes(data)") d.emit("if err != nil {") d.emit(" return 0, mdecodeErr(%q, err)", d.pathName(assignto)) d.emit("}") d.emit("%v += uint64(len(data) - len(tail))", d.var_("nread")) d.emit("data = tail") } func (d *decoderM) genMap(assignto string, typ *types.Map, obj types.Object) { d.genMapCommon(d, assignto, typ, obj) } // emit code to size/encode/decode custom type func (s *sizerN) genCustomN(path string) { s.size.AddExpr("%s.neoEncodedLenN()", path) } func (e *encoderN) genCustomN(path string) { e.emit("{") e.emit("n := %s.neoEncodeN(data[%v:])", path, e.n) e.emit("data = data[%v + n:]", e.n) e.emit("}") e.n = 0 } func (d *decoderN) genCustomN(path string) { d.resetPos() // make sure we check for overflow previous-code before proceeding to custom decoder. d.overflowCheck() d.emit("{") d.emit("n, ok := %s.neoDecodeN(data)", path) d.emit("if !ok { goto overflow }") d.emit("data = data[n:]") d.emit("%v += n", d.var_("nread")) d.emit("}") // insert overflow checkpoint after custom decoder so that overflow // checks for following code are inserted after custom decoder call. d.overflowCheck() } // ---- struct head ---- // N: nothing func (s *sizerN) genStructHead(path string, typ *types.Struct, userType types.Type) {} func (e *encoderN) genStructHead(path string, typ *types.Struct, userType types.Type) {} func (d *decoderN) genStructHead(path string, typ *types.Struct, userType types.Type) {} // M: array<nfields> func (s *sizerM) genStructHead(path string, typ *types.Struct, userType types.Type) { if !strings.Contains(path, ".") && !strings.Contains(path, "*") { path = fmt.Sprintf("(*%s)", path) } s.size.AddExpr("msgpack.ArrayHeadSize(reflect.TypeOf(%s).NumField())", path) } func (e *encoderM) genStructHead(path string, typ *types.Struct, userType types.Type) { if typ.NumFields() > 0x0f { panic("TODO: struct with > 15 elements") } e.emit("data[%v] = byte(msgpack.FixArray_4 | %d)", e.n, typ.NumFields()) e.n += 1 } func (d *decoderM) genStructHead(path string, typ *types.Struct, userType types.Type) { d.resetPos() // we are going to go into msgp - flush previously queued // overflow checks; put place for next overflow check after // msgp is done. d.overflowCheck() defer d.overflowCheck() d.emit("{") d.emit("_, tail, err := msgp.ReadArrayHeaderBytes(data)") d.emit("if err != nil {") d.emit(fmt.Sprintf("return 0, mdecodeErr(%q, err)", d.pathName(path))) d.emit("}") d.emit("%v += uint64(len(data) - len(tail))", d.var_("nread")) d.emit("data = tail") d.emit("}") } // top-level driver for emitting size/encode/decode code for a type // // obj is object that uses this type in source program (so in case of an error // we can point to source location for where it happened) func codegenType(path string, typ types.Type, obj types.Object, codegen CodeGenerator) { // neo.customCodecN ccCustomN, ok := codegen.(interface { genCustomN(path string) }) if ok && (types.Implements(typ, neo_customCodecN) || types.Implements(types.NewPointer(typ), neo_customCodecN)) { ccCustomN.genCustomN(path) return } // mem.Buf if tptr, ok := typ.Underlying().(*types.Pointer); ok && tptr.Elem() == memBuf { codegen.genBuf(path) return } switch u := typ.Underlying().(type) { case *types.Basic: // go puts string into basic, but it is really slice1 if u.Kind() == types.String { codegen.genSlice1(path, u) break } _, ok := basicTypesN[u.Kind()] // ok to check N to see if supported for both N and M if !ok { log.Fatalf("%v: %v: basic type %v not supported", pos(obj), obj.Name(), u) } codegen.genBasic(path, u, typ) case *types.Struct: codegen.genStructHead(path, u, typ) for i := 0; i < u.NumFields(); i++ { v := u.Field(i) codegenType(path+"."+v.Name(), v.Type(), v, codegen) } case *types.Array: // [...]byte or [...]uint8 - just straight copy if isByte(u.Elem()) { codegen.genArray1(path, u) } else { var i int64 for i = 0; i < u.Len(); i++ { codegenType(fmt.Sprintf("%v[%v]", path, i), u.Elem(), obj, codegen) } } case *types.Slice: if isByte(u.Elem()) { codegen.genSlice1(path, u) } else { codegen.genSlice(path, u, obj) } case *types.Map: codegen.genMap(path, u, obj) default: log.Fatalf("%v: %v has unsupported type %v (%v)", pos(obj), obj.Name(), typ, u) } } // generate size/encode/decode functions for a type declaration typespec func generateCodecCode(typespec *ast.TypeSpec, codegen CodeGenerator) string { // type & object which refers to this type typ := typeInfo.Types[typespec.Type].Type obj := typeInfo.Defs[typespec.Name] codegen.setFunc("p", typespec.Name.Name, typ, codegen.encoding()) codegenType("p", typ, obj, codegen) return codegen.generatedCode() } // xlookup looks up <pkg>.<name> object. // It is fatal error if the object cannot be found. func xlookup(pkg *types.Package, name string) types.Object { obj := pkg.Scope().Lookup(name) if obj == nil { log.Fatalf("cannot find `%s.%s`", pkg.Name(), name) } return obj } // isByte returns whether typ represents byte. func isByte(typ types.Type) bool { t, ok := typ.(*types.Basic) return ok && t.Kind() == types.Byte }