Commit 8658e189 authored by Craig Peterson's avatar Craig Peterson Committed by GitHub

Merge branch 'master' into macros

parents 345b312e 9a22cda1
...@@ -518,6 +518,11 @@ func startWithListenerFds(cdyfile Input, inst *Instance, restartFds map[string]r ...@@ -518,6 +518,11 @@ func startWithListenerFds(cdyfile Input, inst *Instance, restartFds map[string]r
} }
if !Quiet { if !Quiet {
for _, srvln := range inst.servers { for _, srvln := range inst.servers {
// only show FD notice if the listener is not nil.
// This can happen when only serving UDP or TCP
if srvln.listener == nil {
continue
}
if !IsLoopback(srvln.listener.Addr().String()) { if !IsLoopback(srvln.listener.Addr().String()) {
checkFdlimit() checkFdlimit()
break break
......
...@@ -214,6 +214,9 @@ func SameNext(next1, next2 Handler) bool { ...@@ -214,6 +214,9 @@ func SameNext(next1, next2 Handler) bool {
// Context key constants. // Context key constants.
const ( const (
// ReplacerCtxKey is the context key for a per-request replacer.
ReplacerCtxKey caddy.CtxKey = "replacer"
// RemoteUserCtxKey is the key for the remote user of the request, if any (basicauth). // RemoteUserCtxKey is the key for the remote user of the request, if any (basicauth).
RemoteUserCtxKey caddy.CtxKey = "remote_user" RemoteUserCtxKey caddy.CtxKey = "remote_user"
......
...@@ -102,6 +102,18 @@ func (lw *limitWriter) String() string { ...@@ -102,6 +102,18 @@ func (lw *limitWriter) String() string {
// emptyValue should be the string that is used in place // emptyValue should be the string that is used in place
// of empty string (can still be empty string). // of empty string (can still be empty string).
func NewReplacer(r *http.Request, rr *ResponseRecorder, emptyValue string) Replacer { func NewReplacer(r *http.Request, rr *ResponseRecorder, emptyValue string) Replacer {
repl := &replacer{
request: r,
responseRecorder: rr,
emptyValue: emptyValue,
}
// extract customReplacements from a request replacer when present.
if existing, ok := r.Context().Value(ReplacerCtxKey).(*replacer); ok {
repl.requestBody = existing.requestBody
repl.customReplacements = existing.customReplacements
} else {
// if there is no existing replacer, build one from scratch.
rb := newLimitWriter(MaxLogBodySize) rb := newLimitWriter(MaxLogBodySize)
if r.Body != nil { if r.Body != nil {
r.Body = struct { r.Body = struct {
...@@ -109,13 +121,11 @@ func NewReplacer(r *http.Request, rr *ResponseRecorder, emptyValue string) Repla ...@@ -109,13 +121,11 @@ func NewReplacer(r *http.Request, rr *ResponseRecorder, emptyValue string) Repla
io.Closer io.Closer
}{io.TeeReader(r.Body, rb), io.Closer(r.Body)} }{io.TeeReader(r.Body, rb), io.Closer(r.Body)}
} }
return &replacer{ repl.requestBody = rb
request: r, repl.customReplacements = make(map[string]string)
requestBody: rb,
responseRecorder: rr,
customReplacements: make(map[string]string),
emptyValue: emptyValue,
} }
return repl
} }
func canLogRequest(r *http.Request) bool { func canLogRequest(r *http.Request) bool {
......
...@@ -356,6 +356,12 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { ...@@ -356,6 +356,12 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
c := context.WithValue(r.Context(), OriginalURLCtxKey, urlCopy) c := context.WithValue(r.Context(), OriginalURLCtxKey, urlCopy)
r = r.WithContext(c) r = r.WithContext(c)
// Setup a replacer for the request that keeps track of placeholder
// values across plugins.
replacer := NewReplacer(r, nil, "")
c = context.WithValue(r.Context(), ReplacerCtxKey, replacer)
r = r.WithContext(c)
w.Header().Set("Server", caddy.AppName) w.Header().Set("Server", caddy.AppName)
status, _ := s.serveHTTP(w, r) status, _ := s.serveHTTP(w, r)
......
...@@ -83,6 +83,7 @@ type UpstreamHost struct { ...@@ -83,6 +83,7 @@ type UpstreamHost struct {
// reads & writes to this value. The default value of 0 indicates that it // reads & writes to this value. The default value of 0 indicates that it
// is healthy and any non-zero value indicates unhealthy. // is healthy and any non-zero value indicates unhealthy.
Unhealthy int32 Unhealthy int32
HealthCheckResult atomic.Value
} }
// Down checks whether the upstream host is down or not. // Down checks whether the upstream host is down or not.
......
...@@ -26,7 +26,9 @@ ...@@ -26,7 +26,9 @@
package proxy package proxy
import ( import (
"context"
"crypto/tls" "crypto/tls"
"fmt"
"io" "io"
"net" "net"
"net/http" "net/http"
...@@ -91,6 +93,8 @@ type ReverseProxy struct { ...@@ -91,6 +93,8 @@ type ReverseProxy struct {
// response body. // response body.
// If zero, no periodic flushing is done. // If zero, no periodic flushing is done.
FlushInterval time.Duration FlushInterval time.Duration
srvResolver srvResolver
} }
// Though the relevant directive prefix is just "unix:", url.Parse // Though the relevant directive prefix is just "unix:", url.Parse
...@@ -105,6 +109,23 @@ func socketDial(hostName string) func(network, addr string) (conn net.Conn, err ...@@ -105,6 +109,23 @@ func socketDial(hostName string) func(network, addr string) (conn net.Conn, err
} }
} }
func (rp *ReverseProxy) srvDialerFunc(locator string) func(network, addr string) (conn net.Conn, err error) {
service := locator
if strings.HasPrefix(locator, "srv://") {
service = locator[6:]
} else if strings.HasPrefix(locator, "srv+https://") {
service = locator[12:]
}
return func(network, addr string) (conn net.Conn, err error) {
_, addrs, err := rp.srvResolver.LookupSRV(context.Background(), "", "", service)
if err != nil {
return nil, err
}
return net.Dial("tcp", fmt.Sprintf("%s:%d", addrs[0].Target, addrs[0].Port))
}
}
func singleJoiningSlash(a, b string) string { func singleJoiningSlash(a, b string) string {
aslash := strings.HasSuffix(a, "/") aslash := strings.HasSuffix(a, "/")
bslash := strings.HasPrefix(b, "/") bslash := strings.HasPrefix(b, "/")
...@@ -131,6 +152,12 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) * ...@@ -131,6 +152,12 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
// scheme and host have to be faked // scheme and host have to be faked
req.URL.Scheme = "http" req.URL.Scheme = "http"
req.URL.Host = "socket" req.URL.Host = "socket"
} else if target.Scheme == "srv" {
req.URL.Scheme = "http"
req.URL.Host = target.Host
} else if target.Scheme == "srv+https" {
req.URL.Scheme = "https"
req.URL.Host = target.Host
} else { } else {
req.URL.Scheme = target.Scheme req.URL.Scheme = target.Scheme
req.URL.Host = target.Host req.URL.Host = target.Host
...@@ -199,7 +226,12 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) * ...@@ -199,7 +226,12 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
} }
} }
rp := &ReverseProxy{Director: director, FlushInterval: 250 * time.Millisecond} // flushing good for streaming & server-sent events rp := &ReverseProxy{
Director: director,
FlushInterval: 250 * time.Millisecond, // flushing good for streaming & server-sent events
srvResolver: net.DefaultResolver,
}
if target.Scheme == "unix" { if target.Scheme == "unix" {
rp.Transport = &http.Transport{ rp.Transport = &http.Transport{
Dial: socketDial(target.String()), Dial: socketDial(target.String()),
...@@ -210,13 +242,15 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) * ...@@ -210,13 +242,15 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
HandshakeTimeout: defaultCryptoHandshakeTimeout, HandshakeTimeout: defaultCryptoHandshakeTimeout,
}, },
} }
} else if keepalive != http.DefaultMaxIdleConnsPerHost { } else if keepalive != http.DefaultMaxIdleConnsPerHost || strings.HasPrefix(target.Scheme, "srv") {
// if keepalive is equal to the default, dialFunc := defaultDialer.Dial
// just use default transport, to avoid creating if strings.HasPrefix(target.Scheme, "srv") {
// a brand new transport dialFunc = rp.srvDialerFunc(target.String())
}
transport := &http.Transport{ transport := &http.Transport{
Proxy: http.ProxyFromEnvironment, Proxy: http.ProxyFromEnvironment,
Dial: defaultDialer.Dial, Dial: dialFunc,
TLSHandshakeTimeout: defaultCryptoHandshakeTimeout, TLSHandshakeTimeout: defaultCryptoHandshakeTimeout,
ExpectContinueTimeout: 1 * time.Second, ExpectContinueTimeout: 1 * time.Second,
} }
......
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package proxy
import (
"net"
"net/http"
"net/http/httptest"
"net/url"
"strconv"
"testing"
)
const (
expectedResponse = "response from request proxied to upstream"
expectedStatus = http.StatusOK
)
var upstreamHost *httptest.Server
func setupTest() {
upstreamHost = httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/test-path" {
w.WriteHeader(expectedStatus)
w.Write([]byte(expectedResponse))
} else {
w.WriteHeader(404)
w.Write([]byte("Not found"))
}
}))
}
func tearDownTest() {
upstreamHost.Close()
}
func TestSingleSRVHostReverseProxy(t *testing.T) {
setupTest()
defer tearDownTest()
target, err := url.Parse("srv://test.upstream.service")
if err != nil {
t.Errorf("Failed to parse target URL. %s", err.Error())
}
upstream, err := url.Parse(upstreamHost.URL)
if err != nil {
t.Errorf("Failed to parse test server URL [%s]. %s", upstreamHost.URL, err.Error())
}
pp, err := strconv.Atoi(upstream.Port())
if err != nil {
t.Errorf("Failed to parse upstream server port [%s]. %s", upstream.Port(), err.Error())
}
port := uint16(pp)
rp := NewSingleHostReverseProxy(target, "", http.DefaultMaxIdleConnsPerHost)
rp.srvResolver = testResolver{
result: []*net.SRV{
{Target: upstream.Hostname(), Port: port, Priority: 1, Weight: 1},
},
}
resp := httptest.NewRecorder()
req, err := http.NewRequest("GET", "http://test.host/test-path", nil)
if err != nil {
t.Errorf("Failed to create new request. %s", err.Error())
}
err = rp.ServeHTTP(resp, req, nil)
if err != nil {
t.Errorf("Failed to perform reverse proxy to upstream host. %s", err.Error())
}
if resp.Body.String() != expectedResponse {
t.Errorf("Unexpected proxy response received. Expected: '%s', Got: '%s'", expectedResponse, resp.Body.String())
}
if resp.Code != expectedStatus {
t.Errorf("Unexpected proxy status. Expected: '%d', Got: '%d'", expectedStatus, resp.Code)
}
}
...@@ -16,6 +16,7 @@ package proxy ...@@ -16,6 +16,7 @@ package proxy
import ( import (
"bytes" "bytes"
"context"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
...@@ -65,6 +66,11 @@ type staticUpstream struct { ...@@ -65,6 +66,11 @@ type staticUpstream struct {
IgnoredSubPaths []string IgnoredSubPaths []string
insecureSkipVerify bool insecureSkipVerify bool
MaxFails int32 MaxFails int32
resolver srvResolver
}
type srvResolver interface {
LookupSRV(context.Context, string, string, string) (string, []*net.SRV, error)
} }
// NewStaticUpstreams parses the configuration input and sets up // NewStaticUpstreams parses the configuration input and sets up
...@@ -86,6 +92,7 @@ func NewStaticUpstreams(c caddyfile.Dispenser, host string) ([]Upstream, error) ...@@ -86,6 +92,7 @@ func NewStaticUpstreams(c caddyfile.Dispenser, host string) ([]Upstream, error)
TryInterval: 250 * time.Millisecond, TryInterval: 250 * time.Millisecond,
MaxConns: 0, MaxConns: 0,
KeepAlive: http.DefaultMaxIdleConnsPerHost, KeepAlive: http.DefaultMaxIdleConnsPerHost,
resolver: net.DefaultResolver,
} }
if !c.Args(&upstream.from) { if !c.Args(&upstream.from) {
...@@ -93,7 +100,21 @@ func NewStaticUpstreams(c caddyfile.Dispenser, host string) ([]Upstream, error) ...@@ -93,7 +100,21 @@ func NewStaticUpstreams(c caddyfile.Dispenser, host string) ([]Upstream, error)
} }
var to []string var to []string
hasSrv := false
for _, t := range c.RemainingArgs() { for _, t := range c.RemainingArgs() {
if len(to) > 0 && hasSrv {
return upstreams, c.Err("only one upstream is supported when using SRV locator")
}
if strings.HasPrefix(t, "srv://") || strings.HasPrefix(t, "srv+https://") {
if len(to) > 0 {
return upstreams, c.Err("service locator upstreams can not be mixed with host names")
}
hasSrv = true
}
parsed, err := parseUpstream(t) parsed, err := parseUpstream(t)
if err != nil { if err != nil {
return upstreams, err return upstreams, err
...@@ -107,13 +128,18 @@ func NewStaticUpstreams(c caddyfile.Dispenser, host string) ([]Upstream, error) ...@@ -107,13 +128,18 @@ func NewStaticUpstreams(c caddyfile.Dispenser, host string) ([]Upstream, error)
if !c.NextArg() { if !c.NextArg() {
return upstreams, c.ArgErr() return upstreams, c.ArgErr()
} }
if hasSrv {
return upstreams, c.Err("upstream directive is not supported when backend is service locator")
}
parsed, err := parseUpstream(c.Val()) parsed, err := parseUpstream(c.Val())
if err != nil { if err != nil {
return upstreams, err return upstreams, err
} }
to = append(to, parsed...) to = append(to, parsed...)
default: default:
if err := parseBlock(&c, upstream); err != nil { if err := parseBlock(&c, upstream, hasSrv); err != nil {
return upstreams, err return upstreams, err
} }
} }
...@@ -165,7 +191,9 @@ func (u *staticUpstream) From() string { ...@@ -165,7 +191,9 @@ func (u *staticUpstream) From() string {
func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) { func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) {
if !strings.HasPrefix(host, "http") && if !strings.HasPrefix(host, "http") &&
!strings.HasPrefix(host, "unix:") && !strings.HasPrefix(host, "unix:") &&
!strings.HasPrefix(host, "quic:") { !strings.HasPrefix(host, "quic:") &&
!strings.HasPrefix(host, "srv://") &&
!strings.HasPrefix(host, "srv+https://") {
host = "http://" + host host = "http://" + host
} }
uh := &UpstreamHost{ uh := &UpstreamHost{
...@@ -189,6 +217,7 @@ func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) { ...@@ -189,6 +217,7 @@ func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) {
}(u), }(u),
WithoutPathPrefix: u.WithoutPathPrefix, WithoutPathPrefix: u.WithoutPathPrefix,
MaxConns: u.MaxConns, MaxConns: u.MaxConns,
HealthCheckResult: atomic.Value{},
} }
baseURL, err := url.Parse(uh.Name) baseURL, err := url.Parse(uh.Name)
...@@ -205,11 +234,22 @@ func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) { ...@@ -205,11 +234,22 @@ func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) {
} }
func parseUpstream(u string) ([]string, error) { func parseUpstream(u string) ([]string, error) {
if !strings.HasPrefix(u, "unix:") { if strings.HasPrefix(u, "unix:") {
return []string{u}, nil
}
isSrv := strings.HasPrefix(u, "srv://") || strings.HasPrefix(u, "srv+https://")
colonIdx := strings.LastIndex(u, ":") colonIdx := strings.LastIndex(u, ":")
protoIdx := strings.Index(u, "://") protoIdx := strings.Index(u, "://")
if colonIdx != -1 && colonIdx != protoIdx { if colonIdx == -1 || colonIdx == protoIdx {
return []string{u}, nil
}
if isSrv {
return nil, fmt.Errorf("service locator %s can not have port specified", u)
}
us := u[:colonIdx] us := u[:colonIdx]
ue := "" ue := ""
portsEnd := len(u) portsEnd := len(u)
...@@ -217,9 +257,18 @@ func parseUpstream(u string) ([]string, error) { ...@@ -217,9 +257,18 @@ func parseUpstream(u string) ([]string, error) {
portsEnd = colonIdx + nextSlash portsEnd = colonIdx + nextSlash
ue = u[portsEnd:] ue = u[portsEnd:]
} }
ports := u[len(us)+1 : portsEnd] ports := u[len(us)+1 : portsEnd]
separators := strings.Count(ports, "-")
if separators == 0 {
return []string{u}, nil
}
if separators > 1 {
return nil, fmt.Errorf("port range [%s] has %d separators", ports, separators)
}
if separators := strings.Count(ports, "-"); separators == 1 {
portsStr := strings.Split(ports, "-") portsStr := strings.Split(ports, "-")
pIni, err := strconv.Atoi(portsStr[0]) pIni, err := strconv.Atoi(portsStr[0])
if err != nil { if err != nil {
...@@ -239,16 +288,11 @@ func parseUpstream(u string) ([]string, error) { ...@@ -239,16 +288,11 @@ func parseUpstream(u string) ([]string, error) {
for p := pIni; p <= pEnd; p++ { for p := pIni; p <= pEnd; p++ {
hosts = append(hosts, fmt.Sprintf("%s:%d%s", us, p, ue)) hosts = append(hosts, fmt.Sprintf("%s:%d%s", us, p, ue))
} }
return hosts, nil
}
}
}
return []string{u}, nil
return hosts, nil
} }
func parseBlock(c *caddyfile.Dispenser, u *staticUpstream) error { func parseBlock(c *caddyfile.Dispenser, u *staticUpstream, hasSrv bool) error {
switch c.Val() { switch c.Val() {
case "policy": case "policy":
if !c.NextArg() { if !c.NextArg() {
...@@ -348,6 +392,11 @@ func parseBlock(c *caddyfile.Dispenser, u *staticUpstream) error { ...@@ -348,6 +392,11 @@ func parseBlock(c *caddyfile.Dispenser, u *staticUpstream) error {
if !c.NextArg() { if !c.NextArg() {
return c.ArgErr() return c.ArgErr()
} }
if hasSrv {
return c.Err("health_check_port directive is not allowed when upstream is SRV locator")
}
port := c.Val() port := c.Val()
n, err := strconv.Atoi(port) n, err := strconv.Atoi(port)
if err != nil { if err != nil {
...@@ -420,11 +469,43 @@ func parseBlock(c *caddyfile.Dispenser, u *staticUpstream) error { ...@@ -420,11 +469,43 @@ func parseBlock(c *caddyfile.Dispenser, u *staticUpstream) error {
return nil return nil
} }
func (u *staticUpstream) resolveHost(h string) ([]string, bool, error) {
names := []string{}
proto := "http"
if !strings.HasPrefix(h, "srv://") && !strings.HasPrefix(h, "srv+https://") {
return []string{h}, false, nil
}
if strings.HasPrefix(h, "srv+https://") {
proto = "https"
}
_, addrs, err := u.resolver.LookupSRV(context.Background(), "", "", h)
if err != nil {
return names, true, err
}
for _, addr := range addrs {
names = append(names, fmt.Sprintf("%s://%s:%d", proto, addr.Target, addr.Port))
}
return names, true, nil
}
func (u *staticUpstream) healthCheck() { func (u *staticUpstream) healthCheck() {
for _, host := range u.Hosts { for _, host := range u.Hosts {
hostURL := host.Name candidates, isSrv, err := u.resolveHost(host.Name)
if u.HealthCheck.Port != "" { if err != nil {
hostURL = replacePort(host.Name, u.HealthCheck.Port) host.HealthCheckResult.Store(err.Error())
atomic.StoreInt32(&host.Unhealthy, 1)
continue
}
unhealthyCount := 0
for _, addr := range candidates {
hostURL := addr
if !isSrv && u.HealthCheck.Port != "" {
hostURL = replacePort(hostURL, u.HealthCheck.Port)
} }
hostURL += u.HealthCheck.Path hostURL += u.HealthCheck.Path
...@@ -464,10 +545,18 @@ func (u *staticUpstream) healthCheck() { ...@@ -464,10 +545,18 @@ func (u *staticUpstream) healthCheck() {
} }
return true return true
}() }()
if unhealthy { if unhealthy {
unhealthyCount++
}
}
if unhealthyCount == len(candidates) {
atomic.StoreInt32(&host.Unhealthy, 1) atomic.StoreInt32(&host.Unhealthy, 1)
host.HealthCheckResult.Store("Failed")
} else { } else {
atomic.StoreInt32(&host.Unhealthy, 0) atomic.StoreInt32(&host.Unhealthy, 0)
host.HealthCheckResult.Store("OK")
} }
} }
} }
......
...@@ -15,10 +15,15 @@ ...@@ -15,10 +15,15 @@
package proxy package proxy
import ( import (
"context"
"errors"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url"
"reflect"
"strconv"
"strings" "strings"
"sync/atomic" "sync/atomic"
"testing" "testing"
...@@ -187,7 +192,7 @@ func TestParseBlockHealthCheck(t *testing.T) { ...@@ -187,7 +192,7 @@ func TestParseBlockHealthCheck(t *testing.T) {
u := staticUpstream{} u := staticUpstream{}
c := caddyfile.NewDispenser("Testfile", strings.NewReader(test.config)) c := caddyfile.NewDispenser("Testfile", strings.NewReader(test.config))
for c.Next() { for c.Next() {
parseBlock(&c, &u) parseBlock(&c, &u, false)
} }
if u.HealthCheck.Interval.String() != test.interval { if u.HealthCheck.Interval.String() != test.interval {
t.Errorf( t.Errorf(
...@@ -551,3 +556,216 @@ func TestQuicHost(t *testing.T) { ...@@ -551,3 +556,216 @@ func TestQuicHost(t *testing.T) {
} }
} }
} }
func TestParseSRVBlock(t *testing.T) {
tests := []struct {
config string
shouldErr bool
}{
{"proxy / srv://bogus.service", false},
{"proxy / srv://bogus.service:80", true},
{"proxy / srv://bogus.service srv://bogus.service.fallback", true},
{"proxy / srv://bogus.service http://bogus.service.fallback", true},
{"proxy / http://bogus.service srv://bogus.service.fallback", true},
{"proxy / srv://bogus.service bogus.service.fallback", true},
{`proxy / srv://bogus.service {
upstream srv://bogus.service
}`, true},
{"proxy / srv+https://bogus.service", false},
{"proxy / srv+https://bogus.service:80", true},
{"proxy / srv+https://bogus.service srv://bogus.service.fallback", true},
{"proxy / srv+https://bogus.service http://bogus.service.fallback", true},
{"proxy / http://bogus.service srv+https://bogus.service.fallback", true},
{"proxy / srv+https://bogus.service bogus.service.fallback", true},
{`proxy / srv+https://bogus.service {
upstream srv://bogus.service
}`, true},
{`proxy / srv+https://bogus.service {
health_check_port 96
}`, true},
}
for i, test := range tests {
_, err := NewStaticUpstreams(caddyfile.NewDispenser("Testfile", strings.NewReader(test.config)), "")
if err == nil && test.shouldErr {
t.Errorf("Case %d - Expected an error. got nothing", i)
}
if err != nil && !test.shouldErr {
t.Errorf("Case %d - Expected no error. got %s", i, err.Error())
}
}
}
type testResolver struct {
errOn string
result []*net.SRV
}
func (r testResolver) LookupSRV(ctx context.Context, _, _, service string) (string, []*net.SRV, error) {
if service == r.errOn {
return "", nil, errors.New("an error occurred")
}
return "", r.result, nil
}
func TestResolveHost(t *testing.T) {
upstream := &staticUpstream{
resolver: testResolver{
errOn: "srv://problematic.service.name",
result: []*net.SRV{
{Target: "target-1.fqdn", Port: 85, Priority: 1, Weight: 1},
{Target: "target-2.fqdn", Port: 33, Priority: 1, Weight: 1},
{Target: "target-3.fqdn", Port: 94, Priority: 1, Weight: 1},
},
},
}
tests := []struct {
host string
expect []string
isSrv bool
shouldErr bool
}{
// Static DNS records
{"http://subdomain.domain.service",
[]string{"http://subdomain.domain.service"},
false,
false},
{"https://subdomain.domain.service",
[]string{"https://subdomain.domain.service"},
false,
false},
{"http://subdomain.domain.service:76",
[]string{"http://subdomain.domain.service:76"},
false,
false},
{"https://subdomain.domain.service:65",
[]string{"https://subdomain.domain.service:65"},
false,
false},
// SRV lookups
{"srv://service.name", []string{
"http://target-1.fqdn:85",
"http://target-2.fqdn:33",
"http://target-3.fqdn:94",
}, true, false},
{"srv+https://service.name", []string{
"https://target-1.fqdn:85",
"https://target-2.fqdn:33",
"https://target-3.fqdn:94",
}, true, false},
{"srv://problematic.service.name", []string{}, true, true},
}
for i, test := range tests {
results, isSrv, err := upstream.resolveHost(test.host)
if err == nil && test.shouldErr {
t.Errorf("Test %d - expected an error, got none", i)
}
if err != nil && !test.shouldErr {
t.Errorf("Test %d - unexpected error %s", i, err.Error())
}
if test.isSrv && !isSrv {
t.Errorf("Test %d - expecting resolution to be SRV lookup but it isn't", i)
}
if isSrv && !test.isSrv {
t.Errorf("Test %d - expecting resolution to be normal lookup, got SRV", i)
}
if !reflect.DeepEqual(results, test.expect) {
t.Errorf("Test %d - resolution result %#v does not match expected value %#v", i, results, test.expect)
}
}
}
func TestSRVHealthCheck(t *testing.T) {
serverURL, err := url.Parse(workableServer.URL)
if err != nil {
t.Errorf("Failed to parse test server URL: %s", err.Error())
}
pp, err := strconv.Atoi(serverURL.Port())
if err != nil {
t.Errorf("Failed to parse test server port [%s]: %s", serverURL.Port(), err.Error())
}
port := uint16(pp)
allGoodResolver := testResolver{
result: []*net.SRV{
{Target: serverURL.Hostname(), Port: port, Priority: 1, Weight: 1},
},
}
partialFailureResolver := testResolver{
result: []*net.SRV{
{Target: serverURL.Hostname(), Port: port, Priority: 1, Weight: 1},
{Target: "target-2.fqdn", Port: 33, Priority: 1, Weight: 1},
{Target: "target-3.fqdn", Port: 94, Priority: 1, Weight: 1},
},
}
fullFailureResolver := testResolver{
result: []*net.SRV{
{Target: "target-1.fqdn", Port: 876, Priority: 1, Weight: 1},
{Target: "target-2.fqdn", Port: 33, Priority: 1, Weight: 1},
{Target: "target-3.fqdn", Port: 94, Priority: 1, Weight: 1},
},
}
resolutionErrorResolver := testResolver{
errOn: "srv://tag.service.consul",
result: []*net.SRV{},
}
upstream := &staticUpstream{
Hosts: []*UpstreamHost{
{Name: "srv://tag.service.consul"},
},
FailTimeout: 10 * time.Second,
MaxFails: 1,
}
tests := []struct {
resolver testResolver
shouldFail bool
shouldErr bool
}{
{allGoodResolver, false, false},
{partialFailureResolver, false, false},
{fullFailureResolver, true, false},
{resolutionErrorResolver, true, true},
}
for i, test := range tests {
upstream.resolver = test.resolver
upstream.healthCheck()
if upstream.Hosts[0].Down() && !test.shouldFail {
t.Errorf("Test %d - expected all healthchecks to pass, all failing", i)
}
if test.shouldFail && !upstream.Hosts[0].Down() {
t.Errorf("Test %d - expected all healthchecks to fail, all passing", i)
}
status := fmt.Sprintf("%s", upstream.Hosts[0].HealthCheckResult.Load())
if test.shouldFail && !test.shouldErr && status != "Failed" {
t.Errorf("Test %d - Expected health check result to be 'Failed', got '%s'", i, status)
}
if !test.shouldFail && status != "OK" {
t.Errorf("Test %d - Expected health check result to be 'OK', got '%s'", i, status)
}
if test.shouldErr && status != "an error occurred" {
t.Errorf("Test %d - Expected health check result to be 'an error occured', got '%s'", i, status)
}
}
}
...@@ -39,6 +39,7 @@ type ACMEClient struct { ...@@ -39,6 +39,7 @@ type ACMEClient struct {
AllowPrompts bool AllowPrompts bool
config *Config config *Config
acmeClient *acme.Client acmeClient *acme.Client
locker Locker
} }
// newACMEClient creates a new ACMEClient given an email and whether // newACMEClient creates a new ACMEClient given an email and whether
...@@ -120,6 +121,10 @@ var newACMEClient = func(config *Config, allowPrompts bool) (*ACMEClient, error) ...@@ -120,6 +121,10 @@ var newACMEClient = func(config *Config, allowPrompts bool) (*ACMEClient, error)
AllowPrompts: allowPrompts, AllowPrompts: allowPrompts,
config: config, config: config,
acmeClient: client, acmeClient: client,
locker: &syncLock{
nameLocks: make(map[string]*sync.WaitGroup),
nameLocksMu: sync.Mutex{},
},
} }
if config.DNSProvider == "" { if config.DNSProvider == "" {
...@@ -210,7 +215,7 @@ func (c *ACMEClient) Obtain(name string) error { ...@@ -210,7 +215,7 @@ func (c *ACMEClient) Obtain(name string) error {
return err return err
} }
waiter, err := storage.TryLock(name) waiter, err := c.locker.TryLock(name)
if err != nil { if err != nil {
return err return err
} }
...@@ -220,7 +225,7 @@ func (c *ACMEClient) Obtain(name string) error { ...@@ -220,7 +225,7 @@ func (c *ACMEClient) Obtain(name string) error {
return nil // we assume the process with the lock succeeded, rather than hammering this execution path again return nil // we assume the process with the lock succeeded, rather than hammering this execution path again
} }
defer func() { defer func() {
if err := storage.Unlock(name); err != nil { if err := c.locker.Unlock(name); err != nil {
log.Printf("[ERROR] Unable to unlock obtain call for %s: %v", name, err) log.Printf("[ERROR] Unable to unlock obtain call for %s: %v", name, err)
} }
}() }()
...@@ -286,7 +291,7 @@ func (c *ACMEClient) Renew(name string) error { ...@@ -286,7 +291,7 @@ func (c *ACMEClient) Renew(name string) error {
return err return err
} }
waiter, err := storage.TryLock(name) waiter, err := c.locker.TryLock(name)
if err != nil { if err != nil {
return err return err
} }
...@@ -296,7 +301,7 @@ func (c *ACMEClient) Renew(name string) error { ...@@ -296,7 +301,7 @@ func (c *ACMEClient) Renew(name string) error {
return nil // we assume the process with the lock succeeded, rather than hammering this execution path again return nil // we assume the process with the lock succeeded, rather than hammering this execution path again
} }
defer func() { defer func() {
if err := storage.Unlock(name); err != nil { if err := c.locker.Unlock(name); err != nil {
log.Printf("[ERROR] Unable to unlock renew call for %s: %v", name, err) log.Printf("[ERROR] Unable to unlock renew call for %s: %v", name, err)
} }
}() }()
......
...@@ -22,7 +22,6 @@ import ( ...@@ -22,7 +22,6 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
"sync"
"github.com/mholt/caddy" "github.com/mholt/caddy"
) )
...@@ -41,7 +40,6 @@ var storageBasePath = filepath.Join(caddy.AssetsPath(), "acme") ...@@ -41,7 +40,6 @@ var storageBasePath = filepath.Join(caddy.AssetsPath(), "acme")
func NewFileStorage(caURL *url.URL) (Storage, error) { func NewFileStorage(caURL *url.URL) (Storage, error) {
return &FileStorage{ return &FileStorage{
Path: filepath.Join(storageBasePath, caURL.Host), Path: filepath.Join(storageBasePath, caURL.Host),
nameLocks: make(map[string]*sync.WaitGroup),
}, nil }, nil
} }
...@@ -50,8 +48,6 @@ func NewFileStorage(caURL *url.URL) (Storage, error) { ...@@ -50,8 +48,6 @@ func NewFileStorage(caURL *url.URL) (Storage, error) {
// cross-platform way or persisting ACME assets on the file system. // cross-platform way or persisting ACME assets on the file system.
type FileStorage struct { type FileStorage struct {
Path string Path string
nameLocks map[string]*sync.WaitGroup
nameLocksMu sync.Mutex
} }
// sites gets the directory that stores site certificate and keys. // sites gets the directory that stores site certificate and keys.
...@@ -254,36 +250,6 @@ func (s *FileStorage) StoreUser(email string, data *UserData) error { ...@@ -254,36 +250,6 @@ func (s *FileStorage) StoreUser(email string, data *UserData) error {
return nil return nil
} }
// TryLock attempts to get a lock for name, otherwise it returns
// a Waiter value to wait until the other process is finished.
func (s *FileStorage) TryLock(name string) (Waiter, error) {
s.nameLocksMu.Lock()
defer s.nameLocksMu.Unlock()
wg, ok := s.nameLocks[name]
if ok {
// lock already obtained, let caller wait on it
return wg, nil
}
// caller gets lock
wg = new(sync.WaitGroup)
wg.Add(1)
s.nameLocks[name] = wg
return nil, nil
}
// Unlock unlocks name.
func (s *FileStorage) Unlock(name string) error {
s.nameLocksMu.Lock()
defer s.nameLocksMu.Unlock()
wg, ok := s.nameLocks[name]
if !ok {
return fmt.Errorf("FileStorage: no lock to release for %s", name)
}
wg.Done()
delete(s.nameLocks, name)
return nil
}
// MostRecentUserEmail implements Storage.MostRecentUserEmail by finding the // MostRecentUserEmail implements Storage.MostRecentUserEmail by finding the
// most recently written sub directory in the users' directory. It is named // most recently written sub directory in the users' directory. It is named
// after the email address. This corresponds to the most recent call to // after the email address. This corresponds to the most recent call to
......
...@@ -39,24 +39,9 @@ type UserData struct { ...@@ -39,24 +39,9 @@ type UserData struct {
Key []byte Key []byte
} }
// Storage is an interface abstracting all storage used by Caddy's TLS // Locker provides support for mutual exclusion
// subsystem. Implementations of this interface store both site and type Locker interface {
// user data. // TryLock will return immediatedly with or without acquiring the lock.
type Storage interface {
// SiteExists returns true if this site exists in storage.
// Site data is considered present when StoreSite has been called
// successfully (without DeleteSite having been called, of course).
SiteExists(domain string) (bool, error)
// TryLock is called before Caddy attempts to obtain or renew a
// certificate for a certain name and store it. From the perspective
// of this method and its companion Unlock, the actions of
// obtaining/renewing and then storing the certificate are atomic,
// and both should occur within a lock. This prevents multiple
// processes -- maybe distributed ones -- from stepping on each
// other's space in the same shared storage, and from spamming
// certificate providers with multiple, redundant requests.
//
// If a lock could be obtained, (nil, nil) is returned and you may // If a lock could be obtained, (nil, nil) is returned and you may
// continue normally. If not (meaning another process is already // continue normally. If not (meaning another process is already
// working on that name), a Waiter value will be returned upon // working on that name), a Waiter value will be returned upon
...@@ -75,6 +60,16 @@ type Storage interface { ...@@ -75,6 +60,16 @@ type Storage interface {
// the obtain/renew and store are finished, even if there was // the obtain/renew and store are finished, even if there was
// an error (or a timeout). // an error (or a timeout).
Unlock(name string) error Unlock(name string) error
}
// Storage is an interface abstracting all storage used by Caddy's TLS
// subsystem. Implementations of this interface store both site and
// user data.
type Storage interface {
// SiteExists returns true if this site exists in storage.
// Site data is considered present when StoreSite has been called
// successfully (without DeleteSite having been called, of course).
SiteExists(domain string) (bool, error)
// LoadSite obtains the site data from storage for the given domain and // LoadSite obtains the site data from storage for the given domain and
// returns it. If data for the domain does not exist, an error value // returns it. If data for the domain does not exist, an error value
......
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package caddytls
import (
"fmt"
"sync"
)
var _ Locker = &syncLock{}
type syncLock struct {
nameLocks map[string]*sync.WaitGroup
nameLocksMu sync.Mutex
}
// TryLock attempts to get a lock for name, otherwise it returns
// a Waiter value to wait until the other process is finished.
func (s *syncLock) TryLock(name string) (Waiter, error) {
s.nameLocksMu.Lock()
defer s.nameLocksMu.Unlock()
wg, ok := s.nameLocks[name]
if ok {
// lock already obtained, let caller wait on it
return wg, nil
}
// caller gets lock
wg = new(sync.WaitGroup)
wg.Add(1)
s.nameLocks[name] = wg
return nil, nil
}
// Unlock unlocks name.
func (s *syncLock) Unlock(name string) error {
s.nameLocksMu.Lock()
defer s.nameLocksMu.Unlock()
wg, ok := s.nameLocks[name]
if !ok {
return fmt.Errorf("FileStorage: no lock to release for %s", name)
}
wg.Done()
delete(s.nameLocks, name)
return nil
}
...@@ -16,7 +16,6 @@ package caddytls ...@@ -16,7 +16,6 @@ package caddytls
import ( import (
"os" "os"
"sync"
"testing" "testing"
"github.com/xenolf/lego/acme" "github.com/xenolf/lego/acme"
...@@ -94,7 +93,7 @@ func TestQualifiesForManagedTLS(t *testing.T) { ...@@ -94,7 +93,7 @@ func TestQualifiesForManagedTLS(t *testing.T) {
} }
func TestSaveCertResource(t *testing.T) { func TestSaveCertResource(t *testing.T) {
storage := &FileStorage{Path: "./le_test_save", nameLocks: make(map[string]*sync.WaitGroup)} storage := &FileStorage{Path: "./le_test_save"}
defer func() { defer func() {
err := os.RemoveAll(storage.Path) err := os.RemoveAll(storage.Path)
if err != nil { if err != nil {
...@@ -140,7 +139,7 @@ func TestSaveCertResource(t *testing.T) { ...@@ -140,7 +139,7 @@ func TestSaveCertResource(t *testing.T) {
} }
func TestExistingCertAndKey(t *testing.T) { func TestExistingCertAndKey(t *testing.T) {
storage := &FileStorage{Path: "./le_test_existing", nameLocks: make(map[string]*sync.WaitGroup)} storage := &FileStorage{Path: "./le_test_existing"}
defer func() { defer func() {
err := os.RemoveAll(storage.Path) err := os.RemoveAll(storage.Path)
if err != nil { if err != nil {
......
...@@ -21,7 +21,6 @@ import ( ...@@ -21,7 +21,6 @@ import (
"crypto/rand" "crypto/rand"
"io" "io"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
...@@ -196,7 +195,7 @@ func TestGetEmail(t *testing.T) { ...@@ -196,7 +195,7 @@ func TestGetEmail(t *testing.T) {
} }
} }
var testStorage = &FileStorage{Path: "./testdata", nameLocks: make(map[string]*sync.WaitGroup)} var testStorage = &FileStorage{Path: "./testdata"}
func (s *FileStorage) clean() error { func (s *FileStorage) clean() error {
return os.RemoveAll(s.Path) return os.RemoveAll(s.Path)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment