Commit 1adf0596 authored by Kirill Smelkov's avatar Kirill Smelkov

Add PyDict mode

Similarly to StrictUnicode mode (see b28613c2) add new opt-in mode that
requests to decode Python dicts as ogórek.Dict instead of builtin map.
As explained in recent patch "Add custom Dict that mirrors Python dict
behaviour" this is needed to fix decoding issues that can be there due
to different behaviour of Python dict and builtin Go map:

    ---- 8< ----
    Ogórek currently represents unpickled dict via map[any]any, which is
    logical, but also exhibits issues because builtin Go map behaviour is
    different from Python's dict behaviour. For example:

    - Python's dict allows tuples to be used in keys, while Go map
      does not (https://github.com/kisielk/og-rek/issues/50),

    - Python's dict allows both long and int to be used interchangeable as
      keys, while Go map does not handle *big.Int as key with the same
      semantic (https://github.com/kisielk/og-rek/issues/55)

    - Python's dict allows to use numbers interchangeable in keys - all int
      and float, but on Go side int(1) and float64(1.0) are considered by
      builtin map as different keys.

    - In Python world bytestring (str from py2) is considered to be related
      to both unicode (str on py3) and bytes, but builtin map considers all
      string, Bytes and ByteString as different keys.

    - etc...

    All in all there are many differences in behaviour in builtin Python
    dict and Go map that result in generally different semantics when
    decoding pickled data. Those differences can be fixed only if we add
    custom dict implementation that mirrors what Python does.

    -> Do that: add custom Dict that implements key -> value mapping with
       mirroring Python behaviour.

    For now we are only adding the Dict class itself and its tests.
    Later we will use this new Dict to handle decoding dictionaries from the pickles.
    ---- 8< ----

In this patch we add new Decoder option to activate PyDict mode
decoding, teach encoder to also support encoding of Dict and adjust
tests.

The behaviour of new system is explained by the following doc.go
excerpt:

    For dicts there are two modes. In the first, default, mode Python dicts are
    decoded into standard Go map. This mode tries to use builtin Go type, but
    cannot mirror py behaviour fully because e.g. int(1), big.Int(1) and
    float64(1.0) are all treated as different keys by Go, while Python treats
    them as being equal. It also does not support decoding dicts with tuple
    used in keys:

         dict      map[any]any                       PyDict=n mode, default
                 ←  ogórek.Dict

    With PyDict=y mode, however, Python dicts are decoded as ogórek.Dict which
    mirrors behaviour of Python dict with respect to keys equality, and with
    respect to which types are allowed to be used as keys.

         dict      ogórek.Dict                       PyDict=y mode
                 ←  map[any]any
parent 49eb8f31
......@@ -15,7 +15,7 @@ import (
"github.com/aristanetworks/gomap"
)
// Dict represents dict from Python.
// Dict represents dict from Python in PyDict mode.
//
// It mirrors Python with respect to which types are allowed to be used as
// keys, and with respect to keys equality. For example Tuple is allowed to be
......@@ -27,6 +27,8 @@ import (
// underlying content ByteString, because it represents str type from Python2,
// is treated equal to both Bytes and string.
//
// See PyDict mode documentation in top-level package overview for details.
//
// Note: similarly to builtin map Dict is pointer-like type: its zero-value
// represents nil dictionary that is empty and invalid to use Set on.
type Dict struct {
......
......@@ -22,12 +22,29 @@
// long ↔ *big.Int
// float ↔ float64
// float ← floatX
// list ↔ []interface{}
// list ↔ []any
// tuple ↔ ogórek.Tuple
// dict ↔ map[interface{}]interface{}
//
//
// For strings there are two modes. In the first, default, mode both py2/py3
// For dicts there are two modes. In the first, default, mode Python dicts are
// decoded into standard Go map. This mode tries to use builtin Go type, but
// cannot mirror py behaviour fully because e.g. int(1), big.Int(1) and
// float64(1.0) are all treated as different keys by Go, while Python treats
// them as being equal. It also does not support decoding dicts with tuple
// used in keys:
//
// dict ↔ map[any]any PyDict=n mode, default
// ← ogórek.Dict
//
// With PyDict=y mode, however, Python dicts are decoded as ogórek.Dict which
// mirrors behaviour of Python dict with respect to keys equality, and with
// respect to which types are allowed to be used as keys.
//
// dict ↔ ogórek.Dict PyDict=y mode
// ← map[any]any
//
//
// For strings there are also two modes. In the first, default, mode both py2/py3
// str and py2 unicode are decoded into string with py2 str being considered
// as UTF-8 encoded. Correspondingly for protocol ≤ 2 Go string is encoded as
// UTF-8 encoded py2 str, and for protocol ≥ 3 as py3 str / py2 unicode.
......@@ -155,6 +172,11 @@
// Using the helpers fits into Python3 strings/bytes model but also allows to
// handle the data generated from under Python2.
//
// Similarly Dict considers ByteString to be equal to both string and Bytes
// with the same underlying content. This allows programs to access Dict via
// string/bytes keys following Python3 model, while still being able to handle
// dictionaries generated from under Python2.
//
//
// --------
//
......
......@@ -497,6 +497,41 @@ func (e *Encoder) encodeMap(m reflect.Value) error {
return e.emit(opDict)
}
func (e *Encoder) encodeDict(d Dict) error {
l := d.Len()
// protocol >= 1: ø dict -> EMPTY_DICT
if e.config.Protocol >= 1 && l == 0 {
return e.emit(opEmptyDict)
}
// MARK + ... + DICT
// TODO cycles + sort keys (see encodeMap for details)
err := e.emit(opMark)
if err != nil {
return err
}
d.Iter()(func(k, v any) bool {
err = e.encode(reflectValueOf(k))
if err != nil {
return false
}
err = e.encode(reflectValueOf(v))
if err != nil {
return false
}
return true
})
if err != nil {
return err
}
return e.emit(opDict)
}
func (e *Encoder) encodeCall(v *Call) error {
err := e.encodeClass(&v.Callable)
if err != nil {
......@@ -570,6 +605,8 @@ func (e *Encoder) encodeStruct(st reflect.Value) error {
return e.encodeRef(&v)
case big.Int:
return e.encodeLong(&v)
case Dict:
return e.encodeDict(v)
}
structTags := getStructTags(st)
......
......@@ -8,20 +8,24 @@ import (
)
func Fuzz(data []byte) int {
f1 := fuzz(data, false)
f2 := fuzz(data, true)
f := 0
f += fuzz(data, false, false)
f += fuzz(data, false, true)
f += fuzz(data, true, false)
f += fuzz(data, true, true)
f := f1+f2
if f > 1 {
f = 1
}
return f
}
func fuzz(data []byte, strictUnicode bool) int {
func fuzz(data []byte, pyDict, strictUnicode bool) int {
// obj = decode(data) - this tests things like stack overflow in Decoder
buf := bytes.NewBuffer(data)
dec := NewDecoderWithConfig(buf, &DecoderConfig{
PyDict: pyDict,
StrictUnicode: strictUnicode,
})
obj, err := dec.Decode()
......@@ -37,7 +41,7 @@ func fuzz(data []byte, strictUnicode bool) int {
// because obj - as we got it as decoding from input - is known not to
// contain arbitrary Go structs.
for proto := 0; proto <= highestProtocol; proto++ {
subj := fmt.Sprintf("strictUnicode %v: protocol %d", strictUnicode, proto)
subj := fmt.Sprintf("pyDict %v: strictUnicode %v: protocol %d", pyDict, strictUnicode, proto)
buf.Reset()
enc := NewEncoderWithConfig(buf, &EncoderConfig{
......@@ -66,6 +70,7 @@ func fuzz(data []byte, strictUnicode bool) int {
encoded := buf.String()
dec = NewDecoderWithConfig(bytes.NewBufferString(encoded), &DecoderConfig{
PyDict: pyDict,
StrictUnicode: strictUnicode,
})
obj2, err := dec.Decode()
......
......@@ -189,6 +189,11 @@ type DecoderConfig struct {
// decoded into ByteString in this mode. See StrictUnicode mode
// documentation in top-level package overview for details.
StrictUnicode bool
// PyDict, when true, requests to decode Python dicts as ogórek.Dict
// instead of builtin map. See PyDict mode documentation in top-level
// package overview for details.
PyDict bool
}
// NewDecoder returns a new Decoder with the default configuration.
......@@ -996,29 +1001,75 @@ func mapTryAssign(m map[interface{}]interface{}, key, value interface{}) (ok boo
return
}
// dictTryAssign is like mapTryAssign but for Dict.
func dictTryAssign(d Dict, key, value interface{}) (ok bool) {
defer func() {
if r := recover(); r != nil {
ok = false
}
}()
d.Set(key, value)
ok = true
return
}
func (d *Decoder) loadDict() error {
k, err := d.marker()
if err != nil {
return err
}
m := make(map[interface{}]interface{}, 0)
items := d.stack[k+1:]
if len(items) % 2 != 0 {
return fmt.Errorf("pickle: loadDict: odd # of elements")
}
var m interface{}
if d.config.PyDict {
m, err = d.loadDictDict(items)
} else {
m, err = d.loadDictMap(items)
}
if err != nil {
return err
}
d.stack = append(d.stack[:k], m)
return nil
}
func (d *Decoder) loadDictMap(items []interface{}) (map[interface{}]interface{}, error) {
m := make(map[interface{}]interface{}, len(items)/2)
for i := 0; i < len(items); i += 2 {
key := items[i]
if !mapTryAssign(m, key, items[i+1]) {
return fmt.Errorf("pickle: loadDict: invalid key type %T", key)
return nil, fmt.Errorf("pickle: loadDict: map: invalid key type %T", key)
}
}
d.stack = append(d.stack[:k], m)
return nil
return m, nil
}
func (d *Decoder) loadDictDict(items []interface{}) (Dict, error) {
m := NewDictWithSizeHint(len(items)/2)
for i := 0; i < len(items); i += 2 {
key := items[i]
if !dictTryAssign(m, key, items[i+1]) {
return Dict{}, fmt.Errorf("pickle: loadDict: Dict: invalid key type %T", key)
}
}
return m, nil
}
func (d *Decoder) loadEmptyDict() error {
m := make(map[interface{}]interface{}, 0)
var m interface{}
if d.config.PyDict {
m = NewDict()
} else {
m = make(map[interface{}]interface{}, 0)
}
d.push(m)
return nil
}
......@@ -1205,10 +1256,14 @@ func (d *Decoder) loadSetItem() error {
switch m := m.(type) {
case map[interface{}]interface{}:
if !mapTryAssign(m, k, v) {
return fmt.Errorf("pickle: loadSetItem: invalid key type %T", k)
return fmt.Errorf("pickle: loadSetItem: map: invalid key type %T", k)
}
case Dict:
if !dictTryAssign(m, k, v) {
return fmt.Errorf("pickle: loadSetItem: Dict: invalid key type %T", k)
}
default:
return fmt.Errorf("pickle: loadSetItem: expected a map, got %T", m)
return fmt.Errorf("pickle: loadSetItem: expected a map or Dict, got %T", m)
}
return nil
}
......@@ -1221,23 +1276,31 @@ func (d *Decoder) loadSetItems() error {
if k < 1 {
return errStackUnderflow
}
if (len(d.stack) - (k + 1)) % 2 != 0 {
return fmt.Errorf("pickle: loadSetItems: odd # of elements")
}
l := d.stack[k-1]
switch m := l.(type) {
case map[interface{}]interface{}:
if (len(d.stack) - (k + 1)) % 2 != 0 {
return fmt.Errorf("pickle: loadSetItems: odd # of elements")
}
for i := k + 1; i < len(d.stack); i += 2 {
key := d.stack[i]
if !mapTryAssign(m, key, d.stack[i+1]) {
return fmt.Errorf("pickle: loadSetItems: invalid key type %T", key)
return fmt.Errorf("pickle: loadSetItems: map: invalid key type %T", key)
}
}
d.stack = append(d.stack[:k-1], m)
case Dict:
for i := k + 1; i < len(d.stack); i += 2 {
key := d.stack[i]
if !dictTryAssign(m, key, d.stack[i+1]) {
return fmt.Errorf("pickle: loadSetItems: Dict: invalid key type %T", key)
}
}
default:
return fmt.Errorf("pickle: loadSetItems: expected a map, got %T", m)
return fmt.Errorf("pickle: loadSetItems: expected a map or Dict, got %T", m)
}
d.stack = append(d.stack[:k-1], l)
return nil
}
......
......@@ -85,6 +85,9 @@ type TestEntry struct {
strictUnicodeN bool // whether to test with StrictUnicode=n while decoding/encoding
strictUnicodeY bool // whether to test with StrictUnicode=y while decoding/encoding
pyDictN bool // whether to test with PyDict=n while decoding/encoding
pyDictY bool // ----//---- PyDict=y
}
// X, I, P0, P1, P* form a language to describe decode/encode tests:
......@@ -104,7 +107,8 @@ type TestEntry struct {
// the entry is tested under both StrictUnicode=n and StrictUnicode=y modes.
func X(name string, object interface{}, picklev ...TestPickle) TestEntry {
return TestEntry{name: name, objectIn: object, objectOut: object, picklev: picklev,
strictUnicodeN: true, strictUnicodeY: true}
strictUnicodeN: true, strictUnicodeY: true,
pyDictN: true, pyDictY: true}
}
// Xuauto is syntactic sugar to prepare one TestEntry that is tested only under StrictUnicode=n mode.
......@@ -121,6 +125,38 @@ func Xustrict(name string, object interface{}, picklev ...TestPickle) TestEntry
return x
}
// Xdgo is syntactic sugar to prepare one TestEntry that is tested only under PyDict=n mode.
func Xdgo(name string, object interface{}, picklev ...TestPickle) TestEntry {
x := X(name, object, picklev...)
x.pyDictY = false
return x
}
// Xdpy is syntactic sugar to prepare one TestEntry that is tested only under PyDict=y mode.
func Xdpy(name string, object interface{}, picklev ...TestPickle) TestEntry {
x := X(name, object, picklev...)
x.pyDictN = false
return x
}
// Xuauto_dgo is syntactic sugar to prepare one TestEntry that is tested only
// under StrictUnicode=n ^ pyDict=n mode.
func Xuauto_dgo(name string, object interface{}, picklev ...TestPickle) TestEntry {
x := X(name, object, picklev...)
x.strictUnicodeY = false
x.pyDictY = false
return x
}
// Xuauto_dpy is syntactic sugar to prepare one TestEntry that is tested only
// under StrictUnicode=n ^ pyDict=y mode.
func Xuauto_dpy(name string, object interface{}, picklev ...TestPickle) TestEntry {
x := X(name, object, picklev...)
x.strictUnicodeY = false
x.pyDictN = false
return x
}
// Xloosy is syntactic sugar to prepare one TestEntry with loosy encoding.
//
// It should be used only if objectIn contains Go structs.
......@@ -384,19 +420,46 @@ var tests = []TestEntry{
// bytearray(text, encoding); GLOBAL + BINUNICODE + TUPLE + REDUCE
I("c__builtin__\nbytearray\nq\x00(X\x13\x00\x00\x00hello\n\xc3\x90\xc2\xbc\xc3\x90\xc2\xb8\xc3\x91\xc2\x80\x01q\x01X\x07\x00\x00\x00latin-1q\x02tq\x03Rq\x04.")),
// dicts in default PyDict=n mode
Xdgo("dict({})", make(map[interface{}]interface{}),
P0("(d."), // MARK + DICT
P1_("}."), // EMPTY_DICT
I("(dp0\n.")),
Xuauto_dgo("dict({'a': '1'})", map[interface{}]interface{}{"a": "1"},
P0("(S\"a\"\nS\"1\"\nd."), // MARK + STRING + DICT
P12("(U\x01aU\x011d."), // MARK + SHORT_BINSTRING + DICT
P3("(X\x01\x00\x00\x00aX\x01\x00\x00\x001d."), // MARK + BINUNICODE + DICT
P4_("(\x8c\x01a\x8c\x011d.")), // MARK + SHORT_BINUNICODE + DICT
Xuauto_dgo("dict({'a': '1', 'b': '2'})", map[interface{}]interface{}{"a": "1", "b": "2"},
// map iteration order is not stable - test only decoding
I("(S\"a\"\nS\"1\"\nS\"b\"\nS\"2\"\nd."), // P0: MARK + STRING + DICT
I("(U\x01aU\x011U\x01bU\x012d."), // P12: MARK + SHORT_BINSTRING + DICT
// P3: MARK + BINUNICODE + DICT
I("(X\x01\x00\x00\x00aX\x01\x00\x00\x001X\x01\x00\x00\x00bX\x01\x00\x00\x002d."),
I("(\x8c\x01a\x8c\x011\x8c\x01b\x8c\x012d."), // P4_: MARK + SHORT_BINUNICODE + DICT
I("(dS'a'\nS'1'\nsS'b'\nS'2'\ns."), // MARK + DICT + STRING + SETITEM
I("}(U\x01aU\x011U\x01bU\x012u."), // EMPTY_DICT + MARK + SHORT_BINSTRING + SETITEMS
I("(dp0\nS'a'\np1\nS'1'\np2\nsS'b'\np3\nS'2'\np4\ns.")),
// dicts in PyDict=y mode
X("dict({})", make(map[interface{}]interface{}),
Xdpy("dict({})", NewDict(),
P0("(d."), // MARK + DICT
P1_("}."), // EMPTY_DICT
I("(dp0\n.")),
Xuauto("dict({'a': '1'})", map[interface{}]interface{}{"a": "1"},
Xuauto_dpy("dict({'a': '1'})", NewDictWithData("a","1"),
P0("(S\"a\"\nS\"1\"\nd."), // MARK + STRING + DICT
P12("(U\x01aU\x011d."), // MARK + SHORT_BINSTRING + DICT
P3("(X\x01\x00\x00\x00aX\x01\x00\x00\x001d."), // MARK + BINUNICODE + DICT
P4_("(\x8c\x01a\x8c\x011d.")), // MARK + SHORT_BINUNICODE + DICT
Xuauto("dict({'a': '1', 'b': '2'})", map[interface{}]interface{}{"a": "1", "b": "2"},
Xuauto_dpy("dict({'a': '1', 'b': '2'})", NewDictWithData("a","1", "b","2"),
// map iteration order is not stable - test only decoding
I("(S\"a\"\nS\"1\"\nS\"b\"\nS\"2\"\nd."), // P0: MARK + STRING + DICT
I("(U\x01aU\x011U\x01bU\x012d."), // P12: MARK + SHORT_BINSTRING + DICT
......@@ -409,6 +472,21 @@ var tests = []TestEntry{
I("}(U\x01aU\x011U\x01bU\x012u."), // EMPTY_DICT + MARK + SHORT_BINSTRING + SETITEMS
I("(dp0\nS'a'\np1\nS'1'\np2\nsS'b'\np3\nS'2'\np4\ns.")),
Xdpy("dict({123L: 0})", NewDictWithData(bigInt("123"), int64(0)),
P0("(L123L\nI0\nd."), // MARK + LONG + INT + DICT
P1("(L123L\nK\x00d."), // MARK + LONG + BININT1 + DICT
I("(\x8a\x01{K\x00d.")), // MARK + LONG1 + BININT1 + DICT
Xdpy("dict(tuple(): 0)", NewDictWithData(Tuple{}, int64(0)),
P0("((tI0\nd."), // MARK + MARK + TUPLE + INT + DICT
P1_("()K\x00d.")), // MARK + EMPTY_TUPLE + BININT1 + DICT
Xdpy("dict(tuple(1,2): 0)", NewDictWithData(Tuple{int64(1), int64(2)}, int64(0)),
P0("((I1\nI2\ntI0\nd."), // MARK + MARK + INT + INT + TUPLE + INT + DICT
P1("((K\x01K\x02tK\x00d."), // MARK + MARK + BININT1 + BININT1 + TUPLE + BININT1 + DICT
P2_("(K\x01K\x02\x86K\x00d.")), // MARK + BININT1 + BININT1 + TUPLE2 + BININT1 + DICT
Xuauto("foo.bar # global", Class{Module: "foo", Name: "bar"},
P0123("cfoo\nbar\n."), // GLOBAL
P4_("\x8c\x03foo\x8c\x03bar\x93."), // SHORT_BINUNICODE + STACK_GLOBAL
......@@ -448,9 +526,9 @@ var tests = []TestEntry{
X("LONG_BINPUT", []interface{}{int64(17)},
I("(lr0000I17\na.")),
Xuauto("graphite message1", graphiteObject1, graphitePickle1),
Xuauto("graphite message2", graphiteObject2, graphitePickle2),
Xuauto("graphite message3", graphiteObject3, graphitePickle3),
Xuauto_dgo("graphite message1", graphiteObject1, graphitePickle1),
Xuauto_dgo("graphite message2", graphiteObject2, graphitePickle2),
Xuauto_dgo("graphite message3", graphiteObject3, graphitePickle3),
Xuauto("too long line", longLine, I("V" + longLine + "\n.")),
// opcodes from protocol 4
......@@ -487,9 +565,16 @@ type foo struct {
// protocol prefix is always automatically prepended and is always concrete.
var protoPrefixTemplate = string([]byte{opProto, 0xff})
// TestDecode verifies ogórek decoder.
func TestDecode(t *testing.T) {
for _, test := range tests {
// WithEachMode runs f under all decoding/encoding modes covered by test entry.
func (test TestEntry) WithEachMode(t *testing.T, f func(t *testing.T, decConfig DecoderConfig, encConfig EncoderConfig)) {
for _, pyDict := range []bool{false, true} {
if pyDict && !test.pyDictY {
continue
}
if !pyDict && !test.pyDictN {
continue
}
for _, strictUnicode := range []bool{false, true} {
if strictUnicode && !test.strictUnicodeY {
continue
......@@ -497,8 +582,28 @@ func TestDecode(t *testing.T) {
if !strictUnicode && !test.strictUnicodeN {
continue
}
testname := fmt.Sprintf("%s/StrictUnicode=%s", test.name, yn(strictUnicode))
t.Run(fmt.Sprintf("%s/PyDict=%s/StrictUnicode=%s", test.name, yn(pyDict), yn(strictUnicode)),
func(t *testing.T) {
decConfig := DecoderConfig{
PyDict: pyDict,
StrictUnicode: strictUnicode,
}
encConfig := EncoderConfig{
// no PyDict setting for encoder
StrictUnicode: strictUnicode,
}
f(t, decConfig, encConfig)
})
}
}
}
// TestDecode verifies ogórek decoder.
func TestDecode(t *testing.T) {
for _, test := range tests {
test.WithEachMode(t, func(t *testing.T, decConfig DecoderConfig, encConfig EncoderConfig) {
for _, pickle := range test.picklev {
if pickle.err != nil {
continue
......@@ -511,32 +616,24 @@ func TestDecode(t *testing.T) {
data := string([]byte{opProto, byte(proto)}) +
pickle.data[len(protoPrefixTemplate):]
t.Run(fmt.Sprintf("%s/%q/proto=%d", testname, data, proto), func(t *testing.T) {
testDecode(t, strictUnicode, test.objectOut, data)
t.Run(fmt.Sprintf("%q/proto=%d", data, proto), func(t *testing.T) {
testDecode(t, decConfig, test.objectOut, data)
})
}
} else {
t.Run(fmt.Sprintf("%s/%q", testname, pickle.data), func(t *testing.T) {
testDecode(t, strictUnicode, test.objectOut, pickle.data)
t.Run(fmt.Sprintf("%q", pickle.data), func(t *testing.T) {
testDecode(t, decConfig, test.objectOut, pickle.data)
})
}
}
}
})
}
}
// TestEncode verifies ogórek encoder.
func TestEncode(t *testing.T) {
for _, test := range tests {
for _, strictUnicode := range []bool{false, true} {
if strictUnicode && !test.strictUnicodeY {
continue
}
if !strictUnicode && !test.strictUnicodeN {
continue
}
testname := fmt.Sprintf("%s/StrictUnicode=%s", test.name, yn(strictUnicode))
test.WithEachMode(t, func(t *testing.T, decConfig DecoderConfig, encConfig EncoderConfig) {
alreadyTested := make(map[int]bool) // protocols we tested encode with so far
for _, pickle := range test.picklev {
for _, proto := range pickle.protov {
......@@ -546,8 +643,8 @@ func TestEncode(t *testing.T) {
dataOk = string([]byte{opProto, byte(proto)}) + dataOk
}
t.Run(fmt.Sprintf("%s/proto=%d", testname, proto), func(t *testing.T) {
testEncode(t, proto, strictUnicode, test.objectIn, test.objectOut, dataOk, pickle.err)
t.Run(fmt.Sprintf("proto=%d", proto), func(t *testing.T) {
testEncode(t, proto, encConfig, decConfig, test.objectIn, test.objectOut, dataOk, pickle.err)
})
alreadyTested[proto] = true
......@@ -560,11 +657,11 @@ func TestEncode(t *testing.T) {
continue
}
t.Run(fmt.Sprintf("%s/proto=%d(roundtrip)", testname, proto), func(t *testing.T) {
testEncode(t, proto, strictUnicode, test.objectIn, test.objectOut, "", nil)
t.Run(fmt.Sprintf("proto=%d(roundtrip)", proto), func(t *testing.T) {
testEncode(t, proto, encConfig, decConfig, test.objectIn, test.objectOut, "", nil)
})
}
}
})
}
}
......@@ -572,11 +669,9 @@ func TestEncode(t *testing.T) {
//
// It also verifies decoder robustness - via feeding it various kinds of
// corrupt data derived from input.
func testDecode(t *testing.T, strictUnicode bool, object interface{}, input string) {
func testDecode(t *testing.T, decConfig DecoderConfig, object interface{}, input string) {
newDecoder := func(r io.Reader) *Decoder {
return NewDecoderWithConfig(r, &DecoderConfig{
StrictUnicode: strictUnicode,
})
return NewDecoderWithConfig(r, &decConfig)
}
// decode(input) -> expected
......@@ -633,18 +728,16 @@ func testDecode(t *testing.T, strictUnicode bool, object interface{}, input stri
// encode-back tests are still performed.
//
// If errOk != nil, object encoding must produce that error.
func testEncode(t *testing.T, proto int, strictUnicode bool, object, objectDecodedBack interface{}, dataOk string, errOk error) {
func testEncode(t *testing.T, proto int, encConfig EncoderConfig, decConfig DecoderConfig, object, objectDecodedBack interface{}, dataOk string, errOk error) {
newEncoder := func(w io.Writer) *Encoder {
return NewEncoderWithConfig(w, &EncoderConfig{
Protocol: proto,
StrictUnicode: strictUnicode,
})
econf := EncoderConfig{}
econf = encConfig
econf.Protocol = proto
return NewEncoderWithConfig(w, &econf)
}
newDecoder := func(r io.Reader) *Decoder {
return NewDecoderWithConfig(r, &DecoderConfig{
StrictUnicode: strictUnicode,
})
return NewDecoderWithConfig(r, &decConfig)
}
buf := &bytes.Buffer{}
......@@ -980,7 +1073,9 @@ func BenchmarkDecode(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
buf := bytes.NewBuffer(input)
dec := NewDecoder(buf)
dec := NewDecoderWithConfig(buf, &DecoderConfig{
PyDict: true, // so that decoding e.g. {(): 0} does not fail
})
j := 0
for ; ; j++ {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment