Commit d05f8929 authored by Matthew Holt's avatar Matthew Holt

https: Minor refactoring and some new tests

parent 1ef7f3c4
...@@ -50,23 +50,21 @@ type Certificate struct { ...@@ -50,23 +50,21 @@ type Certificate struct {
OCSP *ocsp.Response OCSP *ocsp.Response
} }
// getCertificate gets a certificate from the in-memory cache that // getCertificate gets a certificate that matches name (a server name)
// matches name (a certificate name). Note that if name does not have // from the in-memory cache. If there is no exact match for name, it
// an exact match, it will be checked against names of the form // will be checked against names of the form '*.example.com' (wildcard
// '*.example.com' (wildcard certificates) according to RFC 6125. // certificates) according to RFC 6125. If a match is found, matched will
// // be true. If no matches are found, matched will be false and a default
// If cert was found by matching name, matched will be returned true. // certificate will be returned with defaulted set to true. If no default
// If no match is found, the default certificate will be returned and // certificate is set, defaulted will be set to false.
// matched will be returned as false. (The default certificate is the
// first one that entered the cache.) If the cache is empty (or there
// is no default certificate for some reason), matched will still be
// false, but cert.Certificate will be nil.
// //
// The logic in this function is adapted from the Go standard library, // The logic in this function is adapted from the Go standard library,
// which is by the Go Authors. // which is by the Go Authors.
// //
// This function is safe for concurrent use. // This function is safe for concurrent use.
func getCertificate(name string) (cert Certificate, matched bool) { func getCertificate(name string) (cert Certificate, matched, defaulted bool) {
var ok bool
// Not going to trim trailing dots here since RFC 3546 says, // Not going to trim trailing dots here since RFC 3546 says,
// "The hostname is represented ... without a trailing dot." // "The hostname is represented ... without a trailing dot."
// Just normalize to lowercase. // Just normalize to lowercase.
...@@ -76,8 +74,9 @@ func getCertificate(name string) (cert Certificate, matched bool) { ...@@ -76,8 +74,9 @@ func getCertificate(name string) (cert Certificate, matched bool) {
defer certCacheMu.RUnlock() defer certCacheMu.RUnlock()
// exact match? great, let's use it // exact match? great, let's use it
if cert, ok := certCache[name]; ok { if cert, ok = certCache[name]; ok {
return cert, true matched = true
return
} }
// try replacing labels in the name with wildcards until we get a match // try replacing labels in the name with wildcards until we get a match
...@@ -85,14 +84,15 @@ func getCertificate(name string) (cert Certificate, matched bool) { ...@@ -85,14 +84,15 @@ func getCertificate(name string) (cert Certificate, matched bool) {
for i := range labels { for i := range labels {
labels[i] = "*" labels[i] = "*"
candidate := strings.Join(labels, ".") candidate := strings.Join(labels, ".")
if cert, ok := certCache[candidate]; ok { if cert, ok = certCache[candidate]; ok {
return cert, true matched = true
return
} }
} }
// if nothing matches, return the default certificate // if nothing matches, use the default certificate or bust
cert = certCache[""] cert, defaulted = certCache[""]
return cert, false return
} }
// cacheManagedCertificate loads the certificate for domain into the // cacheManagedCertificate loads the certificate for domain into the
...@@ -214,8 +214,8 @@ func cacheCertificate(cert Certificate) { ...@@ -214,8 +214,8 @@ func cacheCertificate(cert Certificate) {
certCacheMu.Lock() certCacheMu.Lock()
if _, ok := certCache[""]; !ok { if _, ok := certCache[""]; !ok {
// use as default // use as default
certCache[""] = cert
cert.Names = append(cert.Names, "") cert.Names = append(cert.Names, "")
certCache[""] = cert
} }
for len(certCache)+len(cert.Names) > 10000 { for len(certCache)+len(cert.Names) > 10000 {
// for simplicity, just remove random elements // for simplicity, just remove random elements
......
package https
import "testing"
func TestUnexportedGetCertificate(t *testing.T) {
defer func() { certCache = make(map[string]Certificate) }()
// When cache is empty
if _, matched, defaulted := getCertificate("example.com"); matched || defaulted {
t.Errorf("Got a certificate when cache was empty; matched=%v, defaulted=%v", matched, defaulted)
}
// When cache has one certificate in it (also is default)
defaultCert := Certificate{Names: []string{"example.com", ""}}
certCache[""] = defaultCert
certCache["example.com"] = defaultCert
if cert, matched, defaulted := getCertificate("Example.com"); !matched || defaulted || cert.Names[0] != "example.com" {
t.Errorf("Didn't get a cert for 'Example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
}
if cert, matched, defaulted := getCertificate(""); !matched || defaulted || cert.Names[0] != "example.com" {
t.Errorf("Didn't get a cert for '' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
}
// When retrieving wildcard certificate
certCache["*.example.com"] = Certificate{Names: []string{"*.example.com"}}
if cert, matched, defaulted := getCertificate("sub.example.com"); !matched || defaulted || cert.Names[0] != "*.example.com" {
t.Errorf("Didn't get wildcard cert for 'sub.example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
}
// When no certificate matches, the default is returned
if cert, matched, defaulted := getCertificate("nomatch"); matched || !defaulted {
t.Errorf("Expected matched=false, defaulted=true; but got matched=%v, defaulted=%v (cert: %v)", matched, defaulted, cert)
} else if cert.Names[0] != "example.com" {
t.Errorf("Expected default cert, got: %v", cert)
}
}
func TestCacheCertificate(t *testing.T) {
defer func() { certCache = make(map[string]Certificate) }()
cacheCertificate(Certificate{Names: []string{"example.com", "sub.example.com"}})
if _, ok := certCache["example.com"]; !ok {
t.Error("Expected first cert to be cached by key 'example.com', but it wasn't")
}
if _, ok := certCache["sub.example.com"]; !ok {
t.Error("Expected first cert to be cached by key 'sub.exmaple.com', but it wasn't")
}
if cert, ok := certCache[""]; !ok || cert.Names[2] != "" {
t.Error("Expected first cert to be cached additionally as the default certificate with empty name added, but it wasn't")
}
cacheCertificate(Certificate{Names: []string{"example2.com"}})
if _, ok := certCache["example2.com"]; !ok {
t.Error("Expected second cert to be cached by key 'exmaple2.com', but it wasn't")
}
if cert, ok := certCache[""]; ok && cert.Names[0] == "example2.com" {
t.Error("Expected second cert to NOT be cached as default, but it was")
}
}
...@@ -39,31 +39,30 @@ func GetOrObtainCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, ...@@ -39,31 +39,30 @@ func GetOrObtainCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate,
} }
// getCertDuringHandshake will get a certificate for name. It first tries // getCertDuringHandshake will get a certificate for name. It first tries
// the in-memory cache. If no certificate for name is in the cach and if // the in-memory cache. If no certificate for name is in the cache and if
// loadIfNecessary == true, it goes to disk to load it into the cache and // loadIfNecessary == true, it goes to disk to load it into the cache and
// serve it. If it's not on disk and if obtainIfNecessary == true, the // serve it. If it's not on disk and if obtainIfNecessary == true, the
// certificate will be obtained from the CA, cached, and served. If // certificate will be obtained from the CA, cached, and served. If
// obtainIfNecessary is true, then loadIfNecessary must also be set to true. // obtainIfNecessary is true, then loadIfNecessary must also be set to true.
// An error will be returned if and only if no certificate is available.
// //
// This function is safe for concurrent use. // This function is safe for concurrent use.
func getCertDuringHandshake(name string, loadIfNecessary, obtainIfNecessary bool) (Certificate, error) { func getCertDuringHandshake(name string, loadIfNecessary, obtainIfNecessary bool) (Certificate, error) {
// First check our in-memory cache to see if we've already loaded it // First check our in-memory cache to see if we've already loaded it
cert, ok := getCertificate(name) cert, matched, defaulted := getCertificate(name)
if ok { if matched {
return cert, nil return cert, nil
} }
if loadIfNecessary { if loadIfNecessary {
var err error
// Then check to see if we have one on disk // Then check to see if we have one on disk
cert, err = cacheManagedCertificate(name, true) loadedCert, err := cacheManagedCertificate(name, true)
if err == nil { if err == nil {
cert, err = handshakeMaintenance(name, cert) loadedCert, err = handshakeMaintenance(name, loadedCert)
if err != nil { if err != nil {
log.Printf("[ERROR] Maintaining newly-loaded certificate for %s: %v", name, err) log.Printf("[ERROR] Maintaining newly-loaded certificate for %s: %v", name, err)
} }
return cert, nil return loadedCert, nil
} }
if obtainIfNecessary { if obtainIfNecessary {
...@@ -87,7 +86,11 @@ func getCertDuringHandshake(name string, loadIfNecessary, obtainIfNecessary bool ...@@ -87,7 +86,11 @@ func getCertDuringHandshake(name string, loadIfNecessary, obtainIfNecessary bool
} }
} }
return Certificate{}, nil if defaulted {
return cert, nil
}
return Certificate{}, errors.New("no certificate for " + name)
} }
// checkLimitsForObtainingNewCerts checks to see if name can be issued right // checkLimitsForObtainingNewCerts checks to see if name can be issued right
......
package https
import (
"crypto/tls"
"crypto/x509"
"testing"
)
func TestGetCertificate(t *testing.T) {
defer func() { certCache = make(map[string]Certificate) }()
hello := &tls.ClientHelloInfo{ServerName: "example.com"}
helloSub := &tls.ClientHelloInfo{ServerName: "sub.example.com"}
helloNoSNI := &tls.ClientHelloInfo{}
helloNoMatch := &tls.ClientHelloInfo{ServerName: "nomatch"}
// When cache is empty
if cert, err := GetCertificate(hello); err == nil {
t.Errorf("GetCertificate should return error when cache is empty, got: %v", cert)
}
if cert, err := GetCertificate(helloNoSNI); err == nil {
t.Errorf("GetCertificate should return error when cache is empty even if server name is blank, got: %v", cert)
}
// When cache has one certificate in it (also is default)
defaultCert := Certificate{Names: []string{"example.com", ""}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"example.com"}}}}
certCache[""] = defaultCert
certCache["example.com"] = defaultCert
if cert, err := GetCertificate(hello); err != nil {
t.Errorf("Got an error but shouldn't have, when cert exists in cache: %v", err)
} else if cert.Leaf.DNSNames[0] != "example.com" {
t.Errorf("Got wrong certificate with exact match; expected 'example.com', got: %v", cert)
}
if cert, err := GetCertificate(helloNoSNI); err != nil {
t.Errorf("Got an error with no SNI but shouldn't have, when cert exists in cache: %v", err)
} else if cert.Leaf.DNSNames[0] != "example.com" {
t.Errorf("Got wrong certificate for no SNI; expected 'example.com' as default, got: %v", cert)
}
// When retrieving wildcard certificate
certCache["*.example.com"] = Certificate{Names: []string{"*.example.com"}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"*.example.com"}}}}
if cert, err := GetCertificate(helloSub); err != nil {
t.Errorf("Didn't get wildcard cert, got: cert=%v, err=%v ", cert, err)
} else if cert.Leaf.DNSNames[0] != "*.example.com" {
t.Errorf("Got wrong certificate, expected wildcard: %v", cert)
}
// When no certificate matches, the default is returned
if cert, err := GetCertificate(helloNoMatch); err != nil {
t.Errorf("Expected default certificate with no error when no matches, got err: %v", err)
} else if cert.Leaf.DNSNames[0] != "example.com" {
t.Errorf("Expected default cert with no matches, got: %v", cert)
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment