Commit 286d8d1e authored by Mateusz Gajewski's avatar Mateusz Gajewski Committed by Matt Holt

tls: Per-site TLS configs using GetClientConfig, including http2 switch (#1389)

* Remove manual TLS clone method

* WiP tls

* Use GetClientConfig for tls.Config

* gofmt -s -w

* GetConfig

* Handshake

* Removed comment

* Disable HTTP2 on demand

* Remove junk

* Remove http2 enable (no-op)
parent 977a3c32
......@@ -31,6 +31,7 @@ type Server struct {
connTimeout time.Duration // max time to wait for a connection before force stop
tlsGovChan chan struct{} // close to stop the TLS maintenance goroutine
vhosts *vhostTrie
tlsConfig caddytls.ConfigGroup
}
// ensure it satisfies the interface
......@@ -72,16 +73,31 @@ func NewServer(addr string, group []*SiteConfig) (*Server, error) {
}
// Set up TLS configuration
var tlsConfigs []*caddytls.Config
tlsConfigs := make(caddytls.ConfigGroup)
var allConfigs []*caddytls.Config
for _, site := range group {
tlsConfigs = append(tlsConfigs, site.TLS)
if err := site.TLS.Build(tlsConfigs); err != nil {
return nil, err
}
var err error
s.Server.TLSConfig, err = caddytls.MakeTLSConfig(tlsConfigs)
if err != nil {
tlsConfigs[site.TLS.Hostname] = site.TLS
allConfigs = append(allConfigs, site.TLS)
}
// Check if configs are valid
if err := caddytls.CheckConfigs(allConfigs); err != nil {
return nil, err
}
s.tlsConfig = tlsConfigs
s.Server.TLSConfig = &tls.Config{
GetConfigForClient: s.tlsConfig.GetConfigForClient,
GetCertificate: s.tlsConfig.GetCertificate,
}
// As of Go 1.7, HTTP/2 is enabled only if NextProtos includes the string "h2"
if HTTP2 && s.Server.TLSConfig != nil && len(s.Server.TLSConfig.NextProtos) == 0 {
s.Server.TLSConfig.NextProtos = []string{"h2"}
......
......@@ -442,7 +442,7 @@ func newConnHijackerTransport(base http.RoundTripper) *connHijackerTransport {
if b, _ := base.(*http.Transport); b != nil {
tlsClientConfig := b.TLSClientConfig
if tlsClientConfig.NextProtos != nil {
tlsClientConfig = cloneTLSClientConfig(tlsClientConfig)
tlsClientConfig = tlsClientConfig.Clone()
tlsClientConfig.NextProtos = nil
}
......@@ -566,37 +566,6 @@ func (tlsHandshakeTimeoutError) Timeout() bool { return true }
func (tlsHandshakeTimeoutError) Temporary() bool { return true }
func (tlsHandshakeTimeoutError) Error() string { return "net/http: TLS handshake timeout" }
// cloneTLSClientConfig is like cloneTLSConfig but omits
// the fields SessionTicketsDisabled and SessionTicketKey.
// This makes it safe to call cloneTLSClientConfig on a config
// in active use by a server.
func cloneTLSClientConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
}
return &tls.Config{
Rand: cfg.Rand,
Time: cfg.Time,
Certificates: cfg.Certificates,
NameToCertificate: cfg.NameToCertificate,
GetCertificate: cfg.GetCertificate,
RootCAs: cfg.RootCAs,
NextProtos: cfg.NextProtos,
ServerName: cfg.ServerName,
ClientAuth: cfg.ClientAuth,
ClientCAs: cfg.ClientCAs,
InsecureSkipVerify: cfg.InsecureSkipVerify,
CipherSuites: cfg.CipherSuites,
PreferServerCipherSuites: cfg.PreferServerCipherSuites,
ClientSessionCache: cfg.ClientSessionCache,
MinVersion: cfg.MinVersion,
MaxVersion: cfg.MaxVersion,
CurvePreferences: cfg.CurvePreferences,
DynamicRecordSizingDisabled: cfg.DynamicRecordSizingDisabled,
Renegotiation: cfg.Renegotiation,
}
}
func requestIsWebsocket(req *http.Request) bool {
return strings.ToLower(req.Header.Get("Upgrade")) == "websocket" && strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade")
}
......
......@@ -108,6 +108,12 @@ type Config struct {
// Add the must staple TLS extension to the CSR generated by lego/acme
MustStaple bool
// Disables HTTP2 completely
DisableHTTP2 bool
// Holds final tls.Config
tlsConfig *tls.Config
}
// OnDemandState contains some state relevant for providing
......@@ -217,47 +223,25 @@ func (c *Config) StorageFor(caURL string) (Storage, error) {
return s, nil
}
// MakeTLSConfig reduces configs into a single tls.Config.
// If TLS is to be disabled, a nil tls.Config will be returned.
func MakeTLSConfig(configs []*Config) (*tls.Config, error) {
if len(configs) == 0 {
return nil, nil
}
config := new(tls.Config)
ciphersAdded := make(map[uint16]struct{})
curvesAdded := make(map[tls.CurveID]struct{})
configMap := make(configGroup)
func (cfg *Config) Build(group ConfigGroup) error {
config, err := cfg.build()
for i, cfg := range configs {
if cfg == nil {
// avoid nil pointer dereference below
configs[i] = new(Config)
continue
if err != nil {
return err
}
// Key this config by its hostname; this
// overwrites configs with the same hostname
configMap[cfg.Hostname] = cfg
cfg.tlsConfig = config
cfg.tlsConfig.GetCertificate = group.GetCertificate
return nil
}
// Can't serve TLS and not-TLS on same port
if i > 0 && cfg.Enabled != configs[i-1].Enabled {
thisConfProto, lastConfProto := "not TLS", "not TLS"
if cfg.Enabled {
thisConfProto = "TLS"
}
if configs[i-1].Enabled {
lastConfProto = "TLS"
}
return nil, fmt.Errorf("cannot multiplex %s (%s) and %s (%s) on same listener",
configs[i-1].Hostname, lastConfProto, cfg.Hostname, thisConfProto)
}
func (cfg *Config) build() (*tls.Config, error) {
config := new(tls.Config)
if !cfg.Enabled {
continue
}
ciphersAdded := make(map[uint16]struct{})
curvesAdded := make(map[tls.CurveID]struct{})
// Union cipher suites
// Add cipher suites
for _, ciph := range cfg.Ciphers {
if _, ok := ciphersAdded[ciph]; !ok {
ciphersAdded[ciph] = struct{}{}
......@@ -265,10 +249,6 @@ func MakeTLSConfig(configs []*Config) (*tls.Config, error) {
}
}
// Can't resolve conflicting PreferServerCipherSuites settings
if i > 0 && cfg.PreferServerCipherSuites != configs[i-1].PreferServerCipherSuites {
return nil, fmt.Errorf("cannot both PreferServerCipherSuites and not prefer them")
}
config.PreferServerCipherSuites = cfg.PreferServerCipherSuites
// Union curves
......@@ -279,48 +259,15 @@ func MakeTLSConfig(configs []*Config) (*tls.Config, error) {
}
}
// Go with the widest range of protocol versions
if config.MinVersion == 0 || cfg.ProtocolMinVersion < config.MinVersion {
config.MinVersion = cfg.ProtocolMinVersion
}
if cfg.ProtocolMaxVersion > config.MaxVersion {
config.MaxVersion = cfg.ProtocolMaxVersion
}
// Go with the strictest ClientAuth type
if cfg.ClientAuth > config.ClientAuth {
config.ClientAuth = cfg.ClientAuth
}
}
// Is TLS disabled? If so, we're done here.
// By now, we know that all configs agree
// whether it is or not, so we can just look
// at the first one.
if len(configs) == 0 || !configs[0].Enabled {
return nil, nil
}
// Default cipher suites
if len(config.CipherSuites) == 0 {
config.CipherSuites = defaultCiphers
}
// For security, ensure TLS_FALLBACK_SCSV is always included first
if len(config.CipherSuites) == 0 || config.CipherSuites[0] != tls.TLS_FALLBACK_SCSV {
config.CipherSuites = append([]uint16{tls.TLS_FALLBACK_SCSV}, config.CipherSuites...)
}
// Default curves
if len(config.CurvePreferences) == 0 {
config.CurvePreferences = defaultCurves
}
// Set up client authentication if enabled
if config.ClientAuth != tls.NoClientCert {
pool := x509.NewCertPool()
clientCertsAdded := make(map[string]struct{})
for _, cfg := range configs {
for _, caFile := range cfg.ClientCerts {
// don't add cert to pool more than once
if _, ok := clientCertsAdded[caFile]; ok {
......@@ -338,16 +285,58 @@ func MakeTLSConfig(configs []*Config) (*tls.Config, error) {
return nil, fmt.Errorf("error loading client certificate '%s': no certificates were successfully parsed", caFile)
}
}
}
config.ClientCAs = pool
}
// Associate the GetCertificate callback, or almost nothing we just did will work
config.GetCertificate = configMap.GetCertificate
// Default cipher suites
if len(config.CipherSuites) == 0 {
config.CipherSuites = defaultCiphers
}
// For security, ensure TLS_FALLBACK_SCSV is always included first
if len(config.CipherSuites) == 0 || config.CipherSuites[0] != tls.TLS_FALLBACK_SCSV {
config.CipherSuites = append([]uint16{tls.TLS_FALLBACK_SCSV}, config.CipherSuites...)
}
if cfg.DisableHTTP2 {
config.NextProtos = []string{}
} else {
config.NextProtos = []string{"h2"}
}
return config, nil
}
// CheckConfigs checks if multiple TLS configs does not collide with each other
func CheckConfigs(configs []*Config) error {
if len(configs) == 0 {
return nil
}
for i, cfg := range configs {
// Can't serve TLS and not-TLS on same port
if i > 0 && cfg.Enabled != configs[i-1].Enabled {
thisConfProto, lastConfProto := "not TLS", "not TLS"
if cfg.Enabled {
thisConfProto = "TLS"
}
if configs[i-1].Enabled {
lastConfProto = "TLS"
}
return fmt.Errorf("cannot multiplex %s (%s) and %s (%s) on same listener",
configs[i-1].Hostname, lastConfProto, cfg.Hostname, thisConfProto)
}
if !cfg.Enabled {
continue
}
}
return nil
}
// ConfigGetter gets a Config keyed by key.
type ConfigGetter func(c *caddy.Controller) *Config
......
......@@ -10,14 +10,12 @@ import (
func TestMakeTLSConfigProtocolVersions(t *testing.T) {
// same min and max protocol versions
configs := []*Config{
{
config := Config{
Enabled: true,
ProtocolMinVersion: tls.VersionTLS12,
ProtocolMaxVersion: tls.VersionTLS12,
},
}
result, err := MakeTLSConfig(configs)
result, err := config.build()
if err != nil {
t.Fatalf("Did not expect an error, but got %v", err)
}
......@@ -31,28 +29,14 @@ func TestMakeTLSConfigProtocolVersions(t *testing.T) {
func TestMakeTLSConfigPreferServerCipherSuites(t *testing.T) {
// prefer server cipher suites
configs := []*Config{{Enabled: true, PreferServerCipherSuites: true}}
result, err := MakeTLSConfig(configs)
config := Config{Enabled: true, PreferServerCipherSuites: true}
result, err := config.build()
if err != nil {
t.Fatalf("Did not expect an error, but got %v", err)
}
if got, want := result.PreferServerCipherSuites, true; got != want {
t.Errorf("Expected PreferServerCipherSuites==%v but got %v", want, got)
}
// make sure we don't get an error if there's a conflict
// when both of the configs have TLS disabled
configs = []*Config{
{Enabled: false, PreferServerCipherSuites: false},
{Enabled: false, PreferServerCipherSuites: true},
}
result, err = MakeTLSConfig(configs)
if err != nil {
t.Fatalf("Did not expect an error when TLS is disabled, but got '%v'", err)
}
if result != nil {
t.Errorf("Expected nil result because TLS disabled, got: %+v", err)
}
}
func TestMakeTLSConfigTLSEnabledDisabled(t *testing.T) {
......@@ -61,20 +45,10 @@ func TestMakeTLSConfigTLSEnabledDisabled(t *testing.T) {
{Enabled: true},
{Enabled: false},
}
_, err := MakeTLSConfig(configs)
err := CheckConfigs(configs)
if err == nil {
t.Fatalf("Expected an error, but got %v", err)
}
// verify that when disabled, a nil pair is returned
configs = []*Config{{}, {}}
result, err := MakeTLSConfig(configs)
if err != nil {
t.Errorf("Did not expect an error, but got %v", err)
}
if result != nil {
t.Errorf("Expected a nil *tls.Config result, got %+v", result)
}
}
func TestMakeTLSConfigCipherSuites(t *testing.T) {
......@@ -83,25 +57,22 @@ func TestMakeTLSConfigCipherSuites(t *testing.T) {
configs := []*Config{
{Enabled: true, Ciphers: []uint16{0xc02c, 0xc030}},
{Enabled: true, Ciphers: []uint16{0xc012, 0xc030, 0xc00a}},
{Enabled: true, Ciphers: nil},
}
result, err := MakeTLSConfig(configs)
if err != nil {
t.Fatalf("Did not expect an error, but got %v", err)
}
expected := []uint16{tls.TLS_FALLBACK_SCSV, 0xc02c, 0xc030, 0xc012, 0xc00a}
if !reflect.DeepEqual(result.CipherSuites, expected) {
t.Errorf("Expected ciphers %v but got %v", expected, result.CipherSuites)
expectedCiphers := [][]uint16{
{tls.TLS_FALLBACK_SCSV, 0xc02c, 0xc030},
{tls.TLS_FALLBACK_SCSV, 0xc012, 0xc030, 0xc00a},
append([]uint16{tls.TLS_FALLBACK_SCSV}, defaultCiphers...),
}
// use default suites if none specified
configs = []*Config{{Enabled: true}}
result, err = MakeTLSConfig(configs)
if err != nil {
t.Fatalf("Did not expect an error, but got %v", err)
for i, config := range configs {
cfg, _ := config.build()
if !reflect.DeepEqual(cfg.CipherSuites, expectedCiphers[i]) {
t.Errorf("Expected ciphers %v but got %v", expectedCiphers[i], cfg.CipherSuites)
}
expected = append([]uint16{tls.TLS_FALLBACK_SCSV}, defaultCiphers...)
if !reflect.DeepEqual(result.CipherSuites, expected) {
t.Errorf("Expected default ciphers %v but got %v", expected, result.CipherSuites)
}
}
......
......@@ -15,7 +15,7 @@ import (
// (hostnames can have wildcard characters; use the getConfig
// method to get a config by matching its hostname). Its
// GetCertificate function can be used with tls.Config.
type configGroup map[string]*Config
type ConfigGroup map[string]*Config
// getConfig gets the config by the first key match for name.
// In other words, "sub.foo.bar" will get the config for "*.foo.bar"
......@@ -24,7 +24,7 @@ type configGroup map[string]*Config
//
// This function follows nearly the same logic to lookup
// a hostname as the getCertificate function uses.
func (cg configGroup) getConfig(name string) *Config {
func (cg ConfigGroup) getConfig(name string) *Config {
name = strings.ToLower(name)
// exact match? great, let's use it
......@@ -58,11 +58,27 @@ func (cg configGroup) getConfig(name string) *Config {
// via ACME.
//
// This method is safe for use as a tls.Config.GetCertificate callback.
func (cg configGroup) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
func (cg ConfigGroup) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
cert, err := cg.getCertDuringHandshake(strings.ToLower(clientHello.ServerName), true, true)
return &cert.Certificate, err
}
// GetConfigForClient gets a TLS configuration satisfying clientHello. In getting
// the configuration, it abides the rules and settings defined in the
// Config that matches clientHello.ServerName.
//
// This method is safe for use as a tls.Config.GetConfigForClient callback.
func (cg ConfigGroup) GetConfigForClient(clientHello *tls.ClientHelloInfo) (*tls.Config, error) {
config := cg.getConfig(clientHello.ServerName)
if config != nil {
return config.tlsConfig, nil
}
return nil, nil
}
// getCertDuringHandshake will get a certificate for name. It first tries
// the in-memory cache. If no certificate for name is in the cache, the
// config most closely corresponding to name will be loaded. If that config
......@@ -74,7 +90,7 @@ func (cg configGroup) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Cer
// certificate is available.
//
// This function is safe for concurrent use.
func (cg configGroup) getCertDuringHandshake(name string, loadIfNecessary, obtainIfNecessary bool) (Certificate, error) {
func (cg ConfigGroup) getCertDuringHandshake(name string, loadIfNecessary, obtainIfNecessary bool) (Certificate, error) {
// First check our in-memory cache to see if we've already loaded it
cert, matched, defaulted := getCertificate(name)
if matched {
......@@ -127,7 +143,7 @@ func (cg configGroup) getCertDuringHandshake(name string, loadIfNecessary, obtai
// now according to mitigating factors we keep track of and preferences the
// user has set. If a non-nil error is returned, do not issue a new certificate
// for name.
func (cg configGroup) checkLimitsForObtainingNewCerts(name string, cfg *Config) error {
func (cg ConfigGroup) checkLimitsForObtainingNewCerts(name string, cfg *Config) error {
// User can set hard limit for number of certs for the process to issue
if cfg.OnDemandState.MaxObtain > 0 &&
atomic.LoadInt32(&cfg.OnDemandState.ObtainedCount) >= cfg.OnDemandState.MaxObtain {
......@@ -160,7 +176,7 @@ func (cg configGroup) checkLimitsForObtainingNewCerts(name string, cfg *Config)
// name, it will wait and use what the other goroutine obtained.
//
// This function is safe for use by multiple concurrent goroutines.
func (cg configGroup) obtainOnDemandCertificate(name string, cfg *Config) (Certificate, error) {
func (cg ConfigGroup) obtainOnDemandCertificate(name string, cfg *Config) (Certificate, error) {
// We must protect this process from happening concurrently, so synchronize.
obtainCertWaitChansMu.Lock()
wait, ok := obtainCertWaitChans[name]
......@@ -219,7 +235,7 @@ func (cg configGroup) obtainOnDemandCertificate(name string, cfg *Config) (Certi
// validity.
//
// This function is safe for use by multiple concurrent goroutines.
func (cg configGroup) handshakeMaintenance(name string, cert Certificate) (Certificate, error) {
func (cg ConfigGroup) handshakeMaintenance(name string, cert Certificate) (Certificate, error) {
// Check cert expiration
timeLeft := cert.NotAfter.Sub(time.Now().UTC())
if timeLeft < RenewDurationBefore {
......@@ -252,7 +268,7 @@ func (cg configGroup) handshakeMaintenance(name string, cert Certificate) (Certi
// usable. name should already be lower-cased before calling this function.
//
// This function is safe for use by multiple concurrent goroutines.
func (cg configGroup) renewDynamicCertificate(name string, cfg *Config) (Certificate, error) {
func (cg ConfigGroup) renewDynamicCertificate(name string, cfg *Config) (Certificate, error) {
obtainCertWaitChansMu.Lock()
wait, ok := obtainCertWaitChans[name]
if ok {
......
......@@ -9,7 +9,7 @@ import (
func TestGetCertificate(t *testing.T) {
defer func() { certCache = make(map[string]Certificate) }()
cg := make(configGroup)
cg := make(ConfigGroup)
hello := &tls.ClientHelloInfo{ServerName: "example.com"}
helloSub := &tls.ClientHelloInfo{ServerName: "sub.example.com"}
......
......@@ -164,6 +164,20 @@ func setupTLS(c *caddy.Controller) error {
return c.Errf("Unsupported Storage provider '%s'", args[0])
}
config.StorageProvider = args[0]
case "http2":
args := c.RemainingArgs()
if len(args) != 1 {
return c.ArgErr()
}
switch args[0] {
case "off":
config.DisableHTTP2 = true
default:
c.ArgErr()
}
case "muststaple":
config.MustStaple = true
default:
......
......@@ -91,6 +91,10 @@ func TestSetupParseBasic(t *testing.T) {
t.Error("Expected PreferServerCipherSuites = true, but was false")
}
if cfg.DisableHTTP2 {
t.Error("Expected HTTP2 to be enabled by default")
}
// Ensure curve count is correct
if len(cfg.CurvePreferences) != len(defaultCurves) {
t.Errorf("Expected %v Curves, got %v", len(defaultCurves), len(cfg.CurvePreferences))
......@@ -118,6 +122,7 @@ func TestSetupParseWithOptionalParams(t *testing.T) {
protocols tls1.0 tls1.2
ciphers RSA-AES256-CBC-SHA ECDHE-RSA-AES128-GCM-SHA256 ECDHE-ECDSA-AES256-GCM-SHA384
muststaple
http2 off
}`
cfg := new(Config)
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
......@@ -141,7 +146,11 @@ func TestSetupParseWithOptionalParams(t *testing.T) {
}
if !cfg.MustStaple {
t.Errorf("Expected must staple to be true")
t.Error("Expected must staple to be true")
}
if !cfg.DisableHTTP2 {
t.Error("Expected HTTP2 to be disabled")
}
}
......@@ -184,7 +193,7 @@ func TestSetupParseWithWrongOptionalParams(t *testing.T) {
c = caddy.NewTestController("", params)
err = setupTLS(c)
if err == nil {
t.Errorf("Expected errors, but no error returned")
t.Error("Expected errors, but no error returned")
}
// Test key_type wrong params
......@@ -196,7 +205,7 @@ func TestSetupParseWithWrongOptionalParams(t *testing.T) {
c = caddy.NewTestController("", params)
err = setupTLS(c)
if err == nil {
t.Errorf("Expected errors, but no error returned")
t.Error("Expected errors, but no error returned")
}
// Test curves wrong params
......@@ -208,7 +217,7 @@ func TestSetupParseWithWrongOptionalParams(t *testing.T) {
c = caddy.NewTestController("", params)
err = setupTLS(c)
if err == nil {
t.Errorf("Expected errors, but no error returned")
t.Error("Expected errors, but no error returned")
}
}
......@@ -222,7 +231,7 @@ func TestSetupParseWithClientAuth(t *testing.T) {
c := caddy.NewTestController("", params)
err := setupTLS(c)
if err == nil {
t.Errorf("Expected an error, but no error returned")
t.Error("Expected an error, but no error returned")
}
noCAs, twoCAs := []string{}, []string{"client_ca.crt", "client2_ca.crt"}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment