Commit f1ba7fa3 authored by Matt Holt's avatar Matt Holt

Merge pull request #467 from eiszfuchs/feature/proxy-socket

proxy: Support unix sockets
parents 57ffe5a6 7091a209
......@@ -63,7 +63,6 @@ var tryDuration = 60 * time.Second
// ServeHTTP satisfies the middleware.Handler interface.
func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
for _, upstream := range p.Upstreams {
if middleware.Path(r.URL.Path).Matches(upstream.From()) && upstream.IsAllowedPath(r.URL.Path) {
var replacer middleware.Replacer
......
......@@ -3,6 +3,7 @@ package proxy
import (
"bufio"
"bytes"
"fmt"
"io"
"io/ioutil"
"log"
......@@ -13,7 +14,9 @@ import (
"os"
"strings"
"testing"
"runtime"
"time"
"path/filepath"
"golang.org/x/net/websocket"
)
......@@ -160,6 +163,69 @@ func TestWebSocketReverseProxyFromWSClient(t *testing.T) {
}
}
func TestUnixSocketProxy(t *testing.T) {
if runtime.GOOS == "windows" {
return
}
trialMsg := "Is it working?"
var proxySuccess bool
// This is our fake "application" we want to proxy to
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Request was proxied when this is called
proxySuccess = true
fmt.Fprint(w, trialMsg)
}))
// Get absolute path for unix: socket
socketPath, err := filepath.Abs("./test_socket")
if err != nil {
t.Fatalf("Unable to get absolute path: %v", err)
}
// Change httptest.Server listener to listen to unix: socket
ln, err := net.Listen("unix", socketPath)
if err != nil {
t.Fatalf("Unable to listen: %v", err)
}
ts.Listener = ln
ts.Start()
defer ts.Close()
url := strings.Replace(ts.URL, "http://", "unix:", 1)
p := newWebSocketTestProxy(url)
echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
p.ServeHTTP(w, r)
}))
defer echoProxy.Close()
res, err := http.Get(echoProxy.URL)
if err != nil {
t.Fatalf("Unable to GET: %v", err)
}
greeting, err := ioutil.ReadAll(res.Body)
res.Body.Close()
if err != nil {
t.Fatalf("Unable to GET: %v", err)
}
actualMsg := fmt.Sprintf("%s", greeting)
if !proxySuccess {
t.Errorf("Expected request to be proxied, but it wasn't")
}
if actualMsg != trialMsg {
t.Errorf("Expected '%s' but got '%s' instead", trialMsg, actualMsg)
}
}
func newFakeUpstream(name string, insecure bool) *fakeUpstream {
uri, _ := url.Parse(name)
u := &fakeUpstream{
......
......@@ -59,6 +59,18 @@ func singleJoiningSlash(a, b string) string {
return a + b
}
// Though the relevant directive prefix is just "unix:", url.Parse
// will - assuming the regular URL scheme - add additional slashes
// as if "unix" was a request protocol.
// What we need is just the path, so if "unix:/var/run/www.socket"
// was the proxy directive, the parsed hostName would be
// "unix:///var/run/www.socket", hence the ambiguous trimming.
func socketDial(hostName string) func(network, addr string) (conn net.Conn, err error) {
return func(network, addr string) (conn net.Conn, err error) {
return net.Dial("unix", hostName[len("unix://"):])
}
}
// NewSingleHostReverseProxy returns a new ReverseProxy that rewrites
// URLs to the scheme, host, and base path provided in target. If the
// target's path is "/base" and the incoming request was for "/dir",
......@@ -68,8 +80,15 @@ func singleJoiningSlash(a, b string) string {
func NewSingleHostReverseProxy(target *url.URL, without string) *ReverseProxy {
targetQuery := target.RawQuery
director := func(req *http.Request) {
req.URL.Scheme = target.Scheme
req.URL.Host = target.Host
if target.Scheme == "unix" {
// to make Dial work with unix URL,
// scheme and host have to be faked
req.URL.Scheme = "http"
req.URL.Host = "socket"
} else {
req.URL.Scheme = target.Scheme
req.URL.Host = target.Host
}
req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
if targetQuery == "" || req.URL.RawQuery == "" {
req.URL.RawQuery = targetQuery + req.URL.RawQuery
......@@ -80,7 +99,13 @@ func NewSingleHostReverseProxy(target *url.URL, without string) *ReverseProxy {
req.URL.Path = strings.TrimPrefix(req.URL.Path, without)
}
}
return &ReverseProxy{Director: director}
rp := &ReverseProxy{Director: director}
if target.Scheme == "unix" {
rp.Transport = &http.Transport{
Dial: socketDial(target.String()),
}
}
return rp
}
func copyHeader(dst, src http.Header) {
......
......@@ -65,7 +65,8 @@ func NewStaticUpstreams(c parse.Dispenser) ([]Upstream, error) {
upstream.Hosts = make([]*UpstreamHost, len(to))
for i, host := range to {
if !strings.HasPrefix(host, "http") {
if !strings.HasPrefix(host, "http") &&
!strings.HasPrefix(host, "unix:") {
host = "http://" + host
}
uh := &UpstreamHost{
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment