Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
C
caddy
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
Analytics
Analytics
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Commits
Issue Boards
Open sidebar
nexedi
caddy
Commits
a1a8d0f6
Commit
a1a8d0f6
authored
Jan 01, 2017
by
Matthew Holt
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'master' of github.com:mholt/caddy
parents
5d813a1b
04bee0f3
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
545 additions
and
85 deletions
+545
-85
caddyhttp/httpserver/pathcleaner.go
caddyhttp/httpserver/pathcleaner.go
+76
-0
caddyhttp/httpserver/pathcleaner_test.go
caddyhttp/httpserver/pathcleaner_test.go
+120
-0
caddyhttp/httpserver/server.go
caddyhttp/httpserver/server.go
+1
-2
caddyhttp/proxy/proxy.go
caddyhttp/proxy/proxy.go
+20
-4
caddyhttp/proxy/proxy_test.go
caddyhttp/proxy/proxy_test.go
+96
-10
caddyhttp/proxy/reverseproxy.go
caddyhttp/proxy/reverseproxy.go
+232
-69
No files found.
caddyhttp/httpserver/pathcleaner.go
0 → 100644
View file @
a1a8d0f6
package
httpserver
import
(
"math/rand"
"path"
"strings"
"time"
)
// CleanMaskedPath prevents one or more of the path cleanup operations:
// - collapse multiple slashes into one
// - eliminate "/." (current directory)
// - eliminate "<parent_directory>/.."
// by masking certain patterns in the path with a temporary random string.
// This could be helpful when certain patterns in the path are desired to be preserved
// that would otherwise be changed by path.Clean().
// One such use case is the presence of the double slashes as protocol separator
// (e.g., /api/endpoint/http://example.com).
// This is a common pattern in many applications to allow passing URIs as path argument.
func
CleanMaskedPath
(
reqPath
string
,
masks
...
string
)
string
{
var
replacerVal
string
maskMap
:=
make
(
map
[
string
]
string
)
// Iterate over supplied masks and create temporary replacement strings
// only for the masks that are present in the path, then replace all occurrences
for
_
,
mask
:=
range
masks
{
if
strings
.
Index
(
reqPath
,
mask
)
>=
0
{
replacerVal
=
"/_caddy"
+
generateRandomString
()
+
"__"
maskMap
[
mask
]
=
replacerVal
reqPath
=
strings
.
Replace
(
reqPath
,
mask
,
replacerVal
,
-
1
)
}
}
reqPath
=
path
.
Clean
(
reqPath
)
// Revert the replaced masks after path cleanup
for
mask
,
replacerVal
:=
range
maskMap
{
reqPath
=
strings
.
Replace
(
reqPath
,
replacerVal
,
mask
,
-
1
)
}
return
reqPath
}
// CleanPath calls CleanMaskedPath() with the default mask of "://"
// to preserve double slashes of protocols
// such as "http://", "https://", and "ftp://" etc.
func
CleanPath
(
reqPath
string
)
string
{
return
CleanMaskedPath
(
reqPath
,
"://"
)
}
// An efficient and fast method for random string generation.
// Inspired by http://stackoverflow.com/a/31832326.
const
randomStringLength
=
4
const
letterBytes
=
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
const
(
letterIdxBits
=
6
letterIdxMask
=
1
<<
letterIdxBits
-
1
letterIdxMax
=
63
/
letterIdxBits
)
var
src
=
rand
.
NewSource
(
time
.
Now
()
.
UnixNano
())
func
generateRandomString
()
string
{
b
:=
make
([]
byte
,
randomStringLength
)
for
i
,
cache
,
remain
:=
randomStringLength
-
1
,
src
.
Int63
(),
letterIdxMax
;
i
>=
0
;
{
if
remain
==
0
{
cache
,
remain
=
src
.
Int63
(),
letterIdxMax
}
if
idx
:=
int
(
cache
&
letterIdxMask
);
idx
<
len
(
letterBytes
)
{
b
[
i
]
=
letterBytes
[
idx
]
i
--
}
cache
>>=
letterIdxBits
remain
--
}
return
string
(
b
)
}
caddyhttp/httpserver/pathcleaner_test.go
0 → 100644
View file @
a1a8d0f6
package
httpserver
import
(
"path"
"testing"
)
var
paths
=
map
[
string
]
map
[
string
]
string
{
"/../a/b/../././/c"
:
{
"preserve_all"
:
"/../a/b/../././/c"
,
"preserve_protocol"
:
"/a/c"
,
"preserve_slashes"
:
"/a//c"
,
"preserve_dots"
:
"/../a/b/../././c"
,
"clean_all"
:
"/a/c"
,
},
"/path/https://www.google.com"
:
{
"preserve_all"
:
"/path/https://www.google.com"
,
"preserve_protocol"
:
"/path/https://www.google.com"
,
"preserve_slashes"
:
"/path/https://www.google.com"
,
"preserve_dots"
:
"/path/https:/www.google.com"
,
"clean_all"
:
"/path/https:/www.google.com"
,
},
"/a/b/../././/c/http://example.com/foo//bar/../blah"
:
{
"preserve_all"
:
"/a/b/../././/c/http://example.com/foo//bar/../blah"
,
"preserve_protocol"
:
"/a/c/http://example.com/foo/blah"
,
"preserve_slashes"
:
"/a//c/http://example.com/foo/blah"
,
"preserve_dots"
:
"/a/b/../././c/http:/example.com/foo/bar/../blah"
,
"clean_all"
:
"/a/c/http:/example.com/foo/blah"
,
},
}
func
assertEqual
(
t
*
testing
.
T
,
expected
,
received
string
)
{
if
expected
!=
received
{
t
.
Errorf
(
"
\t
Expected: %s
\n\t\t\t
Recieved: %s"
,
expected
,
received
)
}
}
func
maskedTestRunner
(
t
*
testing
.
T
,
variation
string
,
masks
...
string
)
{
for
reqPath
,
transformation
:=
range
paths
{
assertEqual
(
t
,
transformation
[
variation
],
CleanMaskedPath
(
reqPath
,
masks
...
))
}
}
// No need to test the built-in path.Clean() function.
// However, it could be useful to cross-examine the test dataset.
func
TestPathClean
(
t
*
testing
.
T
)
{
for
reqPath
,
transformation
:=
range
paths
{
assertEqual
(
t
,
transformation
[
"clean_all"
],
path
.
Clean
(
reqPath
))
}
}
func
TestCleanAll
(
t
*
testing
.
T
)
{
maskedTestRunner
(
t
,
"clean_all"
)
}
func
TestPreserveAll
(
t
*
testing
.
T
)
{
maskedTestRunner
(
t
,
"preserve_all"
,
"//"
,
"/.."
,
"/."
)
}
func
TestPreserveProtocol
(
t
*
testing
.
T
)
{
maskedTestRunner
(
t
,
"preserve_protocol"
,
"://"
)
}
func
TestPreserveSlashes
(
t
*
testing
.
T
)
{
maskedTestRunner
(
t
,
"preserve_slashes"
,
"//"
)
}
func
TestPreserveDots
(
t
*
testing
.
T
)
{
maskedTestRunner
(
t
,
"preserve_dots"
,
"/.."
,
"/."
)
}
func
TestDefaultMask
(
t
*
testing
.
T
)
{
for
reqPath
,
transformation
:=
range
paths
{
assertEqual
(
t
,
transformation
[
"preserve_protocol"
],
CleanPath
(
reqPath
))
}
}
func
maskedBenchmarkRunner
(
b
*
testing
.
B
,
masks
...
string
)
{
for
n
:=
0
;
n
<
b
.
N
;
n
++
{
for
reqPath
:=
range
paths
{
CleanMaskedPath
(
reqPath
,
masks
...
)
}
}
}
func
BenchmarkPathClean
(
b
*
testing
.
B
)
{
for
n
:=
0
;
n
<
b
.
N
;
n
++
{
for
reqPath
:=
range
paths
{
path
.
Clean
(
reqPath
)
}
}
}
func
BenchmarkCleanAll
(
b
*
testing
.
B
)
{
maskedBenchmarkRunner
(
b
)
}
func
BenchmarkPreserveAll
(
b
*
testing
.
B
)
{
maskedBenchmarkRunner
(
b
,
"//"
,
"/.."
,
"/."
)
}
func
BenchmarkPreserveProtocol
(
b
*
testing
.
B
)
{
maskedBenchmarkRunner
(
b
,
"://"
)
}
func
BenchmarkPreserveSlashes
(
b
*
testing
.
B
)
{
maskedBenchmarkRunner
(
b
,
"//"
)
}
func
BenchmarkPreserveDots
(
b
*
testing
.
B
)
{
maskedBenchmarkRunner
(
b
,
"/.."
,
"/."
)
}
func
BenchmarkDefaultMask
(
b
*
testing
.
B
)
{
for
n
:=
0
;
n
<
b
.
N
;
n
++
{
for
reqPath
:=
range
paths
{
CleanPath
(
reqPath
)
}
}
}
caddyhttp/httpserver/server.go
View file @
a1a8d0f6
...
...
@@ -9,7 +9,6 @@ import (
"net"
"net/http"
"os"
"path"
"runtime"
"strings"
"sync"
...
...
@@ -351,7 +350,7 @@ func sanitizePath(r *http.Request) {
if
r
.
URL
.
Path
==
"/"
{
return
}
cleanedPath
:=
path
.
Clean
(
r
.
URL
.
Path
)
cleanedPath
:=
CleanPath
(
r
.
URL
.
Path
)
if
cleanedPath
==
"."
{
r
.
URL
.
Path
=
"/"
}
else
{
...
...
caddyhttp/proxy/proxy.go
View file @
a1a8d0f6
...
...
@@ -247,12 +247,28 @@ func createUpstreamRequest(r *http.Request) *http.Request {
outreq
.
URL
.
Opaque
=
outreq
.
URL
.
RawPath
}
// We are modifying the same underlying map from req (shallow
// copied above) so we only copy it if necessary.
copiedHeaders
:=
false
// Remove hop-by-hop headers listed in the "Connection" header.
// See RFC 2616, section 14.10.
if
c
:=
outreq
.
Header
.
Get
(
"Connection"
);
c
!=
""
{
for
_
,
f
:=
range
strings
.
Split
(
c
,
","
)
{
if
f
=
strings
.
TrimSpace
(
f
);
f
!=
""
{
if
!
copiedHeaders
{
outreq
.
Header
=
make
(
http
.
Header
)
copyHeader
(
outreq
.
Header
,
r
.
Header
)
copiedHeaders
=
true
}
outreq
.
Header
.
Del
(
f
)
}
}
}
// Remove hop-by-hop headers to the backend. Especially
// important is "Connection" because we want a persistent
// connection, regardless of what the client sent to us. This
// is modifying the same underlying map from r (shallow
// copied above) so we only copy it if necessary.
var
copiedHeaders
bool
// connection, regardless of what the client sent to us.
for
_
,
h
:=
range
hopHeaders
{
if
outreq
.
Header
.
Get
(
h
)
!=
""
{
if
!
copiedHeaders
{
...
...
caddyhttp/proxy/proxy_test.go
View file @
a1a8d0f6
...
...
@@ -42,10 +42,32 @@ func TestReverseProxy(t *testing.T) {
log
.
SetOutput
(
ioutil
.
Discard
)
defer
log
.
SetOutput
(
os
.
Stderr
)
verifyHeaders
:=
func
(
headers
http
.
Header
,
trailers
http
.
Header
)
{
if
headers
.
Get
(
"X-Header"
)
!=
"header-value"
{
t
.
Error
(
"Expected header 'X-Header' to be proxied properly"
)
}
if
trailers
==
nil
{
t
.
Error
(
"Expected to receive trailers"
)
}
if
trailers
.
Get
(
"X-Trailer"
)
!=
"trailer-value"
{
t
.
Error
(
"Expected header 'X-Trailer' to be proxied properly"
)
}
}
var
requestReceived
bool
backend
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
// read the body (even if it's empty) to make Go parse trailers
io
.
Copy
(
ioutil
.
Discard
,
r
.
Body
)
verifyHeaders
(
r
.
Header
,
r
.
Trailer
)
requestReceived
=
true
w
.
Header
()
.
Set
(
"Trailer"
,
"X-Trailer"
)
w
.
Header
()
.
Set
(
"X-Header"
,
"header-value"
)
w
.
WriteHeader
(
http
.
StatusOK
)
w
.
Write
([]
byte
(
"Hello, client"
))
w
.
Header
()
.
Set
(
"X-Trailer"
,
"trailer-value"
)
}))
defer
backend
.
Close
()
...
...
@@ -59,12 +81,21 @@ func TestReverseProxy(t *testing.T) {
r
:=
httptest
.
NewRequest
(
"GET"
,
"/"
,
nil
)
w
:=
httptest
.
NewRecorder
()
r
.
ContentLength
=
-
1
// force chunked encoding (required for trailers)
r
.
Header
.
Set
(
"X-Header"
,
"header-value"
)
r
.
Trailer
=
map
[
string
][]
string
{
"X-Trailer"
:
{
"trailer-value"
},
}
p
.
ServeHTTP
(
w
,
r
)
if
!
requestReceived
{
t
.
Error
(
"Expected backend to receive request, but it didn't"
)
}
res
:=
w
.
Result
()
verifyHeaders
(
res
.
Header
,
res
.
Trailer
)
// Make sure {upstream} placeholder is set
rr
:=
httpserver
.
NewResponseRecorder
(
httptest
.
NewRecorder
())
rr
.
Replacer
=
httpserver
.
NewReplacer
(
r
,
rr
,
"-"
)
...
...
@@ -123,7 +154,7 @@ func TestWebSocketReverseProxyNonHijackerPanic(t *testing.T) {
defer
wsNop
.
Close
()
// Get proxy to use for the test
p
:=
newWebSocketTestProxy
(
wsNop
.
URL
)
p
:=
newWebSocketTestProxy
(
wsNop
.
URL
,
false
)
// Create client request
r
:=
httptest
.
NewRequest
(
"GET"
,
"/"
,
nil
)
...
...
@@ -148,7 +179,7 @@ func TestWebSocketReverseProxyServeHTTPHandler(t *testing.T) {
defer
wsNop
.
Close
()
// Get proxy to use for the test
p
:=
newWebSocketTestProxy
(
wsNop
.
URL
)
p
:=
newWebSocketTestProxy
(
wsNop
.
URL
,
false
)
// Create client request
r
:=
httptest
.
NewRequest
(
"GET"
,
"/"
,
nil
)
...
...
@@ -189,7 +220,7 @@ func TestWebSocketReverseProxyFromWSClient(t *testing.T) {
defer
wsEcho
.
Close
()
// Get proxy to use for the test
p
:=
newWebSocketTestProxy
(
wsEcho
.
URL
)
p
:=
newWebSocketTestProxy
(
wsEcho
.
URL
,
false
)
// This is a full end-end test, so the proxy handler
// has to be part of a server listening on a port. Our
...
...
@@ -228,6 +259,52 @@ func TestWebSocketReverseProxyFromWSClient(t *testing.T) {
}
}
func
TestWebSocketReverseProxyFromWSSClient
(
t
*
testing
.
T
)
{
wsEcho
:=
newTLSServer
(
websocket
.
Handler
(
func
(
ws
*
websocket
.
Conn
)
{
io
.
Copy
(
ws
,
ws
)
}))
defer
wsEcho
.
Close
()
p
:=
newWebSocketTestProxy
(
wsEcho
.
URL
,
true
)
echoProxy
:=
newTLSServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
p
.
ServeHTTP
(
w
,
r
)
}))
defer
echoProxy
.
Close
()
// Set up WebSocket client
url
:=
strings
.
Replace
(
echoProxy
.
URL
,
"https://"
,
"wss://"
,
1
)
wsCfg
,
err
:=
websocket
.
NewConfig
(
url
,
echoProxy
.
URL
)
if
err
!=
nil
{
t
.
Fatal
(
err
)
}
wsCfg
.
TlsConfig
=
&
tls
.
Config
{
InsecureSkipVerify
:
true
}
ws
,
err
:=
websocket
.
DialConfig
(
wsCfg
)
if
err
!=
nil
{
t
.
Fatal
(
err
)
}
defer
ws
.
Close
()
// Send test message
trialMsg
:=
"Is it working?"
if
sendErr
:=
websocket
.
Message
.
Send
(
ws
,
trialMsg
);
sendErr
!=
nil
{
t
.
Fatal
(
sendErr
)
}
// It should be echoed back to us
var
actualMsg
string
if
rcvErr
:=
websocket
.
Message
.
Receive
(
ws
,
&
actualMsg
);
rcvErr
!=
nil
{
t
.
Fatal
(
rcvErr
)
}
if
actualMsg
!=
trialMsg
{
t
.
Errorf
(
"Expected '%s' but got '%s' instead"
,
trialMsg
,
actualMsg
)
}
}
func
TestUnixSocketProxy
(
t
*
testing
.
T
)
{
if
runtime
.
GOOS
==
"windows"
{
return
...
...
@@ -264,7 +341,7 @@ func TestUnixSocketProxy(t *testing.T) {
defer
ts
.
Close
()
url
:=
strings
.
Replace
(
ts
.
URL
,
"http://"
,
"unix:"
,
1
)
p
:=
newWebSocketTestProxy
(
url
)
p
:=
newWebSocketTestProxy
(
url
,
false
)
echoProxy
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
p
.
ServeHTTP
(
w
,
r
)
...
...
@@ -982,10 +1059,14 @@ func (u *fakeUpstream) GetTryInterval() time.Duration { return 250 * time.
// redirect to the specified backendAddr. The function
// also sets up the rules/environment for testing WebSocket
// proxy.
func
newWebSocketTestProxy
(
backendAddr
string
)
*
Proxy
{
func
newWebSocketTestProxy
(
backendAddr
string
,
insecure
bool
)
*
Proxy
{
return
&
Proxy
{
Next
:
httpserver
.
EmptyNext
,
// prevents panic in some cases when test fails
Upstreams
:
[]
Upstream
{
&
fakeWsUpstream
{
name
:
backendAddr
,
without
:
""
}},
Next
:
httpserver
.
EmptyNext
,
// prevents panic in some cases when test fails
Upstreams
:
[]
Upstream
{
&
fakeWsUpstream
{
name
:
backendAddr
,
without
:
""
,
insecure
:
insecure
,
}},
}
}
...
...
@@ -997,8 +1078,9 @@ func newPrefixedWebSocketTestProxy(backendAddr string, prefix string) *Proxy {
}
type
fakeWsUpstream
struct
{
name
string
without
string
name
string
without
string
insecure
bool
}
func
(
u
*
fakeWsUpstream
)
From
()
string
{
...
...
@@ -1007,13 +1089,17 @@ func (u *fakeWsUpstream) From() string {
func
(
u
*
fakeWsUpstream
)
Select
(
r
*
http
.
Request
)
*
UpstreamHost
{
uri
,
_
:=
url
.
Parse
(
u
.
name
)
return
&
UpstreamHost
{
host
:=
&
UpstreamHost
{
Name
:
u
.
name
,
ReverseProxy
:
NewSingleHostReverseProxy
(
uri
,
u
.
without
,
http
.
DefaultMaxIdleConnsPerHost
),
UpstreamHeaders
:
http
.
Header
{
"Connection"
:
{
"{>Connection}"
},
"Upgrade"
:
{
"{>Upgrade}"
}},
}
if
u
.
insecure
{
host
.
ReverseProxy
.
UseInsecureTransport
()
}
return
host
}
func
(
u
*
fakeWsUpstream
)
AllowedPath
(
requestPath
string
)
bool
{
return
true
}
...
...
caddyhttp/proxy/reverseproxy.go
View file @
a1a8d0f6
...
...
@@ -27,10 +27,28 @@ import (
"github.com/mholt/caddy/caddyhttp/httpserver"
)
var
bufferPool
=
sync
.
Pool
{
New
:
createBuffer
}
var
(
defaultDialer
=
&
net
.
Dialer
{
Timeout
:
30
*
time
.
Second
,
KeepAlive
:
30
*
time
.
Second
,
}
bufferPool
=
sync
.
Pool
{
New
:
createBuffer
}
)
func
createBuffer
()
interface
{}
{
return
make
([]
byte
,
32
*
1024
)
return
make
([]
byte
,
0
,
32
*
1024
)
}
func
pooledIoCopy
(
dst
io
.
Writer
,
src
io
.
Reader
)
{
buf
:=
bufferPool
.
Get
()
.
([]
byte
)
defer
bufferPool
.
Put
(
buf
)
// CopyBuffer only uses buf up to its length and panics if it's 0.
// Due to that we extend buf's length to its capacity here and
// ensure it's always non-zero.
bufCap
:=
cap
(
buf
)
io
.
CopyBuffer
(
dst
,
src
,
buf
[
0
:
bufCap
:
bufCap
])
}
// onExitFlushLoop is a callback set by tests to detect the state of the
...
...
@@ -135,11 +153,8 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
// just use default transport, to avoid creating
// a brand new transport
transport
:=
&
http
.
Transport
{
Proxy
:
http
.
ProxyFromEnvironment
,
Dial
:
(
&
net
.
Dialer
{
Timeout
:
30
*
time
.
Second
,
KeepAlive
:
30
*
time
.
Second
,
})
.
Dial
,
Proxy
:
http
.
ProxyFromEnvironment
,
Dial
:
defaultDialer
.
Dial
,
TLSHandshakeTimeout
:
10
*
time
.
Second
,
ExpectContinueTimeout
:
1
*
time
.
Second
,
}
...
...
@@ -148,7 +163,9 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
}
else
{
transport
.
MaxIdleConnsPerHost
=
keepalive
}
http2
.
ConfigureTransport
(
transport
)
if
httpserver
.
HTTP2
{
http2
.
ConfigureTransport
(
transport
)
}
rp
.
Transport
=
transport
}
return
rp
...
...
@@ -160,18 +177,20 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
func
(
rp
*
ReverseProxy
)
UseInsecureTransport
()
{
if
rp
.
Transport
==
nil
{
transport
:=
&
http
.
Transport
{
Proxy
:
http
.
ProxyFromEnvironment
,
Dial
:
(
&
net
.
Dialer
{
Timeout
:
30
*
time
.
Second
,
KeepAlive
:
30
*
time
.
Second
,
})
.
Dial
,
Proxy
:
http
.
ProxyFromEnvironment
,
Dial
:
defaultDialer
.
Dial
,
TLSHandshakeTimeout
:
10
*
time
.
Second
,
TLSClientConfig
:
&
tls
.
Config
{
InsecureSkipVerify
:
true
},
}
http2
.
ConfigureTransport
(
transport
)
if
httpserver
.
HTTP2
{
http2
.
ConfigureTransport
(
transport
)
}
rp
.
Transport
=
transport
}
else
if
transport
,
ok
:=
rp
.
Transport
.
(
*
http
.
Transport
);
ok
{
transport
.
TLSClientConfig
=
&
tls
.
Config
{
InsecureSkipVerify
:
true
}
// No http2.ConfigureTransport() here.
// For now this is only added in places where
// an http.Transport is actually created.
}
}
...
...
@@ -186,20 +205,33 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request,
}
rp
.
Director
(
outreq
)
outreq
.
Proto
=
"HTTP/1.1"
outreq
.
ProtoMajor
=
1
outreq
.
ProtoMinor
=
1
outreq
.
Close
=
false
res
,
err
:=
transport
.
RoundTrip
(
outreq
)
if
err
!=
nil
{
return
err
}
isWebsocket
:=
res
.
StatusCode
==
http
.
StatusSwitchingProtocols
&&
strings
.
ToLower
(
res
.
Header
.
Get
(
"Upgrade"
))
==
"websocket"
// Remove hop-by-hop headers listed in the
// "Connection" header of the response.
if
c
:=
res
.
Header
.
Get
(
"Connection"
);
c
!=
""
{
for
_
,
f
:=
range
strings
.
Split
(
c
,
","
)
{
if
f
=
strings
.
TrimSpace
(
f
);
f
!=
""
{
res
.
Header
.
Del
(
f
)
}
}
}
for
_
,
h
:=
range
hopHeaders
{
res
.
Header
.
Del
(
h
)
}
if
respUpdateFn
!=
nil
{
respUpdateFn
(
res
)
}
if
res
.
StatusCode
==
http
.
StatusSwitchingProtocols
&&
strings
.
ToLower
(
res
.
Header
.
Get
(
"Upgrade"
))
==
"websocket"
{
if
isWebsocket
{
res
.
Body
.
Close
()
hj
,
ok
:=
rw
.
(
http
.
Hijacker
)
if
!
ok
{
...
...
@@ -228,27 +260,39 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request,
}
defer
backendConn
.
Close
()
go
func
()
{
io
.
Copy
(
backendConn
,
conn
)
// write tcp stream to backend.
}()
io
.
Copy
(
conn
,
backendConn
)
// read tcp stream from backend.
go
pooledIoCopy
(
backendConn
,
conn
)
// write tcp stream to backend
pooledIoCopy
(
conn
,
backendConn
)
// read tcp stream from backend
}
else
{
defer
res
.
Body
.
Close
()
for
_
,
h
:=
range
hopHeaders
{
res
.
Header
.
Del
(
h
)
}
copyHeader
(
rw
.
Header
(),
res
.
Header
)
// The "Trailer" header isn't included in the Transport's response,
// at least for *http.Transport. Build it up from Trailer.
if
len
(
res
.
Trailer
)
>
0
{
trailerKeys
:=
make
([]
string
,
0
,
len
(
res
.
Trailer
))
for
k
:=
range
res
.
Trailer
{
trailerKeys
=
append
(
trailerKeys
,
k
)
}
rw
.
Header
()
.
Add
(
"Trailer"
,
strings
.
Join
(
trailerKeys
,
", "
))
}
rw
.
WriteHeader
(
res
.
StatusCode
)
if
len
(
res
.
Trailer
)
>
0
{
// Force chunking if we saw a response trailer.
// This prevents net/http from calculating the length for short
// bodies and adding a Content-Length.
if
fl
,
ok
:=
rw
.
(
http
.
Flusher
);
ok
{
fl
.
Flush
()
}
}
rp
.
copyResponse
(
rw
,
res
.
Body
)
res
.
Body
.
Close
()
// close now, instead of defer, to populate res.Trailer
copyHeader
(
rw
.
Header
(),
res
.
Trailer
)
}
return
nil
}
func
(
rp
*
ReverseProxy
)
copyResponse
(
dst
io
.
Writer
,
src
io
.
Reader
)
{
buf
:=
bufferPool
.
Get
()
defer
bufferPool
.
Put
(
buf
)
if
rp
.
FlushInterval
!=
0
{
if
wf
,
ok
:=
dst
.
(
writeFlusher
);
ok
{
mlw
:=
&
maxLatencyWriter
{
...
...
@@ -261,7 +305,7 @@ func (rp *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
dst
=
mlw
}
}
io
.
CopyBuffer
(
dst
,
src
,
buf
.
([]
byte
)
)
pooledIoCopy
(
dst
,
src
)
}
// skip these headers if they already exist.
...
...
@@ -295,16 +339,17 @@ func copyHeader(dst, src http.Header) {
// Hop-by-hop headers. These are removed when sent to the backend.
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
var
hopHeaders
=
[]
string
{
"Alt-Svc"
,
"Alternate-Protocol"
,
"Connection"
,
"Keep-Alive"
,
"Proxy-Authenticate"
,
"Proxy-Authorization"
,
"Te"
,
// canonicalized version of "TE"
"Trailers"
,
"Proxy-Connection"
,
// non-standard but still sent by libcurl and rejected by e.g. google
"Te"
,
// canonicalized version of "TE"
"Trailer"
,
// not Trailers per URL above; http://www.rfc-editor.org/errata_search.php?eid=4522
"Transfer-Encoding"
,
"Upgrade"
,
"Alternate-Protocol"
,
"Alt-Svc"
,
}
type
respUpdateFn
func
(
resp
*
http
.
Response
)
...
...
@@ -331,51 +376,169 @@ type connHijackerTransport struct {
}
func
newConnHijackerTransport
(
base
http
.
RoundTripper
)
*
connHijackerTransport
{
transport
:=
&
http
.
Transport
{
Proxy
:
http
.
ProxyFromEnvironment
,
Dial
:
(
&
net
.
Dialer
{
Timeout
:
30
*
time
.
Second
,
KeepAlive
:
30
*
time
.
Second
,
})
.
Dial
,
TLSHandshakeTimeout
:
10
*
time
.
Second
,
t
:=
&
http
.
Transport
{
MaxIdleConnsPerHost
:
-
1
,
}
if
base
!=
nil
{
if
baseTransport
,
ok
:=
base
.
(
*
http
.
Transport
);
ok
{
transport
.
Proxy
=
baseTransport
.
Proxy
transport
.
TLSClientConfig
=
baseTransport
.
TLSClientConfig
transport
.
TLSHandshakeTimeout
=
baseTransport
.
TLSHandshakeTimeout
transport
.
Dial
=
baseTransport
.
Dial
transport
.
DialTLS
=
baseTransport
.
DialTLS
transport
.
MaxIdleConnsPerHost
=
-
1
if
b
,
_
:=
base
.
(
*
http
.
Transport
);
b
!=
nil
{
tlsClientConfig
:=
b
.
TLSClientConfig
if
tlsClientConfig
.
NextProtos
!=
nil
{
tlsClientConfig
=
cloneTLSClientConfig
(
tlsClientConfig
)
tlsClientConfig
.
NextProtos
=
nil
}
t
.
Proxy
=
b
.
Proxy
t
.
TLSClientConfig
=
tlsClientConfig
t
.
TLSHandshakeTimeout
=
b
.
TLSHandshakeTimeout
t
.
Dial
=
b
.
Dial
t
.
DialTLS
=
b
.
DialTLS
}
else
{
t
.
Proxy
=
http
.
ProxyFromEnvironment
t
.
TLSHandshakeTimeout
=
10
*
time
.
Second
}
hj
Transport
:=
&
connHijackerTransport
{
transpor
t
,
nil
,
bufferPool
.
Get
()
.
([]
byte
)[
:
0
]}
oldDial
:=
transport
.
Dial
oldDialTLS
:=
transport
.
DialTLS
if
oldDial
==
nil
{
oldDial
=
(
&
net
.
Dialer
{
Timeout
:
30
*
time
.
Second
,
KeepAlive
:
30
*
time
.
Second
,
})
.
Dial
hj
:=
&
connHijackerTransport
{
t
,
nil
,
bufferPool
.
Get
()
.
([]
byte
)[
:
0
]}
dial
:=
getTransportDial
(
t
)
dialTLS
:=
getTransportDialTLS
(
t
)
t
.
Dial
=
func
(
network
,
addr
string
)
(
net
.
Conn
,
error
)
{
c
,
err
:=
dial
(
network
,
addr
)
hj
.
Conn
=
c
return
&
hijackedConn
{
c
,
hj
},
err
}
hjTransport
.
Dial
=
func
(
network
,
addr
string
)
(
net
.
Conn
,
error
)
{
c
,
err
:=
oldDial
(
network
,
addr
)
hj
Transport
.
Conn
=
c
return
&
hijackedConn
{
c
,
hj
Transport
},
err
t
.
DialTLS
=
func
(
network
,
addr
string
)
(
net
.
Conn
,
error
)
{
c
,
err
:=
dialTLS
(
network
,
addr
)
hj
.
Conn
=
c
return
&
hijackedConn
{
c
,
hj
},
err
}
if
oldDialTLS
!=
nil
{
hjTransport
.
DialTLS
=
func
(
network
,
addr
string
)
(
net
.
Conn
,
error
)
{
c
,
err
:=
oldDialTLS
(
network
,
addr
)
hjTransport
.
Conn
=
c
return
&
hijackedConn
{
c
,
hjTransport
},
err
return
hj
}
// getTransportDial always returns a plain Dialer
// and defaults to the existing t.Dial.
func
getTransportDial
(
t
*
http
.
Transport
)
func
(
network
,
addr
string
)
(
net
.
Conn
,
error
)
{
if
t
.
Dial
!=
nil
{
return
t
.
Dial
}
return
defaultDialer
.
Dial
}
// getTransportDial always returns a TLS Dialer
// and defaults to the existing t.DialTLS.
func
getTransportDialTLS
(
t
*
http
.
Transport
)
func
(
network
,
addr
string
)
(
net
.
Conn
,
error
)
{
if
t
.
DialTLS
!=
nil
{
return
t
.
DialTLS
}
// newConnHijackerTransport will modify t.Dial after calling this method
// => Create a backup reference.
plainDial
:=
getTransportDial
(
t
)
// The following DialTLS implementation stems from the Go stdlib and
// is identical to what happens if DialTLS is not provided.
// Source: https://github.com/golang/go/blob/230a376b5a67f0e9341e1fa47e670ff762213c83/src/net/http/transport.go#L1018-L1051
return
func
(
network
,
addr
string
)
(
net
.
Conn
,
error
)
{
plainConn
,
err
:=
plainDial
(
network
,
addr
)
if
err
!=
nil
{
return
nil
,
err
}
tlsClientConfig
:=
t
.
TLSClientConfig
if
tlsClientConfig
==
nil
{
tlsClientConfig
=
&
tls
.
Config
{}
}
if
!
tlsClientConfig
.
InsecureSkipVerify
&&
tlsClientConfig
.
ServerName
==
""
{
tlsClientConfig
.
ServerName
=
stripPort
(
addr
)
}
tlsConn
:=
tls
.
Client
(
plainConn
,
tlsClientConfig
)
errc
:=
make
(
chan
error
,
2
)
var
timer
*
time
.
Timer
if
d
:=
t
.
TLSHandshakeTimeout
;
d
!=
0
{
timer
=
time
.
AfterFunc
(
d
,
func
()
{
errc
<-
tlsHandshakeTimeoutError
{}
})
}
go
func
()
{
err
:=
tlsConn
.
Handshake
()
if
timer
!=
nil
{
timer
.
Stop
()
}
errc
<-
err
}()
if
err
:=
<-
errc
;
err
!=
nil
{
plainConn
.
Close
()
return
nil
,
err
}
if
!
tlsClientConfig
.
InsecureSkipVerify
{
hostname
:=
tlsClientConfig
.
ServerName
if
hostname
==
""
{
hostname
=
stripPort
(
addr
)
}
if
err
:=
tlsConn
.
VerifyHostname
(
hostname
);
err
!=
nil
{
plainConn
.
Close
()
return
nil
,
err
}
}
return
tlsConn
,
nil
}
}
// stripPort returns address without its port if it has one and
// works with IP addresses as well as hostnames formatted as host:port.
//
// IPv6 addresses (excluding the port) must be enclosed in
// square brackets similar to the requirements of Go's stdlib.
func
stripPort
(
address
string
)
string
{
// Keep in mind that the address might be a IPv6 address
// and thus contain a colon, but not have a port.
portIdx
:=
strings
.
LastIndex
(
address
,
":"
)
ipv6Idx
:=
strings
.
LastIndex
(
address
,
"]"
)
if
portIdx
>
ipv6Idx
{
address
=
address
[
:
portIdx
]
}
return
address
}
type
tlsHandshakeTimeoutError
struct
{}
func
(
tlsHandshakeTimeoutError
)
Timeout
()
bool
{
return
true
}
func
(
tlsHandshakeTimeoutError
)
Temporary
()
bool
{
return
true
}
func
(
tlsHandshakeTimeoutError
)
Error
()
string
{
return
"net/http: TLS handshake timeout"
}
// cloneTLSClientConfig is like cloneTLSConfig but omits
// the fields SessionTicketsDisabled and SessionTicketKey.
// This makes it safe to call cloneTLSClientConfig on a config
// in active use by a server.
func
cloneTLSClientConfig
(
cfg
*
tls
.
Config
)
*
tls
.
Config
{
if
cfg
==
nil
{
return
&
tls
.
Config
{}
}
return
&
tls
.
Config
{
Rand
:
cfg
.
Rand
,
Time
:
cfg
.
Time
,
Certificates
:
cfg
.
Certificates
,
NameToCertificate
:
cfg
.
NameToCertificate
,
GetCertificate
:
cfg
.
GetCertificate
,
RootCAs
:
cfg
.
RootCAs
,
NextProtos
:
cfg
.
NextProtos
,
ServerName
:
cfg
.
ServerName
,
ClientAuth
:
cfg
.
ClientAuth
,
ClientCAs
:
cfg
.
ClientCAs
,
InsecureSkipVerify
:
cfg
.
InsecureSkipVerify
,
CipherSuites
:
cfg
.
CipherSuites
,
PreferServerCipherSuites
:
cfg
.
PreferServerCipherSuites
,
ClientSessionCache
:
cfg
.
ClientSessionCache
,
MinVersion
:
cfg
.
MinVersion
,
MaxVersion
:
cfg
.
MaxVersion
,
CurvePreferences
:
cfg
.
CurvePreferences
,
DynamicRecordSizingDisabled
:
cfg
.
DynamicRecordSizingDisabled
,
Renegotiation
:
cfg
.
Renegotiation
,
}
return
hjTransport
}
func
requestIsWebsocket
(
req
*
http
.
Request
)
bool
{
return
!
(
strings
.
ToLower
(
req
.
Header
.
Get
(
"Upgrade"
))
!=
"websocket"
||
!
strings
.
Contains
(
strings
.
ToLower
(
req
.
Header
.
Get
(
"Connection"
)),
"upgrade"
)
)
return
strings
.
ToLower
(
req
.
Header
.
Get
(
"Upgrade"
))
==
"websocket"
&&
strings
.
Contains
(
strings
.
ToLower
(
req
.
Header
.
Get
(
"Connection"
)),
"upgrade"
)
}
type
writeFlusher
interface
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment