Commit a7aeb979 authored by Matthew Holt's avatar Matthew Holt

caddytls: Use IP address to find config; vendor: update certmagic

Closes #2356
parent 771dcf3d
...@@ -405,6 +405,8 @@ func groupSiteConfigsByListenAddr(configs []*SiteConfig) (map[string][]*SiteConf ...@@ -405,6 +405,8 @@ func groupSiteConfigsByListenAddr(configs []*SiteConfig) (map[string][]*SiteConf
// parts of an address. The component parts may be // parts of an address. The component parts may be
// updated to the correct values as setup proceeds, // updated to the correct values as setup proceeds,
// but the original value should never be changed. // but the original value should never be changed.
//
// The Host field must be in a normalized form.
type Address struct { type Address struct {
Original, Scheme, Host, Port, Path string Original, Scheme, Host, Port, Path string
} }
...@@ -453,10 +455,17 @@ func (a Address) Normalize() Address { ...@@ -453,10 +455,17 @@ func (a Address) Normalize() Address {
if !CaseSensitivePath { if !CaseSensitivePath {
path = strings.ToLower(path) path = strings.ToLower(path)
} }
// ensure host is normalized if it's an IP address
host := a.Host
if ip := net.ParseIP(host); ip != nil {
host = ip.String()
}
return Address{ return Address{
Original: a.Original, Original: a.Original,
Scheme: strings.ToLower(a.Scheme), Scheme: strings.ToLower(a.Scheme),
Host: strings.ToLower(a.Host), Host: strings.ToLower(host),
Port: a.Port, Port: a.Port,
Path: path, Path: path,
} }
......
...@@ -33,7 +33,9 @@ type Config struct { ...@@ -33,7 +33,9 @@ type Config struct {
// The hostname or class of hostnames this config is // The hostname or class of hostnames this config is
// designated for; can contain wildcard characters // designated for; can contain wildcard characters
// according to RFC 6125 §6.4.3 - this field MUST // according to RFC 6125 §6.4.3 - this field MUST
// be set in order for things to work as expected // be set in order for things to work as expected,
// must be normalized, and if an IP address, must
// be normalized
Hostname string Hostname string
// Whether TLS is enabled // Whether TLS is enabled
...@@ -272,7 +274,7 @@ func MakeTLSConfig(configs []*Config) (*tls.Config, error) { ...@@ -272,7 +274,7 @@ func MakeTLSConfig(configs []*Config) (*tls.Config, error) {
// A tls.Config must have Certificates or GetCertificate // A tls.Config must have Certificates or GetCertificate
// set, in order to be accepted by tls.Listen and quic.Listen. // set, in order to be accepted by tls.Listen and quic.Listen.
// TODO: remove this once the standard library allows a tls.Config with // TODO: remove this once the standard library allows a tls.Config with
// only GetConfigForClient set. // only GetConfigForClient set. https://github.com/mholt/caddy/pull/2404
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) { GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
return nil, fmt.Errorf("all certificates configured via GetConfigForClient") return nil, fmt.Errorf("all certificates configured via GetConfigForClient")
}, },
......
...@@ -20,6 +20,7 @@ import ( ...@@ -20,6 +20,7 @@ import (
"strings" "strings"
"github.com/mholt/caddy/telemetry" "github.com/mholt/caddy/telemetry"
"github.com/mholt/certmagic"
) )
// configGroup is a type that keys configs by their hostname // configGroup is a type that keys configs by their hostname
...@@ -27,7 +28,7 @@ import ( ...@@ -27,7 +28,7 @@ import (
// method to get a config by matching its hostname). // method to get a config by matching its hostname).
type configGroup map[string]*Config type configGroup map[string]*Config
// getConfig gets the config by the first key match for name. // getConfig gets the config by the first key match for hello.
// In other words, "sub.foo.bar" will get the config for "*.foo.bar" // In other words, "sub.foo.bar" will get the config for "*.foo.bar"
// if that is the closest match. If no match is found, the first // if that is the closest match. If no match is found, the first
// (random) config will be loaded, which will defer any TLS alerts // (random) config will be loaded, which will defer any TLS alerts
...@@ -36,8 +37,8 @@ type configGroup map[string]*Config ...@@ -36,8 +37,8 @@ type configGroup map[string]*Config
// //
// This function follows nearly the same logic to lookup // This function follows nearly the same logic to lookup
// a hostname as the getCertificate function uses. // a hostname as the getCertificate function uses.
func (cg configGroup) getConfig(name string) *Config { func (cg configGroup) getConfig(hello *tls.ClientHelloInfo) *Config {
name = strings.ToLower(name) name := certmagic.CertNameFromClientHello(hello)
// exact match? great, let's use it // exact match? great, let's use it
if config, ok := cg[name]; ok { if config, ok := cg[name]; ok {
...@@ -72,7 +73,7 @@ func (cg configGroup) getConfig(name string) *Config { ...@@ -72,7 +73,7 @@ func (cg configGroup) getConfig(name string) *Config {
// //
// This method is safe for use as a tls.Config.GetConfigForClient callback. // This method is safe for use as a tls.Config.GetConfigForClient callback.
func (cg configGroup) GetConfigForClient(clientHello *tls.ClientHelloInfo) (*tls.Config, error) { func (cg configGroup) GetConfigForClient(clientHello *tls.ClientHelloInfo) (*tls.Config, error) {
config := cg.getConfig(clientHello.ServerName) config := cg.getConfig(clientHello)
if config != nil { if config != nil {
return config.tlsConfig, nil return config.tlsConfig, nil
} }
......
...@@ -369,13 +369,10 @@ func (cfg *Config) preObtainOrRenewChecks(name string, allowPrompts bool) (bool, ...@@ -369,13 +369,10 @@ func (cfg *Config) preObtainOrRenewChecks(name string, allowPrompts bool) (bool,
return true, nil return true, nil
} }
if cfg.Email == "" { err := cfg.getEmail(allowPrompts)
var err error
cfg.Email, err = cfg.getEmail(allowPrompts)
if err != nil { if err != nil {
return false, err return false, err
} }
}
return false, nil return false, nil
} }
......
...@@ -19,6 +19,7 @@ import ( ...@@ -19,6 +19,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"log" "log"
"net"
"strings" "strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
...@@ -66,18 +67,24 @@ func (cfg *Config) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certif ...@@ -66,18 +67,24 @@ func (cfg *Config) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certif
} }
} }
wrapped := wrappedClientHelloInfo{
ClientHelloInfo: clientHello,
serverNameOrIP: CertNameFromClientHello(clientHello),
}
// get the certificate and serve it up // get the certificate and serve it up
cert, err := cfg.getCertDuringHandshake(strings.ToLower(clientHello.ServerName), true, true) cert, err := cfg.getCertDuringHandshake(wrapped, true, true)
if err == nil && cfg.OnEvent != nil { if err == nil && cfg.OnEvent != nil {
cfg.OnEvent("tls_handshake_completed", clientHello) cfg.OnEvent("tls_handshake_completed", clientHello)
} }
return &cert.Certificate, err return &cert.Certificate, err
} }
// getCertificate gets a certificate that matches name (a server name) // getCertificate gets a certificate that matches name from the in-memory
// from the in-memory cache, according to the lookup table associated with // cache, according to the lookup table associated with cfg. The lookup then
// cfg. The lookup then points to a certificate in the Instance certificate // points to a certificate in the Instance certificate cache.
// cache. //
// The name is expected to already be normalized (e.g. lowercased).
// //
// If there is no exact match for name, it will be checked against names of // If there is no exact match for name, it will be checked against names of
// the form '*.example.com' (wildcard certificates) according to RFC 6125. // the form '*.example.com' (wildcard certificates) according to RFC 6125.
...@@ -93,11 +100,6 @@ func (cfg *Config) getCertificate(name string) (cert Certificate, matched, defau ...@@ -93,11 +100,6 @@ func (cfg *Config) getCertificate(name string) (cert Certificate, matched, defau
var certKey string var certKey string
var ok bool var ok bool
// Not going to trim trailing dots here since RFC 3546 says,
// "The hostname is represented ... without a trailing dot."
// Just normalize to lowercase.
name = strings.ToLower(name)
cfg.certCache.mu.RLock() cfg.certCache.mu.RLock()
defer cfg.certCache.mu.RUnlock() defer cfg.certCache.mu.RUnlock()
...@@ -123,10 +125,11 @@ func (cfg *Config) getCertificate(name string) (cert Certificate, matched, defau ...@@ -123,10 +125,11 @@ func (cfg *Config) getCertificate(name string) (cert Certificate, matched, defau
// check the certCache directly to see if the SNI name is // check the certCache directly to see if the SNI name is
// already the key of the certificate it wants; this implies // already the key of the certificate it wants; this implies
// that the SNI can contain the hash of a specific cert // that the SNI can contain the hash of a specific cert
// (chain) it wants and we will still be able to serveit up // (chain) it wants and we will still be able to serve it up
// (this behavior, by the way, could be controversial as to // (this behavior, by the way, could be controversial as to
// whether it complies with RFC 6066 about SNI, but I think // whether it complies with RFC 6066 about SNI, but I think
// it does, soooo...) // it does, soooo...)
// (this is how we solved the former ACME TLS-SNI challenge)
if directCert, ok := cfg.certCache.cache[name]; ok { if directCert, ok := cfg.certCache.cache[name]; ok {
cert = directCert cert = directCert
matched = true matched = true
...@@ -147,9 +150,9 @@ func (cfg *Config) getCertificate(name string) (cert Certificate, matched, defau ...@@ -147,9 +150,9 @@ func (cfg *Config) getCertificate(name string) (cert Certificate, matched, defau
return return
} }
// getCertDuringHandshake will get a certificate for name. It first tries // getCertDuringHandshake will get a certificate for hello. It first tries
// the in-memory cache. If no certificate for name is in the cache, the // the in-memory cache. If no certificate for hello is in the cache, the
// config most closely corresponding to name will be loaded. If that config // config most closely corresponding to hello will be loaded. If that config
// allows it (OnDemand==true) and if loadIfNecessary == true, it goes to disk // allows it (OnDemand==true) and if loadIfNecessary == true, it goes to disk
// to load it into the cache and serve it. If it's not on disk and if // to load it into the cache and serve it. If it's not on disk and if
// obtainIfNecessary == true, the certificate will be obtained from the CA, // obtainIfNecessary == true, the certificate will be obtained from the CA,
...@@ -158,9 +161,9 @@ func (cfg *Config) getCertificate(name string) (cert Certificate, matched, defau ...@@ -158,9 +161,9 @@ func (cfg *Config) getCertificate(name string) (cert Certificate, matched, defau
// certificate is available. // certificate is available.
// //
// This function is safe for concurrent use. // This function is safe for concurrent use.
func (cfg *Config) getCertDuringHandshake(name string, loadIfNecessary, obtainIfNecessary bool) (Certificate, error) { func (cfg *Config) getCertDuringHandshake(hello wrappedClientHelloInfo, loadIfNecessary, obtainIfNecessary bool) (Certificate, error) {
// First check our in-memory cache to see if we've already loaded it // First check our in-memory cache to see if we've already loaded it
cert, matched, defaulted := cfg.getCertificate(name) cert, matched, defaulted := cfg.getCertificate(hello.serverNameOrIP)
if matched { if matched {
return cert, nil return cert, nil
} }
...@@ -169,32 +172,30 @@ func (cfg *Config) getCertDuringHandshake(name string, loadIfNecessary, obtainIf ...@@ -169,32 +172,30 @@ func (cfg *Config) getCertDuringHandshake(name string, loadIfNecessary, obtainIf
// obtain a needed certificate // obtain a needed certificate
if cfg.OnDemand != nil && loadIfNecessary { if cfg.OnDemand != nil && loadIfNecessary {
// Then check to see if we have one on disk // Then check to see if we have one on disk
loadedCert, err := cfg.CacheManagedCertificate(name) loadedCert, err := cfg.CacheManagedCertificate(hello.serverNameOrIP)
if err == nil { if err == nil {
loadedCert, err = cfg.handshakeMaintenance(name, loadedCert) loadedCert, err = cfg.handshakeMaintenance(hello, loadedCert)
if err != nil { if err != nil {
log.Printf("[ERROR] Maintaining newly-loaded certificate for %s: %v", name, err) log.Printf("[ERROR] Maintaining newly-loaded certificate for %s: %v", hello.serverNameOrIP, err)
} }
return loadedCert, nil return loadedCert, nil
} }
if obtainIfNecessary { if obtainIfNecessary {
// By this point, we need to ask the CA for a certificate // By this point, we need to ask the CA for a certificate
name = strings.ToLower(name)
// Make sure the certificate should be obtained based on config // Make sure the certificate should be obtained based on config
err := cfg.checkIfCertShouldBeObtained(name) err := cfg.checkIfCertShouldBeObtained(hello.serverNameOrIP)
if err != nil { if err != nil {
return Certificate{}, err return Certificate{}, err
} }
// Name has to qualify for a certificate // Name has to qualify for a certificate
if !HostQualifies(name) { if !HostQualifies(hello.serverNameOrIP) {
return cert, fmt.Errorf("hostname '%s' does not qualify for certificate", name) return cert, fmt.Errorf("hostname '%s' does not qualify for certificate", hello.serverNameOrIP)
} }
// Obtain certificate from the CA // Obtain certificate from the CA
return cfg.obtainOnDemandCertificate(name) return cfg.obtainOnDemandCertificate(hello)
} }
} }
...@@ -203,7 +204,7 @@ func (cfg *Config) getCertDuringHandshake(name string, loadIfNecessary, obtainIf ...@@ -203,7 +204,7 @@ func (cfg *Config) getCertDuringHandshake(name string, loadIfNecessary, obtainIf
return cert, nil return cert, nil
} }
return Certificate{}, fmt.Errorf("no certificate available for %s", name) return Certificate{}, fmt.Errorf("no certificate available for %s", hello.serverNameOrIP)
} }
// checkIfCertShouldBeObtained checks to see if an on-demand tls certificate // checkIfCertShouldBeObtained checks to see if an on-demand tls certificate
...@@ -216,52 +217,52 @@ func (cfg *Config) checkIfCertShouldBeObtained(name string) error { ...@@ -216,52 +217,52 @@ func (cfg *Config) checkIfCertShouldBeObtained(name string) error {
return cfg.OnDemand.Allowed(name) return cfg.OnDemand.Allowed(name)
} }
// obtainOnDemandCertificate obtains a certificate for name for the given // obtainOnDemandCertificate obtains a certificate for hello.
// name. If another goroutine has already started obtaining a cert for // If another goroutine has already started obtaining a cert for
// name, it will wait and use what the other goroutine obtained. // hello, it will wait and use what the other goroutine obtained.
// //
// This function is safe for use by multiple concurrent goroutines. // This function is safe for use by multiple concurrent goroutines.
func (cfg *Config) obtainOnDemandCertificate(name string) (Certificate, error) { func (cfg *Config) obtainOnDemandCertificate(hello wrappedClientHelloInfo) (Certificate, error) {
// We must protect this process from happening concurrently, so synchronize. // We must protect this process from happening concurrently, so synchronize.
obtainCertWaitChansMu.Lock() obtainCertWaitChansMu.Lock()
wait, ok := obtainCertWaitChans[name] wait, ok := obtainCertWaitChans[hello.serverNameOrIP]
if ok { if ok {
// lucky us -- another goroutine is already obtaining the certificate. // lucky us -- another goroutine is already obtaining the certificate.
// wait for it to finish obtaining the cert and then we'll use it. // wait for it to finish obtaining the cert and then we'll use it.
obtainCertWaitChansMu.Unlock() obtainCertWaitChansMu.Unlock()
<-wait <-wait
return cfg.getCertDuringHandshake(name, true, false) return cfg.getCertDuringHandshake(hello, true, false)
} }
// looks like it's up to us to do all the work and obtain the cert. // looks like it's up to us to do all the work and obtain the cert.
// make a chan others can wait on if needed // make a chan others can wait on if needed
wait = make(chan struct{}) wait = make(chan struct{})
obtainCertWaitChans[name] = wait obtainCertWaitChans[hello.serverNameOrIP] = wait
obtainCertWaitChansMu.Unlock() obtainCertWaitChansMu.Unlock()
// obtain the certificate // obtain the certificate
log.Printf("[INFO] Obtaining new certificate for %s", name) log.Printf("[INFO] Obtaining new certificate for %s", hello.serverNameOrIP)
err := cfg.ObtainCert(name, false) err := cfg.ObtainCert(hello.serverNameOrIP, false)
// immediately unblock anyone waiting for it; doing this in // immediately unblock anyone waiting for it; doing this in
// a defer would risk deadlock because of the recursive call // a defer would risk deadlock because of the recursive call
// to getCertDuringHandshake below when we return! // to getCertDuringHandshake below when we return!
obtainCertWaitChansMu.Lock() obtainCertWaitChansMu.Lock()
close(wait) close(wait)
delete(obtainCertWaitChans, name) delete(obtainCertWaitChans, hello.serverNameOrIP)
obtainCertWaitChansMu.Unlock() obtainCertWaitChansMu.Unlock()
if err != nil { if err != nil {
// Failed to solve challenge, so don't allow another on-demand // Failed to solve challenge, so don't allow another on-demand
// issue for this name to be attempted for a little while. // issue for this name to be attempted for a little while.
failedIssuanceMu.Lock() failedIssuanceMu.Lock()
failedIssuance[name] = time.Now() failedIssuance[hello.serverNameOrIP] = time.Now()
go func(name string) { go func(name string) {
time.Sleep(5 * time.Minute) time.Sleep(5 * time.Minute)
failedIssuanceMu.Lock() failedIssuanceMu.Lock()
delete(failedIssuance, name) delete(failedIssuance, name)
failedIssuanceMu.Unlock() failedIssuanceMu.Unlock()
}(name) }(hello.serverNameOrIP)
failedIssuanceMu.Unlock() failedIssuanceMu.Unlock()
return Certificate{}, err return Certificate{}, err
} }
...@@ -273,19 +274,18 @@ func (cfg *Config) obtainOnDemandCertificate(name string) (Certificate, error) { ...@@ -273,19 +274,18 @@ func (cfg *Config) obtainOnDemandCertificate(name string) (Certificate, error) {
lastIssueTimeMu.Unlock() lastIssueTimeMu.Unlock()
// certificate is already on disk; now just start over to load it and serve it // certificate is already on disk; now just start over to load it and serve it
return cfg.getCertDuringHandshake(name, true, false) return cfg.getCertDuringHandshake(hello, true, false)
} }
// handshakeMaintenance performs a check on cert for expiration and OCSP // handshakeMaintenance performs a check on cert for expiration and OCSP validity.
// validity.
// //
// This function is safe for use by multiple concurrent goroutines. // This function is safe for use by multiple concurrent goroutines.
func (cfg *Config) handshakeMaintenance(name string, cert Certificate) (Certificate, error) { func (cfg *Config) handshakeMaintenance(hello wrappedClientHelloInfo, cert Certificate) (Certificate, error) {
// Check cert expiration // Check cert expiration
timeLeft := cert.NotAfter.Sub(time.Now().UTC()) timeLeft := cert.NotAfter.Sub(time.Now().UTC())
if timeLeft < cfg.RenewDurationBefore { if timeLeft < cfg.RenewDurationBefore {
log.Printf("[INFO] Certificate for %v expires in %v; attempting renewal", cert.Names, timeLeft) log.Printf("[INFO] Certificate for %v expires in %v; attempting renewal", cert.Names, timeLeft)
return cfg.renewDynamicCertificate(name, cert) return cfg.renewDynamicCertificate(hello, cert)
} }
// Check OCSP staple validity // Check OCSP staple validity
...@@ -296,7 +296,7 @@ func (cfg *Config) handshakeMaintenance(name string, cert Certificate) (Certific ...@@ -296,7 +296,7 @@ func (cfg *Config) handshakeMaintenance(name string, cert Certificate) (Certific
if err != nil { if err != nil {
// An error with OCSP stapling is not the end of the world, and in fact, is // An error with OCSP stapling is not the end of the world, and in fact, is
// quite common considering not all certs have issuer URLs that support it. // quite common considering not all certs have issuer URLs that support it.
log.Printf("[ERROR] Getting OCSP for %s: %v", name, err) log.Printf("[ERROR] Getting OCSP for %s: %v", hello.serverNameOrIP, err)
} }
cfg.certCache.mu.Lock() cfg.certCache.mu.Lock()
cfg.certCache.cache[cert.Hash] = cert cfg.certCache.cache[cert.Hash] = cert
...@@ -313,37 +313,38 @@ func (cfg *Config) handshakeMaintenance(name string, cert Certificate) (Certific ...@@ -313,37 +313,38 @@ func (cfg *Config) handshakeMaintenance(name string, cert Certificate) (Certific
// ClientHello. // ClientHello.
// //
// This function is safe for use by multiple concurrent goroutines. // This function is safe for use by multiple concurrent goroutines.
func (cfg *Config) renewDynamicCertificate(name string, currentCert Certificate) (Certificate, error) { func (cfg *Config) renewDynamicCertificate(hello wrappedClientHelloInfo, currentCert Certificate) (Certificate, error) {
obtainCertWaitChansMu.Lock() obtainCertWaitChansMu.Lock()
wait, ok := obtainCertWaitChans[name] wait, ok := obtainCertWaitChans[hello.serverNameOrIP]
if ok { if ok {
// lucky us -- another goroutine is already renewing the certificate. // lucky us -- another goroutine is already renewing the certificate.
// wait for it to finish, then we'll use the new one. // wait for it to finish, then we'll use the new one.
obtainCertWaitChansMu.Unlock() obtainCertWaitChansMu.Unlock()
<-wait <-wait
return cfg.getCertDuringHandshake(name, true, false) return cfg.getCertDuringHandshake(hello, true, false)
} }
// looks like it's up to us to do all the work and renew the cert // looks like it's up to us to do all the work and renew the cert
wait = make(chan struct{}) wait = make(chan struct{})
obtainCertWaitChans[name] = wait obtainCertWaitChans[hello.serverNameOrIP] = wait
obtainCertWaitChansMu.Unlock() obtainCertWaitChansMu.Unlock()
// renew and reload the certificate // renew and reload the certificate
log.Printf("[INFO] Renewing certificate for %s", name) log.Printf("[INFO] Renewing certificate for %s", hello.serverNameOrIP)
err := cfg.RenewCert(name, false) err := cfg.RenewCert(hello.serverNameOrIP, false)
if err == nil { if err == nil {
// even though the recursive nature of the dynamic cert loading // even though the recursive nature of the dynamic cert loading
// would just call this function anyway, we do it here to // would just call this function anyway, we do it here to
// make the replacement as atomic as possible. // make the replacement as atomic as possible.
newCert, err := currentCert.configs[0].CacheManagedCertificate(name) newCert, err := currentCert.configs[0].CacheManagedCertificate(hello.serverNameOrIP)
if err != nil { if err != nil {
log.Printf("[ERROR] loading renewed certificate for %s: %v", name, err) log.Printf("[ERROR] loading renewed certificate for %s: %v", hello.serverNameOrIP, err)
} else { } else {
// replace the old certificate with the new one // replace the old certificate with the new one
err = cfg.certCache.replaceCertificate(currentCert, newCert) err = cfg.certCache.replaceCertificate(currentCert, newCert)
if err != nil { if err != nil {
log.Printf("[ERROR] Replacing certificate for %s: %v", name, err) log.Printf("[ERROR] Replacing certificate for %s: %v", hello.serverNameOrIP, err)
} }
} }
} }
...@@ -353,14 +354,14 @@ func (cfg *Config) renewDynamicCertificate(name string, currentCert Certificate) ...@@ -353,14 +354,14 @@ func (cfg *Config) renewDynamicCertificate(name string, currentCert Certificate)
// to getCertDuringHandshake below when we return! // to getCertDuringHandshake below when we return!
obtainCertWaitChansMu.Lock() obtainCertWaitChansMu.Lock()
close(wait) close(wait)
delete(obtainCertWaitChans, name) delete(obtainCertWaitChans, hello.serverNameOrIP)
obtainCertWaitChansMu.Unlock() obtainCertWaitChansMu.Unlock()
if err != nil { if err != nil {
return Certificate{}, err return Certificate{}, err
} }
return cfg.getCertDuringHandshake(name, true, false) return cfg.getCertDuringHandshake(hello, true, false)
} }
// tryDistributedChallengeSolver is to be called when the clientHello pertains to // tryDistributedChallengeSolver is to be called when the clientHello pertains to
...@@ -395,6 +396,38 @@ func (cfg *Config) tryDistributedChallengeSolver(clientHello *tls.ClientHelloInf ...@@ -395,6 +396,38 @@ func (cfg *Config) tryDistributedChallengeSolver(clientHello *tls.ClientHelloInf
return Certificate{Certificate: *cert}, true, nil return Certificate{Certificate: *cert}, true, nil
} }
// CertNameFromClientHello returns a normalized name for which to
// look up a certificate given this ClientHelloInfo. If the client
// did not send a ServerName value, the connection's local IP is
// assumed.
func CertNameFromClientHello(hello *tls.ClientHelloInfo) string {
// Not going to trim trailing dots here since RFC 3546 says,
// "The hostname is represented ... without a trailing dot."
// Just normalize to lowercase and remove any leading or
// trailing whitespace n case the hello was sloppily made
name := strings.ToLower(strings.TrimSpace(hello.ServerName))
// if SNI is not set, assume IP of listener
if name == "" && hello.Conn != nil {
addr := hello.Conn.LocalAddr().String()
ip, _, err := net.SplitHostPort(addr)
if err == nil {
name = ip
}
}
return name
}
// wrappedClientHelloInfo is a type that allows us to
// attach a name with which to look for a certificate
// to a given ClientHelloInfo, since not all clients
// use SNI and some self-signed certificates use IP.
type wrappedClientHelloInfo struct {
*tls.ClientHelloInfo
serverNameOrIP string
}
// obtainCertWaitChans is used to coordinate obtaining certs for each hostname. // obtainCertWaitChans is used to coordinate obtaining certs for each hostname.
var obtainCertWaitChans = make(map[string]chan struct{}) var obtainCertWaitChans = make(map[string]chan struct{})
var obtainCertWaitChansMu sync.Mutex var obtainCertWaitChansMu sync.Mutex
...@@ -23,12 +23,13 @@ import ( ...@@ -23,12 +23,13 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"net/http"
"os" "os"
"path" "path"
"sort" "sort"
"strings" "strings"
"github.com/xenolf/lego/lego" "github.com/xenolf/lego/acme"
"github.com/xenolf/lego/registration" "github.com/xenolf/lego/registration"
) )
...@@ -71,81 +72,74 @@ func (cfg *Config) newUser(email string) (user, error) { ...@@ -71,81 +72,74 @@ func (cfg *Config) newUser(email string) (user, error) {
// getEmail does everything it can to obtain an email address // getEmail does everything it can to obtain an email address
// from the user within the scope of memory and storage to use // from the user within the scope of memory and storage to use
// for ACME TLS. If it cannot get an email address, it returns // for ACME TLS. If it cannot get an email address, it does nothing
// empty string. (If user is present, it will warn the user of // (If user is prompted, it will warn the user of
// the consequences of an empty email.) This function MAY prompt // the consequences of an empty email.) This function MAY prompt
// the user for input. If userPresent is false, the operator // the user for input. If allowPrompts is false, the user
// will NOT be prompted and an empty email may be returned. // will NOT be prompted and an empty email may be returned.
// If the user is prompted, a new User will be created and func (cfg *Config) getEmail(allowPrompts bool) error {
// stored in storage according to the email address they
// provided (which might be blank).
func (cfg *Config) getEmail(userPresent bool) (string, error) {
// First try memory
leEmail := cfg.Email leEmail := cfg.Email
// First try package default email
if leEmail == "" { if leEmail == "" {
leEmail = Email leEmail = Email
} }
// Then try to get most recent user email from storage // Then try to get most recent user email from storage
if leEmail == "" { if leEmail == "" {
leEmail = cfg.mostRecentUserEmail() leEmail = cfg.mostRecentUserEmail()
cfg.Email = leEmail // save for next time
} }
if leEmail == "" && allowPrompts {
// Looks like there is no email address readily available, // Looks like there is no email address readily available,
// so we will have to ask the user if we can. // so we will have to ask the user if we can.
if leEmail == "" && userPresent { var err error
// evidently, no User data was present in storage; leEmail, err = cfg.promptUserForEmail()
// thus we must make a new User so that we can get
// the Terms of Service URL via our ACME client, phew!
user, err := cfg.newUser("")
if err != nil { if err != nil {
return "", err return err
}
cfg.Agreed = true
} }
// lower-casing the email is important for consistency
cfg.Email = strings.ToLower(leEmail)
return nil
}
// get the agreement URL func (cfg *Config) getAgreementURL() (string, error) {
agreementURL := agreementTestURL if agreementTestURL != "" {
if agreementURL == "" { return agreementTestURL, nil
// we call acme.NewClient directly because newACMEClient }
// would require that we already know the user's email
caURL := CA caURL := CA
if cfg.CA != "" { if cfg.CA != "" {
caURL = cfg.CA caURL = cfg.CA
} }
legoConfig := lego.NewConfig(user) response, err := http.Get(caURL)
legoConfig.CADirURL = caURL
legoConfig.UserAgent = UserAgent
tempClient, err := lego.NewClient(legoConfig)
if err != nil { if err != nil {
return "", fmt.Errorf("making ACME client to get ToS URL: %v", err) return "", err
} }
agreementURL = tempClient.GetToSURL() defer response.Body.Close()
var dir acme.Directory
err = json.NewDecoder(response.Body).Decode(&dir)
if err != nil {
return "", err
} }
return dir.Meta.TermsOfService, nil
}
func (cfg *Config) promptUserForEmail() (string, error) {
agreementURL, err := cfg.getAgreementURL()
if err != nil {
return "", fmt.Errorf("get Agreement URL: %v", err)
}
// prompt the user for an email address and terms agreement // prompt the user for an email address and terms agreement
reader := bufio.NewReader(stdin) reader := bufio.NewReader(stdin)
cfg.promptUserAgreement(agreementURL) cfg.promptUserAgreement(agreementURL)
fmt.Println("Please enter your email address to signify agreement and to be notified") fmt.Println("Please enter your email address to signify agreement and to be notified")
fmt.Println("in case of issues. You can leave it blank, but we don't recommend it.") fmt.Println("in case of issues. You can leave it blank, but we don't recommend it.")
fmt.Print(" Email address: ") fmt.Print(" Email address: ")
leEmail, err = reader.ReadString('\n') leEmail, err := reader.ReadString('\n')
if err != nil && err != io.EOF { if err != nil && err != io.EOF {
return "", fmt.Errorf("reading email address: %v", err) return "", fmt.Errorf("reading email address: %v", err)
} }
leEmail = strings.TrimSpace(leEmail) leEmail = strings.TrimSpace(leEmail)
cfg.Email = leEmail return leEmail, nil
cfg.Agreed = true
// save the new user to preserve this for next time
user.Email = leEmail
err = cfg.saveUser(user)
if err != nil {
return "", err
}
}
// lower-casing the email is important for consistency
return strings.ToLower(leEmail), nil
} }
// getUser loads the user with the given email from disk // getUser loads the user with the given email from disk
......
...@@ -138,7 +138,7 @@ ...@@ -138,7 +138,7 @@
"importpath": "github.com/mholt/certmagic", "importpath": "github.com/mholt/certmagic",
"repository": "https://github.com/mholt/certmagic", "repository": "https://github.com/mholt/certmagic",
"vcs": "git", "vcs": "git",
"revision": "01ffe8b3c7d611483ef936e90845329709721127", "revision": "c1d472b46046ee329c099086d689ada0c44d56b0",
"branch": "master", "branch": "master",
"notests": true "notests": true
}, },
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment