Commit 95514da9 authored by Matt Holt's avatar Matt Holt Committed by GitHub

Merge pull request #2072 from mholt/acmev2

tls: Use ACMEv2 and support automatic wildcard certificates
parents a8dfa9f0 18ff8748
...@@ -27,7 +27,7 @@ import ( ...@@ -27,7 +27,7 @@ import (
"gopkg.in/natefinch/lumberjack.v2" "gopkg.in/natefinch/lumberjack.v2"
"github.com/xenolf/lego/acme" "github.com/xenolf/lego/acmev2"
"github.com/mholt/caddy" "github.com/mholt/caddy"
// plug in the HTTP server type // plug in the HTTP server type
...@@ -42,7 +42,7 @@ func init() { ...@@ -42,7 +42,7 @@ func init() {
setVersion() setVersion()
flag.BoolVar(&caddytls.Agreed, "agree", false, "Agree to the CA's Subscriber Agreement") flag.BoolVar(&caddytls.Agreed, "agree", false, "Agree to the CA's Subscriber Agreement")
flag.StringVar(&caddytls.DefaultCAUrl, "ca", "https://acme-v01.api.letsencrypt.org/directory", "URL to certificate authority's ACME server directory") flag.StringVar(&caddytls.DefaultCAUrl, "ca", "https://acme-v02.api.letsencrypt.org/directory", "URL to certificate authority's ACME server directory")
flag.BoolVar(&caddytls.DisableHTTPChallenge, "disable-http-challenge", caddytls.DisableHTTPChallenge, "Disable the ACME HTTP challenge") flag.BoolVar(&caddytls.DisableHTTPChallenge, "disable-http-challenge", caddytls.DisableHTTPChallenge, "Disable the ACME HTTP challenge")
flag.BoolVar(&caddytls.DisableTLSSNIChallenge, "disable-tls-sni-challenge", caddytls.DisableTLSSNIChallenge, "Disable the ACME TLS-SNI challenge") flag.BoolVar(&caddytls.DisableTLSSNIChallenge, "disable-tls-sni-challenge", caddytls.DisableTLSSNIChallenge, "Disable the ACME TLS-SNI challenge")
flag.StringVar(&conf, "conf", "", "Caddyfile to load (default \""+caddy.DefaultConfigFile+"\")") flag.StringVar(&conf, "conf", "", "Caddyfile to load (default \""+caddy.DefaultConfigFile+"\")")
......
...@@ -100,8 +100,8 @@ func enableAutoHTTPS(configs []*SiteConfig, loadCertificates bool) error { ...@@ -100,8 +100,8 @@ func enableAutoHTTPS(configs []*SiteConfig, loadCertificates bool) error {
} }
cfg.TLS.Enabled = true cfg.TLS.Enabled = true
cfg.Addr.Scheme = "https" cfg.Addr.Scheme = "https"
if loadCertificates && caddytls.HostQualifies(cfg.Addr.Host) { if loadCertificates && caddytls.HostQualifies(cfg.TLS.Hostname) {
_, err := cfg.TLS.CacheManagedCertificate(cfg.Addr.Host) _, err := cfg.TLS.CacheManagedCertificate(cfg.TLS.Hostname)
if err != nil { if err != nil {
return err return err
} }
......
...@@ -431,12 +431,26 @@ func (r *replacer) getSubstitution(key string) string { ...@@ -431,12 +431,26 @@ func (r *replacer) getSubstitution(key string) string {
return "UNKNOWN" // this should never happen, but guard in case return "UNKNOWN" // this should never happen, but guard in case
} }
return r.emptyValue return r.emptyValue
default:
// {labelN}
if strings.HasPrefix(key, "{label") {
nStr := key[6 : len(key)-1] // get the integer N in "{labelN}"
n, err := strconv.Atoi(nStr)
if err != nil || n < 1 {
return r.emptyValue
}
labels := strings.Split(r.request.Host, ".")
if n > len(labels) {
return r.emptyValue
}
return labels[n-1]
}
} }
return r.emptyValue return r.emptyValue
} }
//convertToMilliseconds returns the number of milliseconds in the given duration // convertToMilliseconds returns the number of milliseconds in the given duration
func convertToMilliseconds(d time.Duration) int64 { func convertToMilliseconds(d time.Duration) int64 {
return d.Nanoseconds() / 1e6 return d.Nanoseconds() / 1e6
} }
......
...@@ -53,7 +53,7 @@ func TestReplace(t *testing.T) { ...@@ -53,7 +53,7 @@ func TestReplace(t *testing.T) {
recordRequest := NewResponseRecorder(w) recordRequest := NewResponseRecorder(w)
reader := strings.NewReader(`{"username": "dennis"}`) reader := strings.NewReader(`{"username": "dennis"}`)
request, err := http.NewRequest("POST", "http://localhost/?foo=bar", reader) request, err := http.NewRequest("POST", "http://localhost.local/?foo=bar", reader)
if err != nil { if err != nil {
t.Fatalf("Failed to make request: %v", err) t.Fatalf("Failed to make request: %v", err)
} }
...@@ -87,7 +87,7 @@ func TestReplace(t *testing.T) { ...@@ -87,7 +87,7 @@ func TestReplace(t *testing.T) {
expect string expect string
}{ }{
{"This hostname is {hostname}", "This hostname is " + hostname}, {"This hostname is {hostname}", "This hostname is " + hostname},
{"This host is {host}.", "This host is localhost."}, {"This host is {host}.", "This host is localhost.local."},
{"This request method is {method}.", "This request method is POST."}, {"This request method is {method}.", "This request method is POST."},
{"The response status is {status}.", "The response status is 200."}, {"The response status is {status}.", "The response status is 200."},
{"{when}", "02/Jan/2006:15:04:05 +0000"}, {"{when}", "02/Jan/2006:15:04:05 +0000"},
...@@ -97,7 +97,7 @@ func TestReplace(t *testing.T) { ...@@ -97,7 +97,7 @@ func TestReplace(t *testing.T) {
{"The CustomAdd header is {>CustomAdd}.", "The CustomAdd header is caddy."}, {"The CustomAdd header is {>CustomAdd}.", "The CustomAdd header is caddy."},
{"The Custom response header is {<Custom}.", "The Custom response header is CustomResponseHeader."}, {"The Custom response header is {<Custom}.", "The Custom response header is CustomResponseHeader."},
{"Bad {>Custom placeholder", "Bad {>Custom placeholder"}, {"Bad {>Custom placeholder", "Bad {>Custom placeholder"},
{"The request is {request}.", "The request is POST /?foo=bar HTTP/1.1\\r\\nHost: localhost\\r\\n" + {"The request is {request}.", "The request is POST /?foo=bar HTTP/1.1\\r\\nHost: localhost.local\\r\\n" +
"Cookie: foo=bar; taste=delicious\\r\\nCustom: foobarbaz\\r\\nCustomadd: caddy\\r\\n" + "Cookie: foo=bar; taste=delicious\\r\\nCustom: foobarbaz\\r\\nCustomadd: caddy\\r\\n" +
"Shorterval: 1\\r\\n\\r\\n."}, "Shorterval: 1\\r\\n\\r\\n."},
{"The cUsToM header is {>cUsToM}...", "The cUsToM header is foobarbaz..."}, {"The cUsToM header is {>cUsToM}...", "The cUsToM header is foobarbaz..."},
...@@ -112,6 +112,8 @@ func TestReplace(t *testing.T) { ...@@ -112,6 +112,8 @@ func TestReplace(t *testing.T) {
{"Query string is {query}", "Query string is foo=bar"}, {"Query string is {query}", "Query string is foo=bar"},
{"Query string value for foo is {?foo}", "Query string value for foo is bar"}, {"Query string value for foo is {?foo}", "Query string value for foo is bar"},
{"Missing query string argument is {?missing}", "Missing query string argument is "}, {"Missing query string argument is {?missing}", "Missing query string argument is "},
{"{label1} {label2} {label3} {label4}", "localhost local - -"},
{"Label with missing number is {label} or {labelQQ}", "Label with missing number is - or -"},
{"\\{ 'hostname': '{hostname}' \\}", "{ 'hostname': '" + hostname + "' }"}, {"\\{ 'hostname': '{hostname}' \\}", "{ 'hostname': '" + hostname + "' }"},
} }
......
...@@ -26,7 +26,7 @@ import ( ...@@ -26,7 +26,7 @@ import (
"time" "time"
"github.com/mholt/caddy" "github.com/mholt/caddy"
"github.com/xenolf/lego/acme" "github.com/xenolf/lego/acmev2"
) )
// acmeMu ensures that only one ACME challenge occurs at a time. // acmeMu ensures that only one ACME challenge occurs at a time.
...@@ -89,26 +89,21 @@ var newACMEClient = func(config *Config, allowPrompts bool) (*ACMEClient, error) ...@@ -89,26 +89,21 @@ var newACMEClient = func(config *Config, allowPrompts bool) (*ACMEClient, error)
// If not registered, the user must register an account with the CA // If not registered, the user must register an account with the CA
// and agree to terms // and agree to terms
if leUser.Registration == nil { if leUser.Registration == nil {
reg, err := client.Register()
if err != nil {
return nil, errors.New("registration error: " + err.Error())
}
leUser.Registration = reg
if allowPrompts { // can't prompt a user who isn't there if allowPrompts { // can't prompt a user who isn't there
if !Agreed && reg.TosURL == "" { termsURL := client.GetToSURL()
Agreed = promptUserAgreement(saURL, false) // TODO - latest URL if !Agreed && termsURL != "" {
Agreed = askUserAgreement(client.GetToSURL())
} }
if !Agreed && reg.TosURL == "" { if !Agreed && termsURL != "" {
return nil, errors.New("user must agree to terms") return nil, errors.New("user must agree to CA terms (use -agree flag)")
} }
} }
err = client.AgreeToTOS() reg, err := client.Register(Agreed)
if err != nil { if err != nil {
saveUser(storage, leUser) // Might as well try, right? return nil, errors.New("registration error: " + err.Error())
return nil, errors.New("error agreeing to terms: " + err.Error())
} }
leUser.Registration = reg
// save user to the file system // save user to the file system
err = saveUser(storage, leUser) err = saveUser(storage, leUser)
...@@ -136,38 +131,57 @@ var newACMEClient = func(config *Config, allowPrompts bool) (*ACMEClient, error) ...@@ -136,38 +131,57 @@ var newACMEClient = func(config *Config, allowPrompts bool) (*ACMEClient, error)
useHTTPPort = DefaultHTTPAlternatePort useHTTPPort = DefaultHTTPAlternatePort
} }
// TODO: tls-sni challenge was removed in January 2018, but a variant of it might return
// See which port TLS-SNI challenges will be accomplished on // See which port TLS-SNI challenges will be accomplished on
useTLSSNIPort := TLSSNIChallengePort // useTLSSNIPort := TLSSNIChallengePort
if config.AltTLSSNIPort != "" { // if config.AltTLSSNIPort != "" {
useTLSSNIPort = config.AltTLSSNIPort // useTLSSNIPort = config.AltTLSSNIPort
} // }
// err := c.acmeClient.SetTLSAddress(net.JoinHostPort(config.ListenHost, useTLSSNIPort))
// if err != nil {
// return nil, err
// }
// if using file storage, we can distribute the HTTP challenge across
// all instances sharing the acme folder; either way, we must still set
// the address for the default HTTP provider server
var useDistributedHTTPSolver bool
if storage, err := c.config.StorageFor(c.config.CAUrl); err == nil {
if _, ok := storage.(*FileStorage); ok {
useDistributedHTTPSolver = true
}
}
if useDistributedHTTPSolver {
c.acmeClient.SetChallengeProvider(acme.HTTP01, distributedHTTPSolver{
// being careful to respect user's listener bind preferences
httpProviderServer: acme.NewHTTPProviderServer(config.ListenHost, useHTTPPort),
})
} else {
// Always respect user's bind preferences by using config.ListenHost. // Always respect user's bind preferences by using config.ListenHost.
// NOTE(Sep'16): At time of writing, SetHTTPAddress() and SetTLSAddress() // NOTE(Sep'16): At time of writing, SetHTTPAddress() and SetTLSAddress()
// must be called before SetChallengeProvider(), since they reset the // must be called before SetChallengeProvider() (see above), since they reset
// challenge provider back to the default one! // the challenge provider back to the default one! (still true in March 2018)
err := c.acmeClient.SetHTTPAddress(net.JoinHostPort(config.ListenHost, useHTTPPort)) err := c.acmeClient.SetHTTPAddress(net.JoinHostPort(config.ListenHost, useHTTPPort))
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = c.acmeClient.SetTLSAddress(net.JoinHostPort(config.ListenHost, useTLSSNIPort))
if err != nil {
return nil, err
} }
// TODO: tls-sni challenge was removed in January 2018, but a variant of it might return
// See if TLS challenge needs to be handled by our own facilities // See if TLS challenge needs to be handled by our own facilities
if caddy.HasListenerWithAddress(net.JoinHostPort(config.ListenHost, useTLSSNIPort)) { // if caddy.HasListenerWithAddress(net.JoinHostPort(config.ListenHost, useTLSSNIPort)) {
c.acmeClient.SetChallengeProvider(acme.TLSSNI01, tlsSNISolver{certCache: config.certCache}) // c.acmeClient.SetChallengeProvider(acme.TLSSNI01, tlsSNISolver{certCache: config.certCache})
} // }
// Disable any challenges that should not be used // Disable any challenges that should not be used
var disabledChallenges []acme.Challenge var disabledChallenges []acme.Challenge
if DisableHTTPChallenge { if DisableHTTPChallenge {
disabledChallenges = append(disabledChallenges, acme.HTTP01) disabledChallenges = append(disabledChallenges, acme.HTTP01)
} }
if DisableTLSSNIChallenge { // TODO: tls-sni challenge was removed in January 2018, but a variant of it might return
disabledChallenges = append(disabledChallenges, acme.TLSSNI01) // if DisableTLSSNIChallenge {
} // disabledChallenges = append(disabledChallenges, acme.TLSSNI01)
// }
if len(disabledChallenges) > 0 { if len(disabledChallenges) > 0 {
c.acmeClient.ExcludeChallenges(disabledChallenges) c.acmeClient.ExcludeChallenges(disabledChallenges)
} }
...@@ -188,7 +202,9 @@ var newACMEClient = func(config *Config, allowPrompts bool) (*ACMEClient, error) ...@@ -188,7 +202,9 @@ var newACMEClient = func(config *Config, allowPrompts bool) (*ACMEClient, error)
} }
// Use the DNS challenge exclusively // Use the DNS challenge exclusively
c.acmeClient.ExcludeChallenges([]acme.Challenge{acme.HTTP01, acme.TLSSNI01}) // TODO: tls-sni challenge was removed in January 2018, but a variant of it might return
// c.acmeClient.ExcludeChallenges([]acme.Challenge{acme.HTTP01, acme.TLSSNI01})
c.acmeClient.ExcludeChallenges([]acme.Challenge{acme.HTTP01})
c.acmeClient.SetChallengeProvider(acme.DNS01, prov) c.acmeClient.SetChallengeProvider(acme.DNS01, prov)
} }
...@@ -221,7 +237,6 @@ func (c *ACMEClient) Obtain(name string) error { ...@@ -221,7 +237,6 @@ func (c *ACMEClient) Obtain(name string) error {
} }
}() }()
Attempts:
for attempts := 0; attempts < 2; attempts++ { for attempts := 0; attempts < 2; attempts++ {
namesObtaining.Add([]string{name}) namesObtaining.Add([]string{name})
acmeMu.Lock() acmeMu.Lock()
...@@ -230,31 +245,15 @@ Attempts: ...@@ -230,31 +245,15 @@ Attempts:
namesObtaining.Remove([]string{name}) namesObtaining.Remove([]string{name})
if len(failures) > 0 { if len(failures) > 0 {
// Error - try to fix it or report it to the user and abort // Error - try to fix it or report it to the user and abort
var errMsg string // we'll combine all the failures into a single error message
var promptedForAgreement bool // only prompt user for agreement at most once
var errMsg string // combine all the failures into a single error message
for errDomain, obtainErr := range failures { for errDomain, obtainErr := range failures {
if obtainErr == nil { if obtainErr == nil {
continue continue
} }
if tosErr, ok := obtainErr.(acme.TOSError); ok { errMsg += fmt.Sprintf("[%s] failed to get certificate: %v\n", errDomain, obtainErr)
// Terms of Service agreement error; we can probably deal with this
if !Agreed && !promptedForAgreement && c.AllowPrompts {
Agreed = promptUserAgreement(tosErr.Detail, true) // TODO: Use latest URL
promptedForAgreement = true
}
if Agreed || !c.AllowPrompts {
err := c.acmeClient.AgreeToTOS()
if err != nil {
return errors.New("error agreeing to updated terms: " + err.Error())
}
continue Attempts
}
} }
// If user did not agree or it was any other kind of error, just append to the list of errors
errMsg += "[" + errDomain + "] failed to get certificate: " + obtainErr.Error() + "\n"
}
return errors.New(errMsg) return errors.New(errMsg)
} }
...@@ -316,19 +315,9 @@ func (c *ACMEClient) Renew(name string) error { ...@@ -316,19 +315,9 @@ func (c *ACMEClient) Renew(name string) error {
break break
} }
// If the legal terms were updated and need to be // wait a little bit and try again
// agreed to again, we can handle that.
if _, ok := err.(acme.TOSError); ok {
err := c.acmeClient.AgreeToTOS()
if err != nil {
return err
}
continue
}
// For any other kind of error, wait 10s and try again.
wait := 10 * time.Second wait := 10 * time.Second
log.Printf("[ERROR] Renewing: %v; trying again in %s", err, wait) log.Printf("[ERROR] Renewing [%v]: %v; trying again in %s", name, err, wait)
time.Sleep(wait) time.Sleep(wait)
} }
......
...@@ -25,7 +25,7 @@ import ( ...@@ -25,7 +25,7 @@ import (
"github.com/codahale/aesnicheck" "github.com/codahale/aesnicheck"
"github.com/mholt/caddy" "github.com/mholt/caddy"
"github.com/xenolf/lego/acme" "github.com/xenolf/lego/acmev2"
) )
// Config describes how TLS should be configured and used. // Config describes how TLS should be configured and used.
...@@ -190,10 +190,15 @@ func NewConfig(inst *caddy.Instance) *Config { ...@@ -190,10 +190,15 @@ func NewConfig(inst *caddy.Instance) *Config {
// it does not load them into memory. If allowPrompts is true, // it does not load them into memory. If allowPrompts is true,
// the user may be shown a prompt. // the user may be shown a prompt.
func (c *Config) ObtainCert(name string, allowPrompts bool) error { func (c *Config) ObtainCert(name string, allowPrompts bool) error {
if !c.Managed || !HostQualifies(name) { skip, err := c.preObtainOrRenewChecks(name, allowPrompts)
if err != nil {
return err
}
if skip {
return nil return nil
} }
// we expect this to be a new (non-existent) site
storage, err := c.StorageFor(c.CAUrl) storage, err := c.StorageFor(c.CAUrl)
if err != nil { if err != nil {
return err return err
...@@ -205,9 +210,6 @@ func (c *Config) ObtainCert(name string, allowPrompts bool) error { ...@@ -205,9 +210,6 @@ func (c *Config) ObtainCert(name string, allowPrompts bool) error {
if siteExists { if siteExists {
return nil return nil
} }
if c.ACMEEmail == "" {
c.ACMEEmail = getEmail(storage, allowPrompts)
}
client, err := newACMEClient(c, allowPrompts) client, err := newACMEClient(c, allowPrompts)
if err != nil { if err != nil {
...@@ -219,6 +221,14 @@ func (c *Config) ObtainCert(name string, allowPrompts bool) error { ...@@ -219,6 +221,14 @@ func (c *Config) ObtainCert(name string, allowPrompts bool) error {
// RenewCert renews the certificate for name using c. It stows the // RenewCert renews the certificate for name using c. It stows the
// renewed certificate and its assets in storage if successful. // renewed certificate and its assets in storage if successful.
func (c *Config) RenewCert(name string, allowPrompts bool) error { func (c *Config) RenewCert(name string, allowPrompts bool) error {
skip, err := c.preObtainOrRenewChecks(name, allowPrompts)
if err != nil {
return err
}
if skip {
return nil
}
client, err := newACMEClient(c, allowPrompts) client, err := newACMEClient(c, allowPrompts)
if err != nil { if err != nil {
return err return err
...@@ -226,6 +236,33 @@ func (c *Config) RenewCert(name string, allowPrompts bool) error { ...@@ -226,6 +236,33 @@ func (c *Config) RenewCert(name string, allowPrompts bool) error {
return client.Renew(name) return client.Renew(name)
} }
// preObtainOrRenewChecks perform a few simple checks before
// obtaining or renewing a certificate with ACME, and returns
// whether this name should be skipped (like if it's not
// managed TLS) as well as any error. It ensures that the
// config is Managed, that the name qualifies for a certificate,
// and that an email address is available.
func (c *Config) preObtainOrRenewChecks(name string, allowPrompts bool) (bool, error) {
if !c.Managed || !HostQualifies(name) {
return true, nil
}
// wildcard certificates require DNS challenge (as of March 2018)
if strings.Contains(name, "*") && c.DNSProvider == "" {
return false, fmt.Errorf("wildcard domain name (%s) requires DNS challenge; use dns subdirective to configure it", name)
}
if c.ACMEEmail == "" {
var err error
c.ACMEEmail, err = getEmail(c, allowPrompts)
if err != nil {
return false, err
}
}
return false, nil
}
// StorageFor obtains a TLS Storage instance for the given CA URL which should // StorageFor obtains a TLS Storage instance for the given CA URL which should
// be unique for every different ACME CA. If a StorageCreator is set on this // be unique for every different ACME CA. If a StorageCreator is set on this
// Config, it will be used. Otherwise the default file storage implementation // Config, it will be used. Otherwise the default file storage implementation
......
...@@ -42,7 +42,7 @@ import ( ...@@ -42,7 +42,7 @@ import (
"golang.org/x/crypto/ocsp" "golang.org/x/crypto/ocsp"
"github.com/mholt/caddy" "github.com/mholt/caddy"
"github.com/xenolf/lego/acme" "github.com/xenolf/lego/acmev2"
) )
// loadPrivateKey loads a PEM-encoded ECC/RSA private key from an array of bytes. // loadPrivateKey loads a PEM-encoded ECC/RSA private key from an array of bytes.
...@@ -107,7 +107,8 @@ func stapleOCSP(cert *Certificate, pemBundle []byte) error { ...@@ -107,7 +107,8 @@ func stapleOCSP(cert *Certificate, pemBundle []byte) error {
// TODO: Use Storage interface instead of disk directly // TODO: Use Storage interface instead of disk directly
var ocspFileNamePrefix string var ocspFileNamePrefix string
if len(cert.Names) > 0 { if len(cert.Names) > 0 {
ocspFileNamePrefix = cert.Names[0] + "-" firstName := strings.Replace(cert.Names[0], "*", "wildcard_", -1)
ocspFileNamePrefix = firstName + "-"
} }
ocspFileName := ocspFileNamePrefix + fastHash(pemBundle) ocspFileName := ocspFileNamePrefix + fastHash(pemBundle)
ocspCachePath := filepath.Join(ocspFolder, ocspFileName) ocspCachePath := filepath.Join(ocspFolder, ocspFileName)
......
...@@ -30,14 +30,14 @@ func init() { ...@@ -30,14 +30,14 @@ func init() {
RegisterStorageProvider("file", NewFileStorage) RegisterStorageProvider("file", NewFileStorage)
} }
// storageBasePath is the root path in which all TLS/ACME assets are
// stored. Do not change this value during the lifetime of the program.
var storageBasePath = filepath.Join(caddy.AssetsPath(), "acme")
// NewFileStorage is a StorageConstructor function that creates a new // NewFileStorage is a StorageConstructor function that creates a new
// Storage instance backed by the local disk. The resulting Storage // Storage instance backed by the local disk. The resulting Storage
// instance is guaranteed to be non-nil if there is no error. // instance is guaranteed to be non-nil if there is no error.
func NewFileStorage(caURL *url.URL) (Storage, error) { func NewFileStorage(caURL *url.URL) (Storage, error) {
// storageBasePath is the root path in which all TLS/ACME assets are
// stored. Do not change this value during the lifetime of the program.
storageBasePath := filepath.Join(caddy.AssetsPath(), "acme")
storage := &FileStorage{Path: filepath.Join(storageBasePath, caURL.Host)} storage := &FileStorage{Path: filepath.Join(storageBasePath, caURL.Host)}
storage.Locker = &fileStorageLock{caURL: caURL.Host, storage: storage} storage.Locker = &fileStorageLock{caURL: caURL.Host, storage: storage}
return storage, nil return storage, nil
...@@ -58,24 +58,29 @@ func (s *FileStorage) sites() string { ...@@ -58,24 +58,29 @@ func (s *FileStorage) sites() string {
// site returns the path to the folder containing assets for domain. // site returns the path to the folder containing assets for domain.
func (s *FileStorage) site(domain string) string { func (s *FileStorage) site(domain string) string {
// Windows doesn't allow * in filenames, sigh...
domain = strings.Replace(domain, "*", "wildcard_", -1)
domain = strings.ToLower(domain) domain = strings.ToLower(domain)
return filepath.Join(s.sites(), domain) return filepath.Join(s.sites(), domain)
} }
// siteCertFile returns the path to the certificate file for domain. // siteCertFile returns the path to the certificate file for domain.
func (s *FileStorage) siteCertFile(domain string) string { func (s *FileStorage) siteCertFile(domain string) string {
domain = strings.Replace(domain, "*", "wildcard_", -1)
domain = strings.ToLower(domain) domain = strings.ToLower(domain)
return filepath.Join(s.site(domain), domain+".crt") return filepath.Join(s.site(domain), domain+".crt")
} }
// siteKeyFile returns the path to domain's private key file. // siteKeyFile returns the path to domain's private key file.
func (s *FileStorage) siteKeyFile(domain string) string { func (s *FileStorage) siteKeyFile(domain string) string {
domain = strings.Replace(domain, "*", "wildcard_", -1)
domain = strings.ToLower(domain) domain = strings.ToLower(domain)
return filepath.Join(s.site(domain), domain+".key") return filepath.Join(s.site(domain), domain+".key")
} }
// siteMetaFile returns the path to the domain's asset metadata file. // siteMetaFile returns the path to the domain's asset metadata file.
func (s *FileStorage) siteMetaFile(domain string) string { func (s *FileStorage) siteMetaFile(domain string) string {
domain = strings.Replace(domain, "*", "wildcard_", -1)
domain = strings.ToLower(domain) domain = strings.ToLower(domain)
return filepath.Join(s.site(domain), domain+".json") return filepath.Join(s.site(domain), domain+".json")
} }
......
...@@ -16,12 +16,16 @@ package caddytls ...@@ -16,12 +16,16 @@ package caddytls
import ( import (
"crypto/tls" "crypto/tls"
"encoding/json"
"fmt" "fmt"
"log" "log"
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
"net/url" "net/url"
"os"
"strings" "strings"
"github.com/xenolf/lego/acmev2"
) )
const challengeBasePath = "/.well-known/acme-challenge" const challengeBasePath = "/.well-known/acme-challenge"
...@@ -38,6 +42,13 @@ func HTTPChallengeHandler(w http.ResponseWriter, r *http.Request, listenHost str ...@@ -38,6 +42,13 @@ func HTTPChallengeHandler(w http.ResponseWriter, r *http.Request, listenHost str
if DisableHTTPChallenge { if DisableHTTPChallenge {
return false return false
} }
// see if another instance started the HTTP challenge for this name
if tryDistributedChallengeSolver(w, r) {
return true
}
// otherwise, if we aren't getting the name, then ignore this challenge
if !namesObtaining.Has(r.Host) { if !namesObtaining.Has(r.Host) {
return false return false
} }
...@@ -70,3 +81,40 @@ func HTTPChallengeHandler(w http.ResponseWriter, r *http.Request, listenHost str ...@@ -70,3 +81,40 @@ func HTTPChallengeHandler(w http.ResponseWriter, r *http.Request, listenHost str
return true return true
} }
// tryDistributedChallengeSolver checks to see if this challenge
// request was initiated by another instance that shares file
// storage, and attempts to complete the challenge for it. It
// returns true if the challenge was handled; false otherwise.
func tryDistributedChallengeSolver(w http.ResponseWriter, r *http.Request) bool {
filePath := distributedHTTPSolver{}.challengeTokensPath(r.Host)
f, err := os.Open(filePath)
if err != nil {
if !os.IsNotExist(err) {
log.Printf("[ERROR][%s] Opening distributed challenge token file: %v", r.Host, err)
}
return false
}
defer f.Close()
var chalInfo challengeInfo
err = json.NewDecoder(f).Decode(&chalInfo)
if err != nil {
log.Printf("[ERROR][%s] Decoding challenge token file %s (corrupted?): %v", r.Host, filePath, err)
return false
}
// this part borrowed from xenolf/lego's built-in HTTP-01 challenge solver (March 2018)
challengeReqPath := acme.HTTP01ChallengePath(chalInfo.Token)
if r.URL.Path == challengeReqPath &&
strings.HasPrefix(r.Host, chalInfo.Domain) &&
r.Method == "GET" {
w.Header().Add("Content-Type", "text/plain")
w.Write([]byte(chalInfo.KeyAuth))
r.Close = true
log.Printf("[INFO][%s] Served key authentication", chalInfo.Domain)
return true
}
return false
}
...@@ -207,8 +207,21 @@ func setupTLS(c *caddy.Controller) error { ...@@ -207,8 +207,21 @@ func setupTLS(c *caddy.Controller) error {
} }
case "must_staple": case "must_staple":
config.MustStaple = true config.MustStaple = true
case "wildcard":
if !HostQualifies(config.Hostname) {
return c.Errf("Hostname '%s' does not qualify for managed TLS, so cannot manage wildcard certificate for it", config.Hostname)
}
if strings.Contains(config.Hostname, "*") {
return c.Errf("Cannot convert domain name '%s' to a valid wildcard: already has a wildcard label", config.Hostname)
}
parts := strings.Split(config.Hostname, ".")
if len(parts) < 3 {
return c.Errf("Cannot convert domain name '%s' to a valid wildcard: too few labels", config.Hostname)
}
parts[0] = "*"
config.Hostname = strings.Join(parts, ".")
default: default:
return c.Errf("Unknown keyword '%s'", c.Val()) return c.Errf("Unknown subdirective '%s'", c.Val())
} }
} }
......
...@@ -22,7 +22,7 @@ import ( ...@@ -22,7 +22,7 @@ import (
"testing" "testing"
"github.com/mholt/caddy" "github.com/mholt/caddy"
"github.com/xenolf/lego/acme" "github.com/xenolf/lego/acmev2"
) )
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
......
...@@ -30,26 +30,35 @@ package caddytls ...@@ -30,26 +30,35 @@ package caddytls
import ( import (
"encoding/json" "encoding/json"
"fmt"
"io/ioutil"
"log"
"net" "net"
"os"
"path/filepath"
"strings" "strings"
"github.com/mholt/caddy" "github.com/mholt/caddy"
"github.com/xenolf/lego/acme" "github.com/xenolf/lego/acmev2"
) )
// HostQualifies returns true if the hostname alone // HostQualifies returns true if the hostname alone
// appears eligible for automatic HTTPS. For example, // appears eligible for automatic HTTPS. For example:
// localhost, empty hostname, and IP addresses are // localhost, empty hostname, and IP addresses are
// not eligible because we cannot obtain certificates // not eligible because we cannot obtain certificates
// for those names. // for those names. Wildcard names are allowed, as long
// as they conform to CABF requirements (only one wildcard
// label, and it must be the left-most label).
func HostQualifies(hostname string) bool { func HostQualifies(hostname string) bool {
return hostname != "localhost" && // localhost is ineligible return hostname != "localhost" && // localhost is ineligible
// hostname must not be empty // hostname must not be empty
strings.TrimSpace(hostname) != "" && strings.TrimSpace(hostname) != "" &&
// must not contain wildcard (*) characters (until CA supports it) // only one wildcard label allowed, and it must be left-most
!strings.Contains(hostname, "*") && (!strings.Contains(hostname, "*") ||
(strings.Count(hostname, "*") == 1 &&
strings.HasPrefix(hostname, "*."))) &&
// must not start or end with a dot // must not start or end with a dot
!strings.HasPrefix(hostname, ".") && !strings.HasPrefix(hostname, ".") &&
...@@ -88,39 +97,125 @@ func Revoke(host string) error { ...@@ -88,39 +97,125 @@ func Revoke(host string) error {
return client.Revoke(host) return client.Revoke(host)
} }
// tlsSNISolver is a type that can solve TLS-SNI challenges using // TODO: tls-sni challenge was removed in January 2018, but a variant of it might return
// an existing listener and our custom, in-memory certificate cache. // // tlsSNISolver is a type that can solve TLS-SNI challenges using
type tlsSNISolver struct { // // an existing listener and our custom, in-memory certificate cache.
certCache *certificateCache // type tlsSNISolver struct {
// certCache *certificateCache
// }
// // Present adds the challenge certificate to the cache.
// func (s tlsSNISolver) Present(domain, token, keyAuth string) error {
// cert, acmeDomain, err := acme.TLSSNI01ChallengeCert(keyAuth)
// if err != nil {
// return err
// }
// certHash := hashCertificateChain(cert.Certificate)
// s.certCache.Lock()
// s.certCache.cache[acmeDomain] = Certificate{
// Certificate: cert,
// Names: []string{acmeDomain},
// Hash: certHash, // perhaps not necesssary
// }
// s.certCache.Unlock()
// return nil
// }
// // CleanUp removes the challenge certificate from the cache.
// func (s tlsSNISolver) CleanUp(domain, token, keyAuth string) error {
// _, acmeDomain, err := acme.TLSSNI01ChallengeCert(keyAuth)
// if err != nil {
// return err
// }
// s.certCache.Lock()
// delete(s.certCache.cache, acmeDomain)
// s.certCache.Unlock()
// return nil
// }
// distributedHTTPSolver allows the HTTP-01 challenge to be solved by
// an instance other than the one which initiated it. This is useful
// behind load balancers or in other cluster/fleet configurations.
// The only requirement is that this (the initiating) instance share
// the $CADDYPATH/acme folder with the instance that will complete
// the challenge. Mounting the folder locally should be sufficient.
//
// Obviously, the instance which completes the challenge must be
// serving on the HTTPChallengePort to receive and handle the request.
// The HTTP server which receives it must check if a file exists, e.g.:
// $CADDYPATH/acme/challenge_tokens/example.com.json, and if so,
// decode it and use it to serve up the correct response. Caddy's HTTP
// server does this by default.
//
// So as long as the folder is shared, this will just work. There are
// no other requirements. The instances may be on other machines or
// even other networks, as long as they share the folder as part of
// the local file system.
//
// This solver works by persisting the token and keyauth information
// to disk in the shared folder when the authorization is presented,
// and then deletes it when it is cleaned up.
type distributedHTTPSolver struct {
// The distributed HTTPS solver only works if an instance (either
// this one or another one) is already listening and serving on the
// HTTPChallengePort. If not -- for example: if this is the only
// instance, and it is just starting up and hasn't started serving
// yet -- then we still need a listener open with an HTTP server
// to handle the challenge request. Set this field to have the
// standard HTTPProviderServer open its listener for the duration
// of the challenge. Make sure to configure its listen address
// correctly.
httpProviderServer *acme.HTTPProviderServer
}
type challengeInfo struct {
Domain, Token, KeyAuth string
} }
// Present adds the challenge certificate to the cache. // Present adds the challenge certificate to the cache.
func (s tlsSNISolver) Present(domain, token, keyAuth string) error { func (dhs distributedHTTPSolver) Present(domain, token, keyAuth string) error {
cert, acmeDomain, err := acme.TLSSNI01ChallengeCert(keyAuth) if dhs.httpProviderServer != nil {
err := dhs.httpProviderServer.Present(domain, token, keyAuth)
if err != nil {
return fmt.Errorf("presenting with standard HTTP provider server: %v", err)
}
}
err := os.MkdirAll(dhs.challengeTokensBasePath(), 0755)
if err != nil { if err != nil {
return err return err
} }
certHash := hashCertificateChain(cert.Certificate)
s.certCache.Lock() infoBytes, err := json.Marshal(challengeInfo{
s.certCache.cache[acmeDomain] = Certificate{ Domain: domain,
Certificate: cert, Token: token,
Names: []string{acmeDomain}, KeyAuth: keyAuth,
Hash: certHash, // perhaps not necesssary })
if err != nil {
return err
} }
s.certCache.Unlock()
return nil return ioutil.WriteFile(dhs.challengeTokensPath(domain), infoBytes, 0644)
} }
// CleanUp removes the challenge certificate from the cache. // CleanUp removes the challenge certificate from the cache.
func (s tlsSNISolver) CleanUp(domain, token, keyAuth string) error { func (dhs distributedHTTPSolver) CleanUp(domain, token, keyAuth string) error {
_, acmeDomain, err := acme.TLSSNI01ChallengeCert(keyAuth) if dhs.httpProviderServer != nil {
err := dhs.httpProviderServer.CleanUp(domain, token, keyAuth)
if err != nil { if err != nil {
return err log.Printf("[ERROR] Cleaning up standard HTTP provider server: %v", err)
} }
s.certCache.Lock() }
delete(s.certCache.cache, acmeDomain) return os.Remove(dhs.challengeTokensPath(domain))
s.certCache.Unlock() }
return nil
func (dhs distributedHTTPSolver) challengeTokensPath(domain string) string {
domainFile := strings.Replace(strings.ToLower(domain), "*", "wildcard_", -1)
return filepath.Join(dhs.challengeTokensBasePath(), domainFile+".json")
}
func (dhs distributedHTTPSolver) challengeTokensBasePath() string {
return filepath.Join(caddy.AssetsPath(), "acme", "challenge_tokens")
} }
// ConfigHolder is any type that has a Config; it presumably is // ConfigHolder is any type that has a Config; it presumably is
......
...@@ -18,7 +18,7 @@ import ( ...@@ -18,7 +18,7 @@ import (
"os" "os"
"testing" "testing"
"github.com/xenolf/lego/acme" "github.com/xenolf/lego/acmev2"
) )
func TestHostQualifies(t *testing.T) { func TestHostQualifies(t *testing.T) {
...@@ -37,7 +37,10 @@ func TestHostQualifies(t *testing.T) { ...@@ -37,7 +37,10 @@ func TestHostQualifies(t *testing.T) {
{"0.0.0.0", false}, {"0.0.0.0", false},
{"", false}, {"", false},
{" ", false}, {" ", false},
{"*.example.com", false}, {"*.example.com", true},
{"*.*.example.com", false},
{"sub.*.example.com", false},
{"*sub.example.com", false},
{".com", false}, {".com", false},
{"example.com.", false}, {"example.com.", false},
{"localhost", false}, {"localhost", false},
...@@ -77,7 +80,10 @@ func TestQualifiesForManagedTLS(t *testing.T) { ...@@ -77,7 +80,10 @@ func TestQualifiesForManagedTLS(t *testing.T) {
{holder{host: "localhost", cfg: new(Config)}, false}, {holder{host: "localhost", cfg: new(Config)}, false},
{holder{host: "123.44.3.21", cfg: new(Config)}, false}, {holder{host: "123.44.3.21", cfg: new(Config)}, false},
{holder{host: "example.com", cfg: new(Config)}, true}, {holder{host: "example.com", cfg: new(Config)}, true},
{holder{host: "*.example.com", cfg: new(Config)}, false}, {holder{host: "*.example.com", cfg: new(Config)}, true},
{holder{host: "*.*.example.com", cfg: new(Config)}, false},
{holder{host: "*sub.example.com", cfg: new(Config)}, false},
{holder{host: "sub.*.example.com", cfg: new(Config)}, false},
{holder{host: "example.com", cfg: &Config{Manual: true}}, false}, {holder{host: "example.com", cfg: &Config{Manual: true}}, false},
{holder{host: "example.com", cfg: &Config{ACMEEmail: "off"}}, false}, {holder{host: "example.com", cfg: &Config{ACMEEmail: "off"}}, false},
{holder{host: "example.com", cfg: &Config{ACMEEmail: "foo@bar.com"}}, true}, {holder{host: "example.com", cfg: &Config{ACMEEmail: "foo@bar.com"}}, true},
......
...@@ -27,7 +27,7 @@ import ( ...@@ -27,7 +27,7 @@ import (
"os" "os"
"strings" "strings"
"github.com/xenolf/lego/acme" "github.com/xenolf/lego/acmev2"
) )
// User represents a Let's Encrypt user account. // User represents a Let's Encrypt user account.
...@@ -67,43 +67,82 @@ func newUser(email string) (User, error) { ...@@ -67,43 +67,82 @@ func newUser(email string) (User, error) {
return user, nil return user, nil
} }
// getEmail does everything it can to obtain an email // getEmail does everything it can to obtain an email address
// address from the user within the scope of storage // from the user within the scope of memory and storage to use
// to use for ACME TLS. If it cannot get an email // for ACME TLS. If it cannot get an email address, it returns
// address, it returns empty string. (It will warn the // empty string. (If user is present, it will warn the user of
// user of the consequences of an empty email.) This // the consequences of an empty email.) This function MAY prompt
// function MAY prompt the user for input. If userPresent // the user for input. If userPresent is false, the operator
// is false, the operator will NOT be prompted and an // will NOT be prompted and an empty email may be returned.
// empty email may be returned. // If the user is prompted, a new User will be created and
func getEmail(storage Storage, userPresent bool) string { // stored in storage according to the email address they
// provided (which might be blank).
func getEmail(cfg *Config, userPresent bool) (string, error) {
storage, err := cfg.StorageFor(cfg.CAUrl)
if err != nil {
return "", err
}
// First try memory (command line flag or typed by user previously) // First try memory (command line flag or typed by user previously)
leEmail := DefaultEmail leEmail := DefaultEmail
// Then try to get most recent user email from storage
if leEmail == "" { if leEmail == "" {
// Then try to get most recent user email
leEmail = storage.MostRecentUserEmail() leEmail = storage.MostRecentUserEmail()
// Save for next time DefaultEmail = leEmail // save for next time
DefaultEmail = leEmail
} }
// Looks like there is no email address readily available,
// so we will have to ask the user if we can.
if leEmail == "" && userPresent { if leEmail == "" && userPresent {
// Alas, we must bother the user and ask for an email address; // evidently, no User data was present in storage;
// if they proceed they also agree to the SA. // thus we must make a new User so that we can get
// the Terms of Service URL via our ACME client, phew!
user, err := newUser("")
if err != nil {
return "", err
}
// get the agreement URL
agreementURL := agreementTestURL
if agreementURL == "" {
// we call acme.NewClient directly because newACMEClient
// would require that we already know the user's email
caURL := DefaultCAUrl
if cfg.CAUrl != "" {
caURL = cfg.CAUrl
}
tempClient, err := acme.NewClient(caURL, user, "")
if err != nil {
return "", fmt.Errorf("making ACME client to get ToS URL: %v", err)
}
agreementURL = tempClient.GetToSURL()
}
// prompt the user for an email address and terms agreement
reader := bufio.NewReader(stdin) reader := bufio.NewReader(stdin)
fmt.Println("\nYour sites will be served over HTTPS automatically using Let's Encrypt.") promptUserAgreement(agreementURL)
fmt.Println("By continuing, you agree to the Let's Encrypt Subscriber Agreement at:") fmt.Println("Please enter your email address to signify agreement and to be notified")
fmt.Println(" " + saURL) // TODO: Show current SA link fmt.Println("in case of issues. You can leave it blank, but we don't recommend it.")
fmt.Println("Please enter your email address so you can recover your account if needed.") fmt.Print(" Email address: ")
fmt.Println("You can leave it blank, but you'll lose the ability to recover your account.")
fmt.Print("Email address: ")
var err error
leEmail, err = reader.ReadString('\n') leEmail, err = reader.ReadString('\n')
if err != nil { if err != nil && err != io.EOF {
return "" return "", fmt.Errorf("reading email address: %v", err)
} }
leEmail = strings.TrimSpace(leEmail) leEmail = strings.TrimSpace(leEmail)
DefaultEmail = leEmail DefaultEmail = leEmail
Agreed = true Agreed = true
// save the new user to preserve this for next time
user.Email = leEmail
err = saveUser(storage, user)
if err != nil {
return "", err
}
} }
return strings.ToLower(leEmail)
// lower-casing the email is important for consistency
return strings.ToLower(leEmail), nil
} }
// getUser loads the user with the given email from disk // getUser loads the user with the given email from disk
...@@ -154,18 +193,21 @@ func saveUser(storage Storage, user User) error { ...@@ -154,18 +193,21 @@ func saveUser(storage Storage, user User) error {
return err return err
} }
// promptUserAgreement prompts the user to agree to the agreement // promptUserAgreement simply outputs the standard user
// at agreementURL via stdin. If the agreement has changed, then pass // agreement prompt with the given agreement URL.
// true as the second argument. If this is the user's first time // It outputs a newline after the message.
// agreeing, pass false. It returns whether the user agreed or not. func promptUserAgreement(agreementURL string) {
func promptUserAgreement(agreementURL string, changed bool) bool { const userAgreementPrompt = `Your sites will be served over HTTPS automatically using Let's Encrypt.
if changed { By continuing, you agree to the Let's Encrypt Subscriber Agreement at:`
fmt.Printf("The Let's Encrypt Subscriber Agreement has changed:\n %s\n", agreementURL) fmt.Printf("\n\n%s\n %s\n", userAgreementPrompt, agreementURL)
fmt.Print("Do you agree to the new terms? (y/n): ") }
} else {
fmt.Printf("To continue, you must agree to the Let's Encrypt Subscriber Agreement:\n %s\n", agreementURL) // askUserAgreement prompts the user to agree to the agreement
// at the given agreement URL via stdin. It returns whether the
// user agreed or not.
func askUserAgreement(agreementURL string) bool {
promptUserAgreement(agreementURL)
fmt.Print("Do you agree to the terms? (y/n): ") fmt.Print("Do you agree to the terms? (y/n): ")
}
reader := bufio.NewReader(stdin) reader := bufio.NewReader(stdin)
answer, err := reader.ReadString('\n') answer, err := reader.ReadString('\n')
...@@ -177,14 +219,15 @@ func promptUserAgreement(agreementURL string, changed bool) bool { ...@@ -177,14 +219,15 @@ func promptUserAgreement(agreementURL string, changed bool) bool {
return answer == "y" || answer == "yes" return answer == "y" || answer == "yes"
} }
// agreementTestURL is set during tests to skip requiring
// setting up an entire ACME CA endpoint.
var agreementTestURL string
// stdin is used to read the user's input if prompted; // stdin is used to read the user's input if prompted;
// this is changed by tests during tests. // this is changed by tests during tests.
var stdin = io.ReadWriter(os.Stdin) var stdin = io.ReadWriter(os.Stdin)
// The name of the folder for accounts where the email // The name of the folder for accounts where the email
// address was not provided; default 'username' if you will. // address was not provided; default 'username' if you will,
// but only for local/storage use, not with the CA.
const emptyEmail = "default" const emptyEmail = "default"
// TODO: After Boulder implements the 'meta' field of the directory,
// we can get this link dynamically.
const saURL = "https://acme-v01.api.letsencrypt.org/terms"
...@@ -20,13 +20,14 @@ import ( ...@@ -20,13 +20,14 @@ import (
"crypto/elliptic" "crypto/elliptic"
"crypto/rand" "crypto/rand"
"io" "io"
"path/filepath"
"strings" "strings"
"testing" "testing"
"time" "time"
"os" "os"
"github.com/xenolf/lego/acme" "github.com/xenolf/lego/acmev2"
) )
func TestUser(t *testing.T) { func TestUser(t *testing.T) {
...@@ -135,7 +136,13 @@ func TestGetUserAlreadyExists(t *testing.T) { ...@@ -135,7 +136,13 @@ func TestGetUserAlreadyExists(t *testing.T) {
} }
func TestGetEmail(t *testing.T) { func TestGetEmail(t *testing.T) {
storageBasePath = testStorage.Path // to contain calls that create a new Storage... // ensure storage (via StorageFor) uses the local testdata folder that we delete later
origCaddypath := os.Getenv("CADDYPATH")
os.Setenv("CADDYPATH", "./testdata")
defer os.Setenv("CADDYPATH", origCaddypath)
agreementTestURL = "(none - testing)"
defer func() { agreementTestURL = "" }()
// let's not clutter up the output // let's not clutter up the output
origStdout := os.Stdout origStdout := os.Stdout
...@@ -146,7 +153,10 @@ func TestGetEmail(t *testing.T) { ...@@ -146,7 +153,10 @@ func TestGetEmail(t *testing.T) {
DefaultEmail = "test2@foo.com" DefaultEmail = "test2@foo.com"
// Test1: Use default email from flag (or user previously typing it) // Test1: Use default email from flag (or user previously typing it)
actual := getEmail(testStorage, true) actual, err := getEmail(testConfig, true)
if err != nil {
t.Fatalf("getEmail (1) error: %v", err)
}
if actual != DefaultEmail { if actual != DefaultEmail {
t.Errorf("Did not get correct email from memory; expected '%s' but got '%s'", DefaultEmail, actual) t.Errorf("Did not get correct email from memory; expected '%s' but got '%s'", DefaultEmail, actual)
} }
...@@ -154,16 +164,19 @@ func TestGetEmail(t *testing.T) { ...@@ -154,16 +164,19 @@ func TestGetEmail(t *testing.T) {
// Test2: Get input from user // Test2: Get input from user
DefaultEmail = "" DefaultEmail = ""
stdin = new(bytes.Buffer) stdin = new(bytes.Buffer)
_, err := io.Copy(stdin, strings.NewReader("test3@foo.com\n")) _, err = io.Copy(stdin, strings.NewReader("test3@foo.com\n"))
if err != nil { if err != nil {
t.Fatalf("Could not simulate user input, error: %v", err) t.Fatalf("Could not simulate user input, error: %v", err)
} }
actual = getEmail(testStorage, true) actual, err = getEmail(testConfig, true)
if err != nil {
t.Fatalf("getEmail (2) error: %v", err)
}
if actual != "test3@foo.com" { if actual != "test3@foo.com" {
t.Errorf("Did not get correct email from user input prompt; expected '%s' but got '%s'", "test3@foo.com", actual) t.Errorf("Did not get correct email from user input prompt; expected '%s' but got '%s'", "test3@foo.com", actual)
} }
// Test3: Get most recent email from before // Test3: Get most recent email from before (in storage)
DefaultEmail = "" DefaultEmail = ""
for i, eml := range []string{ for i, eml := range []string{
"TEST4-3@foo.com", // test case insensitivity "TEST4-3@foo.com", // test case insensitivity
...@@ -189,14 +202,20 @@ func TestGetEmail(t *testing.T) { ...@@ -189,14 +202,20 @@ func TestGetEmail(t *testing.T) {
t.Fatalf("Could not change user folder mod time for '%s': %v", eml, err) t.Fatalf("Could not change user folder mod time for '%s': %v", eml, err)
} }
} }
actual = getEmail(testStorage, true) actual, err = getEmail(testConfig, true)
if err != nil {
t.Fatalf("getEmail (3) error: %v", err)
}
if actual != "test4-3@foo.com" { if actual != "test4-3@foo.com" {
t.Errorf("Did not get correct email from storage; expected '%s' but got '%s'", "test4-3@foo.com", actual) t.Errorf("Did not get correct email from storage; expected '%s' but got '%s'", "test4-3@foo.com", actual)
} }
} }
var testStorage = &FileStorage{Path: "./testdata"} var (
testStorageBase = "./testdata" // ephemeral folder that gets deleted after tests finish
testCAHost = "localhost"
testConfig = &Config{CAUrl: "http://" + testCAHost + "/directory", StorageProvider: "file"}
testStorage = &FileStorage{Path: filepath.Join(testStorageBase, "acme", testCAHost)}
)
func (s *FileStorage) clean() error { func (s *FileStorage) clean() error { return os.RemoveAll(testStorageBase) }
return os.RemoveAll(s.Path)
}
package acme
import (
"time"
"gopkg.in/square/go-jose.v1"
)
type directory struct {
NewAuthzURL string `json:"new-authz"`
NewCertURL string `json:"new-cert"`
NewRegURL string `json:"new-reg"`
RevokeCertURL string `json:"revoke-cert"`
}
type registrationMessage struct {
Resource string `json:"resource"`
Contact []string `json:"contact"`
Delete bool `json:"delete,omitempty"`
}
// Registration is returned by the ACME server after the registration
// The client implementation should save this registration somewhere.
type Registration struct {
Resource string `json:"resource,omitempty"`
ID int `json:"id"`
Key jose.JsonWebKey `json:"key"`
Contact []string `json:"contact"`
Agreement string `json:"agreement,omitempty"`
Authorizations string `json:"authorizations,omitempty"`
Certificates string `json:"certificates,omitempty"`
}
// RegistrationResource represents all important informations about a registration
// of which the client needs to keep track itself.
type RegistrationResource struct {
Body Registration `json:"body,omitempty"`
URI string `json:"uri,omitempty"`
NewAuthzURL string `json:"new_authzr_uri,omitempty"`
TosURL string `json:"terms_of_service,omitempty"`
}
type authorizationResource struct {
Body authorization
Domain string
NewCertURL string
AuthURL string
}
type authorization struct {
Resource string `json:"resource,omitempty"`
Identifier identifier `json:"identifier"`
Status string `json:"status,omitempty"`
Expires time.Time `json:"expires,omitempty"`
Challenges []challenge `json:"challenges,omitempty"`
Combinations [][]int `json:"combinations,omitempty"`
}
type identifier struct {
Type string `json:"type"`
Value string `json:"value"`
}
type validationRecord struct {
URI string `json:"url,omitempty"`
Hostname string `json:"hostname,omitempty"`
Port string `json:"port,omitempty"`
ResolvedAddresses []string `json:"addressesResolved,omitempty"`
UsedAddress string `json:"addressUsed,omitempty"`
}
type challenge struct {
Resource string `json:"resource,omitempty"`
Type Challenge `json:"type,omitempty"`
Status string `json:"status,omitempty"`
URI string `json:"uri,omitempty"`
Token string `json:"token,omitempty"`
KeyAuthorization string `json:"keyAuthorization,omitempty"`
TLS bool `json:"tls,omitempty"`
Iterations int `json:"n,omitempty"`
Error RemoteError `json:"error,omitempty"`
ValidationRecords []validationRecord `json:"validationRecord,omitempty"`
}
type csrMessage struct {
Resource string `json:"resource,omitempty"`
Csr string `json:"csr"`
Authorizations []string `json:"authorizations"`
}
type revokeCertMessage struct {
Resource string `json:"resource"`
Certificate string `json:"certificate"`
}
type deactivateAuthMessage struct {
Resource string `json:"resource,omitempty"`
Status string `jsom:"status"`
}
// CertificateResource represents a CA issued certificate.
// PrivateKey, Certificate and IssuerCertificate are all
// already PEM encoded and can be directly written to disk.
// Certificate may be a certificate bundle, depending on the
// options supplied to create it.
type CertificateResource struct {
Domain string `json:"domain"`
CertURL string `json:"certUrl"`
CertStableURL string `json:"certStableUrl"`
AccountRef string `json:"accountRef,omitempty"`
PrivateKey []byte `json:"-"`
Certificate []byte `json:"-"`
IssuerCertificate []byte `json:"-"`
CSR []byte `json:"-"`
}
package acme
import (
"crypto/rsa"
"crypto/sha256"
"crypto/tls"
"encoding/hex"
"fmt"
"log"
)
type tlsSNIChallenge struct {
jws *jws
validate validateFunc
provider ChallengeProvider
}
func (t *tlsSNIChallenge) Solve(chlng challenge, domain string) error {
// FIXME: https://github.com/ietf-wg-acme/acme/pull/22
// Currently we implement this challenge to track boulder, not the current spec!
logf("[INFO][%s] acme: Trying to solve TLS-SNI-01", domain)
// Generate the Key Authorization for the challenge
keyAuth, err := getKeyAuthorization(chlng.Token, t.jws.privKey)
if err != nil {
return err
}
err = t.provider.Present(domain, chlng.Token, keyAuth)
if err != nil {
return fmt.Errorf("[%s] error presenting token: %v", domain, err)
}
defer func() {
err := t.provider.CleanUp(domain, chlng.Token, keyAuth)
if err != nil {
log.Printf("[%s] error cleaning up: %v", domain, err)
}
}()
return t.validate(t.jws, domain, chlng.URI, challenge{Resource: "challenge", Type: chlng.Type, Token: chlng.Token, KeyAuthorization: keyAuth})
}
// TLSSNI01ChallengeCert returns a certificate and target domain for the `tls-sni-01` challenge
func TLSSNI01ChallengeCert(keyAuth string) (tls.Certificate, string, error) {
// generate a new RSA key for the certificates
tempPrivKey, err := generatePrivateKey(RSA2048)
if err != nil {
return tls.Certificate{}, "", err
}
rsaPrivKey := tempPrivKey.(*rsa.PrivateKey)
rsaPrivPEM := pemEncode(rsaPrivKey)
zBytes := sha256.Sum256([]byte(keyAuth))
z := hex.EncodeToString(zBytes[:sha256.Size])
domain := fmt.Sprintf("%s.%s.acme.invalid", z[:32], z[32:])
tempCertPEM, err := generatePemCert(rsaPrivKey, domain)
if err != nil {
return tls.Certificate{}, "", err
}
certificate, err := tls.X509KeyPair(tempCertPEM, rsaPrivPEM)
if err != nil {
return tls.Certificate{}, "", err
}
return certificate, domain, nil
}
package acme
import (
"crypto/tls"
"fmt"
"net"
"net/http"
)
// TLSProviderServer implements ChallengeProvider for `TLS-SNI-01` challenge
// It may be instantiated without using the NewTLSProviderServer function if
// you want only to use the default values.
type TLSProviderServer struct {
iface string
port string
done chan bool
listener net.Listener
}
// NewTLSProviderServer creates a new TLSProviderServer on the selected interface and port.
// Setting iface and / or port to an empty string will make the server fall back to
// the "any" interface and port 443 respectively.
func NewTLSProviderServer(iface, port string) *TLSProviderServer {
return &TLSProviderServer{iface: iface, port: port}
}
// Present makes the keyAuth available as a cert
func (s *TLSProviderServer) Present(domain, token, keyAuth string) error {
if s.port == "" {
s.port = "443"
}
cert, _, err := TLSSNI01ChallengeCert(keyAuth)
if err != nil {
return err
}
tlsConf := new(tls.Config)
tlsConf.Certificates = []tls.Certificate{cert}
s.listener, err = tls.Listen("tcp", net.JoinHostPort(s.iface, s.port), tlsConf)
if err != nil {
return fmt.Errorf("Could not start HTTPS server for challenge -> %v", err)
}
s.done = make(chan bool)
go func() {
http.Serve(s.listener, nil)
s.done <- true
}()
return nil
}
// CleanUp closes the HTTP server.
func (s *TLSProviderServer) CleanUp(domain, token, keyAuth string) error {
if s.listener == nil {
return nil
}
s.listener.Close()
<-s.done
return nil
}
...@@ -7,9 +7,6 @@ const ( ...@@ -7,9 +7,6 @@ const (
// HTTP01 is the "http-01" ACME challenge https://github.com/ietf-wg-acme/acme/blob/master/draft-ietf-acme-acme.md#http // HTTP01 is the "http-01" ACME challenge https://github.com/ietf-wg-acme/acme/blob/master/draft-ietf-acme-acme.md#http
// Note: HTTP01ChallengePath returns the URL path to fulfill this challenge // Note: HTTP01ChallengePath returns the URL path to fulfill this challenge
HTTP01 = Challenge("http-01") HTTP01 = Challenge("http-01")
// TLSSNI01 is the "tls-sni-01" ACME challenge https://github.com/ietf-wg-acme/acme/blob/master/draft-ietf-acme-acme.md#tls-with-server-name-indication-tls-sni
// Note: TLSSNI01ChallengeCert returns a certificate to fulfill this challenge
TLSSNI01 = Challenge("tls-sni-01")
// DNS01 is the "dns-01" ACME challenge https://github.com/ietf-wg-acme/acme/blob/master/draft-ietf-acme-acme.md#dns // DNS01 is the "dns-01" ACME challenge https://github.com/ietf-wg-acme/acme/blob/master/draft-ietf-acme-acme.md#dns
// Note: DNS01Record returns a DNS record which will fulfill this challenge // Note: DNS01Record returns a DNS record which will fulfill this challenge
DNS01 = Challenge("dns-01") DNS01 = Challenge("dns-01")
......
...@@ -17,12 +17,12 @@ import ( ...@@ -17,12 +17,12 @@ import (
"io/ioutil" "io/ioutil"
"math/big" "math/big"
"net/http" "net/http"
"strings"
"time" "time"
"encoding/asn1" "encoding/asn1"
"golang.org/x/crypto/ocsp" "golang.org/x/crypto/ocsp"
jose "gopkg.in/square/go-jose.v2"
) )
// KeyType represents the key algo as well as the key size or curve to use. // KeyType represents the key algo as well as the key size or curve to use.
...@@ -136,9 +136,9 @@ func getKeyAuthorization(token string, key interface{}) (string, error) { ...@@ -136,9 +136,9 @@ func getKeyAuthorization(token string, key interface{}) (string, error) {
} }
// Generate the Key Authorization for the challenge // Generate the Key Authorization for the challenge
jwk := keyAsJWK(publicKey) jwk := &jose.JSONWebKey{Key: publicKey}
if jwk == nil { if jwk == nil {
return "", errors.New("Could not generate JWK from key.") return "", errors.New("Could not generate JWK from key")
} }
thumbBytes, err := jwk.Thumbprint(crypto.SHA256) thumbBytes, err := jwk.Thumbprint(crypto.SHA256)
if err != nil { if err != nil {
...@@ -146,11 +146,7 @@ func getKeyAuthorization(token string, key interface{}) (string, error) { ...@@ -146,11 +146,7 @@ func getKeyAuthorization(token string, key interface{}) (string, error) {
} }
// unpad the base64URL // unpad the base64URL
keyThumb := base64.URLEncoding.EncodeToString(thumbBytes) keyThumb := base64.RawURLEncoding.EncodeToString(thumbBytes)
index := strings.Index(keyThumb, "=")
if index != -1 {
keyThumb = keyThumb[:index]
}
return token + "." + keyThumb, nil return token + "." + keyThumb, nil
} }
...@@ -177,7 +173,7 @@ func parsePEMBundle(bundle []byte) ([]*x509.Certificate, error) { ...@@ -177,7 +173,7 @@ func parsePEMBundle(bundle []byte) ([]*x509.Certificate, error) {
} }
if len(certificates) == 0 { if len(certificates) == 0 {
return nil, errors.New("No certificates were found while parsing the bundle.") return nil, errors.New("No certificates were found while parsing the bundle")
} }
return certificates, nil return certificates, nil
......
...@@ -11,7 +11,6 @@ import ( ...@@ -11,7 +11,6 @@ import (
"time" "time"
"github.com/miekg/dns" "github.com/miekg/dns"
"golang.org/x/net/publicsuffix"
) )
type preCheckDNSFunc func(fqdn, value string) (bool, error) type preCheckDNSFunc func(fqdn, value string) (bool, error)
...@@ -30,6 +29,7 @@ var defaultNameservers = []string{ ...@@ -30,6 +29,7 @@ var defaultNameservers = []string{
"google-public-dns-b.google.com:53", "google-public-dns-b.google.com:53",
} }
// RecursiveNameservers are used to pre-check DNS propagations
var RecursiveNameservers = getNameservers(defaultResolvConf, defaultNameservers) var RecursiveNameservers = getNameservers(defaultResolvConf, defaultNameservers)
// DNSTimeout is used to override the default DNS timeout of 10 seconds. // DNSTimeout is used to override the default DNS timeout of 10 seconds.
...@@ -58,8 +58,7 @@ func getNameservers(path string, defaults []string) []string { ...@@ -58,8 +58,7 @@ func getNameservers(path string, defaults []string) []string {
func DNS01Record(domain, keyAuth string) (fqdn string, value string, ttl int) { func DNS01Record(domain, keyAuth string) (fqdn string, value string, ttl int) {
keyAuthShaBytes := sha256.Sum256([]byte(keyAuth)) keyAuthShaBytes := sha256.Sum256([]byte(keyAuth))
// base64URL encoding without padding // base64URL encoding without padding
keyAuthSha := base64.URLEncoding.EncodeToString(keyAuthShaBytes[:sha256.Size]) value = base64.RawURLEncoding.EncodeToString(keyAuthShaBytes[:sha256.Size])
value = strings.TrimRight(keyAuthSha, "=")
ttl = 120 ttl = 120
fqdn = fmt.Sprintf("_acme-challenge.%s.", domain) fqdn = fmt.Sprintf("_acme-challenge.%s.", domain)
return return
...@@ -115,7 +114,7 @@ func (s *dnsChallenge) Solve(chlng challenge, domain string) error { ...@@ -115,7 +114,7 @@ func (s *dnsChallenge) Solve(chlng challenge, domain string) error {
return err return err
} }
return s.validate(s.jws, domain, chlng.URI, challenge{Resource: "challenge", Type: chlng.Type, Token: chlng.Token, KeyAuthorization: keyAuth}) return s.validate(s.jws, domain, chlng.URL, challenge{Type: chlng.Type, Token: chlng.Token, KeyAuthorization: keyAuth})
} }
// checkDNSPropagation checks if the expected TXT record has been propagated to all authoritative nameservers. // checkDNSPropagation checks if the expected TXT record has been propagated to all authoritative nameservers.
...@@ -194,7 +193,7 @@ func dnsQuery(fqdn string, rtype uint16, nameservers []string, recursive bool) ( ...@@ -194,7 +193,7 @@ func dnsQuery(fqdn string, rtype uint16, nameservers []string, recursive bool) (
if err == dns.ErrTruncated { if err == dns.ErrTruncated {
tcp := &dns.Client{Net: "tcp", Timeout: DNSTimeout} tcp := &dns.Client{Net: "tcp", Timeout: DNSTimeout}
// If the TCP request suceeds, the err will reset to nil // If the TCP request succeeds, the err will reset to nil
in, _, err = tcp.Exchange(m, ns) in, _, err = tcp.Exchange(m, ns)
} }
...@@ -242,10 +241,6 @@ func FindZoneByFqdn(fqdn string, nameservers []string) (string, error) { ...@@ -242,10 +241,6 @@ func FindZoneByFqdn(fqdn string, nameservers []string) (string, error) {
labelIndexes := dns.Split(fqdn) labelIndexes := dns.Split(fqdn)
for _, index := range labelIndexes { for _, index := range labelIndexes {
domain := fqdn[index:] domain := fqdn[index:]
// Give up if we have reached the TLD
if isTLD(domain) {
break
}
in, err := dnsQuery(domain, dns.TypeSOA, nameservers, true) in, err := dnsQuery(domain, dns.TypeSOA, nameservers, true)
if err != nil { if err != nil {
...@@ -260,6 +255,13 @@ func FindZoneByFqdn(fqdn string, nameservers []string) (string, error) { ...@@ -260,6 +255,13 @@ func FindZoneByFqdn(fqdn string, nameservers []string) (string, error) {
// Check if we got a SOA RR in the answer section // Check if we got a SOA RR in the answer section
if in.Rcode == dns.RcodeSuccess { if in.Rcode == dns.RcodeSuccess {
// CNAME records cannot/should not exist at the root of a zone.
// So we skip a domain when a CNAME is found.
if dnsMsgContainsCNAME(in) {
continue
}
for _, ans := range in.Answer { for _, ans := range in.Answer {
if soa, ok := ans.(*dns.SOA); ok { if soa, ok := ans.(*dns.SOA); ok {
zone := soa.Hdr.Name zone := soa.Hdr.Name
...@@ -273,11 +275,13 @@ func FindZoneByFqdn(fqdn string, nameservers []string) (string, error) { ...@@ -273,11 +275,13 @@ func FindZoneByFqdn(fqdn string, nameservers []string) (string, error) {
return "", fmt.Errorf("Could not find the start of authority") return "", fmt.Errorf("Could not find the start of authority")
} }
func isTLD(domain string) bool { // dnsMsgContainsCNAME checks for a CNAME answer in msg
publicsuffix, _ := publicsuffix.PublicSuffix(UnFqdn(domain)) func dnsMsgContainsCNAME(msg *dns.Msg) bool {
if publicsuffix == UnFqdn(domain) { for _, ans := range msg.Answer {
if _, ok := ans.(*dns.CNAME); ok {
return true return true
} }
}
return false return false
} }
......
...@@ -9,8 +9,8 @@ import ( ...@@ -9,8 +9,8 @@ import (
) )
const ( const (
tosAgreementError = "Must agree to subscriber agreement before any further actions" tosAgreementError = "Terms of service have changed"
invalidNonceError = "JWS has invalid anti-replay nonce" invalidNonceError = "urn:ietf:params:acme:error:badNonce"
) )
// RemoteError is the base type for all errors specific to the ACME protocol. // RemoteError is the base type for all errors specific to the ACME protocol.
...@@ -42,27 +42,11 @@ type domainError struct { ...@@ -42,27 +42,11 @@ type domainError struct {
Error error Error error
} }
type challengeError struct {
RemoteError
records []validationRecord
}
func (c challengeError) Error() string {
var errStr string
for _, validation := range c.records {
errStr = errStr + fmt.Sprintf("\tValidation for %s:%s\n\tResolved to:\n\t\t%s\n\tUsed: %s\n\n",
validation.Hostname, validation.Port, strings.Join(validation.ResolvedAddresses, "\n\t\t"), validation.UsedAddress)
}
return fmt.Sprintf("%s\nError Detail:\n%s", c.RemoteError.Error(), errStr)
}
func handleHTTPError(resp *http.Response) error { func handleHTTPError(resp *http.Response) error {
var errorDetail RemoteError var errorDetail RemoteError
contentType := resp.Header.Get("Content-Type") contentType := resp.Header.Get("Content-Type")
if contentType == "application/json" || contentType == "application/problem+json" { if contentType == "application/json" || strings.HasPrefix(contentType, "application/problem+json") {
err := json.NewDecoder(resp.Body).Decode(&errorDetail) err := json.NewDecoder(resp.Body).Decode(&errorDetail)
if err != nil { if err != nil {
return err return err
...@@ -82,7 +66,7 @@ func handleHTTPError(resp *http.Response) error { ...@@ -82,7 +66,7 @@ func handleHTTPError(resp *http.Response) error {
return TOSError{errorDetail} return TOSError{errorDetail}
} }
if errorDetail.StatusCode == http.StatusBadRequest && strings.HasPrefix(errorDetail.Detail, invalidNonceError) { if errorDetail.StatusCode == http.StatusBadRequest && errorDetail.Type == invalidNonceError {
return NonceError{errorDetail} return NonceError{errorDetail}
} }
...@@ -90,5 +74,5 @@ func handleHTTPError(resp *http.Response) error { ...@@ -90,5 +74,5 @@ func handleHTTPError(resp *http.Response) error {
} }
func handleChallengeError(chlng challenge) error { func handleChallengeError(chlng challenge) error {
return challengeError{chlng.Error, chlng.ValidationRecords} return chlng.Error
} }
...@@ -18,6 +18,7 @@ var UserAgent string ...@@ -18,6 +18,7 @@ var UserAgent string
// HTTPClient is an HTTP client with a reasonable timeout value. // HTTPClient is an HTTP client with a reasonable timeout value.
var HTTPClient = http.Client{ var HTTPClient = http.Client{
Transport: &http.Transport{ Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
Dial: (&net.Dialer{ Dial: (&net.Dialer{
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second, KeepAlive: 30 * time.Second,
...@@ -101,7 +102,7 @@ func getJSON(uri string, respBody interface{}) (http.Header, error) { ...@@ -101,7 +102,7 @@ func getJSON(uri string, respBody interface{}) (http.Header, error) {
func postJSON(j *jws, uri string, reqBody, respBody interface{}) (http.Header, error) { func postJSON(j *jws, uri string, reqBody, respBody interface{}) (http.Header, error) {
jsonBytes, err := json.Marshal(reqBody) jsonBytes, err := json.Marshal(reqBody)
if err != nil { if err != nil {
return nil, errors.New("Failed to marshal network message...") return nil, errors.New("Failed to marshal network message")
} }
resp, err := j.post(uri, jsonBytes) resp, err := j.post(uri, jsonBytes)
......
...@@ -37,5 +37,5 @@ func (s *httpChallenge) Solve(chlng challenge, domain string) error { ...@@ -37,5 +37,5 @@ func (s *httpChallenge) Solve(chlng challenge, domain string) error {
} }
}() }()
return s.validate(s.jws, domain, chlng.URI, challenge{Resource: "challenge", Type: chlng.Type, Token: chlng.Token, KeyAuthorization: keyAuth}) return s.validate(s.jws, domain, chlng.URL, challenge{Type: chlng.Type, Token: chlng.Token, KeyAuthorization: keyAuth})
} }
...@@ -10,37 +10,27 @@ import ( ...@@ -10,37 +10,27 @@ import (
"net/http" "net/http"
"sync" "sync"
"gopkg.in/square/go-jose.v1" "gopkg.in/square/go-jose.v2"
) )
type jws struct { type jws struct {
directoryURL string getNonceURL string
privKey crypto.PrivateKey privKey crypto.PrivateKey
kid string
nonces nonceManager nonces nonceManager
} }
func keyAsJWK(key interface{}) *jose.JsonWebKey {
switch k := key.(type) {
case *ecdsa.PublicKey:
return &jose.JsonWebKey{Key: k, Algorithm: "EC"}
case *rsa.PublicKey:
return &jose.JsonWebKey{Key: k, Algorithm: "RSA"}
default:
return nil
}
}
// Posts a JWS signed message to the specified URL. // Posts a JWS signed message to the specified URL.
// It does NOT close the response body, so the caller must // It does NOT close the response body, so the caller must
// do that if no error was returned. // do that if no error was returned.
func (j *jws) post(url string, content []byte) (*http.Response, error) { func (j *jws) post(url string, content []byte) (*http.Response, error) {
signedContent, err := j.signContent(content) signedContent, err := j.signContent(url, content)
if err != nil { if err != nil {
return nil, fmt.Errorf("Failed to sign content -> %s", err.Error()) return nil, fmt.Errorf("Failed to sign content -> %s", err.Error())
} }
resp, err := httpPost(url, "application/jose+json", bytes.NewBuffer([]byte(signedContent.FullSerialize()))) data := bytes.NewBuffer([]byte(signedContent.FullSerialize()))
resp, err := httpPost(url, "application/jose+json", data)
if err != nil { if err != nil {
return nil, fmt.Errorf("Failed to HTTP POST to %s -> %s", url, err.Error()) return nil, fmt.Errorf("Failed to HTTP POST to %s -> %s", url, err.Error())
} }
...@@ -53,7 +43,7 @@ func (j *jws) post(url string, content []byte) (*http.Response, error) { ...@@ -53,7 +43,7 @@ func (j *jws) post(url string, content []byte) (*http.Response, error) {
return resp, nil return resp, nil
} }
func (j *jws) signContent(content []byte) (*jose.JsonWebSignature, error) { func (j *jws) signContent(url string, content []byte) (*jose.JSONWebSignature, error) {
var alg jose.SignatureAlgorithm var alg jose.SignatureAlgorithm
switch k := j.privKey.(type) { switch k := j.privKey.(type) {
...@@ -67,11 +57,28 @@ func (j *jws) signContent(content []byte) (*jose.JsonWebSignature, error) { ...@@ -67,11 +57,28 @@ func (j *jws) signContent(content []byte) (*jose.JsonWebSignature, error) {
} }
} }
signer, err := jose.NewSigner(alg, j.privKey) jsonKey := jose.JSONWebKey{
Key: j.privKey,
KeyID: j.kid,
}
signKey := jose.SigningKey{
Algorithm: alg,
Key: jsonKey,
}
options := jose.SignerOptions{
NonceSource: j,
ExtraHeaders: make(map[jose.HeaderKey]interface{}),
}
options.ExtraHeaders["url"] = url
if j.kid == "" {
options.EmbedJWK = true
}
signer, err := jose.NewSigner(signKey, &options)
if err != nil { if err != nil {
return nil, fmt.Errorf("Failed to create jose signer -> %s", err.Error()) return nil, fmt.Errorf("Failed to create jose signer -> %s", err.Error())
} }
signer.SetNonceSource(j)
signed, err := signer.Sign(content) signed, err := signer.Sign(content)
if err != nil { if err != nil {
...@@ -85,7 +92,7 @@ func (j *jws) Nonce() (string, error) { ...@@ -85,7 +92,7 @@ func (j *jws) Nonce() (string, error) {
return nonce, nil return nonce, nil
} }
return getNonce(j.directoryURL) return getNonce(j.getNonceURL)
} }
type nonceManager struct { type nonceManager struct {
...@@ -124,7 +131,7 @@ func getNonce(url string) (string, error) { ...@@ -124,7 +131,7 @@ func getNonce(url string) (string, error) {
func getNonceFromResponse(resp *http.Response) (string, error) { func getNonceFromResponse(resp *http.Response) (string, error) {
nonce := resp.Header.Get("Replay-Nonce") nonce := resp.Header.Get("Replay-Nonce")
if nonce == "" { if nonce == "" {
return "", fmt.Errorf("Server did not respond with a proper nonce header.") return "", fmt.Errorf("Server did not respond with a proper nonce header")
} }
return nonce, nil return nonce, nil
......
package acme
import (
"time"
)
// RegistrationResource represents all important informations about a registration
// of which the client needs to keep track itself.
type RegistrationResource struct {
Body accountMessage `json:"body,omitempty"`
URI string `json:"uri,omitempty"`
}
type directory struct {
NewNonceURL string `json:"newNonce"`
NewAccountURL string `json:"newAccount"`
NewOrderURL string `json:"newOrder"`
RevokeCertURL string `json:"revokeCert"`
KeyChangeURL string `json:"keyChange"`
Meta struct {
TermsOfService string `json:"termsOfService"`
Website string `json:"website"`
CaaIdentities []string `json:"caaIdentities"`
ExternalAccountRequired bool `json:"externalAccountRequired"`
} `json:"meta"`
}
type accountMessage struct {
Status string `json:"status,omitempty"`
Contact []string `json:"contact,omitempty"`
TermsOfServiceAgreed bool `json:"termsOfServiceAgreed,omitempty"`
Orders string `json:"orders,omitempty"`
OnlyReturnExisting bool `json:"onlyReturnExisting,omitempty"`
}
type orderResource struct {
URL string `json:"url,omitempty"`
orderMessage `json:"body,omitempty"`
}
type orderMessage struct {
Status string `json:"status,omitempty"`
Expires string `json:"expires,omitempty"`
Identifiers []identifier `json:"identifiers"`
NotBefore string `json:"notBefore,omitempty"`
NotAfter string `json:"notAfter,omitempty"`
Authorizations []string `json:"authorizations,omitempty"`
Finalize string `json:"finalize,omitempty"`
Certificate string `json:"certificate,omitempty"`
}
type authorization struct {
Status string `json:"status"`
Expires time.Time `json:"expires"`
Identifier identifier `json:"identifier"`
Challenges []challenge `json:"challenges"`
}
type identifier struct {
Type string `json:"type"`
Value string `json:"value"`
}
type challenge struct {
URL string `json:"url"`
Type string `json:"type"`
Status string `json:"status"`
Token string `json:"token"`
Validated time.Time `json:"validated"`
KeyAuthorization string `json:"keyAuthorization"`
Error RemoteError `json:"error"`
}
type csrMessage struct {
Csr string `json:"csr"`
}
type emptyObjectMessage struct {
}
type revokeCertMessage struct {
Certificate string `json:"certificate"`
}
type deactivateAuthMessage struct {
Status string `jsom:"status"`
}
// CertificateResource represents a CA issued certificate.
// PrivateKey, Certificate and IssuerCertificate are all
// already PEM encoded and can be directly written to disk.
// Certificate may be a certificate bundle, depending on the
// options supplied to create it.
type CertificateResource struct {
Domain string `json:"domain"`
CertURL string `json:"certUrl"`
CertStableURL string `json:"certStableUrl"`
AccountRef string `json:"accountRef,omitempty"`
PrivateKey []byte `json:"-"`
Certificate []byte `json:"-"`
IssuerCertificate []byte `json:"-"`
CSR []byte `json:"-"`
}
This diff is collapsed.
This diff is collapsed.
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package autocert
import (
"context"
"errors"
"io/ioutil"
"os"
"path/filepath"
)
// ErrCacheMiss is returned when a certificate is not found in cache.
var ErrCacheMiss = errors.New("acme/autocert: certificate cache miss")
// Cache is used by Manager to store and retrieve previously obtained certificates
// as opaque data.
//
// The key argument of the methods refers to a domain name but need not be an FQDN.
// Cache implementations should not rely on the key naming pattern.
type Cache interface {
// Get returns a certificate data for the specified key.
// If there's no such key, Get returns ErrCacheMiss.
Get(ctx context.Context, key string) ([]byte, error)
// Put stores the data in the cache under the specified key.
// Underlying implementations may use any data storage format,
// as long as the reverse operation, Get, results in the original data.
Put(ctx context.Context, key string, data []byte) error
// Delete removes a certificate data from the cache under the specified key.
// If there's no such key in the cache, Delete returns nil.
Delete(ctx context.Context, key string) error
}
// DirCache implements Cache using a directory on the local filesystem.
// If the directory does not exist, it will be created with 0700 permissions.
type DirCache string
// Get reads a certificate data from the specified file name.
func (d DirCache) Get(ctx context.Context, name string) ([]byte, error) {
name = filepath.Join(string(d), name)
var (
data []byte
err error
done = make(chan struct{})
)
go func() {
data, err = ioutil.ReadFile(name)
close(done)
}()
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-done:
}
if os.IsNotExist(err) {
return nil, ErrCacheMiss
}
return data, err
}
// Put writes the certificate data to the specified file name.
// The file will be created with 0600 permissions.
func (d DirCache) Put(ctx context.Context, name string, data []byte) error {
if err := os.MkdirAll(string(d), 0700); err != nil {
return err
}
done := make(chan struct{})
var err error
go func() {
defer close(done)
var tmp string
if tmp, err = d.writeTempFile(name, data); err != nil {
return
}
select {
case <-ctx.Done():
// Don't overwrite the file if the context was canceled.
default:
newName := filepath.Join(string(d), name)
err = os.Rename(tmp, newName)
}
}()
select {
case <-ctx.Done():
return ctx.Err()
case <-done:
}
return err
}
// Delete removes the specified file name.
func (d DirCache) Delete(ctx context.Context, name string) error {
name = filepath.Join(string(d), name)
var (
err error
done = make(chan struct{})
)
go func() {
err = os.Remove(name)
close(done)
}()
select {
case <-ctx.Done():
return ctx.Err()
case <-done:
}
if err != nil && !os.IsNotExist(err) {
return err
}
return nil
}
// writeTempFile writes b to a temporary file, closes the file and returns its path.
func (d DirCache) writeTempFile(prefix string, b []byte) (string, error) {
// TempFile uses 0600 permissions
f, err := ioutil.TempFile(string(d), prefix)
if err != nil {
return "", err
}
if _, err := f.Write(b); err != nil {
f.Close()
return "", err
}
return f.Name(), f.Close()
}
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package autocert
import (
"crypto/tls"
"log"
"net"
"os"
"path/filepath"
"runtime"
"time"
)
// NewListener returns a net.Listener that listens on the standard TLS
// port (443) on all interfaces and returns *tls.Conn connections with
// LetsEncrypt certificates for the provided domain or domains.
//
// It enables one-line HTTPS servers:
//
// log.Fatal(http.Serve(autocert.NewListener("example.com"), handler))
//
// NewListener is a convenience function for a common configuration.
// More complex or custom configurations can use the autocert.Manager
// type instead.
//
// Use of this function implies acceptance of the LetsEncrypt Terms of
// Service. If domains is not empty, the provided domains are passed
// to HostWhitelist. If domains is empty, the listener will do
// LetsEncrypt challenges for any requested domain, which is not
// recommended.
//
// Certificates are cached in a "golang-autocert" directory under an
// operating system-specific cache or temp directory. This may not
// be suitable for servers spanning multiple machines.
//
// The returned listener uses a *tls.Config that enables HTTP/2, and
// should only be used with servers that support HTTP/2.
//
// The returned Listener also enables TCP keep-alives on the accepted
// connections. The returned *tls.Conn are returned before their TLS
// handshake has completed.
func NewListener(domains ...string) net.Listener {
m := &Manager{
Prompt: AcceptTOS,
}
if len(domains) > 0 {
m.HostPolicy = HostWhitelist(domains...)
}
dir := cacheDir()
if err := os.MkdirAll(dir, 0700); err != nil {
log.Printf("warning: autocert.NewListener not using a cache: %v", err)
} else {
m.Cache = DirCache(dir)
}
return m.Listener()
}
// Listener listens on the standard TLS port (443) on all interfaces
// and returns a net.Listener returning *tls.Conn connections.
//
// The returned listener uses a *tls.Config that enables HTTP/2, and
// should only be used with servers that support HTTP/2.
//
// The returned Listener also enables TCP keep-alives on the accepted
// connections. The returned *tls.Conn are returned before their TLS
// handshake has completed.
//
// Unlike NewListener, it is the caller's responsibility to initialize
// the Manager m's Prompt, Cache, HostPolicy, and other desired options.
func (m *Manager) Listener() net.Listener {
ln := &listener{
m: m,
conf: &tls.Config{
GetCertificate: m.GetCertificate, // bonus: panic on nil m
NextProtos: []string{"h2", "http/1.1"}, // Enable HTTP/2
},
}
ln.tcpListener, ln.tcpListenErr = net.Listen("tcp", ":443")
return ln
}
type listener struct {
m *Manager
conf *tls.Config
tcpListener net.Listener
tcpListenErr error
}
func (ln *listener) Accept() (net.Conn, error) {
if ln.tcpListenErr != nil {
return nil, ln.tcpListenErr
}
conn, err := ln.tcpListener.Accept()
if err != nil {
return nil, err
}
tcpConn := conn.(*net.TCPConn)
// Because Listener is a convenience function, help out with
// this too. This is not possible for the caller to set once
// we return a *tcp.Conn wrapping an inaccessible net.Conn.
// If callers don't want this, they can do things the manual
// way and tweak as needed. But this is what net/http does
// itself, so copy that. If net/http changes, we can change
// here too.
tcpConn.SetKeepAlive(true)
tcpConn.SetKeepAlivePeriod(3 * time.Minute)
return tls.Server(tcpConn, ln.conf), nil
}
func (ln *listener) Addr() net.Addr {
if ln.tcpListener != nil {
return ln.tcpListener.Addr()
}
// net.Listen failed. Return something non-nil in case callers
// call Addr before Accept:
return &net.TCPAddr{IP: net.IP{0, 0, 0, 0}, Port: 443}
}
func (ln *listener) Close() error {
if ln.tcpListenErr != nil {
return ln.tcpListenErr
}
return ln.tcpListener.Close()
}
func homeDir() string {
if runtime.GOOS == "windows" {
return os.Getenv("HOMEDRIVE") + os.Getenv("HOMEPATH")
}
if h := os.Getenv("HOME"); h != "" {
return h
}
return "/"
}
func cacheDir() string {
const base = "golang-autocert"
switch runtime.GOOS {
case "darwin":
return filepath.Join(homeDir(), "Library", "Caches", base)
case "windows":
for _, ev := range []string{"APPDATA", "CSIDL_APPDATA", "TEMP", "TMP"} {
if v := os.Getenv(ev); v != "" {
return filepath.Join(v, base)
}
}
// Worst case:
return filepath.Join(homeDir(), base)
}
if xdg := os.Getenv("XDG_CACHE_HOME"); xdg != "" {
return filepath.Join(xdg, base)
}
return filepath.Join(homeDir(), ".cache", base)
}
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package autocert
import (
"context"
"crypto"
"sync"
"time"
)
// renewJitter is the maximum deviation from Manager.RenewBefore.
const renewJitter = time.Hour
// domainRenewal tracks the state used by the periodic timers
// renewing a single domain's cert.
type domainRenewal struct {
m *Manager
domain string
key crypto.Signer
timerMu sync.Mutex
timer *time.Timer
}
// start starts a cert renewal timer at the time
// defined by the certificate expiration time exp.
//
// If the timer is already started, calling start is a noop.
func (dr *domainRenewal) start(exp time.Time) {
dr.timerMu.Lock()
defer dr.timerMu.Unlock()
if dr.timer != nil {
return
}
dr.timer = time.AfterFunc(dr.next(exp), dr.renew)
}
// stop stops the cert renewal timer.
// If the timer is already stopped, calling stop is a noop.
func (dr *domainRenewal) stop() {
dr.timerMu.Lock()
defer dr.timerMu.Unlock()
if dr.timer == nil {
return
}
dr.timer.Stop()
dr.timer = nil
}
// renew is called periodically by a timer.
// The first renew call is kicked off by dr.start.
func (dr *domainRenewal) renew() {
dr.timerMu.Lock()
defer dr.timerMu.Unlock()
if dr.timer == nil {
return
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
// TODO: rotate dr.key at some point?
next, err := dr.do(ctx)
if err != nil {
next = renewJitter / 2
next += time.Duration(pseudoRand.int63n(int64(next)))
}
dr.timer = time.AfterFunc(next, dr.renew)
testDidRenewLoop(next, err)
}
// do is similar to Manager.createCert but it doesn't lock a Manager.state item.
// Instead, it requests a new certificate independently and, upon success,
// replaces dr.m.state item with a new one and updates cache for the given domain.
//
// It may return immediately if the expiration date of the currently cached cert
// is far enough in the future.
//
// The returned value is a time interval after which the renewal should occur again.
func (dr *domainRenewal) do(ctx context.Context) (time.Duration, error) {
// a race is likely unavoidable in a distributed environment
// but we try nonetheless
if tlscert, err := dr.m.cacheGet(ctx, dr.domain); err == nil {
next := dr.next(tlscert.Leaf.NotAfter)
if next > dr.m.renewBefore()+renewJitter {
return next, nil
}
}
der, leaf, err := dr.m.authorizedCert(ctx, dr.key, dr.domain)
if err != nil {
return 0, err
}
state := &certState{
key: dr.key,
cert: der,
leaf: leaf,
}
tlscert, err := state.tlscert()
if err != nil {
return 0, err
}
dr.m.cachePut(ctx, dr.domain, tlscert)
dr.m.stateMu.Lock()
defer dr.m.stateMu.Unlock()
// m.state is guaranteed to be non-nil at this point
dr.m.state[dr.domain] = state
return dr.next(leaf.NotAfter), nil
}
func (dr *domainRenewal) next(expiry time.Time) time.Duration {
d := expiry.Sub(timeNow()) - dr.m.renewBefore()
// add a bit of randomness to renew deadline
n := pseudoRand.int63n(int64(renewJitter))
d -= time.Duration(n)
if d < 0 {
return 0
}
return d
}
var testDidRenewLoop = func(next time.Duration, err error) {}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package acme
import (
"crypto"
"crypto/ecdsa"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
_ "crypto/sha512" // need for EC keys
"encoding/base64"
"encoding/json"
"fmt"
"math/big"
)
// jwsEncodeJSON signs claimset using provided key and a nonce.
// The result is serialized in JSON format.
// See https://tools.ietf.org/html/rfc7515#section-7.
func jwsEncodeJSON(claimset interface{}, key crypto.Signer, nonce string) ([]byte, error) {
jwk, err := jwkEncode(key.Public())
if err != nil {
return nil, err
}
alg, sha := jwsHasher(key)
if alg == "" || !sha.Available() {
return nil, ErrUnsupportedKey
}
phead := fmt.Sprintf(`{"alg":%q,"jwk":%s,"nonce":%q}`, alg, jwk, nonce)
phead = base64.RawURLEncoding.EncodeToString([]byte(phead))
cs, err := json.Marshal(claimset)
if err != nil {
return nil, err
}
payload := base64.RawURLEncoding.EncodeToString(cs)
hash := sha.New()
hash.Write([]byte(phead + "." + payload))
sig, err := jwsSign(key, sha, hash.Sum(nil))
if err != nil {
return nil, err
}
enc := struct {
Protected string `json:"protected"`
Payload string `json:"payload"`
Sig string `json:"signature"`
}{
Protected: phead,
Payload: payload,
Sig: base64.RawURLEncoding.EncodeToString(sig),
}
return json.Marshal(&enc)
}
// jwkEncode encodes public part of an RSA or ECDSA key into a JWK.
// The result is also suitable for creating a JWK thumbprint.
// https://tools.ietf.org/html/rfc7517
func jwkEncode(pub crypto.PublicKey) (string, error) {
switch pub := pub.(type) {
case *rsa.PublicKey:
// https://tools.ietf.org/html/rfc7518#section-6.3.1
n := pub.N
e := big.NewInt(int64(pub.E))
// Field order is important.
// See https://tools.ietf.org/html/rfc7638#section-3.3 for details.
return fmt.Sprintf(`{"e":"%s","kty":"RSA","n":"%s"}`,
base64.RawURLEncoding.EncodeToString(e.Bytes()),
base64.RawURLEncoding.EncodeToString(n.Bytes()),
), nil
case *ecdsa.PublicKey:
// https://tools.ietf.org/html/rfc7518#section-6.2.1
p := pub.Curve.Params()
n := p.BitSize / 8
if p.BitSize%8 != 0 {
n++
}
x := pub.X.Bytes()
if n > len(x) {
x = append(make([]byte, n-len(x)), x...)
}
y := pub.Y.Bytes()
if n > len(y) {
y = append(make([]byte, n-len(y)), y...)
}
// Field order is important.
// See https://tools.ietf.org/html/rfc7638#section-3.3 for details.
return fmt.Sprintf(`{"crv":"%s","kty":"EC","x":"%s","y":"%s"}`,
p.Name,
base64.RawURLEncoding.EncodeToString(x),
base64.RawURLEncoding.EncodeToString(y),
), nil
}
return "", ErrUnsupportedKey
}
// jwsSign signs the digest using the given key.
// It returns ErrUnsupportedKey if the key type is unknown.
// The hash is used only for RSA keys.
func jwsSign(key crypto.Signer, hash crypto.Hash, digest []byte) ([]byte, error) {
switch key := key.(type) {
case *rsa.PrivateKey:
return key.Sign(rand.Reader, digest, hash)
case *ecdsa.PrivateKey:
r, s, err := ecdsa.Sign(rand.Reader, key, digest)
if err != nil {
return nil, err
}
rb, sb := r.Bytes(), s.Bytes()
size := key.Params().BitSize / 8
if size%8 > 0 {
size++
}
sig := make([]byte, size*2)
copy(sig[size-len(rb):], rb)
copy(sig[size*2-len(sb):], sb)
return sig, nil
}
return nil, ErrUnsupportedKey
}
// jwsHasher indicates suitable JWS algorithm name and a hash function
// to use for signing a digest with the provided key.
// It returns ("", 0) if the key is not supported.
func jwsHasher(key crypto.Signer) (string, crypto.Hash) {
switch key := key.(type) {
case *rsa.PrivateKey:
return "RS256", crypto.SHA256
case *ecdsa.PrivateKey:
switch key.Params().Name {
case "P-256":
return "ES256", crypto.SHA256
case "P-384":
return "ES384", crypto.SHA384
case "P-521":
return "ES512", crypto.SHA512
}
}
return "", 0
}
// JWKThumbprint creates a JWK thumbprint out of pub
// as specified in https://tools.ietf.org/html/rfc7638.
func JWKThumbprint(pub crypto.PublicKey) (string, error) {
jwk, err := jwkEncode(pub)
if err != nil {
return "", err
}
b := sha256.Sum256([]byte(jwk))
return base64.RawURLEncoding.EncodeToString(b[:]), nil
}
This diff is collapsed.
This diff is collapsed.
...@@ -28,7 +28,7 @@ import ( ...@@ -28,7 +28,7 @@ import (
// size may be at most 1<<16 bytes (64 KiB). // size may be at most 1<<16 bytes (64 KiB).
func DeriveECDHES(alg string, apuData, apvData []byte, priv *ecdsa.PrivateKey, pub *ecdsa.PublicKey, size int) []byte { func DeriveECDHES(alg string, apuData, apvData []byte, priv *ecdsa.PrivateKey, pub *ecdsa.PublicKey, size int) []byte {
if size > 1<<16 { if size > 1<<16 {
panic("ECDH-ES output size too large, must be less than 1<<16") panic("ECDH-ES output size too large, must be less than or equal to 1<<16")
} }
// algId, partyUInfo, partyVInfo inputs must be prefixed with the length // algId, partyUInfo, partyVInfo inputs must be prefixed with the length
......
...@@ -17,10 +17,11 @@ ...@@ -17,10 +17,11 @@
/* /*
Package jose aims to provide an implementation of the Javascript Object Signing Package jose aims to provide an implementation of the Javascript Object Signing
and Encryption set of standards. For the moment, it mainly focuses on and Encryption set of standards. It implements encryption and signing based on
encryption and signing based on the JSON Web Encryption and JSON Web Signature the JSON Web Encryption and JSON Web Signature standards, with optional JSON
standards. The library supports both the compact and full serialization Web Token support available in a sub-package. The library supports both the
formats, and has optional support for multiple recipients. compact and full serialization formats, and has optional support for multiple
recipients.
*/ */
package jose // import "gopkg.in/square/go-jose.v1" package jose
...@@ -21,29 +21,14 @@ import ( ...@@ -21,29 +21,14 @@ import (
"compress/flate" "compress/flate"
"encoding/base64" "encoding/base64"
"encoding/binary" "encoding/binary"
"encoding/json"
"io" "io"
"math/big" "math/big"
"regexp" "regexp"
"strings"
"gopkg.in/square/go-jose.v1/json"
) )
var stripWhitespaceRegex = regexp.MustCompile("\\s") var stripWhitespaceRegex = regexp.MustCompile("\\s")
// Url-safe base64 encode that strips padding
func base64URLEncode(data []byte) string {
var result = base64.URLEncoding.EncodeToString(data)
return strings.TrimRight(result, "=")
}
// Url-safe base64 decoder that adds padding
func base64URLDecode(data string) ([]byte, error) {
var missing = (4 - len(data)%4) % 4
data += strings.Repeat("=", missing)
return base64.URLEncoding.DecodeString(data)
}
// Helper function to serialize known-good objects. // Helper function to serialize known-good objects.
// Precondition: value is not a nil pointer. // Precondition: value is not a nil pointer.
func mustSerializeJSON(value interface{}) []byte { func mustSerializeJSON(value interface{}) []byte {
...@@ -162,7 +147,7 @@ func (b *byteBuffer) UnmarshalJSON(data []byte) error { ...@@ -162,7 +147,7 @@ func (b *byteBuffer) UnmarshalJSON(data []byte) error {
return nil return nil
} }
decoded, err := base64URLDecode(encoded) decoded, err := base64.RawURLEncoding.DecodeString(encoded)
if err != nil { if err != nil {
return err return err
} }
...@@ -173,7 +158,7 @@ func (b *byteBuffer) UnmarshalJSON(data []byte) error { ...@@ -173,7 +158,7 @@ func (b *byteBuffer) UnmarshalJSON(data []byte) error {
} }
func (b *byteBuffer) base64() string { func (b *byteBuffer) base64() string {
return base64URLEncode(b.data) return base64.RawURLEncoding.EncodeToString(b.data)
} }
func (b *byteBuffer) bytes() []byte { func (b *byteBuffer) bytes() []byte {
......
...@@ -14,15 +14,32 @@ ...@@ -14,15 +14,32 @@
* limitations under the License. * limitations under the License.
*/ */
package jose package main
import ( import (
"crypto/x509" "crypto/x509"
"encoding/pem" "encoding/pem"
"errors"
"fmt" "fmt"
"gopkg.in/square/go-jose.v2"
) )
// LoadPublicKey loads a public key from PEM/DER-encoded data. func LoadJSONWebKey(json []byte, pub bool) (*jose.JSONWebKey, error) {
var jwk jose.JSONWebKey
err := jwk.UnmarshalJSON(json)
if err != nil {
return nil, err
}
if !jwk.Valid() {
return nil, errors.New("invalid JWK key")
}
if jwk.IsPublic() != pub {
return nil, errors.New("priv/pub JWK key mismatch")
}
return &jwk, nil
}
// LoadPublicKey loads a public key from PEM/DER/JWK-encoded data.
func LoadPublicKey(data []byte) (interface{}, error) { func LoadPublicKey(data []byte) (interface{}, error) {
input := data input := data
...@@ -42,10 +59,15 @@ func LoadPublicKey(data []byte) (interface{}, error) { ...@@ -42,10 +59,15 @@ func LoadPublicKey(data []byte) (interface{}, error) {
return cert.PublicKey, nil return cert.PublicKey, nil
} }
return nil, fmt.Errorf("square/go-jose: parse error, got '%s' and '%s'", err0, err1) jwk, err2 := LoadJSONWebKey(data, true)
if err2 == nil {
return jwk, nil
}
return nil, fmt.Errorf("square/go-jose: parse error, got '%s', '%s' and '%s'", err0, err1, err2)
} }
// LoadPrivateKey loads a private key from PEM/DER-encoded data. // LoadPrivateKey loads a private key from PEM/DER/JWK-encoded data.
func LoadPrivateKey(data []byte) (interface{}, error) { func LoadPrivateKey(data []byte) (interface{}, error) {
input := data input := data
...@@ -70,5 +92,10 @@ func LoadPrivateKey(data []byte) (interface{}, error) { ...@@ -70,5 +92,10 @@ func LoadPrivateKey(data []byte) (interface{}, error) {
return priv, nil return priv, nil
} }
return nil, fmt.Errorf("square/go-jose: parse error, got '%s', '%s' and '%s'", err0, err1, err2) jwk, err3 := LoadJSONWebKey(input, false)
if err3 == nil {
return jwk, nil
}
return nil, fmt.Errorf("square/go-jose: parse error, got '%s', '%s', '%s' and '%s'", err0, err1, err2, err3)
} }
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment