Commit 1cfd960f authored by Matthew Holt's avatar Matthew Holt

Bug fixes and other improvements to TLS functions

Now attempt to staple OCSP even for certs that don't have an existing staple (issue #605). "tls off" short-circuits tls setup function. Now we call getEmail() when setting up an acme.Client that does renewals, rather than making a new account with empty email address. Check certificate expiry every 12 hours, and OCSP every hour.
parent 2dba4432
...@@ -24,7 +24,7 @@ var certCacheMu sync.RWMutex ...@@ -24,7 +24,7 @@ var certCacheMu sync.RWMutex
// we can be more efficient by extracting the metadata once so it's // we can be more efficient by extracting the metadata once so it's
// just there, ready to use. // just there, ready to use.
type Certificate struct { type Certificate struct {
*tls.Certificate tls.Certificate
// Names is the list of names this certificate is written for. // Names is the list of names this certificate is written for.
// The first is the CommonName (if any), the rest are SAN. // The first is the CommonName (if any), the rest are SAN.
...@@ -170,7 +170,6 @@ func makeCertificate(certPEMBlock, keyPEMBlock []byte) (Certificate, error) { ...@@ -170,7 +170,6 @@ func makeCertificate(certPEMBlock, keyPEMBlock []byte) (Certificate, error) {
if len(tlsCert.Certificate) == 0 { if len(tlsCert.Certificate) == 0 {
return cert, errors.New("certificate is empty") return cert, errors.New("certificate is empty")
} }
cert.Certificate = &tlsCert
// Parse leaf certificate and extract relevant metadata // Parse leaf certificate and extract relevant metadata
leaf, err := x509.ParseCertificate(tlsCert.Certificate[0]) leaf, err := x509.ParseCertificate(tlsCert.Certificate[0])
...@@ -198,6 +197,7 @@ func makeCertificate(certPEMBlock, keyPEMBlock []byte) (Certificate, error) { ...@@ -198,6 +197,7 @@ func makeCertificate(certPEMBlock, keyPEMBlock []byte) (Certificate, error) {
cert.OCSP = ocspResp cert.OCSP = ocspResp
} }
cert.Certificate = tlsCert
return cert, nil return cert, nil
} }
...@@ -213,7 +213,9 @@ func makeCertificate(certPEMBlock, keyPEMBlock []byte) (Certificate, error) { ...@@ -213,7 +213,9 @@ func makeCertificate(certPEMBlock, keyPEMBlock []byte) (Certificate, error) {
func cacheCertificate(cert Certificate) { func cacheCertificate(cert Certificate) {
certCacheMu.Lock() certCacheMu.Lock()
if _, ok := certCache[""]; !ok { if _, ok := certCache[""]; !ok {
certCache[""] = cert // use as default // use as default
certCache[""] = cert
cert.Names = append(cert.Names, "")
} }
for len(certCache)+len(cert.Names) > 10000 { for len(certCache)+len(cert.Names) > 10000 {
// for simplicity, just remove random elements // for simplicity, just remove random elements
......
...@@ -23,7 +23,7 @@ import ( ...@@ -23,7 +23,7 @@ import (
// This function is safe for use as a tls.Config.GetCertificate callback. // This function is safe for use as a tls.Config.GetCertificate callback.
func GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { func GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
cert, err := getCertDuringHandshake(clientHello.ServerName, false, false) cert, err := getCertDuringHandshake(clientHello.ServerName, false, false)
return cert.Certificate, err return &cert.Certificate, err
} }
// GetOrObtainCertificate will get a certificate to satisfy clientHello, even // GetOrObtainCertificate will get a certificate to satisfy clientHello, even
...@@ -35,7 +35,7 @@ func GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) ...@@ -35,7 +35,7 @@ func GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error)
// This function is safe for use as a tls.Config.GetCertificate callback. // This function is safe for use as a tls.Config.GetCertificate callback.
func GetOrObtainCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { func GetOrObtainCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
cert, err := getCertDuringHandshake(clientHello.ServerName, true, true) cert, err := getCertDuringHandshake(clientHello.ServerName, true, true)
return cert.Certificate, err return &cert.Certificate, err
} }
// getCertDuringHandshake will get a certificate for name. It first tries // getCertDuringHandshake will get a certificate for name. It first tries
...@@ -122,8 +122,8 @@ func checkLimitsForObtainingNewCerts(name string) error { ...@@ -122,8 +122,8 @@ func checkLimitsForObtainingNewCerts(name string) error {
} }
// obtainOnDemandCertificate obtains a certificate for name for the given // obtainOnDemandCertificate obtains a certificate for name for the given
// clientHello. If another goroutine has already started obtaining a cert // name. If another goroutine has already started obtaining a cert for
// for name, it will wait and use what the other goroutine obtained. // name, it will wait and use what the other goroutine obtained.
// //
// This function is safe for use by multiple concurrent goroutines. // This function is safe for use by multiple concurrent goroutines.
func obtainOnDemandCertificate(name string) (Certificate, error) { func obtainOnDemandCertificate(name string) (Certificate, error) {
...@@ -248,7 +248,7 @@ func renewDynamicCertificate(name string) (Certificate, error) { ...@@ -248,7 +248,7 @@ func renewDynamicCertificate(name string) (Certificate, error) {
log.Printf("[INFO] Renewing certificate for %s", name) log.Printf("[INFO] Renewing certificate for %s", name)
client, err := NewACMEClient("", false) // renewals don't use email client, err := NewACMEClientGetEmail(server.Config{}, false)
if err != nil { if err != nil {
return Certificate{}, err return Certificate{}, err
} }
...@@ -295,7 +295,7 @@ var obtainCertWaitChansMu sync.Mutex ...@@ -295,7 +295,7 @@ var obtainCertWaitChansMu sync.Mutex
// OnDemandIssuedCount is the number of certificates that have been issued // OnDemandIssuedCount is the number of certificates that have been issued
// on-demand by this process. It is only safe to modify this count atomically. // on-demand by this process. It is only safe to modify this count atomically.
// If it reaches max_certs, on-demand issuances will fail. // If it reaches onDemandMaxIssue, on-demand issuances will fail.
var OnDemandIssuedCount = new(int32) var OnDemandIssuedCount = new(int32)
// onDemandMaxIssue is set based on max_certs in tls config. It specifies the // onDemandMaxIssue is set based on max_certs in tls config. It specifies the
......
...@@ -12,7 +12,6 @@ import ( ...@@ -12,7 +12,6 @@ import (
"net/http" "net/http"
"os" "os"
"strings" "strings"
"time"
"github.com/mholt/caddy/middleware" "github.com/mholt/caddy/middleware"
"github.com/mholt/caddy/middleware/redirect" "github.com/mholt/caddy/middleware/redirect"
...@@ -215,7 +214,7 @@ func hostHasOtherPort(allConfigs []server.Config, thisConfigIdx int, otherPort s ...@@ -215,7 +214,7 @@ func hostHasOtherPort(allConfigs []server.Config, thisConfigIdx int, otherPort s
// all configs. // all configs.
func MakePlaintextRedirects(allConfigs []server.Config) []server.Config { func MakePlaintextRedirects(allConfigs []server.Config) []server.Config {
for i, cfg := range allConfigs { for i, cfg := range allConfigs {
if (cfg.TLS.Managed || cfg.TLS.OnDemand) && if cfg.TLS.Managed &&
!hostHasOtherPort(allConfigs, i, "80") && !hostHasOtherPort(allConfigs, i, "80") &&
(cfg.Port == "443" || !hostHasOtherPort(allConfigs, i, "443")) { (cfg.Port == "443" || !hostHasOtherPort(allConfigs, i, "443")) {
allConfigs = append(allConfigs, redirPlaintextHost(cfg)) allConfigs = append(allConfigs, redirPlaintextHost(cfg))
...@@ -233,15 +232,16 @@ func MakePlaintextRedirects(allConfigs []server.Config) []server.Config { ...@@ -233,15 +232,16 @@ func MakePlaintextRedirects(allConfigs []server.Config) []server.Config {
// setting up the config may make it look like it // setting up the config may make it look like it
// doesn't qualify even though it originally did. // doesn't qualify even though it originally did.
func ConfigQualifies(cfg server.Config) bool { func ConfigQualifies(cfg server.Config) bool {
return !cfg.TLS.Manual && // user can provide own cert and key return (!cfg.TLS.Manual || cfg.TLS.OnDemand) && // user might provide own cert and key
// user can force-disable automatic HTTPS for this host // user can force-disable automatic HTTPS for this host
cfg.Scheme != "http" && cfg.Scheme != "http" &&
cfg.Port != "80" && cfg.Port != "80" &&
cfg.TLS.LetsEncryptEmail != "off" && cfg.TLS.LetsEncryptEmail != "off" &&
// we get can't certs for some kinds of hostnames // we get can't certs for some kinds of hostnames, but
HostQualifies(cfg.Host) // on-demand TLS allows empty hostnames at startup
(HostQualifies(cfg.Host) || cfg.TLS.OnDemand)
} }
// HostQualifies returns true if the hostname alone // HostQualifies returns true if the hostname alone
...@@ -387,20 +387,11 @@ var ( ...@@ -387,20 +387,11 @@ var (
CAUrl string CAUrl string
) )
// Some essential values related to the Let's Encrypt process // AlternatePort is the port on which the acme client will open a
const ( // listener and solve the CA's challenges. If this alternate port
// AlternatePort is the port on which the acme client will open a // is used instead of the default port (80 or 443), then the
// listener and solve the CA's challenges. If this alternate port // default port for the challenge must be forwarded to this one.
// is used instead of the default port (80 or 443), then the const AlternatePort = "5033"
// default port for the challenge must be forwarded to this one.
AlternatePort = "5033"
// RenewInterval is how often to check certificates for renewal.
RenewInterval = 6 * time.Hour
// OCSPInterval is how often to check if OCSP stapling needs updating.
OCSPInterval = 1 * time.Hour
)
// KeySize represents the length of a key in bits. // KeySize represents the length of a key in bits.
type KeySize int type KeySize int
......
...@@ -4,9 +4,19 @@ import ( ...@@ -4,9 +4,19 @@ import (
"log" "log"
"time" "time"
"github.com/mholt/caddy/server"
"golang.org/x/crypto/ocsp" "golang.org/x/crypto/ocsp"
) )
const (
// RenewInterval is how often to check certificates for renewal.
RenewInterval = 12 * time.Hour
// OCSPInterval is how often to check if OCSP stapling needs updating.
OCSPInterval = 1 * time.Hour
)
// maintainAssets is a permanently-blocking function // maintainAssets is a permanently-blocking function
// that loops indefinitely and, on a regular schedule, checks // that loops indefinitely and, on a regular schedule, checks
// certificates for expiration and initiates a renewal of certs // certificates for expiration and initiates a renewal of certs
...@@ -28,7 +38,7 @@ func maintainAssets(stopChan chan struct{}) { ...@@ -28,7 +38,7 @@ func maintainAssets(stopChan chan struct{}) {
log.Println("[INFO] Done checking certificates") log.Println("[INFO] Done checking certificates")
case <-ocspTicker.C: case <-ocspTicker.C:
log.Println("[INFO] Scanning for stale OCSP staples") log.Println("[INFO] Scanning for stale OCSP staples")
updatePreloadedOCSPStaples() updateOCSPStaples()
log.Println("[INFO] Done checking OCSP staples") log.Println("[INFO] Done checking OCSP staples")
case <-stopChan: case <-stopChan:
renewalTicker.Stop() renewalTicker.Stop()
...@@ -70,7 +80,7 @@ func renewManagedCertificates(allowPrompts bool) (err error) { ...@@ -70,7 +80,7 @@ func renewManagedCertificates(allowPrompts bool) (err error) {
log.Printf("[INFO] Certificate for %v expires in %v; attempting renewal", cert.Names, timeLeft) log.Printf("[INFO] Certificate for %v expires in %v; attempting renewal", cert.Names, timeLeft)
if client == nil { if client == nil {
client, err = NewACMEClient("", allowPrompts) // renewals don't use email client, err = NewACMEClientGetEmail(server.Config{}, allowPrompts)
if err != nil { if err != nil {
return err return err
} }
...@@ -116,42 +126,66 @@ func renewManagedCertificates(allowPrompts bool) (err error) { ...@@ -116,42 +126,66 @@ func renewManagedCertificates(allowPrompts bool) (err error) {
return nil return nil
} }
func updatePreloadedOCSPStaples() { func updateOCSPStaples() {
// Create a temporary place to store updates // Create a temporary place to store updates
// until we release the potentially slow read // until we release the potentially long-lived
// lock so we can use a quick write lock. // read lock and use a short-lived write lock.
type ocspUpdate struct { type ocspUpdate struct {
rawBytes []byte rawBytes []byte
parsedResponse *ocsp.Response parsed *ocsp.Response
} }
updated := make(map[string]ocspUpdate) updated := make(map[string]ocspUpdate)
// A single SAN certificate maps to multiple names, so we use this
// set to make sure we don't waste cycles checking OCSP for the same
// certificate multiple times.
visited := make(map[string]struct{})
certCacheMu.RLock() certCacheMu.RLock()
for name, cert := range certCache { for name, cert := range certCache {
// we update OCSP for managed and un-managed certs here, but only // skip this certificate if we've already visited it,
// if it has OCSP stapled and only for pre-loaded certificates // and if not, mark all the names as visited
if cert.OnDemand || cert.OCSP == nil { if _, ok := visited[name]; ok {
continue
}
for _, n := range cert.Names {
visited[n] = struct{}{}
}
// no point in updating OCSP for expired certificates
if time.Now().After(cert.NotAfter) {
continue continue
} }
var lastNextUpdate time.Time
if cert.OCSP != nil {
// start checking OCSP staple about halfway through validity period for good measure // start checking OCSP staple about halfway through validity period for good measure
oldNextUpdate := cert.OCSP.NextUpdate lastNextUpdate = cert.OCSP.NextUpdate
refreshTime := cert.OCSP.ThisUpdate.Add(oldNextUpdate.Sub(cert.OCSP.ThisUpdate) / 2) refreshTime := cert.OCSP.ThisUpdate.Add(lastNextUpdate.Sub(cert.OCSP.ThisUpdate) / 2)
// since OCSP is already stapled, we need only check if we're in that "refresh window"
if time.Now().Before(refreshTime) {
continue
}
}
// only check for updated OCSP validity window if the refresh time is
// in the past and the certificate is not expired
if time.Now().After(refreshTime) && time.Now().Before(cert.NotAfter) {
err := stapleOCSP(&cert, nil) err := stapleOCSP(&cert, nil)
if err != nil { if err != nil {
if cert.OCSP != nil {
// if it was no staple before, that's fine, otherwise we should log the error
log.Printf("[ERROR] Checking OCSP for %s: %v", name, err) log.Printf("[ERROR] Checking OCSP for %s: %v", name, err)
}
continue continue
} }
// if the OCSP response has been updated, we use it // By this point, we've obtained the latest OCSP response.
if oldNextUpdate != cert.OCSP.NextUpdate { // If there was no staple before, or if the response is updated, make
log.Printf("[INFO] Moving validity period of OCSP staple for %s from %v to %v", // sure we apply the update to all names on the certificate.
name, oldNextUpdate, cert.OCSP.NextUpdate) if lastNextUpdate.IsZero() || lastNextUpdate != cert.OCSP.NextUpdate {
updated[name] = ocspUpdate{rawBytes: cert.Certificate.OCSPStaple, parsedResponse: cert.OCSP} log.Printf("[INFO] Advancing OCSP staple for %v from %s to %s",
cert.Names, lastNextUpdate, cert.OCSP.NextUpdate)
for _, n := range cert.Names {
updated[n] = ocspUpdate{rawBytes: cert.Certificate.OCSPStaple, parsed: cert.OCSP}
} }
} }
} }
...@@ -161,7 +195,7 @@ func updatePreloadedOCSPStaples() { ...@@ -161,7 +195,7 @@ func updatePreloadedOCSPStaples() {
certCacheMu.Lock() certCacheMu.Lock()
for name, update := range updated { for name, update := range updated {
cert := certCache[name] cert := certCache[name]
cert.OCSP = update.parsedResponse cert.OCSP = update.parsed
cert.Certificate.OCSPStaple = update.rawBytes cert.Certificate.OCSPStaple = update.rawBytes
certCache[name] = cert certCache[name] = cert
} }
......
...@@ -20,12 +20,12 @@ import ( ...@@ -20,12 +20,12 @@ import (
// are specified by the user in the config file. All the automatic HTTPS // are specified by the user in the config file. All the automatic HTTPS
// stuff comes later outside of this function. // stuff comes later outside of this function.
func Setup(c *setup.Controller) (middleware.Middleware, error) { func Setup(c *setup.Controller) (middleware.Middleware, error) {
if c.Scheme == "http" { if c.Port == "80" || c.Scheme == "http" {
c.TLS.Enabled = false c.TLS.Enabled = false
log.Printf("[WARNING] TLS disabled for %s://%s.", c.Scheme, c.Address()) log.Printf("[WARNING] TLS disabled for %s://%s.", c.Scheme, c.Address())
} else { return nil, nil
c.TLS.Enabled = true
} }
c.TLS.Enabled = true
for c.Next() { for c.Next() {
var certificateFile, keyFile, loadDir, maxCerts string var certificateFile, keyFile, loadDir, maxCerts string
...@@ -38,6 +38,7 @@ func Setup(c *setup.Controller) (middleware.Middleware, error) { ...@@ -38,6 +38,7 @@ func Setup(c *setup.Controller) (middleware.Middleware, error) {
// user can force-disable managed TLS this way // user can force-disable managed TLS this way
if c.TLS.LetsEncryptEmail == "off" { if c.TLS.LetsEncryptEmail == "off" {
c.TLS.Enabled = false c.TLS.Enabled = false
return nil, nil
} }
case 2: case 2:
certificateFile = args[0] certificateFile = args[0]
...@@ -120,9 +121,28 @@ func Setup(c *setup.Controller) (middleware.Middleware, error) { ...@@ -120,9 +121,28 @@ func Setup(c *setup.Controller) (middleware.Middleware, error) {
} }
// load a directory of certificates, if specified // load a directory of certificates, if specified
// modeled after haproxy: https://cbonte.github.io/haproxy-dconv/configuration-1.5.html#5.1-crt
if loadDir != "" { if loadDir != "" {
err := filepath.Walk(loadDir, func(path string, info os.FileInfo, err error) error { err := loadCertsInDir(c, loadDir)
if err != nil {
return nil, err
}
}
}
setDefaultTLSParams(c.Config)
return nil, nil
}
// loadCertsInDir loads all the certificates/keys in dir, as long as
// the file ends with .pem. This method of loading certificates is
// modeled after haproxy, which expects the certificate and key to
// be bundled into the same file:
// https://cbonte.github.io/haproxy-dconv/configuration-1.5.html#5.1-crt
//
// This function may write to the log as it walks the directory tree.
func loadCertsInDir(c *setup.Controller, dir string) error {
return filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
if err != nil { if err != nil {
log.Printf("[WARNING] Unable to traverse into %s; skipping", path) log.Printf("[WARNING] Unable to traverse into %s; skipping", path)
return nil return nil
...@@ -132,7 +152,7 @@ func Setup(c *setup.Controller) (middleware.Middleware, error) { ...@@ -132,7 +152,7 @@ func Setup(c *setup.Controller) (middleware.Middleware, error) {
} }
if strings.HasSuffix(strings.ToLower(info.Name()), ".pem") { if strings.HasSuffix(strings.ToLower(info.Name()), ".pem") {
certBuilder, keyBuilder := new(bytes.Buffer), new(bytes.Buffer) certBuilder, keyBuilder := new(bytes.Buffer), new(bytes.Buffer)
var foundKey bool var foundKey bool // use only the first key in the file
bundle, err := ioutil.ReadFile(path) bundle, err := ioutil.ReadFile(path)
if err != nil { if err != nil {
...@@ -151,8 +171,8 @@ func Setup(c *setup.Controller) (middleware.Middleware, error) { ...@@ -151,8 +171,8 @@ func Setup(c *setup.Controller) (middleware.Middleware, error) {
// Re-encode certificate as PEM, appending to certificate chain // Re-encode certificate as PEM, appending to certificate chain
pem.Encode(certBuilder, derBlock) pem.Encode(certBuilder, derBlock)
} else if derBlock.Type == "EC PARAMETERS" { } else if derBlock.Type == "EC PARAMETERS" {
// EC keys are composed of two blocks: parameters and key // EC keys generated from openssl can be composed of two blocks:
// (parameter block should come first) // parameters and key (parameter block should come first)
if !foundKey { if !foundKey {
// Encode parameters // Encode parameters
pem.Encode(keyBuilder, derBlock) pem.Encode(keyBuilder, derBlock)
...@@ -192,15 +212,6 @@ func Setup(c *setup.Controller) (middleware.Middleware, error) { ...@@ -192,15 +212,6 @@ func Setup(c *setup.Controller) (middleware.Middleware, error) {
} }
return nil return nil
}) })
if err != nil {
return nil, err
}
}
}
setDefaultTLSParams(c.Config)
return nil, nil
} }
// setDefaultTLSParams sets the default TLS cipher suites, protocol versions, // setDefaultTLSParams sets the default TLS cipher suites, protocol versions,
...@@ -231,7 +242,7 @@ func setDefaultTLSParams(c *server.Config) { ...@@ -231,7 +242,7 @@ func setDefaultTLSParams(c *server.Config) {
// Default TLS port is 443; only use if port is not manually specified, // Default TLS port is 443; only use if port is not manually specified,
// TLS is enabled, and the host is not localhost // TLS is enabled, and the host is not localhost
if c.Port == "" && c.TLS.Enabled && !c.TLS.Manual && c.Host != "localhost" { if c.Port == "" && c.TLS.Enabled && (!c.TLS.Manual || c.TLS.OnDemand) && c.Host != "localhost" {
c.Port = "443" c.Port = "443"
} }
} }
......
...@@ -68,7 +68,7 @@ type TLSConfig struct { ...@@ -68,7 +68,7 @@ type TLSConfig struct {
Enabled bool // will be set to true if TLS is enabled Enabled bool // will be set to true if TLS is enabled
LetsEncryptEmail string LetsEncryptEmail string
Manual bool // will be set to true if user provides own certs and keys Manual bool // will be set to true if user provides own certs and keys
Managed bool // will be set to true if config qualifies for automatic/managed HTTPS Managed bool // will be set to true if config qualifies for implicit automatic/managed HTTPS
OnDemand bool // will be set to true if user enables on-demand TLS (obtain certs during handshakes) OnDemand bool // will be set to true if user enables on-demand TLS (obtain certs during handshakes)
Ciphers []uint16 Ciphers []uint16
ProtocolMinVersion uint16 ProtocolMinVersion uint16
......
...@@ -63,15 +63,7 @@ func New(addr string, configs []Config, gracefulTimeout time.Duration) (*Server, ...@@ -63,15 +63,7 @@ func New(addr string, configs []Config, gracefulTimeout time.Duration) (*Server,
var useTLS, useOnDemandTLS bool var useTLS, useOnDemandTLS bool
if len(configs) > 0 { if len(configs) > 0 {
useTLS = configs[0].TLS.Enabled useTLS = configs[0].TLS.Enabled
if useTLS { useOnDemandTLS = configs[0].TLS.OnDemand
host, _, err := net.SplitHostPort(addr)
if err != nil {
host = addr
}
if host == "" && configs[0].TLS.OnDemand {
useOnDemandTLS = true
}
}
} }
s := &Server{ s := &Server{
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment