Commit 4fd6be93 authored by Kirill Smelkov's avatar Kirill Smelkov

encoder: Adjust it so that decode(encode(v)) == v

That is so that decode·encode becomes identity.

For the other way

	encode(decode(pickle)) == pickle

it is generally not possible to do, because pickle machine is almost
general machine, and it is possible to e.g. have some prefix in program
which is NOP or result of which is no longer used, use different opcodes
to build list or tuple etc.

Adjustments to encoder are:

- teach it to encode big.Int back to opLong (not some struct)
- teach it to encode Call back to opReduce (not some struct)

Tests are adjusted to verify that `decode(encode(v)) == v` holds for
all inputs in test vector.
parent fd5b2c0c
...@@ -5,6 +5,7 @@ import ( ...@@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"io" "io"
"math" "math"
"math/big"
"reflect" "reflect"
) )
...@@ -195,6 +196,12 @@ func (e *Encoder) encodeInt(k reflect.Kind, i int64) error { ...@@ -195,6 +196,12 @@ func (e *Encoder) encodeInt(k reflect.Kind, i int64) error {
return err return err
} }
func (e *Encoder) encodeLong(b *big.Int) error {
// TODO if e.protocol >= 2 use opLong1 & opLong4
_, err := fmt.Fprintf(e.w, "%c%dL\n", opLong, b)
return err
}
func (e *Encoder) encodeMap(m reflect.Value) error { func (e *Encoder) encodeMap(m reflect.Value) error {
keys := m.MapKeys() keys := m.MapKeys()
...@@ -238,14 +245,32 @@ func (e *Encoder) encodeString(s string) error { ...@@ -238,14 +245,32 @@ func (e *Encoder) encodeString(s string) error {
return e.encodeBytes([]byte(s)) return e.encodeBytes([]byte(s))
} }
func (e *Encoder) encodeCall(v *Call) error {
_, err := fmt.Fprintf(e.w, "%c%s\n%s\n", opGlobal, v.Callable.Module, v.Callable.Name)
if err != nil {
return err
}
err = e.encode(reflectValueOf(v.Args))
if err != nil {
return err
}
_, err = e.w.Write([]byte{opReduce})
return err
}
func (e *Encoder) encodeStruct(st reflect.Value) error { func (e *Encoder) encodeStruct(st reflect.Value) error {
typ := st.Type() typ := st.Type()
// first test if it's one of our internal python structs // first test if it's one of our internal python structs
if _, ok := st.Interface().(None); ok { switch v := st.Interface().(type) {
case None:
_, err := e.w.Write([]byte{opNone}) _, err := e.w.Write([]byte{opNone})
return err return err
case Call:
return e.encodeCall(&v)
case big.Int:
return e.encodeLong(&v)
} }
structTags := getStructTags(st) structTags := getStructTags(st)
......
...@@ -71,6 +71,7 @@ func TestDecode(t *testing.T) { ...@@ -71,6 +71,7 @@ func TestDecode(t *testing.T) {
{"SHORTBINUNICODE opcode", "\x8c\t\xe6\x97\xa5\xe6\x9c\xac\xe8\xaa\x9e\x94.", "日本語"}, {"SHORTBINUNICODE opcode", "\x8c\t\xe6\x97\xa5\xe6\x9c\xac\xe8\xaa\x9e\x94.", "日本語"},
} }
for _, test := range tests { for _, test := range tests {
// decode(input) -> expected
buf := bytes.NewBufferString(test.input) buf := bytes.NewBufferString(test.input)
dec := NewDecoder(buf) dec := NewDecoder(buf)
v, err := dec.Decode() v, err := dec.Decode()
...@@ -82,11 +83,30 @@ func TestDecode(t *testing.T) { ...@@ -82,11 +83,30 @@ func TestDecode(t *testing.T) {
t.Errorf("%s: decode:\nhave: %#v\nwant: %#v", test.name, v, test.expected) t.Errorf("%s: decode:\nhave: %#v\nwant: %#v", test.name, v, test.expected)
} }
// decode more -> EOF
v, err = dec.Decode() v, err = dec.Decode()
if !(v == nil && err == io.EOF) { if !(v == nil && err == io.EOF) {
t.Errorf("%s: decode: no EOF at end: v = %#v err = %#v", test.name, v, err) t.Errorf("%s: decode: no EOF at end: v = %#v err = %#v", test.name, v, err)
} }
// expected (= decoded(input)) -> encode -> decode = identity
buf.Reset()
enc := NewEncoder(buf)
err = enc.Encode(test.expected)
if err != nil {
t.Errorf("%s: encode(expected): %v", test.name, err)
} else {
dec := NewDecoder(buf)
v, err := dec.Decode()
if err != nil {
t.Error(err)
}
if !reflect.DeepEqual(v, test.expected) {
t.Errorf("%s: expected -> decode -> encode != identity\nhave: %#v\nwant: %#v", test.name, v, test.expected)
}
}
// for truncated input io.ErrUnexpectedEOF must be returned // for truncated input io.ErrUnexpectedEOF must be returned
for l := len(test.input) - 1; l > 0; l-- { for l := len(test.input) - 1; l > 0; l-- {
buf := bytes.NewBufferString(test.input[:l]) buf := bytes.NewBufferString(test.input[:l])
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment