Commit 986d4ffe authored by Matt Holt's avatar Matt Holt Committed by GitHub

Merge pull request #2015 from mholt/cert-cache

tls: Restructure and improve certificate management
parents 6f4cf7ee a03eba6f
...@@ -77,8 +77,18 @@ var ( ...@@ -77,8 +77,18 @@ var (
mu sync.Mutex mu sync.Mutex
) )
func init() {
OnProcessExit = append(OnProcessExit, func() {
if PidFile != "" {
os.Remove(PidFile)
}
})
}
// Instance contains the state of servers created as a result of // Instance contains the state of servers created as a result of
// calling Start and can be used to access or control those servers. // calling Start and can be used to access or control those servers.
// It is literally an instance of a server type. Instance values
// should NOT be copied. Use *Instance for safety.
type Instance struct { type Instance struct {
// serverType is the name of the instance's server type // serverType is the name of the instance's server type
serverType string serverType string
...@@ -89,10 +99,11 @@ type Instance struct { ...@@ -89,10 +99,11 @@ type Instance struct {
// wg is used to wait for all servers to shut down // wg is used to wait for all servers to shut down
wg *sync.WaitGroup wg *sync.WaitGroup
// context is the context created for this instance. // context is the context created for this instance,
// used to coordinate the setting up of the server type
context Context context Context
// servers is the list of servers with their listeners. // servers is the list of servers with their listeners
servers []ServerListener servers []ServerListener
// these callbacks execute when certain events occur // these callbacks execute when certain events occur
...@@ -101,6 +112,18 @@ type Instance struct { ...@@ -101,6 +112,18 @@ type Instance struct {
onRestart []func() error // before restart commences onRestart []func() error // before restart commences
onShutdown []func() error // stopping, even as part of a restart onShutdown []func() error // stopping, even as part of a restart
onFinalShutdown []func() error // stopping, not as part of a restart onFinalShutdown []func() error // stopping, not as part of a restart
// storing values on an instance is preferable to
// global state because these will get garbage-
// collected after in-process reloads when the
// old instances are destroyed; use StorageMu
// to access this value safely
Storage map[interface{}]interface{}
StorageMu sync.RWMutex
}
func Instances() []*Instance {
return instances
} }
// Servers returns the ServerListeners in i. // Servers returns the ServerListeners in i.
...@@ -196,7 +219,7 @@ func (i *Instance) Restart(newCaddyfile Input) (*Instance, error) { ...@@ -196,7 +219,7 @@ func (i *Instance) Restart(newCaddyfile Input) (*Instance, error) {
} }
// create new instance; if the restart fails, it is simply discarded // create new instance; if the restart fails, it is simply discarded
newInst := &Instance{serverType: newCaddyfile.ServerType(), wg: i.wg} newInst := &Instance{serverType: newCaddyfile.ServerType(), wg: i.wg, Storage: make(map[interface{}]interface{})}
// attempt to start new instance // attempt to start new instance
err := startWithListenerFds(newCaddyfile, newInst, restartFds) err := startWithListenerFds(newCaddyfile, newInst, restartFds)
...@@ -455,7 +478,7 @@ func (i *Instance) Caddyfile() Input { ...@@ -455,7 +478,7 @@ func (i *Instance) Caddyfile() Input {
// //
// This function blocks until all the servers are listening. // This function blocks until all the servers are listening.
func Start(cdyfile Input) (*Instance, error) { func Start(cdyfile Input) (*Instance, error) {
inst := &Instance{serverType: cdyfile.ServerType(), wg: new(sync.WaitGroup)} inst := &Instance{serverType: cdyfile.ServerType(), wg: new(sync.WaitGroup), Storage: make(map[interface{}]interface{})}
err := startWithListenerFds(cdyfile, inst, nil) err := startWithListenerFds(cdyfile, inst, nil)
if err != nil { if err != nil {
return inst, err return inst, err
...@@ -468,11 +491,34 @@ func Start(cdyfile Input) (*Instance, error) { ...@@ -468,11 +491,34 @@ func Start(cdyfile Input) (*Instance, error) {
} }
func startWithListenerFds(cdyfile Input, inst *Instance, restartFds map[string]restartTriple) error { func startWithListenerFds(cdyfile Input, inst *Instance, restartFds map[string]restartTriple) error {
// save this instance in the list now so that
// plugins can access it if need be, for example
// the caddytls package, so it can perform cert
// renewals while starting up; we just have to
// remove the instance from the list later if
// it fails
instancesMu.Lock()
instances = append(instances, inst)
instancesMu.Unlock()
var err error
defer func() {
if err != nil {
instancesMu.Lock()
for i, otherInst := range instances {
if otherInst == inst {
instances = append(instances[:i], instances[i+1:]...)
break
}
}
instancesMu.Unlock()
}
}()
if cdyfile == nil { if cdyfile == nil {
cdyfile = CaddyfileInput{} cdyfile = CaddyfileInput{}
} }
err := ValidateAndExecuteDirectives(cdyfile, inst, false) err = ValidateAndExecuteDirectives(cdyfile, inst, false)
if err != nil { if err != nil {
return err return err
} }
...@@ -504,10 +550,6 @@ func startWithListenerFds(cdyfile Input, inst *Instance, restartFds map[string]r ...@@ -504,10 +550,6 @@ func startWithListenerFds(cdyfile Input, inst *Instance, restartFds map[string]r
return err return err
} }
instancesMu.Lock()
instances = append(instances, inst)
instancesMu.Unlock()
// run any AfterStartup callbacks if this is not // run any AfterStartup callbacks if this is not
// part of a restart; then show file descriptor notice // part of a restart; then show file descriptor notice
if restartFds == nil { if restartFds == nil {
...@@ -546,7 +588,7 @@ func startWithListenerFds(cdyfile Input, inst *Instance, restartFds map[string]r ...@@ -546,7 +588,7 @@ func startWithListenerFds(cdyfile Input, inst *Instance, restartFds map[string]r
func ValidateAndExecuteDirectives(cdyfile Input, inst *Instance, justValidate bool) error { func ValidateAndExecuteDirectives(cdyfile Input, inst *Instance, justValidate bool) error {
// If parsing only inst will be nil, create an instance for this function call only. // If parsing only inst will be nil, create an instance for this function call only.
if justValidate { if justValidate {
inst = &Instance{serverType: cdyfile.ServerType(), wg: new(sync.WaitGroup)} inst = &Instance{serverType: cdyfile.ServerType(), wg: new(sync.WaitGroup), Storage: make(map[interface{}]interface{})}
} }
stypeName := cdyfile.ServerType() stypeName := cdyfile.ServerType()
...@@ -563,14 +605,14 @@ func ValidateAndExecuteDirectives(cdyfile Input, inst *Instance, justValidate bo ...@@ -563,14 +605,14 @@ func ValidateAndExecuteDirectives(cdyfile Input, inst *Instance, justValidate bo
return err return err
} }
inst.context = stype.NewContext() inst.context = stype.NewContext(inst)
if inst.context == nil { if inst.context == nil {
return fmt.Errorf("server type %s produced a nil Context", stypeName) return fmt.Errorf("server type %s produced a nil Context", stypeName)
} }
sblocks, err = inst.context.InspectServerBlocks(cdyfile.Path(), sblocks) sblocks, err = inst.context.InspectServerBlocks(cdyfile.Path(), sblocks)
if err != nil { if err != nil {
return err return fmt.Errorf("error inspecting server blocks: %v", err)
} }
return executeDirectives(inst, cdyfile.Path(), stype.Directives(), sblocks, justValidate) return executeDirectives(inst, cdyfile.Path(), stype.Directives(), sblocks, justValidate)
......
...@@ -27,7 +27,7 @@ func activateHTTPS(cctx caddy.Context) error { ...@@ -27,7 +27,7 @@ func activateHTTPS(cctx caddy.Context) error {
operatorPresent := !caddy.Started() operatorPresent := !caddy.Started()
if !caddy.Quiet && operatorPresent { if !caddy.Quiet && operatorPresent {
fmt.Print("Activating privacy features...") fmt.Print("Activating privacy features... ")
} }
ctx := cctx.(*httpContext) ctx := cctx.(*httpContext)
...@@ -69,7 +69,7 @@ func activateHTTPS(cctx caddy.Context) error { ...@@ -69,7 +69,7 @@ func activateHTTPS(cctx caddy.Context) error {
} }
if !caddy.Quiet && operatorPresent { if !caddy.Quiet && operatorPresent {
fmt.Println(" done.") fmt.Println("done.")
} }
return nil return nil
...@@ -160,23 +160,37 @@ func hostHasOtherPort(allConfigs []*SiteConfig, thisConfigIdx int, otherPort str ...@@ -160,23 +160,37 @@ func hostHasOtherPort(allConfigs []*SiteConfig, thisConfigIdx int, otherPort str
// to listen on HTTPPort. The TLS field of cfg must not be nil. // to listen on HTTPPort. The TLS field of cfg must not be nil.
func redirPlaintextHost(cfg *SiteConfig) *SiteConfig { func redirPlaintextHost(cfg *SiteConfig) *SiteConfig {
redirPort := cfg.Addr.Port redirPort := cfg.Addr.Port
if redirPort == DefaultHTTPSPort { if redirPort == HTTPSPort {
redirPort = "" // default port is redundant // By default, HTTPSPort should be DefaultHTTPSPort,
// which of course doesn't need to be explicitly stated
// in the Location header. Even if HTTPSPort is changed
// so that it is no longer DefaultHTTPSPort, we shouldn't
// append it to the URL in the Location because changing
// the HTTPS port is assumed to be an internal-only change
// (in other words, we assume port forwarding is going on);
// but redirects go back to a presumably-external client.
// (If redirect clients are also internal, that is more
// advanced, and the user should configure HTTP->HTTPS
// redirects themselves.)
redirPort = ""
} }
redirMiddleware := func(next Handler) Handler { redirMiddleware := func(next Handler) Handler {
return HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { return HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
// Construct the URL to which to redirect. Note that the Host in a request might // Construct the URL to which to redirect. Note that the Host in a
// contain a port, but we just need the hostname; we'll set the port if needed. // request might contain a port, but we just need the hostname from
// it; and we'll set the port if needed.
toURL := "https://" toURL := "https://"
requestHost, _, err := net.SplitHostPort(r.Host) requestHost, _, err := net.SplitHostPort(r.Host)
if err != nil { if err != nil {
requestHost = r.Host // Host did not contain a port; great requestHost = r.Host // Host did not contain a port, so use the whole value
} }
if redirPort == "" { if redirPort == "" {
toURL += requestHost toURL += requestHost
} else { } else {
toURL += net.JoinHostPort(requestHost, redirPort) toURL += net.JoinHostPort(requestHost, redirPort)
} }
toURL += r.URL.RequestURI() toURL += r.URL.RequestURI()
w.Header().Set("Connection", "close") w.Header().Set("Connection", "close")
...@@ -184,9 +198,11 @@ func redirPlaintextHost(cfg *SiteConfig) *SiteConfig { ...@@ -184,9 +198,11 @@ func redirPlaintextHost(cfg *SiteConfig) *SiteConfig {
return 0, nil return 0, nil
}) })
} }
host := cfg.Addr.Host host := cfg.Addr.Host
port := HTTPPort port := HTTPPort
addr := net.JoinHostPort(host, port) addr := net.JoinHostPort(host, port)
return &SiteConfig{ return &SiteConfig{
Addr: Address{Original: addr, Host: host, Port: port}, Addr: Address{Original: addr, Host: host, Port: port},
ListenHost: cfg.ListenHost, ListenHost: cfg.ListenHost,
......
...@@ -53,7 +53,7 @@ func TestRedirPlaintextHost(t *testing.T) { ...@@ -53,7 +53,7 @@ func TestRedirPlaintextHost(t *testing.T) {
}, },
{ {
Host: "foohost", Host: "foohost",
Port: "443", // since this is the default HTTPS port, should not be included in Location value Port: HTTPSPort, // since this is the 'default' HTTPS port, should not be included in Location value
}, },
{ {
Host: "*.example.com", Host: "*.example.com",
......
...@@ -91,11 +91,13 @@ func hideCaddyfile(cctx caddy.Context) error { ...@@ -91,11 +91,13 @@ func hideCaddyfile(cctx caddy.Context) error {
return nil return nil
} }
func newContext() caddy.Context { func newContext(inst *caddy.Instance) caddy.Context {
return &httpContext{keysToSiteConfigs: make(map[string]*SiteConfig)} return &httpContext{instance: inst, keysToSiteConfigs: make(map[string]*SiteConfig)}
} }
type httpContext struct { type httpContext struct {
instance *caddy.Instance
// keysToSiteConfigs maps an address at the top of a // keysToSiteConfigs maps an address at the top of a
// server block (a "key") to its SiteConfig. Not all // server block (a "key") to its SiteConfig. Not all
// SiteConfigs will be represented here, only ones // SiteConfigs will be represented here, only ones
...@@ -115,12 +117,14 @@ func (h *httpContext) saveConfig(key string, cfg *SiteConfig) { ...@@ -115,12 +117,14 @@ func (h *httpContext) saveConfig(key string, cfg *SiteConfig) {
// executing directives and otherwise prepares the directives to // executing directives and otherwise prepares the directives to
// be parsed and executed. // be parsed and executed.
func (h *httpContext) InspectServerBlocks(sourceFile string, serverBlocks []caddyfile.ServerBlock) ([]caddyfile.ServerBlock, error) { func (h *httpContext) InspectServerBlocks(sourceFile string, serverBlocks []caddyfile.ServerBlock) ([]caddyfile.ServerBlock, error) {
siteAddrs := make(map[string]string)
// For each address in each server block, make a new config // For each address in each server block, make a new config
for _, sb := range serverBlocks { for _, sb := range serverBlocks {
for _, key := range sb.Keys { for _, key := range sb.Keys {
key = strings.ToLower(key) key = strings.ToLower(key)
if _, dup := h.keysToSiteConfigs[key]; dup { if _, dup := h.keysToSiteConfigs[key]; dup {
return serverBlocks, fmt.Errorf("duplicate site address: %s", key) return serverBlocks, fmt.Errorf("duplicate site key: %s", key)
} }
addr, err := standardizeAddress(key) addr, err := standardizeAddress(key)
if err != nil { if err != nil {
...@@ -136,6 +140,23 @@ func (h *httpContext) InspectServerBlocks(sourceFile string, serverBlocks []cadd ...@@ -136,6 +140,23 @@ func (h *httpContext) InspectServerBlocks(sourceFile string, serverBlocks []cadd
addr.Port = Port addr.Port = Port
} }
// Make sure the adjusted site address is distinct
addrCopy := addr // make copy so we don't disturb the original, carefully-parsed address struct
if addrCopy.Port == "" && Port == DefaultPort {
addrCopy.Port = Port
}
addrStr := strings.ToLower(addrCopy.String())
if otherSiteKey, dup := siteAddrs[addrStr]; dup {
err := fmt.Errorf("duplicate site address: %s", addrStr)
if (addrCopy.Host == Host && Host != DefaultHost) ||
(addrCopy.Port == Port && Port != DefaultPort) {
err = fmt.Errorf("site defined as %s is a duplicate of %s because of modified "+
"default host and/or port values (usually via -host or -port flags)", key, otherSiteKey)
}
return serverBlocks, err
}
siteAddrs[addrStr] = key
// If default HTTP or HTTPS ports have been customized, // If default HTTP or HTTPS ports have been customized,
// make sure the ACME challenge ports match // make sure the ACME challenge ports match
var altHTTPPort, altTLSSNIPort string var altHTTPPort, altTLSSNIPort string
...@@ -146,15 +167,19 @@ func (h *httpContext) InspectServerBlocks(sourceFile string, serverBlocks []cadd ...@@ -146,15 +167,19 @@ func (h *httpContext) InspectServerBlocks(sourceFile string, serverBlocks []cadd
altTLSSNIPort = HTTPSPort altTLSSNIPort = HTTPSPort
} }
// Make our caddytls.Config, which has a pointer to the
// instance's certificate cache and enough information
// to use automatic HTTPS when the time comes
caddytlsConfig := caddytls.NewConfig(h.instance)
caddytlsConfig.Hostname = addr.Host
caddytlsConfig.AltHTTPPort = altHTTPPort
caddytlsConfig.AltTLSSNIPort = altTLSSNIPort
// Save the config to our master list, and key it for lookups // Save the config to our master list, and key it for lookups
cfg := &SiteConfig{ cfg := &SiteConfig{
Addr: addr, Addr: addr,
Root: Root, Root: Root,
TLS: &caddytls.Config{ TLS: caddytlsConfig,
Hostname: addr.Host,
AltHTTPPort: altHTTPPort,
AltTLSSNIPort: altTLSSNIPort,
},
originCaddyfile: sourceFile, originCaddyfile: sourceFile,
IndexPages: staticfiles.DefaultIndexPages, IndexPages: staticfiles.DefaultIndexPages,
} }
......
...@@ -137,7 +137,7 @@ func TestAddressString(t *testing.T) { ...@@ -137,7 +137,7 @@ func TestAddressString(t *testing.T) {
func TestInspectServerBlocksWithCustomDefaultPort(t *testing.T) { func TestInspectServerBlocksWithCustomDefaultPort(t *testing.T) {
Port = "9999" Port = "9999"
filename := "Testfile" filename := "Testfile"
ctx := newContext().(*httpContext) ctx := newContext(&caddy.Instance{Storage: make(map[interface{}]interface{})}).(*httpContext)
input := strings.NewReader(`localhost`) input := strings.NewReader(`localhost`)
sblocks, err := caddyfile.Parse(filename, input, nil) sblocks, err := caddyfile.Parse(filename, input, nil)
if err != nil { if err != nil {
...@@ -153,9 +153,26 @@ func TestInspectServerBlocksWithCustomDefaultPort(t *testing.T) { ...@@ -153,9 +153,26 @@ func TestInspectServerBlocksWithCustomDefaultPort(t *testing.T) {
} }
} }
// See discussion on PR #2015
func TestInspectServerBlocksWithAdjustedAddress(t *testing.T) {
Port = DefaultPort
Host = "example.com"
filename := "Testfile"
ctx := newContext(&caddy.Instance{Storage: make(map[interface{}]interface{})}).(*httpContext)
input := strings.NewReader("example.com {\n}\n:2015 {\n}")
sblocks, err := caddyfile.Parse(filename, input, nil)
if err != nil {
t.Fatalf("Expected no error setting up test, got: %v", err)
}
_, err = ctx.InspectServerBlocks(filename, sblocks)
if err == nil {
t.Fatalf("Expected an error because site definitions should overlap, got: %v", err)
}
}
func TestInspectServerBlocksCaseInsensitiveKey(t *testing.T) { func TestInspectServerBlocksCaseInsensitiveKey(t *testing.T) {
filename := "Testfile" filename := "Testfile"
ctx := newContext().(*httpContext) ctx := newContext(&caddy.Instance{Storage: make(map[interface{}]interface{})}).(*httpContext)
input := strings.NewReader("localhost {\n}\nLOCALHOST {\n}") input := strings.NewReader("localhost {\n}\nLOCALHOST {\n}")
sblocks, err := caddyfile.Parse(filename, input, nil) sblocks, err := caddyfile.Parse(filename, input, nil)
if err != nil { if err != nil {
...@@ -207,7 +224,7 @@ func TestDirectivesList(t *testing.T) { ...@@ -207,7 +224,7 @@ func TestDirectivesList(t *testing.T) {
} }
func TestContextSaveConfig(t *testing.T) { func TestContextSaveConfig(t *testing.T) {
ctx := newContext().(*httpContext) ctx := newContext(&caddy.Instance{Storage: make(map[interface{}]interface{})}).(*httpContext)
ctx.saveConfig("foo", new(SiteConfig)) ctx.saveConfig("foo", new(SiteConfig))
if _, ok := ctx.keysToSiteConfigs["foo"]; !ok { if _, ok := ctx.keysToSiteConfigs["foo"]; !ok {
t.Error("Expected config to be saved, but it wasn't") t.Error("Expected config to be saved, but it wasn't")
...@@ -226,7 +243,7 @@ func TestContextSaveConfig(t *testing.T) { ...@@ -226,7 +243,7 @@ func TestContextSaveConfig(t *testing.T) {
// Test to make sure we are correctly hiding the Caddyfile // Test to make sure we are correctly hiding the Caddyfile
func TestHideCaddyfile(t *testing.T) { func TestHideCaddyfile(t *testing.T) {
ctx := newContext().(*httpContext) ctx := newContext(&caddy.Instance{Storage: make(map[interface{}]interface{})}).(*httpContext)
ctx.saveConfig("test", &SiteConfig{ ctx.saveConfig("test", &SiteConfig{
Root: Root, Root: Root,
originCaddyfile: "Testfile", originCaddyfile: "Testfile",
......
...@@ -389,7 +389,7 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) (int, error) ...@@ -389,7 +389,7 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) (int, error)
if vhost == nil { if vhost == nil {
// check for ACME challenge even if vhost is nil; // check for ACME challenge even if vhost is nil;
// could be a new host coming online soon // could be a new host coming online soon
if caddytls.HTTPChallengeHandler(w, r, "localhost", caddytls.DefaultHTTPAlternatePort) { if caddytls.HTTPChallengeHandler(w, r, "localhost") {
return 0, nil return 0, nil
} }
// otherwise, log the error and write a message to the client // otherwise, log the error and write a message to the client
...@@ -405,7 +405,7 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) (int, error) ...@@ -405,7 +405,7 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) (int, error)
// we still check for ACME challenge if the vhost exists, // we still check for ACME challenge if the vhost exists,
// because we must apply its HTTP challenge config settings // because we must apply its HTTP challenge config settings
if s.proxyHTTPChallenge(vhost, w, r) { if caddytls.HTTPChallengeHandler(w, r, vhost.ListenHost) {
return 0, nil return 0, nil
} }
...@@ -422,24 +422,6 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) (int, error) ...@@ -422,24 +422,6 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) (int, error)
return vhost.middlewareChain.ServeHTTP(w, r) return vhost.middlewareChain.ServeHTTP(w, r)
} }
// proxyHTTPChallenge solves the ACME HTTP challenge if r is the HTTP
// request for the challenge. If it is, and if the request has been
// fulfilled (response written), true is returned; false otherwise.
// If you don't have a vhost, just call the challenge handler directly.
func (s *Server) proxyHTTPChallenge(vhost *SiteConfig, w http.ResponseWriter, r *http.Request) bool {
if vhost.Addr.Port != caddytls.HTTPChallengePort {
return false
}
if vhost.TLS != nil && vhost.TLS.Manual {
return false
}
altPort := caddytls.DefaultHTTPAlternatePort
if vhost.TLS != nil && vhost.TLS.AltHTTPPort != "" {
altPort = vhost.TLS.AltHTTPPort
}
return caddytls.HTTPChallengeHandler(w, r, vhost.ListenHost, altPort)
}
// Address returns the address s was assigned to listen on. // Address returns the address s was assigned to listen on.
func (s *Server) Address() string { func (s *Server) Address() string {
return s.Server.Addr return s.Server.Addr
......
This diff is collapsed.
...@@ -17,57 +17,71 @@ package caddytls ...@@ -17,57 +17,71 @@ package caddytls
import "testing" import "testing"
func TestUnexportedGetCertificate(t *testing.T) { func TestUnexportedGetCertificate(t *testing.T) {
defer func() { certCache = make(map[string]Certificate) }() certCache := &certificateCache{cache: make(map[string]Certificate)}
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
// When cache is empty // When cache is empty
if _, matched, defaulted := getCertificate("example.com"); matched || defaulted { if _, matched, defaulted := cfg.getCertificate("example.com"); matched || defaulted {
t.Errorf("Got a certificate when cache was empty; matched=%v, defaulted=%v", matched, defaulted) t.Errorf("Got a certificate when cache was empty; matched=%v, defaulted=%v", matched, defaulted)
} }
// When cache has one certificate in it (also is default) // When cache has one certificate in it
defaultCert := Certificate{Names: []string{"example.com", ""}} firstCert := Certificate{Names: []string{"example.com"}}
certCache[""] = defaultCert certCache.cache["0xdeadbeef"] = firstCert
certCache["example.com"] = defaultCert cfg.Certificates["example.com"] = "0xdeadbeef"
if cert, matched, defaulted := getCertificate("Example.com"); !matched || defaulted || cert.Names[0] != "example.com" { if cert, matched, defaulted := cfg.getCertificate("Example.com"); !matched || defaulted || cert.Names[0] != "example.com" {
t.Errorf("Didn't get a cert for 'Example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted) t.Errorf("Didn't get a cert for 'Example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
} }
if cert, matched, defaulted := getCertificate(""); !matched || defaulted || cert.Names[0] != "example.com" { if cert, matched, defaulted := cfg.getCertificate("example.com"); !matched || defaulted || cert.Names[0] != "example.com" {
t.Errorf("Didn't get a cert for '' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted) t.Errorf("Didn't get a cert for 'example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
} }
// When retrieving wildcard certificate // When retrieving wildcard certificate
certCache["*.example.com"] = Certificate{Names: []string{"*.example.com"}} certCache.cache["0xb01dface"] = Certificate{Names: []string{"*.example.com"}}
if cert, matched, defaulted := getCertificate("sub.example.com"); !matched || defaulted || cert.Names[0] != "*.example.com" { cfg.Certificates["*.example.com"] = "0xb01dface"
if cert, matched, defaulted := cfg.getCertificate("sub.example.com"); !matched || defaulted || cert.Names[0] != "*.example.com" {
t.Errorf("Didn't get wildcard cert for 'sub.example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted) t.Errorf("Didn't get wildcard cert for 'sub.example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
} }
// When no certificate matches, the default is returned // When no certificate matches and SNI is provided, return no certificate (should be TLS alert)
if cert, matched, defaulted := getCertificate("nomatch"); matched || !defaulted { if cert, matched, defaulted := cfg.getCertificate("nomatch"); matched || defaulted {
t.Errorf("Expected matched=false, defaulted=false; but got matched=%v, defaulted=%v (cert: %v)", matched, defaulted, cert)
}
// When no certificate matches and SNI is NOT provided, a random is returned
if cert, matched, defaulted := cfg.getCertificate(""); matched || !defaulted {
t.Errorf("Expected matched=false, defaulted=true; but got matched=%v, defaulted=%v (cert: %v)", matched, defaulted, cert) t.Errorf("Expected matched=false, defaulted=true; but got matched=%v, defaulted=%v (cert: %v)", matched, defaulted, cert)
} else if cert.Names[0] != "example.com" {
t.Errorf("Expected default cert, got: %v", cert)
} }
} }
func TestCacheCertificate(t *testing.T) { func TestCacheCertificate(t *testing.T) {
defer func() { certCache = make(map[string]Certificate) }() certCache := &certificateCache{cache: make(map[string]Certificate)}
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
cacheCertificate(Certificate{Names: []string{"example.com", "sub.example.com"}}) cfg.cacheCertificate(Certificate{Names: []string{"example.com", "sub.example.com"}, Hash: "foobar"})
if _, ok := certCache["example.com"]; !ok { if len(certCache.cache) != 1 {
t.Error("Expected first cert to be cached by key 'example.com', but it wasn't") t.Errorf("Expected length of certificate cache to be 1")
}
if _, ok := certCache.cache["foobar"]; !ok {
t.Error("Expected first cert to be cached by key 'foobar', but it wasn't")
} }
if _, ok := certCache["sub.example.com"]; !ok { if _, ok := cfg.Certificates["example.com"]; !ok {
t.Error("Expected first cert to be cached by key 'sub.example.com', but it wasn't") t.Error("Expected first cert to be keyed by 'example.com', but it wasn't")
} }
if cert, ok := certCache[""]; !ok || cert.Names[2] != "" { if _, ok := cfg.Certificates["sub.example.com"]; !ok {
t.Error("Expected first cert to be cached additionally as the default certificate with empty name added, but it wasn't") t.Error("Expected first cert to be keyed by 'sub.example.com', but it wasn't")
} }
cacheCertificate(Certificate{Names: []string{"example2.com"}}) // different config, but using same cache; and has cert with overlapping name,
if _, ok := certCache["example2.com"]; !ok { // but different hash
t.Error("Expected second cert to be cached by key 'exmaple2.com', but it wasn't") cfg2 := &Config{Certificates: make(map[string]string), certCache: certCache}
cfg2.cacheCertificate(Certificate{Names: []string{"example.com"}, Hash: "barbaz"})
if _, ok := certCache.cache["barbaz"]; !ok {
t.Error("Expected second cert to be cached by key 'barbaz.com', but it wasn't")
} }
if cert, ok := certCache[""]; ok && cert.Names[0] == "example2.com" { if hash, ok := cfg2.Certificates["example.com"]; !ok {
t.Error("Expected second cert to NOT be cached as default, but it was") t.Error("Expected second cert to be keyed by 'example.com', but it wasn't")
} else if hash != "barbaz" {
t.Errorf("Expected second cert to map to 'barbaz' but it was %s instead", hash)
} }
} }
...@@ -39,7 +39,7 @@ type ACMEClient struct { ...@@ -39,7 +39,7 @@ type ACMEClient struct {
AllowPrompts bool AllowPrompts bool
config *Config config *Config
acmeClient *acme.Client acmeClient *acme.Client
locker Locker storage Storage
} }
// newACMEClient creates a new ACMEClient given an email and whether // newACMEClient creates a new ACMEClient given an email and whether
...@@ -121,10 +121,7 @@ var newACMEClient = func(config *Config, allowPrompts bool) (*ACMEClient, error) ...@@ -121,10 +121,7 @@ var newACMEClient = func(config *Config, allowPrompts bool) (*ACMEClient, error)
AllowPrompts: allowPrompts, AllowPrompts: allowPrompts,
config: config, config: config,
acmeClient: client, acmeClient: client,
locker: &syncLock{ storage: storage,
nameLocks: make(map[string]*sync.WaitGroup),
nameLocksMu: sync.Mutex{},
},
} }
if config.DNSProvider == "" { if config.DNSProvider == "" {
...@@ -160,7 +157,7 @@ var newACMEClient = func(config *Config, allowPrompts bool) (*ACMEClient, error) ...@@ -160,7 +157,7 @@ var newACMEClient = func(config *Config, allowPrompts bool) (*ACMEClient, error)
// See if TLS challenge needs to be handled by our own facilities // See if TLS challenge needs to be handled by our own facilities
if caddy.HasListenerWithAddress(net.JoinHostPort(config.ListenHost, useTLSSNIPort)) { if caddy.HasListenerWithAddress(net.JoinHostPort(config.ListenHost, useTLSSNIPort)) {
c.acmeClient.SetChallengeProvider(acme.TLSSNI01, tlsSniSolver{}) c.acmeClient.SetChallengeProvider(acme.TLSSNI01, tlsSNISolver{certCache: config.certCache})
} }
// Disable any challenges that should not be used // Disable any challenges that should not be used
...@@ -209,13 +206,7 @@ var newACMEClient = func(config *Config, allowPrompts bool) (*ACMEClient, error) ...@@ -209,13 +206,7 @@ var newACMEClient = func(config *Config, allowPrompts bool) (*ACMEClient, error)
// Callers who have access to a Config value should use the ObtainCert // Callers who have access to a Config value should use the ObtainCert
// method on that instead of this lower-level method. // method on that instead of this lower-level method.
func (c *ACMEClient) Obtain(name string) error { func (c *ACMEClient) Obtain(name string) error {
// Get access to ACME storage waiter, err := c.storage.TryLock(name)
storage, err := c.config.StorageFor(c.config.CAUrl)
if err != nil {
return err
}
waiter, err := c.locker.TryLock(name)
if err != nil { if err != nil {
return err return err
} }
...@@ -225,7 +216,7 @@ func (c *ACMEClient) Obtain(name string) error { ...@@ -225,7 +216,7 @@ func (c *ACMEClient) Obtain(name string) error {
return nil // we assume the process with the lock succeeded, rather than hammering this execution path again return nil // we assume the process with the lock succeeded, rather than hammering this execution path again
} }
defer func() { defer func() {
if err := c.locker.Unlock(name); err != nil { if err := c.storage.Unlock(name); err != nil {
log.Printf("[ERROR] Unable to unlock obtain call for %s: %v", name, err) log.Printf("[ERROR] Unable to unlock obtain call for %s: %v", name, err)
} }
}() }()
...@@ -268,7 +259,7 @@ Attempts: ...@@ -268,7 +259,7 @@ Attempts:
} }
// Success - immediately save the certificate resource // Success - immediately save the certificate resource
err = saveCertResource(storage, certificate) err = saveCertResource(c.storage, certificate)
if err != nil { if err != nil {
return fmt.Errorf("error saving assets for %v: %v", name, err) return fmt.Errorf("error saving assets for %v: %v", name, err)
} }
...@@ -279,35 +270,30 @@ Attempts: ...@@ -279,35 +270,30 @@ Attempts:
return nil return nil
} }
// Renew renews the managed certificate for name. This function is // Renew renews the managed certificate for name. It puts the renewed
// safe for concurrent use. // certificate into storage (not the cache). This function is safe for
// concurrent use.
// //
// Callers who have access to a Config value should use the RenewCert // Callers who have access to a Config value should use the RenewCert
// method on that instead of this lower-level method. // method on that instead of this lower-level method.
func (c *ACMEClient) Renew(name string) error { func (c *ACMEClient) Renew(name string) error {
// Get access to ACME storage waiter, err := c.storage.TryLock(name)
storage, err := c.config.StorageFor(c.config.CAUrl)
if err != nil {
return err
}
waiter, err := c.locker.TryLock(name)
if err != nil { if err != nil {
return err return err
} }
if waiter != nil { if waiter != nil {
log.Printf("[INFO] Certificate for %s is already being renewed elsewhere and stored; waiting", name) log.Printf("[INFO] Certificate for %s is already being renewed elsewhere and stored; waiting", name)
waiter.Wait() waiter.Wait()
return nil // we assume the process with the lock succeeded, rather than hammering this execution path again return nil // assume that the worker that renewed the cert succeeded; avoid hammering this path over and over
} }
defer func() { defer func() {
if err := c.locker.Unlock(name); err != nil { if err := c.storage.Unlock(name); err != nil {
log.Printf("[ERROR] Unable to unlock renew call for %s: %v", name, err) log.Printf("[ERROR] Unable to unlock renew call for %s: %v", name, err)
} }
}() }()
// Prepare for renewal (load PEM cert, key, and meta) // Prepare for renewal (load PEM cert, key, and meta)
siteData, err := storage.LoadSite(name) siteData, err := c.storage.LoadSite(name)
if err != nil { if err != nil {
return err return err
} }
...@@ -350,21 +336,15 @@ func (c *ACMEClient) Renew(name string) error { ...@@ -350,21 +336,15 @@ func (c *ACMEClient) Renew(name string) error {
return errors.New("too many renewal attempts; last error: " + err.Error()) return errors.New("too many renewal attempts; last error: " + err.Error())
} }
// Executes Cert renew events
caddy.EmitEvent(caddy.CertRenewEvent, name) caddy.EmitEvent(caddy.CertRenewEvent, name)
return saveCertResource(storage, newCertMeta) return saveCertResource(c.storage, newCertMeta)
} }
// Revoke revokes the certificate for name and deltes // Revoke revokes the certificate for name and deletes
// it from storage. // it from storage.
func (c *ACMEClient) Revoke(name string) error { func (c *ACMEClient) Revoke(name string) error {
storage, err := c.config.StorageFor(c.config.CAUrl) siteExists, err := c.storage.SiteExists(name)
if err != nil {
return err
}
siteExists, err := storage.SiteExists(name)
if err != nil { if err != nil {
return err return err
} }
...@@ -373,7 +353,7 @@ func (c *ACMEClient) Revoke(name string) error { ...@@ -373,7 +353,7 @@ func (c *ACMEClient) Revoke(name string) error {
return errors.New("no certificate and key for " + name) return errors.New("no certificate and key for " + name)
} }
siteData, err := storage.LoadSite(name) siteData, err := c.storage.LoadSite(name)
if err != nil { if err != nil {
return err return err
} }
...@@ -383,7 +363,7 @@ func (c *ACMEClient) Revoke(name string) error { ...@@ -383,7 +363,7 @@ func (c *ACMEClient) Revoke(name string) error {
return err return err
} }
err = storage.DeleteSite(name) err = c.storage.DeleteSite(name)
if err != nil { if err != nil {
return errors.New("certificate revoked, but unable to delete certificate file: " + err.Error()) return errors.New("certificate revoked, but unable to delete certificate file: " + err.Error())
} }
......
...@@ -93,16 +93,17 @@ type Config struct { ...@@ -93,16 +93,17 @@ type Config struct {
// an ACME challenge // an ACME challenge
ListenHost string ListenHost string
// The alternate port (ONLY port, not host) // The alternate port (ONLY port, not host) to
// to use for the ACME HTTP challenge; this // use for the ACME HTTP challenge; if non-empty,
// port will be used if we proxy challenges // this port will be used instead of
// coming in on port 80 to this alternate port // HTTPChallengePort to spin up a listener for
// the HTTP challenge
AltHTTPPort string AltHTTPPort string
// The alternate port (ONLY port, not host) // The alternate port (ONLY port, not host)
// to use for the ACME TLS-SNI challenge. // to use for the ACME TLS-SNI challenge.
// The system must forward the standard port // The system must forward TLSSNIChallengePort
// for the TLS-SNI challenge to this port. // to this port for challenge to succeed
AltTLSSNIPort string AltTLSSNIPort string
// The string identifier of the DNS provider // The string identifier of the DNS provider
...@@ -134,7 +135,12 @@ type Config struct { ...@@ -134,7 +135,12 @@ type Config struct {
// Protocol Negotiation (ALPN). // Protocol Negotiation (ALPN).
ALPN []string ALPN []string
tlsConfig *tls.Config // the final tls.Config created with buildStandardTLSConfig() // The map of hostname to certificate hash. This is used to complete
// handshakes and serve the right certificate given the SNI.
Certificates map[string]string
certCache *certificateCache // pointer to the Instance's certificate store
tlsConfig *tls.Config // the final tls.Config created with buildStandardTLSConfig()
} }
// OnDemandState contains some state relevant for providing // OnDemandState contains some state relevant for providing
...@@ -155,6 +161,25 @@ type OnDemandState struct { ...@@ -155,6 +161,25 @@ type OnDemandState struct {
AskURL *url.URL AskURL *url.URL
} }
// NewConfig returns a new Config with a pointer to the instance's
// certificate cache. You will usually need to set Other fields on
// the returned Config for successful practical use.
func NewConfig(inst *caddy.Instance) *Config {
inst.StorageMu.RLock()
certCache, ok := inst.Storage[CertCacheInstStorageKey].(*certificateCache)
inst.StorageMu.RUnlock()
if !ok || certCache == nil {
certCache = &certificateCache{cache: make(map[string]Certificate)}
inst.StorageMu.Lock()
inst.Storage[CertCacheInstStorageKey] = certCache
inst.StorageMu.Unlock()
}
cfg := new(Config)
cfg.Certificates = make(map[string]string)
cfg.certCache = certCache
return cfg
}
// ObtainCert obtains a certificate for name using c, as long // ObtainCert obtains a certificate for name using c, as long
// as a certificate does not already exist in storage for that // as a certificate does not already exist in storage for that
// name. The name must qualify and c must be flagged as Managed. // name. The name must qualify and c must be flagged as Managed.
...@@ -330,7 +355,9 @@ func (c *Config) buildStandardTLSConfig() error { ...@@ -330,7 +355,9 @@ func (c *Config) buildStandardTLSConfig() error {
// MakeTLSConfig makes a tls.Config from configs. The returned // MakeTLSConfig makes a tls.Config from configs. The returned
// tls.Config is programmed to load the matching caddytls.Config // tls.Config is programmed to load the matching caddytls.Config
// based on the hostname in SNI, but that's all. // based on the hostname in SNI, but that's all. This is used
// to create a single TLS configuration for a listener (a group
// of sites).
func MakeTLSConfig(configs []*Config) (*tls.Config, error) { func MakeTLSConfig(configs []*Config) (*tls.Config, error) {
if len(configs) == 0 { if len(configs) == 0 {
return nil, nil return nil, nil
...@@ -358,15 +385,28 @@ func MakeTLSConfig(configs []*Config) (*tls.Config, error) { ...@@ -358,15 +385,28 @@ func MakeTLSConfig(configs []*Config) (*tls.Config, error) {
configs[i-1].Hostname, lastConfProto, cfg.Hostname, thisConfProto) configs[i-1].Hostname, lastConfProto, cfg.Hostname, thisConfProto)
} }
// convert each caddytls.Config into a tls.Config // convert this caddytls.Config into a tls.Config
if err := cfg.buildStandardTLSConfig(); err != nil { if err := cfg.buildStandardTLSConfig(); err != nil {
return nil, err return nil, err
} }
// Key this config by its hostname (overwriting // if an existing config with this hostname was already
// configs with the same hostname pattern); during // configured, then they must be identical (or at least
// TLS handshakes, configs are loaded based on // compatible), otherwise that is a configuration error
// the hostname pattern, according to client's SNI. if otherConfig, ok := configMap[cfg.Hostname]; ok {
if err := assertConfigsCompatible(cfg, otherConfig); err != nil {
return nil, fmt.Errorf("incompabile TLS configurations for the same SNI "+
"name (%s) on the same listener: %v",
cfg.Hostname, err)
}
}
// key this config by its hostname (overwrites
// configs with the same hostname pattern; should
// be OK since we already asserted they are roughly
// the same); during TLS handshakes, configs are
// loaded based on the hostname pattern, according
// to client's SNI
configMap[cfg.Hostname] = cfg configMap[cfg.Hostname] = cfg
} }
...@@ -383,6 +423,63 @@ func MakeTLSConfig(configs []*Config) (*tls.Config, error) { ...@@ -383,6 +423,63 @@ func MakeTLSConfig(configs []*Config) (*tls.Config, error) {
}, nil }, nil
} }
// assertConfigsCompatible returns an error if the two Configs
// do not have the same (or roughly compatible) configurations.
// If one of the tlsConfig pointers on either Config is nil,
// an error will be returned. If both are nil, no error.
func assertConfigsCompatible(cfg1, cfg2 *Config) error {
c1, c2 := cfg1.tlsConfig, cfg2.tlsConfig
if (c1 == nil && c2 != nil) || (c1 != nil && c2 == nil) {
return fmt.Errorf("one config is not made")
}
if c1 == nil && c2 == nil {
return nil
}
if len(c1.CipherSuites) != len(c2.CipherSuites) {
return fmt.Errorf("different number of allowed cipher suites")
}
for i, ciph := range c1.CipherSuites {
if c2.CipherSuites[i] != ciph {
return fmt.Errorf("different cipher suites or different order")
}
}
if len(c1.CurvePreferences) != len(c2.CurvePreferences) {
return fmt.Errorf("different number of allowed cipher suites")
}
for i, curve := range c1.CurvePreferences {
if c2.CurvePreferences[i] != curve {
return fmt.Errorf("different curve preferences or different order")
}
}
if len(c1.NextProtos) != len(c2.NextProtos) {
return fmt.Errorf("different number of ALPN (NextProtos) values")
}
for i, proto := range c1.NextProtos {
if c2.NextProtos[i] != proto {
return fmt.Errorf("different ALPN (NextProtos) values or different order")
}
}
if c1.PreferServerCipherSuites != c2.PreferServerCipherSuites {
return fmt.Errorf("one prefers server cipher suites, the other does not")
}
if c1.MinVersion != c2.MinVersion {
return fmt.Errorf("minimum TLS version mismatch")
}
if c1.MaxVersion != c2.MaxVersion {
return fmt.Errorf("maximum TLS version mismatch")
}
if c1.ClientAuth != c2.ClientAuth {
return fmt.Errorf("client authentication policy mismatch")
}
return nil
}
// ConfigGetter gets a Config keyed by key. // ConfigGetter gets a Config keyed by key.
type ConfigGetter func(c *caddy.Controller) *Config type ConfigGetter func(c *caddy.Controller) *Config
...@@ -522,7 +619,7 @@ var supportedCurvesMap = map[string]tls.CurveID{ ...@@ -522,7 +619,7 @@ var supportedCurvesMap = map[string]tls.CurveID{
"P521": tls.CurveP521, "P521": tls.CurveP521,
} }
// List of all the curves we want to use by default // List of all the curves we want to use by default.
// //
// This list should only include curves which are fast by design (e.g. X25519) // This list should only include curves which are fast by design (e.g. X25519)
// and those for which an optimized assembly implementation exists (e.g. P256). // and those for which an optimized assembly implementation exists (e.g. P256).
...@@ -548,4 +645,8 @@ const ( ...@@ -548,4 +645,8 @@ const (
// be capable of proxying or forwarding the request to this // be capable of proxying or forwarding the request to this
// alternate port. // alternate port.
DefaultHTTPAlternatePort = "5033" DefaultHTTPAlternatePort = "5033"
// CertCacheInstStorageKey is the name of the key for
// accessing the certificate storage on the *caddy.Instance.
CertCacheInstStorageKey = "tls_cert_cache"
) )
...@@ -237,15 +237,17 @@ func makeSelfSignedCert(config *Config) error { ...@@ -237,15 +237,17 @@ func makeSelfSignedCert(config *Config) error {
return fmt.Errorf("could not create certificate: %v", err) return fmt.Errorf("could not create certificate: %v", err)
} }
cacheCertificate(Certificate{ chain := [][]byte{derBytes}
config.cacheCertificate(Certificate{
Certificate: tls.Certificate{ Certificate: tls.Certificate{
Certificate: [][]byte{derBytes}, Certificate: chain,
PrivateKey: privKey, PrivateKey: privKey,
Leaf: cert, Leaf: cert,
}, },
Names: cert.DNSNames, Names: cert.DNSNames,
NotAfter: cert.NotAfter, NotAfter: cert.NotAfter,
Config: config, Hash: hashCertificateChain(chain),
}) })
return nil return nil
......
...@@ -38,9 +38,9 @@ var storageBasePath = filepath.Join(caddy.AssetsPath(), "acme") ...@@ -38,9 +38,9 @@ var storageBasePath = filepath.Join(caddy.AssetsPath(), "acme")
// Storage instance backed by the local disk. The resulting Storage // Storage instance backed by the local disk. The resulting Storage
// instance is guaranteed to be non-nil if there is no error. // instance is guaranteed to be non-nil if there is no error.
func NewFileStorage(caURL *url.URL) (Storage, error) { func NewFileStorage(caURL *url.URL) (Storage, error) {
return &FileStorage{ storage := &FileStorage{Path: filepath.Join(storageBasePath, caURL.Host)}
Path: filepath.Join(storageBasePath, caURL.Host), storage.Locker = &fileStorageLock{caURL: caURL.Host, storage: storage}
}, nil return storage, nil
} }
// FileStorage facilitates forming file paths derived from a root // FileStorage facilitates forming file paths derived from a root
...@@ -48,6 +48,7 @@ func NewFileStorage(caURL *url.URL) (Storage, error) { ...@@ -48,6 +48,7 @@ func NewFileStorage(caURL *url.URL) (Storage, error) {
// cross-platform way or persisting ACME assets on the file system. // cross-platform way or persisting ACME assets on the file system.
type FileStorage struct { type FileStorage struct {
Path string Path string
Locker
} }
// sites gets the directory that stores site certificate and keys. // sites gets the directory that stores site certificate and keys.
......
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package caddytls
import (
"fmt"
"os"
"sync"
"time"
"github.com/mholt/caddy"
)
func init() {
// be sure to remove lock files when exiting the process!
caddy.OnProcessExit = append(caddy.OnProcessExit, func() {
fileStorageNameLocksMu.Lock()
defer fileStorageNameLocksMu.Unlock()
for key, fw := range fileStorageNameLocks {
os.Remove(fw.filename)
delete(fileStorageNameLocks, key)
}
})
}
// fileStorageLock facilitates ACME-related locking by using
// the associated FileStorage, so multiple processes can coordinate
// renewals on the certificates on a shared file system.
type fileStorageLock struct {
caURL string
storage *FileStorage
}
// TryLock attempts to get a lock for name, otherwise it returns
// a Waiter value to wait until the other process is finished.
func (s *fileStorageLock) TryLock(name string) (Waiter, error) {
fileStorageNameLocksMu.Lock()
defer fileStorageNameLocksMu.Unlock()
// see if lock already exists within this process
fw, ok := fileStorageNameLocks[s.caURL+name]
if ok {
// lock already created within process, let caller wait on it
return fw, nil
}
// attempt to persist lock to disk by creating lock file
fw = &fileWaiter{
filename: s.storage.siteCertFile(name) + ".lock",
wg: new(sync.WaitGroup),
}
// parent dir must exist
if err := os.MkdirAll(s.storage.site(name), 0700); err != nil {
return nil, err
}
lf, err := os.OpenFile(fw.filename, os.O_CREATE|os.O_EXCL, 0644)
if err != nil {
if os.IsExist(err) {
// another process has the lock; use it to wait
return fw, nil
}
// otherwise, this was some unexpected error
return nil, err
}
lf.Close()
// looks like we get the lock
fw.wg.Add(1)
fileStorageNameLocks[s.caURL+name] = fw
return nil, nil
}
// Unlock unlocks name.
func (s *fileStorageLock) Unlock(name string) error {
fileStorageNameLocksMu.Lock()
defer fileStorageNameLocksMu.Unlock()
fw, ok := fileStorageNameLocks[s.caURL+name]
if !ok {
return fmt.Errorf("FileStorage: no lock to release for %s", name)
}
os.Remove(fw.filename)
fw.wg.Done()
delete(fileStorageNameLocks, s.caURL+name)
return nil
}
// fileWaiter waits for a file to disappear; it polls
// the file system to check for the existence of a file.
// It also has a WaitGroup which will be faster than
// polling, for when locking need only happen within this
// process.
type fileWaiter struct {
filename string
wg *sync.WaitGroup
}
// Wait waits until the lock is released.
func (fw *fileWaiter) Wait() {
start := time.Now()
fw.wg.Wait()
for time.Since(start) < 1*time.Hour {
_, err := os.Stat(fw.filename)
if os.IsNotExist(err) {
return
}
time.Sleep(1 * time.Second)
}
}
var fileStorageNameLocks = make(map[string]*fileWaiter) // keyed by CA + name
var fileStorageNameLocksMu sync.Mutex
var _ Locker = &fileStorageLock{}
var _ Waiter = &fileWaiter{}
...@@ -59,15 +59,15 @@ func (cg configGroup) getConfig(name string) *Config { ...@@ -59,15 +59,15 @@ func (cg configGroup) getConfig(name string) *Config {
} }
} }
// as a fallback, try a config that serves all names // try a config that serves all names (this
// is basically the same as a config defined
// for "*" -- I think -- but the above loop
// doesn't try an empty string)
if config, ok := cg[""]; ok { if config, ok := cg[""]; ok {
return config return config
} }
// as a last resort, use a random config // no matches, so just serve up a random config
// (even if the config isn't for that hostname,
// it should help us serve clients without SNI
// or at least defer TLS alerts to the cert)
for _, config := range cg { for _, config := range cg {
return config return config
} }
...@@ -102,6 +102,86 @@ func (cfg *Config) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certif ...@@ -102,6 +102,86 @@ func (cfg *Config) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certif
return &cert.Certificate, err return &cert.Certificate, err
} }
// getCertificate gets a certificate that matches name (a server name)
// from the in-memory cache, according to the lookup table associated with
// cfg. The lookup then points to a certificate in the Instance certificate
// cache.
//
// If there is no exact match for name, it will be checked against names of
// the form '*.example.com' (wildcard certificates) according to RFC 6125.
// If a match is found, matched will be true. If no matches are found, matched
// will be false and a "default" certificate will be returned with defaulted
// set to true. If defaulted is false, then no certificates were available.
//
// The logic in this function is adapted from the Go standard library,
// which is by the Go Authors.
//
// This function is safe for concurrent use.
func (cfg *Config) getCertificate(name string) (cert Certificate, matched, defaulted bool) {
var certKey string
var ok bool
// Not going to trim trailing dots here since RFC 3546 says,
// "The hostname is represented ... without a trailing dot."
// Just normalize to lowercase.
name = strings.ToLower(name)
cfg.certCache.RLock()
defer cfg.certCache.RUnlock()
// exact match? great, let's use it
if certKey, ok = cfg.Certificates[name]; ok {
cert = cfg.certCache.cache[certKey]
matched = true
return
}
// try replacing labels in the name with wildcards until we get a match
labels := strings.Split(name, ".")
for i := range labels {
labels[i] = "*"
candidate := strings.Join(labels, ".")
if certKey, ok = cfg.Certificates[candidate]; ok {
cert = cfg.certCache.cache[certKey]
matched = true
return
}
}
// check the certCache directly to see if the SNI name is
// already the key of the certificate it wants! this is vital
// for supporting the TLS-SNI challenge, since the tlsSNISolver
// just puts the temporary certificate in the instance cache,
// with no regard for configs; this also means that the SNI
// can contain the hash of a specific cert (chain) it wants
// and we will still be able to serve it up
// (this behavior, by the way, could be controversial as to
// whether it complies with RFC 6066 about SNI, but I think
// it does soooo...)
// NOTE/TODO: TLS-SNI challenge is changing, as of Jan. 2018
// but what will be different, if it ever returns, is unclear
if directCert, ok := cfg.certCache.cache[name]; ok {
cert = directCert
matched = true
return
}
// if nothing matches and SNI was not provided, use a random
// certificate; at least there's a chance this older client
// can connect, and in the future we won't need this provision
// (if SNI is present, it's probably best to just raise a TLS
// alert by not serving a certificate)
if name == "" {
for _, certKey := range cfg.Certificates {
defaulted = true
cert = cfg.certCache.cache[certKey]
return
}
}
return
}
// getCertDuringHandshake will get a certificate for name. It first tries // getCertDuringHandshake will get a certificate for name. It first tries
// the in-memory cache. If no certificate for name is in the cache, the // the in-memory cache. If no certificate for name is in the cache, the
// config most closely corresponding to name will be loaded. If that config // config most closely corresponding to name will be loaded. If that config
...@@ -115,7 +195,7 @@ func (cfg *Config) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certif ...@@ -115,7 +195,7 @@ func (cfg *Config) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certif
// This function is safe for concurrent use. // This function is safe for concurrent use.
func (cfg *Config) getCertDuringHandshake(name string, loadIfNecessary, obtainIfNecessary bool) (Certificate, error) { func (cfg *Config) getCertDuringHandshake(name string, loadIfNecessary, obtainIfNecessary bool) (Certificate, error) {
// First check our in-memory cache to see if we've already loaded it // First check our in-memory cache to see if we've already loaded it
cert, matched, defaulted := getCertificate(name) cert, matched, defaulted := cfg.getCertificate(name)
if matched { if matched {
return cert, nil return cert, nil
} }
...@@ -258,7 +338,7 @@ func (cfg *Config) obtainOnDemandCertificate(name string) (Certificate, error) { ...@@ -258,7 +338,7 @@ func (cfg *Config) obtainOnDemandCertificate(name string) (Certificate, error) {
obtainCertWaitChans[name] = wait obtainCertWaitChans[name] = wait
obtainCertWaitChansMu.Unlock() obtainCertWaitChansMu.Unlock()
// do the obtain // obtain the certificate
log.Printf("[INFO] Obtaining new certificate for %s", name) log.Printf("[INFO] Obtaining new certificate for %s", name)
err := cfg.ObtainCert(name, false) err := cfg.ObtainCert(name, false)
...@@ -317,9 +397,9 @@ func (cfg *Config) handshakeMaintenance(name string, cert Certificate) (Certific ...@@ -317,9 +397,9 @@ func (cfg *Config) handshakeMaintenance(name string, cert Certificate) (Certific
// quite common considering not all certs have issuer URLs that support it. // quite common considering not all certs have issuer URLs that support it.
log.Printf("[ERROR] Getting OCSP for %s: %v", name, err) log.Printf("[ERROR] Getting OCSP for %s: %v", name, err)
} }
certCacheMu.Lock() cfg.certCache.Lock()
certCache[name] = cert cfg.certCache.cache[cert.Hash] = cert
certCacheMu.Unlock() cfg.certCache.Unlock()
} }
} }
...@@ -348,29 +428,22 @@ func (cfg *Config) renewDynamicCertificate(name string, currentCert Certificate) ...@@ -348,29 +428,22 @@ func (cfg *Config) renewDynamicCertificate(name string, currentCert Certificate)
obtainCertWaitChans[name] = wait obtainCertWaitChans[name] = wait
obtainCertWaitChansMu.Unlock() obtainCertWaitChansMu.Unlock()
// do the renew and reload the certificate // renew and reload the certificate
log.Printf("[INFO] Renewing certificate for %s", name) log.Printf("[INFO] Renewing certificate for %s", name)
err := cfg.RenewCert(name, false) err := cfg.RenewCert(name, false)
if err == nil { if err == nil {
// immediately flush this certificate from the cache so
// the name doesn't overlap when we try to replace it,
// which would fail, because overlapping existing cert
// names isn't allowed
certCacheMu.Lock()
for _, certName := range currentCert.Names {
delete(certCache, certName)
}
certCacheMu.Unlock()
// even though the recursive nature of the dynamic cert loading // even though the recursive nature of the dynamic cert loading
// would just call this function anyway, we do it here to // would just call this function anyway, we do it here to
// make the replacement as atomic as possible. (TODO: similar // make the replacement as atomic as possible.
// to the note in maintain.go, it'd be nice if the clearing of newCert, err := currentCert.configs[0].CacheManagedCertificate(name)
// the cache entries above and this load function were truly
// atomic...)
_, err := currentCert.Config.CacheManagedCertificate(name)
if err != nil { if err != nil {
log.Printf("[ERROR] loading renewed certificate: %v", err) log.Printf("[ERROR] loading renewed certificate for %s: %v", name, err)
} else {
// replace the old certificate with the new one
err = cfg.certCache.replaceCertificate(currentCert, newCert)
if err != nil {
log.Printf("[ERROR] Replacing certificate for %s: %v", name, err)
}
} }
} }
......
...@@ -21,9 +21,8 @@ import ( ...@@ -21,9 +21,8 @@ import (
) )
func TestGetCertificate(t *testing.T) { func TestGetCertificate(t *testing.T) {
defer func() { certCache = make(map[string]Certificate) }() certCache := &certificateCache{cache: make(map[string]Certificate)}
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
cfg := new(Config)
hello := &tls.ClientHelloInfo{ServerName: "example.com"} hello := &tls.ClientHelloInfo{ServerName: "example.com"}
helloSub := &tls.ClientHelloInfo{ServerName: "sub.example.com"} helloSub := &tls.ClientHelloInfo{ServerName: "sub.example.com"}
...@@ -38,33 +37,40 @@ func TestGetCertificate(t *testing.T) { ...@@ -38,33 +37,40 @@ func TestGetCertificate(t *testing.T) {
t.Errorf("GetCertificate should return error when cache is empty even if server name is blank, got: %v", cert) t.Errorf("GetCertificate should return error when cache is empty even if server name is blank, got: %v", cert)
} }
// When cache has one certificate in it (also is default) // When cache has one certificate in it
defaultCert := Certificate{Names: []string{"example.com", ""}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"example.com"}}}} firstCert := Certificate{Names: []string{"example.com"}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"example.com"}}}}
certCache[""] = defaultCert cfg.cacheCertificate(firstCert)
certCache["example.com"] = defaultCert
if cert, err := cfg.GetCertificate(hello); err != nil { if cert, err := cfg.GetCertificate(hello); err != nil {
t.Errorf("Got an error but shouldn't have, when cert exists in cache: %v", err) t.Errorf("Got an error but shouldn't have, when cert exists in cache: %v", err)
} else if cert.Leaf.DNSNames[0] != "example.com" { } else if cert.Leaf.DNSNames[0] != "example.com" {
t.Errorf("Got wrong certificate with exact match; expected 'example.com', got: %v", cert) t.Errorf("Got wrong certificate with exact match; expected 'example.com', got: %v", cert)
} }
if cert, err := cfg.GetCertificate(helloNoSNI); err != nil { if _, err := cfg.GetCertificate(helloNoSNI); err != nil {
t.Errorf("Got an error with no SNI but shouldn't have, when cert exists in cache: %v", err) t.Errorf("Got an error with no SNI but shouldn't have, when cert exists in cache: %v", err)
} else if cert.Leaf.DNSNames[0] != "example.com" {
t.Errorf("Got wrong certificate for no SNI; expected 'example.com' as default, got: %v", cert)
} }
// When retrieving wildcard certificate // When retrieving wildcard certificate
certCache["*.example.com"] = Certificate{Names: []string{"*.example.com"}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"*.example.com"}}}} wildcardCert := Certificate{
Names: []string{"*.example.com"},
Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"*.example.com"}}},
Hash: "(don't overwrite the first one)",
}
cfg.cacheCertificate(wildcardCert)
if cert, err := cfg.GetCertificate(helloSub); err != nil { if cert, err := cfg.GetCertificate(helloSub); err != nil {
t.Errorf("Didn't get wildcard cert, got: cert=%v, err=%v ", cert, err) t.Errorf("Didn't get wildcard cert, got: cert=%v, err=%v ", cert, err)
} else if cert.Leaf.DNSNames[0] != "*.example.com" { } else if cert.Leaf.DNSNames[0] != "*.example.com" {
t.Errorf("Got wrong certificate, expected wildcard: %v", cert) t.Errorf("Got wrong certificate, expected wildcard: %v", cert)
} }
// When no certificate matches, the default is returned // When cache is NOT empty but there's no SNI
if cert, err := cfg.GetCertificate(helloNoMatch); err != nil { if cert, err := cfg.GetCertificate(helloNoSNI); err != nil {
t.Errorf("Expected default certificate with no error when no matches, got err: %v", err) t.Errorf("Expected random certificate with no error when no SNI, got err: %v", err)
} else if cert.Leaf.DNSNames[0] != "example.com" { } else if cert == nil || len(cert.Leaf.DNSNames) == 0 {
t.Errorf("Expected default cert with no matches, got: %v", cert) t.Errorf("Expected random cert with no matches, got: %v", cert)
}
// When no certificate matches, raise an alert
if _, err := cfg.GetCertificate(helloNoMatch); err == nil {
t.Errorf("Expected an error when no certificate matched the SNI, got: %v", err)
} }
} }
...@@ -27,10 +27,11 @@ import ( ...@@ -27,10 +27,11 @@ import (
const challengeBasePath = "/.well-known/acme-challenge" const challengeBasePath = "/.well-known/acme-challenge"
// HTTPChallengeHandler proxies challenge requests to ACME client if the // HTTPChallengeHandler proxies challenge requests to ACME client if the
// request path starts with challengeBasePath. It returns true if it // request path starts with challengeBasePath, if the HTTP challenge is not
// handled the request and no more needs to be done; it returns false // disabled, and if we are known to be obtaining a certificate for the name.
// if this call was a no-op and the request still needs handling. // It returns true if it handled the request and no more needs to be done;
func HTTPChallengeHandler(w http.ResponseWriter, r *http.Request, listenHost, altPort string) bool { // it returns false if this call was a no-op and the request still needs handling.
func HTTPChallengeHandler(w http.ResponseWriter, r *http.Request, listenHost string) bool {
if !strings.HasPrefix(r.URL.Path, challengeBasePath) { if !strings.HasPrefix(r.URL.Path, challengeBasePath) {
return false return false
} }
...@@ -50,7 +51,11 @@ func HTTPChallengeHandler(w http.ResponseWriter, r *http.Request, listenHost, al ...@@ -50,7 +51,11 @@ func HTTPChallengeHandler(w http.ResponseWriter, r *http.Request, listenHost, al
listenHost = "localhost" listenHost = "localhost"
} }
upstream, err := url.Parse(fmt.Sprintf("%s://%s:%s", scheme, listenHost, altPort)) // always proxy to the DefaultHTTPAlternatePort because obviously the
// ACME challenge request already got into one of our HTTP handlers, so
// it means we must have started a HTTP listener on the alternate
// port instead; which is only accessible via listenHost
upstream, err := url.Parse(fmt.Sprintf("%s://%s:%s", scheme, listenHost, DefaultHTTPAlternatePort))
if err != nil { if err != nil {
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
log.Printf("[ERROR] ACME proxy handler: %v", err) log.Printf("[ERROR] ACME proxy handler: %v", err)
......
...@@ -39,7 +39,7 @@ func TestHTTPChallengeHandlerNoOp(t *testing.T) { ...@@ -39,7 +39,7 @@ func TestHTTPChallengeHandlerNoOp(t *testing.T) {
t.Fatalf("Could not craft request, got error: %v", err) t.Fatalf("Could not craft request, got error: %v", err)
} }
rw := httptest.NewRecorder() rw := httptest.NewRecorder()
if HTTPChallengeHandler(rw, req, "", DefaultHTTPAlternatePort) { if HTTPChallengeHandler(rw, req, "") {
t.Errorf("Got true with this URL, but shouldn't have: %s", url) t.Errorf("Got true with this URL, but shouldn't have: %s", url)
} }
} }
...@@ -76,7 +76,7 @@ func TestHTTPChallengeHandlerSuccess(t *testing.T) { ...@@ -76,7 +76,7 @@ func TestHTTPChallengeHandlerSuccess(t *testing.T) {
} }
rw := httptest.NewRecorder() rw := httptest.NewRecorder()
HTTPChallengeHandler(rw, req, "", DefaultHTTPAlternatePort) HTTPChallengeHandler(rw, req, "")
if !proxySuccess { if !proxySuccess {
t.Fatal("Expected request to be proxied, but it wasn't") t.Fatal("Expected request to be proxied, but it wasn't")
......
This diff is collapsed.
...@@ -38,6 +38,7 @@ func init() { ...@@ -38,6 +38,7 @@ func init() {
// are specified by the user in the config file. All the automatic HTTPS // are specified by the user in the config file. All the automatic HTTPS
// stuff comes later outside of this function. // stuff comes later outside of this function.
func setupTLS(c *caddy.Controller) error { func setupTLS(c *caddy.Controller) error {
// obtain the configGetter, which loads the config we're, uh, configuring
configGetter, ok := configGetters[c.ServerType()] configGetter, ok := configGetters[c.ServerType()]
if !ok { if !ok {
return fmt.Errorf("no caddytls.ConfigGetter for %s server type; must call RegisterConfigGetter", c.ServerType()) return fmt.Errorf("no caddytls.ConfigGetter for %s server type; must call RegisterConfigGetter", c.ServerType())
...@@ -47,6 +48,14 @@ func setupTLS(c *caddy.Controller) error { ...@@ -47,6 +48,14 @@ func setupTLS(c *caddy.Controller) error {
return fmt.Errorf("no caddytls.Config to set up for %s", c.Key) return fmt.Errorf("no caddytls.Config to set up for %s", c.Key)
} }
// the certificate cache is tied to the current caddy.Instance; get a pointer to it
certCache, ok := c.Get(CertCacheInstStorageKey).(*certificateCache)
if !ok || certCache == nil {
certCache = &certificateCache{cache: make(map[string]Certificate)}
c.Set(CertCacheInstStorageKey, certCache)
}
config.certCache = certCache
config.Enabled = true config.Enabled = true
for c.Next() { for c.Next() {
...@@ -237,7 +246,7 @@ func setupTLS(c *caddy.Controller) error { ...@@ -237,7 +246,7 @@ func setupTLS(c *caddy.Controller) error {
// load a single certificate and key, if specified // load a single certificate and key, if specified
if certificateFile != "" && keyFile != "" { if certificateFile != "" && keyFile != "" {
err := cacheUnmanagedCertificatePEMFile(certificateFile, keyFile) err := config.cacheUnmanagedCertificatePEMFile(certificateFile, keyFile)
if err != nil { if err != nil {
return c.Errf("Unable to load certificate and key files for '%s': %v", c.Key, err) return c.Errf("Unable to load certificate and key files for '%s': %v", c.Key, err)
} }
...@@ -246,7 +255,7 @@ func setupTLS(c *caddy.Controller) error { ...@@ -246,7 +255,7 @@ func setupTLS(c *caddy.Controller) error {
// load a directory of certificates, if specified // load a directory of certificates, if specified
if loadDir != "" { if loadDir != "" {
err := loadCertsInDir(c, loadDir) err := loadCertsInDir(config, c, loadDir)
if err != nil { if err != nil {
return err return err
} }
...@@ -273,7 +282,7 @@ func setupTLS(c *caddy.Controller) error { ...@@ -273,7 +282,7 @@ func setupTLS(c *caddy.Controller) error {
// https://cbonte.github.io/haproxy-dconv/configuration-1.5.html#5.1-crt // https://cbonte.github.io/haproxy-dconv/configuration-1.5.html#5.1-crt
// //
// This function may write to the log as it walks the directory tree. // This function may write to the log as it walks the directory tree.
func loadCertsInDir(c *caddy.Controller, dir string) error { func loadCertsInDir(cfg *Config, c *caddy.Controller, dir string) error {
return filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { return filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
if err != nil { if err != nil {
log.Printf("[WARNING] Unable to traverse into %s; skipping", path) log.Printf("[WARNING] Unable to traverse into %s; skipping", path)
...@@ -336,7 +345,7 @@ func loadCertsInDir(c *caddy.Controller, dir string) error { ...@@ -336,7 +345,7 @@ func loadCertsInDir(c *caddy.Controller, dir string) error {
return c.Errf("%s: no private key block found", path) return c.Errf("%s: no private key block found", path)
} }
err = cacheUnmanagedCertificatePEMBytes(certPEMBytes, keyPEMBytes) err = cfg.cacheUnmanagedCertificatePEMBytes(certPEMBytes, keyPEMBytes)
if err != nil { if err != nil {
return c.Errf("%s: failed to load cert and key for '%s': %v", path, c.Key, err) return c.Errf("%s: failed to load cert and key for '%s': %v", path, c.Key, err)
} }
......
...@@ -46,9 +46,12 @@ func TestMain(m *testing.M) { ...@@ -46,9 +46,12 @@ func TestMain(m *testing.M) {
} }
func TestSetupParseBasic(t *testing.T) { func TestSetupParseBasic(t *testing.T) {
cfg := new(Config) certCache := &certificateCache{cache: make(map[string]Certificate)}
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg }) RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
c := caddy.NewTestController("", `tls `+certFile+` `+keyFile+``) c := caddy.NewTestController("", `tls `+certFile+` `+keyFile+``)
c.Set(CertCacheInstStorageKey, certCache)
err := setupTLS(c) err := setupTLS(c)
if err != nil { if err != nil {
...@@ -124,9 +127,12 @@ func TestSetupParseWithOptionalParams(t *testing.T) { ...@@ -124,9 +127,12 @@ func TestSetupParseWithOptionalParams(t *testing.T) {
must_staple must_staple
alpn http/1.1 alpn http/1.1
}` }`
cfg := new(Config) certCache := &certificateCache{cache: make(map[string]Certificate)}
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg }) RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
c := caddy.NewTestController("", params) c := caddy.NewTestController("", params)
c.Set(CertCacheInstStorageKey, certCache)
err := setupTLS(c) err := setupTLS(c)
if err != nil { if err != nil {
...@@ -158,9 +164,11 @@ func TestSetupDefaultWithOptionalParams(t *testing.T) { ...@@ -158,9 +164,11 @@ func TestSetupDefaultWithOptionalParams(t *testing.T) {
params := `tls { params := `tls {
ciphers RSA-3DES-EDE-CBC-SHA ciphers RSA-3DES-EDE-CBC-SHA
}` }`
cfg := new(Config) certCache := &certificateCache{cache: make(map[string]Certificate)}
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg }) RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
c := caddy.NewTestController("", params) c := caddy.NewTestController("", params)
c.Set(CertCacheInstStorageKey, certCache)
err := setupTLS(c) err := setupTLS(c)
if err != nil { if err != nil {
...@@ -176,9 +184,12 @@ func TestSetupParseWithWrongOptionalParams(t *testing.T) { ...@@ -176,9 +184,12 @@ func TestSetupParseWithWrongOptionalParams(t *testing.T) {
params := `tls ` + certFile + ` ` + keyFile + ` { params := `tls ` + certFile + ` ` + keyFile + ` {
protocols ssl tls protocols ssl tls
}` }`
cfg := new(Config) certCache := &certificateCache{cache: make(map[string]Certificate)}
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg }) RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
c := caddy.NewTestController("", params) c := caddy.NewTestController("", params)
c.Set(CertCacheInstStorageKey, certCache)
err := setupTLS(c) err := setupTLS(c)
if err == nil { if err == nil {
t.Errorf("Expected errors, but no error returned") t.Errorf("Expected errors, but no error returned")
...@@ -191,6 +202,7 @@ func TestSetupParseWithWrongOptionalParams(t *testing.T) { ...@@ -191,6 +202,7 @@ func TestSetupParseWithWrongOptionalParams(t *testing.T) {
cfg = new(Config) cfg = new(Config)
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg }) RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
c = caddy.NewTestController("", params) c = caddy.NewTestController("", params)
c.Set(CertCacheInstStorageKey, certCache)
err = setupTLS(c) err = setupTLS(c)
if err == nil { if err == nil {
t.Error("Expected errors, but no error returned") t.Error("Expected errors, but no error returned")
...@@ -215,6 +227,7 @@ func TestSetupParseWithWrongOptionalParams(t *testing.T) { ...@@ -215,6 +227,7 @@ func TestSetupParseWithWrongOptionalParams(t *testing.T) {
cfg = new(Config) cfg = new(Config)
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg }) RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
c = caddy.NewTestController("", params) c = caddy.NewTestController("", params)
c.Set(CertCacheInstStorageKey, certCache)
err = setupTLS(c) err = setupTLS(c)
if err == nil { if err == nil {
t.Error("Expected errors, but no error returned") t.Error("Expected errors, but no error returned")
...@@ -226,7 +239,8 @@ func TestSetupParseWithClientAuth(t *testing.T) { ...@@ -226,7 +239,8 @@ func TestSetupParseWithClientAuth(t *testing.T) {
params := `tls ` + certFile + ` ` + keyFile + ` { params := `tls ` + certFile + ` ` + keyFile + ` {
clients clients
}` }`
cfg := new(Config) certCache := &certificateCache{cache: make(map[string]Certificate)}
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg }) RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
c := caddy.NewTestController("", params) c := caddy.NewTestController("", params)
err := setupTLS(c) err := setupTLS(c)
...@@ -259,9 +273,11 @@ func TestSetupParseWithClientAuth(t *testing.T) { ...@@ -259,9 +273,11 @@ func TestSetupParseWithClientAuth(t *testing.T) {
clients verify_if_given clients verify_if_given
}`, tls.VerifyClientCertIfGiven, true, noCAs}, }`, tls.VerifyClientCertIfGiven, true, noCAs},
} { } {
cfg := new(Config) certCache := &certificateCache{cache: make(map[string]Certificate)}
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg }) RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
c := caddy.NewTestController("", caseData.params) c := caddy.NewTestController("", caseData.params)
c.Set(CertCacheInstStorageKey, certCache)
err := setupTLS(c) err := setupTLS(c)
if caseData.expectedErr { if caseData.expectedErr {
if err == nil { if err == nil {
...@@ -311,9 +327,11 @@ func TestSetupParseWithCAUrl(t *testing.T) { ...@@ -311,9 +327,11 @@ func TestSetupParseWithCAUrl(t *testing.T) {
ca 1 2 ca 1 2
}`, true, ""}, }`, true, ""},
} { } {
cfg := new(Config) certCache := &certificateCache{cache: make(map[string]Certificate)}
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg }) RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
c := caddy.NewTestController("", caseData.params) c := caddy.NewTestController("", caseData.params)
c.Set(CertCacheInstStorageKey, certCache)
err := setupTLS(c) err := setupTLS(c)
if caseData.expectedErr { if caseData.expectedErr {
if err == nil { if err == nil {
...@@ -335,9 +353,11 @@ func TestSetupParseWithKeyType(t *testing.T) { ...@@ -335,9 +353,11 @@ func TestSetupParseWithKeyType(t *testing.T) {
params := `tls { params := `tls {
key_type p384 key_type p384
}` }`
cfg := new(Config) certCache := &certificateCache{cache: make(map[string]Certificate)}
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg }) RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
c := caddy.NewTestController("", params) c := caddy.NewTestController("", params)
c.Set(CertCacheInstStorageKey, certCache)
err := setupTLS(c) err := setupTLS(c)
if err != nil { if err != nil {
...@@ -353,9 +373,11 @@ func TestSetupParseWithCurves(t *testing.T) { ...@@ -353,9 +373,11 @@ func TestSetupParseWithCurves(t *testing.T) {
params := `tls { params := `tls {
curves x25519 p256 p384 p521 curves x25519 p256 p384 p521
}` }`
cfg := new(Config) certCache := &certificateCache{cache: make(map[string]Certificate)}
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg }) RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
c := caddy.NewTestController("", params) c := caddy.NewTestController("", params)
c.Set(CertCacheInstStorageKey, certCache)
err := setupTLS(c) err := setupTLS(c)
if err != nil { if err != nil {
...@@ -380,9 +402,11 @@ func TestSetupParseWithOneTLSProtocol(t *testing.T) { ...@@ -380,9 +402,11 @@ func TestSetupParseWithOneTLSProtocol(t *testing.T) {
params := `tls { params := `tls {
protocols tls1.2 protocols tls1.2
}` }`
cfg := new(Config) certCache := &certificateCache{cache: make(map[string]Certificate)}
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg }) RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
c := caddy.NewTestController("", params) c := caddy.NewTestController("", params)
c.Set(CertCacheInstStorageKey, certCache)
err := setupTLS(c) err := setupTLS(c)
if err != nil { if err != nil {
......
...@@ -107,6 +107,10 @@ type Storage interface { ...@@ -107,6 +107,10 @@ type Storage interface {
// in StoreUser. The result is an empty string if there are no // in StoreUser. The result is an empty string if there are no
// persisted users in storage. // persisted users in storage.
MostRecentUserEmail() string MostRecentUserEmail() string
// Locker is necessary because synchronizing certificate maintenance
// depends on how storage is implemented.
Locker
} }
// ErrNotExist is returned by Storage implementations when // ErrNotExist is returned by Storage implementations when
......
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package caddytls
import (
"fmt"
"sync"
)
var _ Locker = &syncLock{}
type syncLock struct {
nameLocks map[string]*sync.WaitGroup
nameLocksMu sync.Mutex
}
// TryLock attempts to get a lock for name, otherwise it returns
// a Waiter value to wait until the other process is finished.
func (s *syncLock) TryLock(name string) (Waiter, error) {
s.nameLocksMu.Lock()
defer s.nameLocksMu.Unlock()
wg, ok := s.nameLocks[name]
if ok {
// lock already obtained, let caller wait on it
return wg, nil
}
// caller gets lock
wg = new(sync.WaitGroup)
wg.Add(1)
s.nameLocks[name] = wg
return nil, nil
}
// Unlock unlocks name.
func (s *syncLock) Unlock(name string) error {
s.nameLocksMu.Lock()
defer s.nameLocksMu.Unlock()
wg, ok := s.nameLocks[name]
if !ok {
return fmt.Errorf("FileStorage: no lock to release for %s", name)
}
wg.Done()
delete(s.nameLocks, name)
return nil
}
...@@ -88,30 +88,38 @@ func Revoke(host string) error { ...@@ -88,30 +88,38 @@ func Revoke(host string) error {
return client.Revoke(host) return client.Revoke(host)
} }
// tlsSniSolver is a type that can solve tls-sni challenges using // tlsSNISolver is a type that can solve TLS-SNI challenges using
// an existing listener and our custom, in-memory certificate cache. // an existing listener and our custom, in-memory certificate cache.
type tlsSniSolver struct{} type tlsSNISolver struct {
certCache *certificateCache
}
// Present adds the challenge certificate to the cache. // Present adds the challenge certificate to the cache.
func (s tlsSniSolver) Present(domain, token, keyAuth string) error { func (s tlsSNISolver) Present(domain, token, keyAuth string) error {
cert, acmeDomain, err := acme.TLSSNI01ChallengeCert(keyAuth) cert, acmeDomain, err := acme.TLSSNI01ChallengeCert(keyAuth)
if err != nil { if err != nil {
return err return err
} }
cacheCertificate(Certificate{ certHash := hashCertificateChain(cert.Certificate)
s.certCache.Lock()
s.certCache.cache[acmeDomain] = Certificate{
Certificate: cert, Certificate: cert,
Names: []string{acmeDomain}, Names: []string{acmeDomain},
}) Hash: certHash, // perhaps not necesssary
}
s.certCache.Unlock()
return nil return nil
} }
// CleanUp removes the challenge certificate from the cache. // CleanUp removes the challenge certificate from the cache.
func (s tlsSniSolver) CleanUp(domain, token, keyAuth string) error { func (s tlsSNISolver) CleanUp(domain, token, keyAuth string) error {
_, acmeDomain, err := acme.TLSSNI01ChallengeCert(keyAuth) _, acmeDomain, err := acme.TLSSNI01ChallengeCert(keyAuth)
if err != nil { if err != nil {
return err return err
} }
uncacheCertificate(acmeDomain) s.certCache.Lock()
delete(s.certCache.cache, acmeDomain)
s.certCache.Unlock()
return nil return nil
} }
......
...@@ -103,6 +103,20 @@ func (c *Controller) Context() Context { ...@@ -103,6 +103,20 @@ func (c *Controller) Context() Context {
return c.instance.context return c.instance.context
} }
// Get safely gets a value from the Instance's storage.
func (c *Controller) Get(key interface{}) interface{} {
c.instance.StorageMu.RLock()
defer c.instance.StorageMu.RUnlock()
return c.instance.Storage[key]
}
// Set safely sets a value on the Instance's storage.
func (c *Controller) Set(key, val interface{}) {
c.instance.StorageMu.Lock()
c.instance.Storage[key] = val
c.instance.StorageMu.Unlock()
}
// NewTestController creates a new Controller for // NewTestController creates a new Controller for
// the server type and input specified. The filename // the server type and input specified. The filename
// is "Testfile". If the server type is not empty and // is "Testfile". If the server type is not empty and
...@@ -113,12 +127,12 @@ func (c *Controller) Context() Context { ...@@ -113,12 +127,12 @@ func (c *Controller) Context() Context {
// Used only for testing, but exported so plugins can // Used only for testing, but exported so plugins can
// use this for convenience. // use this for convenience.
func NewTestController(serverType, input string) *Controller { func NewTestController(serverType, input string) *Controller {
var ctx Context testInst := &Instance{serverType: serverType, Storage: make(map[interface{}]interface{})}
if stype, err := getServerType(serverType); err == nil { if stype, err := getServerType(serverType); err == nil {
ctx = stype.NewContext() testInst.context = stype.NewContext(testInst)
} }
return &Controller{ return &Controller{
instance: &Instance{serverType: serverType, context: ctx}, instance: testInst,
Dispenser: caddyfile.NewDispenser("Testfile", strings.NewReader(input)), Dispenser: caddyfile.NewDispenser("Testfile", strings.NewReader(input)),
OncePerServerBlock: func(f func() error) error { return f() }, OncePerServerBlock: func(f func() error) error { return f() },
} }
......
...@@ -195,7 +195,7 @@ type ServerType struct { ...@@ -195,7 +195,7 @@ type ServerType struct {
// startup phases before this one. It's a way to keep // startup phases before this one. It's a way to keep
// each set of server instances separate and to reduce // each set of server instances separate and to reduce
// the amount of global state you need. // the amount of global state you need.
NewContext func() Context NewContext func(inst *Instance) Context
} }
// Plugin is a type which holds information about a plugin. // Plugin is a type which holds information about a plugin.
...@@ -387,6 +387,14 @@ func loadCaddyfileInput(serverType string) (Input, error) { ...@@ -387,6 +387,14 @@ func loadCaddyfileInput(serverType string) (Input, error) {
return caddyfileToUse, nil return caddyfileToUse, nil
} }
// OnProcessExit is a list of functions to run when the process
// exits -- they are ONLY for cleanup and should not block,
// return errors, or do anything fancy. They will be run with
// every signal, even if "shutdown callbacks" are not executed.
// This variable must only be modified in the main goroutine
// from init() functions.
var OnProcessExit []func()
// caddyfileLoader pairs the name of a loader to the loader. // caddyfileLoader pairs the name of a loader to the loader.
type caddyfileLoader struct { type caddyfileLoader struct {
name string name string
......
...@@ -44,16 +44,17 @@ func trapSignalsCrossPlatform() { ...@@ -44,16 +44,17 @@ func trapSignalsCrossPlatform() {
if i > 0 { if i > 0 {
log.Println("[INFO] SIGINT: Force quit") log.Println("[INFO] SIGINT: Force quit")
if PidFile != "" { for _, f := range OnProcessExit {
os.Remove(PidFile) f() // important cleanup actions only
} }
os.Exit(2) os.Exit(2)
} }
log.Println("[INFO] SIGINT: Shutting down") log.Println("[INFO] SIGINT: Shutting down")
if PidFile != "" { // important cleanup actions before shutdown callbacks
os.Remove(PidFile) for _, f := range OnProcessExit {
f()
} }
go func() { go func() {
......
...@@ -33,22 +33,22 @@ func trapSignalsPosix() { ...@@ -33,22 +33,22 @@ func trapSignalsPosix() {
switch sig { switch sig {
case syscall.SIGQUIT: case syscall.SIGQUIT:
log.Println("[INFO] SIGQUIT: Quitting process immediately") log.Println("[INFO] SIGQUIT: Quitting process immediately")
if PidFile != "" { for _, f := range OnProcessExit {
os.Remove(PidFile) f() // only perform important cleanup actions
} }
os.Exit(0) os.Exit(0)
case syscall.SIGTERM: case syscall.SIGTERM:
log.Println("[INFO] SIGTERM: Shutting down servers then terminating") log.Println("[INFO] SIGTERM: Shutting down servers then terminating")
exitCode := executeShutdownCallbacks("SIGTERM") exitCode := executeShutdownCallbacks("SIGTERM")
for _, f := range OnProcessExit {
f() // only perform important cleanup actions
}
err := Stop() err := Stop()
if err != nil { if err != nil {
log.Printf("[ERROR] SIGTERM stop: %v", err) log.Printf("[ERROR] SIGTERM stop: %v", err)
exitCode = 3 exitCode = 3
} }
if PidFile != "" {
os.Remove(PidFile)
}
os.Exit(exitCode) os.Exit(exitCode)
case syscall.SIGUSR1: case syscall.SIGUSR1:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment