Commit 12035895 authored by Kamil Trzciński's avatar Kamil Trzciński Committed by Nick Thomas

Allow to access remote archive

parent dea4edff
......@@ -5,8 +5,15 @@ import (
"flag"
"fmt"
"io"
"net"
"net/http"
"os"
"strings"
"time"
"github.com/jfbus/httprs"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/zipartifacts"
)
......@@ -16,6 +23,67 @@ var Version = "unknown"
var printVersion = flag.Bool("version", false, "Print version and exit")
var httpClient = &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 10 * time.Second,
}).DialContext,
IdleConnTimeout: 30 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 10 * time.Second,
ResponseHeaderTimeout: 30 * time.Second,
},
}
func isURL(path string) bool {
return strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://")
}
func openHTTPArchive(archivePath string) (*zip.Reader, func()) {
scrubbedArchivePath := helper.ScrubURLParams(archivePath)
resp, err := httpClient.Get(archivePath)
if err != nil {
fatalError(fmt.Errorf("HTTP GET %q: %v", scrubbedArchivePath, err))
} else if resp.StatusCode == http.StatusNotFound {
notFoundError(fmt.Errorf("HTTP GET %q: not found", scrubbedArchivePath))
} else if resp.StatusCode != http.StatusOK {
fatalError(fmt.Errorf("HTTP GET %q: %d: %v", scrubbedArchivePath, resp.StatusCode, resp.Status))
}
rs := httprs.NewHttpReadSeeker(resp, httpClient)
archive, err := zip.NewReader(rs, resp.ContentLength)
if err != nil {
notFoundError(fmt.Errorf("open %q: %v", scrubbedArchivePath, err))
}
return archive, func() {
resp.Body.Close()
rs.Close()
}
}
func openFileArchive(archivePath string) (*zip.Reader, func()) {
archive, err := zip.OpenReader(archivePath)
if err != nil {
notFoundError(fmt.Errorf("open %q: %v", archivePath, err))
}
return &archive.Reader, func() {
archive.Close()
}
}
func openArchive(archivePath string) (*zip.Reader, func()) {
if isURL(archivePath) {
return openHTTPArchive(archivePath)
}
return openFileArchive(archivePath)
}
func main() {
flag.Parse()
......@@ -25,32 +93,34 @@ func main() {
os.Exit(0)
}
if len(os.Args) != 3 {
fmt.Fprintf(os.Stderr, "Usage: %s FILE.ZIP ENTRY\n", progName)
archivePath := os.Getenv("ARCHIVE_PATH")
encodedFileName := os.Getenv("ENCODED_FILE_NAME")
if len(os.Args) != 1 || archivePath == "" || encodedFileName == "" {
fmt.Fprintf(os.Stderr, "Usage: %s\n", progName)
fmt.Fprintf(os.Stderr, "Env: ARCHIVE_PATH=https://path.to/archive.zip or /path/to/archive.zip\n")
fmt.Fprintf(os.Stderr, "Env: ENCODED_FILE_NAME=base64-encoded-file-name\n")
os.Exit(1)
}
archiveFileName := os.Args[1]
scrubbedArchivePath := helper.ScrubURLParams(archivePath)
fileName, err := zipartifacts.DecodeFileEntry(os.Args[2])
fileName, err := zipartifacts.DecodeFileEntry(encodedFileName)
if err != nil {
fatalError(fmt.Errorf("decode entry %q: %v", os.Args[2], err))
fatalError(fmt.Errorf("decode entry %q: %v", encodedFileName, err))
}
archive, err := zip.OpenReader(archiveFileName)
if err != nil {
notFoundError(fmt.Errorf("open %q: %v", archiveFileName, err))
}
defer archive.Close()
archive, cleanFn := openArchive(archivePath)
defer cleanFn()
file := findFileInZip(fileName, &archive.Reader)
file := findFileInZip(fileName, archive)
if file == nil {
notFoundError(fmt.Errorf("find %q in %q: not found", fileName, archiveFileName))
notFoundError(fmt.Errorf("find %q in %q: not found", fileName, scrubbedArchivePath))
}
// Start decompressing the file
reader, err := file.Open()
if err != nil {
fatalError(fmt.Errorf("open %q in %q: %v", fileName, archiveFileName, err))
fatalError(fmt.Errorf("open %q in %q: %v", fileName, scrubbedArchivePath, err))
}
defer reader.Close()
......@@ -59,7 +129,7 @@ func main() {
}
if _, err := io.Copy(os.Stdout, reader); err != nil {
fatalError(fmt.Errorf("write %q from %q to stdout: %v", fileName, archiveFileName, err))
fatalError(fmt.Errorf("write %q from %q to stdout: %v", fileName, scrubbedArchivePath, err))
}
}
......
......@@ -55,13 +55,17 @@ func detectFileContentType(fileName string) string {
return contentType
}
func unpackFileFromZip(archiveFileName, encodedFilename string, headers http.Header, output io.Writer) error {
func unpackFileFromZip(archivePath, encodedFilename string, headers http.Header, output io.Writer) error {
fileName, err := zipartifacts.DecodeFileEntry(encodedFilename)
if err != nil {
return err
}
catFile := exec.Command("gitlab-zip-cat", archiveFileName, encodedFilename)
catFile := exec.Command("gitlab-zip-cat")
catFile.Env = []string{
"ARCHIVE_PATH=" + archivePath,
"ENCODED_FILE_NAME=" + encodedFilename,
}
catFile.Stderr = os.Stderr
catFile.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
stdout, err := catFile.StdoutPipe()
......
......@@ -8,17 +8,18 @@ import (
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/require"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper"
)
func testEntryServer(t *testing.T, archive string, entry string) *httptest.ResponseRecorder {
mux := http.NewServeMux()
mux.HandleFunc("/url/path", func(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" {
t.Fatal("Expected GET request")
}
require.Equal(t, "GET", r.Method)
encodedEntry := base64.StdEncoding.EncodeToString([]byte(entry))
jsonParams := fmt.Sprintf(`{"Archive":"%s","Entry":"%s"}`, archive, encodedEntry)
......@@ -28,9 +29,7 @@ func testEntryServer(t *testing.T, archive string, entry string) *httptest.Respo
})
httpRequest, err := http.NewRequest("GET", "/url/path", nil)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
response := httptest.NewRecorder()
mux.ServeHTTP(response, httpRequest)
return response
......@@ -38,18 +37,14 @@ func testEntryServer(t *testing.T, archive string, entry string) *httptest.Respo
func TestDownloadingFromValidArchive(t *testing.T) {
tempFile, err := ioutil.TempFile("", "uploads")
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
defer tempFile.Close()
defer os.Remove(tempFile.Name())
archive := zip.NewWriter(tempFile)
defer archive.Close()
fileInArchive, err := archive.Create("test.txt")
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
fmt.Fprint(fileInArchive, "testtest")
archive.Close()
......@@ -67,11 +62,43 @@ func TestDownloadingFromValidArchive(t *testing.T) {
testhelper.AssertResponseBody(t, response, "testtest")
}
func TestDownloadingFromValidHTTPArchive(t *testing.T) {
tempDir, err := ioutil.TempDir("", "uploads")
require.NoError(t, err)
defer os.RemoveAll(tempDir)
f, err := os.Create(filepath.Join(tempDir, "archive.zip"))
require.NoError(t, err)
defer f.Close()
archive := zip.NewWriter(f)
defer archive.Close()
fileInArchive, err := archive.Create("test.txt")
require.NoError(t, err)
fmt.Fprint(fileInArchive, "testtest")
archive.Close()
f.Close()
fileServer := httptest.NewServer(http.FileServer(http.Dir(tempDir)))
defer fileServer.Close()
response := testEntryServer(t, fileServer.URL+"/archive.zip", "test.txt")
testhelper.AssertResponseCode(t, response, 200)
testhelper.AssertResponseWriterHeader(t, response,
"Content-Type",
"text/plain; charset=utf-8")
testhelper.AssertResponseWriterHeader(t, response,
"Content-Disposition",
"attachment; filename=\"test.txt\"")
testhelper.AssertResponseBody(t, response, "testtest")
}
func TestDownloadingNonExistingFile(t *testing.T) {
tempFile, err := ioutil.TempFile("", "uploads")
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
defer tempFile.Close()
defer os.Remove(tempFile.Name())
......@@ -92,3 +119,16 @@ func TestIncompleteApiResponse(t *testing.T) {
response := testEntryServer(t, "", "")
testhelper.AssertResponseCode(t, response, 500)
}
func TestDownloadingFromNonExistingHTTPArchive(t *testing.T) {
tempDir, err := ioutil.TempDir("", "uploads")
require.NoError(t, err)
defer os.RemoveAll(tempDir)
fileServer := httptest.NewServer(http.FileServer(http.Dir(tempDir)))
defer fileServer.Close()
response := testEntryServer(t, fileServer.URL+"/not-existing-archive-file.zip", "test.txt")
testhelper.AssertResponseCode(t, response, 404)
}
......@@ -18,7 +18,7 @@ import (
const NginxResponseBufferHeader = "X-Accel-Buffering"
var scrubRegexp = regexp.MustCompile(`([\?&](?:private|authenticity|rss)[\-_]token)=[^&]*`)
var scrubRegexp = regexp.MustCompile(`(?i)([\?&]((?:private|authenticity|rss)[\-_]token)|X-AMZ-Signature)=[^&]*`)
func Fail500(w http.ResponseWriter, r *http.Request, err error) {
http.Error(w, "Internal server error", 500)
......
......@@ -127,6 +127,9 @@ func TestScrubURLParams(t *testing.T) {
"?private-token=&authenticity_token=&bar": "?private-token=[FILTERED]&authenticity_token=[FILTERED]&bar",
"?private-token=foo&authenticity_token=bar": "?private-token=[FILTERED]&authenticity_token=[FILTERED]",
"?private_token=foo&authenticity-token=bar": "?private_token=[FILTERED]&authenticity-token=[FILTERED]",
"?X-AMZ-Signature=foo": "?X-AMZ-Signature=[FILTERED]",
"&X-AMZ-Signature=foo": "&X-AMZ-Signature=[FILTERED]",
"?x-amz-signature=foo": "?x-amz-signature=[FILTERED]",
} {
after := ScrubURLParams(before)
assert.Equal(t, expected, after, "Scrubbing %q", before)
......
Copyright (c) 2015 Jean-François Bustarret
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
\ No newline at end of file
# httprs
A ReadSeeker for http.Response.Body
[![wercker status](https://app.wercker.com/status/b8ab18faefae7d1f88f9f23d642f0847/s/master "wercker status")](https://app.wercker.com/project/bykey/b8ab18faefae7d1f88f9f23d642f0847)
## Usage
```
import "github.com/jfbus/httprs"
resp, err := http.Get(url)
rs := httprs.NewHttpReadSeeker(resp)
defer rs.Close()
io.ReadFull(rs, buf) // reads the first bytes from the response
rs.Seek(1024, 0) // moves the position
io.ReadFull(rs, buf) // does an additional range request and reads the first bytes from the second response
```
if you use a specific http.Client :
```
rs := httprs.NewHttpReadSeeker(resp, client)
```
## Doc
See http://godoc.org/github.com/jfbus/httprs
## LICENSE
MIT - See LICENSE
\ No newline at end of file
/*
Package httprs provides a ReadSeeker for http.Response.Body.
Usage :
resp, err := http.Get(url)
rs := httprs.NewHttpReadSeeker(resp)
defer rs.Close()
io.ReadFull(rs, buf) // reads the first bytes from the response body
rs.Seek(1024, 0) // moves the position, but does no range request
io.ReadFull(rs, buf) // does a range request and reads from the response body
If you want use a specific http.Client for additional range requests :
rs := httprs.NewHttpReadSeeker(resp, client)
*/
package httprs
import (
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"github.com/mitchellh/copystructure"
)
const shortSeekBytes = 1024
// A HttpReadSeeker reads from a http.Response.Body. It can Seek
// by doing range requests.
type HttpReadSeeker struct {
c *http.Client
req *http.Request
res *http.Response
r io.ReadCloser
pos int64
canSeek bool
Requests int
}
var _ io.ReadCloser = (*HttpReadSeeker)(nil)
var _ io.Seeker = (*HttpReadSeeker)(nil)
var (
// ErrNoContentLength is returned by Seek when the initial http response did not include a Content-Length header
ErrNoContentLength = errors.New("Content-Length was not set")
// ErrRangeRequestsNotSupported is returned by Seek and Read
// when the remote server does not allow range requests (Accept-Ranges was not set)
ErrRangeRequestsNotSupported = errors.New("Range requests are not supported by the remote server")
// ErrInvalidRange is returned by Read when trying to read past the end of the file
ErrInvalidRange = errors.New("Invalid range")
// ErrContentHasChanged is returned by Read when the content has changed since the first request
ErrContentHasChanged = errors.New("Content has changed since first request")
)
// NewHttpReadSeeker returns a HttpReadSeeker, using the http.Response and, optionaly, the http.Client
// that needs to be used for future range requests. If no http.Client is given, http.DefaultClient will
// be used.
//
// res.Request will be reused for range requests, headers may be added/removed
func NewHttpReadSeeker(res *http.Response, client ...*http.Client) *HttpReadSeeker {
r := &HttpReadSeeker{
req: res.Request,
res: res,
r: res.Body,
canSeek: (res.Header.Get("Accept-Ranges") == "bytes"),
}
if len(client) > 0 {
r.c = client[0]
} else {
r.c = http.DefaultClient
}
return r
}
// Clone clones the reader to enable parallel downloads of ranges
func (r *HttpReadSeeker) Clone() (*HttpReadSeeker, error) {
req, err := copystructure.Copy(r.req)
if err != nil {
return nil, err
}
return &HttpReadSeeker{
req: req.(*http.Request),
res: r.res,
r: nil,
canSeek: r.canSeek,
c: r.c,
}, nil
}
// Read reads from the response body. It does a range request if Seek was called before.
//
// May return ErrRangeRequestsNotSupported, ErrInvalidRange or ErrContentHasChanged
func (r *HttpReadSeeker) Read(p []byte) (n int, err error) {
if r.r == nil {
err = r.rangeRequest()
}
if r.r != nil {
n, err = r.r.Read(p)
r.pos += int64(n)
}
return
}
// ReadAt reads from the response body starting at offset off.
//
// May return ErrRangeRequestsNotSupported, ErrInvalidRange or ErrContentHasChanged
func (r *HttpReadSeeker) ReadAt(p []byte, off int64) (n int, err error) {
r.Seek(off, 0)
return r.Read(p)
}
// Close closes the response body
func (r *HttpReadSeeker) Close() error {
if r.r != nil {
return r.r.Close()
}
return nil
}
// Seek moves the reader position to a new offset.
//
// It does not send http requests, allowing for multiple seeks without overhead.
// The http request will be sent by the next Read call.
//
// May return ErrNoContentLength or ErrRangeRequestsNotSupported
func (r *HttpReadSeeker) Seek(offset int64, whence int) (int64, error) {
if !r.canSeek {
return 0, ErrRangeRequestsNotSupported
}
var err error
switch whence {
case 0:
case 1:
offset += r.pos
case 2:
if r.res.ContentLength <= 0 {
return 0, ErrNoContentLength
}
offset = r.res.ContentLength - offset
}
if r.r != nil {
// Try to read, which is cheaper than doing a request
if r.pos < offset && offset-r.pos <= shortSeekBytes {
_, err := io.CopyN(ioutil.Discard, r, offset-r.pos)
if err != nil {
return 0, err
}
}
if r.pos != offset {
err = r.r.Close()
r.r = nil
}
}
r.pos = offset
return r.pos, err
}
func (r *HttpReadSeeker) rangeRequest() error {
r.req.Header.Set("Range", fmt.Sprintf("bytes=%d-", r.pos))
etag, last := r.res.Header.Get("ETag"), r.res.Header.Get("Last-Modified")
switch {
case last != "":
r.req.Header.Set("If-Range", last)
case etag != "":
r.req.Header.Set("If-Range", etag)
}
r.Requests++
res, err := r.c.Do(r.req)
if err != nil {
return err
}
switch res.StatusCode {
case http.StatusRequestedRangeNotSatisfiable:
return ErrInvalidRange
case http.StatusOK:
return ErrContentHasChanged
case http.StatusPartialContent:
r.r = res.Body
return nil
}
return ErrRangeRequestsNotSupported
}
box: wercker/golang
build:
steps:
- setup-go-workspace
- script:
name: Install goconvey
code: go get github.com/smartystreets/goconvey/convey
- script:
name: Go get
code: go get -v ./...
- script:
name: Go test
code: go test -p 1 -v ./...
The MIT License (MIT)
Copyright (c) 2014 Mitchell Hashimoto
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
# copystructure
copystructure is a Go library for deep copying values in Go.
This allows you to copy Go values that may contain reference values
such as maps, slices, or pointers, and copy their data as well instead
of just their references.
## Installation
Standard `go get`:
```
$ go get github.com/mitchellh/copystructure
```
## Usage & Example
For usage and examples see the [Godoc](http://godoc.org/github.com/mitchellh/copystructure).
The `Copy` function has examples associated with it there.
package copystructure
import (
"reflect"
"time"
)
func init() {
Copiers[reflect.TypeOf(time.Time{})] = timeCopier
}
func timeCopier(v interface{}) (interface{}, error) {
// Just... copy it.
return v.(time.Time), nil
}
package copystructure
import (
"errors"
"reflect"
"sync"
"github.com/mitchellh/reflectwalk"
)
// Copy returns a deep copy of v.
func Copy(v interface{}) (interface{}, error) {
return Config{}.Copy(v)
}
// CopierFunc is a function that knows how to deep copy a specific type.
// Register these globally with the Copiers variable.
type CopierFunc func(interface{}) (interface{}, error)
// Copiers is a map of types that behave specially when they are copied.
// If a type is found in this map while deep copying, this function
// will be called to copy it instead of attempting to copy all fields.
//
// The key should be the type, obtained using: reflect.TypeOf(value with type).
//
// It is unsafe to write to this map after Copies have started. If you
// are writing to this map while also copying, wrap all modifications to
// this map as well as to Copy in a mutex.
var Copiers map[reflect.Type]CopierFunc = make(map[reflect.Type]CopierFunc)
// Must is a helper that wraps a call to a function returning
// (interface{}, error) and panics if the error is non-nil. It is intended
// for use in variable initializations and should only be used when a copy
// error should be a crashing case.
func Must(v interface{}, err error) interface{} {
if err != nil {
panic("copy error: " + err.Error())
}
return v
}
var errPointerRequired = errors.New("Copy argument must be a pointer when Lock is true")
type Config struct {
// Lock any types that are a sync.Locker and are not a mutex while copying.
// If there is an RLocker method, use that to get the sync.Locker.
Lock bool
// Copiers is a map of types associated with a CopierFunc. Use the global
// Copiers map if this is nil.
Copiers map[reflect.Type]CopierFunc
}
func (c Config) Copy(v interface{}) (interface{}, error) {
if c.Lock && reflect.ValueOf(v).Kind() != reflect.Ptr {
return nil, errPointerRequired
}
w := new(walker)
if c.Lock {
w.useLocks = true
}
if c.Copiers == nil {
c.Copiers = Copiers
}
err := reflectwalk.Walk(v, w)
if err != nil {
return nil, err
}
// Get the result. If the result is nil, then we want to turn it
// into a typed nil if we can.
result := w.Result
if result == nil {
val := reflect.ValueOf(v)
result = reflect.Indirect(reflect.New(val.Type())).Interface()
}
return result, nil
}
// Return the key used to index interfaces types we've seen. Store the number
// of pointers in the upper 32bits, and the depth in the lower 32bits. This is
// easy to calculate, easy to match a key with our current depth, and we don't
// need to deal with initializing and cleaning up nested maps or slices.
func ifaceKey(pointers, depth int) uint64 {
return uint64(pointers)<<32 | uint64(depth)
}
type walker struct {
Result interface{}
depth int
ignoreDepth int
vals []reflect.Value
cs []reflect.Value
// This stores the number of pointers we've walked over, indexed by depth.
ps []int
// If an interface is indirected by a pointer, we need to know the type of
// interface to create when creating the new value. Store the interface
// types here, indexed by both the walk depth and the number of pointers
// already seen at that depth. Use ifaceKey to calculate the proper uint64
// value.
ifaceTypes map[uint64]reflect.Type
// any locks we've taken, indexed by depth
locks []sync.Locker
// take locks while walking the structure
useLocks bool
}
func (w *walker) Enter(l reflectwalk.Location) error {
w.depth++
// ensure we have enough elements to index via w.depth
for w.depth >= len(w.locks) {
w.locks = append(w.locks, nil)
}
for len(w.ps) < w.depth+1 {
w.ps = append(w.ps, 0)
}
return nil
}
func (w *walker) Exit(l reflectwalk.Location) error {
locker := w.locks[w.depth]
w.locks[w.depth] = nil
if locker != nil {
defer locker.Unlock()
}
// clear out pointers and interfaces as we exit the stack
w.ps[w.depth] = 0
for k := range w.ifaceTypes {
mask := uint64(^uint32(0))
if k&mask == uint64(w.depth) {
delete(w.ifaceTypes, k)
}
}
w.depth--
if w.ignoreDepth > w.depth {
w.ignoreDepth = 0
}
if w.ignoring() {
return nil
}
switch l {
case reflectwalk.Array:
fallthrough
case reflectwalk.Map:
fallthrough
case reflectwalk.Slice:
w.replacePointerMaybe()
// Pop map off our container
w.cs = w.cs[:len(w.cs)-1]
case reflectwalk.MapValue:
// Pop off the key and value
mv := w.valPop()
mk := w.valPop()
m := w.cs[len(w.cs)-1]
// If mv is the zero value, SetMapIndex deletes the key form the map,
// or in this case never adds it. We need to create a properly typed
// zero value so that this key can be set.
if !mv.IsValid() {
mv = reflect.Zero(m.Elem().Type().Elem())
}
m.Elem().SetMapIndex(mk, mv)
case reflectwalk.ArrayElem:
// Pop off the value and the index and set it on the array
v := w.valPop()
i := w.valPop().Interface().(int)
if v.IsValid() {
a := w.cs[len(w.cs)-1]
ae := a.Elem().Index(i) // storing array as pointer on stack - so need Elem() call
if ae.CanSet() {
ae.Set(v)
}
}
case reflectwalk.SliceElem:
// Pop off the value and the index and set it on the slice
v := w.valPop()
i := w.valPop().Interface().(int)
if v.IsValid() {
s := w.cs[len(w.cs)-1]
se := s.Elem().Index(i)
if se.CanSet() {
se.Set(v)
}
}
case reflectwalk.Struct:
w.replacePointerMaybe()
// Remove the struct from the container stack
w.cs = w.cs[:len(w.cs)-1]
case reflectwalk.StructField:
// Pop off the value and the field
v := w.valPop()
f := w.valPop().Interface().(reflect.StructField)
if v.IsValid() {
s := w.cs[len(w.cs)-1]
sf := reflect.Indirect(s).FieldByName(f.Name)
if sf.CanSet() {
sf.Set(v)
}
}
case reflectwalk.WalkLoc:
// Clear out the slices for GC
w.cs = nil
w.vals = nil
}
return nil
}
func (w *walker) Map(m reflect.Value) error {
if w.ignoring() {
return nil
}
w.lock(m)
// Create the map. If the map itself is nil, then just make a nil map
var newMap reflect.Value
if m.IsNil() {
newMap = reflect.New(m.Type())
} else {
newMap = wrapPtr(reflect.MakeMap(m.Type()))
}
w.cs = append(w.cs, newMap)
w.valPush(newMap)
return nil
}
func (w *walker) MapElem(m, k, v reflect.Value) error {
return nil
}
func (w *walker) PointerEnter(v bool) error {
if v {
w.ps[w.depth]++
}
return nil
}
func (w *walker) PointerExit(v bool) error {
if v {
w.ps[w.depth]--
}
return nil
}
func (w *walker) Interface(v reflect.Value) error {
if !v.IsValid() {
return nil
}
if w.ifaceTypes == nil {
w.ifaceTypes = make(map[uint64]reflect.Type)
}
w.ifaceTypes[ifaceKey(w.ps[w.depth], w.depth)] = v.Type()
return nil
}
func (w *walker) Primitive(v reflect.Value) error {
if w.ignoring() {
return nil
}
w.lock(v)
// IsValid verifies the v is non-zero and CanInterface verifies
// that we're allowed to read this value (unexported fields).
var newV reflect.Value
if v.IsValid() && v.CanInterface() {
newV = reflect.New(v.Type())
newV.Elem().Set(v)
}
w.valPush(newV)
w.replacePointerMaybe()
return nil
}
func (w *walker) Slice(s reflect.Value) error {
if w.ignoring() {
return nil
}
w.lock(s)
var newS reflect.Value
if s.IsNil() {
newS = reflect.New(s.Type())
} else {
newS = wrapPtr(reflect.MakeSlice(s.Type(), s.Len(), s.Cap()))
}
w.cs = append(w.cs, newS)
w.valPush(newS)
return nil
}
func (w *walker) SliceElem(i int, elem reflect.Value) error {
if w.ignoring() {
return nil
}
// We don't write the slice here because elem might still be
// arbitrarily complex. Just record the index and continue on.
w.valPush(reflect.ValueOf(i))
return nil
}
func (w *walker) Array(a reflect.Value) error {
if w.ignoring() {
return nil
}
w.lock(a)
newA := reflect.New(a.Type())
w.cs = append(w.cs, newA)
w.valPush(newA)
return nil
}
func (w *walker) ArrayElem(i int, elem reflect.Value) error {
if w.ignoring() {
return nil
}
// We don't write the array here because elem might still be
// arbitrarily complex. Just record the index and continue on.
w.valPush(reflect.ValueOf(i))
return nil
}
func (w *walker) Struct(s reflect.Value) error {
if w.ignoring() {
return nil
}
w.lock(s)
var v reflect.Value
if c, ok := Copiers[s.Type()]; ok {
// We have a Copier for this struct, so we use that copier to
// get the copy, and we ignore anything deeper than this.
w.ignoreDepth = w.depth
dup, err := c(s.Interface())
if err != nil {
return err
}
v = reflect.ValueOf(dup)
} else {
// No copier, we copy ourselves and allow reflectwalk to guide
// us deeper into the structure for copying.
v = reflect.New(s.Type())
}
// Push the value onto the value stack for setting the struct field,
// and add the struct itself to the containers stack in case we walk
// deeper so that its own fields can be modified.
w.valPush(v)
w.cs = append(w.cs, v)
return nil
}
func (w *walker) StructField(f reflect.StructField, v reflect.Value) error {
if w.ignoring() {
return nil
}
// If PkgPath is non-empty, this is a private (unexported) field.
// We do not set this unexported since the Go runtime doesn't allow us.
if f.PkgPath != "" {
return reflectwalk.SkipEntry
}
// Push the field onto the stack, we'll handle it when we exit
// the struct field in Exit...
w.valPush(reflect.ValueOf(f))
return nil
}
// ignore causes the walker to ignore any more values until we exit this on
func (w *walker) ignore() {
w.ignoreDepth = w.depth
}
func (w *walker) ignoring() bool {
return w.ignoreDepth > 0 && w.depth >= w.ignoreDepth
}
func (w *walker) pointerPeek() bool {
return w.ps[w.depth] > 0
}
func (w *walker) valPop() reflect.Value {
result := w.vals[len(w.vals)-1]
w.vals = w.vals[:len(w.vals)-1]
// If we're out of values, that means we popped everything off. In
// this case, we reset the result so the next pushed value becomes
// the result.
if len(w.vals) == 0 {
w.Result = nil
}
return result
}
func (w *walker) valPush(v reflect.Value) {
w.vals = append(w.vals, v)
// If we haven't set the result yet, then this is the result since
// it is the first (outermost) value we're seeing.
if w.Result == nil && v.IsValid() {
w.Result = v.Interface()
}
}
func (w *walker) replacePointerMaybe() {
// Determine the last pointer value. If it is NOT a pointer, then
// we need to push that onto the stack.
if !w.pointerPeek() {
w.valPush(reflect.Indirect(w.valPop()))
return
}
v := w.valPop()
// If the expected type is a pointer to an interface of any depth,
// such as *interface{}, **interface{}, etc., then we need to convert
// the value "v" from *CONCRETE to *interface{} so types match for
// Set.
//
// Example if v is type *Foo where Foo is a struct, v would become
// *interface{} instead. This only happens if we have an interface expectation
// at this depth.
//
// For more info, see GH-16
if iType, ok := w.ifaceTypes[ifaceKey(w.ps[w.depth], w.depth)]; ok && iType.Kind() == reflect.Interface {
y := reflect.New(iType) // Create *interface{}
y.Elem().Set(reflect.Indirect(v)) // Assign "Foo" to interface{} (dereferenced)
v = y // v is now typed *interface{} (where *v = Foo)
}
for i := 1; i < w.ps[w.depth]; i++ {
if iType, ok := w.ifaceTypes[ifaceKey(w.ps[w.depth]-i, w.depth)]; ok {
iface := reflect.New(iType).Elem()
iface.Set(v)
v = iface
}
p := reflect.New(v.Type())
p.Elem().Set(v)
v = p
}
w.valPush(v)
}
// if this value is a Locker, lock it and add it to the locks slice
func (w *walker) lock(v reflect.Value) {
if !w.useLocks {
return
}
if !v.IsValid() || !v.CanInterface() {
return
}
type rlocker interface {
RLocker() sync.Locker
}
var locker sync.Locker
// We can't call Interface() on a value directly, since that requires
// a copy. This is OK, since the pointer to a value which is a sync.Locker
// is also a sync.Locker.
if v.Kind() == reflect.Ptr {
switch l := v.Interface().(type) {
case rlocker:
// don't lock a mutex directly
if _, ok := l.(*sync.RWMutex); !ok {
locker = l.RLocker()
}
case sync.Locker:
locker = l
}
} else if v.CanAddr() {
switch l := v.Addr().Interface().(type) {
case rlocker:
// don't lock a mutex directly
if _, ok := l.(*sync.RWMutex); !ok {
locker = l.RLocker()
}
case sync.Locker:
locker = l
}
}
// still no callable locker
if locker == nil {
return
}
// don't lock a mutex directly
switch locker.(type) {
case *sync.Mutex, *sync.RWMutex:
return
}
locker.Lock()
w.locks[w.depth] = locker
}
// wrapPtr is a helper that takes v and always make it *v. copystructure
// stores things internally as pointers until the last moment before unwrapping
func wrapPtr(v reflect.Value) reflect.Value {
if !v.IsValid() {
return v
}
vPtr := reflect.New(v.Type())
vPtr.Elem().Set(v)
return vPtr
}
The MIT License (MIT)
Copyright (c) 2013 Mitchell Hashimoto
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
# reflectwalk
reflectwalk is a Go library for "walking" a value in Go using reflection,
in the same way a directory tree can be "walked" on the filesystem. Walking
a complex structure can allow you to do manipulations on unknown structures
such as those decoded from JSON.
package reflectwalk
//go:generate stringer -type=Location location.go
type Location uint
const (
None Location = iota
Map
MapKey
MapValue
Slice
SliceElem
Array
ArrayElem
Struct
StructField
WalkLoc
)
// generated by stringer -type=Location location.go; DO NOT EDIT
package reflectwalk
import "fmt"
const _Location_name = "NoneMapMapKeyMapValueSliceSliceElemStructStructFieldWalkLoc"
var _Location_index = [...]uint8{0, 4, 7, 13, 21, 26, 35, 41, 52, 59}
func (i Location) String() string {
if i+1 >= Location(len(_Location_index)) {
return fmt.Sprintf("Location(%d)", i)
}
return _Location_name[_Location_index[i]:_Location_index[i+1]]
}
// reflectwalk is a package that allows you to "walk" complex structures
// similar to how you may "walk" a filesystem: visiting every element one
// by one and calling callback functions allowing you to handle and manipulate
// those elements.
package reflectwalk
import (
"errors"
"reflect"
)
// PrimitiveWalker implementations are able to handle primitive values
// within complex structures. Primitive values are numbers, strings,
// booleans, funcs, chans.
//
// These primitive values are often members of more complex
// structures (slices, maps, etc.) that are walkable by other interfaces.
type PrimitiveWalker interface {
Primitive(reflect.Value) error
}
// InterfaceWalker implementations are able to handle interface values as they
// are encountered during the walk.
type InterfaceWalker interface {
Interface(reflect.Value) error
}
// MapWalker implementations are able to handle individual elements
// found within a map structure.
type MapWalker interface {
Map(m reflect.Value) error
MapElem(m, k, v reflect.Value) error
}
// SliceWalker implementations are able to handle slice elements found
// within complex structures.
type SliceWalker interface {
Slice(reflect.Value) error
SliceElem(int, reflect.Value) error
}
// ArrayWalker implementations are able to handle array elements found
// within complex structures.
type ArrayWalker interface {
Array(reflect.Value) error
ArrayElem(int, reflect.Value) error
}
// StructWalker is an interface that has methods that are called for
// structs when a Walk is done.
type StructWalker interface {
Struct(reflect.Value) error
StructField(reflect.StructField, reflect.Value) error
}
// EnterExitWalker implementations are notified before and after
// they walk deeper into complex structures (into struct fields,
// into slice elements, etc.)
type EnterExitWalker interface {
Enter(Location) error
Exit(Location) error
}
// PointerWalker implementations are notified when the value they're
// walking is a pointer or not. Pointer is called for _every_ value whether
// it is a pointer or not.
type PointerWalker interface {
PointerEnter(bool) error
PointerExit(bool) error
}
// SkipEntry can be returned from walk functions to skip walking
// the value of this field. This is only valid in the following functions:
//
// - Struct: skips all fields from being walked
// - StructField: skips walking the struct value
//
var SkipEntry = errors.New("skip this entry")
// Walk takes an arbitrary value and an interface and traverses the
// value, calling callbacks on the interface if they are supported.
// The interface should implement one or more of the walker interfaces
// in this package, such as PrimitiveWalker, StructWalker, etc.
func Walk(data, walker interface{}) (err error) {
v := reflect.ValueOf(data)
ew, ok := walker.(EnterExitWalker)
if ok {
err = ew.Enter(WalkLoc)
}
if err == nil {
err = walk(v, walker)
}
if ok && err == nil {
err = ew.Exit(WalkLoc)
}
return
}
func walk(v reflect.Value, w interface{}) (err error) {
// Determine if we're receiving a pointer and if so notify the walker.
// The logic here is convoluted but very important (tests will fail if
// almost any part is changed). I will try to explain here.
//
// First, we check if the value is an interface, if so, we really need
// to check the interface's VALUE to see whether it is a pointer.
//
// Check whether the value is then a pointer. If so, then set pointer
// to true to notify the user.
//
// If we still have a pointer or an interface after the indirections, then
// we unwrap another level
//
// At this time, we also set "v" to be the dereferenced value. This is
// because once we've unwrapped the pointer we want to use that value.
pointer := false
pointerV := v
for {
if pointerV.Kind() == reflect.Interface {
if iw, ok := w.(InterfaceWalker); ok {
if err = iw.Interface(pointerV); err != nil {
return
}
}
pointerV = pointerV.Elem()
}
if pointerV.Kind() == reflect.Ptr {
pointer = true
v = reflect.Indirect(pointerV)
}
if pw, ok := w.(PointerWalker); ok {
if err = pw.PointerEnter(pointer); err != nil {
return
}
defer func(pointer bool) {
if err != nil {
return
}
err = pw.PointerExit(pointer)
}(pointer)
}
if pointer {
pointerV = v
}
pointer = false
// If we still have a pointer or interface we have to indirect another level.
switch pointerV.Kind() {
case reflect.Ptr, reflect.Interface:
continue
}
break
}
// We preserve the original value here because if it is an interface
// type, we want to pass that directly into the walkPrimitive, so that
// we can set it.
originalV := v
if v.Kind() == reflect.Interface {
v = v.Elem()
}
k := v.Kind()
if k >= reflect.Int && k <= reflect.Complex128 {
k = reflect.Int
}
switch k {
// Primitives
case reflect.Bool, reflect.Chan, reflect.Func, reflect.Int, reflect.String, reflect.Invalid:
err = walkPrimitive(originalV, w)
return
case reflect.Map:
err = walkMap(v, w)
return
case reflect.Slice:
err = walkSlice(v, w)
return
case reflect.Struct:
err = walkStruct(v, w)
return
case reflect.Array:
err = walkArray(v, w)
return
default:
panic("unsupported type: " + k.String())
}
}
func walkMap(v reflect.Value, w interface{}) error {
ew, ewok := w.(EnterExitWalker)
if ewok {
ew.Enter(Map)
}
if mw, ok := w.(MapWalker); ok {
if err := mw.Map(v); err != nil {
return err
}
}
for _, k := range v.MapKeys() {
kv := v.MapIndex(k)
if mw, ok := w.(MapWalker); ok {
if err := mw.MapElem(v, k, kv); err != nil {
return err
}
}
ew, ok := w.(EnterExitWalker)
if ok {
ew.Enter(MapKey)
}
if err := walk(k, w); err != nil {
return err
}
if ok {
ew.Exit(MapKey)
ew.Enter(MapValue)
}
if err := walk(kv, w); err != nil {
return err
}
if ok {
ew.Exit(MapValue)
}
}
if ewok {
ew.Exit(Map)
}
return nil
}
func walkPrimitive(v reflect.Value, w interface{}) error {
if pw, ok := w.(PrimitiveWalker); ok {
return pw.Primitive(v)
}
return nil
}
func walkSlice(v reflect.Value, w interface{}) (err error) {
ew, ok := w.(EnterExitWalker)
if ok {
ew.Enter(Slice)
}
if sw, ok := w.(SliceWalker); ok {
if err := sw.Slice(v); err != nil {
return err
}
}
for i := 0; i < v.Len(); i++ {
elem := v.Index(i)
if sw, ok := w.(SliceWalker); ok {
if err := sw.SliceElem(i, elem); err != nil {
return err
}
}
ew, ok := w.(EnterExitWalker)
if ok {
ew.Enter(SliceElem)
}
if err := walk(elem, w); err != nil {
return err
}
if ok {
ew.Exit(SliceElem)
}
}
ew, ok = w.(EnterExitWalker)
if ok {
ew.Exit(Slice)
}
return nil
}
func walkArray(v reflect.Value, w interface{}) (err error) {
ew, ok := w.(EnterExitWalker)
if ok {
ew.Enter(Array)
}
if aw, ok := w.(ArrayWalker); ok {
if err := aw.Array(v); err != nil {
return err
}
}
for i := 0; i < v.Len(); i++ {
elem := v.Index(i)
if aw, ok := w.(ArrayWalker); ok {
if err := aw.ArrayElem(i, elem); err != nil {
return err
}
}
ew, ok := w.(EnterExitWalker)
if ok {
ew.Enter(ArrayElem)
}
if err := walk(elem, w); err != nil {
return err
}
if ok {
ew.Exit(ArrayElem)
}
}
ew, ok = w.(EnterExitWalker)
if ok {
ew.Exit(Array)
}
return nil
}
func walkStruct(v reflect.Value, w interface{}) (err error) {
ew, ewok := w.(EnterExitWalker)
if ewok {
ew.Enter(Struct)
}
skip := false
if sw, ok := w.(StructWalker); ok {
err = sw.Struct(v)
if err == SkipEntry {
skip = true
err = nil
}
if err != nil {
return
}
}
if !skip {
vt := v.Type()
for i := 0; i < vt.NumField(); i++ {
sf := vt.Field(i)
f := v.FieldByIndex([]int{i})
if sw, ok := w.(StructWalker); ok {
err = sw.StructField(sf, f)
// SkipEntry just pretends this field doesn't even exist
if err == SkipEntry {
continue
}
if err != nil {
return
}
}
ew, ok := w.(EnterExitWalker)
if ok {
ew.Enter(StructField)
}
err = walk(f, w)
if err != nil {
return
}
if ok {
ew.Exit(StructField)
}
}
}
if ewok {
ew.Exit(Struct)
}
return nil
}
......@@ -53,6 +53,12 @@
"path": "github.com/gorilla/websocket",
"revision": "e8f0f8aaa98dfb6586cbdf2978d511e3199a960a"
},
{
"checksumSHA1": "g46WnPAlsmkyNDUFdObDgP6DP+s=",
"path": "github.com/jfbus/httprs",
"revision": "e879ae6bf984ca6c009eb44696e9f400ec3ed2d9",
"revisionTime": "2017-07-03T13:54:35Z"
},
{
"checksumSHA1": "oIkoHb8+rM5Etur5HhZVY/sDQKQ=",
"path": "github.com/jpillora/backoff",
......@@ -65,6 +71,18 @@
"path": "github.com/matttproud/golang_protobuf_extensions/pbutil",
"revision": "c12348ce28de40eed0136aa2b644d0ee0650e56c"
},
{
"checksumSHA1": "kSmDazz+cokgcHQT7q56Na+IBe0=",
"path": "github.com/mitchellh/copystructure",
"revision": "f81071c9d77b7931f78c90b416a074ecdc50e959",
"revisionTime": "2017-01-16T00:44:49Z"
},
{
"checksumSHA1": "KqsMqI+Y+3EFYPhyzafpIneaVCM=",
"path": "github.com/mitchellh/reflectwalk",
"revision": "8d802ff4ae93611b807597f639c19f76074df5c6",
"revisionTime": "2017-05-08T17:38:06Z"
},
{
"checksumSHA1": "LuFv4/jlrmFNnDb/5SCSEPAM9vU=",
"path": "github.com/pmezard/go-difflib/difflib",
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment