Commit e0f1a02c authored by Matthew Holt's avatar Matthew Holt

Extract most of caddytls core code into external CertMagic package

All code relating to a caddytls.Config and setting it up from the
Caddyfile is still intact; only the certificate management-related
code was removed into a separate package.

I don't expect this to build in CI successfully; updating dependencies
and vendor is coming next.

I've also removed the ad-hoc, half-baked storage plugins that we need
to finish making first-class Caddy plugins (they were never documented
anyway). The new certmagic package has a much better storage interface,
and we can finally move toward making a new storage plugin type, but
it shouldn't be configurable in the Caddyfile, I think, since it doesn't
make sense for a Caddy instance to use more than one storage config...

We also have the option of eliminating DNS provider plugins and just
shipping all of lego's DNS providers by using a lego package (the
caddytls/setup.go file has a comment describing how) -- but it doubles
Caddy's binary size by 100% from about 19 MB to around 40 MB...!
parent 8f583dcf
...@@ -108,12 +108,12 @@ type Instance struct { ...@@ -108,12 +108,12 @@ type Instance struct {
servers []ServerListener servers []ServerListener
// these callbacks execute when certain events occur // these callbacks execute when certain events occur
onFirstStartup []func() error // starting, not as part of a restart OnFirstStartup []func() error // starting, not as part of a restart
onStartup []func() error // starting, even as part of a restart OnStartup []func() error // starting, even as part of a restart
onRestart []func() error // before restart commences OnRestart []func() error // before restart commences
onRestartFailed []func() error // if restart failed OnRestartFailed []func() error // if restart failed
onShutdown []func() error // stopping, even as part of a restart OnShutdown []func() error // stopping, even as part of a restart
onFinalShutdown []func() error // stopping, not as part of a restart OnFinalShutdown []func() error // stopping, not as part of a restart
// storing values on an instance is preferable to // storing values on an instance is preferable to
// global state because these will get garbage- // global state because these will get garbage-
...@@ -163,13 +163,13 @@ func (i *Instance) Stop() error { ...@@ -163,13 +163,13 @@ func (i *Instance) Stop() error {
// the rest. All the non-nil errors will be returned. // the rest. All the non-nil errors will be returned.
func (i *Instance) ShutdownCallbacks() []error { func (i *Instance) ShutdownCallbacks() []error {
var errs []error var errs []error
for _, shutdownFunc := range i.onShutdown { for _, shutdownFunc := range i.OnShutdown {
err := shutdownFunc() err := shutdownFunc()
if err != nil { if err != nil {
errs = append(errs, err) errs = append(errs, err)
} }
} }
for _, finalShutdownFunc := range i.onFinalShutdown { for _, finalShutdownFunc := range i.OnFinalShutdown {
err := finalShutdownFunc() err := finalShutdownFunc()
if err != nil { if err != nil {
errs = append(errs, err) errs = append(errs, err)
...@@ -192,7 +192,7 @@ func (i *Instance) Restart(newCaddyfile Input) (*Instance, error) { ...@@ -192,7 +192,7 @@ func (i *Instance) Restart(newCaddyfile Input) (*Instance, error) {
defer func() { defer func() {
r := recover() r := recover()
if err != nil || r != nil { if err != nil || r != nil {
for _, fn := range i.onRestartFailed { for _, fn := range i.OnRestartFailed {
err = fn() err = fn()
if err != nil { if err != nil {
log.Printf("[ERROR] restart failed: %v", err) log.Printf("[ERROR] restart failed: %v", err)
...@@ -205,7 +205,7 @@ func (i *Instance) Restart(newCaddyfile Input) (*Instance, error) { ...@@ -205,7 +205,7 @@ func (i *Instance) Restart(newCaddyfile Input) (*Instance, error) {
}() }()
// run restart callbacks // run restart callbacks
for _, fn := range i.onRestart { for _, fn := range i.OnRestart {
err = fn() err = fn()
if err != nil { if err != nil {
return i, err return i, err
...@@ -252,7 +252,7 @@ func (i *Instance) Restart(newCaddyfile Input) (*Instance, error) { ...@@ -252,7 +252,7 @@ func (i *Instance) Restart(newCaddyfile Input) (*Instance, error) {
if err != nil { if err != nil {
return i, err return i, err
} }
for _, shutdownFunc := range i.onShutdown { for _, shutdownFunc := range i.OnShutdown {
err = shutdownFunc() err = shutdownFunc()
if err != nil { if err != nil {
return i, err return i, err
...@@ -274,42 +274,6 @@ func (i *Instance) SaveServer(s Server, ln net.Listener) { ...@@ -274,42 +274,6 @@ func (i *Instance) SaveServer(s Server, ln net.Listener) {
i.servers = append(i.servers, ServerListener{server: s, listener: ln}) i.servers = append(i.servers, ServerListener{server: s, listener: ln})
} }
// HasListenerWithAddress returns whether this package is
// tracking a server using a listener with the address
// addr.
func HasListenerWithAddress(addr string) bool {
instancesMu.Lock()
defer instancesMu.Unlock()
for _, inst := range instances {
for _, sln := range inst.servers {
if listenerAddrEqual(sln.listener, addr) {
return true
}
}
}
return false
}
// listenerAddrEqual compares a listener's address with
// addr. Extra care is taken to match addresses with an
// empty hostname portion, as listeners tend to report
// [::]:80, for example, when the matching address that
// created the listener might be simply :80.
func listenerAddrEqual(ln net.Listener, addr string) bool {
lnAddr := ln.Addr().String()
hostname, port, err := net.SplitHostPort(addr)
if err != nil {
return lnAddr == addr
}
if lnAddr == net.JoinHostPort("::", port) {
return true
}
if lnAddr == net.JoinHostPort("0.0.0.0", port) {
return true
}
return hostname != "" && lnAddr == addr
}
// TCPServer is a type that can listen and serve connections. // TCPServer is a type that can listen and serve connections.
// A TCPServer must associate with exactly zero or one net.Listeners. // A TCPServer must associate with exactly zero or one net.Listeners.
type TCPServer interface { type TCPServer interface {
...@@ -551,14 +515,14 @@ func startWithListenerFds(cdyfile Input, inst *Instance, restartFds map[string]r ...@@ -551,14 +515,14 @@ func startWithListenerFds(cdyfile Input, inst *Instance, restartFds map[string]r
// run startup callbacks // run startup callbacks
if !IsUpgrade() && restartFds == nil { if !IsUpgrade() && restartFds == nil {
// first startup means not a restart or upgrade // first startup means not a restart or upgrade
for _, firstStartupFunc := range inst.onFirstStartup { for _, firstStartupFunc := range inst.OnFirstStartup {
err = firstStartupFunc() err = firstStartupFunc()
if err != nil { if err != nil {
return err return err
} }
} }
} }
for _, startupFunc := range inst.onStartup { for _, startupFunc := range inst.OnStartup {
err = startupFunc() err = startupFunc()
if err != nil { if err != nil {
return err return err
......
...@@ -33,8 +33,8 @@ import ( ...@@ -33,8 +33,8 @@ import (
"github.com/mholt/caddy" "github.com/mholt/caddy"
"github.com/mholt/caddy/caddytls" "github.com/mholt/caddy/caddytls"
"github.com/mholt/caddy/telemetry" "github.com/mholt/caddy/telemetry"
"github.com/xenolf/lego/acme" "github.com/mholt/certmagic"
"gopkg.in/natefinch/lumberjack.v2" lumberjack "gopkg.in/natefinch/lumberjack.v2"
_ "github.com/mholt/caddy/caddyhttp" // plug in the HTTP server type _ "github.com/mholt/caddy/caddyhttp" // plug in the HTTP server type
// This is where other plugins get plugged in (imported) // This is where other plugins get plugged in (imported)
...@@ -44,17 +44,17 @@ func init() { ...@@ -44,17 +44,17 @@ func init() {
caddy.TrapSignals() caddy.TrapSignals()
setVersion() setVersion()
flag.BoolVar(&caddytls.Agreed, "agree", false, "Agree to the CA's Subscriber Agreement") flag.BoolVar(&certmagic.Agreed, "agree", false, "Agree to the CA's Subscriber Agreement")
flag.StringVar(&caddytls.DefaultCAUrl, "ca", "https://acme-v02.api.letsencrypt.org/directory", "URL to certificate authority's ACME server directory") flag.StringVar(&certmagic.CA, "ca", certmagic.CA, "URL to certificate authority's ACME server directory")
flag.BoolVar(&caddytls.DisableHTTPChallenge, "disable-http-challenge", caddytls.DisableHTTPChallenge, "Disable the ACME HTTP challenge") flag.BoolVar(&certmagic.DisableHTTPChallenge, "disable-http-challenge", certmagic.DisableHTTPChallenge, "Disable the ACME HTTP challenge")
flag.BoolVar(&caddytls.DisableTLSALPNChallenge, "disable-tls-alpn-challenge", caddytls.DisableTLSALPNChallenge, "Disable the ACME TLS-ALPN challenge") flag.BoolVar(&certmagic.DisableTLSALPNChallenge, "disable-tls-alpn-challenge", certmagic.DisableTLSALPNChallenge, "Disable the ACME TLS-ALPN challenge")
flag.StringVar(&disabledMetrics, "disabled-metrics", "", "Comma-separated list of telemetry metrics to disable") flag.StringVar(&disabledMetrics, "disabled-metrics", "", "Comma-separated list of telemetry metrics to disable")
flag.StringVar(&conf, "conf", "", "Caddyfile to load (default \""+caddy.DefaultConfigFile+"\")") flag.StringVar(&conf, "conf", "", "Caddyfile to load (default \""+caddy.DefaultConfigFile+"\")")
flag.StringVar(&cpu, "cpu", "100%", "CPU cap") flag.StringVar(&cpu, "cpu", "100%", "CPU cap")
flag.StringVar(&envFile, "env", "", "Path to file with environment variables to load in KEY=VALUE format") flag.StringVar(&envFile, "env", "", "Path to file with environment variables to load in KEY=VALUE format")
flag.BoolVar(&plugins, "plugins", false, "List installed plugins") flag.BoolVar(&plugins, "plugins", false, "List installed plugins")
flag.StringVar(&caddytls.DefaultEmail, "email", "", "Default ACME CA account email address") flag.StringVar(&certmagic.Email, "email", "", "Default ACME CA account email address")
flag.DurationVar(&acme.HTTPClient.Timeout, "catimeout", acme.HTTPClient.Timeout, "Default ACME CA HTTP timeout") flag.DurationVar(&certmagic.HTTPTimeout, "catimeout", certmagic.HTTPTimeout, "Default ACME CA HTTP timeout")
flag.StringVar(&logfile, "log", "", "Process log file") flag.StringVar(&logfile, "log", "", "Process log file")
flag.StringVar(&caddy.PidFile, "pidfile", "", "Path to write pid file") flag.StringVar(&caddy.PidFile, "pidfile", "", "Path to write pid file")
flag.BoolVar(&caddy.Quiet, "quiet", false, "Quiet mode (no initialization output)") flag.BoolVar(&caddy.Quiet, "quiet", false, "Quiet mode (no initialization output)")
...@@ -73,7 +73,7 @@ func Run() { ...@@ -73,7 +73,7 @@ func Run() {
caddy.AppName = appName caddy.AppName = appName
caddy.AppVersion = appVersion caddy.AppVersion = appVersion
acme.UserAgent = appName + "/" + appVersion certmagic.UserAgent = appName + "/" + appVersion
// Set up process log before anything bad happens // Set up process log before anything bad happens
switch logfile { switch logfile {
......
...@@ -32,7 +32,7 @@ func setupBind(c *caddy.Controller) error { ...@@ -32,7 +32,7 @@ func setupBind(c *caddy.Controller) error {
if !c.Args(&config.ListenHost) { if !c.Args(&config.ListenHost) {
return c.ArgErr() return c.ArgErr()
} }
config.TLS.ListenHost = config.ListenHost // necessary for ACME challenges, see issue #309 config.TLS.Manager.ListenHost = config.ListenHost // necessary for ACME challenges, see issue #309
} }
return nil return nil
} }
...@@ -32,7 +32,7 @@ func TestSetupBind(t *testing.T) { ...@@ -32,7 +32,7 @@ func TestSetupBind(t *testing.T) {
if got, want := cfg.ListenHost, "1.2.3.4"; got != want { if got, want := cfg.ListenHost, "1.2.3.4"; got != want {
t.Errorf("Expected the config's ListenHost to be %s, was %s", want, got) t.Errorf("Expected the config's ListenHost to be %s, was %s", want, got)
} }
if got, want := cfg.TLS.ListenHost, "1.2.3.4"; got != want { if got, want := cfg.TLS.Manager.ListenHost, "1.2.3.4"; got != want {
t.Errorf("Expected the TLS config's ListenHost to be %s, was %s", want, got) t.Errorf("Expected the TLS config's ListenHost to be %s, was %s", want, got)
} }
} }
...@@ -21,6 +21,7 @@ import ( ...@@ -21,6 +21,7 @@ import (
"github.com/mholt/caddy" "github.com/mholt/caddy"
"github.com/mholt/caddy/caddytls" "github.com/mholt/caddy/caddytls"
"github.com/mholt/certmagic"
) )
func activateHTTPS(cctx caddy.Context) error { func activateHTTPS(cctx caddy.Context) error {
...@@ -37,10 +38,10 @@ func activateHTTPS(cctx caddy.Context) error { ...@@ -37,10 +38,10 @@ func activateHTTPS(cctx caddy.Context) error {
// place certificates and keys on disk // place certificates and keys on disk
for _, c := range ctx.siteConfigs { for _, c := range ctx.siteConfigs {
if c.TLS.OnDemand { if c.TLS.Manager.OnDemand != nil {
continue // obtain these certificates on-demand instead continue // obtain these certificates on-demand instead
} }
err := c.TLS.ObtainCert(c.TLS.Hostname, operatorPresent) err := c.TLS.Manager.ObtainCert(c.TLS.Hostname, operatorPresent)
if err != nil { if err != nil {
return err return err
} }
...@@ -62,9 +63,14 @@ func activateHTTPS(cctx caddy.Context) error { ...@@ -62,9 +63,14 @@ func activateHTTPS(cctx caddy.Context) error {
// on the ports we'd need to do ACME before we finish starting; parent process // on the ports we'd need to do ACME before we finish starting; parent process
// already running renewal ticker, so renewal won't be missed anyway.) // already running renewal ticker, so renewal won't be missed anyway.)
if !caddy.IsUpgrade() { if !caddy.IsUpgrade() {
err = caddytls.RenewManagedCertificates(true) ctx.instance.StorageMu.RLock()
if err != nil { certCache, ok := ctx.instance.Storage[caddytls.CertCacheInstStorageKey].(*certmagic.Cache)
return err ctx.instance.StorageMu.RUnlock()
if ok && certCache != nil {
err = certCache.RenewManagedCertificates(operatorPresent)
if err != nil {
return err
}
} }
} }
...@@ -95,13 +101,13 @@ func markQualifiedForAutoHTTPS(configs []*SiteConfig) { ...@@ -95,13 +101,13 @@ func markQualifiedForAutoHTTPS(configs []*SiteConfig) {
// value will always be nil. // value will always be nil.
func enableAutoHTTPS(configs []*SiteConfig, loadCertificates bool) error { func enableAutoHTTPS(configs []*SiteConfig, loadCertificates bool) error {
for _, cfg := range configs { for _, cfg := range configs {
if cfg == nil || cfg.TLS == nil || !cfg.TLS.Managed || cfg.TLS.OnDemand { if cfg == nil || cfg.TLS == nil || !cfg.TLS.Managed || cfg.TLS.Manager.OnDemand != nil {
continue continue
} }
cfg.TLS.Enabled = true cfg.TLS.Enabled = true
cfg.Addr.Scheme = "https" cfg.Addr.Scheme = "https"
if loadCertificates && caddytls.HostQualifies(cfg.TLS.Hostname) { if loadCertificates && certmagic.HostQualifies(cfg.TLS.Hostname) {
_, err := cfg.TLS.CacheManagedCertificate(cfg.TLS.Hostname) _, err := cfg.TLS.Manager.CacheManagedCertificate(cfg.TLS.Hostname)
if err != nil { if err != nil {
return err return err
} }
...@@ -113,7 +119,7 @@ func enableAutoHTTPS(configs []*SiteConfig, loadCertificates bool) error { ...@@ -113,7 +119,7 @@ func enableAutoHTTPS(configs []*SiteConfig, loadCertificates bool) error {
// Set default port of 443 if not explicitly set // Set default port of 443 if not explicitly set
if cfg.Addr.Port == "" && if cfg.Addr.Port == "" &&
cfg.TLS.Enabled && cfg.TLS.Enabled &&
(!cfg.TLS.Manual || cfg.TLS.OnDemand) && (!cfg.TLS.Manual || cfg.TLS.Manager.OnDemand != nil) &&
cfg.Addr.Host != "localhost" { cfg.Addr.Host != "localhost" {
cfg.Addr.Port = HTTPSPort cfg.Addr.Port = HTTPSPort
} }
...@@ -207,7 +213,7 @@ func redirPlaintextHost(cfg *SiteConfig) *SiteConfig { ...@@ -207,7 +213,7 @@ func redirPlaintextHost(cfg *SiteConfig) *SiteConfig {
Addr: Address{Original: addr, Host: host, Port: port}, Addr: Address{Original: addr, Host: host, Port: port},
ListenHost: cfg.ListenHost, ListenHost: cfg.ListenHost,
middleware: []Middleware{redirMiddleware}, middleware: []Middleware{redirMiddleware},
TLS: &caddytls.Config{AltHTTPPort: cfg.TLS.AltHTTPPort, AltTLSALPNPort: cfg.TLS.AltTLSALPNPort}, TLS: &caddytls.Config{Manager: cfg.TLS.Manager},
Timeouts: cfg.Timeouts, Timeouts: cfg.Timeouts,
} }
} }
...@@ -22,6 +22,7 @@ import ( ...@@ -22,6 +22,7 @@ import (
"testing" "testing"
"github.com/mholt/caddy/caddytls" "github.com/mholt/caddy/caddytls"
"github.com/mholt/certmagic"
) )
func TestRedirPlaintextHost(t *testing.T) { func TestRedirPlaintextHost(t *testing.T) {
...@@ -150,18 +151,18 @@ func TestHostHasOtherPort(t *testing.T) { ...@@ -150,18 +151,18 @@ func TestHostHasOtherPort(t *testing.T) {
func TestMakePlaintextRedirects(t *testing.T) { func TestMakePlaintextRedirects(t *testing.T) {
configs := []*SiteConfig{ configs := []*SiteConfig{
// Happy path = standard redirect from 80 to 443 // Happy path = standard redirect from 80 to 443
{Addr: Address{Host: "example.com"}, TLS: &caddytls.Config{Managed: true}}, {Addr: Address{Host: "example.com"}, TLS: &caddytls.Config{Manager: &certmagic.Config{Managed: true}}},
// Host on port 80 already defined; don't change it (no redirect) // Host on port 80 already defined; don't change it (no redirect)
{Addr: Address{Host: "sub1.example.com", Port: "80", Scheme: "http"}, TLS: new(caddytls.Config)}, {Addr: Address{Host: "sub1.example.com", Port: "80", Scheme: "http"}, TLS: new(caddytls.Config)},
{Addr: Address{Host: "sub1.example.com"}, TLS: &caddytls.Config{Managed: true}}, {Addr: Address{Host: "sub1.example.com"}, TLS: &caddytls.Config{Manager: &certmagic.Config{Managed: true}}},
// Redirect from port 80 to port 5000 in this case // Redirect from port 80 to port 5000 in this case
{Addr: Address{Host: "sub2.example.com", Port: "5000"}, TLS: &caddytls.Config{Managed: true}}, {Addr: Address{Host: "sub2.example.com", Port: "5000"}, TLS: &caddytls.Config{Manager: &certmagic.Config{Managed: true}}},
// Can redirect from 80 to either 443 or 5001, but choose 443 // Can redirect from 80 to either 443 or 5001, but choose 443
{Addr: Address{Host: "sub3.example.com", Port: "443"}, TLS: &caddytls.Config{Managed: true}}, {Addr: Address{Host: "sub3.example.com", Port: "443"}, TLS: &caddytls.Config{Manager: &certmagic.Config{Managed: true}}},
{Addr: Address{Host: "sub3.example.com", Port: "5001", Scheme: "https"}, TLS: &caddytls.Config{Managed: true}}, {Addr: Address{Host: "sub3.example.com", Port: "5001", Scheme: "https"}, TLS: &caddytls.Config{Manager: &certmagic.Config{Managed: true}}},
} }
result := makePlaintextRedirects(configs) result := makePlaintextRedirects(configs)
...@@ -175,7 +176,7 @@ func TestMakePlaintextRedirects(t *testing.T) { ...@@ -175,7 +176,7 @@ func TestMakePlaintextRedirects(t *testing.T) {
func TestEnableAutoHTTPS(t *testing.T) { func TestEnableAutoHTTPS(t *testing.T) {
configs := []*SiteConfig{ configs := []*SiteConfig{
{Addr: Address{Host: "example.com"}, TLS: &caddytls.Config{Managed: true}}, {Addr: Address{Host: "example.com"}, TLS: &caddytls.Config{Manager: &certmagic.Config{Managed: true}}},
{}, // not managed - no changes! {}, // not managed - no changes!
} }
...@@ -215,7 +216,7 @@ func TestMarkQualifiedForAutoHTTPS(t *testing.T) { ...@@ -215,7 +216,7 @@ func TestMarkQualifiedForAutoHTTPS(t *testing.T) {
count := 0 count := 0
for _, cfg := range configs { for _, cfg := range configs {
if cfg.TLS.Managed { if cfg.TLS.Manager.Managed {
count++ count++
} }
} }
......
...@@ -23,6 +23,7 @@ import ( ...@@ -23,6 +23,7 @@ import (
"net/url" "net/url"
"os" "os"
"path/filepath" "path/filepath"
"strconv"
"strings" "strings"
"time" "time"
...@@ -31,6 +32,7 @@ import ( ...@@ -31,6 +32,7 @@ import (
"github.com/mholt/caddy/caddyhttp/staticfiles" "github.com/mholt/caddy/caddyhttp/staticfiles"
"github.com/mholt/caddy/caddytls" "github.com/mholt/caddy/caddytls"
"github.com/mholt/caddy/telemetry" "github.com/mholt/caddy/telemetry"
"github.com/mholt/certmagic"
) )
const serverType = "http" const serverType = "http"
...@@ -169,12 +171,20 @@ func (h *httpContext) InspectServerBlocks(sourceFile string, serverBlocks []cadd ...@@ -169,12 +171,20 @@ func (h *httpContext) InspectServerBlocks(sourceFile string, serverBlocks []cadd
// If default HTTP or HTTPS ports have been customized, // If default HTTP or HTTPS ports have been customized,
// make sure the ACME challenge ports match // make sure the ACME challenge ports match
var altHTTPPort, altTLSALPNPort string var altHTTPPort, altTLSALPNPort int
if HTTPPort != DefaultHTTPPort { if HTTPPort != DefaultHTTPPort {
altHTTPPort = HTTPPort portInt, err := strconv.Atoi(HTTPPort)
if err != nil {
return nil, err
}
altHTTPPort = portInt
} }
if HTTPSPort != DefaultHTTPSPort { if HTTPSPort != DefaultHTTPSPort {
altTLSALPNPort = HTTPSPort portInt, err := strconv.Atoi(HTTPSPort)
if err != nil {
return nil, err
}
altTLSALPNPort = portInt
} }
// Make our caddytls.Config, which has a pointer to the // Make our caddytls.Config, which has a pointer to the
...@@ -182,8 +192,8 @@ func (h *httpContext) InspectServerBlocks(sourceFile string, serverBlocks []cadd ...@@ -182,8 +192,8 @@ func (h *httpContext) InspectServerBlocks(sourceFile string, serverBlocks []cadd
// to use automatic HTTPS when the time comes // to use automatic HTTPS when the time comes
caddytlsConfig := caddytls.NewConfig(h.instance) caddytlsConfig := caddytls.NewConfig(h.instance)
caddytlsConfig.Hostname = addr.Host caddytlsConfig.Hostname = addr.Host
caddytlsConfig.AltHTTPPort = altHTTPPort caddytlsConfig.Manager.AltHTTPPort = altHTTPPort
caddytlsConfig.AltTLSALPNPort = altTLSALPNPort caddytlsConfig.Manager.AltTLSALPNPort = altTLSALPNPort
// Save the config to our master list, and key it for lookups // Save the config to our master list, and key it for lookups
cfg := &SiteConfig{ cfg := &SiteConfig{
...@@ -221,7 +231,7 @@ func (h *httpContext) MakeServers() ([]caddy.Server, error) { ...@@ -221,7 +231,7 @@ func (h *httpContext) MakeServers() ([]caddy.Server, error) {
// trusted CA (obviously not a perfect hueristic) // trusted CA (obviously not a perfect hueristic)
var looksLikeProductionCA bool var looksLikeProductionCA bool
for _, publicCAEndpoint := range caddytls.KnownACMECAs { for _, publicCAEndpoint := range caddytls.KnownACMECAs {
if strings.Contains(caddytls.DefaultCAUrl, publicCAEndpoint) { if strings.Contains(certmagic.CA, publicCAEndpoint) {
looksLikeProductionCA = true looksLikeProductionCA = true
break break
} }
...@@ -243,7 +253,7 @@ func (h *httpContext) MakeServers() ([]caddy.Server, error) { ...@@ -243,7 +253,7 @@ func (h *httpContext) MakeServers() ([]caddy.Server, error) {
if !caddy.IsLoopback(cfg.Addr.Host) && if !caddy.IsLoopback(cfg.Addr.Host) &&
!caddy.IsLoopback(cfg.ListenHost) && !caddy.IsLoopback(cfg.ListenHost) &&
(caddytls.QualifiesForManagedTLS(cfg) || (caddytls.QualifiesForManagedTLS(cfg) ||
caddytls.HostQualifies(cfg.Addr.Host)) { certmagic.HostQualifies(cfg.Addr.Host)) {
atLeastOneSiteLooksLikeProduction = true atLeastOneSiteLooksLikeProduction = true
} }
} }
...@@ -264,7 +274,7 @@ func (h *httpContext) MakeServers() ([]caddy.Server, error) { ...@@ -264,7 +274,7 @@ func (h *httpContext) MakeServers() ([]caddy.Server, error) {
// is incorrect for this site. // is incorrect for this site.
cfg.Addr.Scheme = "https" cfg.Addr.Scheme = "https"
} }
if cfg.Addr.Port == "" && ((!cfg.TLS.Manual && !cfg.TLS.SelfSigned) || cfg.TLS.OnDemand) { if cfg.Addr.Port == "" && ((!cfg.TLS.Manual && !cfg.TLS.SelfSigned) || cfg.TLS.Manager.OnDemand != nil) {
// this is vital, otherwise the function call below that // this is vital, otherwise the function call below that
// sets the listener address will use the default port // sets the listener address will use the default port
// instead of 443 because it doesn't know about TLS. // instead of 443 because it doesn't know about TLS.
......
...@@ -402,24 +402,26 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) (int, error) ...@@ -402,24 +402,26 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) (int, error)
if vhost == nil { if vhost == nil {
// check for ACME challenge even if vhost is nil; // check for ACME challenge even if vhost is nil;
// could be a new host coming online soon // could be a new host coming online soon - choose any
if caddytls.HTTPChallengeHandler(w, r, "localhost") { // vhost's cert manager configuration, I guess
if len(s.sites) > 0 && s.sites[0].TLS.Manager.HandleHTTPChallenge(w, r) {
return 0, nil return 0, nil
} }
// otherwise, log the error and write a message to the client // otherwise, log the error and write a message to the client
remoteHost, _, err := net.SplitHostPort(r.RemoteAddr) remoteHost, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil { if err != nil {
remoteHost = r.RemoteAddr remoteHost = r.RemoteAddr
} }
WriteSiteNotFound(w, r) // don't add headers outside of this function WriteSiteNotFound(w, r) // don't add headers outside of this function (http.forwardproxy)
log.Printf("[INFO] %s - No such site at %s (Remote: %s, Referer: %s)", log.Printf("[INFO] %s - No such site at %s (Remote: %s, Referer: %s)",
hostname, s.Server.Addr, remoteHost, r.Header.Get("Referer")) hostname, s.Server.Addr, remoteHost, r.Header.Get("Referer"))
return 0, nil return 0, nil
} }
// we still check for ACME challenge if the vhost exists, // we still check for ACME challenge if the vhost exists,
// because we must apply its HTTP challenge config settings // because the HTTP challenge might be disabled by its config
if caddytls.HTTPChallengeHandler(w, r, vhost.ListenHost) { if vhost.TLS.Manager.HandleHTTPChallenge(w, r) {
return 0, nil return 0, nil
} }
......
This diff is collapsed.
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package caddytls
import "testing"
func TestUnexportedGetCertificate(t *testing.T) {
certCache := &certificateCache{cache: make(map[string]Certificate)}
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
// When cache is empty
if _, matched, defaulted := cfg.getCertificate("example.com"); matched || defaulted {
t.Errorf("Got a certificate when cache was empty; matched=%v, defaulted=%v", matched, defaulted)
}
// When cache has one certificate in it
firstCert := Certificate{Names: []string{"example.com"}}
certCache.cache["0xdeadbeef"] = firstCert
cfg.Certificates["example.com"] = "0xdeadbeef"
if cert, matched, defaulted := cfg.getCertificate("Example.com"); !matched || defaulted || cert.Names[0] != "example.com" {
t.Errorf("Didn't get a cert for 'Example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
}
if cert, matched, defaulted := cfg.getCertificate("example.com"); !matched || defaulted || cert.Names[0] != "example.com" {
t.Errorf("Didn't get a cert for 'example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
}
// When retrieving wildcard certificate
certCache.cache["0xb01dface"] = Certificate{Names: []string{"*.example.com"}}
cfg.Certificates["*.example.com"] = "0xb01dface"
if cert, matched, defaulted := cfg.getCertificate("sub.example.com"); !matched || defaulted || cert.Names[0] != "*.example.com" {
t.Errorf("Didn't get wildcard cert for 'sub.example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
}
// When no certificate matches and SNI is provided, return no certificate (should be TLS alert)
if cert, matched, defaulted := cfg.getCertificate("nomatch"); matched || defaulted {
t.Errorf("Expected matched=false, defaulted=false; but got matched=%v, defaulted=%v (cert: %v)", matched, defaulted, cert)
}
}
func TestCacheCertificate(t *testing.T) {
certCache := &certificateCache{cache: make(map[string]Certificate)}
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
cfg.cacheCertificate(Certificate{Names: []string{"example.com", "sub.example.com"}, Hash: "foobar"})
if len(certCache.cache) != 1 {
t.Errorf("Expected length of certificate cache to be 1")
}
if _, ok := certCache.cache["foobar"]; !ok {
t.Error("Expected first cert to be cached by key 'foobar', but it wasn't")
}
if _, ok := cfg.Certificates["example.com"]; !ok {
t.Error("Expected first cert to be keyed by 'example.com', but it wasn't")
}
if _, ok := cfg.Certificates["sub.example.com"]; !ok {
t.Error("Expected first cert to be keyed by 'sub.example.com', but it wasn't")
}
// different config, but using same cache; and has cert with overlapping name,
// but different hash
cfg2 := &Config{Certificates: make(map[string]string), certCache: certCache}
cfg2.cacheCertificate(Certificate{Names: []string{"example.com"}, Hash: "barbaz"})
if _, ok := certCache.cache["barbaz"]; !ok {
t.Error("Expected second cert to be cached by key 'barbaz.com', but it wasn't")
}
if hash, ok := cfg2.Certificates["example.com"]; !ok {
t.Error("Expected second cert to be keyed by 'example.com', but it wasn't")
} else if hash != "barbaz" {
t.Errorf("Expected second cert to map to 'barbaz' but it was %s instead", hash)
}
}
This diff is collapsed.
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package caddytls
// TODO
This diff is collapsed.
...@@ -16,8 +16,6 @@ package caddytls ...@@ -16,8 +16,6 @@ package caddytls
import ( import (
"crypto/tls" "crypto/tls"
"errors"
"net/url"
"reflect" "reflect"
"testing" "testing"
...@@ -110,120 +108,3 @@ func TestGetPreferredDefaultCiphers(t *testing.T) { ...@@ -110,120 +108,3 @@ func TestGetPreferredDefaultCiphers(t *testing.T) {
} }
} }
} }
func TestStorageForNoURL(t *testing.T) {
c := &Config{}
if _, err := c.StorageFor(""); err == nil {
t.Fatal("Expected error on empty URL")
}
}
func TestStorageForLowercasesAndPrefixesScheme(t *testing.T) {
resultStr := ""
RegisterStorageProvider("fake-TestStorageForLowercasesAndPrefixesScheme", func(caURL *url.URL) (Storage, error) {
resultStr = caURL.String()
return nil, nil
})
c := &Config{
StorageProvider: "fake-TestStorageForLowercasesAndPrefixesScheme",
}
if _, err := c.StorageFor("EXAMPLE.COM/BLAH"); err != nil {
t.Fatal(err)
}
if resultStr != "https://example.com/blah" {
t.Fatalf("Unexpected CA URL string: %v", resultStr)
}
}
func TestStorageForBadURL(t *testing.T) {
c := &Config{}
if _, err := c.StorageFor("http://192.168.0.%31/"); err == nil {
t.Fatal("Expected error for bad URL")
}
}
func TestStorageForDefault(t *testing.T) {
c := &Config{}
s, err := c.StorageFor("example.com")
if err != nil {
t.Fatal(err)
}
if _, ok := s.(*FileStorage); !ok {
t.Fatalf("Unexpected storage type: %#v", s)
}
}
func TestStorageForCustom(t *testing.T) {
storage := fakeStorage("fake-TestStorageForCustom")
RegisterStorageProvider("fake-TestStorageForCustom", func(caURL *url.URL) (Storage, error) { return storage, nil })
c := &Config{
StorageProvider: "fake-TestStorageForCustom",
}
s, err := c.StorageFor("example.com")
if err != nil {
t.Fatal(err)
}
if s != storage {
t.Fatal("Unexpected storage")
}
}
func TestStorageForCustomError(t *testing.T) {
RegisterStorageProvider("fake-TestStorageForCustomError", func(caURL *url.URL) (Storage, error) { return nil, errors.New("some error") })
c := &Config{
StorageProvider: "fake-TestStorageForCustomError",
}
if _, err := c.StorageFor("example.com"); err == nil {
t.Fatal("Expecting error")
}
}
func TestStorageForCustomNil(t *testing.T) {
// Should fall through to the default
c := &Config{StorageProvider: ""}
s, err := c.StorageFor("example.com")
if err != nil {
t.Fatal(err)
}
if _, ok := s.(*FileStorage); !ok {
t.Fatalf("Unexpected storage type: %#v", s)
}
}
type fakeStorage string
func (s fakeStorage) SiteExists(domain string) (bool, error) {
panic("no impl")
}
func (s fakeStorage) LoadSite(domain string) (*SiteData, error) {
panic("no impl")
}
func (s fakeStorage) StoreSite(domain string, data *SiteData) error {
panic("no impl")
}
func (s fakeStorage) DeleteSite(domain string) error {
panic("no impl")
}
func (s fakeStorage) TryLock(domain string) (Waiter, error) {
panic("no impl")
}
func (s fakeStorage) Unlock(domain string) error {
panic("no impl")
}
func (s fakeStorage) LoadUser(email string) (*UserData, error) {
panic("no impl")
}
func (s fakeStorage) StoreUser(email string, data *UserData) error {
panic("no impl")
}
func (s fakeStorage) MostRecentUserEmail() string {
panic("no impl")
}
...@@ -15,265 +15,20 @@ ...@@ -15,265 +15,20 @@
package caddytls package caddytls
import ( import (
"bytes"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand" "crypto/rand"
"crypto/rsa"
"crypto/tls" "crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"errors"
"fmt"
"hash/fnv"
"io" "io"
"io/ioutil"
"log"
"math/big"
"net"
"os"
"path/filepath"
"strings"
"sync" "sync"
"time" "time"
"golang.org/x/crypto/ocsp"
"github.com/mholt/caddy"
"github.com/xenolf/lego/acme"
) )
// loadPrivateKey loads a PEM-encoded ECC/RSA private key from an array of bytes.
func loadPrivateKey(keyBytes []byte) (crypto.PrivateKey, error) {
keyBlock, _ := pem.Decode(keyBytes)
switch keyBlock.Type {
case "RSA PRIVATE KEY":
return x509.ParsePKCS1PrivateKey(keyBlock.Bytes)
case "EC PRIVATE KEY":
return x509.ParseECPrivateKey(keyBlock.Bytes)
}
return nil, errors.New("unknown private key type")
}
// savePrivateKey saves a PEM-encoded ECC/RSA private key to an array of bytes.
func savePrivateKey(key crypto.PrivateKey) ([]byte, error) {
var pemType string
var keyBytes []byte
switch key := key.(type) {
case *ecdsa.PrivateKey:
var err error
pemType = "EC"
keyBytes, err = x509.MarshalECPrivateKey(key)
if err != nil {
return nil, err
}
case *rsa.PrivateKey:
pemType = "RSA"
keyBytes = x509.MarshalPKCS1PrivateKey(key)
}
pemKey := pem.Block{Type: pemType + " PRIVATE KEY", Bytes: keyBytes}
return pem.EncodeToMemory(&pemKey), nil
}
// stapleOCSP staples OCSP information to cert for hostname name.
// If you have it handy, you should pass in the PEM-encoded certificate
// bundle; otherwise the DER-encoded cert will have to be PEM-encoded.
// If you don't have the PEM blocks already, just pass in nil.
//
// Errors here are not necessarily fatal, it could just be that the
// certificate doesn't have an issuer URL.
func stapleOCSP(cert *Certificate, pemBundle []byte) error {
if pemBundle == nil {
// The function in the acme package that gets OCSP requires a PEM-encoded cert
bundle := new(bytes.Buffer)
for _, derBytes := range cert.Certificate.Certificate {
pem.Encode(bundle, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
}
pemBundle = bundle.Bytes()
}
var ocspBytes []byte
var ocspResp *ocsp.Response
var ocspErr error
var gotNewOCSP bool
// First try to load OCSP staple from storage and see if
// we can still use it.
// TODO: Use Storage interface instead of disk directly
var ocspFileNamePrefix string
if len(cert.Names) > 0 {
firstName := strings.Replace(cert.Names[0], "*", "wildcard_", -1)
ocspFileNamePrefix = firstName + "-"
}
ocspFileName := ocspFileNamePrefix + fastHash(pemBundle)
ocspCachePath := filepath.Join(ocspFolder, ocspFileName)
cachedOCSP, err := ioutil.ReadFile(ocspCachePath)
if err == nil {
resp, err := ocsp.ParseResponse(cachedOCSP, nil)
if err == nil {
if freshOCSP(resp) {
// staple is still fresh; use it
ocspBytes = cachedOCSP
ocspResp = resp
}
} else {
// invalid contents; delete the file
// (we do this independently of the maintenance routine because
// in this case we know for sure this should be a staple file
// because we loaded it by name, whereas the maintenance routine
// just iterates the list of files, even if somehow a non-staple
// file gets in the folder. in this case we are sure it is corrupt.)
err := os.Remove(ocspCachePath)
if err != nil {
log.Printf("[WARNING] Unable to delete invalid OCSP staple file: %v", err)
}
}
}
// If we couldn't get a fresh staple by reading the cache,
// then we need to request it from the OCSP responder
if ocspResp == nil || len(ocspBytes) == 0 {
ocspBytes, ocspResp, ocspErr = acme.GetOCSPForCert(pemBundle)
if ocspErr != nil {
// An error here is not a problem because a certificate may simply
// not contain a link to an OCSP server. But we should log it anyway.
// There's nothing else we can do to get OCSP for this certificate,
// so we can return here with the error.
return fmt.Errorf("no OCSP stapling for %v: %v", cert.Names, ocspErr)
}
gotNewOCSP = true
}
// By now, we should have a response. If good, staple it to
// the certificate. If the OCSP response was not loaded from
// storage, we persist it for next time.
if ocspResp.Status == ocsp.Good {
if ocspResp.NextUpdate.After(cert.NotAfter) {
// uh oh, this OCSP response expires AFTER the certificate does, that's kinda bogus.
// it was the reason a lot of Symantec-validated sites (not Caddy) went down
// in October 2017. https://twitter.com/mattiasgeniar/status/919432824708648961
return fmt.Errorf("invalid: OCSP response for %v valid after certificate expiration (%s)",
cert.Names, cert.NotAfter.Sub(ocspResp.NextUpdate))
}
cert.Certificate.OCSPStaple = ocspBytes
cert.OCSP = ocspResp
if gotNewOCSP {
err := os.MkdirAll(filepath.Join(caddy.AssetsPath(), "ocsp"), 0700)
if err != nil {
return fmt.Errorf("unable to make OCSP staple path for %v: %v", cert.Names, err)
}
err = ioutil.WriteFile(ocspCachePath, ocspBytes, 0644)
if err != nil {
return fmt.Errorf("unable to write OCSP staple file for %v: %v", cert.Names, err)
}
}
}
return nil
}
func makeSelfSignedCertWithCustomSAN(sans []string, config *Config) (Certificate, error) {
// start by generating private key
var privKey interface{}
var err error
switch config.KeyType {
case "", acme.EC256:
privKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
case acme.EC384:
privKey, err = ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
case acme.RSA2048:
privKey, err = rsa.GenerateKey(rand.Reader, 2048)
case acme.RSA4096:
privKey, err = rsa.GenerateKey(rand.Reader, 4096)
case acme.RSA8192:
privKey, err = rsa.GenerateKey(rand.Reader, 8192)
default:
return Certificate{}, fmt.Errorf("cannot generate private key; unknown key type %v", config.KeyType)
}
if err != nil {
return Certificate{}, fmt.Errorf("failed to generate private key: %v", err)
}
// create certificate structure with proper values
notBefore := time.Now()
notAfter := notBefore.Add(24 * time.Hour * 7)
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
return Certificate{}, fmt.Errorf("failed to generate serial number: %v", err)
}
cert := &x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{Organization: []string{"Caddy Self-Signed"}},
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
}
if len(sans) == 0 {
sans = []string{""}
}
var names []string
for _, san := range sans {
if ip := net.ParseIP(san); ip != nil {
names = append(names, strings.ToLower(ip.String()))
cert.IPAddresses = append(cert.IPAddresses, ip)
} else {
names = append(names, strings.ToLower(san))
cert.DNSNames = append(cert.DNSNames, strings.ToLower(san))
}
}
publicKey := func(privKey interface{}) interface{} {
switch k := privKey.(type) {
case *rsa.PrivateKey:
return &k.PublicKey
case *ecdsa.PrivateKey:
return &k.PublicKey
default:
return errors.New("unknown key type")
}
}
derBytes, err := x509.CreateCertificate(rand.Reader, cert, cert, publicKey(privKey), privKey)
if err != nil {
return Certificate{}, fmt.Errorf("could not create certificate: %v", err)
}
chain := [][]byte{derBytes}
return Certificate{
Certificate: tls.Certificate{
Certificate: chain,
PrivateKey: privKey,
Leaf: cert,
},
Names: names,
NotAfter: cert.NotAfter,
Hash: hashCertificateChain(chain),
}, nil
}
// makeSelfSignedCertForConfig makes a self-signed certificate according
// to the parameters in config and caches the new cert in config directly.
func makeSelfSignedCertForConfig(config *Config) error {
cert, err := makeSelfSignedCertWithCustomSAN([]string{config.Hostname}, config)
if err != nil {
return err
}
config.cacheCertificate(cert)
return nil
}
// RotateSessionTicketKeys rotates the TLS session ticket keys // RotateSessionTicketKeys rotates the TLS session ticket keys
// on cfg every TicketRotateInterval. It spawns a new goroutine so // on cfg every TicketRotateInterval. It spawns a new goroutine so
// this function does NOT block. It returns a channel you should // this function does NOT block. It returns a channel you should
// close when you are ready to stop the key rotation, like when the // close when you are ready to stop the key rotation, like when the
// server using cfg is no longer running. // server using cfg is no longer running.
//
// TODO: See about moving this into CertMagic and using its Storage
func RotateSessionTicketKeys(cfg *tls.Config) chan struct{} { func RotateSessionTicketKeys(cfg *tls.Config) chan struct{} {
ch := make(chan struct{}) ch := make(chan struct{})
ticker := time.NewTicker(TicketRotateInterval) ticker := time.NewTicker(TicketRotateInterval)
...@@ -347,15 +102,6 @@ func standaloneTLSTicketKeyRotation(c *tls.Config, ticker *time.Ticker, exitChan ...@@ -347,15 +102,6 @@ func standaloneTLSTicketKeyRotation(c *tls.Config, ticker *time.Ticker, exitChan
} }
} }
// fastHash hashes input using a hashing algorithm that
// is fast, and returns the hash as a hex-encoded string.
// Do not use this for cryptographic purposes.
func fastHash(input []byte) string {
h := fnv.New32a()
h.Write(input)
return fmt.Sprintf("%x", h.Sum32())
}
const ( const (
// NumTickets is how many tickets to hold and consider // NumTickets is how many tickets to hold and consider
// to decrypt TLS sessions. // to decrypt TLS sessions.
......
...@@ -15,83 +15,11 @@ ...@@ -15,83 +15,11 @@
package caddytls package caddytls
import ( import (
"bytes"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/tls" "crypto/tls"
"crypto/x509"
"testing" "testing"
"time" "time"
) )
func TestSaveAndLoadRSAPrivateKey(t *testing.T) {
privateKey, err := rsa.GenerateKey(rand.Reader, 128) // make tests faster; small key size OK for testing
if err != nil {
t.Fatal(err)
}
// test save
savedBytes, err := savePrivateKey(privateKey)
if err != nil {
t.Fatal("error saving private key:", err)
}
// test load
loadedKey, err := loadPrivateKey(savedBytes)
if err != nil {
t.Error("error loading private key:", err)
}
// verify loaded key is correct
if !PrivateKeysSame(privateKey, loadedKey) {
t.Error("Expected key bytes to be the same, but they weren't")
}
}
func TestSaveAndLoadECCPrivateKey(t *testing.T) {
privateKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
if err != nil {
t.Fatal(err)
}
// test save
savedBytes, err := savePrivateKey(privateKey)
if err != nil {
t.Fatal("error saving private key:", err)
}
// test load
loadedKey, err := loadPrivateKey(savedBytes)
if err != nil {
t.Error("error loading private key:", err)
}
// verify loaded key is correct
if !PrivateKeysSame(privateKey, loadedKey) {
t.Error("Expected key bytes to be the same, but they weren't")
}
}
// PrivateKeysSame compares the bytes of a and b and returns true if they are the same.
func PrivateKeysSame(a, b crypto.PrivateKey) bool {
return bytes.Equal(PrivateKeyBytes(a), PrivateKeyBytes(b))
}
// PrivateKeyBytes returns the bytes of DER-encoded key.
func PrivateKeyBytes(key crypto.PrivateKey) []byte {
var keyBytes []byte
switch key := key.(type) {
case *rsa.PrivateKey:
keyBytes = x509.MarshalPKCS1PrivateKey(key)
case *ecdsa.PrivateKey:
keyBytes, _ = x509.MarshalECPrivateKey(key)
}
return keyBytes
}
func TestStandaloneTLSTicketKeyRotation(t *testing.T) { func TestStandaloneTLSTicketKeyRotation(t *testing.T) {
type syncPkt struct { type syncPkt struct {
ticketKey [32]byte ticketKey [32]byte
......
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package caddytls
import (
"fmt"
"io/ioutil"
"log"
"net/url"
"os"
"path/filepath"
"strings"
"github.com/mholt/caddy"
)
func init() {
RegisterStorageProvider("file", NewFileStorage)
}
// NewFileStorage is a StorageConstructor function that creates a new
// Storage instance backed by the local disk. The resulting Storage
// instance is guaranteed to be non-nil if there is no error.
func NewFileStorage(caURL *url.URL) (Storage, error) {
// storageBasePath is the root path in which all TLS/ACME assets are
// stored. Do not change this value during the lifetime of the program.
storageBasePath := filepath.Join(caddy.AssetsPath(), "acme")
storage := &FileStorage{Path: filepath.Join(storageBasePath, caURL.Host)}
storage.Locker = &fileStorageLock{caURL: caURL.Host, storage: storage}
return storage, nil
}
// FileStorage facilitates forming file paths derived from a root
// directory. It is used to get file paths in a consistent,
// cross-platform way or persisting ACME assets on the file system.
type FileStorage struct {
Path string
Locker
}
// sites gets the directory that stores site certificate and keys.
func (s *FileStorage) sites() string {
return filepath.Join(s.Path, "sites")
}
// site returns the path to the folder containing assets for domain.
func (s *FileStorage) site(domain string) string {
domain = fileSafe(domain)
return filepath.Join(s.sites(), domain)
}
// siteCertFile returns the path to the certificate file for domain.
func (s *FileStorage) siteCertFile(domain string) string {
domain = fileSafe(domain)
return filepath.Join(s.site(domain), domain+".crt")
}
// siteKeyFile returns the path to domain's private key file.
func (s *FileStorage) siteKeyFile(domain string) string {
domain = fileSafe(domain)
return filepath.Join(s.site(domain), domain+".key")
}
// siteMetaFile returns the path to the domain's asset metadata file.
func (s *FileStorage) siteMetaFile(domain string) string {
domain = fileSafe(domain)
return filepath.Join(s.site(domain), domain+".json")
}
// users gets the directory that stores account folders.
func (s *FileStorage) users() string {
return filepath.Join(s.Path, "users")
}
// user gets the account folder for the user with email
func (s *FileStorage) user(email string) string {
if email == "" {
email = emptyEmail
}
email = fileSafe(email)
return filepath.Join(s.users(), email)
}
// emailUsername returns the username portion of an email address (part before
// '@') or the original input if it can't find the "@" symbol.
func emailUsername(email string) string {
at := strings.Index(email, "@")
if at == -1 {
return email
} else if at == 0 {
return email[1:]
}
return email[:at]
}
// userRegFile gets the path to the registration file for the user with the
// given email address.
func (s *FileStorage) userRegFile(email string) string {
if email == "" {
email = emptyEmail
}
email = strings.ToLower(email)
fileName := emailUsername(email)
if fileName == "" {
fileName = "registration"
}
fileName = fileSafe(fileName)
return filepath.Join(s.user(email), fileName+".json")
}
// userKeyFile gets the path to the private key file for the user with the
// given email address.
func (s *FileStorage) userKeyFile(email string) string {
if email == "" {
email = emptyEmail
}
email = strings.ToLower(email)
fileName := emailUsername(email)
if fileName == "" {
fileName = "private"
}
fileName = fileSafe(fileName)
return filepath.Join(s.user(email), fileName+".key")
}
// readFile abstracts a simple ioutil.ReadFile, making sure to return an
// ErrNotExist instance when the file is not found.
func (s *FileStorage) readFile(file string) ([]byte, error) {
b, err := ioutil.ReadFile(file)
if os.IsNotExist(err) {
return nil, ErrNotExist(err)
}
return b, err
}
// SiteExists implements Storage.SiteExists by checking for the presence of
// cert and key files.
func (s *FileStorage) SiteExists(domain string) (bool, error) {
_, err := os.Stat(s.siteCertFile(domain))
if os.IsNotExist(err) {
return false, nil
} else if err != nil {
return false, err
}
_, err = os.Stat(s.siteKeyFile(domain))
if err != nil {
return false, err
}
return true, nil
}
// LoadSite implements Storage.LoadSite by loading it from disk. If it is not
// present, an instance of ErrNotExist is returned.
func (s *FileStorage) LoadSite(domain string) (*SiteData, error) {
var err error
siteData := new(SiteData)
siteData.Cert, err = s.readFile(s.siteCertFile(domain))
if err != nil {
return nil, err
}
siteData.Key, err = s.readFile(s.siteKeyFile(domain))
if err != nil {
return nil, err
}
siteData.Meta, err = s.readFile(s.siteMetaFile(domain))
if err != nil {
return nil, err
}
return siteData, nil
}
// StoreSite implements Storage.StoreSite by writing it to disk. The base
// directories needed for the file are automatically created as needed.
func (s *FileStorage) StoreSite(domain string, data *SiteData) error {
err := os.MkdirAll(s.site(domain), 0700)
if err != nil {
return fmt.Errorf("making site directory: %v", err)
}
err = ioutil.WriteFile(s.siteCertFile(domain), data.Cert, 0600)
if err != nil {
return fmt.Errorf("writing certificate file: %v", err)
}
err = ioutil.WriteFile(s.siteKeyFile(domain), data.Key, 0600)
if err != nil {
return fmt.Errorf("writing key file: %v", err)
}
err = ioutil.WriteFile(s.siteMetaFile(domain), data.Meta, 0600)
if err != nil {
return fmt.Errorf("writing cert meta file: %v", err)
}
log.Printf("[INFO][%v] Certificate written to disk: %v", domain, s.siteCertFile(domain))
return nil
}
// DeleteSite implements Storage.DeleteSite by deleting just the cert from
// disk. If it is not present, an instance of ErrNotExist is returned.
func (s *FileStorage) DeleteSite(domain string) error {
err := os.Remove(s.siteCertFile(domain))
if err != nil {
if os.IsNotExist(err) {
return ErrNotExist(err)
}
return err
}
return nil
}
// LoadUser implements Storage.LoadUser by loading it from disk. If it is not
// present, an instance of ErrNotExist is returned.
func (s *FileStorage) LoadUser(email string) (*UserData, error) {
var err error
userData := new(UserData)
userData.Reg, err = s.readFile(s.userRegFile(email))
if err != nil {
return nil, err
}
userData.Key, err = s.readFile(s.userKeyFile(email))
if err != nil {
return nil, err
}
return userData, nil
}
// StoreUser implements Storage.StoreUser by writing it to disk. The base
// directories needed for the file are automatically created as needed.
func (s *FileStorage) StoreUser(email string, data *UserData) error {
err := os.MkdirAll(s.user(email), 0700)
if err != nil {
return fmt.Errorf("making user directory: %v", err)
}
err = ioutil.WriteFile(s.userRegFile(email), data.Reg, 0600)
if err != nil {
return fmt.Errorf("writing user registration file: %v", err)
}
err = ioutil.WriteFile(s.userKeyFile(email), data.Key, 0600)
if err != nil {
return fmt.Errorf("writing user key file: %v", err)
}
return nil
}
// MostRecentUserEmail implements Storage.MostRecentUserEmail by finding the
// most recently written sub directory in the users' directory. It is named
// after the email address. This corresponds to the most recent call to
// StoreUser.
func (s *FileStorage) MostRecentUserEmail() string {
userDirs, err := ioutil.ReadDir(s.users())
if err != nil {
return ""
}
var mostRecent os.FileInfo
for _, dir := range userDirs {
if !dir.IsDir() {
continue
}
if mostRecent == nil || dir.ModTime().After(mostRecent.ModTime()) {
mostRecent = dir
}
}
if mostRecent != nil {
return mostRecent.Name()
}
return ""
}
// fileSafe standardizes and sanitizes str for use in a file path.
func fileSafe(str string) string {
str = strings.ToLower(str)
str = strings.TrimSpace(str)
repl := strings.NewReplacer(
"..", "",
"/", "",
"\\", "",
// TODO: Consider also replacing "@" with "_at_" (but migrate existing accounts...)
"+", "_plus_",
"*", "wildcard_",
"%", "",
"$", "",
"`", "",
"~", "",
":", "",
";", "",
"=", "",
"!", "",
"#", "",
"&", "",
"|", "",
`"`, "",
"'", "")
return repl.Replace(str)
}
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package caddytls
import (
"path/filepath"
"testing"
)
// *********************************** NOTE ********************************
// Due to circular package dependencies with the storagetest sub package and
// the fact that we want to use that harness to test file storage, most of
// the tests for file storage are done in the storagetest package.
func TestPathBuilders(t *testing.T) {
fs := FileStorage{Path: "test"}
for i, testcase := range []struct {
in, folder, certFile, keyFile, metaFile string
}{
{
in: "example.com",
folder: filepath.Join("test", "sites", "example.com"),
certFile: filepath.Join("test", "sites", "example.com", "example.com.crt"),
keyFile: filepath.Join("test", "sites", "example.com", "example.com.key"),
metaFile: filepath.Join("test", "sites", "example.com", "example.com.json"),
},
{
in: "*.example.com",
folder: filepath.Join("test", "sites", "wildcard_.example.com"),
certFile: filepath.Join("test", "sites", "wildcard_.example.com", "wildcard_.example.com.crt"),
keyFile: filepath.Join("test", "sites", "wildcard_.example.com", "wildcard_.example.com.key"),
metaFile: filepath.Join("test", "sites", "wildcard_.example.com", "wildcard_.example.com.json"),
},
{
// prevent directory traversal! very important, esp. with on-demand TLS
// see issue #2092
in: "a/../../../foo",
folder: filepath.Join("test", "sites", "afoo"),
certFile: filepath.Join("test", "sites", "afoo", "afoo.crt"),
keyFile: filepath.Join("test", "sites", "afoo", "afoo.key"),
metaFile: filepath.Join("test", "sites", "afoo", "afoo.json"),
},
{
in: "b\\..\\..\\..\\foo",
folder: filepath.Join("test", "sites", "bfoo"),
certFile: filepath.Join("test", "sites", "bfoo", "bfoo.crt"),
keyFile: filepath.Join("test", "sites", "bfoo", "bfoo.key"),
metaFile: filepath.Join("test", "sites", "bfoo", "bfoo.json"),
},
{
in: "c/foo",
folder: filepath.Join("test", "sites", "cfoo"),
certFile: filepath.Join("test", "sites", "cfoo", "cfoo.crt"),
keyFile: filepath.Join("test", "sites", "cfoo", "cfoo.key"),
metaFile: filepath.Join("test", "sites", "cfoo", "cfoo.json"),
},
} {
if actual := fs.site(testcase.in); actual != testcase.folder {
t.Errorf("Test %d: site folder: Expected '%s' but got '%s'", i, testcase.folder, actual)
}
if actual := fs.siteCertFile(testcase.in); actual != testcase.certFile {
t.Errorf("Test %d: site cert file: Expected '%s' but got '%s'", i, testcase.certFile, actual)
}
if actual := fs.siteKeyFile(testcase.in); actual != testcase.keyFile {
t.Errorf("Test %d: site key file: Expected '%s' but got '%s'", i, testcase.keyFile, actual)
}
if actual := fs.siteMetaFile(testcase.in); actual != testcase.metaFile {
t.Errorf("Test %d: site meta file: Expected '%s' but got '%s'", i, testcase.metaFile, actual)
}
}
}
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package caddytls
import (
"fmt"
"os"
"sync"
"time"
"github.com/mholt/caddy"
)
func init() {
// be sure to remove lock files when exiting the process!
caddy.OnProcessExit = append(caddy.OnProcessExit, func() {
fileStorageNameLocksMu.Lock()
defer fileStorageNameLocksMu.Unlock()
for key, fw := range fileStorageNameLocks {
os.Remove(fw.filename)
delete(fileStorageNameLocks, key)
}
})
}
// fileStorageLock facilitates ACME-related locking by using
// the associated FileStorage, so multiple processes can coordinate
// renewals on the certificates on a shared file system.
type fileStorageLock struct {
caURL string
storage *FileStorage
}
// TryLock attempts to get a lock for name, otherwise it returns
// a Waiter value to wait until the other process is finished.
func (s *fileStorageLock) TryLock(name string) (Waiter, error) {
fileStorageNameLocksMu.Lock()
defer fileStorageNameLocksMu.Unlock()
// see if lock already exists within this process
fw, ok := fileStorageNameLocks[s.caURL+name]
if ok {
// lock already created within process, let caller wait on it
return fw, nil
}
// attempt to persist lock to disk by creating lock file
fw = &fileWaiter{
filename: s.storage.siteCertFile(name) + ".lock",
wg: new(sync.WaitGroup),
}
// parent dir must exist
if err := os.MkdirAll(s.storage.site(name), 0700); err != nil {
return nil, err
}
lf, err := os.OpenFile(fw.filename, os.O_CREATE|os.O_EXCL, 0644)
if err != nil {
if os.IsExist(err) {
// another process has the lock; use it to wait
return fw, nil
}
// otherwise, this was some unexpected error
return nil, err
}
lf.Close()
// looks like we get the lock
fw.wg.Add(1)
fileStorageNameLocks[s.caURL+name] = fw
return nil, nil
}
// Unlock unlocks name.
func (s *fileStorageLock) Unlock(name string) error {
fileStorageNameLocksMu.Lock()
defer fileStorageNameLocksMu.Unlock()
fw, ok := fileStorageNameLocks[s.caURL+name]
if !ok {
return fmt.Errorf("FileStorage: no lock to release for %s", name)
}
// remove lock file
os.Remove(fw.filename)
// if parent folder is now empty, remove it too to keep it tidy
lockParentFolder := s.storage.site(name)
dir, err := os.Open(lockParentFolder)
if err == nil {
items, _ := dir.Readdirnames(3) // OK to ignore error here
if len(items) == 0 {
os.Remove(lockParentFolder)
}
dir.Close()
}
fw.wg.Done()
delete(fileStorageNameLocks, s.caURL+name)
return nil
}
// fileWaiter waits for a file to disappear; it polls
// the file system to check for the existence of a file.
// It also has a WaitGroup which will be faster than
// polling, for when locking need only happen within this
// process.
type fileWaiter struct {
filename string
wg *sync.WaitGroup
}
// Wait waits until the lock is released.
func (fw *fileWaiter) Wait() {
start := time.Now()
fw.wg.Wait()
for time.Since(start) < 1*time.Hour {
_, err := os.Stat(fw.filename)
if os.IsNotExist(err) {
return
}
time.Sleep(1 * time.Second)
}
}
var fileStorageNameLocks = make(map[string]*fileWaiter) // keyed by CA + name
var fileStorageNameLocksMu sync.Mutex
var _ Locker = &fileStorageLock{}
var _ Waiter = &fileWaiter{}
This diff is collapsed.
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package caddytls
import (
"crypto/tls"
"crypto/x509"
"testing"
)
func TestGetCertificate(t *testing.T) {
certCache := &certificateCache{cache: make(map[string]Certificate)}
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
hello := &tls.ClientHelloInfo{ServerName: "example.com"}
helloSub := &tls.ClientHelloInfo{ServerName: "sub.example.com"}
helloNoSNI := &tls.ClientHelloInfo{}
helloNoMatch := &tls.ClientHelloInfo{ServerName: "nomatch"} // TODO (see below)
// When cache is empty
if cert, err := cfg.GetCertificate(hello); err == nil {
t.Errorf("GetCertificate should return error when cache is empty, got: %v", cert)
}
if cert, err := cfg.GetCertificate(helloNoSNI); err == nil {
t.Errorf("GetCertificate should return error when cache is empty even if server name is blank, got: %v", cert)
}
// When cache has one certificate in it
firstCert := Certificate{Names: []string{"example.com"}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"example.com"}}}}
cfg.cacheCertificate(firstCert)
if cert, err := cfg.GetCertificate(hello); err != nil {
t.Errorf("Got an error but shouldn't have, when cert exists in cache: %v", err)
} else if cert.Leaf.DNSNames[0] != "example.com" {
t.Errorf("Got wrong certificate with exact match; expected 'example.com', got: %v", cert)
}
if _, err := cfg.GetCertificate(helloNoSNI); err != nil {
t.Errorf("Got an error with no SNI but shouldn't have, when cert exists in cache: %v", err)
}
// When retrieving wildcard certificate
wildcardCert := Certificate{
Names: []string{"*.example.com"},
Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"*.example.com"}}},
Hash: "(don't overwrite the first one)",
}
cfg.cacheCertificate(wildcardCert)
if cert, err := cfg.GetCertificate(helloSub); err != nil {
t.Errorf("Didn't get wildcard cert, got: cert=%v, err=%v ", cert, err)
} else if cert.Leaf.DNSNames[0] != "*.example.com" {
t.Errorf("Got wrong certificate, expected wildcard: %v", cert)
}
// When cache is NOT empty but there's no SNI
if cert, err := cfg.GetCertificate(helloNoSNI); err != nil {
t.Errorf("Expected random certificate with no error when no SNI, got err: %v", err)
} else if cert == nil || len(cert.Leaf.DNSNames) == 0 {
t.Errorf("Expected random cert with no matches, got: %v", cert)
}
// When no certificate matches, raise an alert
if _, err := cfg.GetCertificate(helloNoMatch); err == nil {
t.Errorf("Expected an error when no certificate matched the SNI, got: %v", err)
}
}
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package caddytls
import (
"crypto/tls"
"encoding/json"
"fmt"
"log"
"net/http"
"net/http/httputil"
"net/url"
"os"
"strings"
"github.com/xenolf/lego/acme"
)
const challengeBasePath = "/.well-known/acme-challenge"
// HTTPChallengeHandler proxies challenge requests to ACME client if the
// request path starts with challengeBasePath, if the HTTP challenge is not
// disabled, and if we are known to be obtaining a certificate for the name.
// It returns true if it handled the request and no more needs to be done;
// it returns false if this call was a no-op and the request still needs handling.
func HTTPChallengeHandler(w http.ResponseWriter, r *http.Request, listenHost string) bool {
if !strings.HasPrefix(r.URL.Path, challengeBasePath) {
return false
}
if DisableHTTPChallenge {
return false
}
// see if another instance started the HTTP challenge for this name
if tryDistributedChallengeSolver(w, r) {
return true
}
// otherwise, if we aren't getting the name, then ignore this challenge
if !namesObtaining.Has(r.Host) {
return false
}
scheme := "http"
if r.TLS != nil {
scheme = "https"
}
if listenHost == "" {
listenHost = "localhost"
}
// always proxy to the DefaultHTTPAlternatePort because obviously the
// ACME challenge request already got into one of our HTTP handlers, so
// it means we must have started a HTTP listener on the alternate
// port instead; which is only accessible via listenHost
upstream, err := url.Parse(fmt.Sprintf("%s://%s:%s", scheme, listenHost, DefaultHTTPAlternatePort))
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
log.Printf("[ERROR] ACME proxy handler: %v", err)
return true
}
proxy := httputil.NewSingleHostReverseProxy(upstream)
proxy.Transport = &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
proxy.ServeHTTP(w, r)
return true
}
// tryDistributedChallengeSolver checks to see if this challenge
// request was initiated by another instance that shares file
// storage, and attempts to complete the challenge for it. It
// returns true if the challenge was handled; false otherwise.
func tryDistributedChallengeSolver(w http.ResponseWriter, r *http.Request) bool {
filePath := distributedSolver{}.challengeTokensPath(r.Host)
f, err := os.Open(filePath)
if err != nil {
if !os.IsNotExist(err) {
log.Printf("[ERROR][%s] Opening distributed challenge token file: %v", r.Host, err)
}
return false
}
defer f.Close()
var chalInfo challengeInfo
err = json.NewDecoder(f).Decode(&chalInfo)
if err != nil {
log.Printf("[ERROR][%s] Decoding challenge token file %s (corrupted?): %v", r.Host, filePath, err)
return false
}
// this part borrowed from xenolf/lego's built-in HTTP-01 challenge solver (March 2018)
challengeReqPath := acme.HTTP01ChallengePath(chalInfo.Token)
if r.URL.Path == challengeReqPath &&
strings.HasPrefix(r.Host, chalInfo.Domain) &&
r.Method == "GET" {
w.Header().Add("Content-Type", "text/plain")
w.Write([]byte(chalInfo.KeyAuth))
r.Close = true
log.Printf("[INFO][%s] Served key authentication (distributed)", chalInfo.Domain)
return true
}
return false
}
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package caddytls
import (
"net"
"net/http"
"net/http/httptest"
"testing"
)
func TestHTTPChallengeHandlerNoOp(t *testing.T) {
namesObtaining.Add([]string{"localhost"})
// try base paths and host names that aren't
// handled by this handler
for _, url := range []string{
"http://localhost/",
"http://localhost/foo.html",
"http://localhost/.git",
"http://localhost/.well-known/",
"http://localhost/.well-known/acme-challenging",
"http://other/.well-known/acme-challenge/foo",
} {
req, err := http.NewRequest("GET", url, nil)
if err != nil {
t.Fatalf("Could not craft request, got error: %v", err)
}
rw := httptest.NewRecorder()
if HTTPChallengeHandler(rw, req, "") {
t.Errorf("Got true with this URL, but shouldn't have: %s", url)
}
}
}
func TestHTTPChallengeHandlerSuccess(t *testing.T) {
expectedPath := challengeBasePath + "/asdf"
// Set up fake acme handler backend to make sure proxying succeeds
var proxySuccess bool
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
proxySuccess = true
if r.URL.Path != expectedPath {
t.Errorf("Expected path '%s' but got '%s' instead", expectedPath, r.URL.Path)
}
}))
// Custom listener that uses the port we expect
ln, err := net.Listen("tcp", "127.0.0.1:"+DefaultHTTPAlternatePort)
if err != nil {
t.Fatalf("Unable to start test server listener: %v", err)
}
ts.Listener = ln
// Tell this package that we are handling a challenge for 127.0.0.1
namesObtaining.Add([]string{"127.0.0.1"})
// Start our engines and run the test
ts.Start()
defer ts.Close()
req, err := http.NewRequest("GET", "http://127.0.0.1:"+DefaultHTTPAlternatePort+expectedPath, nil)
if err != nil {
t.Fatalf("Could not craft request, got error: %v", err)
}
rw := httptest.NewRecorder()
HTTPChallengeHandler(rw, req, "")
if !proxySuccess {
t.Fatal("Expected request to be proxied, but it wasn't")
}
}
This diff is collapsed.
package caddytls
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"fmt"
"math/big"
"net"
"strings"
"time"
"github.com/xenolf/lego/certcrypto"
)
// newSelfSignedCertificate returns a new self-signed certificate.
func newSelfSignedCertificate(ssconfig selfSignedConfig) (tls.Certificate, error) {
// start by generating private key
var privKey interface{}
var err error
switch ssconfig.KeyType {
case "", certcrypto.EC256:
privKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
case certcrypto.EC384:
privKey, err = ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
case certcrypto.RSA2048:
privKey, err = rsa.GenerateKey(rand.Reader, 2048)
case certcrypto.RSA4096:
privKey, err = rsa.GenerateKey(rand.Reader, 4096)
case certcrypto.RSA8192:
privKey, err = rsa.GenerateKey(rand.Reader, 8192)
default:
return tls.Certificate{}, fmt.Errorf("cannot generate private key; unknown key type %v", ssconfig.KeyType)
}
if err != nil {
return tls.Certificate{}, fmt.Errorf("failed to generate private key: %v", err)
}
// create certificate structure with proper values
notBefore := time.Now()
notAfter := ssconfig.Expire
if notAfter.IsZero() || notAfter.Before(notBefore) {
notAfter = notBefore.Add(24 * time.Hour * 7)
}
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
return tls.Certificate{}, fmt.Errorf("failed to generate serial number: %v", err)
}
cert := &x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{Organization: []string{"Caddy Self-Signed"}},
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
}
if len(ssconfig.SAN) == 0 {
ssconfig.SAN = []string{""}
}
var names []string
for _, san := range ssconfig.SAN {
if ip := net.ParseIP(san); ip != nil {
names = append(names, strings.ToLower(ip.String()))
cert.IPAddresses = append(cert.IPAddresses, ip)
} else {
names = append(names, strings.ToLower(san))
cert.DNSNames = append(cert.DNSNames, strings.ToLower(san))
}
}
// generate the associated public key
publicKey := func(privKey interface{}) interface{} {
switch k := privKey.(type) {
case *rsa.PrivateKey:
return &k.PublicKey
case *ecdsa.PrivateKey:
return &k.PublicKey
default:
return fmt.Errorf("unknown key type")
}
}
derBytes, err := x509.CreateCertificate(rand.Reader, cert, cert, publicKey(privKey), privKey)
if err != nil {
return tls.Certificate{}, fmt.Errorf("could not create certificate: %v", err)
}
chain := [][]byte{derBytes}
return tls.Certificate{
Certificate: chain,
PrivateKey: privKey,
Leaf: cert,
}, nil
}
// selfSignedConfig configures a self-signed certificate.
type selfSignedConfig struct {
SAN []string
KeyType certcrypto.KeyType
Expire time.Time
}
...@@ -29,17 +29,20 @@ import ( ...@@ -29,17 +29,20 @@ import (
"github.com/mholt/caddy" "github.com/mholt/caddy"
"github.com/mholt/caddy/telemetry" "github.com/mholt/caddy/telemetry"
"github.com/mholt/certmagic"
) )
func init() { func init() {
caddy.RegisterPlugin("tls", caddy.Plugin{Action: setupTLS}) caddy.RegisterPlugin("tls", caddy.Plugin{Action: setupTLS})
// ensure TLS assets are stored and accessed from the CADDYPATH
certmagic.DefaultStorage = certmagic.FileStorage{Path: caddy.AssetsPath()}
} }
// setupTLS sets up the TLS configuration and installs certificates that // setupTLS sets up the TLS configuration and installs certificates that
// are specified by the user in the config file. All the automatic HTTPS // are specified by the user in the config file. All the automatic HTTPS
// stuff comes later outside of this function. // stuff comes later outside of this function.
func setupTLS(c *caddy.Controller) error { func setupTLS(c *caddy.Controller) error {
// obtain the configGetter, which loads the config we're, uh, configuring
configGetter, ok := configGetters[c.ServerType()] configGetter, ok := configGetters[c.ServerType()]
if !ok { if !ok {
return fmt.Errorf("no caddytls.ConfigGetter for %s server type; must call RegisterConfigGetter", c.ServerType()) return fmt.Errorf("no caddytls.ConfigGetter for %s server type; must call RegisterConfigGetter", c.ServerType())
...@@ -49,18 +52,68 @@ func setupTLS(c *caddy.Controller) error { ...@@ -49,18 +52,68 @@ func setupTLS(c *caddy.Controller) error {
return fmt.Errorf("no caddytls.Config to set up for %s", c.Key) return fmt.Errorf("no caddytls.Config to set up for %s", c.Key)
} }
// the certificate cache is tied to the current caddy.Instance; get a pointer to it config.Enabled = true
certCache, ok := c.Get(CertCacheInstStorageKey).(*certificateCache)
// a single certificate cache is used by the whole caddy.Instance; get a pointer to it
certCache, ok := c.Get(CertCacheInstStorageKey).(*certmagic.Cache)
if !ok || certCache == nil { if !ok || certCache == nil {
certCache = &certificateCache{cache: make(map[string]Certificate)} certCache = certmagic.NewCache(certmagic.FileStorage{Path: caddy.AssetsPath()})
c.OnShutdown(func() error {
certCache.Stop()
return nil
})
c.Set(CertCacheInstStorageKey, certCache) c.Set(CertCacheInstStorageKey, certCache)
} }
config.certCache = certCache config.Manager = certmagic.NewWithCache(certCache, certmagic.Config{})
// we use certmagic events to collect metrics for telemetry
config.Manager.OnEvent = func(event string, data interface{}) {
switch event {
case "tls_handshake_started":
clientHello := data.(*tls.ClientHelloInfo)
if ClientHelloTelemetry && len(clientHello.SupportedVersions) > 0 {
// If no other plugin (such as the HTTP server type) is implementing ClientHello telemetry, we do it.
// NOTE: The values in the Go standard lib's ClientHelloInfo aren't guaranteed to be in order.
info := ClientHelloInfo{
Version: clientHello.SupportedVersions[0], // report the highest
CipherSuites: clientHello.CipherSuites,
ExtensionsUnknown: true, // no extension info... :(
CompressionMethodsUnknown: true, // no compression methods... :(
Curves: clientHello.SupportedCurves,
Points: clientHello.SupportedPoints,
// We also have, but do not yet use: SignatureSchemes, ServerName, and SupportedProtos (ALPN)
// because the standard lib parses some extensions, but our MITM detector generally doesn't.
}
go telemetry.SetNested("tls_client_hello", info.Key(), info)
}
config.Enabled = true case "tls_handshake_completed":
// TODO: This is a "best guess" for now - at this point, we only gave a
// certificate to the client; we need something listener-level to be sure
go telemetry.Increment("tls_handshake_count")
case "acme_cert_obtained":
go telemetry.Increment("tls_acme_certs_obtained")
case "acme_cert_renewed":
name := data.(string)
caddy.EmitEvent(caddy.CertRenewEvent, name)
go telemetry.Increment("tls_acme_certs_renewed")
case "acme_cert_revoked":
telemetry.Increment("acme_certs_revoked")
case "cached_managed_cert":
telemetry.Increment("tls_managed_cert_count")
case "cached_unmanaged_cert":
telemetry.Increment("tls_unmanaged_cert_count")
}
}
for c.Next() { for c.Next() {
var certificateFile, keyFile, loadDir, maxCerts, askURL string var certificateFile, keyFile, loadDir, maxCerts, askURL string
var onDemand bool
args := c.RemainingArgs() args := c.RemainingArgs()
switch len(args) { switch len(args) {
...@@ -96,14 +149,14 @@ func setupTLS(c *caddy.Controller) error { ...@@ -96,14 +149,14 @@ func setupTLS(c *caddy.Controller) error {
if len(arg) != 1 { if len(arg) != 1 {
return c.ArgErr() return c.ArgErr()
} }
config.CAUrl = arg[0] config.Manager.CA = arg[0]
case "key_type": case "key_type":
arg := c.RemainingArgs() arg := c.RemainingArgs()
value, ok := supportedKeyTypes[strings.ToUpper(arg[0])] value, ok := supportedKeyTypes[strings.ToUpper(arg[0])]
if !ok { if !ok {
return c.Errf("Wrong key type name or key type not supported: '%s'", c.Val()) return c.Errf("Wrong key type name or key type not supported: '%s'", c.Val())
} }
config.KeyType = value config.Manager.KeyType = value
case "protocols": case "protocols":
args := c.RemainingArgs() args := c.RemainingArgs()
if len(args) == 1 { if len(args) == 1 {
...@@ -111,7 +164,6 @@ func setupTLS(c *caddy.Controller) error { ...@@ -111,7 +164,6 @@ func setupTLS(c *caddy.Controller) error {
if !ok { if !ok {
return c.Errf("Wrong protocol name or protocol not supported: '%s'", args[0]) return c.Errf("Wrong protocol name or protocol not supported: '%s'", args[0])
} }
config.ProtocolMinVersion, config.ProtocolMaxVersion = value, value config.ProtocolMinVersion, config.ProtocolMaxVersion = value, value
} else { } else {
value, ok := SupportedProtocols[strings.ToLower(args[0])] value, ok := SupportedProtocols[strings.ToLower(args[0])]
...@@ -174,32 +226,44 @@ func setupTLS(c *caddy.Controller) error { ...@@ -174,32 +226,44 @@ func setupTLS(c *caddy.Controller) error {
config.Manual = true config.Manual = true
case "max_certs": case "max_certs":
c.Args(&maxCerts) c.Args(&maxCerts)
config.OnDemand = true onDemand = true
telemetry.Increment("tls_on_demand_count")
case "ask": case "ask":
c.Args(&askURL) c.Args(&askURL)
config.OnDemand = true onDemand = true
telemetry.Increment("tls_on_demand_count")
case "dns": case "dns":
args := c.RemainingArgs() args := c.RemainingArgs()
if len(args) != 1 { if len(args) != 1 {
return c.ArgErr() return c.ArgErr()
} }
// TODO: we can get rid of DNS provider plugins with this one line
// of code; however, currently (Dec. 2018) this adds about 20 MB
// of bloat to the Caddy binary, doubling its size to ~40 MB...!
// dnsProv, err := dns.NewDNSChallengeProviderByName(args[0])
// if err != nil {
// return c.Errf("Configuring DNS provider '%s': %v", args[0], err)
// }
dnsProvName := args[0] dnsProvName := args[0]
if _, ok := dnsProviders[dnsProvName]; !ok { dnsProvConstructor, ok := dnsProviders[dnsProvName]
return c.Errf("Unsupported DNS provider '%s'", args[0]) if !ok {
} return c.Errf("Unknown DNS provider by name '%s'", dnsProvName)
config.DNSProvider = args[0]
case "storage":
args := c.RemainingArgs()
if len(args) != 1 {
return c.ArgErr()
} }
storageProvName := args[0] dnsProv, err := dnsProvConstructor()
if _, ok := storageProviders[storageProvName]; !ok { if err != nil {
return c.Errf("Unsupported Storage provider '%s'", args[0]) return c.Errf("Setting up DNS provider '%s': %v", dnsProvName, err)
} }
config.StorageProvider = args[0] config.Manager.DNSProvider = dnsProv
// TODO
// case "storage":
// args := c.RemainingArgs()
// if len(args) != 1 {
// return c.ArgErr()
// }
// storageProvName := args[0]
// storageProvConstr, ok := storageProviders[storageProvName]
// if !ok {
// return c.Errf("Unsupported Storage provider '%s'", args[0])
// }
// config.Manager.Storage = storageProvConstr
case "alpn": case "alpn":
args := c.RemainingArgs() args := c.RemainingArgs()
if len(args) == 0 { if len(args) == 0 {
...@@ -209,9 +273,9 @@ func setupTLS(c *caddy.Controller) error { ...@@ -209,9 +273,9 @@ func setupTLS(c *caddy.Controller) error {
config.ALPN = append(config.ALPN, arg) config.ALPN = append(config.ALPN, arg)
} }
case "must_staple": case "must_staple":
config.MustStaple = true config.Manager.MustStaple = true
case "wildcard": case "wildcard":
if !HostQualifies(config.Hostname) { if !certmagic.HostQualifies(config.Hostname) {
return c.Errf("Hostname '%s' does not qualify for managed TLS, so cannot manage wildcard certificate for it", config.Hostname) return c.Errf("Hostname '%s' does not qualify for managed TLS, so cannot manage wildcard certificate for it", config.Hostname)
} }
if strings.Contains(config.Hostname, "*") { if strings.Contains(config.Hostname, "*") {
...@@ -233,26 +297,26 @@ func setupTLS(c *caddy.Controller) error { ...@@ -233,26 +297,26 @@ func setupTLS(c *caddy.Controller) error {
return c.ArgErr() return c.ArgErr()
} }
// set certificate limit if on-demand TLS is enabled // configure on-demand TLS, if enabled
if maxCerts != "" { if onDemand {
maxCertsNum, err := strconv.Atoi(maxCerts) config.Manager.OnDemand = new(certmagic.OnDemandConfig)
if err != nil || maxCertsNum < 1 { if maxCerts != "" {
return c.Err("max_certs must be a positive integer") maxCertsNum, err := strconv.Atoi(maxCerts)
} if err != nil || maxCertsNum < 1 {
config.OnDemandState.MaxObtain = int32(maxCertsNum) return c.Err("max_certs must be a positive integer")
} }
config.Manager.OnDemand.MaxObtain = int32(maxCertsNum)
if askURL != "" {
parsedURL, err := url.Parse(askURL)
if err != nil {
return c.Err("ask must be a valid url")
} }
if askURL != "" {
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" { parsedURL, err := url.Parse(askURL)
return c.Err("ask URL must use http or https") if err != nil {
return c.Err("ask must be a valid url")
}
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
return c.Err("ask URL must use http or https")
}
config.Manager.OnDemand.AskURL = parsedURL
} }
config.OnDemandState.AskURL = parsedURL
} }
// don't try to load certificates unless we're supposed to // don't try to load certificates unless we're supposed to
...@@ -262,7 +326,7 @@ func setupTLS(c *caddy.Controller) error { ...@@ -262,7 +326,7 @@ func setupTLS(c *caddy.Controller) error {
// load a single certificate and key, if specified // load a single certificate and key, if specified
if certificateFile != "" && keyFile != "" { if certificateFile != "" && keyFile != "" {
err := config.cacheUnmanagedCertificatePEMFile(certificateFile, keyFile) err := config.Manager.CacheUnmanagedCertificatePEMFile(certificateFile, keyFile)
if err != nil { if err != nil {
return c.Errf("Unable to load certificate and key files for '%s': %v", c.Key, err) return c.Errf("Unable to load certificate and key files for '%s': %v", c.Key, err)
} }
...@@ -282,7 +346,14 @@ func setupTLS(c *caddy.Controller) error { ...@@ -282,7 +346,14 @@ func setupTLS(c *caddy.Controller) error {
// generate self-signed cert if needed // generate self-signed cert if needed
if config.SelfSigned { if config.SelfSigned {
err := makeSelfSignedCertForConfig(config) ssCert, err := newSelfSignedCertificate(selfSignedConfig{
SAN: []string{config.Hostname},
KeyType: config.Manager.KeyType,
})
if err != nil {
return fmt.Errorf("self-signed certificate generation: %v", err)
}
err = config.Manager.CacheUnmanagedTLSCertificate(ssCert)
if err != nil { if err != nil {
return fmt.Errorf("self-signed: %v", err) return fmt.Errorf("self-signed: %v", err)
} }
...@@ -362,7 +433,7 @@ func loadCertsInDir(cfg *Config, c *caddy.Controller, dir string) error { ...@@ -362,7 +433,7 @@ func loadCertsInDir(cfg *Config, c *caddy.Controller, dir string) error {
return c.Errf("%s: no private key block found", path) return c.Errf("%s: no private key block found", path)
} }
err = cfg.cacheUnmanagedCertificatePEMBytes(certPEMBytes, keyPEMBytes) err = cfg.Manager.CacheUnmanagedCertificatePEMBytes(certPEMBytes, keyPEMBytes)
if err != nil { if err != nil {
return c.Errf("%s: failed to load cert and key for '%s': %v", path, c.Key, err) return c.Errf("%s: failed to load cert and key for '%s': %v", path, c.Key, err)
} }
......
...@@ -22,7 +22,8 @@ import ( ...@@ -22,7 +22,8 @@ import (
"testing" "testing"
"github.com/mholt/caddy" "github.com/mholt/caddy"
"github.com/xenolf/lego/acme" "github.com/mholt/certmagic"
"github.com/xenolf/lego/certcrypto"
) )
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
...@@ -46,8 +47,7 @@ func TestMain(m *testing.M) { ...@@ -46,8 +47,7 @@ func TestMain(m *testing.M) {
} }
func TestSetupParseBasic(t *testing.T) { func TestSetupParseBasic(t *testing.T) {
certCache := &certificateCache{cache: make(map[string]Certificate)} cfg, certCache := testConfigForTLSSetup()
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg }) RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
c := caddy.NewTestController("", `tls `+certFile+` `+keyFile+``) c := caddy.NewTestController("", `tls `+certFile+` `+keyFile+``)
...@@ -127,8 +127,7 @@ func TestSetupParseWithOptionalParams(t *testing.T) { ...@@ -127,8 +127,7 @@ func TestSetupParseWithOptionalParams(t *testing.T) {
must_staple must_staple
alpn http/1.1 alpn http/1.1
}` }`
certCache := &certificateCache{cache: make(map[string]Certificate)} cfg, certCache := testConfigForTLSSetup()
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg }) RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
c := caddy.NewTestController("", params) c := caddy.NewTestController("", params)
...@@ -151,7 +150,7 @@ func TestSetupParseWithOptionalParams(t *testing.T) { ...@@ -151,7 +150,7 @@ func TestSetupParseWithOptionalParams(t *testing.T) {
t.Errorf("Expected 3 Ciphers (not including TLS_FALLBACK_SCSV), got %v", len(cfg.Ciphers)-1) t.Errorf("Expected 3 Ciphers (not including TLS_FALLBACK_SCSV), got %v", len(cfg.Ciphers)-1)
} }
if !cfg.MustStaple { if !cfg.Manager.MustStaple {
t.Error("Expected must staple to be true") t.Error("Expected must staple to be true")
} }
...@@ -164,8 +163,7 @@ func TestSetupDefaultWithOptionalParams(t *testing.T) { ...@@ -164,8 +163,7 @@ func TestSetupDefaultWithOptionalParams(t *testing.T) {
params := `tls { params := `tls {
ciphers RSA-3DES-EDE-CBC-SHA ciphers RSA-3DES-EDE-CBC-SHA
}` }`
certCache := &certificateCache{cache: make(map[string]Certificate)} cfg, certCache := testConfigForTLSSetup()
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg }) RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
c := caddy.NewTestController("", params) c := caddy.NewTestController("", params)
c.Set(CertCacheInstStorageKey, certCache) c.Set(CertCacheInstStorageKey, certCache)
...@@ -184,8 +182,7 @@ func TestSetupParseWithWrongOptionalParams(t *testing.T) { ...@@ -184,8 +182,7 @@ func TestSetupParseWithWrongOptionalParams(t *testing.T) {
params := `tls ` + certFile + ` ` + keyFile + ` { params := `tls ` + certFile + ` ` + keyFile + ` {
protocols ssl tls protocols ssl tls
}` }`
certCache := &certificateCache{cache: make(map[string]Certificate)} cfg, certCache := testConfigForTLSSetup()
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg }) RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
c := caddy.NewTestController("", params) c := caddy.NewTestController("", params)
c.Set(CertCacheInstStorageKey, certCache) c.Set(CertCacheInstStorageKey, certCache)
...@@ -239,8 +236,7 @@ func TestSetupParseWithClientAuth(t *testing.T) { ...@@ -239,8 +236,7 @@ func TestSetupParseWithClientAuth(t *testing.T) {
params := `tls ` + certFile + ` ` + keyFile + ` { params := `tls ` + certFile + ` ` + keyFile + ` {
clients clients
}` }`
certCache := &certificateCache{cache: make(map[string]Certificate)} cfg, _ := testConfigForTLSSetup()
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg }) RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
c := caddy.NewTestController("", params) c := caddy.NewTestController("", params)
err := setupTLS(c) err := setupTLS(c)
...@@ -273,8 +269,8 @@ func TestSetupParseWithClientAuth(t *testing.T) { ...@@ -273,8 +269,8 @@ func TestSetupParseWithClientAuth(t *testing.T) {
clients verify_if_given clients verify_if_given
}`, tls.VerifyClientCertIfGiven, true, noCAs}, }`, tls.VerifyClientCertIfGiven, true, noCAs},
} { } {
certCache := &certificateCache{cache: make(map[string]Certificate)} certCache := certmagic.NewCache(certmagic.DefaultStorage)
cfg := &Config{Certificates: make(map[string]string), certCache: certCache} cfg := &Config{Manager: certmagic.NewWithCache(certCache, certmagic.Config{})}
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg }) RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
c := caddy.NewTestController("", caseData.params) c := caddy.NewTestController("", caseData.params)
c.Set(CertCacheInstStorageKey, certCache) c.Set(CertCacheInstStorageKey, certCache)
...@@ -327,8 +323,8 @@ func TestSetupParseWithCAUrl(t *testing.T) { ...@@ -327,8 +323,8 @@ func TestSetupParseWithCAUrl(t *testing.T) {
ca 1 2 ca 1 2
}`, true, ""}, }`, true, ""},
} { } {
certCache := &certificateCache{cache: make(map[string]Certificate)} certCache := certmagic.NewCache(certmagic.DefaultStorage)
cfg := &Config{Certificates: make(map[string]string), certCache: certCache} cfg := &Config{Manager: certmagic.NewWithCache(certCache, certmagic.Config{})}
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg }) RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
c := caddy.NewTestController("", caseData.params) c := caddy.NewTestController("", caseData.params)
c.Set(CertCacheInstStorageKey, certCache) c.Set(CertCacheInstStorageKey, certCache)
...@@ -343,8 +339,8 @@ func TestSetupParseWithCAUrl(t *testing.T) { ...@@ -343,8 +339,8 @@ func TestSetupParseWithCAUrl(t *testing.T) {
t.Errorf("In case %d: Expected no errors, got: %v", caseNumber, err) t.Errorf("In case %d: Expected no errors, got: %v", caseNumber, err)
} }
if cfg.CAUrl != caseData.expectedCAUrl { if cfg.Manager.CA != caseData.expectedCAUrl {
t.Errorf("Expected '%v' as CAUrl, got %#v", caseData.expectedCAUrl, cfg.CAUrl) t.Errorf("Expected '%v' as CAUrl, got %#v", caseData.expectedCAUrl, cfg.Manager.CA)
} }
} }
} }
...@@ -353,8 +349,7 @@ func TestSetupParseWithKeyType(t *testing.T) { ...@@ -353,8 +349,7 @@ func TestSetupParseWithKeyType(t *testing.T) {
params := `tls { params := `tls {
key_type p384 key_type p384
}` }`
certCache := &certificateCache{cache: make(map[string]Certificate)} cfg, certCache := testConfigForTLSSetup()
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg }) RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
c := caddy.NewTestController("", params) c := caddy.NewTestController("", params)
c.Set(CertCacheInstStorageKey, certCache) c.Set(CertCacheInstStorageKey, certCache)
...@@ -364,8 +359,8 @@ func TestSetupParseWithKeyType(t *testing.T) { ...@@ -364,8 +359,8 @@ func TestSetupParseWithKeyType(t *testing.T) {
t.Errorf("Expected no errors, got: %v", err) t.Errorf("Expected no errors, got: %v", err)
} }
if cfg.KeyType != acme.EC384 { if cfg.Manager.KeyType != certcrypto.EC384 {
t.Errorf("Expected 'P384' as KeyType, got %#v", cfg.KeyType) t.Errorf("Expected 'P384' as KeyType, got %#v", cfg.Manager.KeyType)
} }
} }
...@@ -373,8 +368,7 @@ func TestSetupParseWithCurves(t *testing.T) { ...@@ -373,8 +368,7 @@ func TestSetupParseWithCurves(t *testing.T) {
params := `tls { params := `tls {
curves x25519 p256 p384 p521 curves x25519 p256 p384 p521
}` }`
certCache := &certificateCache{cache: make(map[string]Certificate)} cfg, certCache := testConfigForTLSSetup()
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg }) RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
c := caddy.NewTestController("", params) c := caddy.NewTestController("", params)
c.Set(CertCacheInstStorageKey, certCache) c.Set(CertCacheInstStorageKey, certCache)
...@@ -402,8 +396,7 @@ func TestSetupParseWithOneTLSProtocol(t *testing.T) { ...@@ -402,8 +396,7 @@ func TestSetupParseWithOneTLSProtocol(t *testing.T) {
params := `tls { params := `tls {
protocols tls1.2 protocols tls1.2
}` }`
certCache := &certificateCache{cache: make(map[string]Certificate)} cfg, certCache := testConfigForTLSSetup()
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg }) RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
c := caddy.NewTestController("", params) c := caddy.NewTestController("", params)
c.Set(CertCacheInstStorageKey, certCache) c.Set(CertCacheInstStorageKey, certCache)
...@@ -422,6 +415,14 @@ func TestSetupParseWithOneTLSProtocol(t *testing.T) { ...@@ -422,6 +415,14 @@ func TestSetupParseWithOneTLSProtocol(t *testing.T) {
} }
} }
func testConfigForTLSSetup() (*Config, *certmagic.Cache) {
certCache := certmagic.NewCache(nil)
certCache.Stop()
return &Config{
Manager: certmagic.NewWithCache(certCache, certmagic.Config{}),
}, certCache
}
const ( const (
certFile = "test_cert.pem" certFile = "test_cert.pem"
keyFile = "test_key.pem" keyFile = "test_key.pem"
......
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package caddytls
import "net/url"
// StorageConstructor is a function type that is used in the Config to
// instantiate a new Storage instance. This function can return a nil
// Storage even without an error.
type StorageConstructor func(caURL *url.URL) (Storage, error)
// SiteData contains persisted items pertaining to an individual site.
type SiteData struct {
// Cert is the public cert byte array.
Cert []byte
// Key is the private key byte array.
Key []byte
// Meta is metadata about the site used by Caddy.
Meta []byte
}
// UserData contains persisted items pertaining to a user.
type UserData struct {
// Reg is the user registration byte array.
Reg []byte
// Key is the user key byte array.
Key []byte
}
// Locker provides support for mutual exclusion
type Locker interface {
// TryLock will return immediatedly with or without acquiring the lock.
// If a lock could be obtained, (nil, nil) is returned and you may
// continue normally. If not (meaning another process is already
// working on that name), a Waiter value will be returned upon
// which you can Wait() until it is finished, and then return
// when it unblocks. If waiting, do not unlock!
//
// To prevent deadlocks, all implementations (where this concern
// is relevant) should put a reasonable expiration on the lock in
// case Unlock is unable to be called due to some sort of storage
// system failure or crash.
TryLock(name string) (Waiter, error)
// Unlock unlocks the mutex for name. Only callers of TryLock who
// successfully obtained the lock (no Waiter value was returned)
// should call this method, and it should be called only after
// the obtain/renew and store are finished, even if there was
// an error (or a timeout). Unlock should also clean up any
// unused resources allocated during TryLock.
Unlock(name string) error
}
// Storage is an interface abstracting all storage used by Caddy's TLS
// subsystem. Implementations of this interface store both site and
// user data.
type Storage interface {
// SiteExists returns true if this site exists in storage.
// Site data is considered present when StoreSite has been called
// successfully (without DeleteSite having been called, of course).
SiteExists(domain string) (bool, error)
// LoadSite obtains the site data from storage for the given domain and
// returns it. If data for the domain does not exist, an error value
// of type ErrNotExist is returned. For multi-server storage, care
// should be taken to make this load atomic to prevent race conditions
// that happen with multiple data loads.
LoadSite(domain string) (*SiteData, error)
// StoreSite persists the given site data for the given domain in
// storage. For multi-server storage, care should be taken to make this
// call atomic to prevent half-written data on failure of an internal
// intermediate storage step. Implementers can trust that at runtime
// this function will only be invoked after LockRegister and before
// UnlockRegister of the same domain.
StoreSite(domain string, data *SiteData) error
// DeleteSite deletes the site for the given domain from storage.
// Multi-server implementations should attempt to make this atomic. If
// the site does not exist, an error value of type ErrNotExist is returned.
DeleteSite(domain string) error
// LoadUser obtains user data from storage for the given email and
// returns it. If data for the email does not exist, an error value
// of type ErrNotExist is returned. Multi-server implementations
// should take care to make this operation atomic for all loaded
// data items.
LoadUser(email string) (*UserData, error)
// StoreUser persists the given user data for the given email in
// storage. Multi-server implementations should take care to make this
// operation atomic for all stored data items.
StoreUser(email string, data *UserData) error
// MostRecentUserEmail provides the most recently used email parameter
// in StoreUser. The result is an empty string if there are no
// persisted users in storage.
MostRecentUserEmail() string
// Locker is necessary because synchronizing certificate maintenance
// depends on how storage is implemented.
Locker
}
// ErrNotExist is returned by Storage implementations when
// a resource is not found. It is similar to os.ErrNotExist
// except this is a type, not a variable.
type ErrNotExist interface {
error
}
// Waiter is a type that can block until a storage lock is released.
type Waiter interface {
Wait()
}
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package storagetest
import (
"errors"
"net/url"
"sync"
"github.com/mholt/caddy/caddytls"
)
// memoryMutex is a mutex used to control access to memoryStoragesByCAURL.
var memoryMutex sync.Mutex
// memoryStoragesByCAURL is a map keyed by a CA URL string with values of
// instantiated memory stores. Do not access this directly, it is used by
// InMemoryStorageCreator.
var memoryStoragesByCAURL = make(map[string]*InMemoryStorage)
// InMemoryStorageCreator is a caddytls.Storage.StorageCreator to create
// InMemoryStorage instances for testing.
func InMemoryStorageCreator(caURL *url.URL) (caddytls.Storage, error) {
urlStr := caURL.String()
memoryMutex.Lock()
defer memoryMutex.Unlock()
storage := memoryStoragesByCAURL[urlStr]
if storage == nil {
storage = NewInMemoryStorage()
memoryStoragesByCAURL[urlStr] = storage
}
return storage, nil
}
// InMemoryStorage is a caddytls.Storage implementation for use in testing.
// It simply stores information in runtime memory.
type InMemoryStorage struct {
// Sites are exposed for testing purposes.
Sites map[string]*caddytls.SiteData
// Users are exposed for testing purposes.
Users map[string]*caddytls.UserData
// LastUserEmail is exposed for testing purposes.
LastUserEmail string
}
// NewInMemoryStorage constructs an InMemoryStorage instance. For use with
// caddytls, the InMemoryStorageCreator should be used instead.
func NewInMemoryStorage() *InMemoryStorage {
return &InMemoryStorage{
Sites: make(map[string]*caddytls.SiteData),
Users: make(map[string]*caddytls.UserData),
}
}
// SiteExists implements caddytls.Storage.SiteExists in memory.
func (s *InMemoryStorage) SiteExists(domain string) (bool, error) {
_, siteExists := s.Sites[domain]
return siteExists, nil
}
// Clear completely clears all values associated with this storage.
func (s *InMemoryStorage) Clear() {
s.Sites = make(map[string]*caddytls.SiteData)
s.Users = make(map[string]*caddytls.UserData)
s.LastUserEmail = ""
}
// LoadSite implements caddytls.Storage.LoadSite in memory.
func (s *InMemoryStorage) LoadSite(domain string) (*caddytls.SiteData, error) {
siteData, ok := s.Sites[domain]
if !ok {
return nil, caddytls.ErrNotExist(errors.New("not found"))
}
return siteData, nil
}
func copyBytes(from []byte) []byte {
copiedBytes := make([]byte, len(from))
copy(copiedBytes, from)
return copiedBytes
}
// StoreSite implements caddytls.Storage.StoreSite in memory.
func (s *InMemoryStorage) StoreSite(domain string, data *caddytls.SiteData) error {
copiedData := new(caddytls.SiteData)
copiedData.Cert = copyBytes(data.Cert)
copiedData.Key = copyBytes(data.Key)
copiedData.Meta = copyBytes(data.Meta)
s.Sites[domain] = copiedData
return nil
}
// DeleteSite implements caddytls.Storage.DeleteSite in memory.
func (s *InMemoryStorage) DeleteSite(domain string) error {
if _, ok := s.Sites[domain]; !ok {
return caddytls.ErrNotExist(errors.New("not found"))
}
delete(s.Sites, domain)
return nil
}
// TryLock implements Storage.TryLock by returning nil values because it
// is not a multi-server storage implementation.
func (s *InMemoryStorage) TryLock(domain string) (caddytls.Waiter, error) {
return nil, nil
}
// Unlock implements Storage.Unlock as a no-op because it is
// not a multi-server storage implementation.
func (s *InMemoryStorage) Unlock(domain string) error {
return nil
}
// LoadUser implements caddytls.Storage.LoadUser in memory.
func (s *InMemoryStorage) LoadUser(email string) (*caddytls.UserData, error) {
userData, ok := s.Users[email]
if !ok {
return nil, caddytls.ErrNotExist(errors.New("not found"))
}
return userData, nil
}
// StoreUser implements caddytls.Storage.StoreUser in memory.
func (s *InMemoryStorage) StoreUser(email string, data *caddytls.UserData) error {
copiedData := new(caddytls.UserData)
copiedData.Reg = copyBytes(data.Reg)
copiedData.Key = copyBytes(data.Key)
s.Users[email] = copiedData
s.LastUserEmail = email
return nil
}
// MostRecentUserEmail implements caddytls.Storage.MostRecentUserEmail in memory.
func (s *InMemoryStorage) MostRecentUserEmail() string {
return s.LastUserEmail
}
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package storagetest
import "testing"
func TestMemoryStorage(t *testing.T) {
storage := NewInMemoryStorage()
storageTest := &StorageTest{
Storage: storage,
PostTest: storage.Clear,
}
storageTest.Test(t, false)
}
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package storagetest provides utilities to assist in testing caddytls.Storage
// implementations.
package storagetest
import (
"bytes"
"errors"
"fmt"
"testing"
"github.com/mholt/caddy/caddytls"
)
// StorageTest is a test harness that contains tests to execute all exposed
// parts of a Storage implementation.
type StorageTest struct {
// Storage is the implementation to use during tests. This must be
// present.
caddytls.Storage
// PreTest, if present, is called before every test. Any error returned
// is returned from the test and the test does not continue.
PreTest func() error
// PostTest, if present, is executed after every test via defer which
// means it executes even on failure of the test (but not on failure of
// PreTest).
PostTest func()
// AfterUserEmailStore, if present, is invoked during
// TestMostRecentUserEmail after each storage just in case anything
// needs to be mocked.
AfterUserEmailStore func(email string) error
}
// TestFunc holds information about a test.
type TestFunc struct {
// Name is the friendly name of the test.
Name string
// Fn is the function that is invoked for the test.
Fn func() error
}
// runPreTest runs the PreTest function if present.
func (s *StorageTest) runPreTest() error {
if s.PreTest != nil {
return s.PreTest()
}
return nil
}
// runPostTest runs the PostTest function if present.
func (s *StorageTest) runPostTest() {
if s.PostTest != nil {
s.PostTest()
}
}
// AllFuncs returns all test functions that are part of this harness.
func (s *StorageTest) AllFuncs() []TestFunc {
return []TestFunc{
{"TestSiteInfoExists", s.TestSiteExists},
{"TestSite", s.TestSite},
{"TestUser", s.TestUser},
{"TestMostRecentUserEmail", s.TestMostRecentUserEmail},
}
}
// Test executes the entire harness using the testing package. Failures are
// reported via T.Fatal. If eagerFail is true, the first failure causes all
// testing to stop immediately.
func (s *StorageTest) Test(t *testing.T, eagerFail bool) {
if errs := s.TestAll(eagerFail); len(errs) > 0 {
ifaces := make([]interface{}, len(errs))
for i, err := range errs {
ifaces[i] = err
}
t.Fatal(ifaces...)
}
}
// TestAll executes the entire harness and returns the results as an array of
// errors. If eagerFail is true, the first failure causes all testing to stop
// immediately.
func (s *StorageTest) TestAll(eagerFail bool) (errs []error) {
for _, fn := range s.AllFuncs() {
if err := fn.Fn(); err != nil {
errs = append(errs, fmt.Errorf("%v failed: %v", fn.Name, err))
if eagerFail {
return
}
}
}
return
}
var simpleSiteData = &caddytls.SiteData{
Cert: []byte("foo"),
Key: []byte("bar"),
Meta: []byte("baz"),
}
var simpleSiteDataAlt = &caddytls.SiteData{
Cert: []byte("qux"),
Key: []byte("quux"),
Meta: []byte("corge"),
}
// TestSiteExists tests Storage.SiteExists.
func (s *StorageTest) TestSiteExists() error {
if err := s.runPreTest(); err != nil {
return err
}
defer s.runPostTest()
// Should not exist at first
siteExists, err := s.SiteExists("example.com")
if err != nil {
return err
}
if siteExists {
return errors.New("Site should not exist")
}
// Should exist after we store it
if err := s.StoreSite("example.com", simpleSiteData); err != nil {
return err
}
siteExists, err = s.SiteExists("example.com")
if err != nil {
return err
}
if !siteExists {
return errors.New("Expected site to exist")
}
// Site should no longer exist after we delete it
if err := s.DeleteSite("example.com"); err != nil {
return err
}
siteExists, err = s.SiteExists("example.com")
if err != nil {
return err
}
if siteExists {
return errors.New("Site should not exist after delete")
}
return nil
}
// TestSite tests Storage.LoadSite, Storage.StoreSite, and Storage.DeleteSite.
func (s *StorageTest) TestSite() error {
if err := s.runPreTest(); err != nil {
return err
}
defer s.runPostTest()
// Should be a not-found error at first
_, err := s.LoadSite("example.com")
if _, ok := err.(caddytls.ErrNotExist); !ok {
return fmt.Errorf("Expected caddytls.ErrNotExist from load, got %T: %v", err, err)
}
// Delete should also be a not-found error at first
err = s.DeleteSite("example.com")
if _, ok := err.(caddytls.ErrNotExist); !ok {
return fmt.Errorf("Expected ErrNotExist from delete, got: %v", err)
}
// Should store successfully and then load just fine
if err := s.StoreSite("example.com", simpleSiteData); err != nil {
return err
}
if siteData, err := s.LoadSite("example.com"); err != nil {
return err
} else if !bytes.Equal(siteData.Cert, simpleSiteData.Cert) {
return errors.New("Unexpected cert returned after store")
} else if !bytes.Equal(siteData.Key, simpleSiteData.Key) {
return errors.New("Unexpected key returned after store")
} else if !bytes.Equal(siteData.Meta, simpleSiteData.Meta) {
return errors.New("Unexpected meta returned after store")
}
// Overwrite should work just fine
if err := s.StoreSite("example.com", simpleSiteDataAlt); err != nil {
return err
}
if siteData, err := s.LoadSite("example.com"); err != nil {
return err
} else if !bytes.Equal(siteData.Cert, simpleSiteDataAlt.Cert) {
return errors.New("Unexpected cert returned after overwrite")
}
// It should delete fine and then not be there
if err := s.DeleteSite("example.com"); err != nil {
return err
}
_, err = s.LoadSite("example.com")
if _, ok := err.(caddytls.ErrNotExist); !ok {
return fmt.Errorf("Expected caddytls.ErrNotExist after delete, got %T: %v", err, err)
}
return nil
}
var simpleUserData = &caddytls.UserData{
Reg: []byte("foo"),
Key: []byte("bar"),
}
var simpleUserDataAlt = &caddytls.UserData{
Reg: []byte("baz"),
Key: []byte("qux"),
}
// TestUser tests Storage.LoadUser and Storage.StoreUser.
func (s *StorageTest) TestUser() error {
if err := s.runPreTest(); err != nil {
return err
}
defer s.runPostTest()
// Should be a not-found error at first
_, err := s.LoadUser("foo@example.com")
if _, ok := err.(caddytls.ErrNotExist); !ok {
return fmt.Errorf("Expected caddytls.ErrNotExist from load, got %T: %v", err, err)
}
// Should store successfully and then load just fine
if err := s.StoreUser("foo@example.com", simpleUserData); err != nil {
return err
}
if userData, err := s.LoadUser("foo@example.com"); err != nil {
return err
} else if !bytes.Equal(userData.Reg, simpleUserData.Reg) {
return errors.New("Unexpected reg returned after store")
} else if !bytes.Equal(userData.Key, simpleUserData.Key) {
return errors.New("Unexpected key returned after store")
}
// Overwrite should work just fine
if err := s.StoreUser("foo@example.com", simpleUserDataAlt); err != nil {
return err
}
if userData, err := s.LoadUser("foo@example.com"); err != nil {
return err
} else if !bytes.Equal(userData.Reg, simpleUserDataAlt.Reg) {
return errors.New("Unexpected reg returned after overwrite")
}
return nil
}
// TestMostRecentUserEmail tests Storage.MostRecentUserEmail.
func (s *StorageTest) TestMostRecentUserEmail() error {
if err := s.runPreTest(); err != nil {
return err
}
defer s.runPostTest()
// Should be empty on first run
if e := s.MostRecentUserEmail(); e != "" {
return fmt.Errorf("Expected empty most recent user on first run, got: %v", e)
}
// If we store user, then that one should be returned
if err := s.StoreUser("foo1@example.com", simpleUserData); err != nil {
return err
}
if s.AfterUserEmailStore != nil {
s.AfterUserEmailStore("foo1@example.com")
}
if e := s.MostRecentUserEmail(); e != "foo1@example.com" {
return fmt.Errorf("Unexpected most recent email after first store: %v", e)
}
// If we store another user, then that one should be returned
if err := s.StoreUser("foo2@example.com", simpleUserDataAlt); err != nil {
return err
}
if s.AfterUserEmailStore != nil {
s.AfterUserEmailStore("foo2@example.com")
}
if e := s.MostRecentUserEmail(); e != "foo2@example.com" {
return fmt.Errorf("Unexpected most recent email after user key: %v", e)
}
return nil
}
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -71,37 +71,37 @@ func (c *Controller) ServerType() string { ...@@ -71,37 +71,37 @@ func (c *Controller) ServerType() string {
// OnFirstStartup adds fn to the list of callback functions to execute // OnFirstStartup adds fn to the list of callback functions to execute
// when the server is about to be started NOT as part of a restart. // when the server is about to be started NOT as part of a restart.
func (c *Controller) OnFirstStartup(fn func() error) { func (c *Controller) OnFirstStartup(fn func() error) {
c.instance.onFirstStartup = append(c.instance.onFirstStartup, fn) c.instance.OnFirstStartup = append(c.instance.OnFirstStartup, fn)
} }
// OnStartup adds fn to the list of callback functions to execute // OnStartup adds fn to the list of callback functions to execute
// when the server is about to be started (including restarts). // when the server is about to be started (including restarts).
func (c *Controller) OnStartup(fn func() error) { func (c *Controller) OnStartup(fn func() error) {
c.instance.onStartup = append(c.instance.onStartup, fn) c.instance.OnStartup = append(c.instance.OnStartup, fn)
} }
// OnRestart adds fn to the list of callback functions to execute // OnRestart adds fn to the list of callback functions to execute
// when the server is about to be restarted. // when the server is about to be restarted.
func (c *Controller) OnRestart(fn func() error) { func (c *Controller) OnRestart(fn func() error) {
c.instance.onRestart = append(c.instance.onRestart, fn) c.instance.OnRestart = append(c.instance.OnRestart, fn)
} }
// OnRestartFailed adds fn to the list of callback functions to execute // OnRestartFailed adds fn to the list of callback functions to execute
// if the server failed to restart. // if the server failed to restart.
func (c *Controller) OnRestartFailed(fn func() error) { func (c *Controller) OnRestartFailed(fn func() error) {
c.instance.onRestartFailed = append(c.instance.onRestartFailed, fn) c.instance.OnRestartFailed = append(c.instance.OnRestartFailed, fn)
} }
// OnShutdown adds fn to the list of callback functions to execute // OnShutdown adds fn to the list of callback functions to execute
// when the server is about to be shut down (including restarts). // when the server is about to be shut down (including restarts).
func (c *Controller) OnShutdown(fn func() error) { func (c *Controller) OnShutdown(fn func() error) {
c.instance.onShutdown = append(c.instance.onShutdown, fn) c.instance.OnShutdown = append(c.instance.OnShutdown, fn)
} }
// OnFinalShutdown adds fn to the list of callback functions to execute // OnFinalShutdown adds fn to the list of callback functions to execute
// when the server is about to be shut down NOT as part of a restart. // when the server is about to be shut down NOT as part of a restart.
func (c *Controller) OnFinalShutdown(fn func() error) { func (c *Controller) OnFinalShutdown(fn func() error) {
c.instance.onFinalShutdown = append(c.instance.onFinalShutdown, fn) c.instance.OnFinalShutdown = append(c.instance.OnFinalShutdown, fn)
} }
// Context gets the context associated with the instance associated with c. // Context gets the context associated with the instance associated with c.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment