Commit 11103bd8 authored by Matthew Holt's avatar Matthew Holt

Major refactor of all HTTPS/TLS/ACME code

Biggest change is no longer using standard library's tls.Config.getCertificate function to get a certificate during TLS handshake. Implemented our own cache which can be changed dynamically at runtime, even during TLS handshakes. As such, restarts are no longer required after certificate renewals or OCSP updates.

We also allow loading multiple certificates and keys per host, even by specifying a directory (tls got a new 'load' command for that).

Renamed the letsencrypt package to https in a gradual effort to become more generic; and https is more fitting for what the package does now.

There are still some known bugs, e.g. reloading where a new certificate is required but port 80 isn't currently listening, will cause the challenge to fail. There's still plenty of cleanup to do and tests to write. It is especially confusing right now how we enable "on-demand" TLS during setup and keep track of that. But this change should basically work so far.
parent f1b2637d
...@@ -28,7 +28,7 @@ import ( ...@@ -28,7 +28,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/mholt/caddy/caddy/letsencrypt" "github.com/mholt/caddy/caddy/https"
"github.com/mholt/caddy/server" "github.com/mholt/caddy/server"
) )
...@@ -44,7 +44,7 @@ var ( ...@@ -44,7 +44,7 @@ var (
Quiet bool Quiet bool
// HTTP2 indicates whether HTTP2 is enabled or not. // HTTP2 indicates whether HTTP2 is enabled or not.
HTTP2 bool // TODO: temporary flag until http2 is standard HTTP2 bool
// PidFile is the path to the pidfile to create. // PidFile is the path to the pidfile to create.
PidFile string PidFile string
...@@ -191,9 +191,13 @@ func startServers(groupings bindingGroup) error { ...@@ -191,9 +191,13 @@ func startServers(groupings bindingGroup) error {
if err != nil { if err != nil {
return err return err
} }
s.HTTP2 = HTTP2 // TODO: This setting is temporary s.HTTP2 = HTTP2
s.ReqCallback = letsencrypt.RequestCallback // ensures we can solve ACME challenges while running s.ReqCallback = https.RequestCallback // ensures we can solve ACME challenges while running
s.SNICallback = letsencrypt.GetCertificateDuringHandshake // TLS on demand -- awesome! if s.OnDemandTLS {
s.TLSConfig.GetCertificate = https.GetOrObtainCertificate // TLS on demand -- awesome!
} else {
s.TLSConfig.GetCertificate = https.GetCertificate
}
var ln server.ListenerFile var ln server.ListenerFile
if IsRestart() { if IsRestart() {
...@@ -278,7 +282,7 @@ func startServers(groupings bindingGroup) error { ...@@ -278,7 +282,7 @@ func startServers(groupings bindingGroup) error {
// It does NOT execute shutdown callbacks that may have been // It does NOT execute shutdown callbacks that may have been
// configured by middleware (they must be executed separately). // configured by middleware (they must be executed separately).
func Stop() error { func Stop() error {
letsencrypt.Deactivate() https.Deactivate()
serversMu.Lock() serversMu.Lock()
for _, s := range servers { for _, s := range servers {
......
...@@ -4,10 +4,21 @@ import ( ...@@ -4,10 +4,21 @@ import (
"net/http" "net/http"
"testing" "testing"
"time" "time"
"github.com/mholt/caddy/caddy/https"
"github.com/xenolf/lego/acme"
) )
func TestCaddyStartStop(t *testing.T) { func TestCaddyStartStop(t *testing.T) {
caddyfile := "localhost:1984\ntls off" // Use fake ACME clients for testing
https.NewACMEClient = func(email string, allowPrompts bool) (*https.ACMEClient, error) {
return &https.ACMEClient{
Client: new(acme.Client),
AllowPrompts: allowPrompts,
}, nil
}
caddyfile := "localhost:1984"
for i := 0; i < 2; i++ { for i := 0; i < 2; i++ {
err := Start(CaddyfileInput{Contents: []byte(caddyfile)}) err := Start(CaddyfileInput{Contents: []byte(caddyfile)})
......
...@@ -8,7 +8,7 @@ import ( ...@@ -8,7 +8,7 @@ import (
"net" "net"
"sync" "sync"
"github.com/mholt/caddy/caddy/letsencrypt" "github.com/mholt/caddy/caddy/https"
"github.com/mholt/caddy/caddy/parse" "github.com/mholt/caddy/caddy/parse"
"github.com/mholt/caddy/caddy/setup" "github.com/mholt/caddy/caddy/setup"
"github.com/mholt/caddy/middleware" "github.com/mholt/caddy/middleware"
...@@ -128,7 +128,7 @@ func loadConfigs(filename string, input io.Reader) ([]server.Config, error) { ...@@ -128,7 +128,7 @@ func loadConfigs(filename string, input io.Reader) ([]server.Config, error) {
if !IsRestart() && !Quiet { if !IsRestart() && !Quiet {
fmt.Print("Activating privacy features...") fmt.Print("Activating privacy features...")
} }
configs, err = letsencrypt.Activate(configs) configs, err = https.Activate(configs)
if err != nil { if err != nil {
return nil, err return nil, err
} else if !IsRestart() && !Quiet { } else if !IsRestart() && !Quiet {
...@@ -318,7 +318,7 @@ func validDirective(d string) bool { ...@@ -318,7 +318,7 @@ func validDirective(d string) bool {
// root. // root.
func DefaultInput() CaddyfileInput { func DefaultInput() CaddyfileInput {
port := Port port := Port
if letsencrypt.HostQualifies(Host) && port == DefaultPort { if https.HostQualifies(Host) && port == DefaultPort {
port = "443" port = "443"
} }
return CaddyfileInput{ return CaddyfileInput{
......
package caddy package caddy
import ( import (
"github.com/mholt/caddy/caddy/https"
"github.com/mholt/caddy/caddy/parse" "github.com/mholt/caddy/caddy/parse"
"github.com/mholt/caddy/caddy/setup" "github.com/mholt/caddy/caddy/setup"
"github.com/mholt/caddy/middleware" "github.com/mholt/caddy/middleware"
...@@ -43,7 +44,7 @@ var directiveOrder = []directive{ ...@@ -43,7 +44,7 @@ var directiveOrder = []directive{
// Essential directives that initialize vital configuration settings // Essential directives that initialize vital configuration settings
{"root", setup.Root}, {"root", setup.Root},
{"bind", setup.BindHost}, {"bind", setup.BindHost},
{"tls", setup.TLS}, // letsencrypt is set up just after tls {"tls", https.Setup},
// Other directives that don't create HTTP handlers // Other directives that don't create HTTP handlers
{"startup", setup.Startup}, {"startup", setup.Startup},
......
...@@ -11,14 +11,8 @@ import ( ...@@ -11,14 +11,8 @@ import (
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"github.com/mholt/caddy/caddy/letsencrypt"
) )
func init() {
letsencrypt.OnChange = func() error { return Restart(nil) }
}
// isLocalhost returns true if host looks explicitly like a localhost address. // isLocalhost returns true if host looks explicitly like a localhost address.
func isLocalhost(host string) bool { func isLocalhost(host string) bool {
return host == "localhost" || host == "::1" || strings.HasPrefix(host, "127.") return host == "localhost" || host == "::1" || strings.HasPrefix(host, "127.")
......
package https
import (
"crypto/tls"
"crypto/x509"
"errors"
"io/ioutil"
"log"
"strings"
"sync"
"time"
"github.com/xenolf/lego/acme"
"golang.org/x/crypto/ocsp"
)
// certCache stores certificates in memory,
// keying certificates by name.
var certCache = make(map[string]Certificate)
var certCacheMu sync.RWMutex
// Certificate is a tls.Certificate with associated metadata tacked on.
// Even if the metadata can be obtained by parsing the certificate,
// we can be more efficient by extracting the metadata once so it's
// just there, ready to use.
type Certificate struct {
*tls.Certificate
// Names is the list of names this certificate is written for.
// The first is the CommonName (if any), the rest are SAN.
Names []string
// NotAfter is when the certificate expires.
NotAfter time.Time
// Managed certificates are certificates that Caddy is managing,
// as opposed to the user specifying a certificate and key file
// or directory and managing the certificate resources themselves.
Managed bool
// OnDemand certificates are obtained or loaded on-demand during TLS
// handshakes (as opposed to preloaded certificates, which are loaded
// at startup). If OnDemand is true, Managed must necessarily be true.
// OnDemand certificates are maintained in the background just like
// preloaded ones, however, if an OnDemand certificate fails to renew,
// it is removed from the in-memory cache.
OnDemand bool
// OCSP contains the certificate's parsed OCSP response.
OCSP *ocsp.Response
}
// getCertificate gets a certificate from the in-memory cache that
// matches name (a certificate name). Note that if name does not have
// an exact match, it will be checked against names of the form
// '*.example.com' (wildcard certificates) according to RFC 6125.
//
// If cert was found by matching name, matched will be returned true.
// If no match is found, the default certificate will be returned and
// matched will be returned as false. (The default certificate is the
// first one that entered the cache.) If the cache is empty (or there
// is no default certificate for some reason), matched will still be
// false, but cert.Certificate will be nil.
//
// The logic in this function is adapted from the Go standard library,
// which is by the Go Authors.
//
// This function is safe for concurrent use.
func getCertificate(name string) (cert Certificate, matched bool) {
// Not going to trim trailing dots here since RFC 3546 says,
// "The hostname is represented ... without a trailing dot."
// Just normalize to lowercase.
name = strings.ToLower(name)
certCacheMu.RLock()
defer certCacheMu.RUnlock()
// exact match? great, let's use it
if cert, ok := certCache[name]; ok {
return cert, true
}
// try replacing labels in the name with wildcards until we get a match
labels := strings.Split(name, ".")
for i := range labels {
labels[i] = "*"
candidate := strings.Join(labels, ".")
if cert, ok := certCache[candidate]; ok {
return cert, true
}
}
// if nothing matches, return the default certificate
cert = certCache[""]
return cert, false
}
// cacheManagedCertificate loads the certificate for domain into the
// cache, flagging it as Managed and, if onDemand is true, as OnDemand
// (meaning that it was obtained or loaded during a TLS handshake).
//
// This function is safe for concurrent use.
func cacheManagedCertificate(domain string, onDemand bool) (Certificate, error) {
cert, err := makeCertificateFromDisk(storage.SiteCertFile(domain), storage.SiteKeyFile(domain))
if err != nil {
return cert, err
}
cert.Managed = true
cert.OnDemand = onDemand
cacheCertificate(cert)
return cert, nil
}
// cacheUnmanagedCertificatePEMFile loads a certificate for host using certFile
// and keyFile, which must be in PEM format. It stores the certificate in
// memory. The Managed and OnDemand flags of the certificate will be set to
// false.
//
// This function is safe for concurrent use.
func cacheUnmanagedCertificatePEMFile(certFile, keyFile string) error {
cert, err := makeCertificateFromDisk(certFile, keyFile)
if err != nil {
return err
}
cacheCertificate(cert)
return nil
}
// cacheUnmanagedCertificatePEMBytes makes a certificate out of the PEM bytes
// of the certificate and key, then caches it in memory.
//
// This function is safe for concurrent use.
func cacheUnmanagedCertificatePEMBytes(certBytes, keyBytes []byte) error {
cert, err := makeCertificate(certBytes, keyBytes)
if err != nil {
return err
}
cacheCertificate(cert)
return nil
}
// makeCertificateFromDisk makes a Certificate by loading the
// certificate and key files. It fills out all the fields in
// the certificate except for the Managed and OnDemand flags.
// (It is up to the caller to set those.)
func makeCertificateFromDisk(certFile, keyFile string) (Certificate, error) {
certPEMBlock, err := ioutil.ReadFile(certFile)
if err != nil {
return Certificate{}, err
}
keyPEMBlock, err := ioutil.ReadFile(keyFile)
if err != nil {
return Certificate{}, err
}
return makeCertificate(certPEMBlock, keyPEMBlock)
}
// makeCertificate turns a certificate PEM bundle and a key PEM block into
// a Certificate, with OCSP and other relevant metadata tagged with it,
// except for the OnDemand and Managed flags. It is up to the caller to
// set those properties.
func makeCertificate(certPEMBlock, keyPEMBlock []byte) (Certificate, error) {
var cert Certificate
// Convert to a tls.Certificate
tlsCert, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock)
if err != nil {
return cert, err
}
if len(tlsCert.Certificate) == 0 {
return cert, errors.New("certificate is empty")
}
cert.Certificate = &tlsCert
// Parse leaf certificate and extract relevant metadata
leaf, err := x509.ParseCertificate(tlsCert.Certificate[0])
if err != nil {
return cert, err
}
if leaf.Subject.CommonName != "" {
cert.Names = []string{strings.ToLower(leaf.Subject.CommonName)}
}
for _, name := range leaf.DNSNames {
if name != leaf.Subject.CommonName {
cert.Names = append(cert.Names, strings.ToLower(name))
}
}
cert.NotAfter = leaf.NotAfter
// Staple OCSP
ocspBytes, ocspResp, err := acme.GetOCSPForCert(certPEMBlock)
if err != nil {
// An error here is not a problem because a certificate may simply
// not contain a link to an OCSP server. But we should log it anyway.
log.Printf("[WARNING] No OCSP stapling for %v: %v", cert.Names, err)
} else if ocspResp.Status == ocsp.Good {
tlsCert.OCSPStaple = ocspBytes
cert.OCSP = ocspResp
}
return cert, nil
}
// cacheCertificate adds cert to the in-memory cache. If the cache is
// empty, cert will be used as the default certificate. If the cache is
// full, random entries are deleted until there is room to map all the
// names on the certificate.
//
// This certificate will be keyed to the names in cert.Names. Any name
// that is already a key in the cache will be replaced with this cert.
//
// This function is safe for concurrent use.
func cacheCertificate(cert Certificate) {
certCacheMu.Lock()
if _, ok := certCache[""]; !ok {
certCache[""] = cert // use as default
}
for len(certCache)+len(cert.Names) > 10000 {
// for simplicity, just remove random elements
for key := range certCache {
if key == "" { // ... but not the default cert
continue
}
delete(certCache, key)
break
}
}
for _, name := range cert.Names {
certCache[name] = cert
}
certCacheMu.Unlock()
}
package https
import (
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net"
"sync"
"time"
"github.com/mholt/caddy/server"
"github.com/xenolf/lego/acme"
)
// acmeMu ensures that only one ACME challenge occurs at a time.
var acmeMu sync.Mutex
// ACMEClient is an acme.Client with custom state attached.
type ACMEClient struct {
*acme.Client
AllowPrompts bool // if false, we assume AlternatePort must be used
}
// NewACMEClient creates a new ACMEClient given an email and whether
// prompting the user is allowed. Clients should not be kept and
// re-used over long periods of time, but immediate re-use is more
// efficient than re-creating on every iteration.
var NewACMEClient = func(email string, allowPrompts bool) (*ACMEClient, error) {
// Look up or create the LE user account
leUser, err := getUser(email)
if err != nil {
return nil, err
}
// The client facilitates our communication with the CA server.
client, err := acme.NewClient(CAUrl, &leUser, rsaKeySizeToUse)
if err != nil {
return nil, err
}
// If not registered, the user must register an account with the CA
// and agree to terms
if leUser.Registration == nil {
reg, err := client.Register()
if err != nil {
return nil, errors.New("registration error: " + err.Error())
}
leUser.Registration = reg
if allowPrompts { // can't prompt a user who isn't there
if !Agreed && reg.TosURL == "" {
Agreed = promptUserAgreement(saURL, false) // TODO - latest URL
}
if !Agreed && reg.TosURL == "" {
return nil, errors.New("user must agree to terms")
}
}
err = client.AgreeToTOS()
if err != nil {
saveUser(leUser) // Might as well try, right?
return nil, errors.New("error agreeing to terms: " + err.Error())
}
// save user to the file system
err = saveUser(leUser)
if err != nil {
return nil, errors.New("could not save user: " + err.Error())
}
}
return &ACMEClient{
Client: client,
AllowPrompts: allowPrompts,
}, nil
}
// NewACMEClientGetEmail creates a new ACMEClient and gets an email
// address at the same time (a server config is required, since it
// may contain an email address in it).
func NewACMEClientGetEmail(config server.Config, allowPrompts bool) (*ACMEClient, error) {
return NewACMEClient(getEmail(config, allowPrompts), allowPrompts)
}
// Configure configures c according to bindHost, which is the host (not
// whole address) to bind the listener to in solving the http and tls-sni
// challenges.
func (c *ACMEClient) Configure(bindHost string) {
// If we allow prompts, operator must be present. In our case,
// that is synonymous with saying the server is not already
// started. So if the user is still there, we don't use
// AlternatePort because we don't need to proxy the challenges.
// Conversely, if the operator is not there, the server has
// already started and we need to proxy the challenge.
if c.AllowPrompts {
// Operator is present; server is not already listening
c.SetHTTPAddress(net.JoinHostPort(bindHost, ""))
c.SetTLSAddress(net.JoinHostPort(bindHost, ""))
//c.ExcludeChallenges([]acme.Challenge{acme.DNS01})
} else {
// Operator is not present; server is started, so proxy challenges
c.SetHTTPAddress(net.JoinHostPort(bindHost, AlternatePort))
c.SetTLSAddress(net.JoinHostPort(bindHost, AlternatePort))
//c.ExcludeChallenges([]acme.Challenge{acme.TLSSNI01, acme.DNS01})
}
c.ExcludeChallenges([]acme.Challenge{acme.TLSSNI01, acme.DNS01}) // TODO: can we proxy TLS challenges? and we should support DNS...
}
// Obtain obtains a single certificate for names. It stores the certificate
// on the disk if successful.
func (c *ACMEClient) Obtain(names []string) error {
Attempts:
for attempts := 0; attempts < 2; attempts++ {
acmeMu.Lock()
certificate, failures := c.ObtainCertificate(names, true, nil)
acmeMu.Unlock()
if len(failures) > 0 {
// Error - try to fix it or report it to the user and abort
var errMsg string // we'll combine all the failures into a single error message
var promptedForAgreement bool // only prompt user for agreement at most once
for errDomain, obtainErr := range failures {
// TODO: Double-check, will obtainErr ever be nil?
if tosErr, ok := obtainErr.(acme.TOSError); ok {
// Terms of Service agreement error; we can probably deal with this
if !Agreed && !promptedForAgreement && c.AllowPrompts {
Agreed = promptUserAgreement(tosErr.Detail, true) // TODO: Use latest URL
promptedForAgreement = true
}
if Agreed || !c.AllowPrompts {
err := c.AgreeToTOS()
if err != nil {
return errors.New("error agreeing to updated terms: " + err.Error())
}
continue Attempts
}
}
// If user did not agree or it was any other kind of error, just append to the list of errors
errMsg += "[" + errDomain + "] failed to get certificate: " + obtainErr.Error() + "\n"
}
return errors.New(errMsg)
}
// Success - immediately save the certificate resource
err := saveCertResource(certificate)
if err != nil {
return fmt.Errorf("error saving assets for %v: %v", names, err)
}
break
}
return nil
}
// Renew renews the managed certificate for name. Right now our storage
// mechanism only supports one name per certificate, so this function only
// accepts one domain as input. It can be easily modified to support SAN
// certificates if, one day, they become desperately needed enough that our
// storage mechanism is upgraded to be more complex to support SAN certs.
//
// Anyway, this function is safe for concurrent use.
func (c *ACMEClient) Renew(name string) error {
// Prepare for renewal (load PEM cert, key, and meta)
certBytes, err := ioutil.ReadFile(storage.SiteCertFile(name))
if err != nil {
return err
}
keyBytes, err := ioutil.ReadFile(storage.SiteKeyFile(name))
if err != nil {
return err
}
metaBytes, err := ioutil.ReadFile(storage.SiteMetaFile(name))
if err != nil {
return err
}
var certMeta acme.CertificateResource
err = json.Unmarshal(metaBytes, &certMeta)
certMeta.Certificate = certBytes
certMeta.PrivateKey = keyBytes
// Perform renewal and retry if necessary, but not too many times.
var newCertMeta acme.CertificateResource
var success bool
for attempts := 0; attempts < 2; attempts++ {
acmeMu.Lock()
newCertMeta, err = c.RenewCertificate(certMeta, true)
acmeMu.Unlock()
if err == nil {
success = true
break
}
// If the legal terms changed and need to be agreed to again,
// we can handle that.
if _, ok := err.(acme.TOSError); ok {
err := c.AgreeToTOS()
if err != nil {
return err
}
continue
}
// For any other kind of error, wait 10s and try again.
time.Sleep(10 * time.Second)
}
if !success {
return errors.New("too many renewal attempts; last error: " + err.Error())
}
return saveCertResource(newCertMeta)
}
package letsencrypt package https
import ( import (
"crypto/rsa" "crypto/rsa"
......
package letsencrypt package https
import ( import (
"bytes" "bytes"
......
package letsencrypt package https
import ( import (
"crypto/tls" "crypto/tls"
......
package letsencrypt package https
import ( import (
"net" "net"
......
package https
import (
"bytes"
"crypto/tls"
"encoding/pem"
"errors"
"fmt"
"log"
"sync"
"time"
"github.com/mholt/caddy/server"
"github.com/xenolf/lego/acme"
)
// GetCertificate gets a certificate to satisfy clientHello as long as
// the certificate is already cached in memory.
//
// This function is safe for use as a tls.Config.GetCertificate callback.
func GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
cert, err := getCertDuringHandshake(clientHello.ServerName, false)
return cert.Certificate, err
}
// GetOrObtainCertificate will get a certificate to satisfy clientHello, even
// if that means obtaining a new certificate from a CA during the handshake.
// It first checks the in-memory cache, then accesses disk, then accesses the
// network if it must. An obtained certificate will be stored on disk and
// cached in memory.
//
// This function is safe for use as a tls.Config.GetCertificate callback.
func GetOrObtainCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
cert, err := getCertDuringHandshake(clientHello.ServerName, true)
return cert.Certificate, err
}
// getCertDuringHandshake will get a certificate for name. It first tries
// the in-memory cache, then, if obtainIfNecessary is true, it goes to disk,
// then asks the CA for a certificate if necessary.
//
// This function is safe for concurrent use.
func getCertDuringHandshake(name string, obtainIfNecessary bool) (Certificate, error) {
// First check our in-memory cache to see if we've already loaded it
cert, ok := getCertificate(name)
if ok {
return cert, nil
}
if obtainIfNecessary {
// TODO: Mitigate abuse!
var err error
// Then check to see if we have one on disk
cert, err := cacheManagedCertificate(name, true)
if err != nil {
return cert, err
} else if cert.Certificate != nil {
cert, err := handshakeMaintenance(name, cert)
if err != nil {
log.Printf("[ERROR] Maintaining newly-loaded certificate for %s: %v", name, err)
}
return cert, err
}
// Only option left is to get one from LE, but the name has to qualify first
if !HostQualifies(name) {
return cert, errors.New("hostname '" + name + "' does not qualify for certificate")
}
// By this point, we need to obtain one from the CA.
return obtainOnDemandCertificate(name)
}
return Certificate{}, nil
}
// obtainOnDemandCertificate obtains a certificate for name for the given
// clientHello. If another goroutine has already started obtaining a cert
// for name, it will wait and use what the other goroutine obtained.
//
// This function is safe for use by multiple concurrent goroutines.
func obtainOnDemandCertificate(name string) (Certificate, error) {
// We must protect this process from happening concurrently, so synchronize.
obtainCertWaitChansMu.Lock()
wait, ok := obtainCertWaitChans[name]
if ok {
// lucky us -- another goroutine is already obtaining the certificate.
// wait for it to finish obtaining the cert and then we'll use it.
obtainCertWaitChansMu.Unlock()
<-wait
return getCertDuringHandshake(name, false) // passing in true might result in infinite loop if obtain failed
}
// looks like it's up to us to do all the work and obtain the cert
wait = make(chan struct{})
obtainCertWaitChans[name] = wait
obtainCertWaitChansMu.Unlock()
// Unblock waiters and delete waitgroup when we return
defer func() {
obtainCertWaitChansMu.Lock()
close(wait)
delete(obtainCertWaitChans, name)
obtainCertWaitChansMu.Unlock()
}()
log.Printf("[INFO] Obtaining new certificate for %s", name)
// obtain cert
client, err := NewACMEClientGetEmail(server.Config{}, false)
if err != nil {
return Certificate{}, errors.New("error creating client: " + err.Error())
}
client.Configure("") // TODO: which BindHost?
err = client.Obtain([]string{name})
if err != nil {
return Certificate{}, err
}
// The certificate is on disk; now just start over to load it and serve it
return getCertDuringHandshake(name, false) // pass in false as a fail-safe from infinite-looping
}
// handshakeMaintenance performs a check on cert for expiration and OCSP
// validity.
//
// This function is safe for use by multiple concurrent goroutines.
func handshakeMaintenance(name string, cert Certificate) (Certificate, error) {
// fmt.Println("ON-DEMAND CERT?", cert.OnDemand)
// if !cert.OnDemand {
// return cert, nil
// }
fmt.Println("Checking expiration of cert; on-demand:", cert.OnDemand)
// Check cert expiration
timeLeft := cert.NotAfter.Sub(time.Now().UTC())
if timeLeft < renewDurationBefore {
log.Printf("[INFO] Certificate for %v expires in %v; attempting renewal", cert.Names, timeLeft)
return renewDynamicCertificate(name)
}
// Check OCSP staple validity
if cert.OCSP != nil {
refreshTime := cert.OCSP.ThisUpdate.Add(cert.OCSP.NextUpdate.Sub(cert.OCSP.ThisUpdate) / 2)
if time.Now().After(refreshTime) {
err := stapleOCSP(&cert, nil)
if err != nil {
// An error with OCSP stapling is not the end of the world, and in fact, is
// quite common considering not all certs have issuer URLs that support it.
log.Printf("[ERROR] Getting OCSP for %s: %v", name, err)
}
certCacheMu.Lock()
certCache[name] = cert
certCacheMu.Unlock()
}
}
return cert, nil
}
// renewDynamicCertificate renews currentCert using the clientHello. It returns the
// certificate to use and an error, if any. currentCert may be returned even if an
// error occurs, since we perform renewals before they expire and it may still be
// usable. name should already be lower-cased before calling this function.
//
// This function is safe for use by multiple concurrent goroutines.
func renewDynamicCertificate(name string) (Certificate, error) {
obtainCertWaitChansMu.Lock()
wait, ok := obtainCertWaitChans[name]
if ok {
// lucky us -- another goroutine is already renewing the certificate.
// wait for it to finish, then we'll use the new one.
obtainCertWaitChansMu.Unlock()
<-wait
return getCertDuringHandshake(name, false)
}
// looks like it's up to us to do all the work and renew the cert
wait = make(chan struct{})
obtainCertWaitChans[name] = wait
obtainCertWaitChansMu.Unlock()
// unblock waiters and delete waitgroup when we return
defer func() {
obtainCertWaitChansMu.Lock()
close(wait)
delete(obtainCertWaitChans, name)
obtainCertWaitChansMu.Unlock()
}()
log.Printf("[INFO] Renewing certificate for %s", name)
client, err := NewACMEClient("", false) // renewals don't use email
if err != nil {
return Certificate{}, err
}
client.Configure("") // TODO: Bind address of relevant listener, yuck
err = client.Renew(name)
if err != nil {
return Certificate{}, err
}
return getCertDuringHandshake(name, false)
}
// stapleOCSP staples OCSP information to cert for hostname name.
// If you have it handy, you should pass in the PEM-encoded certificate
// bundle; otherwise the DER-encoded cert will have to be PEM-encoded.
// If you don't have the PEM blocks handy, just pass in nil.
//
// Errors here are not necessarily fatal, it could just be that the
// certificate doesn't have an issuer URL.
func stapleOCSP(cert *Certificate, pemBundle []byte) error {
if pemBundle == nil {
// The function in the acme package that gets OCSP requires a PEM-encoded cert
bundle := new(bytes.Buffer)
for _, derBytes := range cert.Certificate.Certificate {
pem.Encode(bundle, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
}
pemBundle = bundle.Bytes()
}
ocspBytes, ocspResp, err := acme.GetOCSPForCert(pemBundle)
if err != nil {
return err
}
cert.Certificate.OCSPStaple = ocspBytes
cert.OCSP = ocspResp
return nil
}
// obtainCertWaitChans is used to coordinate obtaining certs for each hostname.
var obtainCertWaitChans = make(map[string]chan struct{})
var obtainCertWaitChansMu sync.Mutex
package letsencrypt package https
import ( import (
"io/ioutil" "io/ioutil"
...@@ -48,9 +48,9 @@ func TestConfigQualifies(t *testing.T) { ...@@ -48,9 +48,9 @@ func TestConfigQualifies(t *testing.T) {
}{ }{
{server.Config{Host: ""}, true}, {server.Config{Host: ""}, true},
{server.Config{Host: "localhost"}, false}, {server.Config{Host: "localhost"}, false},
{server.Config{Host: "123.44.3.21"}, false},
{server.Config{Host: "example.com"}, true}, {server.Config{Host: "example.com"}, true},
{server.Config{Host: "example.com", TLS: server.TLSConfig{Certificate: "cert.pem"}}, false}, {server.Config{Host: "example.com", TLS: server.TLSConfig{Manual: true}}, false},
{server.Config{Host: "example.com", TLS: server.TLSConfig{Key: "key.pem"}}, false},
{server.Config{Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "off"}}, false}, {server.Config{Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "off"}}, false},
{server.Config{Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "foo@bar.com"}}, true}, {server.Config{Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "foo@bar.com"}}, true},
{server.Config{Host: "example.com", Scheme: "http"}, false}, {server.Config{Host: "example.com", Scheme: "http"}, false},
...@@ -257,27 +257,14 @@ func TestEnableTLS(t *testing.T) { ...@@ -257,27 +257,14 @@ func TestEnableTLS(t *testing.T) {
server.Config{}, // not managed - no changes! server.Config{}, // not managed - no changes!
} }
EnableTLS(configs) EnableTLS(configs, false)
if !configs[0].TLS.Enabled { if !configs[0].TLS.Enabled {
t.Errorf("Expected config 0 to have TLS.Enabled == true, but it was false") t.Errorf("Expected config 0 to have TLS.Enabled == true, but it was false")
} }
if configs[0].TLS.Certificate == "" {
t.Errorf("Expected config 0 to have TLS.Certificate set, but it was empty")
}
if configs[0].TLS.Key == "" {
t.Errorf("Expected config 0 to have TLS.Key set, but it was empty")
}
if configs[1].TLS.Enabled { if configs[1].TLS.Enabled {
t.Errorf("Expected config 1 to have TLS.Enabled == false, but it was true") t.Errorf("Expected config 1 to have TLS.Enabled == false, but it was true")
} }
if configs[1].TLS.Certificate != "" {
t.Errorf("Expected config 1 to have TLS.Certificate empty, but it was: %s", configs[1].TLS.Certificate)
}
if configs[1].TLS.Key != "" {
t.Errorf("Expected config 1 to have TLS.Key empty, but it was: %s", configs[1].TLS.Key)
}
} }
func TestGroupConfigsByEmail(t *testing.T) { func TestGroupConfigsByEmail(t *testing.T) {
...@@ -316,9 +303,9 @@ func TestMarkQualified(t *testing.T) { ...@@ -316,9 +303,9 @@ func TestMarkQualified(t *testing.T) {
// TODO: TestConfigQualifies and this test share the same config list... // TODO: TestConfigQualifies and this test share the same config list...
configs := []server.Config{ configs := []server.Config{
{Host: "localhost"}, {Host: "localhost"},
{Host: "123.44.3.21"},
{Host: "example.com"}, {Host: "example.com"},
{Host: "example.com", TLS: server.TLSConfig{Certificate: "cert.pem"}}, {Host: "example.com", TLS: server.TLSConfig{Manual: true}},
{Host: "example.com", TLS: server.TLSConfig{Key: "key.pem"}},
{Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "off"}}, {Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "off"}},
{Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "foo@bar.com"}}, {Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "foo@bar.com"}},
{Host: "example.com", Scheme: "http"}, {Host: "example.com", Scheme: "http"},
......
package https
import (
"log"
"time"
"golang.org/x/crypto/ocsp"
)
// maintainAssets is a permanently-blocking function
// that loops indefinitely and, on a regular schedule, checks
// certificates for expiration and initiates a renewal of certs
// that are expiring soon. It also updates OCSP stapling and
// performs other maintenance of assets.
//
// You must pass in the channel which you'll close when
// maintenance should stop, to allow this goroutine to clean up
// after itself and unblock.
func maintainAssets(stopChan chan struct{}) {
renewalTicker := time.NewTicker(RenewInterval)
ocspTicker := time.NewTicker(OCSPInterval)
for {
select {
case <-renewalTicker.C:
log.Println("[INFO] Scanning for expiring certificates")
client, err := NewACMEClient("", false) // renewals don't use email
if err != nil {
log.Printf("[ERROR] Creating client for renewals: %v", err)
continue
}
client.Configure("") // TODO: Bind address of relevant listener, yuck
renewManagedCertificates(client)
log.Println("[INFO] Done checking certificates")
case <-ocspTicker.C:
log.Println("[INFO] Scanning for stale OCSP staples")
updatePreloadedOCSPStaples()
log.Println("[INFO] Done checking OCSP staples")
case <-stopChan:
renewalTicker.Stop()
ocspTicker.Stop()
log.Println("[INFO] Stopped background maintenance routine")
return
}
}
}
func renewManagedCertificates(client *ACMEClient) error {
var renewed, deleted []Certificate
visitedNames := make(map[string]struct{})
certCacheMu.RLock()
for name, cert := range certCache {
if !cert.Managed {
continue
}
// the list of names on this cert should never be empty...
if cert.Names == nil || len(cert.Names) == 0 {
log.Printf("[WARNING] Certificate keyed by '%s' has no names: %v", name, cert.Names)
deleted = append(deleted, cert)
continue
}
// skip names whose certificate we've already renewed
if _, ok := visitedNames[name]; ok {
continue
}
for _, name := range cert.Names {
visitedNames[name] = struct{}{}
}
timeLeft := cert.NotAfter.Sub(time.Now().UTC())
if timeLeft < renewDurationBefore {
log.Printf("[INFO] Certificate for %v expires in %v; attempting renewal", cert.Names, timeLeft)
err := client.Renew(cert.Names[0]) // managed certs better have only one name
if err != nil {
if client.AllowPrompts {
// User is present, so stop immediately and report the error
certCacheMu.RUnlock()
return err
}
log.Printf("[ERROR] %v", err)
if cert.OnDemand {
deleted = append(deleted, cert)
}
} else {
renewed = append(renewed, cert)
}
}
}
certCacheMu.RUnlock()
// Apply changes to the cache
for _, cert := range renewed {
_, err := cacheManagedCertificate(cert.Names[0], cert.OnDemand)
if err != nil {
if client.AllowPrompts {
return err // operator is present, so report error immediately
}
log.Printf("[ERROR] %v", err)
}
}
for _, cert := range deleted {
certCacheMu.Lock()
for _, name := range cert.Names {
delete(certCache, name)
}
certCacheMu.Unlock()
}
return nil
}
func updatePreloadedOCSPStaples() {
// Create a temporary place to store updates
// until we release the potentially slow read
// lock so we can use a quick write lock.
type ocspUpdate struct {
rawBytes []byte
parsedResponse *ocsp.Response
}
updated := make(map[string]ocspUpdate)
certCacheMu.RLock()
for name, cert := range certCache {
// we update OCSP for managed and un-managed certs here, but only
// if it has OCSP stapled and only for pre-loaded certificates
if cert.OnDemand || cert.OCSP == nil {
continue
}
// start checking OCSP staple about halfway through validity period for good measure
oldNextUpdate := cert.OCSP.NextUpdate
refreshTime := cert.OCSP.ThisUpdate.Add(oldNextUpdate.Sub(cert.OCSP.ThisUpdate) / 2)
// only check for updated OCSP validity window if the refresh time is
// in the past and the certificate is not expired
if time.Now().After(refreshTime) && time.Now().Before(cert.NotAfter) {
err := stapleOCSP(&cert, nil)
if err != nil {
log.Printf("[ERROR] Checking OCSP for %s: %v", name, err)
continue
}
// if the OCSP response has been updated, we use it
if oldNextUpdate != cert.OCSP.NextUpdate {
log.Printf("[INFO] Moving validity period of OCSP staple for %s from %v to %v",
name, oldNextUpdate, cert.OCSP.NextUpdate)
updated[name] = ocspUpdate{rawBytes: cert.Certificate.OCSPStaple, parsedResponse: cert.OCSP}
}
}
}
certCacheMu.RUnlock()
// This write lock should be brief since we have all the info we need now.
certCacheMu.Lock()
for name, update := range updated {
cert := certCache[name]
cert.OCSP = update.parsedResponse
cert.Certificate.OCSPStaple = update.rawBytes
certCache[name] = cert
}
certCacheMu.Unlock()
}
// renewDurationBefore is how long before expiration to renew certificates.
const renewDurationBefore = (24 * time.Hour) * 30
package setup package https
import ( import (
"bytes"
"crypto/tls" "crypto/tls"
"encoding/pem"
"io/ioutil"
"log" "log"
"os"
"path/filepath"
"strings" "strings"
"github.com/mholt/caddy/caddy/setup"
"github.com/mholt/caddy/middleware" "github.com/mholt/caddy/middleware"
"github.com/mholt/caddy/server" "github.com/mholt/caddy/server"
) )
// TLS sets up the TLS configuration (but does not activate Let's Encrypt; that is handled elsewhere). // Setup sets up the TLS configuration and installs certificates that
func TLS(c *Controller) (middleware.Middleware, error) { // are specified by the user in the config file. All the automatic HTTPS
// stuff comes later outside of this function.
func Setup(c *setup.Controller) (middleware.Middleware, error) {
if c.Scheme == "http" { if c.Scheme == "http" {
c.TLS.Enabled = false c.TLS.Enabled = false
log.Printf("[WARNING] TLS disabled for %s://%s.", c.Scheme, c.Address()) log.Printf("[WARNING] TLS disabled for %s://%s.", c.Scheme, c.Address())
...@@ -19,18 +27,21 @@ func TLS(c *Controller) (middleware.Middleware, error) { ...@@ -19,18 +27,21 @@ func TLS(c *Controller) (middleware.Middleware, error) {
} }
for c.Next() { for c.Next() {
var certificateFile, keyFile, loadDir string
args := c.RemainingArgs() args := c.RemainingArgs()
switch len(args) { switch len(args) {
case 1: case 1:
c.TLS.LetsEncryptEmail = args[0] c.TLS.LetsEncryptEmail = args[0]
// user can force-disable LE activation this way // user can force-disable managed TLS this way
if c.TLS.LetsEncryptEmail == "off" { if c.TLS.LetsEncryptEmail == "off" {
c.TLS.Enabled = false c.TLS.Enabled = false
} }
case 2: case 2:
c.TLS.Certificate = args[0] certificateFile = args[0]
c.TLS.Key = args[1] keyFile = args[1]
c.TLS.Manual = true
} }
// Optional block with extra parameters // Optional block with extra parameters
...@@ -66,9 +77,9 @@ func TLS(c *Controller) (middleware.Middleware, error) { ...@@ -66,9 +77,9 @@ func TLS(c *Controller) (middleware.Middleware, error) {
if len(c.TLS.ClientCerts) == 0 { if len(c.TLS.ClientCerts) == 0 {
return nil, c.ArgErr() return nil, c.ArgErr()
} }
// TODO: Allow this? It's a bad idea to allow HTTP. If we do this, make sure invoking tls at all (even manually) also sets up a redirect if possible? case "load":
// case "allow_http": c.Args(&loadDir)
// c.TLS.DisableHTTPRedir = true c.TLS.Manual = true
default: default:
return nil, c.Errf("Unknown keyword '%s'", c.Val()) return nil, c.Errf("Unknown keyword '%s'", c.Val())
} }
...@@ -78,18 +89,112 @@ func TLS(c *Controller) (middleware.Middleware, error) { ...@@ -78,18 +89,112 @@ func TLS(c *Controller) (middleware.Middleware, error) {
if len(args) == 0 && !hadBlock { if len(args) == 0 && !hadBlock {
return nil, c.ArgErr() return nil, c.ArgErr()
} }
// don't load certificates unless we're supposed to
if !c.TLS.Enabled || !c.TLS.Manual {
continue
}
// load a single certificate and key, if specified
if certificateFile != "" && keyFile != "" {
err := cacheUnmanagedCertificatePEMFile(certificateFile, keyFile)
if err != nil {
return nil, c.Errf("Unable to load certificate and key files for %s: %v", c.Host, err)
}
log.Printf("[INFO] Successfully loaded TLS assets from %s and %s", certificateFile, keyFile)
}
// load a directory of certificates, if specified
// modeled after haproxy: https://cbonte.github.io/haproxy-dconv/configuration-1.5.html#5.1-crt
if loadDir != "" {
err := filepath.Walk(loadDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
log.Printf("[WARNING] Unable to traverse into %s; skipping", path)
return nil
}
if info.IsDir() {
return nil
}
if strings.HasSuffix(strings.ToLower(info.Name()), ".pem") {
certBuilder, keyBuilder := new(bytes.Buffer), new(bytes.Buffer)
var foundKey bool
bundle, err := ioutil.ReadFile(path)
if err != nil {
return err
}
for {
// Decode next block so we can see what type it is
var derBlock *pem.Block
derBlock, bundle = pem.Decode(bundle)
if derBlock == nil {
break
}
if derBlock.Type == "CERTIFICATE" {
// Re-encode certificate as PEM, appending to certificate chain
pem.Encode(certBuilder, derBlock)
} else if derBlock.Type == "EC PARAMETERS" {
// EC keys are composed of two blocks: parameters and key
// (parameter block should come first)
if !foundKey {
// Encode parameters
pem.Encode(keyBuilder, derBlock)
// Key must immediately follow
derBlock, bundle = pem.Decode(bundle)
if derBlock == nil || derBlock.Type != "EC PRIVATE KEY" {
return c.Errf("%s: expected elliptic private key to immediately follow EC parameters", path)
}
pem.Encode(keyBuilder, derBlock)
foundKey = true
}
} else if derBlock.Type == "PRIVATE KEY" || strings.HasSuffix(derBlock.Type, " PRIVATE KEY") {
// RSA key
if !foundKey {
pem.Encode(keyBuilder, derBlock)
foundKey = true
}
} else {
return c.Errf("%s: unrecognized PEM block type: %s", path, derBlock.Type)
}
}
certPEMBytes, keyPEMBytes := certBuilder.Bytes(), keyBuilder.Bytes()
if len(certPEMBytes) == 0 {
return c.Errf("%s: failed to parse PEM data", path)
}
if len(keyPEMBytes) == 0 {
return c.Errf("%s: no private key block found", path)
}
err = cacheUnmanagedCertificatePEMBytes(certPEMBytes, keyPEMBytes)
if err != nil {
return c.Errf("%s: failed to load cert and key for %s: %v", path, c.Host, err)
}
log.Printf("[INFO] Successfully loaded TLS assets from %s", path)
}
return nil
})
if err != nil {
return nil, err
}
}
} }
SetDefaultTLSParams(c.Config) setDefaultTLSParams(c.Config)
return nil, nil return nil, nil
} }
// SetDefaultTLSParams sets the default TLS cipher suites, protocol versions, // setDefaultTLSParams sets the default TLS cipher suites, protocol versions,
// and server preferences of a server.Config if they were not previously set // and server preferences of a server.Config if they were not previously set
// (it does not overwrite; only fills in missing values). // (it does not overwrite; only fills in missing values). It will also set the
func SetDefaultTLSParams(c *server.Config) { // port to 443 if not already set, TLS is enabled, TLS is manual, and the host
// If no ciphers provided, use all that Caddy supports for the protocol // does not equal localhost.
func setDefaultTLSParams(c *server.Config) {
// If no ciphers provided, use default list
if len(c.TLS.Ciphers) == 0 { if len(c.TLS.Ciphers) == 0 {
c.TLS.Ciphers = defaultCiphers c.TLS.Ciphers = defaultCiphers
} }
...@@ -111,14 +216,14 @@ func SetDefaultTLSParams(c *server.Config) { ...@@ -111,14 +216,14 @@ func SetDefaultTLSParams(c *server.Config) {
// Default TLS port is 443; only use if port is not manually specified, // Default TLS port is 443; only use if port is not manually specified,
// TLS is enabled, and the host is not localhost // TLS is enabled, and the host is not localhost
if c.Port == "" && c.TLS.Enabled && c.Host != "localhost" { if c.Port == "" && c.TLS.Enabled && !c.TLS.Manual && c.Host != "localhost" {
c.Port = "443" c.Port = "443"
} }
} }
// Map of supported protocols // Map of supported protocols.
// SSLv3 will be not supported in future release // SSLv3 will be not supported in future release.
// HTTP/2 only supports TLS 1.2 and higher // HTTP/2 only supports TLS 1.2 and higher.
var supportedProtocols = map[string]uint16{ var supportedProtocols = map[string]uint16{
"ssl3.0": tls.VersionSSL30, "ssl3.0": tls.VersionSSL30,
"tls1.0": tls.VersionTLS10, "tls1.0": tls.VersionTLS10,
......
package setup package https
import ( import (
"crypto/tls" "crypto/tls"
"io/ioutil"
"log"
"os"
"testing" "testing"
"github.com/mholt/caddy/caddy/setup"
) )
func TestTLSParseBasic(t *testing.T) { func TestMain(m *testing.M) {
c := NewTestController(`tls cert.pem key.pem`) // Write test certificates to disk before tests, and clean up
// when we're done.
err := ioutil.WriteFile(certFile, testCert, 0644)
if err != nil {
log.Fatal(err)
}
err = ioutil.WriteFile(keyFile, testKey, 0644)
if err != nil {
os.Remove(certFile)
log.Fatal(err)
}
result := m.Run()
_, err := TLS(c) os.Remove(certFile)
os.Remove(keyFile)
os.Exit(result)
}
func TestSetupParseBasic(t *testing.T) {
c := setup.NewTestController(`tls ` + certFile + ` ` + keyFile + ``)
_, err := Setup(c)
if err != nil { if err != nil {
t.Errorf("Expected no errors, got: %v", err) t.Errorf("Expected no errors, got: %v", err)
} }
// Basic checks // Basic checks
if c.TLS.Certificate != "cert.pem" { if !c.TLS.Manual {
t.Errorf("Expected certificate arg to be 'cert.pem', was '%s'", c.TLS.Certificate) t.Error("Expected TLS Manual=true, but was false")
}
if c.TLS.Key != "key.pem" {
t.Errorf("Expected key arg to be 'key.pem', was '%s'", c.TLS.Key)
} }
if !c.TLS.Enabled { if !c.TLS.Enabled {
t.Error("Expected TLS Enabled=true, but was false") t.Error("Expected TLS Enabled=true, but was false")
...@@ -63,23 +85,23 @@ func TestTLSParseBasic(t *testing.T) { ...@@ -63,23 +85,23 @@ func TestTLSParseBasic(t *testing.T) {
} }
} }
func TestTLSParseIncompleteParams(t *testing.T) { func TestSetupParseIncompleteParams(t *testing.T) {
// Using tls without args is an error because it's unnecessary. // Using tls without args is an error because it's unnecessary.
c := NewTestController(`tls`) c := setup.NewTestController(`tls`)
_, err := TLS(c) _, err := Setup(c)
if err == nil { if err == nil {
t.Error("Expected an error, but didn't get one") t.Error("Expected an error, but didn't get one")
} }
} }
func TestTLSParseWithOptionalParams(t *testing.T) { func TestSetupParseWithOptionalParams(t *testing.T) {
params := `tls cert.crt cert.key { params := `tls ` + certFile + ` ` + keyFile + ` {
protocols ssl3.0 tls1.2 protocols ssl3.0 tls1.2
ciphers RSA-3DES-EDE-CBC-SHA RSA-AES256-CBC-SHA ECDHE-RSA-AES128-GCM-SHA256 ciphers RSA-3DES-EDE-CBC-SHA RSA-AES256-CBC-SHA ECDHE-RSA-AES128-GCM-SHA256
}` }`
c := NewTestController(params) c := setup.NewTestController(params)
_, err := TLS(c) _, err := Setup(c)
if err != nil { if err != nil {
t.Errorf("Expected no errors, got: %v", err) t.Errorf("Expected no errors, got: %v", err)
} }
...@@ -97,13 +119,13 @@ func TestTLSParseWithOptionalParams(t *testing.T) { ...@@ -97,13 +119,13 @@ func TestTLSParseWithOptionalParams(t *testing.T) {
} }
} }
func TestTLSDefaultWithOptionalParams(t *testing.T) { func TestSetupDefaultWithOptionalParams(t *testing.T) {
params := `tls { params := `tls {
ciphers RSA-3DES-EDE-CBC-SHA ciphers RSA-3DES-EDE-CBC-SHA
}` }`
c := NewTestController(params) c := setup.NewTestController(params)
_, err := TLS(c) _, err := Setup(c)
if err != nil { if err != nil {
t.Errorf("Expected no errors, got: %v", err) t.Errorf("Expected no errors, got: %v", err)
} }
...@@ -113,7 +135,7 @@ func TestTLSDefaultWithOptionalParams(t *testing.T) { ...@@ -113,7 +135,7 @@ func TestTLSDefaultWithOptionalParams(t *testing.T) {
} }
// TODO: If we allow this... but probably not a good idea. // TODO: If we allow this... but probably not a good idea.
// func TestTLSDisableHTTPRedirect(t *testing.T) { // func TestSetupDisableHTTPRedirect(t *testing.T) {
// c := NewTestController(`tls { // c := NewTestController(`tls {
// allow_http // allow_http
// }`) // }`)
...@@ -126,34 +148,34 @@ func TestTLSDefaultWithOptionalParams(t *testing.T) { ...@@ -126,34 +148,34 @@ func TestTLSDefaultWithOptionalParams(t *testing.T) {
// } // }
// } // }
func TestTLSParseWithWrongOptionalParams(t *testing.T) { func TestSetupParseWithWrongOptionalParams(t *testing.T) {
// Test protocols wrong params // Test protocols wrong params
params := `tls cert.crt cert.key { params := `tls ` + certFile + ` ` + keyFile + ` {
protocols ssl tls protocols ssl tls
}` }`
c := NewTestController(params) c := setup.NewTestController(params)
_, err := TLS(c) _, err := Setup(c)
if err == nil { if err == nil {
t.Errorf("Expected errors, but no error returned") t.Errorf("Expected errors, but no error returned")
} }
// Test ciphers wrong params // Test ciphers wrong params
params = `tls cert.crt cert.key { params = `tls ` + certFile + ` ` + keyFile + ` {
ciphers not-valid-cipher ciphers not-valid-cipher
}` }`
c = NewTestController(params) c = setup.NewTestController(params)
_, err = TLS(c) _, err = Setup(c)
if err == nil { if err == nil {
t.Errorf("Expected errors, but no error returned") t.Errorf("Expected errors, but no error returned")
} }
} }
func TestTLSParseWithClientAuth(t *testing.T) { func TestSetupParseWithClientAuth(t *testing.T) {
params := `tls cert.crt cert.key { params := `tls ` + certFile + ` ` + keyFile + ` {
clients client_ca.crt client2_ca.crt clients client_ca.crt client2_ca.crt
}` }`
c := NewTestController(params) c := setup.NewTestController(params)
_, err := TLS(c) _, err := Setup(c)
if err != nil { if err != nil {
t.Errorf("Expected no errors, got: %v", err) t.Errorf("Expected no errors, got: %v", err)
} }
...@@ -169,12 +191,40 @@ func TestTLSParseWithClientAuth(t *testing.T) { ...@@ -169,12 +191,40 @@ func TestTLSParseWithClientAuth(t *testing.T) {
} }
// Test missing client cert file // Test missing client cert file
params = `tls cert.crt cert.key { params = `tls ` + certFile + ` ` + keyFile + ` {
clients clients
}` }`
c = NewTestController(params) c = setup.NewTestController(params)
_, err = TLS(c) _, err = Setup(c)
if err == nil { if err == nil {
t.Errorf("Expected an error, but no error returned") t.Errorf("Expected an error, but no error returned")
} }
} }
const (
certFile = "test_cert.pem"
keyFile = "test_key.pem"
)
var testCert = []byte(`-----BEGIN CERTIFICATE-----
MIIBkjCCATmgAwIBAgIJANfFCBcABL6LMAkGByqGSM49BAEwFDESMBAGA1UEAxMJ
bG9jYWxob3N0MB4XDTE2MDIxMDIyMjAyNFoXDTE4MDIwOTIyMjAyNFowFDESMBAG
A1UEAxMJbG9jYWxob3N0MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEs22MtnG7
9K1mvIyjEO9GLx7BFD0tBbGnwQ0VPsuCxC6IeVuXbQDLSiVQvFZ6lUszTlczNxVk
pEfqrM6xAupB7qN1MHMwHQYDVR0OBBYEFHxYDvAxUwL4XrjPev6qZ/BiLDs5MEQG
A1UdIwQ9MDuAFHxYDvAxUwL4XrjPev6qZ/BiLDs5oRikFjAUMRIwEAYDVQQDEwls
b2NhbGhvc3SCCQDXxQgXAAS+izAMBgNVHRMEBTADAQH/MAkGByqGSM49BAEDSAAw
RQIgRvBqbyJM2JCJqhA1FmcoZjeMocmhxQHTt1c+1N2wFUgCIQDtvrivbBPA688N
Qh3sMeAKNKPsx5NxYdoWuu9KWcKz9A==
-----END CERTIFICATE-----
`)
var testKey = []byte(`-----BEGIN EC PARAMETERS-----
BggqhkjOPQMBBw==
-----END EC PARAMETERS-----
-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIGLtRmwzYVcrH3J0BnzYbGPdWVF10i9p6mxkA4+b2fURoAoGCCqGSM49
AwEHoUQDQgAEs22MtnG79K1mvIyjEO9GLx7BFD0tBbGnwQ0VPsuCxC6IeVuXbQDL
SiVQvFZ6lUszTlczNxVkpEfqrM6xAupB7g==
-----END EC PRIVATE KEY-----
`)
package letsencrypt package https
import ( import (
"path/filepath" "path/filepath"
......
package letsencrypt package https
import ( import (
"path/filepath" "path/filepath"
......
package letsencrypt package https
import ( import (
"bufio" "bufio"
...@@ -41,7 +41,7 @@ func (u User) GetPrivateKey() *rsa.PrivateKey { ...@@ -41,7 +41,7 @@ func (u User) GetPrivateKey() *rsa.PrivateKey {
// getUser loads the user with the given email from disk. // getUser loads the user with the given email from disk.
// If the user does not exist, it will create a new one, // If the user does not exist, it will create a new one,
// but it does NOT save new users to the disk or register // but it does NOT save new users to the disk or register
// them via ACME. // them via ACME. It does NOT prompt the user.
func getUser(email string) (User, error) { func getUser(email string) (User, error) {
var user User var user User
...@@ -72,7 +72,8 @@ func getUser(email string) (User, error) { ...@@ -72,7 +72,8 @@ func getUser(email string) (User, error) {
} }
// saveUser persists a user's key and account registration // saveUser persists a user's key and account registration
// to the file system. It does NOT register the user via ACME. // to the file system. It does NOT register the user via ACME
// or prompt the user.
func saveUser(user User) error { func saveUser(user User) error {
// make user account folder // make user account folder
err := os.MkdirAll(storage.User(user.Email), 0700) err := os.MkdirAll(storage.User(user.Email), 0700)
...@@ -99,7 +100,7 @@ func saveUser(user User) error { ...@@ -99,7 +100,7 @@ func saveUser(user User) error {
// with a new private key. This function does NOT save the // with a new private key. This function does NOT save the
// user to disk or register it via ACME. If you want to use // user to disk or register it via ACME. If you want to use
// a user account that might already exist, call getUser // a user account that might already exist, call getUser
// instead. // instead. It does NOT prompt the user.
func newUser(email string) (User, error) { func newUser(email string) (User, error) {
user := User{Email: email} user := User{Email: email}
privateKey, err := rsa.GenerateKey(rand.Reader, rsaKeySizeToUse) privateKey, err := rsa.GenerateKey(rand.Reader, rsaKeySizeToUse)
...@@ -114,10 +115,10 @@ func newUser(email string) (User, error) { ...@@ -114,10 +115,10 @@ func newUser(email string) (User, error) {
// address from the user to use for TLS for cfg. If it // address from the user to use for TLS for cfg. If it
// cannot get an email address, it returns empty string. // cannot get an email address, it returns empty string.
// (It will warn the user of the consequences of an // (It will warn the user of the consequences of an
// empty email.) If skipPrompt is true, the user will // empty email.) This function MAY prompt the user for
// NOT be prompted and an empty email will be returned // input. If userPresent is false, the operator will
// instead. // NOT be prompted and an empty email may be returned.
func getEmail(cfg server.Config, skipPrompt bool) string { func getEmail(cfg server.Config, userPresent bool) string {
// First try the tls directive from the Caddyfile // First try the tls directive from the Caddyfile
leEmail := cfg.TLS.LetsEncryptEmail leEmail := cfg.TLS.LetsEncryptEmail
if leEmail == "" { if leEmail == "" {
...@@ -135,11 +136,12 @@ func getEmail(cfg server.Config, skipPrompt bool) string { ...@@ -135,11 +136,12 @@ func getEmail(cfg server.Config, skipPrompt bool) string {
} }
if mostRecent == nil || dir.ModTime().After(mostRecent.ModTime()) { if mostRecent == nil || dir.ModTime().After(mostRecent.ModTime()) {
leEmail = dir.Name() leEmail = dir.Name()
DefaultEmail = leEmail // save for next time
} }
} }
} }
} }
if leEmail == "" && !skipPrompt { if leEmail == "" && userPresent {
// Alas, we must bother the user and ask for an email address; // Alas, we must bother the user and ask for an email address;
// if they proceed they also agree to the SA. // if they proceed they also agree to the SA.
reader := bufio.NewReader(stdin) reader := bufio.NewReader(stdin)
......
package letsencrypt package https
import ( import (
"bytes" "bytes"
...@@ -140,13 +140,13 @@ func TestGetEmail(t *testing.T) { ...@@ -140,13 +140,13 @@ func TestGetEmail(t *testing.T) {
LetsEncryptEmail: "test1@foo.com", LetsEncryptEmail: "test1@foo.com",
}, },
} }
actual := getEmail(config, false) actual := getEmail(config, true)
if actual != "test1@foo.com" { if actual != "test1@foo.com" {
t.Errorf("Did not get correct email from config; expected '%s' but got '%s'", "test1@foo.com", actual) t.Errorf("Did not get correct email from config; expected '%s' but got '%s'", "test1@foo.com", actual)
} }
// Test2: Use default email from flag (or user previously typing it) // Test2: Use default email from flag (or user previously typing it)
actual = getEmail(server.Config{}, false) actual = getEmail(server.Config{}, true)
if actual != DefaultEmail { if actual != DefaultEmail {
t.Errorf("Did not get correct email from config; expected '%s' but got '%s'", DefaultEmail, actual) t.Errorf("Did not get correct email from config; expected '%s' but got '%s'", DefaultEmail, actual)
} }
...@@ -158,7 +158,7 @@ func TestGetEmail(t *testing.T) { ...@@ -158,7 +158,7 @@ func TestGetEmail(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Could not simulate user input, error: %v", err) t.Fatalf("Could not simulate user input, error: %v", err)
} }
actual = getEmail(server.Config{}, false) actual = getEmail(server.Config{}, true)
if actual != "test3@foo.com" { if actual != "test3@foo.com" {
t.Errorf("Did not get correct email from user input prompt; expected '%s' but got '%s'", "test3@foo.com", actual) t.Errorf("Did not get correct email from user input prompt; expected '%s' but got '%s'", "test3@foo.com", actual)
} }
...@@ -189,7 +189,7 @@ func TestGetEmail(t *testing.T) { ...@@ -189,7 +189,7 @@ func TestGetEmail(t *testing.T) {
t.Fatalf("Could not change user folder mod time for '%s': %v", eml, err) t.Fatalf("Could not change user folder mod time for '%s': %v", eml, err)
} }
} }
actual = getEmail(server.Config{}, false) actual = getEmail(server.Config{}, true)
if actual != "test4-3@foo.com" { if actual != "test4-3@foo.com" {
t.Errorf("Did not get correct email from storage; expected '%s' but got '%s'", "test4-3@foo.com", actual) t.Errorf("Did not get correct email from storage; expected '%s' but got '%s'", "test4-3@foo.com", actual)
} }
......
package letsencrypt
import (
"crypto/tls"
"errors"
"strings"
"sync"
"github.com/mholt/caddy/server"
)
// GetCertificateDuringHandshake is a function that gets a certificate during a TLS handshake.
// It first checks an in-memory cache in case the cert was requested before, then tries to load
// a certificate in the storage folder from disk. If it can't find an existing certificate, it
// will try to obtain one using ACME, which will then be stored on disk and cached in memory.
//
// This function is safe for use by multiple concurrent goroutines.
func GetCertificateDuringHandshake(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
// Utility function to help us load a cert from disk and put it in the cache if successful
loadCertFromDisk := func(domain string) *tls.Certificate {
cert, err := tls.LoadX509KeyPair(storage.SiteCertFile(domain), storage.SiteKeyFile(domain))
if err == nil {
certCacheMu.Lock()
if len(certCache) < 10000 { // limit size of cache to prevent a ridiculous, unusual kind of attack
certCache[domain] = &cert
}
certCacheMu.Unlock()
return &cert
}
return nil
}
// First check our in-memory cache to see if we've already loaded it
certCacheMu.RLock()
cert := server.GetCertificateFromCache(clientHello, certCache)
certCacheMu.RUnlock()
if cert != nil {
return cert, nil
}
// Then check to see if we already have one on disk; if we do, add it to cache and use it
name := strings.ToLower(clientHello.ServerName)
cert = loadCertFromDisk(name)
if cert != nil {
return cert, nil
}
// Only option left is to get one from LE, but the name has to qualify first
if !HostQualifies(name) {
return nil, nil
}
// By this point, we need to obtain one from the CA. We must protect this process
// from happening concurrently, so synchronize.
obtainCertWaitGroupsMutex.Lock()
wg, ok := obtainCertWaitGroups[name]
if ok {
// lucky us -- another goroutine is already obtaining the certificate.
// wait for it to finish obtaining the cert and then we'll use it.
obtainCertWaitGroupsMutex.Unlock()
wg.Wait()
return GetCertificateDuringHandshake(clientHello)
}
// looks like it's up to us to do all the work and obtain the cert
wg = new(sync.WaitGroup)
wg.Add(1)
obtainCertWaitGroups[name] = wg
obtainCertWaitGroupsMutex.Unlock()
// Unblock waiters and delete waitgroup when we return
defer func() {
obtainCertWaitGroupsMutex.Lock()
wg.Done()
delete(obtainCertWaitGroups, name)
obtainCertWaitGroupsMutex.Unlock()
}()
// obtain cert
client, err := newClientPort(DefaultEmail, AlternatePort)
if err != nil {
return nil, errors.New("error creating client: " + err.Error())
}
err = clientObtain(client, []string{name}, false)
if err != nil {
return nil, err
}
// load certificate into memory and return it
return loadCertFromDisk(name), nil
}
// obtainCertWaitGroups is used to coordinate obtaining certs for each hostname.
var obtainCertWaitGroups = make(map[string]*sync.WaitGroup)
var obtainCertWaitGroupsMutex sync.Mutex
// certCache stores certificates that have been obtained in memory.
var certCache = make(map[string]*tls.Certificate)
var certCacheMu sync.RWMutex
package letsencrypt
import (
"encoding/json"
"io/ioutil"
"log"
"time"
"github.com/mholt/caddy/server"
"github.com/xenolf/lego/acme"
)
// OnChange is a callback function that will be used to restart
// the application or the part of the application that uses
// the certificates maintained by this package. When at least
// one certificate is renewed or an OCSP status changes, this
// function will be called.
var OnChange func() error
// maintainAssets is a permanently-blocking function
// that loops indefinitely and, on a regular schedule, checks
// certificates for expiration and initiates a renewal of certs
// that are expiring soon. It also updates OCSP stapling and
// performs other maintenance of assets.
//
// You must pass in the server configs to maintain and the channel
// which you'll close when maintenance should stop, to allow this
// goroutine to clean up after itself and unblock.
func maintainAssets(configs []server.Config, stopChan chan struct{}) {
renewalTicker := time.NewTicker(RenewInterval)
ocspTicker := time.NewTicker(OCSPInterval)
for {
select {
case <-renewalTicker.C:
n, errs := renewCertificates(configs, true)
if len(errs) > 0 {
for _, err := range errs {
log.Printf("[ERROR] Certificate renewal: %v", err)
}
}
// even if there was an error, some renewals may have succeeded
if n > 0 && OnChange != nil {
err := OnChange()
if err != nil {
log.Printf("[ERROR] OnChange after cert renewal: %v", err)
}
}
case <-ocspTicker.C:
for bundle, oldResp := range ocspCache {
// start checking OCSP staple about halfway through validity period for good measure
refreshTime := oldResp.ThisUpdate.Add(oldResp.NextUpdate.Sub(oldResp.ThisUpdate) / 2)
// only check for updated OCSP validity window if refreshTime is in the past
if time.Now().After(refreshTime) {
_, newResp, err := acme.GetOCSPForCert(*bundle)
if err != nil {
log.Printf("[ERROR] Checking OCSP for bundle: %v", err)
continue
}
// we're not looking for different status, just a more future expiration
if newResp.NextUpdate != oldResp.NextUpdate {
if OnChange != nil {
log.Printf("[INFO] Updating OCSP stapling to extend validity period to %v", newResp.NextUpdate)
err := OnChange()
if err != nil {
log.Printf("[ERROR] OnChange after OCSP trigger: %v", err)
}
break
}
}
}
}
case <-stopChan:
renewalTicker.Stop()
ocspTicker.Stop()
return
}
}
}
// renewCertificates loops through all configured site and
// looks for certificates to renew. Nothing is mutated
// through this function; all changes happen directly on disk.
// It returns the number of certificates renewed and any errors
// that occurred. It only performs a renewal if necessary.
// If useCustomPort is true, a custom port will be used, and
// whatever is listening at 443 better proxy ACME requests to it.
// Otherwise, the acme package will create its own listener on 443.
func renewCertificates(configs []server.Config, useCustomPort bool) (int, []error) {
log.Printf("[INFO] Checking certificates for %d hosts", len(configs))
var errs []error
var n int
for _, cfg := range configs {
// Host must be TLS-enabled and have existing assets managed by LE
if !cfg.TLS.Enabled || !existingCertAndKey(cfg.Host) {
continue
}
// Read the certificate and get the NotAfter time.
certBytes, err := ioutil.ReadFile(storage.SiteCertFile(cfg.Host))
if err != nil {
errs = append(errs, err)
continue // still have to check other certificates
}
expTime, err := acme.GetPEMCertExpiration(certBytes)
if err != nil {
errs = append(errs, err)
continue
}
// The time returned from the certificate is always in UTC.
// So calculate the time left with local time as UTC.
// Directly convert it to days for the following checks.
daysLeft := int(expTime.Sub(time.Now().UTC()).Hours() / 24)
// Renew if getting close to expiration.
if daysLeft <= renewDaysBefore {
log.Printf("[INFO] Certificate for %s has %d days remaining; attempting renewal", cfg.Host, daysLeft)
var client *acme.Client
if useCustomPort {
client, err = newClientPort("", AlternatePort) // email not used for renewal
} else {
client, err = newClient("")
}
if err != nil {
errs = append(errs, err)
continue
}
// Read and set up cert meta, required for renewal
metaBytes, err := ioutil.ReadFile(storage.SiteMetaFile(cfg.Host))
if err != nil {
errs = append(errs, err)
continue
}
privBytes, err := ioutil.ReadFile(storage.SiteKeyFile(cfg.Host))
if err != nil {
errs = append(errs, err)
continue
}
var certMeta acme.CertificateResource
err = json.Unmarshal(metaBytes, &certMeta)
certMeta.Certificate = certBytes
certMeta.PrivateKey = privBytes
// Renew certificate
Renew:
newCertMeta, err := client.RenewCertificate(certMeta, true)
if err != nil {
if _, ok := err.(acme.TOSError); ok {
err := client.AgreeToTOS()
if err != nil {
errs = append(errs, err)
}
goto Renew
}
time.Sleep(10 * time.Second)
newCertMeta, err = client.RenewCertificate(certMeta, true)
if err != nil {
errs = append(errs, err)
continue
}
}
saveCertResource(newCertMeta)
n++
} else if daysLeft <= renewDaysBefore+7 && daysLeft >= renewDaysBefore+6 {
log.Printf("[WARNING] Certificate for %s has %d days remaining; will automatically renew when %d days remain\n", cfg.Host, daysLeft, renewDaysBefore)
}
}
return n, errs
}
// renewDaysBefore is how many days before expiration to renew certificates.
const renewDaysBefore = 14
...@@ -12,7 +12,7 @@ import ( ...@@ -12,7 +12,7 @@ import (
"os/exec" "os/exec"
"path" "path"
"github.com/mholt/caddy/caddy/letsencrypt" "github.com/mholt/caddy/caddy/https"
) )
func init() { func init() {
...@@ -133,13 +133,15 @@ func getCertsForNewCaddyfile(newCaddyfile Input) error { ...@@ -133,13 +133,15 @@ func getCertsForNewCaddyfile(newCaddyfile Input) error {
} }
// first mark the configs that are qualified for managed TLS // first mark the configs that are qualified for managed TLS
letsencrypt.MarkQualified(configs) https.MarkQualified(configs)
// we must make sure port is set before we group by bind address // since we group by bind address to obtain certs, we must call
letsencrypt.EnableTLS(configs) // EnableTLS to make sure the port is set properly first
// (can ignore error since we aren't actually using the certs)
https.EnableTLS(configs, false)
// place certs on the disk // place certs on the disk
err = letsencrypt.ObtainCerts(configs, letsencrypt.AlternatePort) err = https.ObtainCerts(configs, false)
if err != nil { if err != nil {
return errors.New("obtaining certs: " + err.Error()) return errors.New("obtaining certs: " + err.Error())
} }
......
...@@ -13,7 +13,7 @@ import ( ...@@ -13,7 +13,7 @@ import (
"time" "time"
"github.com/mholt/caddy/caddy" "github.com/mholt/caddy/caddy"
"github.com/mholt/caddy/caddy/letsencrypt" "github.com/mholt/caddy/caddy/https"
"github.com/xenolf/lego/acme" "github.com/xenolf/lego/acme"
) )
...@@ -32,14 +32,14 @@ const ( ...@@ -32,14 +32,14 @@ const (
func init() { func init() {
caddy.TrapSignals() caddy.TrapSignals()
flag.BoolVar(&letsencrypt.Agreed, "agree", false, "Agree to Let's Encrypt Subscriber Agreement") flag.BoolVar(&https.Agreed, "agree", false, "Agree to Let's Encrypt Subscriber Agreement")
flag.StringVar(&letsencrypt.CAUrl, "ca", "https://acme-v01.api.letsencrypt.org/directory", "Certificate authority ACME server") flag.StringVar(&https.CAUrl, "ca", "https://acme-v01.api.letsencrypt.org/directory", "Certificate authority ACME server")
flag.StringVar(&conf, "conf", "", "Configuration file to use (default="+caddy.DefaultConfigFile+")") flag.StringVar(&conf, "conf", "", "Configuration file to use (default="+caddy.DefaultConfigFile+")")
flag.StringVar(&cpu, "cpu", "100%", "CPU cap") flag.StringVar(&cpu, "cpu", "100%", "CPU cap")
flag.StringVar(&letsencrypt.DefaultEmail, "email", "", "Default Let's Encrypt account email address") flag.StringVar(&https.DefaultEmail, "email", "", "Default Let's Encrypt account email address")
flag.DurationVar(&caddy.GracefulTimeout, "grace", 5*time.Second, "Maximum duration of graceful shutdown") flag.DurationVar(&caddy.GracefulTimeout, "grace", 5*time.Second, "Maximum duration of graceful shutdown")
flag.StringVar(&caddy.Host, "host", caddy.DefaultHost, "Default host") flag.StringVar(&caddy.Host, "host", caddy.DefaultHost, "Default host")
flag.BoolVar(&caddy.HTTP2, "http2", true, "HTTP/2 support") // TODO: temporary flag until http2 merged into std lib flag.BoolVar(&caddy.HTTP2, "http2", true, "HTTP/2 support")
flag.StringVar(&logfile, "log", "", "Process log file") flag.StringVar(&logfile, "log", "", "Process log file")
flag.StringVar(&caddy.PidFile, "pidfile", "", "Path to write pid file") flag.StringVar(&caddy.PidFile, "pidfile", "", "Path to write pid file")
flag.StringVar(&caddy.Port, "port", caddy.DefaultPort, "Default port") flag.StringVar(&caddy.Port, "port", caddy.DefaultPort, "Default port")
...@@ -73,7 +73,7 @@ func main() { ...@@ -73,7 +73,7 @@ func main() {
} }
if revoke != "" { if revoke != "" {
err := letsencrypt.Revoke(revoke) err := https.Revoke(revoke)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
......
...@@ -65,13 +65,10 @@ func (c Config) Address() string { ...@@ -65,13 +65,10 @@ func (c Config) Address() string {
// TLSConfig describes how TLS should be configured and used. // TLSConfig describes how TLS should be configured and used.
type TLSConfig struct { type TLSConfig struct {
Enabled bool Enabled bool
Certificate string LetsEncryptEmail string
Key string Managed bool // will be set to true if config qualifies for automatic, managed TLS
LetsEncryptEmail string Manual bool // will be set to true if user provides the cert and key files
Managed bool // will be set to true if config qualifies for automatic, managed TLS
//DisableHTTPRedir bool // TODO: not a good idea - should we really allow it?
OCSPStaple []byte
Ciphers []uint16 Ciphers []uint16
ProtocolMinVersion uint16 ProtocolMinVersion uint16
ProtocolMaxVersion uint16 ProtocolMaxVersion uint16
......
...@@ -13,7 +13,6 @@ import ( ...@@ -13,7 +13,6 @@ import (
"net/http" "net/http"
"os" "os"
"runtime" "runtime"
"strings"
"sync" "sync"
"time" "time"
) )
...@@ -25,8 +24,9 @@ import ( ...@@ -25,8 +24,9 @@ import (
// graceful termination (POSIX only). // graceful termination (POSIX only).
type Server struct { type Server struct {
*http.Server *http.Server
HTTP2 bool // temporary while http2 is not in std lib (TODO: remove flag when part of std lib) HTTP2 bool // whether to enable HTTP/2
tls bool // whether this server is serving all HTTPS hosts or not tls bool // whether this server is serving all HTTPS hosts or not
OnDemandTLS bool // whether this server supports on-demand TLS (load certs at handshake-time)
vhosts map[string]virtualHost // virtual hosts keyed by their address vhosts map[string]virtualHost // virtual hosts keyed by their address
listener ListenerFile // the listener which is bound to the socket listener ListenerFile // the listener which is bound to the socket
listenerMu sync.Mutex // protects listener listenerMu sync.Mutex // protects listener
...@@ -60,20 +60,29 @@ type OptionalCallback func(http.ResponseWriter, *http.Request) bool ...@@ -60,20 +60,29 @@ type OptionalCallback func(http.ResponseWriter, *http.Request) bool
// as it stands, you should dispose of a server after stopping it. // as it stands, you should dispose of a server after stopping it.
// The behavior of serving with a spent server is undefined. // The behavior of serving with a spent server is undefined.
func New(addr string, configs []Config, gracefulTimeout time.Duration) (*Server, error) { func New(addr string, configs []Config, gracefulTimeout time.Duration) (*Server, error) {
var tls bool var useTLS, useOnDemandTLS bool
if len(configs) > 0 { if len(configs) > 0 {
tls = configs[0].TLS.Enabled useTLS = configs[0].TLS.Enabled
host, _, err := net.SplitHostPort(addr)
if err != nil {
host = addr
}
if useTLS && host == "" && !configs[0].TLS.Manual {
useOnDemandTLS = true
}
} }
s := &Server{ s := &Server{
Server: &http.Server{ Server: &http.Server{
Addr: addr, Addr: addr,
TLSConfig: new(tls.Config),
// TODO: Make these values configurable? // TODO: Make these values configurable?
// ReadTimeout: 2 * time.Minute, // ReadTimeout: 2 * time.Minute,
// WriteTimeout: 2 * time.Minute, // WriteTimeout: 2 * time.Minute,
// MaxHeaderBytes: 1 << 16, // MaxHeaderBytes: 1 << 16,
}, },
tls: tls, tls: useTLS,
OnDemandTLS: useOnDemandTLS,
vhosts: make(map[string]virtualHost), vhosts: make(map[string]virtualHost),
startChan: make(chan struct{}), startChan: make(chan struct{}),
connTimeout: gracefulTimeout, connTimeout: gracefulTimeout,
...@@ -168,7 +177,7 @@ func (s *Server) serve(ln ListenerFile) error { ...@@ -168,7 +177,7 @@ func (s *Server) serve(ln ListenerFile) error {
for _, vh := range s.vhosts { for _, vh := range s.vhosts {
tlsConfigs = append(tlsConfigs, vh.config.TLS) tlsConfigs = append(tlsConfigs, vh.config.TLS)
} }
return serveTLSWithSNI(s, s.listener, tlsConfigs) return serveTLS(s, s.listener, tlsConfigs)
} }
close(s.startChan) // unblock anyone waiting for this to start listening close(s.startChan) // unblock anyone waiting for this to start listening
...@@ -196,106 +205,32 @@ func (s *Server) setup() error { ...@@ -196,106 +205,32 @@ func (s *Server) setup() error {
return nil return nil
} }
// serveTLSWithSNI serves TLS with Server Name Indication (SNI) support, which allows // serveTLS serves TLS with SNI and client auth support if s has them enabled. It
// multiple sites (different hostnames) to be served from the same address. It also // blocks until s quits.
// supports client authentication if srv has it enabled. It blocks until s quits. func serveTLS(s *Server, ln net.Listener, tlsConfigs []TLSConfig) error {
//
// This method is adapted from the std lib's net/http ServeTLS function, which was written
// by the Go Authors. It has been modified to support multiple certificate/key pairs,
// client authentication, and our custom Server type.
func serveTLSWithSNI(s *Server, ln net.Listener, tlsConfigs []TLSConfig) error {
config := cloneTLSConfig(s.TLSConfig)
// Here we diverge from the stdlib a bit by loading multiple certs/key pairs
// then we map the server names to their certs
for _, tlsConfig := range tlsConfigs {
if tlsConfig.Certificate == "" || tlsConfig.Key == "" {
continue
}
cert, err := tls.LoadX509KeyPair(tlsConfig.Certificate, tlsConfig.Key)
if err != nil {
defer close(s.startChan)
return fmt.Errorf("loading certificate and key pair: %v", err)
}
cert.OCSPStaple = tlsConfig.OCSPStaple
config.Certificates = append(config.Certificates, cert)
}
if len(config.Certificates) > 0 {
config.BuildNameToCertificate()
}
config.GetCertificate = func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
// TODO: When Caddy starts, if it is to issue certs dynamically, we need
// terms agreement and an email address. make sure this is enforced at server
// start if the Caddyfile enables dynamic certificate issuance!
// Check NameToCertificate like the std lib does in "getCertificate" (unexported, bah)
cert := GetCertificateFromCache(clientHello, config.NameToCertificate)
if cert != nil {
return cert, nil
}
if s.SNICallback != nil {
return s.SNICallback(clientHello)
}
return nil, nil
}
// Customize our TLS configuration // Customize our TLS configuration
config.MinVersion = tlsConfigs[0].ProtocolMinVersion s.TLSConfig.MinVersion = tlsConfigs[0].ProtocolMinVersion
config.MaxVersion = tlsConfigs[0].ProtocolMaxVersion s.TLSConfig.MaxVersion = tlsConfigs[0].ProtocolMaxVersion
config.CipherSuites = tlsConfigs[0].Ciphers s.TLSConfig.CipherSuites = tlsConfigs[0].Ciphers
config.PreferServerCipherSuites = tlsConfigs[0].PreferServerCipherSuites s.TLSConfig.PreferServerCipherSuites = tlsConfigs[0].PreferServerCipherSuites
// TLS client authentication, if user enabled it // TLS client authentication, if user enabled it
err := setupClientAuth(tlsConfigs, config) err := setupClientAuth(tlsConfigs, s.TLSConfig)
if err != nil { if err != nil {
defer close(s.startChan) defer close(s.startChan)
return err return err
} }
s.TLSConfig = config
// Create TLS listener - note that we do not replace s.listener // Create TLS listener - note that we do not replace s.listener
// with this TLS listener; tls.listener is unexported and does // with this TLS listener; tls.listener is unexported and does
// not implement the File() method we need for graceful restarts // not implement the File() method we need for graceful restarts
// on POSIX systems. // on POSIX systems.
ln = tls.NewListener(ln, config) ln = tls.NewListener(ln, s.TLSConfig)
close(s.startChan) // unblock anyone waiting for this to start listening close(s.startChan) // unblock anyone waiting for this to start listening
return s.Server.Serve(ln) return s.Server.Serve(ln)
} }
// Borrowed from the Go standard library, crypto/tls pacakge, common.go.
// It has been modified to fit this program.
// Original license:
//
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
func GetCertificateFromCache(clientHello *tls.ClientHelloInfo, cache map[string]*tls.Certificate) *tls.Certificate {
name := strings.ToLower(clientHello.ServerName)
for len(name) > 0 && name[len(name)-1] == '.' {
name = name[:len(name)-1]
}
// exact match? great! use it
if cert, ok := cache[name]; ok {
return cert
}
// try replacing labels in the name with wildcards until we get a match.
labels := strings.Split(name, ".")
for i := range labels {
labels[i] = "*"
candidate := strings.Join(labels, ".")
if cert, ok := cache[candidate]; ok {
return cert
}
}
return nil
}
// Stop stops the server. It blocks until the server is // Stop stops the server. It blocks until the server is
// totally stopped. On POSIX systems, it will wait for // totally stopped. On POSIX systems, it will wait for
// connections to close (up to a max timeout of a few // connections to close (up to a max timeout of a few
...@@ -482,6 +417,8 @@ func (ln tcpKeepAliveListener) File() (*os.File, error) { ...@@ -482,6 +417,8 @@ func (ln tcpKeepAliveListener) File() (*os.File, error) {
} }
// copied from net/http/transport.go // copied from net/http/transport.go
/*
TODO - remove - not necessary?
func cloneTLSConfig(cfg *tls.Config) *tls.Config { func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil { if cfg == nil {
return &tls.Config{} return &tls.Config{}
...@@ -507,7 +444,7 @@ func cloneTLSConfig(cfg *tls.Config) *tls.Config { ...@@ -507,7 +444,7 @@ func cloneTLSConfig(cfg *tls.Config) *tls.Config {
MaxVersion: cfg.MaxVersion, MaxVersion: cfg.MaxVersion,
CurvePreferences: cfg.CurvePreferences, CurvePreferences: cfg.CurvePreferences,
} }
} }*/
// ShutdownCallbacks executes all the shutdown callbacks // ShutdownCallbacks executes all the shutdown callbacks
// for all the virtualhosts in servers, and returns all the // for all the virtualhosts in servers, and returns all the
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment